Skip to content

Commit

Permalink
Implemented verify slot winner and fixed network (#618)
Browse files Browse the repository at this point in the history
Implement verify slot winner and fix network block announce pears
  • Loading branch information
osrib authored Nov 22, 2024
1 parent 141ea58 commit d49409a
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/main/java/com/limechain/babe/Authorship.java
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ private static Integer getSecondarySlotAuthor(byte[] randomness,
}

// threshold = 2^128 * (1 - (1 - c) ^ (authority_weight / sum(authorities_weights)))
private static BigInteger calculatePrimaryThreshold(
public static BigInteger calculatePrimaryThreshold(
Pair<BigInteger, BigInteger> constant,
List<Authority> authorities,
int authorityIndex) {
Expand Down Expand Up @@ -259,7 +259,7 @@ private static double getBabeConstant(Pair<BigInteger, BigInteger> constant) {
return c;
}

private static TranscriptData makeTranscript(byte[] randomness, BigInteger slotNumber, BigInteger epochIndex) {
public static TranscriptData makeTranscript(byte[] randomness, BigInteger slotNumber, BigInteger epochIndex) {
var transcript = new TranscriptData("BABE".getBytes());
transcript.appendMessage("slot number", LittleEndianUtils.toLittleEndianBytes(slotNumber));
transcript.appendMessage("current epoch", LittleEndianUtils.toLittleEndianBytes(epochIndex));
Expand Down
5 changes: 3 additions & 2 deletions src/main/java/com/limechain/babe/BabeService.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ public class BabeService implements SlotChangeListener {
private final EpochState epochState;
private final KeyStore keyStore;
private final Map<BigInteger, BabePreDigest> slotToPreRuntimeDigest = new HashedMap<>();
private final Network network = AppBean.getBean(Network.class);
private final Network network;

public BabeService(EpochState epochState, KeyStore keyStore) {
public BabeService(EpochState epochState, KeyStore keyStore, Network network) {
this.epochState = epochState;
this.keyStore = keyStore;
this.network = network;
}

private void executeEpochLottery(BigInteger epochIndex) {
Expand Down
32 changes: 32 additions & 0 deletions src/main/java/com/limechain/babe/BlockProductionVerifier.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.limechain.babe;

import com.limechain.babe.state.EpochState;
import com.limechain.chain.lightsyncstate.Authority;
import com.limechain.rpc.server.AppBean;
import com.limechain.utils.LittleEndianUtils;
import io.emeraldpay.polkaj.merlin.TranscriptData;
import io.emeraldpay.polkaj.schnorrkel.Schnorrkel;
import io.emeraldpay.polkaj.schnorrkel.VrfOutputAndProof;
import lombok.extern.java.Log;

import java.math.BigInteger;
import java.util.List;
import java.util.logging.Level;

@Log
public class BlockProductionVerifier {

private EpochState epochState = AppBean.getBean(EpochState.class);

public boolean verifySlotWinner(int authorityIndex, BigInteger epochIndex, byte[] randomness, BigInteger slotNumber, VrfOutputAndProof vrfOutputAndProof) {
List<Authority> authorities = epochState.getCurrentEpochData().getAuthorities();
Authority verifyingAuthority = authorities.get(authorityIndex);
TranscriptData transcriptData = Authorship.makeTranscript(randomness, slotNumber, epochIndex);
BigInteger threshold = Authorship.calculatePrimaryThreshold(epochState.getCurrentEpochDescriptor().getConstant(), authorities, authorityIndex);
var isBelowThreshold = LittleEndianUtils.fromLittleEndianByteArray(vrfOutputAndProof.getOutput()).compareTo(threshold) < 0;
if (!isBelowThreshold) {
log.log(Level.WARNING, "Block producer is not a winner of the slot");
}
return Schnorrkel.getInstance().vrfVerify(new Schnorrkel.PublicKey(verifyingAuthority.getPublicKey()), transcriptData, vrfOutputAndProof) && isBelowThreshold;
}
}
8 changes: 4 additions & 4 deletions src/main/java/com/limechain/network/Network.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.limechain.storage.KVRepository;
import com.limechain.utils.Ed25519Utils;
import com.limechain.utils.StringUtils;
import com.limechain.utils.async.AsyncExecutor;
import io.ipfs.multiaddr.MultiAddress;
import io.ipfs.multihash.Multihash;
import io.libp2p.core.Host;
Expand Down Expand Up @@ -49,6 +50,7 @@
public class Network {
public static final String LOCAL_IPV4_TCP_ADDRESS = "/ip4/127.0.0.1/tcp/";
private static final int HOST_PORT = 30333;
private static final int ASYNC_EXECUTOR_POOL_SIZE = 10;
private static final Random RANDOM = new Random();
@Getter
private final Chain chain;
Expand Down Expand Up @@ -315,9 +317,7 @@ public void sendNeighbourMessage(PeerId peerId) {
}

public void sendBlockAnnounceMessage(byte[] encodedBlockAnnounceMessage) {
kademliaService.getBootNodePeerIds()
.stream()
.distinct()
.forEach(p -> new Thread(() -> blockAnnounceService.sendBlockAnnounceMessage(this.host, p, encodedBlockAnnounceMessage)).start());
AsyncExecutor asyncExecutor = AsyncExecutor.withPoolSize(ASYNC_EXECUTOR_POOL_SIZE);
connectionManager.getPeerIds().forEach(p -> asyncExecutor.executeAndForget(() -> blockAnnounceService.sendBlockAnnounceMessage(this.host, p, encodedBlockAnnounceMessage)));
}
}
127 changes: 127 additions & 0 deletions src/test/java/com/limechain/babe/BlockProductionVerifierTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package com.limechain.babe;

import com.limechain.babe.state.EpochData;
import com.limechain.babe.state.EpochDescriptor;
import com.limechain.babe.state.EpochState;
import com.limechain.chain.lightsyncstate.Authority;
import io.emeraldpay.polkaj.merlin.TranscriptData;
import io.emeraldpay.polkaj.schnorrkel.Schnorrkel;
import io.emeraldpay.polkaj.schnorrkel.VrfOutputAndProof;
import org.javatuples.Pair;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;

import java.math.BigInteger;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.*;

class BlockProductionVerifierTest {

@Mock
private EpochState epochState;

@Mock
private EpochData currentEpochData;

@Mock
private EpochDescriptor epochDescriptor;

@InjectMocks
private BlockProductionVerifier blockProductionVerifier;

@Mock
private VrfOutputAndProof vrfOutputAndProof;

private final BigInteger epochIndex = BigInteger.ONE;
private final byte[] randomness = new byte[]{0x01, 0x02, 0x03};
private final BigInteger slotNumber = BigInteger.TEN;

@BeforeEach
void setUp() {
MockitoAnnotations.openMocks(this);
}


@Test
void testVerifySlotWinner_AboveThresholdAndValidVrf() {
try (MockedStatic<Schnorrkel> mockedSchnorrkel = mockStatic(Schnorrkel.class)) {
Schnorrkel schnorrkelMock = mock(Schnorrkel.class);
mockedSchnorrkel.when(Schnorrkel::getInstance).thenReturn(schnorrkelMock);

when(schnorrkelMock.vrfVerify(any(Schnorrkel.PublicKey.class), any(TranscriptData.class), eq(vrfOutputAndProof)))
.thenReturn(true);

when(epochState.getCurrentEpochData()).thenReturn(currentEpochData);
when(epochState.getCurrentEpochDescriptor()).thenReturn(epochDescriptor);
when(epochDescriptor.getConstant()).thenReturn(new Pair<>(BigInteger.ZERO, BigInteger.valueOf(4)));
when(currentEpochData.getAuthorities()).thenReturn(List.of(
new Authority(new byte[]{0x01, 0x02, 0x03}, BigInteger.ONE),
new Authority(new byte[32], BigInteger.ONE),
new Authority(new byte[32], BigInteger.ONE)
));
when(vrfOutputAndProof.getOutput()).thenReturn(new byte[]{0x01, 0x02});

boolean result = blockProductionVerifier.verifySlotWinner(0, epochIndex, randomness, slotNumber, vrfOutputAndProof);
assertFalse(result);
mockedSchnorrkel.verify(Schnorrkel::getInstance, times(1));
}
}

@Test
void testVerifySlotWinner_BelowThresholdAndValidVrf() {
try (MockedStatic<Schnorrkel> mockedSchnorrkel = mockStatic(Schnorrkel.class)) {
Schnorrkel schnorrkelMock = mock(Schnorrkel.class);
mockedSchnorrkel.when(Schnorrkel::getInstance).thenReturn(schnorrkelMock);

when(schnorrkelMock.vrfVerify(any(Schnorrkel.PublicKey.class), any(TranscriptData.class), eq(vrfOutputAndProof)))
.thenReturn(true);

when(epochState.getCurrentEpochData()).thenReturn(currentEpochData);
when(epochState.getCurrentEpochDescriptor()).thenReturn(epochDescriptor);
when(epochDescriptor.getConstant()).thenReturn(new Pair<>(BigInteger.valueOf(3), BigInteger.valueOf(4)));
when(currentEpochData.getAuthorities()).thenReturn(List.of(
new Authority(new byte[]{0x01, 0x02, 0x03}, BigInteger.valueOf(1000)),
new Authority(new byte[32], BigInteger.valueOf(500)),
new Authority(new byte[32], BigInteger.valueOf(500))
));
when(vrfOutputAndProof.getOutput()).thenReturn(new byte[]{0x01, 0x02});

boolean result = blockProductionVerifier.verifySlotWinner(0, epochIndex, randomness, slotNumber, vrfOutputAndProof);
assertTrue(result);
mockedSchnorrkel.verify(Schnorrkel::getInstance, times(1));
}
}


@Test
void testVerifySlotWinner_InvalidVrfOutput() {
try (MockedStatic<Schnorrkel> mockedSchnorrkel = mockStatic(Schnorrkel.class)) {
Schnorrkel schnorrkelMock = mock(Schnorrkel.class);
mockedSchnorrkel.when(Schnorrkel::getInstance).thenReturn(schnorrkelMock);

when(schnorrkelMock.vrfVerify(any(Schnorrkel.PublicKey.class), any(TranscriptData.class), eq(vrfOutputAndProof)))
.thenReturn(false);

when(epochState.getCurrentEpochData()).thenReturn(currentEpochData);
when(epochState.getCurrentEpochDescriptor()).thenReturn(epochDescriptor);
when(epochDescriptor.getConstant()).thenReturn(new Pair<>(BigInteger.valueOf(3), BigInteger.valueOf(4)));
when(currentEpochData.getAuthorities()).thenReturn(List.of(
new Authority(new byte[]{0x01, 0x02, 0x03}, BigInteger.valueOf(1000)),
new Authority(new byte[32], BigInteger.valueOf(500)),
new Authority(new byte[32], BigInteger.valueOf(500))
));
when(vrfOutputAndProof.getOutput()).thenReturn(new byte[]{0x01, 0x02});

boolean result = blockProductionVerifier.verifySlotWinner(0, epochIndex, randomness, slotNumber, vrfOutputAndProof);
assertFalse(result);
mockedSchnorrkel.verify(Schnorrkel::getInstance, times(1));
}
}
}

0 comments on commit d49409a

Please sign in to comment.