From 8a354ec3f07bb4028f74df9155dfd7aa415fd1dc Mon Sep 17 00:00:00 2001 From: Arnout Engelen Date: Thu, 20 Feb 2020 13:28:21 +0100 Subject: [PATCH] stream: Improve half-closing of outgoing TCP connections (#28624) Notably fixes the case where upstream finished before the connection was successfully established, and avoids RSTing the incoming stream when the outgoing stream is done (which is now possible due to the cancellation reason being propagated). --- .../main/scala/akka/io/TcpConnection.scala | 1 + .../test/scala/akka/stream/io/TcpSpec.scala | 28 ++++++- .../scala/akka/stream/impl/io/TcpStages.scala | 77 ++++++++++++------- 3 files changed, 75 insertions(+), 31 deletions(-) diff --git a/akka-actor/src/main/scala/akka/io/TcpConnection.scala b/akka-actor/src/main/scala/akka/io/TcpConnection.scala index 089880f7f3..4384444b0e 100644 --- a/akka-actor/src/main/scala/akka/io/TcpConnection.scala +++ b/akka-actor/src/main/scala/akka/io/TcpConnection.scala @@ -142,6 +142,7 @@ private[io] abstract class TcpConnection(val tcp: TcpExt, val channel: SocketCha case SuspendReading => suspendReading(info) case ResumeReading => resumeReading(info) case ChannelReadable => doRead(info, closeCommander) + case Close => doCloseConnection(info.handler, closeCommander, Close.event) case Abort => handleClose(info, Some(sender()), Aborted) } diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala index f457c50799..3288c8523f 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala @@ -469,14 +469,36 @@ class TcpSpec extends StreamSpec(""" server.close() } + "properly half-close by default" in assertAllStagesStopped { + val writeButDontRead: Flow[ByteString, ByteString, NotUsed] = + Flow.fromSinkAndSource(Sink.cancelled, Source.single(ByteString("Early response"))) + + val binding = + Tcp() + .bind("127.0.0.1", 0, halfClose = true) + .toMat(Sink.foreach { conn => + conn.flow.join(writeButDontRead).run() + })(Keep.left) + .run() + .futureValue + + val result = Source.empty + .via(Tcp().outgoingConnection(binding.localAddress)) + .toMat(Sink.fold(ByteString.empty)(_ ++ _))(Keep.right) + .run() + + result.futureValue should ===(ByteString("Early response")) + + binding.unbind() + } + "properly full-close if requested" in assertAllStagesStopped { - val serverAddress = temporaryServerAddress() val writeButIgnoreRead: Flow[ByteString, ByteString, NotUsed] = Flow.fromSinkAndSourceMat(Sink.ignore, Source.single(ByteString("Early response")))(Keep.right) val binding = Tcp() - .bind(serverAddress.getHostString, serverAddress.getPort, halfClose = false) + .bind("127.0.0.1", 0, halfClose = false) .toMat(Sink.foreach { conn => conn.flow.join(writeButIgnoreRead).run() })(Keep.left) @@ -485,7 +507,7 @@ class TcpSpec extends StreamSpec(""" val (promise, result) = Source .maybe[ByteString] - .via(Tcp().outgoingConnection(serverAddress.getHostString, serverAddress.getPort)) + .via(Tcp().outgoingConnection(binding.localAddress)) .toMat(Sink.fold(ByteString.empty)(_ ++ _))(Keep.both) .run() diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala index b951f924b2..b95e37dd9f 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala @@ -240,6 +240,9 @@ private[stream] object ConnectionSourceStage { private def bytesIn = shape.in private def bytesOut = shape.out + + // Set once (in preStart for inbound connections, in 'connecting' for outbound connections) + // After that remains immutable private var connection: ActorRef = _ @silent("deprecated") @@ -250,7 +253,10 @@ private[stream] object ConnectionSourceStage { .size private var writeBuffer = ByteString.empty + + // there is data in-flight that we accepted from upstream but haven't successfully written to the connection yet private var writeInProgress = false + // upstream already finished but are still writing the last data to the connection private var connectionClosePending = false // No reading until role have been decided @@ -274,6 +280,7 @@ private[stream] object ConnectionSourceStage { } } + // Only used for outbound connections private def connecting(ob: Outbound)(evt: (ActorRef, Any)): Unit = { val sender = evt._1 val msg = evt._2 @@ -290,10 +297,12 @@ private[stream] object ConnectionSourceStage { stageActor.watch(connection) connection ! Register(self, keepOpenOnPeerClosed = true, useResumeWriting = false) if (isAvailable(bytesOut)) connection ! ResumeReading - pull(bytesIn) + if (isClosed(bytesIn)) connection ! ConfirmedClose + else pull(bytesIn) } } + // Used for both inbound and outbound connections private def connected(evt: (ActorRef, Any)): Unit = { val msg = evt._2 msg match { @@ -312,8 +321,7 @@ private[stream] object ConnectionSourceStage { } if (!writeInProgress && connectionClosePending) { - // continue onUpstreamFinish - closeConnection() + closeConnectionUpstreamFinished() } if (!isClosed(bytesIn) && !hasBeenPulled(bytesIn)) @@ -327,12 +335,10 @@ private[stream] object ConnectionSourceStage { case Closed => completeStage() case ConfirmedClosed => completeStage() case PeerClosed => complete(bytesOut) - } } - private def closeConnection(): Unit = { - // Note that if there are pending bytes in the writeBuffer those must be written first. + private def closeConnectionUpstreamFinished(): Unit = { if (isClosed(bytesOut) || !role.halfClose) { // Reading has stopped before, either because of cancel, or PeerClosed, so just Close now // (or half-close is turned off) @@ -346,7 +352,27 @@ private[stream] object ConnectionSourceStage { connectionClosePending = true // will continue when WriteAck is received and writeBuffer drained else connection ! ConfirmedClose - } else completeStage() + } + // Otherwise, this is an outbound connection with half-close enabled for which upstream finished + // before the connection was even established. + // In that case we half-close the connection as soon as it's connected + } + + private def closeConnectionDownstreamFinished(): Unit = { + if (connection == null) { + // This is an outbound connection for which downstream finished + // before the connection was even established. + // In that case we close the connection as soon as upstream finishes + } else { + if (role.halfClose) { + if (isClosed(bytesIn) && !writeInProgress) + connection ! Close + else + connection ! ResumeReading + } else if (!writeInProgress) { + connection ! Close + } + } } val readHandler = new OutHandler { @@ -355,25 +381,21 @@ private[stream] object ConnectionSourceStage { } override def onDownstreamFinish(cause: Throwable): Unit = { - if (!isClosed(bytesIn)) connection ! ResumeReading - else { - if (log.isDebugEnabled) { - cause match { - case _: SubscriptionWithCancelException.NonFailureCancellation => - log.debug( - "Aborting connection from {}:{} because downstream cancelled stream", - remoteAddress.getHostString, - remoteAddress.getPort) - case ex => - log.debug( - "Aborting connection from {}:{} because of downstream failure: {}", - remoteAddress.getHostString, - remoteAddress.getPort, - ex) - } - } - connection ! Abort - completeStage() + cause match { + case _: SubscriptionWithCancelException.NonFailureCancellation => + log.debug( + "Not aborting connection from {}:{} because downstream cancelled stream without failure", + remoteAddress.getHostString, + remoteAddress.getPort) + closeConnectionDownstreamFinished() + case ex => + log.debug( + "Aborting connection from {}:{} because of downstream failure: {}", + remoteAddress.getHostString, + remoteAddress.getPort, + ex) + connection ! Abort + failStage(cause) } } } @@ -393,11 +415,10 @@ private[stream] object ConnectionSourceStage { } if (writeBuffer.size < writeBufferSize) pull(bytesIn) - } override def onUpstreamFinish(): Unit = - closeConnection() + closeConnectionUpstreamFinished() override def onUpstreamFailure(ex: Throwable): Unit = { if (connection != null) {