diff --git a/consensus/ibft/src/test/java/tech/pegasys/pantheon/consensus/ibft/statemachine/RoundChangeManagerTest.java b/consensus/ibft/src/test/java/tech/pegasys/pantheon/consensus/ibft/statemachine/RoundChangeManagerTest.java index 8633004524..490e77a22c 100644 --- a/consensus/ibft/src/test/java/tech/pegasys/pantheon/consensus/ibft/statemachine/RoundChangeManagerTest.java +++ b/consensus/ibft/src/test/java/tech/pegasys/pantheon/consensus/ibft/statemachine/RoundChangeManagerTest.java @@ -27,6 +27,7 @@ import tech.pegasys.pantheon.consensus.ibft.ibftmessagedata.ProposalPayload; import tech.pegasys.pantheon.consensus.ibft.ibftmessagedata.RoundChangePayload; import tech.pegasys.pantheon.consensus.ibft.ibftmessagedata.SignedData; import tech.pegasys.pantheon.consensus.ibft.validation.MessageValidator; +import tech.pegasys.pantheon.consensus.ibft.validation.RoundChangeMessageValidator; import tech.pegasys.pantheon.crypto.SECP256K1.KeyPair; import tech.pegasys.pantheon.ethereum.ProtocolContext; import tech.pegasys.pantheon.ethereum.chain.MutableBlockchain; @@ -39,13 +40,11 @@ import tech.pegasys.pantheon.ethereum.mainnet.BlockHeaderValidator; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; -import com.google.common.collect.Maps; import org.junit.Before; import org.junit.Test; @@ -80,39 +79,41 @@ public class RoundChangeManagerTest { when(headerValidator.validateHeader(any(), any(), any(), any())).thenReturn(true); BlockHeader parentHeader = mock(BlockHeader.class); - Map messageValidators = Maps.newHashMap(); - - messageValidators.put( - ri1, - new MessageValidator( - validators, - Util.publicKeyToAddress(proposerKey.getPublicKey()), - ri1, - headerValidator, - protocolContext, - parentHeader)); - - messageValidators.put( - ri2, - new MessageValidator( - validators, - Util.publicKeyToAddress(validator1Key.getPublicKey()), - ri2, - headerValidator, - protocolContext, - parentHeader)); - - messageValidators.put( - ri3, - new MessageValidator( - validators, - Util.publicKeyToAddress(validator2Key.getPublicKey()), - ri3, - headerValidator, - protocolContext, - parentHeader)); - - manager = new RoundChangeManager(2, validators, messageValidators::get); + RoundChangeMessageValidator.MessageValidatorForHeightFactory messageValidatorFactory = + mock(RoundChangeMessageValidator.MessageValidatorForHeightFactory.class); + + when(messageValidatorFactory.createAt(ri1)) + .thenAnswer( + invocation -> + new MessageValidator( + validators, + Util.publicKeyToAddress(proposerKey.getPublicKey()), + ri1, + headerValidator, + protocolContext, + parentHeader)); + when(messageValidatorFactory.createAt(ri2)) + .thenAnswer( + invocation -> + new MessageValidator( + validators, + Util.publicKeyToAddress(validator1Key.getPublicKey()), + ri2, + headerValidator, + protocolContext, + parentHeader)); + when(messageValidatorFactory.createAt(ri3)) + .thenAnswer( + invocation -> + new MessageValidator( + validators, + Util.publicKeyToAddress(validator2Key.getPublicKey()), + ri3, + headerValidator, + protocolContext, + parentHeader)); + + manager = new RoundChangeManager(2, validators, messageValidatorFactory); } private SignedData makeRoundChangeMessage(