diff --git a/ethereum/p2p/src/main/java/tech/pegasys/pantheon/ethereum/p2p/network/netty/NettyPeerConnection.java b/ethereum/p2p/src/main/java/tech/pegasys/pantheon/ethereum/p2p/network/netty/NettyPeerConnection.java index 53c9f85e1e..b99e90fd27 100644 --- a/ethereum/p2p/src/main/java/tech/pegasys/pantheon/ethereum/p2p/network/netty/NettyPeerConnection.java +++ b/ethereum/p2p/src/main/java/tech/pegasys/pantheon/ethereum/p2p/network/netty/NettyPeerConnection.java @@ -49,7 +49,6 @@ final class NettyPeerConnection implements PeerConnection { private final PeerInfo peerInfo; private final Set agreedCapabilities; private final Map protocolToCapability = new HashMap<>(); - private final AtomicBoolean disconnectDispatched = new AtomicBoolean(false); private final AtomicBoolean disconnected = new AtomicBoolean(false); private final Callbacks callbacks; private final CapabilityMultiplexer multiplexer; @@ -81,6 +80,10 @@ final class NettyPeerConnection implements PeerConnection { if (isDisconnected()) { throw new PeerNotConnected("Attempt to send message to a closed peer connection"); } + doSend(capability, message); + } + + private void doSend(final Capability capability, final MessageData message) { if (capability != null) { // Validate message is valid for this capability final SubProtocol subProtocol = multiplexer.subProtocol(capability); @@ -133,10 +136,9 @@ final class NettyPeerConnection implements PeerConnection { @Override public void terminateConnection(final DisconnectReason reason, final boolean peerInitiated) { - if (disconnectDispatched.compareAndSet(false, true)) { + if (disconnected.compareAndSet(false, true)) { LOG.debug("Disconnected ({}) from {}", reason, peerInfo); callbacks.invokeDisconnect(this, reason, peerInitiated); - disconnected.set(true); } // Always ensure the context gets closed immediately even if we previously sent a disconnect // message and are waiting to close. @@ -145,16 +147,10 @@ final class NettyPeerConnection implements PeerConnection { @Override public void disconnect(final DisconnectReason reason) { - if (disconnectDispatched.compareAndSet(false, true)) { + if (disconnected.compareAndSet(false, true)) { LOG.debug("Disconnecting ({}) from {}", reason, peerInfo); callbacks.invokeDisconnect(this, reason, false); - try { - send(null, DisconnectMessage.create(reason)); - } catch (final PeerNotConnected e) { - // The connection has already been closed - nothing left to do - return; - } - disconnected.set(true); + doSend(null, DisconnectMessage.create(reason)); ctx.channel().eventLoop().schedule((Callable) ctx::close, 2L, SECONDS); } } diff --git a/ethereum/p2p/src/test/java/tech/pegasys/pantheon/ethereum/p2p/network/netty/NettyPeerConnectionTest.java b/ethereum/p2p/src/test/java/tech/pegasys/pantheon/ethereum/p2p/network/netty/NettyPeerConnectionTest.java index d061b59c0c..6087e3fcfb 100644 --- a/ethereum/p2p/src/test/java/tech/pegasys/pantheon/ethereum/p2p/network/netty/NettyPeerConnectionTest.java +++ b/ethereum/p2p/src/test/java/tech/pegasys/pantheon/ethereum/p2p/network/netty/NettyPeerConnectionTest.java @@ -13,7 +13,12 @@ package tech.pegasys.pantheon.ethereum.p2p.network.netty; import static java.util.Collections.emptyList; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import tech.pegasys.pantheon.ethereum.p2p.api.PeerConnection.PeerNotConnected; @@ -30,7 +35,6 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.EventLoop; -import org.assertj.core.api.Assertions; import org.junit.Before; import org.junit.Test; @@ -71,7 +75,21 @@ public class NettyPeerConnectionTest { @Test public void shouldThrowExceptionWhenAttemptingToSendMessageOnClosedConnection() { connection.disconnect(DisconnectReason.SUBPROTOCOL_TRIGGERED); - Assertions.assertThatThrownBy(() -> connection.send(null, HelloMessage.create(peerInfo))) + assertThatThrownBy(() -> connection.send(null, HelloMessage.create(peerInfo))) .isInstanceOfAny(PeerNotConnected.class); } + + @Test + public void shouldThrowExceptionWhenAttemptingToSendMessageWhileDisconnecting() { + doAnswer( + invocation -> + assertThatThrownBy(() -> connection.send(null, HelloMessage.create(peerInfo))) + .isInstanceOfAny(PeerNotConnected.class)) + .when(callbacks) + .invokeDisconnect(any(), any(), anyBoolean()); + + connection.disconnect(DisconnectReason.USELESS_PEER); + + verify(callbacks).invokeDisconnect(connection, DisconnectReason.USELESS_PEER, false); + } }