diff --git a/akka-remote/src/main/scala/akka/remote/Endpoint.scala b/akka-remote/src/main/scala/akka/remote/Endpoint.scala index d7a873cec8..62412692cb 100644 --- a/akka-remote/src/main/scala/akka/remote/Endpoint.scala +++ b/akka-remote/src/main/scala/akka/remote/Endpoint.scala @@ -243,7 +243,8 @@ private[remote] class ReliableDeliverySupervisor( case Terminated(_) ⇒ currentHandle = None context.become(idle) - case GotUid(u) ⇒ uid = Some(u) + case GotUid(u) ⇒ uid = Some(u) + case s: EndpointWriter.StopReading ⇒ writer forward s } def gated: Receive = { @@ -255,7 +256,8 @@ private[remote] class ReliableDeliverySupervisor( } else context.become(idle) case s @ Send(msg: SystemMessage, _, _, _) ⇒ tryBuffer(s.copy(seqOpt = Some(nextSeq()))) case s: Send ⇒ context.system.deadLetters ! s - case FlushAndStop ⇒ context.stop(self) + case EndpointWriter.FlushAndStop ⇒ context.stop(self) + case EndpointWriter.StopReading(w) ⇒ sender ! EndpointWriter.StoppedReading(w) case _ ⇒ // Ignore } @@ -265,7 +267,8 @@ private[remote] class ReliableDeliverySupervisor( resendAll() handleSend(s) context.become(receive) - case FlushAndStop ⇒ context.stop(self) + case EndpointWriter.FlushAndStop ⇒ context.stop(self) + case EndpointWriter.StopReading(w) ⇒ sender ! EndpointWriter.StoppedReading(w) } def flushWait: Receive = { @@ -362,6 +365,8 @@ private[remote] object EndpointWriter { case object BackoffTimer case object FlushAndStop case object AckIdleCheckTimer + case class StopReading(writer: ActorRef) + case class StoppedReading(writer: ActorRef) case class OutboundAck(ack: Ack) @@ -533,15 +538,23 @@ private[remote] class EndpointWriter( stash() stay() + case _: StopReading ⇒ + stash() + stay() } whenUnhandled { case Event(Terminated(r), _) if r == reader.orNull ⇒ publishAndThrow(new EndpointDisassociatedException("Disassociated")) + case Event(s: StopReading, _) ⇒ + reader match { + case Some(r) ⇒ r forward s + case None ⇒ stash() + } + stay() case Event(TakeOver(newHandle), _) ⇒ // Shutdown old reader handle foreach { _.disassociate() } - reader foreach context.stop handle = Some(newHandle) goto(Handoff) case Event(FlushAndStop, _) ⇒ @@ -632,7 +645,7 @@ private[remote] class EndpointReader( val reliableDeliverySupervisor: Option[ActorRef], val receiveBuffers: ConcurrentHashMap[Link, AckedReceiveBuffer[Message]]) extends EndpointActor(localAddress, remoteAddress, transport, settings, codec) { - import EndpointWriter.OutboundAck + import EndpointWriter.{ OutboundAck, StopReading, StoppedReading } val provider = RARP(context.system).provider var ackedReceiveBuffer = new AckedReceiveBuffer[Message] @@ -646,8 +659,9 @@ private[remote] class EndpointReader( } } - override def postStop(): Unit = { + override def postStop(): Unit = saveState() + def saveState(): Unit = { @tailrec def updateSavedState(key: Link, expectedState: AckedReceiveBuffer[Message]): Unit = { if (expectedState eq null) { @@ -661,7 +675,8 @@ private[remote] class EndpointReader( } override def receive: Receive = { - case Disassociated ⇒ context.stop(self) + case Disassociated ⇒ + context.stop(self) case InboundPayload(p) if p.size <= transport.maximumPayloadBytes ⇒ val (ackOption, msgOption) = tryDecodeMessageAndAck(p) @@ -682,6 +697,24 @@ private[remote] class EndpointReader( publishError(new OversizedPayloadException(s"Discarding oversized payload received: " + s"max allowed size [${transport.maximumPayloadBytes}] bytes, actual size [${oversized.size}] bytes.")) + case StopReading(writer) ⇒ + saveState() + context.become(notReading) + sender ! StoppedReading(writer) + + } + + def notReading: Receive = { + case Disassociated ⇒ context.stop(self) + + case StopReading(newHandle) ⇒ + sender ! StoppedReading(newHandle) + + case InboundPayload(p) ⇒ + val (ackOption, _) = tryDecodeMessageAndAck(p) + for (ack ← ackOption; reliableDelivery ← reliableDeliverySupervisor) reliableDelivery ! ack + + case _ ⇒ } private def deliverAndAck(): Unit = { diff --git a/akka-remote/src/main/scala/akka/remote/Remoting.scala b/akka-remote/src/main/scala/akka/remote/Remoting.scala index b5edd7381e..d66750425c 100644 --- a/akka-remote/src/main/scala/akka/remote/Remoting.scala +++ b/akka-remote/src/main/scala/akka/remote/Remoting.scala @@ -374,6 +374,8 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends Some(context.system.scheduler.schedule(pruneInterval, pruneInterval, self, Prune)) else None + var pendingReadHandoffs = Map[ActorRef, AkkaProtocolHandle]() + override val supervisorStrategy = OneForOneStrategy(loggingEnabled = false) { case InvalidAssociation(localAddress, remoteAddress, _) ⇒ @@ -505,34 +507,42 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends } case InboundAssociation(handle: AkkaProtocolHandle) ⇒ endpoints.readOnlyEndpointFor(handle.remoteAddress) match { - case Some(endpoint) ⇒ endpoint ! EndpointWriter.TakeOver(handle) + case Some(endpoint) ⇒ + endpoint ! EndpointWriter.TakeOver(handle) case None ⇒ if (endpoints.isQuarantined(handle.remoteAddress, handle.handshakeInfo.uid)) handle.disassociate() - else { - val writing = settings.UsePassiveConnections && !endpoints.hasWritableEndpointFor(handle.remoteAddress) - eventPublisher.notifyListeners(AssociatedEvent(handle.localAddress, handle.remoteAddress, true)) - val endpoint = createEndpoint( - handle.remoteAddress, - handle.localAddress, - transportMapping(handle.localAddress), - settings, - Some(handle), - writing) - if (writing) - endpoints.registerWritableEndpoint(handle.remoteAddress, endpoint) - else { - endpoints.registerReadOnlyEndpoint(handle.remoteAddress, endpoint) - endpoints.writableEndpointWithPolicyFor(handle.remoteAddress) match { - case Some(Pass(_)) ⇒ // Leave it alone - case _ ⇒ - // Since we just communicated with the guy we can lift gate, quarantine, etc. New writer will be - // opened at first write. - endpoints.removePolicy(handle.remoteAddress) + else endpoints.writableEndpointWithPolicyFor(handle.remoteAddress) match { + case Some(Pass(ep)) ⇒ + pendingReadHandoffs += ep -> handle + ep ! EndpointWriter.StopReading(ep) + case _ ⇒ + val writing = settings.UsePassiveConnections && !endpoints.hasWritableEndpointFor(handle.remoteAddress) + eventPublisher.notifyListeners(AssociatedEvent(handle.localAddress, handle.remoteAddress, true)) + val endpoint = createEndpoint( + handle.remoteAddress, + handle.localAddress, + transportMapping(handle.localAddress), + settings, + Some(handle), + writing) + if (writing) + endpoints.registerWritableEndpoint(handle.remoteAddress, endpoint) + else { + endpoints.registerReadOnlyEndpoint(handle.remoteAddress, endpoint) + endpoints.writableEndpointWithPolicyFor(handle.remoteAddress) match { + case Some(Pass(_)) ⇒ // Leave it alone + case _ ⇒ + // Since we just communicated with the guy we can lift gate, quarantine, etc. New writer will be + // opened at first write. + endpoints.removePolicy(handle.remoteAddress) + } } - } } } + case EndpointWriter.StoppedReading(endpoint) ⇒ + acceptPendingReader(takingOverFrom = endpoint) case Terminated(endpoint) ⇒ + acceptPendingReader(takingOverFrom = endpoint) endpoints.unregisterEndpoint(endpoint) case Prune ⇒ endpoints.prune() @@ -609,6 +619,22 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends }) } + private def acceptPendingReader(takingOverFrom: ActorRef): Unit = { + if (pendingReadHandoffs.contains(takingOverFrom)) { + val handle = pendingReadHandoffs(takingOverFrom) + pendingReadHandoffs -= takingOverFrom + eventPublisher.notifyListeners(AssociatedEvent(handle.localAddress, handle.remoteAddress, true)) + val endpoint = createEndpoint( + handle.remoteAddress, + handle.localAddress, + transportMapping(handle.localAddress), + settings, + Some(handle), + writing = false) + endpoints.registerReadOnlyEndpoint(handle.remoteAddress, endpoint) + } + } + private def createEndpoint(remoteAddress: Address, localAddress: Address, transport: Transport,