diff --git a/ethereum/jsonrpc/src/main/java/tech/pegasys/pantheon/ethereum/jsonrpc/websocket/subscription/SubscriptionManager.java b/ethereum/jsonrpc/src/main/java/tech/pegasys/pantheon/ethereum/jsonrpc/websocket/subscription/SubscriptionManager.java index 40a273fc6c..12d4fa277a 100644 --- a/ethereum/jsonrpc/src/main/java/tech/pegasys/pantheon/ethereum/jsonrpc/websocket/subscription/SubscriptionManager.java +++ b/ethereum/jsonrpc/src/main/java/tech/pegasys/pantheon/ethereum/jsonrpc/websocket/subscription/SubscriptionManager.java @@ -70,17 +70,26 @@ public class SubscriptionManager extends AbstractVerticle { } public boolean unsubscribe(final UnsubscribeRequest request) { - LOG.debug("Unsubscribe request subscriptionId = {}", request.getSubscriptionId()); + final Long subscriptionId = request.getSubscriptionId(); + final String connectionId = request.getConnectionId(); - if (!subscriptions.containsKey(request.getSubscriptionId())) { - throw new SubscriptionNotFoundException(request.getSubscriptionId()); + LOG.debug("Unsubscribe request subscriptionId = {}", subscriptionId); + + if (!subscriptions.containsKey(subscriptionId) + || !connectionOwnsSubscription(subscriptionId, connectionId)) { + throw new SubscriptionNotFoundException(subscriptionId); } - destroySubscription(request.getSubscriptionId(), request.getConnectionId()); + destroySubscription(subscriptionId, connectionId); return true; } + private boolean connectionOwnsSubscription(final Long subscriptionId, final String connectionId) { + return connectionSubscriptionsMap.get(connectionId) != null + && connectionSubscriptionsMap.get(connectionId).contains(subscriptionId); + } + private void destroySubscription(final long subscriptionId, final String connectionId) { subscriptions.remove(subscriptionId); diff --git a/ethereum/jsonrpc/src/test/java/tech/pegasys/pantheon/ethereum/jsonrpc/websocket/subscription/SubscriptionManagerTest.java b/ethereum/jsonrpc/src/test/java/tech/pegasys/pantheon/ethereum/jsonrpc/websocket/subscription/SubscriptionManagerTest.java index 9be74b7b47..227636af83 100644 --- a/ethereum/jsonrpc/src/test/java/tech/pegasys/pantheon/ethereum/jsonrpc/websocket/subscription/SubscriptionManagerTest.java +++ b/ethereum/jsonrpc/src/test/java/tech/pegasys/pantheon/ethereum/jsonrpc/websocket/subscription/SubscriptionManagerTest.java @@ -1,6 +1,7 @@ package tech.pegasys.pantheon.ethereum.jsonrpc.websocket.subscription; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.catchThrowable; import static org.hamcrest.CoreMatchers.both; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; @@ -13,6 +14,7 @@ import tech.pegasys.pantheon.ethereum.jsonrpc.websocket.subscription.request.Uns import tech.pegasys.pantheon.ethereum.jsonrpc.websocket.subscription.syncing.SyncingSubscription; import java.util.List; +import java.util.UUID; import org.junit.Before; import org.junit.Rule; @@ -33,8 +35,7 @@ public class SubscriptionManagerTest { @Test public void subscribeShouldCreateSubscription() { - final SubscribeRequest subscribeRequest = - new SubscribeRequest(SubscriptionType.SYNCING, null, null, CONNECTION_ID); + final SubscribeRequest subscribeRequest = subscribeRequest(CONNECTION_ID); final Long subscriptionId = subscriptionManager.subscribe(subscribeRequest); @@ -49,8 +50,7 @@ public class SubscriptionManagerTest { @Test public void unsubscribeExistingSubscriptionShouldDestroySubscription() { - final SubscribeRequest subscribeRequest = - new SubscribeRequest(SubscriptionType.SYNCING, null, null, CONNECTION_ID); + final SubscribeRequest subscribeRequest = subscribeRequest(CONNECTION_ID); final Long subscriptionId = subscriptionManager.subscribe(subscribeRequest); assertThat(subscriptionManager.subscriptions().get(subscriptionId)).isNotNull(); @@ -76,8 +76,7 @@ public class SubscriptionManagerTest { @Test public void shouldAddSubscriptionToNewConnection() { - final SubscribeRequest subscribeRequest = - new SubscribeRequest(SubscriptionType.SYNCING, null, null, CONNECTION_ID); + final SubscribeRequest subscribeRequest = subscribeRequest(CONNECTION_ID); subscriptionManager.subscribe(subscribeRequest); @@ -90,8 +89,7 @@ public class SubscriptionManagerTest { @Test public void shouldAddSubscriptionToExistingConnection() { - final SubscribeRequest subscribeRequest = - new SubscribeRequest(SubscriptionType.SYNCING, null, null, CONNECTION_ID); + final SubscribeRequest subscribeRequest = subscribeRequest(CONNECTION_ID); subscriptionManager.subscribe(subscribeRequest); @@ -110,8 +108,7 @@ public class SubscriptionManagerTest { @Test public void shouldRemoveSubscriptionFromExistingConnection() { - final SubscribeRequest subscribeRequest = - new SubscribeRequest(SubscriptionType.SYNCING, null, null, CONNECTION_ID); + final SubscribeRequest subscribeRequest = subscribeRequest(CONNECTION_ID); final Long subscriptionId1 = subscriptionManager.subscribe(subscribeRequest); @@ -140,8 +137,7 @@ public class SubscriptionManagerTest { @Test public void shouldRemoveConnectionWithSingleSubscriptions() { - final SubscribeRequest subscribeRequest = - new SubscribeRequest(SubscriptionType.SYNCING, null, null, CONNECTION_ID); + final SubscribeRequest subscribeRequest = subscribeRequest(CONNECTION_ID); final Long subscriptionId1 = subscriptionManager.subscribe(subscribeRequest); @@ -188,17 +184,37 @@ public class SubscriptionManagerTest { } @Test - public void unsubscribeWithUnknownConnectionId() { - final SubscribeRequest subscribeRequestOne = - new SubscribeRequest(SubscriptionType.NEW_BLOCK_HEADERS, null, true, CONNECTION_ID); - final long subscriptionId = subscriptionManager.subscribe(subscribeRequestOne); + public void unsubscribeOthersSubscriptionsNotHavingOwnSubscriptionShouldReturnNotFound() { + SubscribeRequest subscribeRequest = subscribeRequest(CONNECTION_ID); + Long subscriptionId = subscriptionManager.subscribe(subscribeRequest); - final boolean success = - subscriptionManager.unsubscribe( - new UnsubscribeRequest(subscriptionId, "unknown-connection-id")); + UnsubscribeRequest unsubscribeRequest = + new UnsubscribeRequest(subscriptionId, UUID.randomUUID().toString()); - assertThat(success).isTrue(); + final Throwable thrown = + catchThrowable(() -> subscriptionManager.unsubscribe(unsubscribeRequest)); + assertThat(thrown).isInstanceOf(SubscriptionNotFoundException.class); } - // TODO vertx event bus testing for response + @Test + public void unsubscribeOthersSubscriptionsHavingOwnSubscriptionShouldReturnNotFound() { + String ownConnectionId = UUID.randomUUID().toString(); + SubscribeRequest ownSubscribeRequest = subscribeRequest(ownConnectionId); + subscriptionManager.subscribe(ownSubscribeRequest); + + String otherConnectionId = UUID.randomUUID().toString(); + SubscribeRequest otherSubscribeRequest = subscribeRequest(otherConnectionId); + Long otherSubscriptionId = subscriptionManager.subscribe(otherSubscribeRequest); + + UnsubscribeRequest unsubscribeRequest = + new UnsubscribeRequest(otherSubscriptionId, ownConnectionId); + + final Throwable thrown = + catchThrowable(() -> subscriptionManager.unsubscribe(unsubscribeRequest)); + assertThat(thrown).isInstanceOf(SubscriptionNotFoundException.class); + } + + private SubscribeRequest subscribeRequest(final String connectionId) { + return new SubscribeRequest(SubscriptionType.SYNCING, null, null, connectionId); + } }