Add support for reading all bytes, or reading up until a delimiter
This commit is contained in:
parent
4cc901c708
commit
ec4e7f7b03
2 changed files with 63 additions and 37 deletions
|
|
@ -15,23 +15,23 @@ object IOActorSpec {
|
|||
|
||||
class SimpleEchoServer(host: String, port: Int, ioManager: ActorRef) extends Actor with IO {
|
||||
|
||||
var serverHandle: Option[IO.Handle] = None
|
||||
var clientHandles: Set[IO.Handle] = Set.empty
|
||||
var server: Option[IO.Handle] = None
|
||||
var clients: Set[IO.Handle] = Set.empty
|
||||
|
||||
override def preStart = {
|
||||
serverHandle = Some(listen(ioManager, host, port))
|
||||
server = Some(listen(ioManager, host, port))
|
||||
}
|
||||
|
||||
def receive = {
|
||||
case IO.NewConnection(handle) ⇒
|
||||
println("S: Client connected")
|
||||
clientHandles += accept(handle, self)
|
||||
clients += accept(handle, self)
|
||||
case IO.Read(handle, bytes) ⇒
|
||||
println("S: Echoing data")
|
||||
write(handle, bytes)
|
||||
case IO.Closed(handle) ⇒
|
||||
println("S: Connection closed")
|
||||
clientHandles -= handle
|
||||
clients -= handle
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
package akka.actor
|
||||
|
||||
import akka.config.Supervision.Permanent
|
||||
import akka.util.ByteString
|
||||
import akka.util.{ ByteString, ByteRope }
|
||||
import akka.dispatch.MessageInvocation
|
||||
|
||||
import java.net.InetSocketAddress
|
||||
|
|
@ -85,9 +85,14 @@ trait IO {
|
|||
}
|
||||
|
||||
object IOActor {
|
||||
class HandleState(val messages: mutable.Queue[MessageInvocation], val readBytes: mutable.Queue[ByteString], var readBytesLength: Int) {
|
||||
def this() = this(mutable.Queue.empty, mutable.Queue.empty, 0)
|
||||
class HandleState(val messages: mutable.Queue[MessageInvocation], var readBytes: ByteRope, var connected: Boolean) {
|
||||
def this() = this(mutable.Queue.empty, ByteRope.empty, false)
|
||||
}
|
||||
|
||||
sealed trait IOContinuation[A] { def continuation: (A) ⇒ Unit }
|
||||
case class ByteStringLength(continuation: (ByteString) ⇒ Unit, length: Int) extends IOContinuation[ByteString]
|
||||
case class ByteStringDelimited(continuation: (ByteString) ⇒ Unit, delimter: ByteString, inclusive: Boolean, scanned: Int) extends IOContinuation[ByteString]
|
||||
case class ByteStringAny(continuation: (ByteString) ⇒ Unit) extends IOContinuation[ByteString]
|
||||
}
|
||||
|
||||
trait IOActor extends Actor with IO {
|
||||
|
|
@ -99,7 +104,7 @@ trait IOActor extends Actor with IO {
|
|||
|
||||
private var _state: Map[IO.Handle, HandleState] = Map.empty
|
||||
|
||||
private var _continuations: Map[MessageInvocation, (Int, ByteString ⇒ Unit)] = Map.empty
|
||||
private var _continuations: Map[MessageInvocation, IOContinuation[_]] = Map.empty
|
||||
|
||||
private def state(handle: IO.Handle): HandleState = _state.get(handle) match {
|
||||
case Some(s) ⇒ s
|
||||
|
|
@ -111,21 +116,31 @@ trait IOActor extends Actor with IO {
|
|||
|
||||
protected def read(handle: IO.Handle, len: Int): ByteString @suspendable = shift { cont: (ByteString ⇒ Unit) ⇒
|
||||
state(handle).messages enqueue self.currentMessage
|
||||
_continuations += (self.currentMessage -> (len, cont))
|
||||
_continuations += (self.currentMessage -> ByteStringLength(cont, len))
|
||||
run(handle)
|
||||
}
|
||||
|
||||
// TODO: read(handle): ByteString, to read at least 1 byte
|
||||
// TODO: read(handle, until: ByteString): ByteString, to read until match
|
||||
protected def read(handle: IO.Handle): ByteString @suspendable = shift { cont: (ByteString ⇒ Unit) ⇒
|
||||
state(handle).messages enqueue self.currentMessage
|
||||
_continuations += (self.currentMessage -> ByteStringAny(cont))
|
||||
run(handle)
|
||||
}
|
||||
|
||||
protected def read(handle: IO.Handle, delimiter: ByteString, inclusive: Boolean = true): ByteString @suspendable = shift { cont: (ByteString ⇒ Unit) ⇒
|
||||
state(handle).messages enqueue self.currentMessage
|
||||
_continuations += (self.currentMessage -> ByteStringDelimited(cont, delimiter, inclusive, 0))
|
||||
run(handle)
|
||||
}
|
||||
|
||||
final def receive = {
|
||||
case IO.Read(handle, newBytes) ⇒
|
||||
val st = state(handle)
|
||||
st.readBytes enqueue newBytes
|
||||
st.readBytesLength += newBytes.length
|
||||
st.readBytes :+= newBytes
|
||||
run(handle)
|
||||
case IO.Connected(handle) ⇒ ()
|
||||
case IO.Closed(handle) ⇒ _state -= handle // TODO: clean up better
|
||||
case IO.Connected(handle) ⇒
|
||||
state(handle).connected = true
|
||||
case IO.Closed(handle) ⇒
|
||||
_state -= handle // TODO: clean up better
|
||||
case msg if sequentialIO && _continuations.nonEmpty ⇒
|
||||
_messages enqueue self.currentMessage
|
||||
case msg if _receiveIO.isDefinedAt(msg) ⇒
|
||||
|
|
@ -143,27 +158,38 @@ trait IOActor extends Actor with IO {
|
|||
if (st.messages.nonEmpty) {
|
||||
val msg = st.messages.head
|
||||
self.currentMessage = msg
|
||||
val Some((waitingFor, continuation)) = _continuations.get(msg)
|
||||
if (st.readBytesLength >= waitingFor) {
|
||||
st.messages.dequeue
|
||||
var left = waitingFor
|
||||
var take: List[ByteString] = Nil
|
||||
while (left > 0 && left >= st.readBytes.head.length) {
|
||||
val bytes = st.readBytes.dequeue
|
||||
st.readBytesLength -= bytes.length
|
||||
left -= bytes.length
|
||||
take ::= bytes
|
||||
}
|
||||
if (left > 0) {
|
||||
val bytes = st.readBytes.dequeue
|
||||
take ::= bytes take left
|
||||
(bytes drop left) +=: st.readBytes
|
||||
st.readBytesLength -= left
|
||||
}
|
||||
val bytes = ByteString.concat(take.reverse: _*)
|
||||
_continuations -= msg
|
||||
continuation(bytes)
|
||||
run(handle)
|
||||
_continuations.get(msg) match {
|
||||
case Some(ByteStringLength(continuation, waitingFor)) ⇒
|
||||
if (st.readBytes.length >= waitingFor) {
|
||||
st.messages.dequeue
|
||||
val bytes = st.readBytes.take(waitingFor).toByteString
|
||||
st.readBytes = st.readBytes.drop(waitingFor)
|
||||
_continuations -= msg
|
||||
continuation(bytes)
|
||||
run(handle)
|
||||
}
|
||||
case Some(ByteStringDelimited(continuation, delimiter, inclusive, scanned)) ⇒
|
||||
val idx = st.readBytes.indexOfSlice(delimiter, scanned)
|
||||
if (idx >= 0) {
|
||||
st.messages.dequeue
|
||||
val index = if (inclusive) idx + delimiter.length else idx
|
||||
val bytes = st.readBytes.take(index).toByteString
|
||||
st.readBytes = st.readBytes.drop(index)
|
||||
_continuations -= msg
|
||||
continuation(bytes)
|
||||
run(handle)
|
||||
} else {
|
||||
_continuations += (msg -> ByteStringDelimited(continuation, delimiter, inclusive, math.min(idx - delimiter.length, 0)))
|
||||
}
|
||||
case Some(ByteStringAny(continuation)) ⇒
|
||||
if (st.readBytes.length > 0) {
|
||||
st.messages.dequeue
|
||||
val bytes = st.readBytes.toByteString
|
||||
st.readBytes = ByteRope.empty
|
||||
_continuations -= msg
|
||||
continuation(bytes)
|
||||
run(handle)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
while ((_continuations.isEmpty || !sequentialIO) && _messages.nonEmpty) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue