diff --git a/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala b/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala index b34fa5fe53..6044a56015 100644 --- a/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala @@ -389,6 +389,39 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") selector.send(connectionActor, ChannelReadable) connectionHandler.expectMsg(PeerClosed) + connectionHandler.send(connectionActor, Close) + + assertThisConnectionActorTerminated() + } + "report when peer closed the connection but allow further writes and acknowledge normal close" in withEstablishedConnection() { setup ⇒ + import setup._ + + closeServerSideAndWaitForClientReadable(fullClose = false) // send EOF (fin) from the server side + + selector.send(connectionActor, ChannelReadable) + connectionHandler.expectMsg(PeerClosed) + object Ack + connectionHandler.send(connectionActor, writeCmd(Ack)) + pullFromServerSide(TestSize) + connectionHandler.expectMsg(Ack) + connectionHandler.send(connectionActor, Close) + connectionHandler.expectMsg(Closed) + + assertThisConnectionActorTerminated() + } + "report when peer closed the connection but allow further writes and acknowledge confirmed close" in withEstablishedConnection() { setup ⇒ + import setup._ + + closeServerSideAndWaitForClientReadable(fullClose = false) // send EOF (fin) from the server side + + selector.send(connectionActor, ChannelReadable) + connectionHandler.expectMsg(PeerClosed) + object Ack + connectionHandler.send(connectionActor, writeCmd(Ack)) + pullFromServerSide(TestSize) + connectionHandler.expectMsg(Ack) + connectionHandler.send(connectionActor, ConfirmedClose) + connectionHandler.expectMsg(ConfirmedClosed) assertThisConnectionActorTerminated() } @@ -535,8 +568,8 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") val clientSelectionKey = registerChannel(clientSideChannel, "client") val serverSelectionKey = registerChannel(serverSideChannel, "server") - def closeServerSideAndWaitForClientReadable(): Unit = { - serverSideChannel.close() + def closeServerSideAndWaitForClientReadable(fullClose: Boolean = true): Unit = { + if (fullClose) serverSideChannel.close() else serverSideChannel.socket.shutdownOutput() checkFor(clientSelectionKey, SelectionKey.OP_READ, 3.seconds.toMillis.toInt) must be(true) } diff --git a/akka-actor-tests/src/test/scala/akka/io/TcpIntegrationSpec.scala b/akka-actor-tests/src/test/scala/akka/io/TcpIntegrationSpec.scala index cac2ab5a88..9b7326d5a6 100644 --- a/akka-actor-tests/src/test/scala/akka/io/TcpIntegrationSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/io/TcpIntegrationSpec.scala @@ -23,6 +23,8 @@ class TcpIntegrationSpec extends AkkaSpec("akka.loglevel = INFO") with TcpIntegr clientHandler.send(clientConnection, Close) clientHandler.expectMsg(Closed) serverHandler.expectMsg(PeerClosed) + serverHandler.send(serverConnection, Close) + serverHandler.expectMsg(Closed) verifyActorTermination(clientConnection) verifyActorTermination(serverConnection) } @@ -52,6 +54,8 @@ class TcpIntegrationSpec extends AkkaSpec("akka.loglevel = INFO") with TcpIntegr serverHandler.send(serverConnection, Close) serverHandler.expectMsg(Closed) clientHandler.expectMsg(PeerClosed) + clientHandler.send(clientConnection, Close) + clientHandler.expectMsg(Closed) verifyActorTermination(clientConnection) verifyActorTermination(serverConnection) diff --git a/akka-actor/src/main/scala/akka/io/TcpConnection.scala b/akka-actor/src/main/scala/akka/io/TcpConnection.scala index c5cad36f7c..7bb66f64b4 100644 --- a/akka-actor/src/main/scala/akka/io/TcpConnection.scala +++ b/akka-actor/src/main/scala/akka/io/TcpConnection.scala @@ -61,25 +61,16 @@ private[io] abstract class TcpConnection(val channel: SocketChannel, } /** normal connected state */ - def connected(handler: ActorRef): Receive = { - case StopReading ⇒ selector ! DisableReadInterest - case ResumeReading ⇒ selector ! ReadInterest - case ChannelReadable ⇒ doRead(handler, None) + def connected(handler: ActorRef): Receive = handleWriteMessages(handler) orElse { + case StopReading ⇒ selector ! DisableReadInterest + case ResumeReading ⇒ selector ! ReadInterest + case ChannelReadable ⇒ doRead(handler, None) - case write: Write if writePending ⇒ - if (TraceLogging) log.debug("Dropping write because queue is full") - sender ! write.failureMessage - - case write: Write if write.data.isEmpty ⇒ - if (write.wantsAck) - sender ! write.ack - - case write: Write ⇒ - pendingWrite = createWrite(write) - doWrite(handler) - - case ChannelWritable ⇒ if (writePending) doWrite(handler) + case cmd: CloseCommand ⇒ handleClose(handler, Some(sender), cmd.event) + } + /** the peer sent EOF first, but we may still want to send */ + def peerSentEOF(handler: ActorRef): Receive = handleWriteMessages(handler) orElse { case cmd: CloseCommand ⇒ handleClose(handler, Some(sender), cmd.event) } @@ -105,6 +96,22 @@ private[io] abstract class TcpConnection(val channel: SocketChannel, case Abort ⇒ handleClose(handler, Some(sender), Aborted) } + def handleWriteMessages(handler: ActorRef): Receive = { + case ChannelWritable ⇒ if (writePending) doWrite(handler) + + case write: Write if writePending ⇒ + if (TraceLogging) log.debug("Dropping write because queue is full") + sender ! write.failureMessage + + case write: Write if write.data.isEmpty ⇒ + if (write.wantsAck) + sender ! write.ack + + case write: Write ⇒ + pendingWrite = createWrite(write) + doWrite(handler) + } + // AUXILIARIES and IMPLEMENTATION /** used in subclasses to start the common machinery above once a channel is connected */ @@ -147,9 +154,12 @@ private[io] abstract class TcpConnection(val channel: SocketChannel, try innerRead(buffer, ReceivedMessageSizeLimit) match { case AllRead ⇒ selector ! ReadInterest case MoreDataWaiting ⇒ self ! ChannelReadable + case EndOfStream if channel.socket.isOutputShutdown ⇒ + if (TraceLogging) log.debug("Read returned end-of-stream, our side already closed") + doCloseConnection(handler, closeCommander, ConfirmedClosed) case EndOfStream ⇒ - if (TraceLogging) log.debug("Read returned end-of-stream") - doCloseConnection(handler, closeCommander, closeReason) + if (TraceLogging) log.debug("Read returned end-of-stream, our side not yet closed") + handleClose(handler, closeCommander, PeerClosed) } catch { case e: IOException ⇒ handleError(handler, e) } finally bufferPool.release(buffer) @@ -190,7 +200,12 @@ private[io] abstract class TcpConnection(val channel: SocketChannel, if (closedEvent == Aborted) { // close instantly if (TraceLogging) log.debug("Got Abort command. RESETing connection.") doCloseConnection(handler, closeCommander, closedEvent) - + } else if (closedEvent == PeerClosed) { + // report that peer closed the connection + handler ! PeerClosed + // used to check if peer already closed its side later + channel.socket().shutdownInput() + context.become(peerSentEOF(handler)) } else if (writePending) { // finish writing first if (TraceLogging) log.debug("Got Close command but write is still pending.") context.become(closingWithPendingWrite(handler, closeCommander, closedEvent)) @@ -198,8 +213,10 @@ private[io] abstract class TcpConnection(val channel: SocketChannel, } else if (closedEvent == ConfirmedClosed) { // shutdown output and wait for confirmation if (TraceLogging) log.debug("Got ConfirmedClose command, sending FIN.") channel.socket.shutdownOutput() - context.become(closing(handler, closeCommander)) + if (channel.socket().isInputShutdown) // if peer closed first, the socket is now fully closed + doCloseConnection(handler, closeCommander, closedEvent) + else context.become(closing(handler, closeCommander)) } else { // close now if (TraceLogging) log.debug("Got Close command, closing connection.") doCloseConnection(handler, closeCommander, closedEvent)