diff --git a/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala b/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala index e898131926..188a441d78 100644 --- a/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/io/TcpConnectionSpec.scala @@ -164,8 +164,9 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") serverSideChannel.read(buffer) must be(0) writer.send(connectionActor, ackedWrite) writer.expectMsg(Ack) - serverSideChannel.read(buffer) must be(8) + pullFromServerSide(remaining = 8, into = buffer) buffer.flip() + buffer.limit must be(8) // not reply to write commander for writes without Ack val unackedWrite = Write(ByteString("morestuff!")) @@ -173,8 +174,9 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") serverSideChannel.read(buffer) must be(0) writer.send(connectionActor, unackedWrite) writer.expectNoMsg(500.millis) - serverSideChannel.read(buffer) must be(10) + pullFromServerSide(remaining = 10, into = buffer) buffer.flip() + buffer.limit must be(10) ByteString(buffer).take(10).decodeString("ASCII") must be("morestuff!") } @@ -752,13 +754,13 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") (sel, key) } - val buffer = ByteBuffer.allocate(TestSize) + val defaultbuffer = ByteBuffer.allocate(TestSize) /** * Tries to simultaneously act on client and server side to read from the server * all pending data from the client. */ - @tailrec final def pullFromServerSide(remaining: Int, remainingTries: Int = 1000): Unit = + @tailrec final def pullFromServerSide(remaining: Int, remainingTries: Int = 1000, into: ByteBuffer = defaultbuffer): Unit = if (remainingTries <= 0) throw new AssertionError("Pulling took too many loops, remaining data: " + remaining) else if (remaining > 0) { @@ -775,25 +777,20 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms") } val read = - if (nioSelector.selectedKeys().contains(serverSelectionKey)) tryReading() - else 0 + if (nioSelector.selectedKeys().contains(serverSelectionKey)) { + if (into eq defaultbuffer) into.clear() + serverSideChannel.read(into) match { + case -1 ⇒ throw new IllegalStateException("Connection was closed unexpectedly with remaining bytes " + remaining) + case 0 ⇒ throw new IllegalStateException("Made no progress") + case other ⇒ other + } + } else 0 nioSelector.selectedKeys().clear() - pullFromServerSide(remaining - read, remainingTries - 1) + pullFromServerSide(remaining - read, remainingTries - 1, into) } - private def tryReading(): Int = { - buffer.clear() - val read = serverSideChannel.read(buffer) - - if (read == 0) - throw new IllegalStateException("Made no progress") - else if (read == -1) - throw new IllegalStateException("Connection was closed unexpectedly with remaining bytes " + remaining) - else read - } - @tailrec final def expectReceivedString(data: String): Unit = { data.length must be > 0