#3205 - Fixing a race condition in TcpConnectionSpec

This commit is contained in:
Viktor Klang 2013-04-20 21:24:45 -07:00
parent a4b485968c
commit 01a013e1c7

View file

@ -164,8 +164,9 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms")
serverSideChannel.read(buffer) must be(0) serverSideChannel.read(buffer) must be(0)
writer.send(connectionActor, ackedWrite) writer.send(connectionActor, ackedWrite)
writer.expectMsg(Ack) writer.expectMsg(Ack)
serverSideChannel.read(buffer) must be(8) pullFromServerSide(remaining = 8, into = buffer)
buffer.flip() buffer.flip()
buffer.limit must be(8)
// not reply to write commander for writes without Ack // not reply to write commander for writes without Ack
val unackedWrite = Write(ByteString("morestuff!")) val unackedWrite = Write(ByteString("morestuff!"))
@ -173,8 +174,9 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms")
serverSideChannel.read(buffer) must be(0) serverSideChannel.read(buffer) must be(0)
writer.send(connectionActor, unackedWrite) writer.send(connectionActor, unackedWrite)
writer.expectNoMsg(500.millis) writer.expectNoMsg(500.millis)
serverSideChannel.read(buffer) must be(10) pullFromServerSide(remaining = 10, into = buffer)
buffer.flip() buffer.flip()
buffer.limit must be(10)
ByteString(buffer).take(10).decodeString("ASCII") must be("morestuff!") ByteString(buffer).take(10).decodeString("ASCII") must be("morestuff!")
} }
@ -752,13 +754,13 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms")
(sel, key) (sel, key)
} }
val buffer = ByteBuffer.allocate(TestSize) val defaultbuffer = ByteBuffer.allocate(TestSize)
/** /**
* Tries to simultaneously act on client and server side to read from the server * Tries to simultaneously act on client and server side to read from the server
* all pending data from the client. * all pending data from the client.
*/ */
@tailrec final def pullFromServerSide(remaining: Int, remainingTries: Int = 1000): Unit = @tailrec final def pullFromServerSide(remaining: Int, remainingTries: Int = 1000, into: ByteBuffer = defaultbuffer): Unit =
if (remainingTries <= 0) if (remainingTries <= 0)
throw new AssertionError("Pulling took too many loops, remaining data: " + remaining) throw new AssertionError("Pulling took too many loops, remaining data: " + remaining)
else if (remaining > 0) { else if (remaining > 0) {
@ -775,25 +777,20 @@ class TcpConnectionSpec extends AkkaSpec("akka.io.tcp.register-timeout = 500ms")
} }
val read = val read =
if (nioSelector.selectedKeys().contains(serverSelectionKey)) tryReading() if (nioSelector.selectedKeys().contains(serverSelectionKey)) {
else 0 if (into eq defaultbuffer) into.clear()
serverSideChannel.read(into) match {
case -1 throw new IllegalStateException("Connection was closed unexpectedly with remaining bytes " + remaining)
case 0 throw new IllegalStateException("Made no progress")
case other other
}
} else 0
nioSelector.selectedKeys().clear() nioSelector.selectedKeys().clear()
pullFromServerSide(remaining - read, remainingTries - 1) pullFromServerSide(remaining - read, remainingTries - 1, into)
} }
private def tryReading(): Int = {
buffer.clear()
val read = serverSideChannel.read(buffer)
if (read == 0)
throw new IllegalStateException("Made no progress")
else if (read == -1)
throw new IllegalStateException("Connection was closed unexpectedly with remaining bytes " + remaining)
else read
}
@tailrec final def expectReceivedString(data: String): Unit = { @tailrec final def expectReceivedString(data: String): Unit = {
data.length must be > 0 data.length must be > 0