diff --git a/akka-actor-tests/src/test/scala/akka/actor/actor/IOActor.scala b/akka-actor-tests/src/test/scala/akka/actor/actor/IOActor.scala index ea6aabf780..7fb54ea53d 100644 --- a/akka-actor-tests/src/test/scala/akka/actor/actor/IOActor.scala +++ b/akka-actor-tests/src/test/scala/akka/actor/actor/IOActor.scala @@ -15,24 +15,26 @@ object IOActorSpec { class SimpleEchoServer(host: String, port: Int, ioManager: ActorRef) extends IOActor { - sequentialIO = false - idleWakeup = true - override def preStart = { listen(ioManager, host, port) } + def createWorker = Actor.actorOf(new IOActor { + def receiveIO = { + case IO.NewConnection(handle) ⇒ + val client = accept(handle) + loop(write(client, read(client))) + } + }) + def receiveIO = { - case IO.NewConnection(handle) ⇒ accept(handle) - case IO.WakeUp(handle) ⇒ write(handle, read(handle)) + case msg: IO.NewConnection ⇒ self startLink createWorker forward msg } } class SimpleEchoClient(host: String, port: Int, ioManager: ActorRef) extends IOActor { - sequentialIO = false - var handle: IO.Handle = _ override def preStart: Unit = { @@ -49,40 +51,53 @@ object IOActorSpec { // Basic Redis-style protocol class KVStore(host: String, port: Int, ioManager: ActorRef) extends IOActor { - sequentialIO = false - idleWakeup = true - var kvs: Map[String, ByteString] = Map.empty override def preStart = { listen(ioManager, host, port) } + def createWorker = Actor.actorOf(new IOActor { + def receiveIO = { + case IO.NewConnection(handle) ⇒ + val server = handle.owner + val client = accept(handle) + loop { + val cmd = read(client, ByteString(" ")).utf8String + cmd match { + case "SET" ⇒ + val key = read(client, ByteString(" ")).utf8String + val len = read(client, ByteString("\r\n")).utf8String + val value = read(client, len.toInt) + server ! ('set, key, value) + write(client, ByteString("+OK\r\n")) + case "GET" ⇒ + val key = read(client, ByteString("\r\n")).utf8String + server !!! (('get, key)) map { value: Option[ByteString] ⇒ + value map { bytes ⇒ + ByteString("$" + bytes.length + "\r\n") ++ bytes + } getOrElse ByteString("$-1\r\n") + } failure { + case e ⇒ ByteString("-" + e.getClass.toString + "\r\n") + } foreach { bytes: ByteString ⇒ + write(client, bytes) + } + } + } + } + }) + def receiveIO = { - case IO.NewConnection(handle) ⇒ - accept(handle) - case IO.WakeUp(handle) ⇒ - val cmd = read(handle, ByteString(" ")).utf8String - cmd match { - case "SET" ⇒ - val key = read(handle, ByteString(" ")).utf8String - val len = read(handle, ByteString("\r\n")).utf8String - val value = read(handle, len.toInt) - kvs += (key -> value) - write(handle, ByteString("+OK\r\n")) - case "GET" ⇒ - val key = read(handle, ByteString("\r\n")).utf8String - write(handle, kvs.get(key).map(v ⇒ ByteString("$" + v.length + "\r\n") ++ v).getOrElse(ByteString("$-1\r\n"))) - } + case msg: IO.NewConnection ⇒ self startLink createWorker forward msg + case ('set, key: String, value: ByteString) ⇒ kvs += (key -> value) + case ('get, key: String) ⇒ self reply_? kvs.get(key) + } } class KVClient(host: String, port: Int, ioManager: ActorRef) extends IOActor { - // FIXME: should prioritize reads from first message - // sequentialIO = false - var handle: IO.Handle = _ override def preStart: Unit = { @@ -131,18 +146,22 @@ class IOActorSpec extends WordSpec with MustMatchers with BeforeAndAfterEach { "run key-value store" in { val ioManager = Actor.actorOf(new IOManager(2)).start // teeny tiny buffer val server = Actor.actorOf(new KVStore("localhost", 8064, ioManager)).start - val client = Actor.actorOf(new KVClient("localhost", 8064, ioManager)).start - val promise1 = client !!! (('set, "hello", ByteString("World"))) - val promise2 = client !!! (('set, "test", ByteString("No one will read me"))) - val promise3 = client !!! (('get, "hello")) - val promise4 = client !!! (('set, "test", ByteString("I'm a test!"))) - val promise5 = client !!! (('get, "test")) + val client1 = Actor.actorOf(new KVClient("localhost", 8064, ioManager)).start + val client2 = Actor.actorOf(new KVClient("localhost", 8064, ioManager)).start + val promise1 = client1 !!! (('set, "hello", ByteString("World"))) + val promise2 = client1 !!! (('set, "test", ByteString("No one will read me"))) + val promise3 = client1 !!! (('get, "hello")) + promise2.await + val promise4 = client2 !!! (('set, "test", ByteString("I'm a test!"))) + promise4.await + val promise5 = client1 !!! (('get, "test")) (promise1.get: ByteString) must equal(ByteString("OK")) (promise2.get: ByteString) must equal(ByteString("OK")) (promise3.get: ByteString) must equal(ByteString("World")) (promise4.get: ByteString) must equal(ByteString("OK")) (promise5.get: ByteString) must equal(ByteString("I'm a test!")) - client.stop + client1.stop + client2.stop server.stop ioManager.stop } diff --git a/akka-actor/src/main/scala/akka/actor/IO.scala b/akka-actor/src/main/scala/akka/actor/IO.scala index b693b8be3e..2a1f8d0520 100644 --- a/akka-actor/src/main/scala/akka/actor/IO.scala +++ b/akka-actor/src/main/scala/akka/actor/IO.scala @@ -96,15 +96,19 @@ object IOActor { case class ByteStringLength(continuation: (ByteString) ⇒ Any, length: Int) extends IOContinuation[ByteString] case class ByteStringDelimited(continuation: (ByteString) ⇒ Any, delimter: ByteString, inclusive: Boolean, scanned: Int) extends IOContinuation[ByteString] case class ByteStringAny(continuation: (ByteString) ⇒ Any) extends IOContinuation[ByteString] + + sealed trait TailRec[A] + case class Return[A](result: A) extends TailRec[A] + case class Call[A](thunk: () ⇒ TailRec[A] @cps[Any]) extends TailRec[A] + def tailrec[A](comp: TailRec[A]): A @cps[Any] = comp match { + case Call(thunk) ⇒ tailrec(thunk()) + case Return(x) ⇒ x + } } trait IOActor extends Actor with IO { import IOActor._ - protected var sequentialIO = true - - protected var idleWakeup = false - private val _messages: mutable.Queue[MessageInvocation] = mutable.Queue.empty private var _state: Map[IO.Handle, HandleState] = Map.empty @@ -137,17 +141,36 @@ trait IOActor extends Actor with IO { run(handle) } + protected def loop(block: ⇒ Any @cps[Any]): Unit @cps[Any] = { + def f(): TailRec[Unit] @cps[Any] = { + block + Call(() ⇒ f()) + } + tailrec(f()) + } + + protected def loopWhile(test: ⇒ Boolean)(block: ⇒ Any @cps[Any]): Unit @cps[Any] = { + def f(): TailRec[Unit] @cps[Any] = { + if (test) { + block + Call(() ⇒ f()) + } else Return(()) + } + tailrec(f()) + } + final def receive = { case IO.Read(handle, newBytes) ⇒ val st = state(handle) st.readBytes :+= newBytes - if (st.messages.isEmpty && idleWakeup) reset { _receiveIO(IO.WakeUp(handle)) } - else run(handle) - case IO.Connected(handle) ⇒ + run(handle) + case msg@IO.Connected(handle) ⇒ state(handle).connected = true - case IO.Closed(handle) ⇒ + if (_receiveIO.isDefinedAt(msg)) reset { _receiveIO(msg) } + case msg@IO.Closed(handle) ⇒ _state -= handle // TODO: clean up better - case msg if sequentialIO && _continuations.nonEmpty ⇒ + if (_receiveIO.isDefinedAt(msg)) reset { _receiveIO(msg) } + case msg if _continuations.nonEmpty ⇒ _messages enqueue self.currentMessage case msg if _receiveIO.isDefinedAt(msg) ⇒ reset { _receiveIO(msg) } @@ -198,7 +221,7 @@ trait IOActor extends Actor with IO { } } } else { - while ((_continuations.isEmpty || !sequentialIO) && _messages.nonEmpty) { + while (_continuations.isEmpty && _messages.nonEmpty) { self invoke _messages.dequeue } }