From bcd7c44fd98deb676f4a88ccf277ae21358225f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Antonsson?= Date: Mon, 20 May 2013 13:53:25 +0200 Subject: [PATCH] Make associate in TestTransport wait for the association event linstener. See #3363 --- .../akka/remote/transport/TestTransport.scala | 45 ++++++++++--------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/akka-remote/src/main/scala/akka/remote/transport/TestTransport.scala b/akka-remote/src/main/scala/akka/remote/transport/TestTransport.scala index 280d0e0099..2505cef5a6 100644 --- a/akka-remote/src/main/scala/akka/remote/transport/TestTransport.scala +++ b/akka-remote/src/main/scala/akka/remote/transport/TestTransport.scala @@ -44,36 +44,37 @@ class TestTransport( private val associationListenerPromise = Promise[AssociationEventListener]() private def defaultListen: Future[(Address, Promise[AssociationEventListener])] = { - associationListenerPromise.future.onSuccess { - case listener: AssociationEventListener ⇒ registry.registerTransport(this, listener) - } + registry.registerTransport(this, associationListenerPromise.future) Future.successful((localAddress, associationListenerPromise)) } private def defaultAssociate(remoteAddress: Address): Future[AssociationHandle] = { registry.transportFor(remoteAddress) match { - case Some((remoteTransport, remoteListener)) ⇒ + case Some((remoteTransport, remoteListenerFuture)) ⇒ val (localHandle, remoteHandle) = createHandlePair(remoteTransport, remoteAddress) localHandle.writable = false remoteHandle.writable = false // Pass a non-writable handle to remote first - remoteListener notify InboundAssociation(remoteHandle) - val remoteHandlerFuture = remoteHandle.readHandlerPromise.future + remoteListenerFuture flatMap { + case listener ⇒ + listener notify InboundAssociation(remoteHandle) + val remoteHandlerFuture = remoteHandle.readHandlerPromise.future - // Registration of reader at local finishes the registration and enables communication - for { - remoteListener ← remoteHandlerFuture - localListener ← localHandle.readHandlerPromise.future - } { - registry.registerListenerPair(localHandle.key, (localListener, remoteListener)) - localHandle.writable = true - remoteHandle.writable = true + // Registration of reader at local finishes the registration and enables communication + for { + remoteListener ← remoteHandlerFuture + localListener ← localHandle.readHandlerPromise.future + } { + registry.registerListenerPair(localHandle.key, (localListener, remoteListener)) + localHandle.writable = true + remoteHandle.writable = true + } + + remoteHandlerFuture.map { _ ⇒ localHandle } } - remoteHandlerFuture.map { _ ⇒ localHandle } - case None ⇒ Future.failed(new InvalidAssociationException(s"No registered transport: $remoteAddress", null)) } @@ -285,7 +286,7 @@ object TestTransport { class AssociationRegistry { private val activityLog = new CopyOnWriteArrayList[Activity]() - private val transportTable = new ConcurrentHashMap[Address, (TestTransport, AssociationEventListener)]() + private val transportTable = new ConcurrentHashMap[Address, (TestTransport, Future[AssociationEventListener])]() private val listenersTable = new ConcurrentHashMap[(Address, Address), (HandleEventListener, HandleEventListener)]() /** @@ -336,11 +337,11 @@ object TestTransport { * * @param transport * The transport that is to be registered. The address of this transport will be used as key. - * @param associationEventListener - * The listener that will handle the events for the given transport. + * @param associationEventListenerFuture + * The future that will be completed with the listener that will handle the events for the given transport. */ - def registerTransport(transport: TestTransport, associationEventListener: AssociationEventListener): Unit = { - transportTable.put(transport.localAddress, (transport, associationEventListener)) + def registerTransport(transport: TestTransport, associationEventListenerFuture: Future[AssociationEventListener]): Unit = { + transportTable.put(transport.localAddress, (transport, associationEventListenerFuture)) } /** @@ -410,7 +411,7 @@ object TestTransport { * @param address The address bound to the transport. * @return The transport if exists. */ - def transportFor(address: Address): Option[(TestTransport, AssociationEventListener)] = + def transportFor(address: Address): Option[(TestTransport, Future[AssociationEventListener])] = Option(transportTable.get(address)) /**