Add support for reading all bytes, or reading up until a delimiter

This commit is contained in:
Derek Williams 2011-05-27 18:54:16 -06:00
parent 4cc901c708
commit ec4e7f7b03
2 changed files with 63 additions and 37 deletions

View file

@ -15,23 +15,23 @@ object IOActorSpec {
class SimpleEchoServer(host: String, port: Int, ioManager: ActorRef) extends Actor with IO {
var serverHandle: Option[IO.Handle] = None
var clientHandles: Set[IO.Handle] = Set.empty
var server: Option[IO.Handle] = None
var clients: Set[IO.Handle] = Set.empty
override def preStart = {
serverHandle = Some(listen(ioManager, host, port))
server = Some(listen(ioManager, host, port))
}
def receive = {
case IO.NewConnection(handle)
println("S: Client connected")
clientHandles += accept(handle, self)
clients += accept(handle, self)
case IO.Read(handle, bytes)
println("S: Echoing data")
write(handle, bytes)
case IO.Closed(handle)
println("S: Connection closed")
clientHandles -= handle
clients -= handle
}
}

View file

@ -4,7 +4,7 @@
package akka.actor
import akka.config.Supervision.Permanent
import akka.util.ByteString
import akka.util.{ ByteString, ByteRope }
import akka.dispatch.MessageInvocation
import java.net.InetSocketAddress
@ -85,9 +85,14 @@ trait IO {
}
object IOActor {
class HandleState(val messages: mutable.Queue[MessageInvocation], val readBytes: mutable.Queue[ByteString], var readBytesLength: Int) {
def this() = this(mutable.Queue.empty, mutable.Queue.empty, 0)
class HandleState(val messages: mutable.Queue[MessageInvocation], var readBytes: ByteRope, var connected: Boolean) {
def this() = this(mutable.Queue.empty, ByteRope.empty, false)
}
sealed trait IOContinuation[A] { def continuation: (A) Unit }
case class ByteStringLength(continuation: (ByteString) Unit, length: Int) extends IOContinuation[ByteString]
case class ByteStringDelimited(continuation: (ByteString) Unit, delimter: ByteString, inclusive: Boolean, scanned: Int) extends IOContinuation[ByteString]
case class ByteStringAny(continuation: (ByteString) Unit) extends IOContinuation[ByteString]
}
trait IOActor extends Actor with IO {
@ -99,7 +104,7 @@ trait IOActor extends Actor with IO {
private var _state: Map[IO.Handle, HandleState] = Map.empty
private var _continuations: Map[MessageInvocation, (Int, ByteString Unit)] = Map.empty
private var _continuations: Map[MessageInvocation, IOContinuation[_]] = Map.empty
private def state(handle: IO.Handle): HandleState = _state.get(handle) match {
case Some(s) s
@ -111,21 +116,31 @@ trait IOActor extends Actor with IO {
protected def read(handle: IO.Handle, len: Int): ByteString @suspendable = shift { cont: (ByteString Unit)
state(handle).messages enqueue self.currentMessage
_continuations += (self.currentMessage -> (len, cont))
_continuations += (self.currentMessage -> ByteStringLength(cont, len))
run(handle)
}
// TODO: read(handle): ByteString, to read at least 1 byte
// TODO: read(handle, until: ByteString): ByteString, to read until match
protected def read(handle: IO.Handle): ByteString @suspendable = shift { cont: (ByteString Unit)
state(handle).messages enqueue self.currentMessage
_continuations += (self.currentMessage -> ByteStringAny(cont))
run(handle)
}
protected def read(handle: IO.Handle, delimiter: ByteString, inclusive: Boolean = true): ByteString @suspendable = shift { cont: (ByteString Unit)
state(handle).messages enqueue self.currentMessage
_continuations += (self.currentMessage -> ByteStringDelimited(cont, delimiter, inclusive, 0))
run(handle)
}
final def receive = {
case IO.Read(handle, newBytes)
val st = state(handle)
st.readBytes enqueue newBytes
st.readBytesLength += newBytes.length
st.readBytes :+= newBytes
run(handle)
case IO.Connected(handle) ()
case IO.Closed(handle) _state -= handle // TODO: clean up better
case IO.Connected(handle)
state(handle).connected = true
case IO.Closed(handle)
_state -= handle // TODO: clean up better
case msg if sequentialIO && _continuations.nonEmpty
_messages enqueue self.currentMessage
case msg if _receiveIO.isDefinedAt(msg)
@ -143,27 +158,38 @@ trait IOActor extends Actor with IO {
if (st.messages.nonEmpty) {
val msg = st.messages.head
self.currentMessage = msg
val Some((waitingFor, continuation)) = _continuations.get(msg)
if (st.readBytesLength >= waitingFor) {
st.messages.dequeue
var left = waitingFor
var take: List[ByteString] = Nil
while (left > 0 && left >= st.readBytes.head.length) {
val bytes = st.readBytes.dequeue
st.readBytesLength -= bytes.length
left -= bytes.length
take ::= bytes
}
if (left > 0) {
val bytes = st.readBytes.dequeue
take ::= bytes take left
(bytes drop left) +=: st.readBytes
st.readBytesLength -= left
}
val bytes = ByteString.concat(take.reverse: _*)
_continuations -= msg
continuation(bytes)
run(handle)
_continuations.get(msg) match {
case Some(ByteStringLength(continuation, waitingFor))
if (st.readBytes.length >= waitingFor) {
st.messages.dequeue
val bytes = st.readBytes.take(waitingFor).toByteString
st.readBytes = st.readBytes.drop(waitingFor)
_continuations -= msg
continuation(bytes)
run(handle)
}
case Some(ByteStringDelimited(continuation, delimiter, inclusive, scanned))
val idx = st.readBytes.indexOfSlice(delimiter, scanned)
if (idx >= 0) {
st.messages.dequeue
val index = if (inclusive) idx + delimiter.length else idx
val bytes = st.readBytes.take(index).toByteString
st.readBytes = st.readBytes.drop(index)
_continuations -= msg
continuation(bytes)
run(handle)
} else {
_continuations += (msg -> ByteStringDelimited(continuation, delimiter, inclusive, math.min(idx - delimiter.length, 0)))
}
case Some(ByteStringAny(continuation))
if (st.readBytes.length > 0) {
st.messages.dequeue
val bytes = st.readBytes.toByteString
st.readBytes = ByteRope.empty
_continuations -= msg
continuation(bytes)
run(handle)
}
}
} else {
while ((_continuations.isEmpty || !sequentialIO) && _messages.nonEmpty) {