stream: Improve half-closing of outgoing TCP connections (#28624)

Notably fixes the case where upstream finished before the connection
was successfully established, and avoids RSTing the incoming stream
when the outgoing stream is done (which is now possible due to the
cancellation reason being propagated).
This commit is contained in:
Arnout Engelen 2020-02-20 13:28:21 +01:00 committed by GitHub
parent cff15cf40d
commit 8a354ec3f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 31 deletions

View file

@ -142,6 +142,7 @@ private[io] abstract class TcpConnection(val tcp: TcpExt, val channel: SocketCha
case SuspendReading => suspendReading(info)
case ResumeReading => resumeReading(info)
case ChannelReadable => doRead(info, closeCommander)
case Close => doCloseConnection(info.handler, closeCommander, Close.event)
case Abort => handleClose(info, Some(sender()), Aborted)
}

View file

@ -469,14 +469,36 @@ class TcpSpec extends StreamSpec("""
server.close()
}
"properly half-close by default" in assertAllStagesStopped {
val writeButDontRead: Flow[ByteString, ByteString, NotUsed] =
Flow.fromSinkAndSource(Sink.cancelled, Source.single(ByteString("Early response")))
val binding =
Tcp()
.bind("127.0.0.1", 0, halfClose = true)
.toMat(Sink.foreach { conn =>
conn.flow.join(writeButDontRead).run()
})(Keep.left)
.run()
.futureValue
val result = Source.empty
.via(Tcp().outgoingConnection(binding.localAddress))
.toMat(Sink.fold(ByteString.empty)(_ ++ _))(Keep.right)
.run()
result.futureValue should ===(ByteString("Early response"))
binding.unbind()
}
"properly full-close if requested" in assertAllStagesStopped {
val serverAddress = temporaryServerAddress()
val writeButIgnoreRead: Flow[ByteString, ByteString, NotUsed] =
Flow.fromSinkAndSourceMat(Sink.ignore, Source.single(ByteString("Early response")))(Keep.right)
val binding =
Tcp()
.bind(serverAddress.getHostString, serverAddress.getPort, halfClose = false)
.bind("127.0.0.1", 0, halfClose = false)
.toMat(Sink.foreach { conn =>
conn.flow.join(writeButIgnoreRead).run()
})(Keep.left)
@ -485,7 +507,7 @@ class TcpSpec extends StreamSpec("""
val (promise, result) = Source
.maybe[ByteString]
.via(Tcp().outgoingConnection(serverAddress.getHostString, serverAddress.getPort))
.via(Tcp().outgoingConnection(binding.localAddress))
.toMat(Sink.fold(ByteString.empty)(_ ++ _))(Keep.both)
.run()

View file

@ -240,6 +240,9 @@ private[stream] object ConnectionSourceStage {
private def bytesIn = shape.in
private def bytesOut = shape.out
// Set once (in preStart for inbound connections, in 'connecting' for outbound connections)
// After that remains immutable
private var connection: ActorRef = _
@silent("deprecated")
@ -250,7 +253,10 @@ private[stream] object ConnectionSourceStage {
.size
private var writeBuffer = ByteString.empty
// there is data in-flight that we accepted from upstream but haven't successfully written to the connection yet
private var writeInProgress = false
// upstream already finished but are still writing the last data to the connection
private var connectionClosePending = false
// No reading until role have been decided
@ -274,6 +280,7 @@ private[stream] object ConnectionSourceStage {
}
}
// Only used for outbound connections
private def connecting(ob: Outbound)(evt: (ActorRef, Any)): Unit = {
val sender = evt._1
val msg = evt._2
@ -290,10 +297,12 @@ private[stream] object ConnectionSourceStage {
stageActor.watch(connection)
connection ! Register(self, keepOpenOnPeerClosed = true, useResumeWriting = false)
if (isAvailable(bytesOut)) connection ! ResumeReading
pull(bytesIn)
if (isClosed(bytesIn)) connection ! ConfirmedClose
else pull(bytesIn)
}
}
// Used for both inbound and outbound connections
private def connected(evt: (ActorRef, Any)): Unit = {
val msg = evt._2
msg match {
@ -312,8 +321,7 @@ private[stream] object ConnectionSourceStage {
}
if (!writeInProgress && connectionClosePending) {
// continue onUpstreamFinish
closeConnection()
closeConnectionUpstreamFinished()
}
if (!isClosed(bytesIn) && !hasBeenPulled(bytesIn))
@ -327,12 +335,10 @@ private[stream] object ConnectionSourceStage {
case Closed => completeStage()
case ConfirmedClosed => completeStage()
case PeerClosed => complete(bytesOut)
}
}
private def closeConnection(): Unit = {
// Note that if there are pending bytes in the writeBuffer those must be written first.
private def closeConnectionUpstreamFinished(): Unit = {
if (isClosed(bytesOut) || !role.halfClose) {
// Reading has stopped before, either because of cancel, or PeerClosed, so just Close now
// (or half-close is turned off)
@ -346,7 +352,27 @@ private[stream] object ConnectionSourceStage {
connectionClosePending = true // will continue when WriteAck is received and writeBuffer drained
else
connection ! ConfirmedClose
} else completeStage()
}
// Otherwise, this is an outbound connection with half-close enabled for which upstream finished
// before the connection was even established.
// In that case we half-close the connection as soon as it's connected
}
private def closeConnectionDownstreamFinished(): Unit = {
if (connection == null) {
// This is an outbound connection for which downstream finished
// before the connection was even established.
// In that case we close the connection as soon as upstream finishes
} else {
if (role.halfClose) {
if (isClosed(bytesIn) && !writeInProgress)
connection ! Close
else
connection ! ResumeReading
} else if (!writeInProgress) {
connection ! Close
}
}
}
val readHandler = new OutHandler {
@ -355,25 +381,21 @@ private[stream] object ConnectionSourceStage {
}
override def onDownstreamFinish(cause: Throwable): Unit = {
if (!isClosed(bytesIn)) connection ! ResumeReading
else {
if (log.isDebugEnabled) {
cause match {
case _: SubscriptionWithCancelException.NonFailureCancellation =>
log.debug(
"Aborting connection from {}:{} because downstream cancelled stream",
"Not aborting connection from {}:{} because downstream cancelled stream without failure",
remoteAddress.getHostString,
remoteAddress.getPort)
closeConnectionDownstreamFinished()
case ex =>
log.debug(
"Aborting connection from {}:{} because of downstream failure: {}",
remoteAddress.getHostString,
remoteAddress.getPort,
ex)
}
}
connection ! Abort
completeStage()
failStage(cause)
}
}
}
@ -393,11 +415,10 @@ private[stream] object ConnectionSourceStage {
}
if (writeBuffer.size < writeBufferSize)
pull(bytesIn)
}
override def onUpstreamFinish(): Unit =
closeConnection()
closeConnectionUpstreamFinished()
override def onUpstreamFailure(ex: Throwable): Unit = {
if (connection != null) {