Update IO to new iteratee based api
This commit is contained in:
parent
e8f16952a6
commit
a1f9b81ce9
3 changed files with 870 additions and 479 deletions
|
|
@ -4,8 +4,6 @@
|
|||
|
||||
package akka.actor
|
||||
|
||||
import org.scalatest.BeforeAndAfterEach
|
||||
|
||||
import akka.util.ByteString
|
||||
import akka.util.cps._
|
||||
import scala.util.continuations._
|
||||
|
|
@ -13,214 +11,237 @@ import akka.testkit._
|
|||
import akka.dispatch.{ Await, Future }
|
||||
|
||||
object IOActorSpec {
|
||||
import IO._
|
||||
|
||||
class SimpleEchoServer(host: String, port: Int, ioManager: ActorRef, started: TestLatch) extends Actor {
|
||||
class SimpleEchoServer(host: String, port: Int, started: TestLatch) extends Actor {
|
||||
|
||||
import context.dispatcher
|
||||
implicit val timeout = context.system.settings.ActorTimeout
|
||||
IO listen (host, port)
|
||||
|
||||
override def preStart = {
|
||||
listen(ioManager, host, port)
|
||||
started.open()
|
||||
}
|
||||
started.open
|
||||
|
||||
def createWorker = context.actorOf(Props(new Actor with IO {
|
||||
def receiveIO = {
|
||||
case NewClient(server) ⇒
|
||||
val socket = server.accept()
|
||||
loopC {
|
||||
val bytes = socket.read()
|
||||
socket write bytes
|
||||
}
|
||||
}
|
||||
}))
|
||||
val state = IO.IterateeRef.Map.sync[IO.Handle]()
|
||||
|
||||
def receive = {
|
||||
case msg: NewClient ⇒
|
||||
createWorker forward msg
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
class SimpleEchoClient(host: String, port: Int, ioManager: ActorRef) extends Actor with IO {
|
||||
|
||||
lazy val socket: SocketHandle = connect(ioManager, host, port)(reader)
|
||||
lazy val reader: ActorRef = context.actorOf(Props({
|
||||
new Actor with IO {
|
||||
def receiveIO = {
|
||||
case length: Int ⇒
|
||||
val bytes = socket.read(length)
|
||||
sender ! bytes
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
def receiveIO = {
|
||||
case bytes: ByteString ⇒
|
||||
case IO.NewClient(server) ⇒
|
||||
val socket = server.accept()
|
||||
state(socket) flatMap { _ ⇒
|
||||
IO repeat {
|
||||
IO.takeAny map { bytes ⇒
|
||||
socket write bytes
|
||||
reader forward bytes.length
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case IO.Read(socket, bytes) ⇒
|
||||
state(socket)(IO Chunk bytes)
|
||||
|
||||
case IO.Closed(socket, cause) ⇒
|
||||
state -= socket
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
class SimpleEchoClient(host: String, port: Int) extends Actor {
|
||||
|
||||
val socket = IO connect (host, port)
|
||||
|
||||
val state = IO.IterateeRef.sync()
|
||||
|
||||
def receive = {
|
||||
|
||||
case bytes: ByteString ⇒
|
||||
val source = sender
|
||||
socket write bytes
|
||||
for {
|
||||
_ ← state
|
||||
bytes ← IO take bytes.length
|
||||
} yield source ! bytes
|
||||
|
||||
case IO.Read(socket, bytes) ⇒
|
||||
state(IO Chunk bytes)
|
||||
|
||||
case IO.Connected(socket) ⇒
|
||||
|
||||
case IO.Closed(socket, cause) ⇒
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
sealed trait KVCommand {
|
||||
def bytes: ByteString
|
||||
}
|
||||
|
||||
case class KVSet(key: String, value: String) extends KVCommand {
|
||||
val bytes = ByteString("SET " + key + " " + value.length + "\r\n" + value + "\r\n")
|
||||
}
|
||||
|
||||
case class KVGet(key: String) extends KVCommand {
|
||||
val bytes = ByteString("GET " + key + "\r\n")
|
||||
}
|
||||
|
||||
case object KVGetAll extends KVCommand {
|
||||
val bytes = ByteString("GETALL\r\n")
|
||||
}
|
||||
|
||||
// Basic Redis-style protocol
|
||||
class KVStore(host: String, port: Int, ioManager: ActorRef, started: TestLatch) extends Actor {
|
||||
class KVStore(host: String, port: Int, started: TestLatch) extends Actor {
|
||||
|
||||
import context.dispatcher
|
||||
implicit val timeout = context.system.settings.ActorTimeout
|
||||
val state = IO.IterateeRef.Map.sync[IO.Handle]()
|
||||
|
||||
var kvs: Map[String, ByteString] = Map.empty
|
||||
var kvs: Map[String, String] = Map.empty
|
||||
|
||||
override def preStart = {
|
||||
listen(ioManager, host, port)
|
||||
started.open()
|
||||
}
|
||||
IO listen (host, port)
|
||||
|
||||
def createWorker = context.actorOf(Props(new Actor with IO {
|
||||
def receiveIO = {
|
||||
case NewClient(server) ⇒
|
||||
val socket = server.accept()
|
||||
loopC {
|
||||
val cmd = socket.read(ByteString("\r\n")).utf8String
|
||||
val result = matchC(cmd.split(' ')) {
|
||||
case Array("SET", key, length) ⇒
|
||||
val value = socket read length.toInt
|
||||
server.owner ? (('set, key, value)) map ((x: Any) ⇒ ByteString("+OK\r\n"))
|
||||
case Array("GET", key) ⇒
|
||||
server.owner ? (('get, key)) map {
|
||||
case Some(b: ByteString) ⇒ ByteString("$" + b.length + "\r\n") ++ b
|
||||
case None ⇒ ByteString("$-1\r\n")
|
||||
}
|
||||
case Array("GETALL") ⇒
|
||||
server.owner ? 'getall map {
|
||||
case m: Map[_, _] ⇒
|
||||
(ByteString("*" + (m.size * 2) + "\r\n") /: m) {
|
||||
case (result, (k: String, v: ByteString)) ⇒
|
||||
val kBytes = ByteString(k)
|
||||
result ++ ByteString("$" + kBytes.length + "\r\n") ++ kBytes ++ ByteString("$" + v.length + "\r\n") ++ v
|
||||
}
|
||||
}
|
||||
}
|
||||
result recover {
|
||||
case e ⇒ ByteString("-" + e.getClass.toString + "\r\n")
|
||||
} foreach { bytes ⇒
|
||||
socket write bytes
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
started.open
|
||||
|
||||
val EOL = ByteString("\r\n")
|
||||
|
||||
def receive = {
|
||||
case msg: NewClient ⇒ createWorker forward msg
|
||||
case ('set, key: String, value: ByteString) ⇒
|
||||
kvs += (key -> value)
|
||||
sender.tell((), self)
|
||||
case ('get, key: String) ⇒ sender.tell(kvs.get(key), self)
|
||||
case 'getall ⇒ sender.tell(kvs, self)
|
||||
|
||||
case IO.NewClient(server) ⇒
|
||||
val socket = server.accept()
|
||||
state(socket) flatMap { _ ⇒
|
||||
IO repeat {
|
||||
IO takeUntil EOL map (_.utf8String split ' ') flatMap {
|
||||
|
||||
case Array("SET", key, length) ⇒
|
||||
for {
|
||||
value ← IO take length.toInt
|
||||
_ ← IO takeUntil EOL
|
||||
} yield {
|
||||
kvs += (key -> value.utf8String)
|
||||
ByteString("+OK\r\n")
|
||||
}
|
||||
|
||||
case Array("GET", key) ⇒
|
||||
IO Iteratee {
|
||||
kvs get key map { value ⇒
|
||||
ByteString("$" + value.length + "\r\n" + value + "\r\n")
|
||||
} getOrElse ByteString("$-1\r\n")
|
||||
}
|
||||
|
||||
case Array("GETALL") ⇒
|
||||
IO Iteratee {
|
||||
(ByteString("*" + (kvs.size * 2) + "\r\n") /: kvs) {
|
||||
case (result, (k, v)) ⇒
|
||||
val kBytes = ByteString(k)
|
||||
val vBytes = ByteString(v)
|
||||
result ++
|
||||
ByteString("$" + kBytes.length) ++ EOL ++
|
||||
kBytes ++ EOL ++
|
||||
ByteString("$" + vBytes.length) ++ EOL ++
|
||||
vBytes ++ EOL
|
||||
}
|
||||
}
|
||||
|
||||
} map (socket write)
|
||||
}
|
||||
}
|
||||
|
||||
case IO.Read(socket, bytes) ⇒
|
||||
state(socket)(IO Chunk bytes)
|
||||
|
||||
case IO.Closed(socket, cause) ⇒
|
||||
state -= socket
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
class KVClient(host: String, port: Int, ioManager: ActorRef) extends Actor with IO {
|
||||
class KVClient(host: String, port: Int) extends Actor {
|
||||
|
||||
import context.dispatcher
|
||||
implicit val timeout = context.system.settings.ActorTimeout
|
||||
val socket = IO connect (host, port)
|
||||
|
||||
var socket: SocketHandle = _
|
||||
val state = IO.IterateeRef.sync()
|
||||
|
||||
val EOL = ByteString("\r\n")
|
||||
|
||||
def receive = {
|
||||
case cmd: KVCommand ⇒
|
||||
val source = sender
|
||||
socket write cmd.bytes
|
||||
for {
|
||||
_ ← state
|
||||
result ← readResult
|
||||
} yield result.fold(err ⇒ source ! Status.Failure(new RuntimeException(err)), source !)
|
||||
|
||||
case IO.Read(socket, bytes) ⇒
|
||||
state(IO Chunk bytes)
|
||||
|
||||
case IO.Connected(socket) ⇒
|
||||
|
||||
case IO.Closed(socket, cause) ⇒
|
||||
|
||||
override def preStart {
|
||||
socket = connect(ioManager, host, port)
|
||||
}
|
||||
|
||||
def reply(msg: Any) = sender.tell(msg, self)
|
||||
|
||||
def receiveIO = {
|
||||
case ('set, key: String, value: ByteString) ⇒
|
||||
socket write (ByteString("SET " + key + " " + value.length + "\r\n") ++ value)
|
||||
reply(readResult)
|
||||
|
||||
case ('get, key: String) ⇒
|
||||
socket write ByteString("GET " + key + "\r\n")
|
||||
reply(readResult)
|
||||
|
||||
case 'getall ⇒
|
||||
socket write ByteString("GETALL\r\n")
|
||||
reply(readResult)
|
||||
}
|
||||
|
||||
def readResult = {
|
||||
val resultType = socket.read(1).utf8String
|
||||
resultType match {
|
||||
case "+" ⇒ socket.read(ByteString("\r\n")).utf8String
|
||||
case "-" ⇒ sys error socket.read(ByteString("\r\n")).utf8String
|
||||
def readResult: IO.Iteratee[Either[String, Any]] = {
|
||||
IO take 1 map (_.utf8String) flatMap {
|
||||
case "+" ⇒ IO takeUntil EOL map (msg ⇒ Right(msg.utf8String))
|
||||
case "-" ⇒ IO takeUntil EOL map (err ⇒ Left(err.utf8String))
|
||||
case "$" ⇒
|
||||
val length = socket.read(ByteString("\r\n")).utf8String
|
||||
socket.read(length.toInt)
|
||||
IO takeUntil EOL map (_.utf8String.toInt) flatMap {
|
||||
case -1 ⇒ IO Iteratee Right(None)
|
||||
case length ⇒
|
||||
for {
|
||||
value ← IO take length
|
||||
_ ← IO takeUntil EOL
|
||||
} yield Right(Some(value.utf8String))
|
||||
}
|
||||
case "*" ⇒
|
||||
val count = socket.read(ByteString("\r\n")).utf8String
|
||||
var result: Map[String, ByteString] = Map.empty
|
||||
repeatC(count.toInt / 2) {
|
||||
val k = readBytes
|
||||
val v = readBytes
|
||||
result += (k.utf8String -> v)
|
||||
}
|
||||
result
|
||||
case _ ⇒ sys error "Unexpected response"
|
||||
IO takeUntil EOL map (_.utf8String.toInt) flatMap {
|
||||
case -1 ⇒ IO Iteratee Right(None)
|
||||
case length ⇒
|
||||
IO.takeList(length)(readResult) map { list ⇒
|
||||
((Right(Map()): Either[String, Map[String, String]]) /: list.grouped(2)) {
|
||||
case (Right(m), List(Right(Some(k: String)), Right(Some(v: String)))) ⇒ Right(m + (k -> v))
|
||||
case (Right(_), _) ⇒ Left("Unexpected Response")
|
||||
case (left, _) ⇒ left
|
||||
}
|
||||
}
|
||||
|
||||
def readBytes = {
|
||||
val resultType = socket.read(1).utf8String
|
||||
if (resultType != "$") sys error "Unexpected response"
|
||||
val length = socket.read(ByteString("\r\n")).utf8String
|
||||
socket.read(length.toInt)
|
||||
}
|
||||
case _ ⇒ IO Iteratee Left("Unexpected Response")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner])
|
||||
class IOActorSpec extends AkkaSpec with BeforeAndAfterEach with DefaultTimeout {
|
||||
class IOActorSpec extends AkkaSpec with DefaultTimeout {
|
||||
import IOActorSpec._
|
||||
|
||||
"an IO Actor" must {
|
||||
"run echo server" in {
|
||||
val started = TestLatch(1)
|
||||
val ioManager = system.actorOf(Props(new IOManager(2))) // teeny tiny buffer
|
||||
val server = system.actorOf(Props(new SimpleEchoServer("localhost", 8064, ioManager, started)))
|
||||
Await.ready(started, timeout.duration)
|
||||
val client = system.actorOf(Props(new SimpleEchoClient("localhost", 8064, ioManager)))
|
||||
val server = system.actorOf(Props(new SimpleEchoServer("localhost", 8064, started)))
|
||||
Await.ready(started, TestLatch.DefaultTimeout)
|
||||
val client = system.actorOf(Props(new SimpleEchoClient("localhost", 8064)))
|
||||
val f1 = client ? ByteString("Hello World!1")
|
||||
val f2 = client ? ByteString("Hello World!2")
|
||||
val f3 = client ? ByteString("Hello World!3")
|
||||
Await.result(f1, timeout.duration) must equal(ByteString("Hello World!1"))
|
||||
Await.result(f2, timeout.duration) must equal(ByteString("Hello World!2"))
|
||||
Await.result(f3, timeout.duration) must equal(ByteString("Hello World!3"))
|
||||
system.stop(client)
|
||||
system.stop(server)
|
||||
system.stop(ioManager)
|
||||
Await.result(f1, TestLatch.DefaultTimeout) must equal(ByteString("Hello World!1"))
|
||||
Await.result(f2, TestLatch.DefaultTimeout) must equal(ByteString("Hello World!2"))
|
||||
Await.result(f3, TestLatch.DefaultTimeout) must equal(ByteString("Hello World!3"))
|
||||
}
|
||||
|
||||
"run echo server under high load" in {
|
||||
val started = TestLatch(1)
|
||||
val ioManager = system.actorOf(Props(new IOManager()))
|
||||
val server = system.actorOf(Props(new SimpleEchoServer("localhost", 8065, ioManager, started)))
|
||||
Await.ready(started, timeout.duration)
|
||||
val client = system.actorOf(Props(new SimpleEchoClient("localhost", 8065, ioManager)))
|
||||
val server = system.actorOf(Props(new SimpleEchoServer("localhost", 8065, started)))
|
||||
Await.ready(started, TestLatch.DefaultTimeout)
|
||||
val client = system.actorOf(Props(new SimpleEchoClient("localhost", 8065)))
|
||||
val list = List.range(0, 1000)
|
||||
val f = Future.traverse(list)(i ⇒ client ? ByteString(i.toString))
|
||||
assert(Await.result(f, timeout.duration).size === 1000)
|
||||
system.stop(client)
|
||||
system.stop(server)
|
||||
system.stop(ioManager)
|
||||
assert(Await.result(f, TestLatch.DefaultTimeout).size === 1000)
|
||||
}
|
||||
|
||||
// Not currently configurable at runtime
|
||||
/*
|
||||
"run echo server under high load with small buffer" in {
|
||||
val started = TestLatch(1)
|
||||
val ioManager = system.actorOf(Props(new IOManager(2)))
|
||||
val server = system.actorOf(Props(new SimpleEchoServer("localhost", 8066, ioManager, started)))
|
||||
Await.ready(started, timeout.duration)
|
||||
val client = system.actorOf(Props(new SimpleEchoClient("localhost", 8066, ioManager)))
|
||||
val ioManager = actorOf(new IOManager(2))
|
||||
val server = actorOf(new SimpleEchoServer("localhost", 8066, ioManager, started))
|
||||
started.await
|
||||
val client = actorOf(new SimpleEchoClient("localhost", 8066, ioManager))
|
||||
val list = List.range(0, 1000)
|
||||
val f = Future.traverse(list)(i ⇒ client ? ByteString(i.toString))
|
||||
assert(Await.result(f, timeout.duration).size === 1000)
|
||||
|
|
@ -228,32 +249,28 @@ class IOActorSpec extends AkkaSpec with BeforeAndAfterEach with DefaultTimeout {
|
|||
system.stop(server)
|
||||
system.stop(ioManager)
|
||||
}
|
||||
*/
|
||||
|
||||
"run key-value store" in {
|
||||
val started = TestLatch(1)
|
||||
val ioManager = system.actorOf(Props(new IOManager(2))) // teeny tiny buffer
|
||||
val server = system.actorOf(Props(new KVStore("localhost", 8067, ioManager, started)))
|
||||
Await.ready(started, timeout.duration)
|
||||
val client1 = system.actorOf(Props(new KVClient("localhost", 8067, ioManager)))
|
||||
val client2 = system.actorOf(Props(new KVClient("localhost", 8067, ioManager)))
|
||||
val f1 = client1 ? (('set, "hello", ByteString("World")))
|
||||
val f2 = client1 ? (('set, "test", ByteString("No one will read me")))
|
||||
val f3 = client1 ? (('get, "hello"))
|
||||
Await.ready(f2, timeout.duration)
|
||||
val f4 = client2 ? (('set, "test", ByteString("I'm a test!")))
|
||||
Await.ready(f4, timeout.duration)
|
||||
val f5 = client1 ? (('get, "test"))
|
||||
val f6 = client2 ? 'getall
|
||||
Await.result(f1, timeout.duration) must equal("OK")
|
||||
Await.result(f2, timeout.duration) must equal("OK")
|
||||
Await.result(f3, timeout.duration) must equal(ByteString("World"))
|
||||
Await.result(f4, timeout.duration) must equal("OK")
|
||||
Await.result(f5, timeout.duration) must equal(ByteString("I'm a test!"))
|
||||
Await.result(f6, timeout.duration) must equal(Map("hello" -> ByteString("World"), "test" -> ByteString("I'm a test!")))
|
||||
system.stop(client1)
|
||||
system.stop(client2)
|
||||
system.stop(server)
|
||||
system.stop(ioManager)
|
||||
val server = system.actorOf(Props(new KVStore("localhost", 8067, started)))
|
||||
Await.ready(started, TestLatch.DefaultTimeout)
|
||||
val client1 = system.actorOf(Props(new KVClient("localhost", 8067)))
|
||||
val client2 = system.actorOf(Props(new KVClient("localhost", 8067)))
|
||||
val f1 = client1 ? KVSet("hello", "World")
|
||||
val f2 = client1 ? KVSet("test", "No one will read me")
|
||||
val f3 = client1 ? KVGet("hello")
|
||||
Await.ready(f2, TestLatch.DefaultTimeout)
|
||||
val f4 = client2 ? KVSet("test", "I'm a test!")
|
||||
Await.ready(f4, TestLatch.DefaultTimeout)
|
||||
val f5 = client1 ? KVGet("test")
|
||||
val f6 = client2 ? KVGetAll
|
||||
Await.result(f1, TestLatch.DefaultTimeout) must equal("OK")
|
||||
Await.result(f2, TestLatch.DefaultTimeout) must equal("OK")
|
||||
Await.result(f3, TestLatch.DefaultTimeout) must equal(Some("World"))
|
||||
Await.result(f4, TestLatch.DefaultTimeout) must equal("OK")
|
||||
Await.result(f5, TestLatch.DefaultTimeout) must equal(Some("I'm a test!"))
|
||||
Await.result(f6, TestLatch.DefaultTimeout) must equal(Map("hello" -> "World", "test" -> "I'm a test!"))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,8 @@
|
|||
package akka.actor
|
||||
|
||||
import akka.util.ByteString
|
||||
import akka.dispatch.Envelope
|
||||
import java.net.InetSocketAddress
|
||||
import java.io.IOException
|
||||
import java.util.concurrent.atomic.AtomicReference
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.channels.{
|
||||
SelectableChannel,
|
||||
|
|
@ -20,9 +18,8 @@ import java.nio.channels.{
|
|||
CancelledKeyException
|
||||
}
|
||||
import scala.collection.mutable
|
||||
import scala.collection.immutable.Queue
|
||||
import scala.annotation.tailrec
|
||||
import scala.util.continuations._
|
||||
import scala.collection.generic.CanBuildFrom
|
||||
import com.eaio.uuid.UUID
|
||||
|
||||
object IO {
|
||||
|
|
@ -44,18 +41,6 @@ object IO {
|
|||
|
||||
sealed trait ReadHandle extends Handle with Product {
|
||||
override def asReadable = this
|
||||
|
||||
def read(len: Int)(implicit actor: Actor with IO): ByteString @cps[IOSuspendable[Any]] = shift { cont: (ByteString ⇒ IOSuspendable[Any]) ⇒
|
||||
ByteStringLength(cont, this, actor.context.asInstanceOf[ActorCell].currentMessage, len)
|
||||
}
|
||||
|
||||
def read()(implicit actor: Actor with IO): ByteString @cps[IOSuspendable[Any]] = shift { cont: (ByteString ⇒ IOSuspendable[Any]) ⇒
|
||||
ByteStringAny(cont, this, actor.context.asInstanceOf[ActorCell].currentMessage)
|
||||
}
|
||||
|
||||
def read(delimiter: ByteString, inclusive: Boolean = false)(implicit actor: Actor with IO): ByteString @cps[IOSuspendable[Any]] = shift { cont: (ByteString ⇒ IOSuspendable[Any]) ⇒
|
||||
ByteStringDelimited(cont, this, actor.context.asInstanceOf[ActorCell].currentMessage, delimiter, inclusive, 0)
|
||||
}
|
||||
}
|
||||
|
||||
sealed trait WriteHandle extends Handle with Product {
|
||||
|
|
@ -89,259 +74,542 @@ object IO {
|
|||
case class Read(handle: ReadHandle, bytes: ByteString) extends IOMessage
|
||||
case class Write(handle: WriteHandle, bytes: ByteString) extends IOMessage
|
||||
|
||||
def listen(ioManager: ActorRef, address: InetSocketAddress)(implicit owner: ActorRef): ServerHandle = {
|
||||
def listen(address: InetSocketAddress)(implicit context: ActorContext, owner: ActorRef): ServerHandle = {
|
||||
val ioManager = IOManager.start()(context.system)
|
||||
val server = ServerHandle(owner, ioManager)
|
||||
ioManager ! Listen(server, address)
|
||||
server
|
||||
}
|
||||
|
||||
def listen(ioManager: ActorRef, host: String, port: Int)(implicit owner: ActorRef): ServerHandle =
|
||||
listen(ioManager, new InetSocketAddress(host, port))
|
||||
def listen(host: String, port: Int)(implicit context: ActorContext, owner: ActorRef): ServerHandle =
|
||||
listen(new InetSocketAddress(host, port))(context, owner)
|
||||
|
||||
def connect(ioManager: ActorRef, address: InetSocketAddress)(implicit owner: ActorRef): SocketHandle = {
|
||||
def listen(address: InetSocketAddress, owner: ActorRef)(implicit context: ActorContext): ServerHandle =
|
||||
listen(address)(context, owner)
|
||||
|
||||
def listen(host: String, port: Int, owner: ActorRef)(implicit context: ActorContext): ServerHandle =
|
||||
listen(new InetSocketAddress(host, port))(context, owner)
|
||||
|
||||
def connect(address: InetSocketAddress)(implicit context: ActorContext, owner: ActorRef): SocketHandle = {
|
||||
val ioManager = IOManager.start()(context.system)
|
||||
val socket = SocketHandle(owner, ioManager)
|
||||
ioManager ! Connect(socket, address)
|
||||
socket
|
||||
}
|
||||
|
||||
def connect(ioManager: ActorRef, host: String, port: Int)(implicit sender: ActorRef): SocketHandle =
|
||||
connect(ioManager, new InetSocketAddress(host, port))
|
||||
def connect(host: String, port: Int)(implicit context: ActorContext, owner: ActorRef): SocketHandle =
|
||||
connect(new InetSocketAddress(host, port))(context, owner)
|
||||
|
||||
private class HandleState(var readBytes: ByteString, var connected: Boolean) {
|
||||
def this() = this(ByteString.empty, false)
|
||||
def connect(address: InetSocketAddress, owner: ActorRef)(implicit context: ActorContext): SocketHandle =
|
||||
connect(address)(context, owner)
|
||||
|
||||
def connect(host: String, port: Int, owner: ActorRef)(implicit context: ActorContext): SocketHandle =
|
||||
connect(new InetSocketAddress(host, port))(context, owner)
|
||||
|
||||
sealed trait Input {
|
||||
def ++(that: Input): Input
|
||||
}
|
||||
|
||||
sealed trait IOSuspendable[+A]
|
||||
sealed trait CurrentMessage { def message: Envelope }
|
||||
private case class ByteStringLength(continuation: (ByteString) ⇒ IOSuspendable[Any], handle: Handle, message: Envelope, length: Int) extends IOSuspendable[ByteString] with CurrentMessage
|
||||
private case class ByteStringDelimited(continuation: (ByteString) ⇒ IOSuspendable[Any], handle: Handle, message: Envelope, delimter: ByteString, inclusive: Boolean, scanned: Int) extends IOSuspendable[ByteString] with CurrentMessage
|
||||
private case class ByteStringAny(continuation: (ByteString) ⇒ IOSuspendable[Any], handle: Handle, message: Envelope) extends IOSuspendable[ByteString] with CurrentMessage
|
||||
private case class Retry(message: Envelope) extends IOSuspendable[Nothing]
|
||||
private case object Idle extends IOSuspendable[Nothing]
|
||||
|
||||
object Chunk {
|
||||
val empty = Chunk(ByteString.empty)
|
||||
}
|
||||
|
||||
trait IO {
|
||||
this: Actor ⇒
|
||||
import IO._
|
||||
|
||||
type ReceiveIO = PartialFunction[Any, Any @cps[IOSuspendable[Any]]]
|
||||
|
||||
implicit protected def ioActor: Actor with IO = this
|
||||
|
||||
private val _messages: mutable.Queue[Envelope] = mutable.Queue.empty
|
||||
|
||||
private var _state: Map[Handle, HandleState] = Map.empty
|
||||
|
||||
private var _next: IOSuspendable[Any] = Idle
|
||||
|
||||
private def state(handle: Handle): HandleState = _state.get(handle) match {
|
||||
case Some(s) ⇒ s
|
||||
case _ ⇒
|
||||
val s = new HandleState()
|
||||
_state += (handle -> s)
|
||||
s
|
||||
}
|
||||
|
||||
final def receive: Receive = {
|
||||
case Read(handle, newBytes) ⇒
|
||||
val st = state(handle)
|
||||
st.readBytes ++= newBytes
|
||||
run()
|
||||
case Connected(socket) ⇒
|
||||
state(socket).connected = true
|
||||
run()
|
||||
case msg @ Closed(handle, _) ⇒
|
||||
_state -= handle // TODO: clean up better
|
||||
if (_receiveIO.isDefinedAt(msg)) {
|
||||
_next = reset { _receiveIO(msg); Idle }
|
||||
}
|
||||
run()
|
||||
case msg if _next ne Idle ⇒
|
||||
_messages enqueue context.asInstanceOf[ActorCell].currentMessage
|
||||
case msg if _receiveIO.isDefinedAt(msg) ⇒
|
||||
_next = reset { _receiveIO(msg); Idle }
|
||||
run()
|
||||
}
|
||||
|
||||
def receiveIO: ReceiveIO
|
||||
|
||||
def retry(): Any @cps[IOSuspendable[Any]] =
|
||||
shift { _: (Any ⇒ IOSuspendable[Any]) ⇒
|
||||
_next match {
|
||||
case n: CurrentMessage ⇒ Retry(n.message)
|
||||
case _ ⇒ Idle
|
||||
case class Chunk(bytes: ByteString) extends Input {
|
||||
def ++(that: Input) = that match {
|
||||
case Chunk(more) ⇒ Chunk(bytes ++ more)
|
||||
case _: EOF ⇒ that
|
||||
}
|
||||
}
|
||||
|
||||
private lazy val _receiveIO = receiveIO
|
||||
|
||||
// only reinvoke messages from the original message to avoid stack overflow
|
||||
private var reinvoked = false
|
||||
private def reinvoke() {
|
||||
if (!reinvoked && (_next eq Idle) && _messages.nonEmpty) {
|
||||
try {
|
||||
reinvoked = true
|
||||
while ((_next eq Idle) && _messages.nonEmpty) self.asInstanceOf[LocalActorRef].underlying invoke _messages.dequeue
|
||||
} finally {
|
||||
reinvoked = false
|
||||
case class EOF(cause: Option[Exception]) extends Input {
|
||||
def ++(that: Input) = this
|
||||
}
|
||||
|
||||
object Iteratee {
|
||||
def apply[A](value: A): Iteratee[A] = Done(value)
|
||||
def apply(): Iteratee[Unit] = unit
|
||||
val unit: Iteratee[Unit] = Done(())
|
||||
}
|
||||
|
||||
/**
|
||||
* A basic Iteratee implementation of Oleg's Iteratee (http://okmij.org/ftp/Streams.html).
|
||||
* No support for Enumerator or Input types other then ByteString at the moment.
|
||||
*/
|
||||
sealed abstract class Iteratee[+A] {
|
||||
|
||||
/**
|
||||
* Applies the given input to the Iteratee, returning the resulting Iteratee
|
||||
* and the unused Input.
|
||||
*/
|
||||
final def apply(input: Input): (Iteratee[A], Input) = this match {
|
||||
case Cont(f) ⇒ f(input)
|
||||
case iter ⇒ (iter, input)
|
||||
}
|
||||
|
||||
final def get: A = this(EOF(None))._1 match {
|
||||
case Done(value) ⇒ value
|
||||
case Cont(_) ⇒ sys.error("Divergent Iteratee")
|
||||
case Failure(e) ⇒ throw e
|
||||
}
|
||||
|
||||
final def flatMap[B](f: A ⇒ Iteratee[B]): Iteratee[B] = this match {
|
||||
case Done(value) ⇒ f(value)
|
||||
case Cont(k: Chain[_]) ⇒ Cont(k :+ f)
|
||||
case Cont(k) ⇒ Cont(Chain(k, f))
|
||||
case failure: Failure ⇒ failure
|
||||
}
|
||||
|
||||
final def map[B](f: A ⇒ B): Iteratee[B] = this match {
|
||||
case Done(value) ⇒ Done(f(value))
|
||||
case Cont(k: Chain[_]) ⇒ Cont(k :+ ((a: A) ⇒ Done(f(a))))
|
||||
case Cont(k) ⇒ Cont(Chain(k, (a: A) ⇒ Done(f(a))))
|
||||
case failure: Failure ⇒ failure
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* An Iteratee representing a result and the remaining ByteString. Also used to
|
||||
* wrap any constants or precalculated values that need to be composed with
|
||||
* other Iteratees.
|
||||
*/
|
||||
final case class Done[+A](result: A) extends Iteratee[A]
|
||||
|
||||
/**
|
||||
* An Iteratee that still requires more input to calculate it's result.
|
||||
*/
|
||||
final case class Cont[+A](f: Input ⇒ (Iteratee[A], Input)) extends Iteratee[A]
|
||||
|
||||
/**
|
||||
* An Iteratee representing a failure to calcualte a result.
|
||||
* FIXME: move into 'Cont' as in Oleg's implementation
|
||||
*/
|
||||
final case class Failure(exception: Throwable) extends Iteratee[Nothing]
|
||||
|
||||
object IterateeRef {
|
||||
def sync[A](initial: Iteratee[A]): IterateeRefSync[A] = new IterateeRefSync(initial)
|
||||
def sync(): IterateeRefSync[Unit] = new IterateeRefSync(Iteratee.unit)
|
||||
|
||||
def async[A](initial: Iteratee[A])(implicit app: ActorSystem): IterateeRefAsync[A] = new IterateeRefAsync(initial)
|
||||
def async()(implicit app: ActorSystem): IterateeRefAsync[Unit] = new IterateeRefAsync(Iteratee.unit)
|
||||
|
||||
class Map[K, V] private (refFactory: ⇒ IterateeRef[V], underlying: mutable.Map[K, IterateeRef[V]] = mutable.Map.empty[K, IterateeRef[V]]) extends mutable.Map[K, IterateeRef[V]] {
|
||||
def get(key: K) = Some(underlying.getOrElseUpdate(key, refFactory))
|
||||
def iterator = underlying.iterator
|
||||
def +=(kv: (K, IterateeRef[V])) = { underlying += kv; this }
|
||||
def -=(key: K) = { underlying -= key; this }
|
||||
override def empty = new Map[K, V](refFactory)
|
||||
}
|
||||
object Map {
|
||||
def apply[K, V](refFactory: ⇒ IterateeRef[V]): IterateeRef.Map[K, V] = new Map(refFactory)
|
||||
def sync[K](): IterateeRef.Map[K, Unit] = new Map(IterateeRef.sync())
|
||||
def async[K]()(implicit app: ActorSystem): IterateeRef.Map[K, Unit] = new Map(IterateeRef.async())
|
||||
}
|
||||
}
|
||||
|
||||
@tailrec
|
||||
private def run() {
|
||||
_next match {
|
||||
case ByteStringLength(continuation, handle, message, waitingFor) ⇒
|
||||
context.asInstanceOf[ActorCell].currentMessage = message
|
||||
val st = state(handle)
|
||||
if (st.readBytes.length >= waitingFor) {
|
||||
val bytes = st.readBytes.take(waitingFor) //.compact
|
||||
st.readBytes = st.readBytes.drop(waitingFor)
|
||||
_next = continuation(bytes)
|
||||
run()
|
||||
/**
|
||||
* A mutable reference to an Iteratee. Not thread safe.
|
||||
*
|
||||
* Designed for use within an Actor.
|
||||
*
|
||||
* Includes mutable implementations of flatMap, map, and apply which
|
||||
* update the internal reference and return Unit.
|
||||
*/
|
||||
trait IterateeRef[A] {
|
||||
def flatMap(f: A ⇒ Iteratee[A]): Unit
|
||||
def map(f: A ⇒ A): Unit
|
||||
def apply(input: Input): Unit
|
||||
}
|
||||
case bsd @ ByteStringDelimited(continuation, handle, message, delimiter, inclusive, scanned) ⇒
|
||||
context.asInstanceOf[ActorCell].currentMessage = message
|
||||
val st = state(handle)
|
||||
val idx = st.readBytes.indexOfSlice(delimiter, scanned)
|
||||
if (idx >= 0) {
|
||||
val index = if (inclusive) idx + delimiter.length else idx
|
||||
val bytes = st.readBytes.take(index) //.compact
|
||||
st.readBytes = st.readBytes.drop(idx + delimiter.length)
|
||||
_next = continuation(bytes)
|
||||
run()
|
||||
|
||||
final class IterateeRefSync[A](initial: Iteratee[A]) extends IterateeRef[A] {
|
||||
private var _value: (Iteratee[A], Input) = (initial, Chunk.empty)
|
||||
def flatMap(f: A ⇒ Iteratee[A]): Unit = _value = _value match {
|
||||
case (iter, chunk @ Chunk(bytes)) if bytes.nonEmpty ⇒ (iter flatMap f)(chunk)
|
||||
case (iter, input) ⇒ (iter flatMap f, input)
|
||||
}
|
||||
def map(f: A ⇒ A): Unit = _value = (_value._1 map f, _value._2)
|
||||
def apply(input: Input): Unit = _value = _value._1(_value._2 ++ input)
|
||||
def value: (Iteratee[A], Input) = _value
|
||||
}
|
||||
|
||||
final class IterateeRefAsync[A](initial: Iteratee[A])(implicit app: ActorSystem) extends IterateeRef[A] {
|
||||
import akka.dispatch.Future
|
||||
private var _value: Future[(Iteratee[A], Input)] = Future((initial, Chunk.empty))
|
||||
def flatMap(f: A ⇒ Iteratee[A]): Unit = _value = _value map {
|
||||
case (iter, chunk @ Chunk(bytes)) if bytes.nonEmpty ⇒ (iter flatMap f)(chunk)
|
||||
case (iter, input) ⇒ (iter flatMap f, input)
|
||||
}
|
||||
def map(f: A ⇒ A): Unit = _value = _value map (v ⇒ (v._1 map f, v._2))
|
||||
def apply(input: Input): Unit = _value = _value map (v ⇒ v._1(v._2 ++ input))
|
||||
def future: Future[(Iteratee[A], Input)] = _value
|
||||
}
|
||||
|
||||
/**
|
||||
* An Iteratee that returns the ByteString prefix up until the supplied delimiter.
|
||||
* The delimiter is dropped by default, but it can be returned with the result by
|
||||
* setting 'inclusive' to be 'true'.
|
||||
*/
|
||||
def takeUntil(delimiter: ByteString, inclusive: Boolean = false): Iteratee[ByteString] = {
|
||||
def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match {
|
||||
case Chunk(more) ⇒
|
||||
val bytes = taken ++ more
|
||||
val startIdx = bytes.indexOfSlice(delimiter, math.max(taken.length - delimiter.length, 0))
|
||||
if (startIdx >= 0) {
|
||||
val endIdx = startIdx + delimiter.length
|
||||
(Done(bytes take (if (inclusive) endIdx else startIdx)), Chunk(bytes drop endIdx))
|
||||
} else {
|
||||
_next = bsd.copy(scanned = math.min(idx - delimiter.length, 0))
|
||||
(Cont(step(bytes)), Chunk.empty)
|
||||
}
|
||||
case ByteStringAny(continuation, handle, message) ⇒
|
||||
context.asInstanceOf[ActorCell].currentMessage = message
|
||||
val st = state(handle)
|
||||
if (st.readBytes.length > 0) {
|
||||
val bytes = st.readBytes //.compact
|
||||
st.readBytes = ByteString.empty
|
||||
_next = continuation(bytes)
|
||||
run()
|
||||
case eof ⇒ (Cont(step(taken)), eof)
|
||||
}
|
||||
case Retry(message) ⇒
|
||||
message +=: _messages
|
||||
_next = Idle
|
||||
run()
|
||||
case Idle ⇒ reinvoke()
|
||||
|
||||
Cont(step(ByteString.empty))
|
||||
}
|
||||
|
||||
def takeWhile(p: (Byte) ⇒ Boolean): Iteratee[ByteString] = {
|
||||
def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match {
|
||||
case Chunk(more) ⇒
|
||||
val (found, rest) = more span p
|
||||
if (rest.isEmpty)
|
||||
(Cont(step(taken ++ found)), Chunk.empty)
|
||||
else
|
||||
(Done(taken ++ found), Chunk(rest))
|
||||
case eof ⇒ (Done(taken), eof)
|
||||
}
|
||||
|
||||
Cont(step(ByteString.empty))
|
||||
}
|
||||
|
||||
/**
|
||||
* An Iteratee that returns a ByteString of the requested length.
|
||||
*/
|
||||
def take(length: Int): Iteratee[ByteString] = {
|
||||
def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match {
|
||||
case Chunk(more) ⇒
|
||||
val bytes = taken ++ more
|
||||
if (bytes.length >= length)
|
||||
(Done(bytes.take(length)), Chunk(bytes.drop(length)))
|
||||
else
|
||||
(Cont(step(bytes)), Chunk.empty)
|
||||
case eof ⇒ (Cont(step(taken)), eof)
|
||||
}
|
||||
|
||||
Cont(step(ByteString.empty))
|
||||
}
|
||||
|
||||
/**
|
||||
* An Iteratee that ignores the specified number of bytes.
|
||||
*/
|
||||
def drop(length: Int): Iteratee[Unit] = {
|
||||
def step(left: Int)(input: Input): (Iteratee[Unit], Input) = input match {
|
||||
case Chunk(more) ⇒
|
||||
if (left > more.length)
|
||||
(Cont(step(left - more.length)), Chunk.empty)
|
||||
else
|
||||
(Done(), Chunk(more drop left))
|
||||
case eof ⇒ (Done(), eof)
|
||||
}
|
||||
|
||||
Cont(step(length))
|
||||
}
|
||||
|
||||
/**
|
||||
* An Iteratee that returns the remaining ByteString until an EOF is given.
|
||||
*/
|
||||
val takeAll: Iteratee[ByteString] = {
|
||||
def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match {
|
||||
case Chunk(more) ⇒
|
||||
val bytes = taken ++ more
|
||||
(Cont(step(bytes)), Chunk.empty)
|
||||
case eof ⇒ (Done(taken), eof)
|
||||
}
|
||||
|
||||
Cont(step(ByteString.empty))
|
||||
}
|
||||
|
||||
/**
|
||||
* An Iteratee that returns any input it receives
|
||||
*/
|
||||
val takeAny: Iteratee[ByteString] = Cont {
|
||||
case Chunk(bytes) if bytes.nonEmpty ⇒ (Done(bytes), Chunk.empty)
|
||||
case Chunk(bytes) ⇒ (takeAny, Chunk.empty)
|
||||
case eof ⇒ (Done(ByteString.empty), eof)
|
||||
}
|
||||
|
||||
def takeList[A](length: Int)(iter: Iteratee[A]): Iteratee[List[A]] = {
|
||||
def step(left: Int, list: List[A]): Iteratee[List[A]] =
|
||||
if (left == 0) Done(list.reverse)
|
||||
else iter flatMap (a ⇒ step(left - 1, a :: list))
|
||||
|
||||
step(length, Nil)
|
||||
}
|
||||
|
||||
def peek(length: Int): Iteratee[ByteString] = {
|
||||
def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match {
|
||||
case Chunk(more) ⇒
|
||||
val bytes = taken ++ more
|
||||
if (bytes.length >= length)
|
||||
(Done(bytes.take(length)), Chunk(bytes))
|
||||
else
|
||||
(Cont(step(bytes)), Chunk.empty)
|
||||
case eof ⇒ (Cont(step(taken)), eof)
|
||||
}
|
||||
|
||||
Cont(step(ByteString.empty))
|
||||
}
|
||||
|
||||
def repeat(iter: Iteratee[Unit]): Iteratee[Unit] =
|
||||
iter flatMap (_ ⇒ repeat(iter))
|
||||
|
||||
def traverse[A, B, M[A] <: Traversable[A]](in: M[A])(f: A ⇒ Iteratee[B])(implicit cbf: CanBuildFrom[M[A], B, M[B]]): Iteratee[M[B]] =
|
||||
fold(cbf(in), in)((b, a) ⇒ f(a) map (b += _)) map (_.result)
|
||||
|
||||
def fold[A, B, M[A] <: Traversable[A]](initial: B, in: M[A])(f: (B, A) ⇒ Iteratee[B]): Iteratee[B] =
|
||||
(Iteratee(initial) /: in)((ib, a) ⇒ ib flatMap (b ⇒ f(b, a)))
|
||||
|
||||
// private api
|
||||
|
||||
private object Chain {
|
||||
def apply[A](f: Input ⇒ (Iteratee[A], Input)) = new Chain[A](f, Nil, Nil)
|
||||
def apply[A, B](f: Input ⇒ (Iteratee[A], Input), k: A ⇒ Iteratee[B]) = new Chain[B](f, List(k.asInstanceOf[Any ⇒ Iteratee[Any]]), Nil)
|
||||
}
|
||||
|
||||
/**
|
||||
* A function 'ByteString => Iteratee[A]' that composes with 'A => Iteratee[B]' functions
|
||||
* in a stack-friendly manner.
|
||||
*
|
||||
* For internal use within Iteratee.
|
||||
*/
|
||||
private final case class Chain[A] private (cur: Input ⇒ (Iteratee[Any], Input), queueOut: List[Any ⇒ Iteratee[Any]], queueIn: List[Any ⇒ Iteratee[Any]]) extends (Input ⇒ (Iteratee[A], Input)) {
|
||||
|
||||
def :+[B](f: A ⇒ Iteratee[B]) = new Chain[B](cur, queueOut, f.asInstanceOf[Any ⇒ Iteratee[Any]] :: queueIn)
|
||||
|
||||
def apply(input: Input): (Iteratee[A], Input) = {
|
||||
@tailrec
|
||||
def run(result: (Iteratee[Any], Input), queueOut: List[Any ⇒ Iteratee[Any]], queueIn: List[Any ⇒ Iteratee[Any]]): (Iteratee[Any], Input) = {
|
||||
if (queueOut.isEmpty) {
|
||||
if (queueIn.isEmpty) result
|
||||
else run(result, queueIn.reverse, Nil)
|
||||
} else result match {
|
||||
case (Done(value), rest) ⇒
|
||||
queueOut.head(value) match {
|
||||
//case Cont(Chain(f, q)) ⇒ run(f(rest), q ++ tail) <- can cause big slowdown, need to test if needed
|
||||
case Cont(f) ⇒ run(f(rest), queueOut.tail, queueIn)
|
||||
case iter ⇒ run((iter, rest), queueOut.tail, queueIn)
|
||||
}
|
||||
case (Cont(f), rest) ⇒
|
||||
(Cont(new Chain(f, queueOut, queueIn)), rest)
|
||||
case _ ⇒ result
|
||||
}
|
||||
}
|
||||
run(cur(input), queueOut, queueIn).asInstanceOf[(Iteratee[A], Input)]
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object IOManager {
|
||||
def start()(implicit system: ActorSystem): ActorRef = {
|
||||
// TODO: Replace with better "get or create" if/when available
|
||||
val ref = system.actorFor(system / "io-manager")
|
||||
if (!ref.isInstanceOf[EmptyLocalActorRef]) ref else try {
|
||||
system.actorOf(Props[IOManager], "io-manager")
|
||||
} catch {
|
||||
case _: InvalidActorNameException ⇒ ref
|
||||
}
|
||||
}
|
||||
def stop()(implicit system: ActorSystem): Unit = {
|
||||
// TODO: send shutdown message to IOManager
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
final class WriteBuffer(bufferSize: Int) {
|
||||
private val _queue = new java.util.ArrayDeque[ByteString]
|
||||
private val _buffer = ByteBuffer.allocate(bufferSize)
|
||||
private var _length = 0
|
||||
|
||||
private def fillBuffer(): Boolean = {
|
||||
while (!_queue.isEmpty && _buffer.hasRemaining) {
|
||||
val next = _queue.pollFirst
|
||||
val rest = next.drop(next.copyToBuffer(_buffer))
|
||||
if (rest.nonEmpty) _queue.offerFirst(rest)
|
||||
}
|
||||
!_buffer.hasRemaining
|
||||
}
|
||||
|
||||
def enqueue(elem: ByteString): this.type = {
|
||||
_length += elem.length
|
||||
val rest = elem.drop(elem.copyToBuffer(_buffer))
|
||||
if (rest.nonEmpty) _queue.offerLast(rest)
|
||||
this
|
||||
}
|
||||
|
||||
def length = _length
|
||||
|
||||
def isEmpty = _length == 0
|
||||
|
||||
def write(channel: WritableByteChannel with SelectableChannel): Int = {
|
||||
@tailrec
|
||||
def run(total: Int): Int = {
|
||||
if (this.isEmpty) total
|
||||
else {
|
||||
val written = try {
|
||||
_buffer.flip()
|
||||
channel write _buffer
|
||||
} finally {
|
||||
// don't leave buffer in wrong state
|
||||
_buffer.compact()
|
||||
fillBuffer()
|
||||
}
|
||||
_length -= written
|
||||
if (_buffer.position > 0) {
|
||||
total + written
|
||||
} else {
|
||||
run(total + written)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class IOManager(bufferSize: Int = 8192) extends Actor {
|
||||
run(0)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TODO: Support a pool of workers
|
||||
final class IOManager extends Actor {
|
||||
import SelectionKey.{ OP_READ, OP_WRITE, OP_ACCEPT, OP_CONNECT }
|
||||
import IOWorker._
|
||||
|
||||
var worker: IOWorker = _
|
||||
val bufferSize = 8192 // TODO: make buffer size configurable
|
||||
|
||||
override def preStart {
|
||||
worker = new IOWorker(context.system, self, bufferSize)
|
||||
worker.start()
|
||||
type ReadChannel = ReadableByteChannel with SelectableChannel
|
||||
type WriteChannel = WritableByteChannel with SelectableChannel
|
||||
|
||||
val selector: Selector = Selector open ()
|
||||
|
||||
val channels = mutable.Map.empty[IO.Handle, SelectableChannel]
|
||||
|
||||
val accepted = mutable.Map.empty[IO.ServerHandle, mutable.Queue[SelectableChannel]]
|
||||
|
||||
val writes = mutable.Map.empty[IO.WriteHandle, WriteBuffer]
|
||||
|
||||
val closing = mutable.Set.empty[IO.Handle]
|
||||
|
||||
val buffer = ByteBuffer.allocate(bufferSize)
|
||||
|
||||
var lastSelect = 0
|
||||
|
||||
val selectAt = 100 // TODO: determine best value, perhaps based on throughput? Other triggers (like write queue size)?
|
||||
|
||||
//val selectEveryNanos = 1000000 // nanos
|
||||
|
||||
//var lastSelectNanos = System.nanoTime
|
||||
|
||||
var running = false
|
||||
|
||||
var selectSent = false
|
||||
|
||||
var fastSelect = false
|
||||
|
||||
object Select
|
||||
|
||||
def run() {
|
||||
if (!running) {
|
||||
running = true
|
||||
if (!selectSent) {
|
||||
selectSent = true
|
||||
self ! Select
|
||||
}
|
||||
}
|
||||
lastSelect += 1
|
||||
if (lastSelect >= selectAt /* || (lastSelectNanos + selectEveryNanos) < System.nanoTime */ ) select()
|
||||
}
|
||||
|
||||
def select() {
|
||||
if (selector.isOpen) {
|
||||
// TODO: Make select behaviour configurable.
|
||||
// Blocking 1ms reduces allocations during idle times, non blocking gives better performance.
|
||||
if (fastSelect) selector.selectNow else selector.select(1)
|
||||
val keys = selector.selectedKeys.iterator
|
||||
fastSelect = keys.hasNext
|
||||
while (keys.hasNext) {
|
||||
val key = keys.next()
|
||||
keys.remove()
|
||||
if (key.isValid) { process(key) }
|
||||
}
|
||||
if (channels.isEmpty) running = false
|
||||
} else {
|
||||
running = false
|
||||
}
|
||||
//lastSelectNanos = System.nanoTime
|
||||
lastSelect = 0
|
||||
}
|
||||
|
||||
def receive = {
|
||||
case Select ⇒
|
||||
select()
|
||||
if (running) self ! Select
|
||||
selectSent = running
|
||||
|
||||
case IO.Listen(server, address) ⇒
|
||||
val channel = ServerSocketChannel open ()
|
||||
channel configureBlocking false
|
||||
channel.socket bind address
|
||||
worker(Register(server, channel, OP_ACCEPT))
|
||||
channel.socket bind (address, 1000) // TODO: make backlog configurable
|
||||
channels update (server, channel)
|
||||
channel register (selector, OP_ACCEPT, server)
|
||||
run()
|
||||
|
||||
case IO.Connect(socket, address) ⇒
|
||||
val channel = SocketChannel open ()
|
||||
channel configureBlocking false
|
||||
channel connect address
|
||||
worker(Register(socket, channel, OP_CONNECT | OP_READ))
|
||||
channels update (socket, channel)
|
||||
channel register (selector, OP_CONNECT | OP_READ, socket)
|
||||
run()
|
||||
|
||||
case IO.Accept(socket, server) ⇒ worker(Accepted(socket, server))
|
||||
case IO.Write(handle, data) ⇒ worker(Write(handle, data.asByteBuffer))
|
||||
case IO.Close(handle) ⇒ worker(Close(handle))
|
||||
case IO.Accept(socket, server) ⇒
|
||||
val queue = accepted(server)
|
||||
val channel = queue.dequeue()
|
||||
channels update (socket, channel)
|
||||
channel register (selector, OP_READ, socket)
|
||||
run()
|
||||
|
||||
case IO.Write(handle, data) ⇒
|
||||
if (channels contains handle) {
|
||||
val queue = {
|
||||
val existing = writes get handle
|
||||
if (existing.isDefined) existing.get
|
||||
else {
|
||||
val q = new WriteBuffer(bufferSize)
|
||||
writes update (handle, q)
|
||||
q
|
||||
}
|
||||
}
|
||||
if (queue.isEmpty) addOps(handle, OP_WRITE)
|
||||
queue enqueue data
|
||||
if (queue.length >= bufferSize) write(handle, channels(handle).asInstanceOf[WriteChannel])
|
||||
}
|
||||
run()
|
||||
|
||||
case IO.Close(handle: IO.WriteHandle) ⇒
|
||||
if (writes get handle filterNot (_.isEmpty) isDefined) {
|
||||
closing += handle
|
||||
} else {
|
||||
cleanup(handle, None)
|
||||
}
|
||||
run()
|
||||
|
||||
case IO.Close(handle) ⇒
|
||||
cleanup(handle, None)
|
||||
run()
|
||||
}
|
||||
|
||||
override def postStop {
|
||||
worker(Shutdown)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private[akka] object IOWorker {
|
||||
sealed trait Request
|
||||
case class Register(handle: IO.Handle, channel: SelectableChannel, ops: Int) extends Request
|
||||
case class Accepted(socket: IO.SocketHandle, server: IO.ServerHandle) extends Request
|
||||
case class Write(handle: IO.WriteHandle, data: ByteBuffer) extends Request
|
||||
case class Close(handle: IO.Handle) extends Request
|
||||
case object Shutdown extends Request
|
||||
}
|
||||
|
||||
private[akka] class IOWorker(system: ActorSystem, ioManager: ActorRef, val bufferSize: Int) {
|
||||
import SelectionKey.{ OP_READ, OP_WRITE, OP_ACCEPT, OP_CONNECT }
|
||||
import IOWorker._
|
||||
|
||||
type ReadChannel = ReadableByteChannel with SelectableChannel
|
||||
type WriteChannel = WritableByteChannel with SelectableChannel
|
||||
|
||||
implicit val optionIOManager: Some[ActorRef] = Some(ioManager)
|
||||
|
||||
def apply(request: Request): Unit =
|
||||
addRequest(request)
|
||||
|
||||
def start(): Unit =
|
||||
thread.start()
|
||||
|
||||
// private
|
||||
|
||||
private val selector: Selector = Selector open ()
|
||||
|
||||
private val _requests = new AtomicReference(List.empty[Request])
|
||||
|
||||
private var accepted = Map.empty[IO.ServerHandle, Queue[SelectableChannel]].withDefaultValue(Queue.empty)
|
||||
|
||||
private var channels = Map.empty[IO.Handle, SelectableChannel]
|
||||
|
||||
private var writes = Map.empty[IO.WriteHandle, Queue[ByteBuffer]].withDefaultValue(Queue.empty)
|
||||
|
||||
private val buffer = ByteBuffer.allocate(bufferSize)
|
||||
|
||||
private val thread = new Thread("io-worker") {
|
||||
override def run() {
|
||||
while (selector.isOpen) {
|
||||
selector select ()
|
||||
val keys = selector.selectedKeys.iterator
|
||||
while (keys.hasNext) {
|
||||
val key = keys next ()
|
||||
keys remove ()
|
||||
if (key.isValid) { process(key) }
|
||||
}
|
||||
_requests.getAndSet(Nil).reverse foreach {
|
||||
case Register(handle, channel, ops) ⇒
|
||||
channels += (handle -> channel)
|
||||
channel register (selector, ops, handle)
|
||||
case Accepted(socket, server) ⇒
|
||||
val (channel, rest) = accepted(server).dequeue
|
||||
if (rest.isEmpty) accepted -= server
|
||||
else accepted += (server -> rest)
|
||||
channels += (socket -> channel)
|
||||
channel register (selector, OP_READ, socket)
|
||||
case Write(handle, data) ⇒
|
||||
if (channels contains handle) {
|
||||
val queue = writes(handle)
|
||||
if (queue.isEmpty) addOps(handle, OP_WRITE)
|
||||
writes += (handle -> queue.enqueue(data))
|
||||
}
|
||||
case Close(handle) ⇒
|
||||
cleanup(handle, None)
|
||||
case Shutdown ⇒
|
||||
channels.values foreach (_.close)
|
||||
channels.keys foreach (handle ⇒ cleanup(handle, None))
|
||||
selector.close
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def process(key: SelectionKey) {
|
||||
def process(key: SelectionKey) {
|
||||
val handle = key.attachment.asInstanceOf[IO.Handle]
|
||||
try {
|
||||
if (key.isConnectable) key.channel match {
|
||||
|
|
@ -369,7 +637,8 @@ private[akka] class IOWorker(system: ActorSystem, ioManager: ActorRef, val buffe
|
|||
}
|
||||
}
|
||||
|
||||
private def cleanup(handle: IO.Handle, cause: Option[Exception]) {
|
||||
def cleanup(handle: IO.Handle, cause: Option[Exception]) {
|
||||
closing -= handle
|
||||
handle match {
|
||||
case server: IO.ServerHandle ⇒ accepted -= server
|
||||
case writable: IO.WriteHandle ⇒ writes -= writable
|
||||
|
|
@ -378,28 +647,27 @@ private[akka] class IOWorker(system: ActorSystem, ioManager: ActorRef, val buffe
|
|||
case Some(channel) ⇒
|
||||
channel.close
|
||||
channels -= handle
|
||||
// TODO: what if handle.owner is no longer running?
|
||||
handle.owner ! IO.Closed(handle, cause)
|
||||
if (!handle.owner.isTerminated) handle.owner ! IO.Closed(handle, cause)
|
||||
case None ⇒
|
||||
}
|
||||
}
|
||||
|
||||
private def setOps(handle: IO.Handle, ops: Int): Unit =
|
||||
def setOps(handle: IO.Handle, ops: Int): Unit =
|
||||
channels(handle) keyFor selector interestOps ops
|
||||
|
||||
private def addOps(handle: IO.Handle, ops: Int) {
|
||||
def addOps(handle: IO.Handle, ops: Int) {
|
||||
val key = channels(handle) keyFor selector
|
||||
val cur = key.interestOps
|
||||
key interestOps (cur | ops)
|
||||
}
|
||||
|
||||
private def removeOps(handle: IO.Handle, ops: Int) {
|
||||
def removeOps(handle: IO.Handle, ops: Int) {
|
||||
val key = channels(handle) keyFor selector
|
||||
val cur = key.interestOps
|
||||
key interestOps (cur - (cur & ops))
|
||||
}
|
||||
|
||||
private def connect(socket: IO.SocketHandle, channel: SocketChannel) {
|
||||
def connect(socket: IO.SocketHandle, channel: SocketChannel) {
|
||||
if (channel.finishConnect) {
|
||||
removeOps(socket, OP_CONNECT)
|
||||
socket.owner ! IO.Connected(socket)
|
||||
|
|
@ -409,18 +677,27 @@ private[akka] class IOWorker(system: ActorSystem, ioManager: ActorRef, val buffe
|
|||
}
|
||||
|
||||
@tailrec
|
||||
private def accept(server: IO.ServerHandle, channel: ServerSocketChannel) {
|
||||
def accept(server: IO.ServerHandle, channel: ServerSocketChannel) {
|
||||
val socket = channel.accept
|
||||
if (socket ne null) {
|
||||
socket configureBlocking false
|
||||
accepted += (server -> (accepted(server) enqueue socket))
|
||||
val queue = {
|
||||
val existing = accepted get server
|
||||
if (existing.isDefined) existing.get
|
||||
else {
|
||||
val q = mutable.Queue[SelectableChannel]()
|
||||
accepted update (server, q)
|
||||
q
|
||||
}
|
||||
}
|
||||
queue += socket
|
||||
server.owner ! IO.NewClient(server)
|
||||
accept(server, channel)
|
||||
}
|
||||
}
|
||||
|
||||
@tailrec
|
||||
private def read(handle: IO.ReadHandle, channel: ReadChannel) {
|
||||
def read(handle: IO.ReadHandle, channel: ReadChannel) {
|
||||
buffer.clear
|
||||
val readLen = channel read buffer
|
||||
if (readLen == -1) {
|
||||
|
|
@ -432,30 +709,16 @@ private[akka] class IOWorker(system: ActorSystem, ioManager: ActorRef, val buffe
|
|||
}
|
||||
}
|
||||
|
||||
@tailrec
|
||||
private def write(handle: IO.WriteHandle, channel: WriteChannel) {
|
||||
def write(handle: IO.WriteHandle, channel: WriteChannel) {
|
||||
val queue = writes(handle)
|
||||
if (queue.nonEmpty) {
|
||||
val (buf, bufs) = queue.dequeue
|
||||
val writeLen = channel write buf
|
||||
if (buf.remaining == 0) {
|
||||
if (bufs.isEmpty) {
|
||||
writes -= handle
|
||||
removeOps(handle, OP_WRITE)
|
||||
queue write channel
|
||||
if (queue.isEmpty) {
|
||||
if (closing(handle)) {
|
||||
cleanup(handle, None)
|
||||
} else {
|
||||
writes += (handle -> bufs)
|
||||
write(handle, channel)
|
||||
}
|
||||
removeOps(handle, OP_WRITE)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@tailrec
|
||||
private def addRequest(req: Request) {
|
||||
val requests = _requests.get
|
||||
if (_requests compareAndSet (requests, req :: requests))
|
||||
selector wakeup ()
|
||||
else
|
||||
addRequest(req)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
package akka.util
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.charset.Charset
|
||||
|
||||
import scala.collection.IndexedSeqOptimized
|
||||
import scala.collection.mutable.{ Builder, ArrayBuilder }
|
||||
import scala.collection.immutable.IndexedSeq
|
||||
import scala.collection.{ IndexedSeqOptimized, LinearSeq }
|
||||
import scala.collection.mutable.{ Builder, ArrayBuilder, WrappedArray }
|
||||
import scala.collection.immutable.{ IndexedSeq, VectorBuilder }
|
||||
import scala.collection.generic.{ CanBuildFrom, GenericCompanion }
|
||||
|
||||
object ByteString {
|
||||
|
|
@ -30,35 +31,43 @@ object ByteString {
|
|||
|
||||
def apply(string: String, charset: String): ByteString = ByteString1(string.getBytes(charset))
|
||||
|
||||
def fromArray(array: Array[Byte], offset: Int, length: Int): ByteString = {
|
||||
val copyOffset = math.max(offset, 0)
|
||||
val copyLength = math.max(math.min(array.length - copyOffset, length), 0)
|
||||
if (copyLength == 0) empty
|
||||
else {
|
||||
val copyArray = new Array[Byte](copyLength)
|
||||
Array.copy(array, copyOffset, copyArray, 0, copyLength)
|
||||
ByteString1(copyArray)
|
||||
}
|
||||
}
|
||||
|
||||
val empty: ByteString = ByteString1(Array.empty[Byte])
|
||||
|
||||
def newBuilder: Builder[Byte, ByteString] = new ArrayBuilder.ofByte mapResult apply
|
||||
def newBuilder = new ByteStringBuilder
|
||||
|
||||
implicit def canBuildFrom = new CanBuildFrom[TraversableOnce[Byte], Byte, ByteString] {
|
||||
def apply(from: TraversableOnce[Byte]) = newBuilder
|
||||
def apply() = newBuilder
|
||||
}
|
||||
|
||||
private object ByteString1 {
|
||||
private[akka] object ByteString1 {
|
||||
def apply(bytes: Array[Byte]) = new ByteString1(bytes)
|
||||
}
|
||||
|
||||
final class ByteString1 private (bytes: Array[Byte], startIndex: Int, endIndex: Int) extends ByteString {
|
||||
final class ByteString1 private (private val bytes: Array[Byte], private val startIndex: Int, val length: Int) extends ByteString {
|
||||
|
||||
private def this(bytes: Array[Byte]) = this(bytes, 0, bytes.length)
|
||||
|
||||
def apply(idx: Int): Byte = bytes(checkRangeConvert(idx))
|
||||
|
||||
private def checkRangeConvert(index: Int) = {
|
||||
val idx = index + startIndex
|
||||
if (0 <= index && idx < endIndex)
|
||||
idx
|
||||
if (0 <= index && length > index)
|
||||
index + startIndex
|
||||
else
|
||||
throw new IndexOutOfBoundsException(index.toString)
|
||||
}
|
||||
|
||||
def length: Int = endIndex - startIndex
|
||||
|
||||
def toArray: Array[Byte] = {
|
||||
val ar = new Array[Byte](length)
|
||||
Array.copy(bytes, startIndex, ar, 0, length)
|
||||
|
|
@ -68,8 +77,7 @@ object ByteString {
|
|||
override def clone: ByteString = new ByteString1(toArray)
|
||||
|
||||
def compact: ByteString =
|
||||
if (startIndex == 0 && endIndex == bytes.length) this
|
||||
else clone
|
||||
if (length == bytes.length) this else clone
|
||||
|
||||
def asByteBuffer: ByteBuffer = {
|
||||
val buffer = ByteBuffer.wrap(bytes, startIndex, length).asReadOnlyBuffer
|
||||
|
|
@ -77,8 +85,8 @@ object ByteString {
|
|||
else buffer
|
||||
}
|
||||
|
||||
def utf8String: String =
|
||||
new String(if (startIndex == 0 && endIndex == bytes.length) bytes else toArray, "UTF-8")
|
||||
def decodeString(charset: String): String =
|
||||
new String(if (length == bytes.length) bytes else toArray, charset)
|
||||
|
||||
def ++(that: ByteString): ByteString = that match {
|
||||
case b: ByteString1 ⇒ ByteStrings(this, b)
|
||||
|
|
@ -87,42 +95,50 @@ object ByteString {
|
|||
|
||||
override def slice(from: Int, until: Int): ByteString = {
|
||||
val newStartIndex = math.max(from, 0) + startIndex
|
||||
val newEndIndex = math.min(until, length) + startIndex
|
||||
if (newEndIndex <= newStartIndex) ByteString.empty
|
||||
else new ByteString1(bytes, newStartIndex, newEndIndex)
|
||||
val newLength = math.min(until, length) - from
|
||||
if (newLength <= 0) ByteString.empty
|
||||
else new ByteString1(bytes, newStartIndex, newLength)
|
||||
}
|
||||
|
||||
override def copyToArray[A >: Byte](xs: Array[A], start: Int, len: Int): Unit =
|
||||
Array.copy(bytes, startIndex, xs, start, math.min(math.min(length, len), xs.length - start))
|
||||
|
||||
def copyToBuffer(buffer: ByteBuffer): Int = {
|
||||
val copyLength = math.min(buffer.remaining, length)
|
||||
if (copyLength > 0) buffer.put(bytes, startIndex, copyLength)
|
||||
copyLength
|
||||
}
|
||||
|
||||
private object ByteStrings {
|
||||
def apply(bytestrings: Vector[ByteString1]): ByteString = new ByteStrings(bytestrings)
|
||||
}
|
||||
|
||||
private[akka] object ByteStrings {
|
||||
def apply(bytestrings: Vector[ByteString1]): ByteString = new ByteStrings(bytestrings, (0 /: bytestrings)(_ + _.length))
|
||||
|
||||
def apply(bytestrings: Vector[ByteString1], length: Int): ByteString = new ByteStrings(bytestrings, length)
|
||||
|
||||
def apply(b1: ByteString1, b2: ByteString1): ByteString = compare(b1, b2) match {
|
||||
case 3 ⇒ new ByteStrings(Vector(b1, b2))
|
||||
case 3 ⇒ new ByteStrings(Vector(b1, b2), b1.length + b2.length)
|
||||
case 2 ⇒ b2
|
||||
case 1 ⇒ b1
|
||||
case 0 ⇒ ByteString.empty
|
||||
}
|
||||
|
||||
def apply(b: ByteString1, bs: ByteStrings): ByteString = compare(b, bs) match {
|
||||
case 3 ⇒ new ByteStrings(b +: bs.bytestrings)
|
||||
case 3 ⇒ new ByteStrings(b +: bs.bytestrings, bs.length + b.length)
|
||||
case 2 ⇒ bs
|
||||
case 1 ⇒ b
|
||||
case 0 ⇒ ByteString.empty
|
||||
}
|
||||
|
||||
def apply(bs: ByteStrings, b: ByteString1): ByteString = compare(bs, b) match {
|
||||
case 3 ⇒ new ByteStrings(bs.bytestrings :+ b)
|
||||
case 3 ⇒ new ByteStrings(bs.bytestrings :+ b, bs.length + b.length)
|
||||
case 2 ⇒ b
|
||||
case 1 ⇒ bs
|
||||
case 0 ⇒ ByteString.empty
|
||||
}
|
||||
|
||||
def apply(bs1: ByteStrings, bs2: ByteStrings): ByteString = compare(bs1, bs2) match {
|
||||
case 3 ⇒ new ByteStrings(bs1.bytestrings ++ bs2.bytestrings)
|
||||
case 3 ⇒ new ByteStrings(bs1.bytestrings ++ bs2.bytestrings, bs1.length + bs2.length)
|
||||
case 2 ⇒ bs2
|
||||
case 1 ⇒ bs1
|
||||
case 0 ⇒ ByteString.empty
|
||||
|
|
@ -133,9 +149,10 @@ object ByteString {
|
|||
if (b1.length == 0)
|
||||
if (b2.length == 0) 0 else 2
|
||||
else if (b2.length == 0) 1 else 3
|
||||
|
||||
}
|
||||
|
||||
final class ByteStrings private (private val bytestrings: Vector[ByteString1]) extends ByteString {
|
||||
final class ByteStrings private (val bytestrings: Vector[ByteString1], val length: Int) extends ByteString {
|
||||
|
||||
def apply(idx: Int): Byte =
|
||||
if (0 <= idx && idx < length) {
|
||||
|
|
@ -148,37 +165,39 @@ object ByteString {
|
|||
bytestrings(pos)(idx - seen)
|
||||
} else throw new IndexOutOfBoundsException(idx.toString)
|
||||
|
||||
val length: Int = (0 /: bytestrings)(_ + _.length)
|
||||
|
||||
override def slice(from: Int, until: Int): ByteString = {
|
||||
val start = math.max(from, 0)
|
||||
val end = math.min(until, length)
|
||||
if (end <= start)
|
||||
ByteString.empty
|
||||
else {
|
||||
val iter = bytestrings.iterator
|
||||
var cur = iter.next
|
||||
var pos = 0
|
||||
var seen = 0
|
||||
while (from >= seen + bytestrings(pos).length) {
|
||||
seen += bytestrings(pos).length
|
||||
while (from >= seen + cur.length) {
|
||||
seen += cur.length
|
||||
pos += 1
|
||||
cur = iter.next
|
||||
}
|
||||
val startpos = pos
|
||||
val startidx = start - seen
|
||||
while (until > seen + bytestrings(pos).length) {
|
||||
seen += bytestrings(pos).length
|
||||
while (until > seen + cur.length) {
|
||||
seen += cur.length
|
||||
pos += 1
|
||||
cur = iter.next
|
||||
}
|
||||
val endpos = pos
|
||||
val endidx = end - seen
|
||||
if (startpos == endpos)
|
||||
bytestrings(startpos).slice(startidx, endidx)
|
||||
cur.slice(startidx, endidx)
|
||||
else {
|
||||
val first = bytestrings(startpos).drop(startidx).asInstanceOf[ByteString1]
|
||||
val last = bytestrings(endpos).take(endidx).asInstanceOf[ByteString1]
|
||||
val last = cur.take(endidx).asInstanceOf[ByteString1]
|
||||
if ((endpos - startpos) == 1)
|
||||
new ByteStrings(Vector(first, last))
|
||||
new ByteStrings(Vector(first, last), until - from)
|
||||
else
|
||||
new ByteStrings(first +: bytestrings.slice(startpos + 1, endpos) :+ last)
|
||||
new ByteStrings(first +: bytestrings.slice(startpos + 1, endpos) :+ last, until - from)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -200,7 +219,16 @@ object ByteString {
|
|||
|
||||
def asByteBuffer: ByteBuffer = compact.asByteBuffer
|
||||
|
||||
def utf8String: String = compact.utf8String
|
||||
def decodeString(charset: String): String = compact.decodeString(charset)
|
||||
|
||||
def copyToBuffer(buffer: ByteBuffer): Int = {
|
||||
val copyLength = math.min(buffer.remaining, length)
|
||||
val iter = bytestrings.iterator
|
||||
while (iter.hasNext && buffer.hasRemaining) {
|
||||
iter.next.copyToBuffer(buffer)
|
||||
}
|
||||
copyLength
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -208,9 +236,92 @@ object ByteString {
|
|||
sealed trait ByteString extends IndexedSeq[Byte] with IndexedSeqOptimized[Byte, ByteString] {
|
||||
override protected[this] def newBuilder = ByteString.newBuilder
|
||||
def ++(that: ByteString): ByteString
|
||||
def copyToBuffer(buffer: ByteBuffer): Int
|
||||
def compact: ByteString
|
||||
def asByteBuffer: ByteBuffer
|
||||
def toByteBuffer: ByteBuffer = ByteBuffer.wrap(toArray)
|
||||
def utf8String: String
|
||||
def mapI(f: Byte ⇒ Int): ByteString = map(f andThen (_.toByte))
|
||||
final def toByteBuffer: ByteBuffer = ByteBuffer.wrap(toArray)
|
||||
final def utf8String: String = decodeString("UTF-8")
|
||||
def decodeString(charset: String): String
|
||||
final def mapI(f: Byte ⇒ Int): ByteString = map(f andThen (_.toByte))
|
||||
}
|
||||
|
||||
final class ByteStringBuilder extends Builder[Byte, ByteString] {
|
||||
import ByteString.{ ByteString1, ByteStrings }
|
||||
private var _length = 0
|
||||
private val _builder = new VectorBuilder[ByteString1]()
|
||||
private var _temp: Array[Byte] = _
|
||||
private var _tempLength = 0
|
||||
private var _tempCapacity = 0
|
||||
|
||||
private def clearTemp() {
|
||||
if (_tempLength > 0) {
|
||||
val arr = new Array[Byte](_tempLength)
|
||||
Array.copy(_temp, 0, arr, 0, _tempLength)
|
||||
_builder += ByteString1(arr)
|
||||
_tempLength = 0
|
||||
}
|
||||
}
|
||||
|
||||
private def resizeTemp(size: Int) {
|
||||
val newtemp = new Array[Byte](size)
|
||||
if (_tempLength > 0) Array.copy(_temp, 0, newtemp, 0, _tempLength)
|
||||
_temp = newtemp
|
||||
}
|
||||
|
||||
private def ensureTempSize(size: Int) {
|
||||
if (_tempCapacity < size || _tempCapacity == 0) {
|
||||
var newSize = if (_tempCapacity == 0) 16 else _tempCapacity * 2
|
||||
while (newSize < size) newSize *= 2
|
||||
resizeTemp(newSize)
|
||||
}
|
||||
}
|
||||
|
||||
def +=(elem: Byte): this.type = {
|
||||
ensureTempSize(_tempLength + 1)
|
||||
_temp(_tempLength) = elem
|
||||
_tempLength += 1
|
||||
_length += 1
|
||||
this
|
||||
}
|
||||
|
||||
override def ++=(xs: TraversableOnce[Byte]): this.type = {
|
||||
xs match {
|
||||
case b: ByteString1 ⇒
|
||||
clearTemp()
|
||||
_builder += b
|
||||
_length += b.length
|
||||
case bs: ByteStrings ⇒
|
||||
clearTemp()
|
||||
_builder ++= bs.bytestrings
|
||||
_length += bs.length
|
||||
case xs: WrappedArray.ofByte ⇒
|
||||
clearTemp()
|
||||
_builder += ByteString1(xs.array.clone)
|
||||
_length += xs.length
|
||||
case _: collection.IndexedSeq[_] ⇒
|
||||
ensureTempSize(_tempLength + xs.size)
|
||||
xs.copyToArray(_temp, _tempLength)
|
||||
case _ ⇒
|
||||
super.++=(xs)
|
||||
}
|
||||
this
|
||||
}
|
||||
|
||||
def clear() {
|
||||
_builder.clear
|
||||
_length = 0
|
||||
_tempLength = 0
|
||||
}
|
||||
|
||||
def result: ByteString =
|
||||
if (_length == 0) ByteString.empty
|
||||
else {
|
||||
clearTemp()
|
||||
val bytestrings = _builder.result
|
||||
if (bytestrings.size == 1)
|
||||
bytestrings.head
|
||||
else
|
||||
ByteStrings(bytestrings, _length)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue