diff --git a/akka-actor/src/main/scala/akka/io/TcpConnection.scala b/akka-actor/src/main/scala/akka/io/TcpConnection.scala index 089880f7f3..4384444b0e 100644 --- a/akka-actor/src/main/scala/akka/io/TcpConnection.scala +++ b/akka-actor/src/main/scala/akka/io/TcpConnection.scala @@ -142,6 +142,7 @@ private[io] abstract class TcpConnection(val tcp: TcpExt, val channel: SocketCha case SuspendReading => suspendReading(info) case ResumeReading => resumeReading(info) case ChannelReadable => doRead(info, closeCommander) + case Close => doCloseConnection(info.handler, closeCommander, Close.event) case Abort => handleClose(info, Some(sender()), Aborted) } diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala index f457c50799..3288c8523f 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala @@ -469,14 +469,36 @@ class TcpSpec extends StreamSpec(""" server.close() } + "properly half-close by default" in assertAllStagesStopped { + val writeButDontRead: Flow[ByteString, ByteString, NotUsed] = + Flow.fromSinkAndSource(Sink.cancelled, Source.single(ByteString("Early response"))) + + val binding = + Tcp() + .bind("127.0.0.1", 0, halfClose = true) + .toMat(Sink.foreach { conn => + conn.flow.join(writeButDontRead).run() + })(Keep.left) + .run() + .futureValue + + val result = Source.empty + .via(Tcp().outgoingConnection(binding.localAddress)) + .toMat(Sink.fold(ByteString.empty)(_ ++ _))(Keep.right) + .run() + + result.futureValue should ===(ByteString("Early response")) + + binding.unbind() + } + "properly full-close if requested" in assertAllStagesStopped { - val serverAddress = temporaryServerAddress() val writeButIgnoreRead: Flow[ByteString, ByteString, NotUsed] = Flow.fromSinkAndSourceMat(Sink.ignore, Source.single(ByteString("Early response")))(Keep.right) val binding = Tcp() - .bind(serverAddress.getHostString, serverAddress.getPort, halfClose = false) + .bind("127.0.0.1", 0, halfClose = false) .toMat(Sink.foreach { conn => conn.flow.join(writeButIgnoreRead).run() })(Keep.left) @@ -485,7 +507,7 @@ class TcpSpec extends StreamSpec(""" val (promise, result) = Source .maybe[ByteString] - .via(Tcp().outgoingConnection(serverAddress.getHostString, serverAddress.getPort)) + .via(Tcp().outgoingConnection(binding.localAddress)) .toMat(Sink.fold(ByteString.empty)(_ ++ _))(Keep.both) .run() diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala index b951f924b2..b95e37dd9f 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala @@ -240,6 +240,9 @@ private[stream] object ConnectionSourceStage { private def bytesIn = shape.in private def bytesOut = shape.out + + // Set once (in preStart for inbound connections, in 'connecting' for outbound connections) + // After that remains immutable private var connection: ActorRef = _ @silent("deprecated") @@ -250,7 +253,10 @@ private[stream] object ConnectionSourceStage { .size private var writeBuffer = ByteString.empty + + // there is data in-flight that we accepted from upstream but haven't successfully written to the connection yet private var writeInProgress = false + // upstream already finished but are still writing the last data to the connection private var connectionClosePending = false // No reading until role have been decided @@ -274,6 +280,7 @@ private[stream] object ConnectionSourceStage { } } + // Only used for outbound connections private def connecting(ob: Outbound)(evt: (ActorRef, Any)): Unit = { val sender = evt._1 val msg = evt._2 @@ -290,10 +297,12 @@ private[stream] object ConnectionSourceStage { stageActor.watch(connection) connection ! Register(self, keepOpenOnPeerClosed = true, useResumeWriting = false) if (isAvailable(bytesOut)) connection ! ResumeReading - pull(bytesIn) + if (isClosed(bytesIn)) connection ! ConfirmedClose + else pull(bytesIn) } } + // Used for both inbound and outbound connections private def connected(evt: (ActorRef, Any)): Unit = { val msg = evt._2 msg match { @@ -312,8 +321,7 @@ private[stream] object ConnectionSourceStage { } if (!writeInProgress && connectionClosePending) { - // continue onUpstreamFinish - closeConnection() + closeConnectionUpstreamFinished() } if (!isClosed(bytesIn) && !hasBeenPulled(bytesIn)) @@ -327,12 +335,10 @@ private[stream] object ConnectionSourceStage { case Closed => completeStage() case ConfirmedClosed => completeStage() case PeerClosed => complete(bytesOut) - } } - private def closeConnection(): Unit = { - // Note that if there are pending bytes in the writeBuffer those must be written first. + private def closeConnectionUpstreamFinished(): Unit = { if (isClosed(bytesOut) || !role.halfClose) { // Reading has stopped before, either because of cancel, or PeerClosed, so just Close now // (or half-close is turned off) @@ -346,7 +352,27 @@ private[stream] object ConnectionSourceStage { connectionClosePending = true // will continue when WriteAck is received and writeBuffer drained else connection ! ConfirmedClose - } else completeStage() + } + // Otherwise, this is an outbound connection with half-close enabled for which upstream finished + // before the connection was even established. + // In that case we half-close the connection as soon as it's connected + } + + private def closeConnectionDownstreamFinished(): Unit = { + if (connection == null) { + // This is an outbound connection for which downstream finished + // before the connection was even established. + // In that case we close the connection as soon as upstream finishes + } else { + if (role.halfClose) { + if (isClosed(bytesIn) && !writeInProgress) + connection ! Close + else + connection ! ResumeReading + } else if (!writeInProgress) { + connection ! Close + } + } } val readHandler = new OutHandler { @@ -355,25 +381,21 @@ private[stream] object ConnectionSourceStage { } override def onDownstreamFinish(cause: Throwable): Unit = { - if (!isClosed(bytesIn)) connection ! ResumeReading - else { - if (log.isDebugEnabled) { - cause match { - case _: SubscriptionWithCancelException.NonFailureCancellation => - log.debug( - "Aborting connection from {}:{} because downstream cancelled stream", - remoteAddress.getHostString, - remoteAddress.getPort) - case ex => - log.debug( - "Aborting connection from {}:{} because of downstream failure: {}", - remoteAddress.getHostString, - remoteAddress.getPort, - ex) - } - } - connection ! Abort - completeStage() + cause match { + case _: SubscriptionWithCancelException.NonFailureCancellation => + log.debug( + "Not aborting connection from {}:{} because downstream cancelled stream without failure", + remoteAddress.getHostString, + remoteAddress.getPort) + closeConnectionDownstreamFinished() + case ex => + log.debug( + "Aborting connection from {}:{} because of downstream failure: {}", + remoteAddress.getHostString, + remoteAddress.getPort, + ex) + connection ! Abort + failStage(cause) } } } @@ -393,11 +415,10 @@ private[stream] object ConnectionSourceStage { } if (writeBuffer.size < writeBufferSize) pull(bytesIn) - } override def onUpstreamFinish(): Unit = - closeConnection() + closeConnectionUpstreamFinished() override def onUpstreamFailure(ex: Throwable): Unit = { if (connection != null) {