From ec4e7f7b035935aa2db8627bc5eef242a4f5d53b Mon Sep 17 00:00:00 2001 From: Derek Williams Date: Fri, 27 May 2011 18:54:16 -0600 Subject: [PATCH] Add support for reading all bytes, or reading up until a delimiter --- .../test/scala/akka/actor/actor/IOActor.scala | 10 +-- akka-actor/src/main/scala/akka/actor/IO.scala | 90 ++++++++++++------- 2 files changed, 63 insertions(+), 37 deletions(-) diff --git a/akka-actor-tests/src/test/scala/akka/actor/actor/IOActor.scala b/akka-actor-tests/src/test/scala/akka/actor/actor/IOActor.scala index e57bf30b53..2b20232c1b 100644 --- a/akka-actor-tests/src/test/scala/akka/actor/actor/IOActor.scala +++ b/akka-actor-tests/src/test/scala/akka/actor/actor/IOActor.scala @@ -15,23 +15,23 @@ object IOActorSpec { class SimpleEchoServer(host: String, port: Int, ioManager: ActorRef) extends Actor with IO { - var serverHandle: Option[IO.Handle] = None - var clientHandles: Set[IO.Handle] = Set.empty + var server: Option[IO.Handle] = None + var clients: Set[IO.Handle] = Set.empty override def preStart = { - serverHandle = Some(listen(ioManager, host, port)) + server = Some(listen(ioManager, host, port)) } def receive = { case IO.NewConnection(handle) ⇒ println("S: Client connected") - clientHandles += accept(handle, self) + clients += accept(handle, self) case IO.Read(handle, bytes) ⇒ println("S: Echoing data") write(handle, bytes) case IO.Closed(handle) ⇒ println("S: Connection closed") - clientHandles -= handle + clients -= handle } } diff --git a/akka-actor/src/main/scala/akka/actor/IO.scala b/akka-actor/src/main/scala/akka/actor/IO.scala index 9dba701d44..d6220f5d85 100644 --- a/akka-actor/src/main/scala/akka/actor/IO.scala +++ b/akka-actor/src/main/scala/akka/actor/IO.scala @@ -4,7 +4,7 @@ package akka.actor import akka.config.Supervision.Permanent -import akka.util.ByteString +import akka.util.{ ByteString, ByteRope } import akka.dispatch.MessageInvocation import java.net.InetSocketAddress @@ -85,9 +85,14 @@ trait IO { } object IOActor { - class HandleState(val messages: mutable.Queue[MessageInvocation], val readBytes: mutable.Queue[ByteString], var readBytesLength: Int) { - def this() = this(mutable.Queue.empty, mutable.Queue.empty, 0) + class HandleState(val messages: mutable.Queue[MessageInvocation], var readBytes: ByteRope, var connected: Boolean) { + def this() = this(mutable.Queue.empty, ByteRope.empty, false) } + + sealed trait IOContinuation[A] { def continuation: (A) ⇒ Unit } + case class ByteStringLength(continuation: (ByteString) ⇒ Unit, length: Int) extends IOContinuation[ByteString] + case class ByteStringDelimited(continuation: (ByteString) ⇒ Unit, delimter: ByteString, inclusive: Boolean, scanned: Int) extends IOContinuation[ByteString] + case class ByteStringAny(continuation: (ByteString) ⇒ Unit) extends IOContinuation[ByteString] } trait IOActor extends Actor with IO { @@ -99,7 +104,7 @@ trait IOActor extends Actor with IO { private var _state: Map[IO.Handle, HandleState] = Map.empty - private var _continuations: Map[MessageInvocation, (Int, ByteString ⇒ Unit)] = Map.empty + private var _continuations: Map[MessageInvocation, IOContinuation[_]] = Map.empty private def state(handle: IO.Handle): HandleState = _state.get(handle) match { case Some(s) ⇒ s @@ -111,21 +116,31 @@ trait IOActor extends Actor with IO { protected def read(handle: IO.Handle, len: Int): ByteString @suspendable = shift { cont: (ByteString ⇒ Unit) ⇒ state(handle).messages enqueue self.currentMessage - _continuations += (self.currentMessage -> (len, cont)) + _continuations += (self.currentMessage -> ByteStringLength(cont, len)) run(handle) } - // TODO: read(handle): ByteString, to read at least 1 byte - // TODO: read(handle, until: ByteString): ByteString, to read until match + protected def read(handle: IO.Handle): ByteString @suspendable = shift { cont: (ByteString ⇒ Unit) ⇒ + state(handle).messages enqueue self.currentMessage + _continuations += (self.currentMessage -> ByteStringAny(cont)) + run(handle) + } + + protected def read(handle: IO.Handle, delimiter: ByteString, inclusive: Boolean = true): ByteString @suspendable = shift { cont: (ByteString ⇒ Unit) ⇒ + state(handle).messages enqueue self.currentMessage + _continuations += (self.currentMessage -> ByteStringDelimited(cont, delimiter, inclusive, 0)) + run(handle) + } final def receive = { case IO.Read(handle, newBytes) ⇒ val st = state(handle) - st.readBytes enqueue newBytes - st.readBytesLength += newBytes.length + st.readBytes :+= newBytes run(handle) - case IO.Connected(handle) ⇒ () - case IO.Closed(handle) ⇒ _state -= handle // TODO: clean up better + case IO.Connected(handle) ⇒ + state(handle).connected = true + case IO.Closed(handle) ⇒ + _state -= handle // TODO: clean up better case msg if sequentialIO && _continuations.nonEmpty ⇒ _messages enqueue self.currentMessage case msg if _receiveIO.isDefinedAt(msg) ⇒ @@ -143,27 +158,38 @@ trait IOActor extends Actor with IO { if (st.messages.nonEmpty) { val msg = st.messages.head self.currentMessage = msg - val Some((waitingFor, continuation)) = _continuations.get(msg) - if (st.readBytesLength >= waitingFor) { - st.messages.dequeue - var left = waitingFor - var take: List[ByteString] = Nil - while (left > 0 && left >= st.readBytes.head.length) { - val bytes = st.readBytes.dequeue - st.readBytesLength -= bytes.length - left -= bytes.length - take ::= bytes - } - if (left > 0) { - val bytes = st.readBytes.dequeue - take ::= bytes take left - (bytes drop left) +=: st.readBytes - st.readBytesLength -= left - } - val bytes = ByteString.concat(take.reverse: _*) - _continuations -= msg - continuation(bytes) - run(handle) + _continuations.get(msg) match { + case Some(ByteStringLength(continuation, waitingFor)) ⇒ + if (st.readBytes.length >= waitingFor) { + st.messages.dequeue + val bytes = st.readBytes.take(waitingFor).toByteString + st.readBytes = st.readBytes.drop(waitingFor) + _continuations -= msg + continuation(bytes) + run(handle) + } + case Some(ByteStringDelimited(continuation, delimiter, inclusive, scanned)) ⇒ + val idx = st.readBytes.indexOfSlice(delimiter, scanned) + if (idx >= 0) { + st.messages.dequeue + val index = if (inclusive) idx + delimiter.length else idx + val bytes = st.readBytes.take(index).toByteString + st.readBytes = st.readBytes.drop(index) + _continuations -= msg + continuation(bytes) + run(handle) + } else { + _continuations += (msg -> ByteStringDelimited(continuation, delimiter, inclusive, math.min(idx - delimiter.length, 0))) + } + case Some(ByteStringAny(continuation)) ⇒ + if (st.readBytes.length > 0) { + st.messages.dequeue + val bytes = st.readBytes.toByteString + st.readBytes = ByteRope.empty + _continuations -= msg + continuation(bytes) + run(handle) + } } } else { while ((_continuations.isEmpty || !sequentialIO) && _messages.nonEmpty) {