diff --git a/akka-remote/src/main/scala/akka/remote/Endpoint.scala b/akka-remote/src/main/scala/akka/remote/Endpoint.scala index 2b2676c9eb..639084e812 100644 --- a/akka-remote/src/main/scala/akka/remote/Endpoint.scala +++ b/akka-remote/src/main/scala/akka/remote/Endpoint.scala @@ -134,8 +134,7 @@ private[remote] class EndpointWriter( override val supervisorStrategy = OneForOneStrategy() { case NonFatal(e) ⇒ publishAndThrow(e) } - val msgDispatch = - new DefaultMessageDispatcher(extendedSystem, RARP(extendedSystem).provider, log) + val msgDispatch = new DefaultMessageDispatcher(extendedSystem, RARP(extendedSystem).provider, log) def inbound = handle.isDefined @@ -151,15 +150,17 @@ private[remote] class EndpointWriter( preStart() } - override def preStart(): Unit = { - if (!inbound) { - transport.associate(remoteAddress) pipeTo self - startWith(Initializing, ()) - } else { - startReadEndpoint() - startWith(Writing, ()) - } - } + override def preStart(): Unit = + startWith( + handle match { + case Some(h) ⇒ + reader = startReadEndpoint(h) + Writing + case None ⇒ + transport.associate(remoteAddress) pipeTo self + Initializing + }, + ()) when(Initializing) { case Event(Send(msg, senderOption, recipient), _) ⇒ @@ -172,8 +173,9 @@ private[remote] class EndpointWriter( case Event(Status.Failure(e), _) ⇒ publishAndThrow(new EndpointException(s"Association failed with [$remoteAddress]", e)) case Event(inboundHandle: AssociationHandle, _) ⇒ + // Assert handle == None? handle = Some(inboundHandle) - startReadEndpoint() + reader = startReadEndpoint(inboundHandle) goto(Writing) } @@ -188,19 +190,20 @@ private[remote] class EndpointWriter( when(Writing) { case Event(Send(msg, senderOption, recipient), _) ⇒ - val pdu = codec.constructMessage(recipient.localAddressToUse, recipient, serializeMessage(msg), senderOption) - val success = try { + try { handle match { - case Some(h) ⇒ h.write(pdu) - case None ⇒ throw new EndpointException("Internal error: Endpoint is in state Writing, but no association" + - "handle is present.") + case Some(h) ⇒ + val pdu = codec.constructMessage(recipient.localAddressToUse, recipient, serializeMessage(msg), senderOption) + if (h.write(pdu)) stay() else { + stash() + goto(Buffering) + } + case None ⇒ + throw new EndpointException("Internal error: Endpoint is in state Writing, but no association handle is present.") } } catch { - case NonFatal(e) ⇒ publishAndThrow(new EndpointException("Failed to write message to the transport", e)) - } - if (success) stay() else { - stash() - goto(Buffering) + case NonFatal(e: EndpointException) ⇒ publishAndThrow(e) + case NonFatal(e) ⇒ publishAndThrow(new EndpointException("Failed to write message to the transport", e)) } } @@ -209,14 +212,9 @@ private[remote] class EndpointWriter( case Event(TakeOver(newHandle), _) ⇒ // Shutdown old reader handle foreach { _.disassociate() } - reader match { - case Some(r) ⇒ - context.unwatch(r) - context.stop(r) - case None ⇒ - } + reader foreach { r ⇒ context stop context.unwatch(r) } handle = Some(newHandle) - startReadEndpoint() + reader = startReadEndpoint(newHandle) unstashAll() goto(Writing) } @@ -242,17 +240,15 @@ private[remote] class EndpointWriter( eventPublisher.notifyListeners(DisassociatedEvent(localAddress, remoteAddress, inbound)) } - private def startReadEndpoint(): Unit = handle match { - case Some(h) ⇒ - val readerLocalAddress = h.localAddress - val readerCodec = codec - val readerDispatcher = msgDispatch - reader = Some( - context.watch(context.actorOf(Props(new EndpointReader(readerCodec, readerLocalAddress, readerDispatcher)), - "endpointReader-" + AddressUrlEncoder(remoteAddress) + "-" + readerId.next()))) - h.readHandlerPromise.success(ActorHandleEventListener(reader.get)) - case None ⇒ throw new EndpointException("Internal error: No handle was present during creation of the endpoint" + - "reader.") + private def startReadEndpoint(handle: AssociationHandle): Some[ActorRef] = { + val readerLocalAddress = handle.localAddress + val readerCodec = codec + val readerDispatcher = msgDispatch + val newReader = + context.watch(context.actorOf(Props(new EndpointReader(readerCodec, readerLocalAddress, readerDispatcher)), + "endpointReader-" + AddressUrlEncoder(remoteAddress) + "-" + readerId.next())) + handle.readHandlerPromise.success(ActorHandleEventListener(newReader)) + Some(newReader) } private def serializeMessage(msg: Any): MessageProtocol = handle match {