From a1f9b81ce9ad45aa05ef88ee1a0362141736d702 Mon Sep 17 00:00:00 2001 From: Derek Williams Date: Sat, 31 Dec 2011 09:18:37 -0700 Subject: [PATCH] Update IO to new iteratee based api --- .../src/test/scala/akka/actor/IOActor.scala | 389 ++++----- akka-actor/src/main/scala/akka/actor/IO.scala | 771 ++++++++++++------ .../src/main/scala/akka/util/ByteString.scala | 189 ++++- 3 files changed, 870 insertions(+), 479 deletions(-) diff --git a/akka-actor-tests/src/test/scala/akka/actor/IOActor.scala b/akka-actor-tests/src/test/scala/akka/actor/IOActor.scala index 757acb1fd0..06d3235409 100644 --- a/akka-actor-tests/src/test/scala/akka/actor/IOActor.scala +++ b/akka-actor-tests/src/test/scala/akka/actor/IOActor.scala @@ -4,8 +4,6 @@ package akka.actor -import org.scalatest.BeforeAndAfterEach - import akka.util.ByteString import akka.util.cps._ import scala.util.continuations._ @@ -13,214 +11,237 @@ import akka.testkit._ import akka.dispatch.{ Await, Future } object IOActorSpec { - import IO._ - class SimpleEchoServer(host: String, port: Int, ioManager: ActorRef, started: TestLatch) extends Actor { + class SimpleEchoServer(host: String, port: Int, started: TestLatch) extends Actor { - import context.dispatcher - implicit val timeout = context.system.settings.ActorTimeout + IO listen (host, port) - override def preStart = { - listen(ioManager, host, port) - started.open() - } + started.open - def createWorker = context.actorOf(Props(new Actor with IO { - def receiveIO = { - case NewClient(server) ⇒ - val socket = server.accept() - loopC { - val bytes = socket.read() - socket write bytes - } - } - })) + val state = IO.IterateeRef.Map.sync[IO.Handle]() def receive = { - case msg: NewClient ⇒ - createWorker forward msg - } - } - - class SimpleEchoClient(host: String, port: Int, ioManager: ActorRef) extends Actor with IO { - - lazy val socket: SocketHandle = connect(ioManager, host, port)(reader) - lazy val reader: ActorRef = context.actorOf(Props({ - new Actor with IO { - def receiveIO = { - case length: Int ⇒ - val bytes = socket.read(length) - sender ! bytes - } - } - })) - - def receiveIO = { - case bytes: ByteString ⇒ - socket write bytes - reader forward bytes.length - } - } - - // Basic Redis-style protocol - class KVStore(host: String, port: Int, ioManager: ActorRef, started: TestLatch) extends Actor { - - import context.dispatcher - implicit val timeout = context.system.settings.ActorTimeout - - var kvs: Map[String, ByteString] = Map.empty - - override def preStart = { - listen(ioManager, host, port) - started.open() - } - - def createWorker = context.actorOf(Props(new Actor with IO { - def receiveIO = { - case NewClient(server) ⇒ - val socket = server.accept() - loopC { - val cmd = socket.read(ByteString("\r\n")).utf8String - val result = matchC(cmd.split(' ')) { - case Array("SET", key, length) ⇒ - val value = socket read length.toInt - server.owner ? (('set, key, value)) map ((x: Any) ⇒ ByteString("+OK\r\n")) - case Array("GET", key) ⇒ - server.owner ? (('get, key)) map { - case Some(b: ByteString) ⇒ ByteString("$" + b.length + "\r\n") ++ b - case None ⇒ ByteString("$-1\r\n") - } - case Array("GETALL") ⇒ - server.owner ? 'getall map { - case m: Map[_, _] ⇒ - (ByteString("*" + (m.size * 2) + "\r\n") /: m) { - case (result, (k: String, v: ByteString)) ⇒ - val kBytes = ByteString(k) - result ++ ByteString("$" + kBytes.length + "\r\n") ++ kBytes ++ ByteString("$" + v.length + "\r\n") ++ v - } - } - } - result recover { - case e ⇒ ByteString("-" + e.getClass.toString + "\r\n") - } foreach { bytes ⇒ + case IO.NewClient(server) ⇒ + val socket = server.accept() + state(socket) flatMap { _ ⇒ + IO repeat { + IO.takeAny map { bytes ⇒ socket write bytes } } - } - })) + } + + case IO.Read(socket, bytes) ⇒ + state(socket)(IO Chunk bytes) + + case IO.Closed(socket, cause) ⇒ + state -= socket + + } + + } + + class SimpleEchoClient(host: String, port: Int) extends Actor { + + val socket = IO connect (host, port) + + val state = IO.IterateeRef.sync() def receive = { - case msg: NewClient ⇒ createWorker forward msg - case ('set, key: String, value: ByteString) ⇒ - kvs += (key -> value) - sender.tell((), self) - case ('get, key: String) ⇒ sender.tell(kvs.get(key), self) - case 'getall ⇒ sender.tell(kvs, self) + + case bytes: ByteString ⇒ + val source = sender + socket write bytes + for { + _ ← state + bytes ← IO take bytes.length + } yield source ! bytes + + case IO.Read(socket, bytes) ⇒ + state(IO Chunk bytes) + + case IO.Connected(socket) ⇒ + + case IO.Closed(socket, cause) ⇒ + + } + } + + sealed trait KVCommand { + def bytes: ByteString + } + + case class KVSet(key: String, value: String) extends KVCommand { + val bytes = ByteString("SET " + key + " " + value.length + "\r\n" + value + "\r\n") + } + + case class KVGet(key: String) extends KVCommand { + val bytes = ByteString("GET " + key + "\r\n") + } + + case object KVGetAll extends KVCommand { + val bytes = ByteString("GETALL\r\n") + } + + // Basic Redis-style protocol + class KVStore(host: String, port: Int, started: TestLatch) extends Actor { + + val state = IO.IterateeRef.Map.sync[IO.Handle]() + + var kvs: Map[String, String] = Map.empty + + IO listen (host, port) + + started.open + + val EOL = ByteString("\r\n") + + def receive = { + + case IO.NewClient(server) ⇒ + val socket = server.accept() + state(socket) flatMap { _ ⇒ + IO repeat { + IO takeUntil EOL map (_.utf8String split ' ') flatMap { + + case Array("SET", key, length) ⇒ + for { + value ← IO take length.toInt + _ ← IO takeUntil EOL + } yield { + kvs += (key -> value.utf8String) + ByteString("+OK\r\n") + } + + case Array("GET", key) ⇒ + IO Iteratee { + kvs get key map { value ⇒ + ByteString("$" + value.length + "\r\n" + value + "\r\n") + } getOrElse ByteString("$-1\r\n") + } + + case Array("GETALL") ⇒ + IO Iteratee { + (ByteString("*" + (kvs.size * 2) + "\r\n") /: kvs) { + case (result, (k, v)) ⇒ + val kBytes = ByteString(k) + val vBytes = ByteString(v) + result ++ + ByteString("$" + kBytes.length) ++ EOL ++ + kBytes ++ EOL ++ + ByteString("$" + vBytes.length) ++ EOL ++ + vBytes ++ EOL + } + } + + } map (socket write) + } + } + + case IO.Read(socket, bytes) ⇒ + state(socket)(IO Chunk bytes) + + case IO.Closed(socket, cause) ⇒ + state -= socket + } } - class KVClient(host: String, port: Int, ioManager: ActorRef) extends Actor with IO { + class KVClient(host: String, port: Int) extends Actor { - import context.dispatcher - implicit val timeout = context.system.settings.ActorTimeout + val socket = IO connect (host, port) - var socket: SocketHandle = _ + val state = IO.IterateeRef.sync() + + val EOL = ByteString("\r\n") + + def receive = { + case cmd: KVCommand ⇒ + val source = sender + socket write cmd.bytes + for { + _ ← state + result ← readResult + } yield result.fold(err ⇒ source ! Status.Failure(new RuntimeException(err)), source !) + + case IO.Read(socket, bytes) ⇒ + state(IO Chunk bytes) + + case IO.Connected(socket) ⇒ + + case IO.Closed(socket, cause) ⇒ - override def preStart { - socket = connect(ioManager, host, port) } - def reply(msg: Any) = sender.tell(msg, self) - - def receiveIO = { - case ('set, key: String, value: ByteString) ⇒ - socket write (ByteString("SET " + key + " " + value.length + "\r\n") ++ value) - reply(readResult) - - case ('get, key: String) ⇒ - socket write ByteString("GET " + key + "\r\n") - reply(readResult) - - case 'getall ⇒ - socket write ByteString("GETALL\r\n") - reply(readResult) - } - - def readResult = { - val resultType = socket.read(1).utf8String - resultType match { - case "+" ⇒ socket.read(ByteString("\r\n")).utf8String - case "-" ⇒ sys error socket.read(ByteString("\r\n")).utf8String + def readResult: IO.Iteratee[Either[String, Any]] = { + IO take 1 map (_.utf8String) flatMap { + case "+" ⇒ IO takeUntil EOL map (msg ⇒ Right(msg.utf8String)) + case "-" ⇒ IO takeUntil EOL map (err ⇒ Left(err.utf8String)) case "$" ⇒ - val length = socket.read(ByteString("\r\n")).utf8String - socket.read(length.toInt) - case "*" ⇒ - val count = socket.read(ByteString("\r\n")).utf8String - var result: Map[String, ByteString] = Map.empty - repeatC(count.toInt / 2) { - val k = readBytes - val v = readBytes - result += (k.utf8String -> v) + IO takeUntil EOL map (_.utf8String.toInt) flatMap { + case -1 ⇒ IO Iteratee Right(None) + case length ⇒ + for { + value ← IO take length + _ ← IO takeUntil EOL + } yield Right(Some(value.utf8String)) } - result - case _ ⇒ sys error "Unexpected response" + case "*" ⇒ + IO takeUntil EOL map (_.utf8String.toInt) flatMap { + case -1 ⇒ IO Iteratee Right(None) + case length ⇒ + IO.takeList(length)(readResult) map { list ⇒ + ((Right(Map()): Either[String, Map[String, String]]) /: list.grouped(2)) { + case (Right(m), List(Right(Some(k: String)), Right(Some(v: String)))) ⇒ Right(m + (k -> v)) + case (Right(_), _) ⇒ Left("Unexpected Response") + case (left, _) ⇒ left + } + } + } + case _ ⇒ IO Iteratee Left("Unexpected Response") } } - - def readBytes = { - val resultType = socket.read(1).utf8String - if (resultType != "$") sys error "Unexpected response" - val length = socket.read(ByteString("\r\n")).utf8String - socket.read(length.toInt) - } } - } @org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner]) -class IOActorSpec extends AkkaSpec with BeforeAndAfterEach with DefaultTimeout { +class IOActorSpec extends AkkaSpec with DefaultTimeout { import IOActorSpec._ "an IO Actor" must { "run echo server" in { val started = TestLatch(1) - val ioManager = system.actorOf(Props(new IOManager(2))) // teeny tiny buffer - val server = system.actorOf(Props(new SimpleEchoServer("localhost", 8064, ioManager, started))) - Await.ready(started, timeout.duration) - val client = system.actorOf(Props(new SimpleEchoClient("localhost", 8064, ioManager))) + val server = system.actorOf(Props(new SimpleEchoServer("localhost", 8064, started))) + Await.ready(started, TestLatch.DefaultTimeout) + val client = system.actorOf(Props(new SimpleEchoClient("localhost", 8064))) val f1 = client ? ByteString("Hello World!1") val f2 = client ? ByteString("Hello World!2") val f3 = client ? ByteString("Hello World!3") - Await.result(f1, timeout.duration) must equal(ByteString("Hello World!1")) - Await.result(f2, timeout.duration) must equal(ByteString("Hello World!2")) - Await.result(f3, timeout.duration) must equal(ByteString("Hello World!3")) - system.stop(client) - system.stop(server) - system.stop(ioManager) + Await.result(f1, TestLatch.DefaultTimeout) must equal(ByteString("Hello World!1")) + Await.result(f2, TestLatch.DefaultTimeout) must equal(ByteString("Hello World!2")) + Await.result(f3, TestLatch.DefaultTimeout) must equal(ByteString("Hello World!3")) } "run echo server under high load" in { val started = TestLatch(1) - val ioManager = system.actorOf(Props(new IOManager())) - val server = system.actorOf(Props(new SimpleEchoServer("localhost", 8065, ioManager, started))) - Await.ready(started, timeout.duration) - val client = system.actorOf(Props(new SimpleEchoClient("localhost", 8065, ioManager))) + val server = system.actorOf(Props(new SimpleEchoServer("localhost", 8065, started))) + Await.ready(started, TestLatch.DefaultTimeout) + val client = system.actorOf(Props(new SimpleEchoClient("localhost", 8065))) val list = List.range(0, 1000) val f = Future.traverse(list)(i ⇒ client ? ByteString(i.toString)) - assert(Await.result(f, timeout.duration).size === 1000) - system.stop(client) - system.stop(server) - system.stop(ioManager) + assert(Await.result(f, TestLatch.DefaultTimeout).size === 1000) } + // Not currently configurable at runtime + /* "run echo server under high load with small buffer" in { val started = TestLatch(1) - val ioManager = system.actorOf(Props(new IOManager(2))) - val server = system.actorOf(Props(new SimpleEchoServer("localhost", 8066, ioManager, started))) - Await.ready(started, timeout.duration) - val client = system.actorOf(Props(new SimpleEchoClient("localhost", 8066, ioManager))) + val ioManager = actorOf(new IOManager(2)) + val server = actorOf(new SimpleEchoServer("localhost", 8066, ioManager, started)) + started.await + val client = actorOf(new SimpleEchoClient("localhost", 8066, ioManager)) val list = List.range(0, 1000) val f = Future.traverse(list)(i ⇒ client ? ByteString(i.toString)) assert(Await.result(f, timeout.duration).size === 1000) @@ -228,32 +249,28 @@ class IOActorSpec extends AkkaSpec with BeforeAndAfterEach with DefaultTimeout { system.stop(server) system.stop(ioManager) } + */ "run key-value store" in { val started = TestLatch(1) - val ioManager = system.actorOf(Props(new IOManager(2))) // teeny tiny buffer - val server = system.actorOf(Props(new KVStore("localhost", 8067, ioManager, started))) - Await.ready(started, timeout.duration) - val client1 = system.actorOf(Props(new KVClient("localhost", 8067, ioManager))) - val client2 = system.actorOf(Props(new KVClient("localhost", 8067, ioManager))) - val f1 = client1 ? (('set, "hello", ByteString("World"))) - val f2 = client1 ? (('set, "test", ByteString("No one will read me"))) - val f3 = client1 ? (('get, "hello")) - Await.ready(f2, timeout.duration) - val f4 = client2 ? (('set, "test", ByteString("I'm a test!"))) - Await.ready(f4, timeout.duration) - val f5 = client1 ? (('get, "test")) - val f6 = client2 ? 'getall - Await.result(f1, timeout.duration) must equal("OK") - Await.result(f2, timeout.duration) must equal("OK") - Await.result(f3, timeout.duration) must equal(ByteString("World")) - Await.result(f4, timeout.duration) must equal("OK") - Await.result(f5, timeout.duration) must equal(ByteString("I'm a test!")) - Await.result(f6, timeout.duration) must equal(Map("hello" -> ByteString("World"), "test" -> ByteString("I'm a test!"))) - system.stop(client1) - system.stop(client2) - system.stop(server) - system.stop(ioManager) + val server = system.actorOf(Props(new KVStore("localhost", 8067, started))) + Await.ready(started, TestLatch.DefaultTimeout) + val client1 = system.actorOf(Props(new KVClient("localhost", 8067))) + val client2 = system.actorOf(Props(new KVClient("localhost", 8067))) + val f1 = client1 ? KVSet("hello", "World") + val f2 = client1 ? KVSet("test", "No one will read me") + val f3 = client1 ? KVGet("hello") + Await.ready(f2, TestLatch.DefaultTimeout) + val f4 = client2 ? KVSet("test", "I'm a test!") + Await.ready(f4, TestLatch.DefaultTimeout) + val f5 = client1 ? KVGet("test") + val f6 = client2 ? KVGetAll + Await.result(f1, TestLatch.DefaultTimeout) must equal("OK") + Await.result(f2, TestLatch.DefaultTimeout) must equal("OK") + Await.result(f3, TestLatch.DefaultTimeout) must equal(Some("World")) + Await.result(f4, TestLatch.DefaultTimeout) must equal("OK") + Await.result(f5, TestLatch.DefaultTimeout) must equal(Some("I'm a test!")) + Await.result(f6, TestLatch.DefaultTimeout) must equal(Map("hello" -> "World", "test" -> "I'm a test!")) } } diff --git a/akka-actor/src/main/scala/akka/actor/IO.scala b/akka-actor/src/main/scala/akka/actor/IO.scala index 28bad4f85e..89922c94ca 100644 --- a/akka-actor/src/main/scala/akka/actor/IO.scala +++ b/akka-actor/src/main/scala/akka/actor/IO.scala @@ -4,10 +4,8 @@ package akka.actor import akka.util.ByteString -import akka.dispatch.Envelope import java.net.InetSocketAddress import java.io.IOException -import java.util.concurrent.atomic.AtomicReference import java.nio.ByteBuffer import java.nio.channels.{ SelectableChannel, @@ -20,9 +18,8 @@ import java.nio.channels.{ CancelledKeyException } import scala.collection.mutable -import scala.collection.immutable.Queue import scala.annotation.tailrec -import scala.util.continuations._ +import scala.collection.generic.CanBuildFrom import com.eaio.uuid.UUID object IO { @@ -44,18 +41,6 @@ object IO { sealed trait ReadHandle extends Handle with Product { override def asReadable = this - - def read(len: Int)(implicit actor: Actor with IO): ByteString @cps[IOSuspendable[Any]] = shift { cont: (ByteString ⇒ IOSuspendable[Any]) ⇒ - ByteStringLength(cont, this, actor.context.asInstanceOf[ActorCell].currentMessage, len) - } - - def read()(implicit actor: Actor with IO): ByteString @cps[IOSuspendable[Any]] = shift { cont: (ByteString ⇒ IOSuspendable[Any]) ⇒ - ByteStringAny(cont, this, actor.context.asInstanceOf[ActorCell].currentMessage) - } - - def read(delimiter: ByteString, inclusive: Boolean = false)(implicit actor: Actor with IO): ByteString @cps[IOSuspendable[Any]] = shift { cont: (ByteString ⇒ IOSuspendable[Any]) ⇒ - ByteStringDelimited(cont, this, actor.context.asInstanceOf[ActorCell].currentMessage, delimiter, inclusive, 0) - } } sealed trait WriteHandle extends Handle with Product { @@ -89,259 +74,542 @@ object IO { case class Read(handle: ReadHandle, bytes: ByteString) extends IOMessage case class Write(handle: WriteHandle, bytes: ByteString) extends IOMessage - def listen(ioManager: ActorRef, address: InetSocketAddress)(implicit owner: ActorRef): ServerHandle = { + def listen(address: InetSocketAddress)(implicit context: ActorContext, owner: ActorRef): ServerHandle = { + val ioManager = IOManager.start()(context.system) val server = ServerHandle(owner, ioManager) ioManager ! Listen(server, address) server } - def listen(ioManager: ActorRef, host: String, port: Int)(implicit owner: ActorRef): ServerHandle = - listen(ioManager, new InetSocketAddress(host, port)) + def listen(host: String, port: Int)(implicit context: ActorContext, owner: ActorRef): ServerHandle = + listen(new InetSocketAddress(host, port))(context, owner) - def connect(ioManager: ActorRef, address: InetSocketAddress)(implicit owner: ActorRef): SocketHandle = { + def listen(address: InetSocketAddress, owner: ActorRef)(implicit context: ActorContext): ServerHandle = + listen(address)(context, owner) + + def listen(host: String, port: Int, owner: ActorRef)(implicit context: ActorContext): ServerHandle = + listen(new InetSocketAddress(host, port))(context, owner) + + def connect(address: InetSocketAddress)(implicit context: ActorContext, owner: ActorRef): SocketHandle = { + val ioManager = IOManager.start()(context.system) val socket = SocketHandle(owner, ioManager) ioManager ! Connect(socket, address) socket } - def connect(ioManager: ActorRef, host: String, port: Int)(implicit sender: ActorRef): SocketHandle = - connect(ioManager, new InetSocketAddress(host, port)) + def connect(host: String, port: Int)(implicit context: ActorContext, owner: ActorRef): SocketHandle = + connect(new InetSocketAddress(host, port))(context, owner) - private class HandleState(var readBytes: ByteString, var connected: Boolean) { - def this() = this(ByteString.empty, false) + def connect(address: InetSocketAddress, owner: ActorRef)(implicit context: ActorContext): SocketHandle = + connect(address)(context, owner) + + def connect(host: String, port: Int, owner: ActorRef)(implicit context: ActorContext): SocketHandle = + connect(new InetSocketAddress(host, port))(context, owner) + + sealed trait Input { + def ++(that: Input): Input } - sealed trait IOSuspendable[+A] - sealed trait CurrentMessage { def message: Envelope } - private case class ByteStringLength(continuation: (ByteString) ⇒ IOSuspendable[Any], handle: Handle, message: Envelope, length: Int) extends IOSuspendable[ByteString] with CurrentMessage - private case class ByteStringDelimited(continuation: (ByteString) ⇒ IOSuspendable[Any], handle: Handle, message: Envelope, delimter: ByteString, inclusive: Boolean, scanned: Int) extends IOSuspendable[ByteString] with CurrentMessage - private case class ByteStringAny(continuation: (ByteString) ⇒ IOSuspendable[Any], handle: Handle, message: Envelope) extends IOSuspendable[ByteString] with CurrentMessage - private case class Retry(message: Envelope) extends IOSuspendable[Nothing] - private case object Idle extends IOSuspendable[Nothing] - -} - -trait IO { - this: Actor ⇒ - import IO._ - - type ReceiveIO = PartialFunction[Any, Any @cps[IOSuspendable[Any]]] - - implicit protected def ioActor: Actor with IO = this - - private val _messages: mutable.Queue[Envelope] = mutable.Queue.empty - - private var _state: Map[Handle, HandleState] = Map.empty - - private var _next: IOSuspendable[Any] = Idle - - private def state(handle: Handle): HandleState = _state.get(handle) match { - case Some(s) ⇒ s - case _ ⇒ - val s = new HandleState() - _state += (handle -> s) - s + object Chunk { + val empty = Chunk(ByteString.empty) } - final def receive: Receive = { - case Read(handle, newBytes) ⇒ - val st = state(handle) - st.readBytes ++= newBytes - run() - case Connected(socket) ⇒ - state(socket).connected = true - run() - case msg @ Closed(handle, _) ⇒ - _state -= handle // TODO: clean up better - if (_receiveIO.isDefinedAt(msg)) { - _next = reset { _receiveIO(msg); Idle } - } - run() - case msg if _next ne Idle ⇒ - _messages enqueue context.asInstanceOf[ActorCell].currentMessage - case msg if _receiveIO.isDefinedAt(msg) ⇒ - _next = reset { _receiveIO(msg); Idle } - run() - } - - def receiveIO: ReceiveIO - - def retry(): Any @cps[IOSuspendable[Any]] = - shift { _: (Any ⇒ IOSuspendable[Any]) ⇒ - _next match { - case n: CurrentMessage ⇒ Retry(n.message) - case _ ⇒ Idle - } - } - - private lazy val _receiveIO = receiveIO - - // only reinvoke messages from the original message to avoid stack overflow - private var reinvoked = false - private def reinvoke() { - if (!reinvoked && (_next eq Idle) && _messages.nonEmpty) { - try { - reinvoked = true - while ((_next eq Idle) && _messages.nonEmpty) self.asInstanceOf[LocalActorRef].underlying invoke _messages.dequeue - } finally { - reinvoked = false - } + case class Chunk(bytes: ByteString) extends Input { + def ++(that: Input) = that match { + case Chunk(more) ⇒ Chunk(bytes ++ more) + case _: EOF ⇒ that } } - @tailrec - private def run() { - _next match { - case ByteStringLength(continuation, handle, message, waitingFor) ⇒ - context.asInstanceOf[ActorCell].currentMessage = message - val st = state(handle) - if (st.readBytes.length >= waitingFor) { - val bytes = st.readBytes.take(waitingFor) //.compact - st.readBytes = st.readBytes.drop(waitingFor) - _next = continuation(bytes) - run() - } - case bsd @ ByteStringDelimited(continuation, handle, message, delimiter, inclusive, scanned) ⇒ - context.asInstanceOf[ActorCell].currentMessage = message - val st = state(handle) - val idx = st.readBytes.indexOfSlice(delimiter, scanned) - if (idx >= 0) { - val index = if (inclusive) idx + delimiter.length else idx - val bytes = st.readBytes.take(index) //.compact - st.readBytes = st.readBytes.drop(idx + delimiter.length) - _next = continuation(bytes) - run() + case class EOF(cause: Option[Exception]) extends Input { + def ++(that: Input) = this + } + + object Iteratee { + def apply[A](value: A): Iteratee[A] = Done(value) + def apply(): Iteratee[Unit] = unit + val unit: Iteratee[Unit] = Done(()) + } + + /** + * A basic Iteratee implementation of Oleg's Iteratee (http://okmij.org/ftp/Streams.html). + * No support for Enumerator or Input types other then ByteString at the moment. + */ + sealed abstract class Iteratee[+A] { + + /** + * Applies the given input to the Iteratee, returning the resulting Iteratee + * and the unused Input. + */ + final def apply(input: Input): (Iteratee[A], Input) = this match { + case Cont(f) ⇒ f(input) + case iter ⇒ (iter, input) + } + + final def get: A = this(EOF(None))._1 match { + case Done(value) ⇒ value + case Cont(_) ⇒ sys.error("Divergent Iteratee") + case Failure(e) ⇒ throw e + } + + final def flatMap[B](f: A ⇒ Iteratee[B]): Iteratee[B] = this match { + case Done(value) ⇒ f(value) + case Cont(k: Chain[_]) ⇒ Cont(k :+ f) + case Cont(k) ⇒ Cont(Chain(k, f)) + case failure: Failure ⇒ failure + } + + final def map[B](f: A ⇒ B): Iteratee[B] = this match { + case Done(value) ⇒ Done(f(value)) + case Cont(k: Chain[_]) ⇒ Cont(k :+ ((a: A) ⇒ Done(f(a)))) + case Cont(k) ⇒ Cont(Chain(k, (a: A) ⇒ Done(f(a)))) + case failure: Failure ⇒ failure + } + + } + + /** + * An Iteratee representing a result and the remaining ByteString. Also used to + * wrap any constants or precalculated values that need to be composed with + * other Iteratees. + */ + final case class Done[+A](result: A) extends Iteratee[A] + + /** + * An Iteratee that still requires more input to calculate it's result. + */ + final case class Cont[+A](f: Input ⇒ (Iteratee[A], Input)) extends Iteratee[A] + + /** + * An Iteratee representing a failure to calcualte a result. + * FIXME: move into 'Cont' as in Oleg's implementation + */ + final case class Failure(exception: Throwable) extends Iteratee[Nothing] + + object IterateeRef { + def sync[A](initial: Iteratee[A]): IterateeRefSync[A] = new IterateeRefSync(initial) + def sync(): IterateeRefSync[Unit] = new IterateeRefSync(Iteratee.unit) + + def async[A](initial: Iteratee[A])(implicit app: ActorSystem): IterateeRefAsync[A] = new IterateeRefAsync(initial) + def async()(implicit app: ActorSystem): IterateeRefAsync[Unit] = new IterateeRefAsync(Iteratee.unit) + + class Map[K, V] private (refFactory: ⇒ IterateeRef[V], underlying: mutable.Map[K, IterateeRef[V]] = mutable.Map.empty[K, IterateeRef[V]]) extends mutable.Map[K, IterateeRef[V]] { + def get(key: K) = Some(underlying.getOrElseUpdate(key, refFactory)) + def iterator = underlying.iterator + def +=(kv: (K, IterateeRef[V])) = { underlying += kv; this } + def -=(key: K) = { underlying -= key; this } + override def empty = new Map[K, V](refFactory) + } + object Map { + def apply[K, V](refFactory: ⇒ IterateeRef[V]): IterateeRef.Map[K, V] = new Map(refFactory) + def sync[K](): IterateeRef.Map[K, Unit] = new Map(IterateeRef.sync()) + def async[K]()(implicit app: ActorSystem): IterateeRef.Map[K, Unit] = new Map(IterateeRef.async()) + } + } + + /** + * A mutable reference to an Iteratee. Not thread safe. + * + * Designed for use within an Actor. + * + * Includes mutable implementations of flatMap, map, and apply which + * update the internal reference and return Unit. + */ + trait IterateeRef[A] { + def flatMap(f: A ⇒ Iteratee[A]): Unit + def map(f: A ⇒ A): Unit + def apply(input: Input): Unit + } + + final class IterateeRefSync[A](initial: Iteratee[A]) extends IterateeRef[A] { + private var _value: (Iteratee[A], Input) = (initial, Chunk.empty) + def flatMap(f: A ⇒ Iteratee[A]): Unit = _value = _value match { + case (iter, chunk @ Chunk(bytes)) if bytes.nonEmpty ⇒ (iter flatMap f)(chunk) + case (iter, input) ⇒ (iter flatMap f, input) + } + def map(f: A ⇒ A): Unit = _value = (_value._1 map f, _value._2) + def apply(input: Input): Unit = _value = _value._1(_value._2 ++ input) + def value: (Iteratee[A], Input) = _value + } + + final class IterateeRefAsync[A](initial: Iteratee[A])(implicit app: ActorSystem) extends IterateeRef[A] { + import akka.dispatch.Future + private var _value: Future[(Iteratee[A], Input)] = Future((initial, Chunk.empty)) + def flatMap(f: A ⇒ Iteratee[A]): Unit = _value = _value map { + case (iter, chunk @ Chunk(bytes)) if bytes.nonEmpty ⇒ (iter flatMap f)(chunk) + case (iter, input) ⇒ (iter flatMap f, input) + } + def map(f: A ⇒ A): Unit = _value = _value map (v ⇒ (v._1 map f, v._2)) + def apply(input: Input): Unit = _value = _value map (v ⇒ v._1(v._2 ++ input)) + def future: Future[(Iteratee[A], Input)] = _value + } + + /** + * An Iteratee that returns the ByteString prefix up until the supplied delimiter. + * The delimiter is dropped by default, but it can be returned with the result by + * setting 'inclusive' to be 'true'. + */ + def takeUntil(delimiter: ByteString, inclusive: Boolean = false): Iteratee[ByteString] = { + def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match { + case Chunk(more) ⇒ + val bytes = taken ++ more + val startIdx = bytes.indexOfSlice(delimiter, math.max(taken.length - delimiter.length, 0)) + if (startIdx >= 0) { + val endIdx = startIdx + delimiter.length + (Done(bytes take (if (inclusive) endIdx else startIdx)), Chunk(bytes drop endIdx)) } else { - _next = bsd.copy(scanned = math.min(idx - delimiter.length, 0)) + (Cont(step(bytes)), Chunk.empty) } - case ByteStringAny(continuation, handle, message) ⇒ - context.asInstanceOf[ActorCell].currentMessage = message - val st = state(handle) - if (st.readBytes.length > 0) { - val bytes = st.readBytes //.compact - st.readBytes = ByteString.empty - _next = continuation(bytes) - run() + case eof ⇒ (Cont(step(taken)), eof) + } + + Cont(step(ByteString.empty)) + } + + def takeWhile(p: (Byte) ⇒ Boolean): Iteratee[ByteString] = { + def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match { + case Chunk(more) ⇒ + val (found, rest) = more span p + if (rest.isEmpty) + (Cont(step(taken ++ found)), Chunk.empty) + else + (Done(taken ++ found), Chunk(rest)) + case eof ⇒ (Done(taken), eof) + } + + Cont(step(ByteString.empty)) + } + + /** + * An Iteratee that returns a ByteString of the requested length. + */ + def take(length: Int): Iteratee[ByteString] = { + def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match { + case Chunk(more) ⇒ + val bytes = taken ++ more + if (bytes.length >= length) + (Done(bytes.take(length)), Chunk(bytes.drop(length))) + else + (Cont(step(bytes)), Chunk.empty) + case eof ⇒ (Cont(step(taken)), eof) + } + + Cont(step(ByteString.empty)) + } + + /** + * An Iteratee that ignores the specified number of bytes. + */ + def drop(length: Int): Iteratee[Unit] = { + def step(left: Int)(input: Input): (Iteratee[Unit], Input) = input match { + case Chunk(more) ⇒ + if (left > more.length) + (Cont(step(left - more.length)), Chunk.empty) + else + (Done(), Chunk(more drop left)) + case eof ⇒ (Done(), eof) + } + + Cont(step(length)) + } + + /** + * An Iteratee that returns the remaining ByteString until an EOF is given. + */ + val takeAll: Iteratee[ByteString] = { + def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match { + case Chunk(more) ⇒ + val bytes = taken ++ more + (Cont(step(bytes)), Chunk.empty) + case eof ⇒ (Done(taken), eof) + } + + Cont(step(ByteString.empty)) + } + + /** + * An Iteratee that returns any input it receives + */ + val takeAny: Iteratee[ByteString] = Cont { + case Chunk(bytes) if bytes.nonEmpty ⇒ (Done(bytes), Chunk.empty) + case Chunk(bytes) ⇒ (takeAny, Chunk.empty) + case eof ⇒ (Done(ByteString.empty), eof) + } + + def takeList[A](length: Int)(iter: Iteratee[A]): Iteratee[List[A]] = { + def step(left: Int, list: List[A]): Iteratee[List[A]] = + if (left == 0) Done(list.reverse) + else iter flatMap (a ⇒ step(left - 1, a :: list)) + + step(length, Nil) + } + + def peek(length: Int): Iteratee[ByteString] = { + def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match { + case Chunk(more) ⇒ + val bytes = taken ++ more + if (bytes.length >= length) + (Done(bytes.take(length)), Chunk(bytes)) + else + (Cont(step(bytes)), Chunk.empty) + case eof ⇒ (Cont(step(taken)), eof) + } + + Cont(step(ByteString.empty)) + } + + def repeat(iter: Iteratee[Unit]): Iteratee[Unit] = + iter flatMap (_ ⇒ repeat(iter)) + + def traverse[A, B, M[A] <: Traversable[A]](in: M[A])(f: A ⇒ Iteratee[B])(implicit cbf: CanBuildFrom[M[A], B, M[B]]): Iteratee[M[B]] = + fold(cbf(in), in)((b, a) ⇒ f(a) map (b += _)) map (_.result) + + def fold[A, B, M[A] <: Traversable[A]](initial: B, in: M[A])(f: (B, A) ⇒ Iteratee[B]): Iteratee[B] = + (Iteratee(initial) /: in)((ib, a) ⇒ ib flatMap (b ⇒ f(b, a))) + + // private api + + private object Chain { + def apply[A](f: Input ⇒ (Iteratee[A], Input)) = new Chain[A](f, Nil, Nil) + def apply[A, B](f: Input ⇒ (Iteratee[A], Input), k: A ⇒ Iteratee[B]) = new Chain[B](f, List(k.asInstanceOf[Any ⇒ Iteratee[Any]]), Nil) + } + + /** + * A function 'ByteString => Iteratee[A]' that composes with 'A => Iteratee[B]' functions + * in a stack-friendly manner. + * + * For internal use within Iteratee. + */ + private final case class Chain[A] private (cur: Input ⇒ (Iteratee[Any], Input), queueOut: List[Any ⇒ Iteratee[Any]], queueIn: List[Any ⇒ Iteratee[Any]]) extends (Input ⇒ (Iteratee[A], Input)) { + + def :+[B](f: A ⇒ Iteratee[B]) = new Chain[B](cur, queueOut, f.asInstanceOf[Any ⇒ Iteratee[Any]] :: queueIn) + + def apply(input: Input): (Iteratee[A], Input) = { + @tailrec + def run(result: (Iteratee[Any], Input), queueOut: List[Any ⇒ Iteratee[Any]], queueIn: List[Any ⇒ Iteratee[Any]]): (Iteratee[Any], Input) = { + if (queueOut.isEmpty) { + if (queueIn.isEmpty) result + else run(result, queueIn.reverse, Nil) + } else result match { + case (Done(value), rest) ⇒ + queueOut.head(value) match { + //case Cont(Chain(f, q)) ⇒ run(f(rest), q ++ tail) <- can cause big slowdown, need to test if needed + case Cont(f) ⇒ run(f(rest), queueOut.tail, queueIn) + case iter ⇒ run((iter, rest), queueOut.tail, queueIn) + } + case (Cont(f), rest) ⇒ + (Cont(new Chain(f, queueOut, queueIn)), rest) + case _ ⇒ result } - case Retry(message) ⇒ - message +=: _messages - _next = Idle - run() - case Idle ⇒ reinvoke() + } + run(cur(input), queueOut, queueIn).asInstanceOf[(Iteratee[A], Input)] } } + } -class IOManager(bufferSize: Int = 8192) extends Actor { +object IOManager { + def start()(implicit system: ActorSystem): ActorRef = { + // TODO: Replace with better "get or create" if/when available + val ref = system.actorFor(system / "io-manager") + if (!ref.isInstanceOf[EmptyLocalActorRef]) ref else try { + system.actorOf(Props[IOManager], "io-manager") + } catch { + case _: InvalidActorNameException ⇒ ref + } + } + def stop()(implicit system: ActorSystem): Unit = { + // TODO: send shutdown message to IOManager + } + +} + +final class WriteBuffer(bufferSize: Int) { + private val _queue = new java.util.ArrayDeque[ByteString] + private val _buffer = ByteBuffer.allocate(bufferSize) + private var _length = 0 + + private def fillBuffer(): Boolean = { + while (!_queue.isEmpty && _buffer.hasRemaining) { + val next = _queue.pollFirst + val rest = next.drop(next.copyToBuffer(_buffer)) + if (rest.nonEmpty) _queue.offerFirst(rest) + } + !_buffer.hasRemaining + } + + def enqueue(elem: ByteString): this.type = { + _length += elem.length + val rest = elem.drop(elem.copyToBuffer(_buffer)) + if (rest.nonEmpty) _queue.offerLast(rest) + this + } + + def length = _length + + def isEmpty = _length == 0 + + def write(channel: WritableByteChannel with SelectableChannel): Int = { + @tailrec + def run(total: Int): Int = { + if (this.isEmpty) total + else { + val written = try { + _buffer.flip() + channel write _buffer + } finally { + // don't leave buffer in wrong state + _buffer.compact() + fillBuffer() + } + _length -= written + if (_buffer.position > 0) { + total + written + } else { + run(total + written) + } + } + } + + run(0) + } + +} + +// TODO: Support a pool of workers +final class IOManager extends Actor { import SelectionKey.{ OP_READ, OP_WRITE, OP_ACCEPT, OP_CONNECT } - import IOWorker._ - var worker: IOWorker = _ + val bufferSize = 8192 // TODO: make buffer size configurable - override def preStart { - worker = new IOWorker(context.system, self, bufferSize) - worker.start() + type ReadChannel = ReadableByteChannel with SelectableChannel + type WriteChannel = WritableByteChannel with SelectableChannel + + val selector: Selector = Selector open () + + val channels = mutable.Map.empty[IO.Handle, SelectableChannel] + + val accepted = mutable.Map.empty[IO.ServerHandle, mutable.Queue[SelectableChannel]] + + val writes = mutable.Map.empty[IO.WriteHandle, WriteBuffer] + + val closing = mutable.Set.empty[IO.Handle] + + val buffer = ByteBuffer.allocate(bufferSize) + + var lastSelect = 0 + + val selectAt = 100 // TODO: determine best value, perhaps based on throughput? Other triggers (like write queue size)? + + //val selectEveryNanos = 1000000 // nanos + + //var lastSelectNanos = System.nanoTime + + var running = false + + var selectSent = false + + var fastSelect = false + + object Select + + def run() { + if (!running) { + running = true + if (!selectSent) { + selectSent = true + self ! Select + } + } + lastSelect += 1 + if (lastSelect >= selectAt /* || (lastSelectNanos + selectEveryNanos) < System.nanoTime */ ) select() + } + + def select() { + if (selector.isOpen) { + // TODO: Make select behaviour configurable. + // Blocking 1ms reduces allocations during idle times, non blocking gives better performance. + if (fastSelect) selector.selectNow else selector.select(1) + val keys = selector.selectedKeys.iterator + fastSelect = keys.hasNext + while (keys.hasNext) { + val key = keys.next() + keys.remove() + if (key.isValid) { process(key) } + } + if (channels.isEmpty) running = false + } else { + running = false + } + //lastSelectNanos = System.nanoTime + lastSelect = 0 } def receive = { + case Select ⇒ + select() + if (running) self ! Select + selectSent = running + case IO.Listen(server, address) ⇒ val channel = ServerSocketChannel open () channel configureBlocking false - channel.socket bind address - worker(Register(server, channel, OP_ACCEPT)) + channel.socket bind (address, 1000) // TODO: make backlog configurable + channels update (server, channel) + channel register (selector, OP_ACCEPT, server) + run() case IO.Connect(socket, address) ⇒ val channel = SocketChannel open () channel configureBlocking false channel connect address - worker(Register(socket, channel, OP_CONNECT | OP_READ)) + channels update (socket, channel) + channel register (selector, OP_CONNECT | OP_READ, socket) + run() - case IO.Accept(socket, server) ⇒ worker(Accepted(socket, server)) - case IO.Write(handle, data) ⇒ worker(Write(handle, data.asByteBuffer)) - case IO.Close(handle) ⇒ worker(Close(handle)) + case IO.Accept(socket, server) ⇒ + val queue = accepted(server) + val channel = queue.dequeue() + channels update (socket, channel) + channel register (selector, OP_READ, socket) + run() + + case IO.Write(handle, data) ⇒ + if (channels contains handle) { + val queue = { + val existing = writes get handle + if (existing.isDefined) existing.get + else { + val q = new WriteBuffer(bufferSize) + writes update (handle, q) + q + } + } + if (queue.isEmpty) addOps(handle, OP_WRITE) + queue enqueue data + if (queue.length >= bufferSize) write(handle, channels(handle).asInstanceOf[WriteChannel]) + } + run() + + case IO.Close(handle: IO.WriteHandle) ⇒ + if (writes get handle filterNot (_.isEmpty) isDefined) { + closing += handle + } else { + cleanup(handle, None) + } + run() + + case IO.Close(handle) ⇒ + cleanup(handle, None) + run() } override def postStop { - worker(Shutdown) + channels.keys foreach (handle ⇒ cleanup(handle, None)) + selector.close } -} - -private[akka] object IOWorker { - sealed trait Request - case class Register(handle: IO.Handle, channel: SelectableChannel, ops: Int) extends Request - case class Accepted(socket: IO.SocketHandle, server: IO.ServerHandle) extends Request - case class Write(handle: IO.WriteHandle, data: ByteBuffer) extends Request - case class Close(handle: IO.Handle) extends Request - case object Shutdown extends Request -} - -private[akka] class IOWorker(system: ActorSystem, ioManager: ActorRef, val bufferSize: Int) { - import SelectionKey.{ OP_READ, OP_WRITE, OP_ACCEPT, OP_CONNECT } - import IOWorker._ - - type ReadChannel = ReadableByteChannel with SelectableChannel - type WriteChannel = WritableByteChannel with SelectableChannel - - implicit val optionIOManager: Some[ActorRef] = Some(ioManager) - - def apply(request: Request): Unit = - addRequest(request) - - def start(): Unit = - thread.start() - - // private - - private val selector: Selector = Selector open () - - private val _requests = new AtomicReference(List.empty[Request]) - - private var accepted = Map.empty[IO.ServerHandle, Queue[SelectableChannel]].withDefaultValue(Queue.empty) - - private var channels = Map.empty[IO.Handle, SelectableChannel] - - private var writes = Map.empty[IO.WriteHandle, Queue[ByteBuffer]].withDefaultValue(Queue.empty) - - private val buffer = ByteBuffer.allocate(bufferSize) - - private val thread = new Thread("io-worker") { - override def run() { - while (selector.isOpen) { - selector select () - val keys = selector.selectedKeys.iterator - while (keys.hasNext) { - val key = keys next () - keys remove () - if (key.isValid) { process(key) } - } - _requests.getAndSet(Nil).reverse foreach { - case Register(handle, channel, ops) ⇒ - channels += (handle -> channel) - channel register (selector, ops, handle) - case Accepted(socket, server) ⇒ - val (channel, rest) = accepted(server).dequeue - if (rest.isEmpty) accepted -= server - else accepted += (server -> rest) - channels += (socket -> channel) - channel register (selector, OP_READ, socket) - case Write(handle, data) ⇒ - if (channels contains handle) { - val queue = writes(handle) - if (queue.isEmpty) addOps(handle, OP_WRITE) - writes += (handle -> queue.enqueue(data)) - } - case Close(handle) ⇒ - cleanup(handle, None) - case Shutdown ⇒ - channels.values foreach (_.close) - selector.close - } - } - } - } - - private def process(key: SelectionKey) { + def process(key: SelectionKey) { val handle = key.attachment.asInstanceOf[IO.Handle] try { if (key.isConnectable) key.channel match { @@ -369,7 +637,8 @@ private[akka] class IOWorker(system: ActorSystem, ioManager: ActorRef, val buffe } } - private def cleanup(handle: IO.Handle, cause: Option[Exception]) { + def cleanup(handle: IO.Handle, cause: Option[Exception]) { + closing -= handle handle match { case server: IO.ServerHandle ⇒ accepted -= server case writable: IO.WriteHandle ⇒ writes -= writable @@ -378,28 +647,27 @@ private[akka] class IOWorker(system: ActorSystem, ioManager: ActorRef, val buffe case Some(channel) ⇒ channel.close channels -= handle - // TODO: what if handle.owner is no longer running? - handle.owner ! IO.Closed(handle, cause) + if (!handle.owner.isTerminated) handle.owner ! IO.Closed(handle, cause) case None ⇒ } } - private def setOps(handle: IO.Handle, ops: Int): Unit = + def setOps(handle: IO.Handle, ops: Int): Unit = channels(handle) keyFor selector interestOps ops - private def addOps(handle: IO.Handle, ops: Int) { + def addOps(handle: IO.Handle, ops: Int) { val key = channels(handle) keyFor selector val cur = key.interestOps key interestOps (cur | ops) } - private def removeOps(handle: IO.Handle, ops: Int) { + def removeOps(handle: IO.Handle, ops: Int) { val key = channels(handle) keyFor selector val cur = key.interestOps key interestOps (cur - (cur & ops)) } - private def connect(socket: IO.SocketHandle, channel: SocketChannel) { + def connect(socket: IO.SocketHandle, channel: SocketChannel) { if (channel.finishConnect) { removeOps(socket, OP_CONNECT) socket.owner ! IO.Connected(socket) @@ -409,18 +677,27 @@ private[akka] class IOWorker(system: ActorSystem, ioManager: ActorRef, val buffe } @tailrec - private def accept(server: IO.ServerHandle, channel: ServerSocketChannel) { + def accept(server: IO.ServerHandle, channel: ServerSocketChannel) { val socket = channel.accept if (socket ne null) { socket configureBlocking false - accepted += (server -> (accepted(server) enqueue socket)) + val queue = { + val existing = accepted get server + if (existing.isDefined) existing.get + else { + val q = mutable.Queue[SelectableChannel]() + accepted update (server, q) + q + } + } + queue += socket server.owner ! IO.NewClient(server) accept(server, channel) } } @tailrec - private def read(handle: IO.ReadHandle, channel: ReadChannel) { + def read(handle: IO.ReadHandle, channel: ReadChannel) { buffer.clear val readLen = channel read buffer if (readLen == -1) { @@ -432,30 +709,16 @@ private[akka] class IOWorker(system: ActorSystem, ioManager: ActorRef, val buffe } } - @tailrec - private def write(handle: IO.WriteHandle, channel: WriteChannel) { + def write(handle: IO.WriteHandle, channel: WriteChannel) { val queue = writes(handle) - if (queue.nonEmpty) { - val (buf, bufs) = queue.dequeue - val writeLen = channel write buf - if (buf.remaining == 0) { - if (bufs.isEmpty) { - writes -= handle - removeOps(handle, OP_WRITE) - } else { - writes += (handle -> bufs) - write(handle, channel) - } + queue write channel + if (queue.isEmpty) { + if (closing(handle)) { + cleanup(handle, None) + } else { + removeOps(handle, OP_WRITE) } } } - @tailrec - private def addRequest(req: Request) { - val requests = _requests.get - if (_requests compareAndSet (requests, req :: requests)) - selector wakeup () - else - addRequest(req) - } } diff --git a/akka-actor/src/main/scala/akka/util/ByteString.scala b/akka-actor/src/main/scala/akka/util/ByteString.scala index 948489c335..fb41df2429 100644 --- a/akka-actor/src/main/scala/akka/util/ByteString.scala +++ b/akka-actor/src/main/scala/akka/util/ByteString.scala @@ -1,10 +1,11 @@ package akka.util import java.nio.ByteBuffer +import java.nio.charset.Charset -import scala.collection.IndexedSeqOptimized -import scala.collection.mutable.{ Builder, ArrayBuilder } -import scala.collection.immutable.IndexedSeq +import scala.collection.{ IndexedSeqOptimized, LinearSeq } +import scala.collection.mutable.{ Builder, ArrayBuilder, WrappedArray } +import scala.collection.immutable.{ IndexedSeq, VectorBuilder } import scala.collection.generic.{ CanBuildFrom, GenericCompanion } object ByteString { @@ -30,35 +31,43 @@ object ByteString { def apply(string: String, charset: String): ByteString = ByteString1(string.getBytes(charset)) + def fromArray(array: Array[Byte], offset: Int, length: Int): ByteString = { + val copyOffset = math.max(offset, 0) + val copyLength = math.max(math.min(array.length - copyOffset, length), 0) + if (copyLength == 0) empty + else { + val copyArray = new Array[Byte](copyLength) + Array.copy(array, copyOffset, copyArray, 0, copyLength) + ByteString1(copyArray) + } + } + val empty: ByteString = ByteString1(Array.empty[Byte]) - def newBuilder: Builder[Byte, ByteString] = new ArrayBuilder.ofByte mapResult apply + def newBuilder = new ByteStringBuilder implicit def canBuildFrom = new CanBuildFrom[TraversableOnce[Byte], Byte, ByteString] { def apply(from: TraversableOnce[Byte]) = newBuilder def apply() = newBuilder } - private object ByteString1 { + private[akka] object ByteString1 { def apply(bytes: Array[Byte]) = new ByteString1(bytes) } - final class ByteString1 private (bytes: Array[Byte], startIndex: Int, endIndex: Int) extends ByteString { + final class ByteString1 private (private val bytes: Array[Byte], private val startIndex: Int, val length: Int) extends ByteString { private def this(bytes: Array[Byte]) = this(bytes, 0, bytes.length) def apply(idx: Int): Byte = bytes(checkRangeConvert(idx)) private def checkRangeConvert(index: Int) = { - val idx = index + startIndex - if (0 <= index && idx < endIndex) - idx + if (0 <= index && length > index) + index + startIndex else throw new IndexOutOfBoundsException(index.toString) } - def length: Int = endIndex - startIndex - def toArray: Array[Byte] = { val ar = new Array[Byte](length) Array.copy(bytes, startIndex, ar, 0, length) @@ -68,8 +77,7 @@ object ByteString { override def clone: ByteString = new ByteString1(toArray) def compact: ByteString = - if (startIndex == 0 && endIndex == bytes.length) this - else clone + if (length == bytes.length) this else clone def asByteBuffer: ByteBuffer = { val buffer = ByteBuffer.wrap(bytes, startIndex, length).asReadOnlyBuffer @@ -77,8 +85,8 @@ object ByteString { else buffer } - def utf8String: String = - new String(if (startIndex == 0 && endIndex == bytes.length) bytes else toArray, "UTF-8") + def decodeString(charset: String): String = + new String(if (length == bytes.length) bytes else toArray, charset) def ++(that: ByteString): ByteString = that match { case b: ByteString1 ⇒ ByteStrings(this, b) @@ -87,42 +95,50 @@ object ByteString { override def slice(from: Int, until: Int): ByteString = { val newStartIndex = math.max(from, 0) + startIndex - val newEndIndex = math.min(until, length) + startIndex - if (newEndIndex <= newStartIndex) ByteString.empty - else new ByteString1(bytes, newStartIndex, newEndIndex) + val newLength = math.min(until, length) - from + if (newLength <= 0) ByteString.empty + else new ByteString1(bytes, newStartIndex, newLength) } override def copyToArray[A >: Byte](xs: Array[A], start: Int, len: Int): Unit = Array.copy(bytes, startIndex, xs, start, math.min(math.min(length, len), xs.length - start)) + def copyToBuffer(buffer: ByteBuffer): Int = { + val copyLength = math.min(buffer.remaining, length) + if (copyLength > 0) buffer.put(bytes, startIndex, copyLength) + copyLength + } + } - private object ByteStrings { - def apply(bytestrings: Vector[ByteString1]): ByteString = new ByteStrings(bytestrings) + private[akka] object ByteStrings { + def apply(bytestrings: Vector[ByteString1]): ByteString = new ByteStrings(bytestrings, (0 /: bytestrings)(_ + _.length)) + + def apply(bytestrings: Vector[ByteString1], length: Int): ByteString = new ByteStrings(bytestrings, length) def apply(b1: ByteString1, b2: ByteString1): ByteString = compare(b1, b2) match { - case 3 ⇒ new ByteStrings(Vector(b1, b2)) + case 3 ⇒ new ByteStrings(Vector(b1, b2), b1.length + b2.length) case 2 ⇒ b2 case 1 ⇒ b1 case 0 ⇒ ByteString.empty } def apply(b: ByteString1, bs: ByteStrings): ByteString = compare(b, bs) match { - case 3 ⇒ new ByteStrings(b +: bs.bytestrings) + case 3 ⇒ new ByteStrings(b +: bs.bytestrings, bs.length + b.length) case 2 ⇒ bs case 1 ⇒ b case 0 ⇒ ByteString.empty } def apply(bs: ByteStrings, b: ByteString1): ByteString = compare(bs, b) match { - case 3 ⇒ new ByteStrings(bs.bytestrings :+ b) + case 3 ⇒ new ByteStrings(bs.bytestrings :+ b, bs.length + b.length) case 2 ⇒ b case 1 ⇒ bs case 0 ⇒ ByteString.empty } def apply(bs1: ByteStrings, bs2: ByteStrings): ByteString = compare(bs1, bs2) match { - case 3 ⇒ new ByteStrings(bs1.bytestrings ++ bs2.bytestrings) + case 3 ⇒ new ByteStrings(bs1.bytestrings ++ bs2.bytestrings, bs1.length + bs2.length) case 2 ⇒ bs2 case 1 ⇒ bs1 case 0 ⇒ ByteString.empty @@ -133,9 +149,10 @@ object ByteString { if (b1.length == 0) if (b2.length == 0) 0 else 2 else if (b2.length == 0) 1 else 3 + } - final class ByteStrings private (private val bytestrings: Vector[ByteString1]) extends ByteString { + final class ByteStrings private (val bytestrings: Vector[ByteString1], val length: Int) extends ByteString { def apply(idx: Int): Byte = if (0 <= idx && idx < length) { @@ -148,37 +165,39 @@ object ByteString { bytestrings(pos)(idx - seen) } else throw new IndexOutOfBoundsException(idx.toString) - val length: Int = (0 /: bytestrings)(_ + _.length) - override def slice(from: Int, until: Int): ByteString = { val start = math.max(from, 0) val end = math.min(until, length) if (end <= start) ByteString.empty else { + val iter = bytestrings.iterator + var cur = iter.next var pos = 0 var seen = 0 - while (from >= seen + bytestrings(pos).length) { - seen += bytestrings(pos).length + while (from >= seen + cur.length) { + seen += cur.length pos += 1 + cur = iter.next } val startpos = pos val startidx = start - seen - while (until > seen + bytestrings(pos).length) { - seen += bytestrings(pos).length + while (until > seen + cur.length) { + seen += cur.length pos += 1 + cur = iter.next } val endpos = pos val endidx = end - seen if (startpos == endpos) - bytestrings(startpos).slice(startidx, endidx) + cur.slice(startidx, endidx) else { val first = bytestrings(startpos).drop(startidx).asInstanceOf[ByteString1] - val last = bytestrings(endpos).take(endidx).asInstanceOf[ByteString1] + val last = cur.take(endidx).asInstanceOf[ByteString1] if ((endpos - startpos) == 1) - new ByteStrings(Vector(first, last)) + new ByteStrings(Vector(first, last), until - from) else - new ByteStrings(first +: bytestrings.slice(startpos + 1, endpos) :+ last) + new ByteStrings(first +: bytestrings.slice(startpos + 1, endpos) :+ last, until - from) } } } @@ -200,7 +219,16 @@ object ByteString { def asByteBuffer: ByteBuffer = compact.asByteBuffer - def utf8String: String = compact.utf8String + def decodeString(charset: String): String = compact.decodeString(charset) + + def copyToBuffer(buffer: ByteBuffer): Int = { + val copyLength = math.min(buffer.remaining, length) + val iter = bytestrings.iterator + while (iter.hasNext && buffer.hasRemaining) { + iter.next.copyToBuffer(buffer) + } + copyLength + } } } @@ -208,9 +236,92 @@ object ByteString { sealed trait ByteString extends IndexedSeq[Byte] with IndexedSeqOptimized[Byte, ByteString] { override protected[this] def newBuilder = ByteString.newBuilder def ++(that: ByteString): ByteString + def copyToBuffer(buffer: ByteBuffer): Int def compact: ByteString def asByteBuffer: ByteBuffer - def toByteBuffer: ByteBuffer = ByteBuffer.wrap(toArray) - def utf8String: String - def mapI(f: Byte ⇒ Int): ByteString = map(f andThen (_.toByte)) + final def toByteBuffer: ByteBuffer = ByteBuffer.wrap(toArray) + final def utf8String: String = decodeString("UTF-8") + def decodeString(charset: String): String + final def mapI(f: Byte ⇒ Int): ByteString = map(f andThen (_.toByte)) +} + +final class ByteStringBuilder extends Builder[Byte, ByteString] { + import ByteString.{ ByteString1, ByteStrings } + private var _length = 0 + private val _builder = new VectorBuilder[ByteString1]() + private var _temp: Array[Byte] = _ + private var _tempLength = 0 + private var _tempCapacity = 0 + + private def clearTemp() { + if (_tempLength > 0) { + val arr = new Array[Byte](_tempLength) + Array.copy(_temp, 0, arr, 0, _tempLength) + _builder += ByteString1(arr) + _tempLength = 0 + } + } + + private def resizeTemp(size: Int) { + val newtemp = new Array[Byte](size) + if (_tempLength > 0) Array.copy(_temp, 0, newtemp, 0, _tempLength) + _temp = newtemp + } + + private def ensureTempSize(size: Int) { + if (_tempCapacity < size || _tempCapacity == 0) { + var newSize = if (_tempCapacity == 0) 16 else _tempCapacity * 2 + while (newSize < size) newSize *= 2 + resizeTemp(newSize) + } + } + + def +=(elem: Byte): this.type = { + ensureTempSize(_tempLength + 1) + _temp(_tempLength) = elem + _tempLength += 1 + _length += 1 + this + } + + override def ++=(xs: TraversableOnce[Byte]): this.type = { + xs match { + case b: ByteString1 ⇒ + clearTemp() + _builder += b + _length += b.length + case bs: ByteStrings ⇒ + clearTemp() + _builder ++= bs.bytestrings + _length += bs.length + case xs: WrappedArray.ofByte ⇒ + clearTemp() + _builder += ByteString1(xs.array.clone) + _length += xs.length + case _: collection.IndexedSeq[_] ⇒ + ensureTempSize(_tempLength + xs.size) + xs.copyToArray(_temp, _tempLength) + case _ ⇒ + super.++=(xs) + } + this + } + + def clear() { + _builder.clear + _length = 0 + _tempLength = 0 + } + + def result: ByteString = + if (_length == 0) ByteString.empty + else { + clearTemp() + val bytestrings = _builder.result + if (bytestrings.size == 1) + bytestrings.head + else + ByteStrings(bytestrings, _length) + } + }