diff --git a/ethereum/p2p/src/test/java/org/hyperledger/besu/ethereum/p2p/discovery/internal/PeerDiscoveryTableRefreshTest.java b/ethereum/p2p/src/test/java/org/hyperledger/besu/ethereum/p2p/discovery/internal/PeerDiscoveryTableRefreshTest.java index 2355defcb0..cd6a308ab1 100644 --- a/ethereum/p2p/src/test/java/org/hyperledger/besu/ethereum/p2p/discovery/internal/PeerDiscoveryTableRefreshTest.java +++ b/ethereum/p2p/src/test/java/org/hyperledger/besu/ethereum/p2p/discovery/internal/PeerDiscoveryTableRefreshTest.java @@ -15,8 +15,10 @@ package org.hyperledger.besu.ethereum.p2p.discovery.internal; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -32,6 +34,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Optional; +import java.util.function.Consumer; import java.util.stream.Collectors; import org.apache.tuweni.bytes.Bytes; @@ -47,11 +50,14 @@ public class PeerDiscoveryTableRefreshTest { final List nodeKeys = PeerDiscoveryTestHelper.generateNodeKeys(2); final List peers = helper.createDiscoveryPeers(nodeKeys); final DiscoveryPeer localPeer = peers.get(0); + final DiscoveryPeer remotePeer = peers.get(1); final NodeKey localKeyPair = nodeKeys.get(0); + final NodeKey remoteKeyPair = nodeKeys.get(1); // Create and start the PeerDiscoveryController final OutboundMessageHandler outboundMessageHandler = mock(OutboundMessageHandler.class); final MockTimerUtil timer = new MockTimerUtil(); + final PeerDiscoveryController controller = spy( PeerDiscoveryController.builder() @@ -67,45 +73,44 @@ public class PeerDiscoveryTableRefreshTest { .build()); controller.start(); - // Send a PING, so as to add a Peer in the controller. - final PingPacketData ping = - PingPacketData.create(peers.get(1).getEndpoint(), peers.get(0).getEndpoint()); - final Packet pingPacket = Packet.create(PacketType.PING, ping, nodeKeys.get(1)); - controller.onMessage(pingPacket, peers.get(1)); + final PingPacketData mockPing = + PingPacketData.create(localPeer.getEndpoint(), remotePeer.getEndpoint()); + final Packet mockPingPacket = Packet.create(PacketType.PING, mockPing, localKeyPair); - final PingPacketData data = - PingPacketData.create(peers.get(0).getEndpoint(), peers.get(1).getEndpoint()); - final Packet packet = Packet.create(PacketType.PING, data, nodeKeys.get(0)); + doAnswer( + invocation -> { + final Consumer handler = invocation.getArgument(2); + handler.accept(mockPingPacket); + return null; + }) + .when(controller) + .createPacket(eq(PacketType.PING), any(), any()); - // Simulate a PONG message from peer 0. - final PongPacketData pong = PongPacketData.create(peers.get(0).getEndpoint(), packet.getHash()); - final Packet pongPacket = Packet.create(PacketType.PONG, pong, nodeKeys.get(1)); + // Send a PING, so as to add a Peer in the controller. + final PingPacketData ping = + PingPacketData.create(remotePeer.getEndpoint(), localPeer.getEndpoint()); + final Packet pingPacket = Packet.create(PacketType.PING, ping, remoteKeyPair); + controller.onMessage(pingPacket, remotePeer); - controller.onMessage(pongPacket, peers.get(1)); + // Answer localPeer PING to complete bonding + final PongPacketData pong = + PongPacketData.create(localPeer.getEndpoint(), mockPingPacket.getHash()); + final Packet pongPacket = Packet.create(PacketType.PONG, pong, remoteKeyPair); + controller.onMessage(pongPacket, remotePeer); // Wait until the controller has added the newly found peer. assertThat(controller.streamDiscoveredPeers()).hasSize(1); final ArgumentCaptor captor = ArgumentCaptor.forClass(Packet.class); for (int i = 0; i < 5; i++) { - controller.onMessage(pingPacket, peers.get(1)); - - final PingPacketData refreshData = - PingPacketData.create(peers.get(0).getEndpoint(), peers.get(1).getEndpoint()); - final Packet refreshPacket = Packet.create(PacketType.PING, refreshData, nodeKeys.get(0)); - - final PongPacketData refreshPong = - PongPacketData.create(peers.get(0).getEndpoint(), refreshPacket.getHash()); - final Packet refreshPongPacket = Packet.create(PacketType.PONG, refreshPong, nodeKeys.get(1)); - - controller.onMessage(refreshPongPacket, peers.get(1)); - controller.getRecursivePeerRefreshState().cancel(); timer.runPeriodicHandlers(); controller.streamDiscoveredPeers().forEach(p -> p.setStatus(PeerDiscoveryStatus.KNOWN)); + controller.onMessage(pingPacket, remotePeer); + controller.onMessage(pongPacket, remotePeer); } - verify(outboundMessageHandler, atLeast(5)).send(eq(peers.get(1)), captor.capture()); + verify(outboundMessageHandler, atLeast(5)).send(eq(remotePeer), captor.capture()); final List capturedFindNeighborsPackets = captor.getAllValues().stream() .filter(p -> p.getType().equals(PacketType.FIND_NEIGHBORS)) @@ -122,8 +127,6 @@ public class PeerDiscoveryTableRefreshTest { targets.add(neighborsData.getTarget()); } - assertThat(targets.size()).isEqualTo(5); - // All targets are unique. assertThat(targets.size()).isEqualTo(new HashSet<>(targets).size()); }