diff --git a/src/main/java/com/limechain/babe/Authorship.java b/src/main/java/com/limechain/babe/Authorship.java index 73de761e..1f4fc4cc 100644 --- a/src/main/java/com/limechain/babe/Authorship.java +++ b/src/main/java/com/limechain/babe/Authorship.java @@ -191,7 +191,7 @@ private static Integer getSecondarySlotAuthor(byte[] randomness, } // threshold = 2^128 * (1 - (1 - c) ^ (authority_weight / sum(authorities_weights))) - private static BigInteger calculatePrimaryThreshold( + public static BigInteger calculatePrimaryThreshold( Pair constant, List authorities, int authorityIndex) { @@ -259,7 +259,7 @@ private static double getBabeConstant(Pair constant) { return c; } - private static TranscriptData makeTranscript(byte[] randomness, BigInteger slotNumber, BigInteger epochIndex) { + public static TranscriptData makeTranscript(byte[] randomness, BigInteger slotNumber, BigInteger epochIndex) { var transcript = new TranscriptData("BABE".getBytes()); transcript.appendMessage("slot number", LittleEndianUtils.toLittleEndianBytes(slotNumber)); transcript.appendMessage("current epoch", LittleEndianUtils.toLittleEndianBytes(epochIndex)); diff --git a/src/main/java/com/limechain/babe/BabeService.java b/src/main/java/com/limechain/babe/BabeService.java index 2798e6ac..a5e1920e 100644 --- a/src/main/java/com/limechain/babe/BabeService.java +++ b/src/main/java/com/limechain/babe/BabeService.java @@ -28,11 +28,12 @@ public class BabeService implements SlotChangeListener { private final EpochState epochState; private final KeyStore keyStore; private final Map slotToPreRuntimeDigest = new HashedMap<>(); - private final Network network = AppBean.getBean(Network.class); + private final Network network; - public BabeService(EpochState epochState, KeyStore keyStore) { + public BabeService(EpochState epochState, KeyStore keyStore, Network network) { this.epochState = epochState; this.keyStore = keyStore; + this.network = network; } private void executeEpochLottery(BigInteger epochIndex) { diff --git a/src/main/java/com/limechain/babe/BlockProductionVerifier.java b/src/main/java/com/limechain/babe/BlockProductionVerifier.java new file mode 100644 index 00000000..3d02a28a --- /dev/null +++ b/src/main/java/com/limechain/babe/BlockProductionVerifier.java @@ -0,0 +1,32 @@ +package com.limechain.babe; + +import com.limechain.babe.state.EpochState; +import com.limechain.chain.lightsyncstate.Authority; +import com.limechain.rpc.server.AppBean; +import com.limechain.utils.LittleEndianUtils; +import io.emeraldpay.polkaj.merlin.TranscriptData; +import io.emeraldpay.polkaj.schnorrkel.Schnorrkel; +import io.emeraldpay.polkaj.schnorrkel.VrfOutputAndProof; +import lombok.extern.java.Log; + +import java.math.BigInteger; +import java.util.List; +import java.util.logging.Level; + +@Log +public class BlockProductionVerifier { + + private EpochState epochState = AppBean.getBean(EpochState.class); + + public boolean verifySlotWinner(int authorityIndex, BigInteger epochIndex, byte[] randomness, BigInteger slotNumber, VrfOutputAndProof vrfOutputAndProof) { + List authorities = epochState.getCurrentEpochData().getAuthorities(); + Authority verifyingAuthority = authorities.get(authorityIndex); + TranscriptData transcriptData = Authorship.makeTranscript(randomness, slotNumber, epochIndex); + BigInteger threshold = Authorship.calculatePrimaryThreshold(epochState.getCurrentEpochDescriptor().getConstant(), authorities, authorityIndex); + var isBelowThreshold = LittleEndianUtils.fromLittleEndianByteArray(vrfOutputAndProof.getOutput()).compareTo(threshold) < 0; + if (!isBelowThreshold) { + log.log(Level.WARNING, "Block producer is not a winner of the slot"); + } + return Schnorrkel.getInstance().vrfVerify(new Schnorrkel.PublicKey(verifyingAuthority.getPublicKey()), transcriptData, vrfOutputAndProof) && isBelowThreshold; + } +} diff --git a/src/main/java/com/limechain/network/Network.java b/src/main/java/com/limechain/network/Network.java index 93b23155..29e1ac58 100644 --- a/src/main/java/com/limechain/network/Network.java +++ b/src/main/java/com/limechain/network/Network.java @@ -19,6 +19,7 @@ import com.limechain.storage.KVRepository; import com.limechain.utils.Ed25519Utils; import com.limechain.utils.StringUtils; +import com.limechain.utils.async.AsyncExecutor; import io.ipfs.multiaddr.MultiAddress; import io.ipfs.multihash.Multihash; import io.libp2p.core.Host; @@ -49,6 +50,7 @@ public class Network { public static final String LOCAL_IPV4_TCP_ADDRESS = "/ip4/127.0.0.1/tcp/"; private static final int HOST_PORT = 30333; + private static final int ASYNC_EXECUTOR_POOL_SIZE = 10; private static final Random RANDOM = new Random(); @Getter private final Chain chain; @@ -315,9 +317,7 @@ public void sendNeighbourMessage(PeerId peerId) { } public void sendBlockAnnounceMessage(byte[] encodedBlockAnnounceMessage) { - kademliaService.getBootNodePeerIds() - .stream() - .distinct() - .forEach(p -> new Thread(() -> blockAnnounceService.sendBlockAnnounceMessage(this.host, p, encodedBlockAnnounceMessage)).start()); + AsyncExecutor asyncExecutor = AsyncExecutor.withPoolSize(ASYNC_EXECUTOR_POOL_SIZE); + connectionManager.getPeerIds().forEach(p -> asyncExecutor.executeAndForget(() -> blockAnnounceService.sendBlockAnnounceMessage(this.host, p, encodedBlockAnnounceMessage))); } } diff --git a/src/test/java/com/limechain/babe/BlockProductionVerifierTest.java b/src/test/java/com/limechain/babe/BlockProductionVerifierTest.java new file mode 100644 index 00000000..4c8959c0 --- /dev/null +++ b/src/test/java/com/limechain/babe/BlockProductionVerifierTest.java @@ -0,0 +1,127 @@ +package com.limechain.babe; + +import com.limechain.babe.state.EpochData; +import com.limechain.babe.state.EpochDescriptor; +import com.limechain.babe.state.EpochState; +import com.limechain.chain.lightsyncstate.Authority; +import io.emeraldpay.polkaj.merlin.TranscriptData; +import io.emeraldpay.polkaj.schnorrkel.Schnorrkel; +import io.emeraldpay.polkaj.schnorrkel.VrfOutputAndProof; +import org.javatuples.Pair; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.MockitoAnnotations; + +import java.math.BigInteger; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.*; + +class BlockProductionVerifierTest { + + @Mock + private EpochState epochState; + + @Mock + private EpochData currentEpochData; + + @Mock + private EpochDescriptor epochDescriptor; + + @InjectMocks + private BlockProductionVerifier blockProductionVerifier; + + @Mock + private VrfOutputAndProof vrfOutputAndProof; + + private final BigInteger epochIndex = BigInteger.ONE; + private final byte[] randomness = new byte[]{0x01, 0x02, 0x03}; + private final BigInteger slotNumber = BigInteger.TEN; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + } + + + @Test + void testVerifySlotWinner_AboveThresholdAndValidVrf() { + try (MockedStatic mockedSchnorrkel = mockStatic(Schnorrkel.class)) { + Schnorrkel schnorrkelMock = mock(Schnorrkel.class); + mockedSchnorrkel.when(Schnorrkel::getInstance).thenReturn(schnorrkelMock); + + when(schnorrkelMock.vrfVerify(any(Schnorrkel.PublicKey.class), any(TranscriptData.class), eq(vrfOutputAndProof))) + .thenReturn(true); + + when(epochState.getCurrentEpochData()).thenReturn(currentEpochData); + when(epochState.getCurrentEpochDescriptor()).thenReturn(epochDescriptor); + when(epochDescriptor.getConstant()).thenReturn(new Pair<>(BigInteger.ZERO, BigInteger.valueOf(4))); + when(currentEpochData.getAuthorities()).thenReturn(List.of( + new Authority(new byte[]{0x01, 0x02, 0x03}, BigInteger.ONE), + new Authority(new byte[32], BigInteger.ONE), + new Authority(new byte[32], BigInteger.ONE) + )); + when(vrfOutputAndProof.getOutput()).thenReturn(new byte[]{0x01, 0x02}); + + boolean result = blockProductionVerifier.verifySlotWinner(0, epochIndex, randomness, slotNumber, vrfOutputAndProof); + assertFalse(result); + mockedSchnorrkel.verify(Schnorrkel::getInstance, times(1)); + } + } + + @Test + void testVerifySlotWinner_BelowThresholdAndValidVrf() { + try (MockedStatic mockedSchnorrkel = mockStatic(Schnorrkel.class)) { + Schnorrkel schnorrkelMock = mock(Schnorrkel.class); + mockedSchnorrkel.when(Schnorrkel::getInstance).thenReturn(schnorrkelMock); + + when(schnorrkelMock.vrfVerify(any(Schnorrkel.PublicKey.class), any(TranscriptData.class), eq(vrfOutputAndProof))) + .thenReturn(true); + + when(epochState.getCurrentEpochData()).thenReturn(currentEpochData); + when(epochState.getCurrentEpochDescriptor()).thenReturn(epochDescriptor); + when(epochDescriptor.getConstant()).thenReturn(new Pair<>(BigInteger.valueOf(3), BigInteger.valueOf(4))); + when(currentEpochData.getAuthorities()).thenReturn(List.of( + new Authority(new byte[]{0x01, 0x02, 0x03}, BigInteger.valueOf(1000)), + new Authority(new byte[32], BigInteger.valueOf(500)), + new Authority(new byte[32], BigInteger.valueOf(500)) + )); + when(vrfOutputAndProof.getOutput()).thenReturn(new byte[]{0x01, 0x02}); + + boolean result = blockProductionVerifier.verifySlotWinner(0, epochIndex, randomness, slotNumber, vrfOutputAndProof); + assertTrue(result); + mockedSchnorrkel.verify(Schnorrkel::getInstance, times(1)); + } + } + + + @Test + void testVerifySlotWinner_InvalidVrfOutput() { + try (MockedStatic mockedSchnorrkel = mockStatic(Schnorrkel.class)) { + Schnorrkel schnorrkelMock = mock(Schnorrkel.class); + mockedSchnorrkel.when(Schnorrkel::getInstance).thenReturn(schnorrkelMock); + + when(schnorrkelMock.vrfVerify(any(Schnorrkel.PublicKey.class), any(TranscriptData.class), eq(vrfOutputAndProof))) + .thenReturn(false); + + when(epochState.getCurrentEpochData()).thenReturn(currentEpochData); + when(epochState.getCurrentEpochDescriptor()).thenReturn(epochDescriptor); + when(epochDescriptor.getConstant()).thenReturn(new Pair<>(BigInteger.valueOf(3), BigInteger.valueOf(4))); + when(currentEpochData.getAuthorities()).thenReturn(List.of( + new Authority(new byte[]{0x01, 0x02, 0x03}, BigInteger.valueOf(1000)), + new Authority(new byte[32], BigInteger.valueOf(500)), + new Authority(new byte[32], BigInteger.valueOf(500)) + )); + when(vrfOutputAndProof.getOutput()).thenReturn(new byte[]{0x01, 0x02}); + + boolean result = blockProductionVerifier.verifySlotWinner(0, epochIndex, randomness, slotNumber, vrfOutputAndProof); + assertFalse(result); + mockedSchnorrkel.verify(Schnorrkel::getInstance, times(1)); + } + } +}