diff --git a/akka-remote/src/main/scala/akka/remote/Endpoint.scala b/akka-remote/src/main/scala/akka/remote/Endpoint.scala index d977751433..6cf2f0c88a 100644 --- a/akka-remote/src/main/scala/akka/remote/Endpoint.scala +++ b/akka-remote/src/main/scala/akka/remote/Endpoint.scala @@ -83,6 +83,7 @@ class DefaultMessageDispatcher(private val system: ExtendedActorSystem, object EndpointWriter { + case class TakeOver(handle: AssociationHandle) case object BackoffTimer sealed trait State @@ -107,10 +108,12 @@ private[remote] class EndpointWriter( import context.dispatcher val extendedSystem: ExtendedActorSystem = context.system.asInstanceOf[ExtendedActorSystem] + val eventPublisher = new EventPublisher(context.system, log, settings.LogLifecycleEvents) + var reader: ActorRef = null var handle: AssociationHandle = handleOrActive.getOrElse(null) var inbound = false - val eventPublisher = new EventPublisher(context.system, log, settings.LogLifecycleEvents) + var readerId = 0 override val supervisorStrategy = OneForOneStrategy() { case NonFatal(e) ⇒ @@ -178,13 +181,25 @@ private[remote] class EndpointWriter( case NonFatal(e) ⇒ publishAndThrow("Failed to write message to the transport", e) } if (success) stay else { - stash + stash() goto(Buffering) } } whenUnhandled { - case Event(Terminated(r), _) if r == reader ⇒ stop() + case Event(Terminated(r), _) if r == reader ⇒ publishAndThrow("Disassociated", null) + case Event(TakeOver(newHandle), _) ⇒ + // Shutdown old reader + if (handle ne null) handle.disassociate() + if (reader ne null) { + context.unwatch(reader) + context.stop(reader) + } + handle = newHandle + inbound = true + startReadEndpoint() + unstashAll() + goto(Writing) } onTransition { @@ -199,6 +214,7 @@ private[remote] class EndpointWriter( onTermination { case StopEvent(_, _, _) ⇒ if (handle ne null) { + unstashAll() handle.disassociate() eventPublisher.notifyListeners(DisassociatedEvent(localAddress, remoteAddress, inbound)) } @@ -207,6 +223,7 @@ private[remote] class EndpointWriter( private def startReadEndpoint(): Unit = { reader = context.actorOf(Props(new EndpointReader(codec, handle.localAddress, msgDispatch)), "endpointReader-" + URLEncoder.encode(remoteAddress.toString, "utf-8")) + readerId += 1 handle.readHandlerPromise.success(reader) context.watch(reader) } diff --git a/akka-remote/src/main/scala/akka/remote/Remoting.scala b/akka-remote/src/main/scala/akka/remote/Remoting.scala index 6858a8188f..84733431a3 100644 --- a/akka-remote/src/main/scala/akka/remote/Remoting.scala +++ b/akka-remote/src/main/scala/akka/remote/Remoting.scala @@ -189,24 +189,38 @@ private[remote] object EndpointManager { // Not threadsafe -- only to be used in HeadActor private[EndpointManager] class EndpointRegistry { - @volatile private var addressToEndpointAndPolicy = HashMap[Address, EndpointPolicy]() - @volatile private var endpointToAddress = HashMap[ActorRef, Address]() + private var addressToEndpointAndPolicy = HashMap[Address, EndpointPolicy]() + private var endpointToAddress = HashMap[ActorRef, Address]() + private var addressToPassive = HashMap[Address, ActorRef]() def getEndpointWithPolicy(address: Address): Option[EndpointPolicy] = addressToEndpointAndPolicy.get(address) + def hasActiveEndpointFor(address: Address): Boolean = addressToEndpointAndPolicy.get(address) match { + case Some(Pass(_)) ⇒ true + case _ ⇒ false + } + + def passiveEndpointFor(address: Address): Option[ActorRef] = addressToPassive.get(address) + def prune(pruneAge: Long): Unit = { addressToEndpointAndPolicy = addressToEndpointAndPolicy.filter { - case (_, Pass(_)) ⇒ true case (_, Latched(timeOfFailure)) ⇒ timeOfFailure + pruneAge > System.nanoTime() + case _ ⇒ true } } - def registerEndpoint(address: Address, endpoint: ActorRef): ActorRef = { + def registerActiveEndpoint(address: Address, endpoint: ActorRef): ActorRef = { addressToEndpointAndPolicy = addressToEndpointAndPolicy + (address -> Pass(endpoint)) endpointToAddress = endpointToAddress + (endpoint -> address) endpoint } + def registerPassiveEndpoint(address: Address, endpoint: ActorRef): ActorRef = { + addressToPassive = addressToPassive + (address -> endpoint) + endpointToAddress = endpointToAddress + (endpoint -> address) + endpoint + } + def markFailed(endpoint: ActorRef, timeOfFailure: Long): Unit = { addressToEndpointAndPolicy += endpointToAddress(endpoint) -> Latched(timeOfFailure) endpointToAddress = endpointToAddress - endpoint @@ -285,19 +299,29 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends endpoints.getEndpointWithPolicy(recipientAddress) match { case Some(Pass(endpoint)) ⇒ endpoint ! s - case Some(Latched(timeOfFailure)) ⇒ if (retryLatchOpen(timeOfFailure)) - createEndpoint(recipientAddress, recipientRef.localAddressToUse, None) ! s - else extendedSystem.deadLetters ! message + case Some(Latched(timeOfFailure)) ⇒ if (retryLatchOpen(timeOfFailure)) { + val endpoint = createEndpoint(recipientAddress, recipientRef.localAddressToUse, None) + endpoints.registerActiveEndpoint(recipientAddress, endpoint) + endpoint ! s + } else extendedSystem.deadLetters ! message case Some(Quarantined(_)) ⇒ extendedSystem.deadLetters ! message - case None ⇒ createEndpoint(recipientAddress, recipientRef.localAddressToUse, None) ! s + case None ⇒ + val endpoint = createEndpoint(recipientAddress, recipientRef.localAddressToUse, None) + endpoints.registerActiveEndpoint(recipientAddress, endpoint) + endpoint ! s } - case InboundAssociation(handle) ⇒ - val endpoint = createEndpoint(handle.remoteAddress, handle.localAddress, Some(handle)) - eventPublisher.notifyListeners(AssociatedEvent(handle.localAddress, handle.remoteAddress, true)) - if (settings.UsePassiveConnections) endpoints.registerEndpoint(handle.localAddress, endpoint) - case Terminated(endpoint) ⇒ endpoints.removeIfNotLatched(endpoint) + case InboundAssociation(handle) ⇒ endpoints.passiveEndpointFor(handle.remoteAddress) match { + case Some(endpoint) ⇒ endpoint ! EndpointWriter.TakeOver(handle) + case None ⇒ + val endpoint = createEndpoint(handle.remoteAddress, handle.localAddress, Some(handle)) + eventPublisher.notifyListeners(AssociatedEvent(handle.localAddress, handle.remoteAddress, true)) + if (settings.UsePassiveConnections && !endpoints.hasActiveEndpointFor(handle.remoteAddress)) { + endpoints.registerActiveEndpoint(handle.remoteAddress, endpoint) + } else endpoints.registerPassiveEndpoint(handle.remoteAddress, endpoint) + } + case Terminated(endpoint) ⇒ endpoints.removeIfNotLatched(endpoint); case Prune ⇒ endpoints.prune(settings.RetryLatchClosedFor) } @@ -363,7 +387,8 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends .withDispatcher("akka.remoting.writer-dispatcher"), "endpointWriter-" + URLEncoder.encode(remoteAddress.toString, "utf-8") + "-" + endpointId) - endpoints.registerEndpoint(remoteAddress, endpoint) + context.watch(endpoint) // TODO: see what to do with this + } private def retryLatchOpen(timeOfFailure: Long): Boolean = (timeOfFailure + settings.RetryLatchClosedFor) < System.nanoTime()