diff --git a/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala b/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala index 4bdfac145e..2243367a81 100644 --- a/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala @@ -333,12 +333,25 @@ class TcpConnectionSpec extends AkkaSpec(""" "respect StopReading and ResumeReading" in new EstablishedConnectionTest() { run { + serverSideChannel.write(ByteBuffer.wrap("testdata".getBytes("ASCII"))) connectionHandler.send(connectionActor, SuspendReading) - // the selector interprets StopReading to deregister interest for reading + // the selector interprets SuspendReading to deregister interest for reading interestCallReceiver.expectMsg(-OP_READ) + + // this simulates a race condition where ChannelReadable was already underway while + // processing SuspendReading + selector.send(connectionActor, ChannelReadable) + + // this ChannelReadable should be properly ignored, even if data is already pending + interestCallReceiver.expectNoMsg(100.millis) + connectionHandler.expectNoMsg(100.millis) + connectionHandler.send(connectionActor, ResumeReading) interestCallReceiver.expectMsg(OP_READ) + + // data should be received only after ResumeReading + expectReceivedString("testdata") } } diff --git a/akka-actor/src/main/scala/akka/io/TcpConnection.scala b/akka-actor/src/main/scala/akka/io/TcpConnection.scala index a2b7dc8525..5ee96da553 100644 --- a/akka-actor/src/main/scala/akka/io/TcpConnection.scala +++ b/akka-actor/src/main/scala/akka/io/TcpConnection.scala @@ -35,6 +35,7 @@ private[io] abstract class TcpConnection(val tcp: TcpExt, val channel: SocketCha private[this] var pendingWrite: PendingWrite = EmptyPendingWrite private[this] var peerClosed = false private[this] var writingSuspended = false + private[this] var readingSuspended = false private[this] var interestedInResume: Option[ActorRef] = None var closedMessage: CloseInformation = _ // for ConnectionClosed message in postStop @@ -72,8 +73,8 @@ private[io] abstract class TcpConnection(val tcp: TcpExt, val channel: SocketCha /** normal connected state */ def connected(info: ConnectionInfo): Receive = handleWriteMessages(info) orElse { - case SuspendReading ⇒ info.registration.disableInterest(OP_READ) - case ResumeReading ⇒ info.registration.enableInterest(OP_READ) + case SuspendReading ⇒ suspendReading(info) + case ResumeReading ⇒ resumeReading(info) case ChannelReadable ⇒ doRead(info, None) case cmd: CloseCommand ⇒ handleClose(info, Some(sender), cmd.event) } @@ -87,8 +88,8 @@ private[io] abstract class TcpConnection(val tcp: TcpExt, val channel: SocketCha /** connection is closing but a write has to be finished first */ def closingWithPendingWrite(info: ConnectionInfo, closeCommander: Option[ActorRef], closedEvent: ConnectionClosed): Receive = { - case SuspendReading ⇒ info.registration.disableInterest(OP_READ) - case ResumeReading ⇒ info.registration.enableInterest(OP_READ) + case SuspendReading ⇒ suspendReading(info) + case ResumeReading ⇒ resumeReading(info) case ChannelReadable ⇒ doRead(info, closeCommander) case ChannelWritable ⇒ @@ -108,8 +109,8 @@ private[io] abstract class TcpConnection(val tcp: TcpExt, val channel: SocketCha /** connection is closed on our side and we're waiting from confirmation from the other side */ def closing(info: ConnectionInfo, closeCommander: Option[ActorRef]): Receive = { - case SuspendReading ⇒ info.registration.disableInterest(OP_READ) - case ResumeReading ⇒ info.registration.enableInterest(OP_READ) + case SuspendReading ⇒ suspendReading(info) + case ResumeReading ⇒ resumeReading(info) case ChannelReadable ⇒ doRead(info, closeCommander) case Abort ⇒ handleClose(info, Some(sender), Aborted) } @@ -180,42 +181,52 @@ private[io] abstract class TcpConnection(val tcp: TcpExt, val channel: SocketCha context.become(waitingForRegistration(registration, commander)) } - def doRead(info: ConnectionInfo, closeCommander: Option[ActorRef]): Unit = { - @tailrec def innerRead(buffer: ByteBuffer, remainingLimit: Int): ReadResult = - if (remainingLimit > 0) { - // never read more than the configured limit - buffer.clear() - val maxBufferSpace = math.min(DirectBufferSize, remainingLimit) - buffer.limit(maxBufferSpace) - val readBytes = channel.read(buffer) - buffer.flip() - - if (TraceLogging) log.debug("Read [{}] bytes.", readBytes) - if (readBytes > 0) info.handler ! Received(ByteString(buffer)) - - readBytes match { - case `maxBufferSpace` ⇒ innerRead(buffer, remainingLimit - maxBufferSpace) - case x if x >= 0 ⇒ AllRead - case -1 ⇒ EndOfStream - case _ ⇒ - throw new IllegalStateException("Unexpected value returned from read: " + readBytes) - } - } else MoreDataWaiting - - val buffer = bufferPool.acquire() - try innerRead(buffer, ReceivedMessageSizeLimit) match { - case AllRead ⇒ info.registration.enableInterest(OP_READ) - case MoreDataWaiting ⇒ self ! ChannelReadable - case EndOfStream if channel.socket.isOutputShutdown ⇒ - if (TraceLogging) log.debug("Read returned end-of-stream, our side already closed") - doCloseConnection(info.handler, closeCommander, ConfirmedClosed) - case EndOfStream ⇒ - if (TraceLogging) log.debug("Read returned end-of-stream, our side not yet closed") - handleClose(info, closeCommander, PeerClosed) - } catch { - case e: IOException ⇒ handleError(info.handler, e) - } finally bufferPool.release(buffer) + def suspendReading(info: ConnectionInfo): Unit = { + readingSuspended = true + info.registration.disableInterest(OP_READ) } + def resumeReading(info: ConnectionInfo): Unit = { + readingSuspended = false + info.registration.enableInterest(OP_READ) + } + + def doRead(info: ConnectionInfo, closeCommander: Option[ActorRef]): Unit = + if (!readingSuspended) { + @tailrec def innerRead(buffer: ByteBuffer, remainingLimit: Int): ReadResult = + if (remainingLimit > 0) { + // never read more than the configured limit + buffer.clear() + val maxBufferSpace = math.min(DirectBufferSize, remainingLimit) + buffer.limit(maxBufferSpace) + val readBytes = channel.read(buffer) + buffer.flip() + + if (TraceLogging) log.debug("Read [{}] bytes.", readBytes) + if (readBytes > 0) info.handler ! Received(ByteString(buffer)) + + readBytes match { + case `maxBufferSpace` ⇒ innerRead(buffer, remainingLimit - maxBufferSpace) + case x if x >= 0 ⇒ AllRead + case -1 ⇒ EndOfStream + case _ ⇒ + throw new IllegalStateException("Unexpected value returned from read: " + readBytes) + } + } else MoreDataWaiting + + val buffer = bufferPool.acquire() + try innerRead(buffer, ReceivedMessageSizeLimit) match { + case AllRead ⇒ info.registration.enableInterest(OP_READ) + case MoreDataWaiting ⇒ self ! ChannelReadable + case EndOfStream if channel.socket.isOutputShutdown ⇒ + if (TraceLogging) log.debug("Read returned end-of-stream, our side already closed") + doCloseConnection(info.handler, closeCommander, ConfirmedClosed) + case EndOfStream ⇒ + if (TraceLogging) log.debug("Read returned end-of-stream, our side not yet closed") + handleClose(info, closeCommander, PeerClosed) + } catch { + case e: IOException ⇒ handleError(info.handler, e) + } finally bufferPool.release(buffer) + } def doWrite(info: ConnectionInfo): Unit = pendingWrite = pendingWrite.doWrite(info)