Update IO to new iteratee based api

This commit is contained in:
Derek Williams 2011-12-31 09:18:37 -07:00
parent e8f16952a6
commit a1f9b81ce9
3 changed files with 870 additions and 479 deletions

View file

@ -4,8 +4,6 @@
package akka.actor
import org.scalatest.BeforeAndAfterEach
import akka.util.ByteString
import akka.util.cps._
import scala.util.continuations._
@ -13,214 +11,237 @@ import akka.testkit._
import akka.dispatch.{ Await, Future }
object IOActorSpec {
import IO._
class SimpleEchoServer(host: String, port: Int, ioManager: ActorRef, started: TestLatch) extends Actor {
class SimpleEchoServer(host: String, port: Int, started: TestLatch) extends Actor {
import context.dispatcher
implicit val timeout = context.system.settings.ActorTimeout
IO listen (host, port)
override def preStart = {
listen(ioManager, host, port)
started.open()
}
started.open
def createWorker = context.actorOf(Props(new Actor with IO {
def receiveIO = {
case NewClient(server)
val socket = server.accept()
loopC {
val bytes = socket.read()
socket write bytes
}
}
}))
val state = IO.IterateeRef.Map.sync[IO.Handle]()
def receive = {
case msg: NewClient
createWorker forward msg
}
}
class SimpleEchoClient(host: String, port: Int, ioManager: ActorRef) extends Actor with IO {
lazy val socket: SocketHandle = connect(ioManager, host, port)(reader)
lazy val reader: ActorRef = context.actorOf(Props({
new Actor with IO {
def receiveIO = {
case length: Int
val bytes = socket.read(length)
sender ! bytes
}
}
}))
def receiveIO = {
case bytes: ByteString
case IO.NewClient(server)
val socket = server.accept()
state(socket) flatMap { _
IO repeat {
IO.takeAny map { bytes
socket write bytes
reader forward bytes.length
}
}
}
case IO.Read(socket, bytes)
state(socket)(IO Chunk bytes)
case IO.Closed(socket, cause)
state -= socket
}
}
class SimpleEchoClient(host: String, port: Int) extends Actor {
val socket = IO connect (host, port)
val state = IO.IterateeRef.sync()
def receive = {
case bytes: ByteString
val source = sender
socket write bytes
for {
_ state
bytes IO take bytes.length
} yield source ! bytes
case IO.Read(socket, bytes)
state(IO Chunk bytes)
case IO.Connected(socket)
case IO.Closed(socket, cause)
}
}
sealed trait KVCommand {
def bytes: ByteString
}
case class KVSet(key: String, value: String) extends KVCommand {
val bytes = ByteString("SET " + key + " " + value.length + "\r\n" + value + "\r\n")
}
case class KVGet(key: String) extends KVCommand {
val bytes = ByteString("GET " + key + "\r\n")
}
case object KVGetAll extends KVCommand {
val bytes = ByteString("GETALL\r\n")
}
// Basic Redis-style protocol
class KVStore(host: String, port: Int, ioManager: ActorRef, started: TestLatch) extends Actor {
class KVStore(host: String, port: Int, started: TestLatch) extends Actor {
import context.dispatcher
implicit val timeout = context.system.settings.ActorTimeout
val state = IO.IterateeRef.Map.sync[IO.Handle]()
var kvs: Map[String, ByteString] = Map.empty
var kvs: Map[String, String] = Map.empty
override def preStart = {
listen(ioManager, host, port)
started.open()
}
IO listen (host, port)
def createWorker = context.actorOf(Props(new Actor with IO {
def receiveIO = {
case NewClient(server)
val socket = server.accept()
loopC {
val cmd = socket.read(ByteString("\r\n")).utf8String
val result = matchC(cmd.split(' ')) {
case Array("SET", key, length)
val value = socket read length.toInt
server.owner ? (('set, key, value)) map ((x: Any) ByteString("+OK\r\n"))
case Array("GET", key)
server.owner ? (('get, key)) map {
case Some(b: ByteString) ByteString("$" + b.length + "\r\n") ++ b
case None ByteString("$-1\r\n")
}
case Array("GETALL")
server.owner ? 'getall map {
case m: Map[_, _]
(ByteString("*" + (m.size * 2) + "\r\n") /: m) {
case (result, (k: String, v: ByteString))
val kBytes = ByteString(k)
result ++ ByteString("$" + kBytes.length + "\r\n") ++ kBytes ++ ByteString("$" + v.length + "\r\n") ++ v
}
}
}
result recover {
case e ByteString("-" + e.getClass.toString + "\r\n")
} foreach { bytes
socket write bytes
}
}
}
}))
started.open
val EOL = ByteString("\r\n")
def receive = {
case msg: NewClient createWorker forward msg
case ('set, key: String, value: ByteString)
kvs += (key -> value)
sender.tell((), self)
case ('get, key: String) sender.tell(kvs.get(key), self)
case 'getall sender.tell(kvs, self)
case IO.NewClient(server)
val socket = server.accept()
state(socket) flatMap { _
IO repeat {
IO takeUntil EOL map (_.utf8String split ' ') flatMap {
case Array("SET", key, length)
for {
value IO take length.toInt
_ IO takeUntil EOL
} yield {
kvs += (key -> value.utf8String)
ByteString("+OK\r\n")
}
case Array("GET", key)
IO Iteratee {
kvs get key map { value
ByteString("$" + value.length + "\r\n" + value + "\r\n")
} getOrElse ByteString("$-1\r\n")
}
case Array("GETALL")
IO Iteratee {
(ByteString("*" + (kvs.size * 2) + "\r\n") /: kvs) {
case (result, (k, v))
val kBytes = ByteString(k)
val vBytes = ByteString(v)
result ++
ByteString("$" + kBytes.length) ++ EOL ++
kBytes ++ EOL ++
ByteString("$" + vBytes.length) ++ EOL ++
vBytes ++ EOL
}
}
} map (socket write)
}
}
case IO.Read(socket, bytes)
state(socket)(IO Chunk bytes)
case IO.Closed(socket, cause)
state -= socket
}
}
class KVClient(host: String, port: Int, ioManager: ActorRef) extends Actor with IO {
class KVClient(host: String, port: Int) extends Actor {
import context.dispatcher
implicit val timeout = context.system.settings.ActorTimeout
val socket = IO connect (host, port)
var socket: SocketHandle = _
val state = IO.IterateeRef.sync()
val EOL = ByteString("\r\n")
def receive = {
case cmd: KVCommand
val source = sender
socket write cmd.bytes
for {
_ state
result readResult
} yield result.fold(err source ! Status.Failure(new RuntimeException(err)), source !)
case IO.Read(socket, bytes)
state(IO Chunk bytes)
case IO.Connected(socket)
case IO.Closed(socket, cause)
override def preStart {
socket = connect(ioManager, host, port)
}
def reply(msg: Any) = sender.tell(msg, self)
def receiveIO = {
case ('set, key: String, value: ByteString)
socket write (ByteString("SET " + key + " " + value.length + "\r\n") ++ value)
reply(readResult)
case ('get, key: String)
socket write ByteString("GET " + key + "\r\n")
reply(readResult)
case 'getall
socket write ByteString("GETALL\r\n")
reply(readResult)
}
def readResult = {
val resultType = socket.read(1).utf8String
resultType match {
case "+" socket.read(ByteString("\r\n")).utf8String
case "-" sys error socket.read(ByteString("\r\n")).utf8String
def readResult: IO.Iteratee[Either[String, Any]] = {
IO take 1 map (_.utf8String) flatMap {
case "+" IO takeUntil EOL map (msg Right(msg.utf8String))
case "-" IO takeUntil EOL map (err Left(err.utf8String))
case "$"
val length = socket.read(ByteString("\r\n")).utf8String
socket.read(length.toInt)
IO takeUntil EOL map (_.utf8String.toInt) flatMap {
case -1 IO Iteratee Right(None)
case length
for {
value IO take length
_ IO takeUntil EOL
} yield Right(Some(value.utf8String))
}
case "*"
val count = socket.read(ByteString("\r\n")).utf8String
var result: Map[String, ByteString] = Map.empty
repeatC(count.toInt / 2) {
val k = readBytes
val v = readBytes
result += (k.utf8String -> v)
}
result
case _ sys error "Unexpected response"
IO takeUntil EOL map (_.utf8String.toInt) flatMap {
case -1 IO Iteratee Right(None)
case length
IO.takeList(length)(readResult) map { list
((Right(Map()): Either[String, Map[String, String]]) /: list.grouped(2)) {
case (Right(m), List(Right(Some(k: String)), Right(Some(v: String)))) Right(m + (k -> v))
case (Right(_), _) Left("Unexpected Response")
case (left, _) left
}
}
def readBytes = {
val resultType = socket.read(1).utf8String
if (resultType != "$") sys error "Unexpected response"
val length = socket.read(ByteString("\r\n")).utf8String
socket.read(length.toInt)
}
case _ IO Iteratee Left("Unexpected Response")
}
}
}
}
@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner])
class IOActorSpec extends AkkaSpec with BeforeAndAfterEach with DefaultTimeout {
class IOActorSpec extends AkkaSpec with DefaultTimeout {
import IOActorSpec._
"an IO Actor" must {
"run echo server" in {
val started = TestLatch(1)
val ioManager = system.actorOf(Props(new IOManager(2))) // teeny tiny buffer
val server = system.actorOf(Props(new SimpleEchoServer("localhost", 8064, ioManager, started)))
Await.ready(started, timeout.duration)
val client = system.actorOf(Props(new SimpleEchoClient("localhost", 8064, ioManager)))
val server = system.actorOf(Props(new SimpleEchoServer("localhost", 8064, started)))
Await.ready(started, TestLatch.DefaultTimeout)
val client = system.actorOf(Props(new SimpleEchoClient("localhost", 8064)))
val f1 = client ? ByteString("Hello World!1")
val f2 = client ? ByteString("Hello World!2")
val f3 = client ? ByteString("Hello World!3")
Await.result(f1, timeout.duration) must equal(ByteString("Hello World!1"))
Await.result(f2, timeout.duration) must equal(ByteString("Hello World!2"))
Await.result(f3, timeout.duration) must equal(ByteString("Hello World!3"))
system.stop(client)
system.stop(server)
system.stop(ioManager)
Await.result(f1, TestLatch.DefaultTimeout) must equal(ByteString("Hello World!1"))
Await.result(f2, TestLatch.DefaultTimeout) must equal(ByteString("Hello World!2"))
Await.result(f3, TestLatch.DefaultTimeout) must equal(ByteString("Hello World!3"))
}
"run echo server under high load" in {
val started = TestLatch(1)
val ioManager = system.actorOf(Props(new IOManager()))
val server = system.actorOf(Props(new SimpleEchoServer("localhost", 8065, ioManager, started)))
Await.ready(started, timeout.duration)
val client = system.actorOf(Props(new SimpleEchoClient("localhost", 8065, ioManager)))
val server = system.actorOf(Props(new SimpleEchoServer("localhost", 8065, started)))
Await.ready(started, TestLatch.DefaultTimeout)
val client = system.actorOf(Props(new SimpleEchoClient("localhost", 8065)))
val list = List.range(0, 1000)
val f = Future.traverse(list)(i client ? ByteString(i.toString))
assert(Await.result(f, timeout.duration).size === 1000)
system.stop(client)
system.stop(server)
system.stop(ioManager)
assert(Await.result(f, TestLatch.DefaultTimeout).size === 1000)
}
// Not currently configurable at runtime
/*
"run echo server under high load with small buffer" in {
val started = TestLatch(1)
val ioManager = system.actorOf(Props(new IOManager(2)))
val server = system.actorOf(Props(new SimpleEchoServer("localhost", 8066, ioManager, started)))
Await.ready(started, timeout.duration)
val client = system.actorOf(Props(new SimpleEchoClient("localhost", 8066, ioManager)))
val ioManager = actorOf(new IOManager(2))
val server = actorOf(new SimpleEchoServer("localhost", 8066, ioManager, started))
started.await
val client = actorOf(new SimpleEchoClient("localhost", 8066, ioManager))
val list = List.range(0, 1000)
val f = Future.traverse(list)(i client ? ByteString(i.toString))
assert(Await.result(f, timeout.duration).size === 1000)
@ -228,32 +249,28 @@ class IOActorSpec extends AkkaSpec with BeforeAndAfterEach with DefaultTimeout {
system.stop(server)
system.stop(ioManager)
}
*/
"run key-value store" in {
val started = TestLatch(1)
val ioManager = system.actorOf(Props(new IOManager(2))) // teeny tiny buffer
val server = system.actorOf(Props(new KVStore("localhost", 8067, ioManager, started)))
Await.ready(started, timeout.duration)
val client1 = system.actorOf(Props(new KVClient("localhost", 8067, ioManager)))
val client2 = system.actorOf(Props(new KVClient("localhost", 8067, ioManager)))
val f1 = client1 ? (('set, "hello", ByteString("World")))
val f2 = client1 ? (('set, "test", ByteString("No one will read me")))
val f3 = client1 ? (('get, "hello"))
Await.ready(f2, timeout.duration)
val f4 = client2 ? (('set, "test", ByteString("I'm a test!")))
Await.ready(f4, timeout.duration)
val f5 = client1 ? (('get, "test"))
val f6 = client2 ? 'getall
Await.result(f1, timeout.duration) must equal("OK")
Await.result(f2, timeout.duration) must equal("OK")
Await.result(f3, timeout.duration) must equal(ByteString("World"))
Await.result(f4, timeout.duration) must equal("OK")
Await.result(f5, timeout.duration) must equal(ByteString("I'm a test!"))
Await.result(f6, timeout.duration) must equal(Map("hello" -> ByteString("World"), "test" -> ByteString("I'm a test!")))
system.stop(client1)
system.stop(client2)
system.stop(server)
system.stop(ioManager)
val server = system.actorOf(Props(new KVStore("localhost", 8067, started)))
Await.ready(started, TestLatch.DefaultTimeout)
val client1 = system.actorOf(Props(new KVClient("localhost", 8067)))
val client2 = system.actorOf(Props(new KVClient("localhost", 8067)))
val f1 = client1 ? KVSet("hello", "World")
val f2 = client1 ? KVSet("test", "No one will read me")
val f3 = client1 ? KVGet("hello")
Await.ready(f2, TestLatch.DefaultTimeout)
val f4 = client2 ? KVSet("test", "I'm a test!")
Await.ready(f4, TestLatch.DefaultTimeout)
val f5 = client1 ? KVGet("test")
val f6 = client2 ? KVGetAll
Await.result(f1, TestLatch.DefaultTimeout) must equal("OK")
Await.result(f2, TestLatch.DefaultTimeout) must equal("OK")
Await.result(f3, TestLatch.DefaultTimeout) must equal(Some("World"))
Await.result(f4, TestLatch.DefaultTimeout) must equal("OK")
Await.result(f5, TestLatch.DefaultTimeout) must equal(Some("I'm a test!"))
Await.result(f6, TestLatch.DefaultTimeout) must equal(Map("hello" -> "World", "test" -> "I'm a test!"))
}
}

View file

@ -4,10 +4,8 @@
package akka.actor
import akka.util.ByteString
import akka.dispatch.Envelope
import java.net.InetSocketAddress
import java.io.IOException
import java.util.concurrent.atomic.AtomicReference
import java.nio.ByteBuffer
import java.nio.channels.{
SelectableChannel,
@ -20,9 +18,8 @@ import java.nio.channels.{
CancelledKeyException
}
import scala.collection.mutable
import scala.collection.immutable.Queue
import scala.annotation.tailrec
import scala.util.continuations._
import scala.collection.generic.CanBuildFrom
import com.eaio.uuid.UUID
object IO {
@ -44,18 +41,6 @@ object IO {
sealed trait ReadHandle extends Handle with Product {
override def asReadable = this
def read(len: Int)(implicit actor: Actor with IO): ByteString @cps[IOSuspendable[Any]] = shift { cont: (ByteString IOSuspendable[Any])
ByteStringLength(cont, this, actor.context.asInstanceOf[ActorCell].currentMessage, len)
}
def read()(implicit actor: Actor with IO): ByteString @cps[IOSuspendable[Any]] = shift { cont: (ByteString IOSuspendable[Any])
ByteStringAny(cont, this, actor.context.asInstanceOf[ActorCell].currentMessage)
}
def read(delimiter: ByteString, inclusive: Boolean = false)(implicit actor: Actor with IO): ByteString @cps[IOSuspendable[Any]] = shift { cont: (ByteString IOSuspendable[Any])
ByteStringDelimited(cont, this, actor.context.asInstanceOf[ActorCell].currentMessage, delimiter, inclusive, 0)
}
}
sealed trait WriteHandle extends Handle with Product {
@ -89,259 +74,542 @@ object IO {
case class Read(handle: ReadHandle, bytes: ByteString) extends IOMessage
case class Write(handle: WriteHandle, bytes: ByteString) extends IOMessage
def listen(ioManager: ActorRef, address: InetSocketAddress)(implicit owner: ActorRef): ServerHandle = {
def listen(address: InetSocketAddress)(implicit context: ActorContext, owner: ActorRef): ServerHandle = {
val ioManager = IOManager.start()(context.system)
val server = ServerHandle(owner, ioManager)
ioManager ! Listen(server, address)
server
}
def listen(ioManager: ActorRef, host: String, port: Int)(implicit owner: ActorRef): ServerHandle =
listen(ioManager, new InetSocketAddress(host, port))
def listen(host: String, port: Int)(implicit context: ActorContext, owner: ActorRef): ServerHandle =
listen(new InetSocketAddress(host, port))(context, owner)
def connect(ioManager: ActorRef, address: InetSocketAddress)(implicit owner: ActorRef): SocketHandle = {
def listen(address: InetSocketAddress, owner: ActorRef)(implicit context: ActorContext): ServerHandle =
listen(address)(context, owner)
def listen(host: String, port: Int, owner: ActorRef)(implicit context: ActorContext): ServerHandle =
listen(new InetSocketAddress(host, port))(context, owner)
def connect(address: InetSocketAddress)(implicit context: ActorContext, owner: ActorRef): SocketHandle = {
val ioManager = IOManager.start()(context.system)
val socket = SocketHandle(owner, ioManager)
ioManager ! Connect(socket, address)
socket
}
def connect(ioManager: ActorRef, host: String, port: Int)(implicit sender: ActorRef): SocketHandle =
connect(ioManager, new InetSocketAddress(host, port))
def connect(host: String, port: Int)(implicit context: ActorContext, owner: ActorRef): SocketHandle =
connect(new InetSocketAddress(host, port))(context, owner)
private class HandleState(var readBytes: ByteString, var connected: Boolean) {
def this() = this(ByteString.empty, false)
def connect(address: InetSocketAddress, owner: ActorRef)(implicit context: ActorContext): SocketHandle =
connect(address)(context, owner)
def connect(host: String, port: Int, owner: ActorRef)(implicit context: ActorContext): SocketHandle =
connect(new InetSocketAddress(host, port))(context, owner)
sealed trait Input {
def ++(that: Input): Input
}
sealed trait IOSuspendable[+A]
sealed trait CurrentMessage { def message: Envelope }
private case class ByteStringLength(continuation: (ByteString) IOSuspendable[Any], handle: Handle, message: Envelope, length: Int) extends IOSuspendable[ByteString] with CurrentMessage
private case class ByteStringDelimited(continuation: (ByteString) IOSuspendable[Any], handle: Handle, message: Envelope, delimter: ByteString, inclusive: Boolean, scanned: Int) extends IOSuspendable[ByteString] with CurrentMessage
private case class ByteStringAny(continuation: (ByteString) IOSuspendable[Any], handle: Handle, message: Envelope) extends IOSuspendable[ByteString] with CurrentMessage
private case class Retry(message: Envelope) extends IOSuspendable[Nothing]
private case object Idle extends IOSuspendable[Nothing]
object Chunk {
val empty = Chunk(ByteString.empty)
}
trait IO {
this: Actor
import IO._
type ReceiveIO = PartialFunction[Any, Any @cps[IOSuspendable[Any]]]
implicit protected def ioActor: Actor with IO = this
private val _messages: mutable.Queue[Envelope] = mutable.Queue.empty
private var _state: Map[Handle, HandleState] = Map.empty
private var _next: IOSuspendable[Any] = Idle
private def state(handle: Handle): HandleState = _state.get(handle) match {
case Some(s) s
case _
val s = new HandleState()
_state += (handle -> s)
s
}
final def receive: Receive = {
case Read(handle, newBytes)
val st = state(handle)
st.readBytes ++= newBytes
run()
case Connected(socket)
state(socket).connected = true
run()
case msg @ Closed(handle, _)
_state -= handle // TODO: clean up better
if (_receiveIO.isDefinedAt(msg)) {
_next = reset { _receiveIO(msg); Idle }
}
run()
case msg if _next ne Idle
_messages enqueue context.asInstanceOf[ActorCell].currentMessage
case msg if _receiveIO.isDefinedAt(msg)
_next = reset { _receiveIO(msg); Idle }
run()
}
def receiveIO: ReceiveIO
def retry(): Any @cps[IOSuspendable[Any]] =
shift { _: (Any IOSuspendable[Any])
_next match {
case n: CurrentMessage Retry(n.message)
case _ Idle
case class Chunk(bytes: ByteString) extends Input {
def ++(that: Input) = that match {
case Chunk(more) Chunk(bytes ++ more)
case _: EOF that
}
}
private lazy val _receiveIO = receiveIO
// only reinvoke messages from the original message to avoid stack overflow
private var reinvoked = false
private def reinvoke() {
if (!reinvoked && (_next eq Idle) && _messages.nonEmpty) {
try {
reinvoked = true
while ((_next eq Idle) && _messages.nonEmpty) self.asInstanceOf[LocalActorRef].underlying invoke _messages.dequeue
} finally {
reinvoked = false
case class EOF(cause: Option[Exception]) extends Input {
def ++(that: Input) = this
}
object Iteratee {
def apply[A](value: A): Iteratee[A] = Done(value)
def apply(): Iteratee[Unit] = unit
val unit: Iteratee[Unit] = Done(())
}
/**
* A basic Iteratee implementation of Oleg's Iteratee (http://okmij.org/ftp/Streams.html).
* No support for Enumerator or Input types other then ByteString at the moment.
*/
sealed abstract class Iteratee[+A] {
/**
* Applies the given input to the Iteratee, returning the resulting Iteratee
* and the unused Input.
*/
final def apply(input: Input): (Iteratee[A], Input) = this match {
case Cont(f) f(input)
case iter (iter, input)
}
final def get: A = this(EOF(None))._1 match {
case Done(value) value
case Cont(_) sys.error("Divergent Iteratee")
case Failure(e) throw e
}
final def flatMap[B](f: A Iteratee[B]): Iteratee[B] = this match {
case Done(value) f(value)
case Cont(k: Chain[_]) Cont(k :+ f)
case Cont(k) Cont(Chain(k, f))
case failure: Failure failure
}
final def map[B](f: A B): Iteratee[B] = this match {
case Done(value) Done(f(value))
case Cont(k: Chain[_]) Cont(k :+ ((a: A) Done(f(a))))
case Cont(k) Cont(Chain(k, (a: A) Done(f(a))))
case failure: Failure failure
}
}
/**
* An Iteratee representing a result and the remaining ByteString. Also used to
* wrap any constants or precalculated values that need to be composed with
* other Iteratees.
*/
final case class Done[+A](result: A) extends Iteratee[A]
/**
* An Iteratee that still requires more input to calculate it's result.
*/
final case class Cont[+A](f: Input (Iteratee[A], Input)) extends Iteratee[A]
/**
* An Iteratee representing a failure to calcualte a result.
* FIXME: move into 'Cont' as in Oleg's implementation
*/
final case class Failure(exception: Throwable) extends Iteratee[Nothing]
object IterateeRef {
def sync[A](initial: Iteratee[A]): IterateeRefSync[A] = new IterateeRefSync(initial)
def sync(): IterateeRefSync[Unit] = new IterateeRefSync(Iteratee.unit)
def async[A](initial: Iteratee[A])(implicit app: ActorSystem): IterateeRefAsync[A] = new IterateeRefAsync(initial)
def async()(implicit app: ActorSystem): IterateeRefAsync[Unit] = new IterateeRefAsync(Iteratee.unit)
class Map[K, V] private (refFactory: IterateeRef[V], underlying: mutable.Map[K, IterateeRef[V]] = mutable.Map.empty[K, IterateeRef[V]]) extends mutable.Map[K, IterateeRef[V]] {
def get(key: K) = Some(underlying.getOrElseUpdate(key, refFactory))
def iterator = underlying.iterator
def +=(kv: (K, IterateeRef[V])) = { underlying += kv; this }
def -=(key: K) = { underlying -= key; this }
override def empty = new Map[K, V](refFactory)
}
object Map {
def apply[K, V](refFactory: IterateeRef[V]): IterateeRef.Map[K, V] = new Map(refFactory)
def sync[K](): IterateeRef.Map[K, Unit] = new Map(IterateeRef.sync())
def async[K]()(implicit app: ActorSystem): IterateeRef.Map[K, Unit] = new Map(IterateeRef.async())
}
}
@tailrec
private def run() {
_next match {
case ByteStringLength(continuation, handle, message, waitingFor)
context.asInstanceOf[ActorCell].currentMessage = message
val st = state(handle)
if (st.readBytes.length >= waitingFor) {
val bytes = st.readBytes.take(waitingFor) //.compact
st.readBytes = st.readBytes.drop(waitingFor)
_next = continuation(bytes)
run()
/**
* A mutable reference to an Iteratee. Not thread safe.
*
* Designed for use within an Actor.
*
* Includes mutable implementations of flatMap, map, and apply which
* update the internal reference and return Unit.
*/
trait IterateeRef[A] {
def flatMap(f: A Iteratee[A]): Unit
def map(f: A A): Unit
def apply(input: Input): Unit
}
case bsd @ ByteStringDelimited(continuation, handle, message, delimiter, inclusive, scanned)
context.asInstanceOf[ActorCell].currentMessage = message
val st = state(handle)
val idx = st.readBytes.indexOfSlice(delimiter, scanned)
if (idx >= 0) {
val index = if (inclusive) idx + delimiter.length else idx
val bytes = st.readBytes.take(index) //.compact
st.readBytes = st.readBytes.drop(idx + delimiter.length)
_next = continuation(bytes)
run()
final class IterateeRefSync[A](initial: Iteratee[A]) extends IterateeRef[A] {
private var _value: (Iteratee[A], Input) = (initial, Chunk.empty)
def flatMap(f: A Iteratee[A]): Unit = _value = _value match {
case (iter, chunk @ Chunk(bytes)) if bytes.nonEmpty (iter flatMap f)(chunk)
case (iter, input) (iter flatMap f, input)
}
def map(f: A A): Unit = _value = (_value._1 map f, _value._2)
def apply(input: Input): Unit = _value = _value._1(_value._2 ++ input)
def value: (Iteratee[A], Input) = _value
}
final class IterateeRefAsync[A](initial: Iteratee[A])(implicit app: ActorSystem) extends IterateeRef[A] {
import akka.dispatch.Future
private var _value: Future[(Iteratee[A], Input)] = Future((initial, Chunk.empty))
def flatMap(f: A Iteratee[A]): Unit = _value = _value map {
case (iter, chunk @ Chunk(bytes)) if bytes.nonEmpty (iter flatMap f)(chunk)
case (iter, input) (iter flatMap f, input)
}
def map(f: A A): Unit = _value = _value map (v (v._1 map f, v._2))
def apply(input: Input): Unit = _value = _value map (v v._1(v._2 ++ input))
def future: Future[(Iteratee[A], Input)] = _value
}
/**
* An Iteratee that returns the ByteString prefix up until the supplied delimiter.
* The delimiter is dropped by default, but it can be returned with the result by
* setting 'inclusive' to be 'true'.
*/
def takeUntil(delimiter: ByteString, inclusive: Boolean = false): Iteratee[ByteString] = {
def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match {
case Chunk(more)
val bytes = taken ++ more
val startIdx = bytes.indexOfSlice(delimiter, math.max(taken.length - delimiter.length, 0))
if (startIdx >= 0) {
val endIdx = startIdx + delimiter.length
(Done(bytes take (if (inclusive) endIdx else startIdx)), Chunk(bytes drop endIdx))
} else {
_next = bsd.copy(scanned = math.min(idx - delimiter.length, 0))
(Cont(step(bytes)), Chunk.empty)
}
case ByteStringAny(continuation, handle, message)
context.asInstanceOf[ActorCell].currentMessage = message
val st = state(handle)
if (st.readBytes.length > 0) {
val bytes = st.readBytes //.compact
st.readBytes = ByteString.empty
_next = continuation(bytes)
run()
case eof (Cont(step(taken)), eof)
}
case Retry(message)
message +=: _messages
_next = Idle
run()
case Idle reinvoke()
Cont(step(ByteString.empty))
}
def takeWhile(p: (Byte) Boolean): Iteratee[ByteString] = {
def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match {
case Chunk(more)
val (found, rest) = more span p
if (rest.isEmpty)
(Cont(step(taken ++ found)), Chunk.empty)
else
(Done(taken ++ found), Chunk(rest))
case eof (Done(taken), eof)
}
Cont(step(ByteString.empty))
}
/**
* An Iteratee that returns a ByteString of the requested length.
*/
def take(length: Int): Iteratee[ByteString] = {
def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match {
case Chunk(more)
val bytes = taken ++ more
if (bytes.length >= length)
(Done(bytes.take(length)), Chunk(bytes.drop(length)))
else
(Cont(step(bytes)), Chunk.empty)
case eof (Cont(step(taken)), eof)
}
Cont(step(ByteString.empty))
}
/**
* An Iteratee that ignores the specified number of bytes.
*/
def drop(length: Int): Iteratee[Unit] = {
def step(left: Int)(input: Input): (Iteratee[Unit], Input) = input match {
case Chunk(more)
if (left > more.length)
(Cont(step(left - more.length)), Chunk.empty)
else
(Done(), Chunk(more drop left))
case eof (Done(), eof)
}
Cont(step(length))
}
/**
* An Iteratee that returns the remaining ByteString until an EOF is given.
*/
val takeAll: Iteratee[ByteString] = {
def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match {
case Chunk(more)
val bytes = taken ++ more
(Cont(step(bytes)), Chunk.empty)
case eof (Done(taken), eof)
}
Cont(step(ByteString.empty))
}
/**
* An Iteratee that returns any input it receives
*/
val takeAny: Iteratee[ByteString] = Cont {
case Chunk(bytes) if bytes.nonEmpty (Done(bytes), Chunk.empty)
case Chunk(bytes) (takeAny, Chunk.empty)
case eof (Done(ByteString.empty), eof)
}
def takeList[A](length: Int)(iter: Iteratee[A]): Iteratee[List[A]] = {
def step(left: Int, list: List[A]): Iteratee[List[A]] =
if (left == 0) Done(list.reverse)
else iter flatMap (a step(left - 1, a :: list))
step(length, Nil)
}
def peek(length: Int): Iteratee[ByteString] = {
def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match {
case Chunk(more)
val bytes = taken ++ more
if (bytes.length >= length)
(Done(bytes.take(length)), Chunk(bytes))
else
(Cont(step(bytes)), Chunk.empty)
case eof (Cont(step(taken)), eof)
}
Cont(step(ByteString.empty))
}
def repeat(iter: Iteratee[Unit]): Iteratee[Unit] =
iter flatMap (_ repeat(iter))
def traverse[A, B, M[A] <: Traversable[A]](in: M[A])(f: A Iteratee[B])(implicit cbf: CanBuildFrom[M[A], B, M[B]]): Iteratee[M[B]] =
fold(cbf(in), in)((b, a) f(a) map (b += _)) map (_.result)
def fold[A, B, M[A] <: Traversable[A]](initial: B, in: M[A])(f: (B, A) Iteratee[B]): Iteratee[B] =
(Iteratee(initial) /: in)((ib, a) ib flatMap (b f(b, a)))
// private api
private object Chain {
def apply[A](f: Input (Iteratee[A], Input)) = new Chain[A](f, Nil, Nil)
def apply[A, B](f: Input (Iteratee[A], Input), k: A Iteratee[B]) = new Chain[B](f, List(k.asInstanceOf[Any Iteratee[Any]]), Nil)
}
/**
* A function 'ByteString => Iteratee[A]' that composes with 'A => Iteratee[B]' functions
* in a stack-friendly manner.
*
* For internal use within Iteratee.
*/
private final case class Chain[A] private (cur: Input (Iteratee[Any], Input), queueOut: List[Any Iteratee[Any]], queueIn: List[Any Iteratee[Any]]) extends (Input (Iteratee[A], Input)) {
def :+[B](f: A Iteratee[B]) = new Chain[B](cur, queueOut, f.asInstanceOf[Any Iteratee[Any]] :: queueIn)
def apply(input: Input): (Iteratee[A], Input) = {
@tailrec
def run(result: (Iteratee[Any], Input), queueOut: List[Any Iteratee[Any]], queueIn: List[Any Iteratee[Any]]): (Iteratee[Any], Input) = {
if (queueOut.isEmpty) {
if (queueIn.isEmpty) result
else run(result, queueIn.reverse, Nil)
} else result match {
case (Done(value), rest)
queueOut.head(value) match {
//case Cont(Chain(f, q)) run(f(rest), q ++ tail) <- can cause big slowdown, need to test if needed
case Cont(f) run(f(rest), queueOut.tail, queueIn)
case iter run((iter, rest), queueOut.tail, queueIn)
}
case (Cont(f), rest)
(Cont(new Chain(f, queueOut, queueIn)), rest)
case _ result
}
}
run(cur(input), queueOut, queueIn).asInstanceOf[(Iteratee[A], Input)]
}
}
}
object IOManager {
def start()(implicit system: ActorSystem): ActorRef = {
// TODO: Replace with better "get or create" if/when available
val ref = system.actorFor(system / "io-manager")
if (!ref.isInstanceOf[EmptyLocalActorRef]) ref else try {
system.actorOf(Props[IOManager], "io-manager")
} catch {
case _: InvalidActorNameException ref
}
}
def stop()(implicit system: ActorSystem): Unit = {
// TODO: send shutdown message to IOManager
}
}
final class WriteBuffer(bufferSize: Int) {
private val _queue = new java.util.ArrayDeque[ByteString]
private val _buffer = ByteBuffer.allocate(bufferSize)
private var _length = 0
private def fillBuffer(): Boolean = {
while (!_queue.isEmpty && _buffer.hasRemaining) {
val next = _queue.pollFirst
val rest = next.drop(next.copyToBuffer(_buffer))
if (rest.nonEmpty) _queue.offerFirst(rest)
}
!_buffer.hasRemaining
}
def enqueue(elem: ByteString): this.type = {
_length += elem.length
val rest = elem.drop(elem.copyToBuffer(_buffer))
if (rest.nonEmpty) _queue.offerLast(rest)
this
}
def length = _length
def isEmpty = _length == 0
def write(channel: WritableByteChannel with SelectableChannel): Int = {
@tailrec
def run(total: Int): Int = {
if (this.isEmpty) total
else {
val written = try {
_buffer.flip()
channel write _buffer
} finally {
// don't leave buffer in wrong state
_buffer.compact()
fillBuffer()
}
_length -= written
if (_buffer.position > 0) {
total + written
} else {
run(total + written)
}
}
}
class IOManager(bufferSize: Int = 8192) extends Actor {
run(0)
}
}
// TODO: Support a pool of workers
final class IOManager extends Actor {
import SelectionKey.{ OP_READ, OP_WRITE, OP_ACCEPT, OP_CONNECT }
import IOWorker._
var worker: IOWorker = _
val bufferSize = 8192 // TODO: make buffer size configurable
override def preStart {
worker = new IOWorker(context.system, self, bufferSize)
worker.start()
type ReadChannel = ReadableByteChannel with SelectableChannel
type WriteChannel = WritableByteChannel with SelectableChannel
val selector: Selector = Selector open ()
val channels = mutable.Map.empty[IO.Handle, SelectableChannel]
val accepted = mutable.Map.empty[IO.ServerHandle, mutable.Queue[SelectableChannel]]
val writes = mutable.Map.empty[IO.WriteHandle, WriteBuffer]
val closing = mutable.Set.empty[IO.Handle]
val buffer = ByteBuffer.allocate(bufferSize)
var lastSelect = 0
val selectAt = 100 // TODO: determine best value, perhaps based on throughput? Other triggers (like write queue size)?
//val selectEveryNanos = 1000000 // nanos
//var lastSelectNanos = System.nanoTime
var running = false
var selectSent = false
var fastSelect = false
object Select
def run() {
if (!running) {
running = true
if (!selectSent) {
selectSent = true
self ! Select
}
}
lastSelect += 1
if (lastSelect >= selectAt /* || (lastSelectNanos + selectEveryNanos) < System.nanoTime */ ) select()
}
def select() {
if (selector.isOpen) {
// TODO: Make select behaviour configurable.
// Blocking 1ms reduces allocations during idle times, non blocking gives better performance.
if (fastSelect) selector.selectNow else selector.select(1)
val keys = selector.selectedKeys.iterator
fastSelect = keys.hasNext
while (keys.hasNext) {
val key = keys.next()
keys.remove()
if (key.isValid) { process(key) }
}
if (channels.isEmpty) running = false
} else {
running = false
}
//lastSelectNanos = System.nanoTime
lastSelect = 0
}
def receive = {
case Select
select()
if (running) self ! Select
selectSent = running
case IO.Listen(server, address)
val channel = ServerSocketChannel open ()
channel configureBlocking false
channel.socket bind address
worker(Register(server, channel, OP_ACCEPT))
channel.socket bind (address, 1000) // TODO: make backlog configurable
channels update (server, channel)
channel register (selector, OP_ACCEPT, server)
run()
case IO.Connect(socket, address)
val channel = SocketChannel open ()
channel configureBlocking false
channel connect address
worker(Register(socket, channel, OP_CONNECT | OP_READ))
channels update (socket, channel)
channel register (selector, OP_CONNECT | OP_READ, socket)
run()
case IO.Accept(socket, server) worker(Accepted(socket, server))
case IO.Write(handle, data) worker(Write(handle, data.asByteBuffer))
case IO.Close(handle) worker(Close(handle))
case IO.Accept(socket, server)
val queue = accepted(server)
val channel = queue.dequeue()
channels update (socket, channel)
channel register (selector, OP_READ, socket)
run()
case IO.Write(handle, data)
if (channels contains handle) {
val queue = {
val existing = writes get handle
if (existing.isDefined) existing.get
else {
val q = new WriteBuffer(bufferSize)
writes update (handle, q)
q
}
}
if (queue.isEmpty) addOps(handle, OP_WRITE)
queue enqueue data
if (queue.length >= bufferSize) write(handle, channels(handle).asInstanceOf[WriteChannel])
}
run()
case IO.Close(handle: IO.WriteHandle)
if (writes get handle filterNot (_.isEmpty) isDefined) {
closing += handle
} else {
cleanup(handle, None)
}
run()
case IO.Close(handle)
cleanup(handle, None)
run()
}
override def postStop {
worker(Shutdown)
}
}
private[akka] object IOWorker {
sealed trait Request
case class Register(handle: IO.Handle, channel: SelectableChannel, ops: Int) extends Request
case class Accepted(socket: IO.SocketHandle, server: IO.ServerHandle) extends Request
case class Write(handle: IO.WriteHandle, data: ByteBuffer) extends Request
case class Close(handle: IO.Handle) extends Request
case object Shutdown extends Request
}
private[akka] class IOWorker(system: ActorSystem, ioManager: ActorRef, val bufferSize: Int) {
import SelectionKey.{ OP_READ, OP_WRITE, OP_ACCEPT, OP_CONNECT }
import IOWorker._
type ReadChannel = ReadableByteChannel with SelectableChannel
type WriteChannel = WritableByteChannel with SelectableChannel
implicit val optionIOManager: Some[ActorRef] = Some(ioManager)
def apply(request: Request): Unit =
addRequest(request)
def start(): Unit =
thread.start()
// private
private val selector: Selector = Selector open ()
private val _requests = new AtomicReference(List.empty[Request])
private var accepted = Map.empty[IO.ServerHandle, Queue[SelectableChannel]].withDefaultValue(Queue.empty)
private var channels = Map.empty[IO.Handle, SelectableChannel]
private var writes = Map.empty[IO.WriteHandle, Queue[ByteBuffer]].withDefaultValue(Queue.empty)
private val buffer = ByteBuffer.allocate(bufferSize)
private val thread = new Thread("io-worker") {
override def run() {
while (selector.isOpen) {
selector select ()
val keys = selector.selectedKeys.iterator
while (keys.hasNext) {
val key = keys next ()
keys remove ()
if (key.isValid) { process(key) }
}
_requests.getAndSet(Nil).reverse foreach {
case Register(handle, channel, ops)
channels += (handle -> channel)
channel register (selector, ops, handle)
case Accepted(socket, server)
val (channel, rest) = accepted(server).dequeue
if (rest.isEmpty) accepted -= server
else accepted += (server -> rest)
channels += (socket -> channel)
channel register (selector, OP_READ, socket)
case Write(handle, data)
if (channels contains handle) {
val queue = writes(handle)
if (queue.isEmpty) addOps(handle, OP_WRITE)
writes += (handle -> queue.enqueue(data))
}
case Close(handle)
cleanup(handle, None)
case Shutdown
channels.values foreach (_.close)
channels.keys foreach (handle cleanup(handle, None))
selector.close
}
}
}
}
private def process(key: SelectionKey) {
def process(key: SelectionKey) {
val handle = key.attachment.asInstanceOf[IO.Handle]
try {
if (key.isConnectable) key.channel match {
@ -369,7 +637,8 @@ private[akka] class IOWorker(system: ActorSystem, ioManager: ActorRef, val buffe
}
}
private def cleanup(handle: IO.Handle, cause: Option[Exception]) {
def cleanup(handle: IO.Handle, cause: Option[Exception]) {
closing -= handle
handle match {
case server: IO.ServerHandle accepted -= server
case writable: IO.WriteHandle writes -= writable
@ -378,28 +647,27 @@ private[akka] class IOWorker(system: ActorSystem, ioManager: ActorRef, val buffe
case Some(channel)
channel.close
channels -= handle
// TODO: what if handle.owner is no longer running?
handle.owner ! IO.Closed(handle, cause)
if (!handle.owner.isTerminated) handle.owner ! IO.Closed(handle, cause)
case None
}
}
private def setOps(handle: IO.Handle, ops: Int): Unit =
def setOps(handle: IO.Handle, ops: Int): Unit =
channels(handle) keyFor selector interestOps ops
private def addOps(handle: IO.Handle, ops: Int) {
def addOps(handle: IO.Handle, ops: Int) {
val key = channels(handle) keyFor selector
val cur = key.interestOps
key interestOps (cur | ops)
}
private def removeOps(handle: IO.Handle, ops: Int) {
def removeOps(handle: IO.Handle, ops: Int) {
val key = channels(handle) keyFor selector
val cur = key.interestOps
key interestOps (cur - (cur & ops))
}
private def connect(socket: IO.SocketHandle, channel: SocketChannel) {
def connect(socket: IO.SocketHandle, channel: SocketChannel) {
if (channel.finishConnect) {
removeOps(socket, OP_CONNECT)
socket.owner ! IO.Connected(socket)
@ -409,18 +677,27 @@ private[akka] class IOWorker(system: ActorSystem, ioManager: ActorRef, val buffe
}
@tailrec
private def accept(server: IO.ServerHandle, channel: ServerSocketChannel) {
def accept(server: IO.ServerHandle, channel: ServerSocketChannel) {
val socket = channel.accept
if (socket ne null) {
socket configureBlocking false
accepted += (server -> (accepted(server) enqueue socket))
val queue = {
val existing = accepted get server
if (existing.isDefined) existing.get
else {
val q = mutable.Queue[SelectableChannel]()
accepted update (server, q)
q
}
}
queue += socket
server.owner ! IO.NewClient(server)
accept(server, channel)
}
}
@tailrec
private def read(handle: IO.ReadHandle, channel: ReadChannel) {
def read(handle: IO.ReadHandle, channel: ReadChannel) {
buffer.clear
val readLen = channel read buffer
if (readLen == -1) {
@ -432,30 +709,16 @@ private[akka] class IOWorker(system: ActorSystem, ioManager: ActorRef, val buffe
}
}
@tailrec
private def write(handle: IO.WriteHandle, channel: WriteChannel) {
def write(handle: IO.WriteHandle, channel: WriteChannel) {
val queue = writes(handle)
if (queue.nonEmpty) {
val (buf, bufs) = queue.dequeue
val writeLen = channel write buf
if (buf.remaining == 0) {
if (bufs.isEmpty) {
writes -= handle
removeOps(handle, OP_WRITE)
queue write channel
if (queue.isEmpty) {
if (closing(handle)) {
cleanup(handle, None)
} else {
writes += (handle -> bufs)
write(handle, channel)
}
removeOps(handle, OP_WRITE)
}
}
}
@tailrec
private def addRequest(req: Request) {
val requests = _requests.get
if (_requests compareAndSet (requests, req :: requests))
selector wakeup ()
else
addRequest(req)
}
}

View file

@ -1,10 +1,11 @@
package akka.util
import java.nio.ByteBuffer
import java.nio.charset.Charset
import scala.collection.IndexedSeqOptimized
import scala.collection.mutable.{ Builder, ArrayBuilder }
import scala.collection.immutable.IndexedSeq
import scala.collection.{ IndexedSeqOptimized, LinearSeq }
import scala.collection.mutable.{ Builder, ArrayBuilder, WrappedArray }
import scala.collection.immutable.{ IndexedSeq, VectorBuilder }
import scala.collection.generic.{ CanBuildFrom, GenericCompanion }
object ByteString {
@ -30,35 +31,43 @@ object ByteString {
def apply(string: String, charset: String): ByteString = ByteString1(string.getBytes(charset))
def fromArray(array: Array[Byte], offset: Int, length: Int): ByteString = {
val copyOffset = math.max(offset, 0)
val copyLength = math.max(math.min(array.length - copyOffset, length), 0)
if (copyLength == 0) empty
else {
val copyArray = new Array[Byte](copyLength)
Array.copy(array, copyOffset, copyArray, 0, copyLength)
ByteString1(copyArray)
}
}
val empty: ByteString = ByteString1(Array.empty[Byte])
def newBuilder: Builder[Byte, ByteString] = new ArrayBuilder.ofByte mapResult apply
def newBuilder = new ByteStringBuilder
implicit def canBuildFrom = new CanBuildFrom[TraversableOnce[Byte], Byte, ByteString] {
def apply(from: TraversableOnce[Byte]) = newBuilder
def apply() = newBuilder
}
private object ByteString1 {
private[akka] object ByteString1 {
def apply(bytes: Array[Byte]) = new ByteString1(bytes)
}
final class ByteString1 private (bytes: Array[Byte], startIndex: Int, endIndex: Int) extends ByteString {
final class ByteString1 private (private val bytes: Array[Byte], private val startIndex: Int, val length: Int) extends ByteString {
private def this(bytes: Array[Byte]) = this(bytes, 0, bytes.length)
def apply(idx: Int): Byte = bytes(checkRangeConvert(idx))
private def checkRangeConvert(index: Int) = {
val idx = index + startIndex
if (0 <= index && idx < endIndex)
idx
if (0 <= index && length > index)
index + startIndex
else
throw new IndexOutOfBoundsException(index.toString)
}
def length: Int = endIndex - startIndex
def toArray: Array[Byte] = {
val ar = new Array[Byte](length)
Array.copy(bytes, startIndex, ar, 0, length)
@ -68,8 +77,7 @@ object ByteString {
override def clone: ByteString = new ByteString1(toArray)
def compact: ByteString =
if (startIndex == 0 && endIndex == bytes.length) this
else clone
if (length == bytes.length) this else clone
def asByteBuffer: ByteBuffer = {
val buffer = ByteBuffer.wrap(bytes, startIndex, length).asReadOnlyBuffer
@ -77,8 +85,8 @@ object ByteString {
else buffer
}
def utf8String: String =
new String(if (startIndex == 0 && endIndex == bytes.length) bytes else toArray, "UTF-8")
def decodeString(charset: String): String =
new String(if (length == bytes.length) bytes else toArray, charset)
def ++(that: ByteString): ByteString = that match {
case b: ByteString1 ByteStrings(this, b)
@ -87,42 +95,50 @@ object ByteString {
override def slice(from: Int, until: Int): ByteString = {
val newStartIndex = math.max(from, 0) + startIndex
val newEndIndex = math.min(until, length) + startIndex
if (newEndIndex <= newStartIndex) ByteString.empty
else new ByteString1(bytes, newStartIndex, newEndIndex)
val newLength = math.min(until, length) - from
if (newLength <= 0) ByteString.empty
else new ByteString1(bytes, newStartIndex, newLength)
}
override def copyToArray[A >: Byte](xs: Array[A], start: Int, len: Int): Unit =
Array.copy(bytes, startIndex, xs, start, math.min(math.min(length, len), xs.length - start))
def copyToBuffer(buffer: ByteBuffer): Int = {
val copyLength = math.min(buffer.remaining, length)
if (copyLength > 0) buffer.put(bytes, startIndex, copyLength)
copyLength
}
private object ByteStrings {
def apply(bytestrings: Vector[ByteString1]): ByteString = new ByteStrings(bytestrings)
}
private[akka] object ByteStrings {
def apply(bytestrings: Vector[ByteString1]): ByteString = new ByteStrings(bytestrings, (0 /: bytestrings)(_ + _.length))
def apply(bytestrings: Vector[ByteString1], length: Int): ByteString = new ByteStrings(bytestrings, length)
def apply(b1: ByteString1, b2: ByteString1): ByteString = compare(b1, b2) match {
case 3 new ByteStrings(Vector(b1, b2))
case 3 new ByteStrings(Vector(b1, b2), b1.length + b2.length)
case 2 b2
case 1 b1
case 0 ByteString.empty
}
def apply(b: ByteString1, bs: ByteStrings): ByteString = compare(b, bs) match {
case 3 new ByteStrings(b +: bs.bytestrings)
case 3 new ByteStrings(b +: bs.bytestrings, bs.length + b.length)
case 2 bs
case 1 b
case 0 ByteString.empty
}
def apply(bs: ByteStrings, b: ByteString1): ByteString = compare(bs, b) match {
case 3 new ByteStrings(bs.bytestrings :+ b)
case 3 new ByteStrings(bs.bytestrings :+ b, bs.length + b.length)
case 2 b
case 1 bs
case 0 ByteString.empty
}
def apply(bs1: ByteStrings, bs2: ByteStrings): ByteString = compare(bs1, bs2) match {
case 3 new ByteStrings(bs1.bytestrings ++ bs2.bytestrings)
case 3 new ByteStrings(bs1.bytestrings ++ bs2.bytestrings, bs1.length + bs2.length)
case 2 bs2
case 1 bs1
case 0 ByteString.empty
@ -133,9 +149,10 @@ object ByteString {
if (b1.length == 0)
if (b2.length == 0) 0 else 2
else if (b2.length == 0) 1 else 3
}
final class ByteStrings private (private val bytestrings: Vector[ByteString1]) extends ByteString {
final class ByteStrings private (val bytestrings: Vector[ByteString1], val length: Int) extends ByteString {
def apply(idx: Int): Byte =
if (0 <= idx && idx < length) {
@ -148,37 +165,39 @@ object ByteString {
bytestrings(pos)(idx - seen)
} else throw new IndexOutOfBoundsException(idx.toString)
val length: Int = (0 /: bytestrings)(_ + _.length)
override def slice(from: Int, until: Int): ByteString = {
val start = math.max(from, 0)
val end = math.min(until, length)
if (end <= start)
ByteString.empty
else {
val iter = bytestrings.iterator
var cur = iter.next
var pos = 0
var seen = 0
while (from >= seen + bytestrings(pos).length) {
seen += bytestrings(pos).length
while (from >= seen + cur.length) {
seen += cur.length
pos += 1
cur = iter.next
}
val startpos = pos
val startidx = start - seen
while (until > seen + bytestrings(pos).length) {
seen += bytestrings(pos).length
while (until > seen + cur.length) {
seen += cur.length
pos += 1
cur = iter.next
}
val endpos = pos
val endidx = end - seen
if (startpos == endpos)
bytestrings(startpos).slice(startidx, endidx)
cur.slice(startidx, endidx)
else {
val first = bytestrings(startpos).drop(startidx).asInstanceOf[ByteString1]
val last = bytestrings(endpos).take(endidx).asInstanceOf[ByteString1]
val last = cur.take(endidx).asInstanceOf[ByteString1]
if ((endpos - startpos) == 1)
new ByteStrings(Vector(first, last))
new ByteStrings(Vector(first, last), until - from)
else
new ByteStrings(first +: bytestrings.slice(startpos + 1, endpos) :+ last)
new ByteStrings(first +: bytestrings.slice(startpos + 1, endpos) :+ last, until - from)
}
}
}
@ -200,7 +219,16 @@ object ByteString {
def asByteBuffer: ByteBuffer = compact.asByteBuffer
def utf8String: String = compact.utf8String
def decodeString(charset: String): String = compact.decodeString(charset)
def copyToBuffer(buffer: ByteBuffer): Int = {
val copyLength = math.min(buffer.remaining, length)
val iter = bytestrings.iterator
while (iter.hasNext && buffer.hasRemaining) {
iter.next.copyToBuffer(buffer)
}
copyLength
}
}
}
@ -208,9 +236,92 @@ object ByteString {
sealed trait ByteString extends IndexedSeq[Byte] with IndexedSeqOptimized[Byte, ByteString] {
override protected[this] def newBuilder = ByteString.newBuilder
def ++(that: ByteString): ByteString
def copyToBuffer(buffer: ByteBuffer): Int
def compact: ByteString
def asByteBuffer: ByteBuffer
def toByteBuffer: ByteBuffer = ByteBuffer.wrap(toArray)
def utf8String: String
def mapI(f: Byte Int): ByteString = map(f andThen (_.toByte))
final def toByteBuffer: ByteBuffer = ByteBuffer.wrap(toArray)
final def utf8String: String = decodeString("UTF-8")
def decodeString(charset: String): String
final def mapI(f: Byte Int): ByteString = map(f andThen (_.toByte))
}
final class ByteStringBuilder extends Builder[Byte, ByteString] {
import ByteString.{ ByteString1, ByteStrings }
private var _length = 0
private val _builder = new VectorBuilder[ByteString1]()
private var _temp: Array[Byte] = _
private var _tempLength = 0
private var _tempCapacity = 0
private def clearTemp() {
if (_tempLength > 0) {
val arr = new Array[Byte](_tempLength)
Array.copy(_temp, 0, arr, 0, _tempLength)
_builder += ByteString1(arr)
_tempLength = 0
}
}
private def resizeTemp(size: Int) {
val newtemp = new Array[Byte](size)
if (_tempLength > 0) Array.copy(_temp, 0, newtemp, 0, _tempLength)
_temp = newtemp
}
private def ensureTempSize(size: Int) {
if (_tempCapacity < size || _tempCapacity == 0) {
var newSize = if (_tempCapacity == 0) 16 else _tempCapacity * 2
while (newSize < size) newSize *= 2
resizeTemp(newSize)
}
}
def +=(elem: Byte): this.type = {
ensureTempSize(_tempLength + 1)
_temp(_tempLength) = elem
_tempLength += 1
_length += 1
this
}
override def ++=(xs: TraversableOnce[Byte]): this.type = {
xs match {
case b: ByteString1
clearTemp()
_builder += b
_length += b.length
case bs: ByteStrings
clearTemp()
_builder ++= bs.bytestrings
_length += bs.length
case xs: WrappedArray.ofByte
clearTemp()
_builder += ByteString1(xs.array.clone)
_length += xs.length
case _: collection.IndexedSeq[_]
ensureTempSize(_tempLength + xs.size)
xs.copyToArray(_temp, _tempLength)
case _
super.++=(xs)
}
this
}
def clear() {
_builder.clear
_length = 0
_tempLength = 0
}
def result: ByteString =
if (_length == 0) ByteString.empty
else {
clearTemp()
val bytestrings = _builder.result
if (bytestrings.size == 1)
bytestrings.head
else
ByteStrings(bytestrings, _length)
}
}