Merge pull request #608 from akka/wip-fix-io-√

2355 - Wip fix io √
This commit is contained in:
Viktor Klang (√) 2012-08-13 08:57:49 -07:00
commit a9b1158bde
3 changed files with 145 additions and 120 deletions

View file

@ -13,8 +13,8 @@ import scala.concurrent.util.duration._
import scala.util.continuations._
import akka.testkit._
import akka.dispatch.MessageDispatcher
import java.net.{ SocketAddress }
import akka.pattern.ask
import java.net.{ Socket, InetSocketAddress, InetAddress, SocketAddress }
object IOActorSpec {
@ -58,24 +58,22 @@ object IOActorSpec {
case bytes: ByteString
val source = sender
socket write bytes
state flatMap { _
IO take bytes.length map (source ! _) recover {
case e source ! Status.Failure(e)
}
}
state flatMap { _ IO take bytes.length map (source ! _) }
case IO.Read(`socket`, bytes)
state(IO Chunk bytes)
case IO.Closed(`socket`, cause)
state(IO EOF cause)
throw (cause getOrElse new RuntimeException("Socket closed"))
state(cause)
throw cause match {
case IO.Error(e) e
case _ new RuntimeException("Socket closed")
}
}
override def postStop {
socket.close()
state(IO EOF None)
state(IO EOF)
}
}
@ -180,29 +178,29 @@ object IOActorSpec {
val source = sender
socket write cmd.bytes
state flatMap { _
readResult map (source !) recover {
case e source ! Status.Failure(e)
}
readResult map (source !)
}
case IO.Read(`socket`, bytes)
state(IO Chunk bytes)
case IO.Closed(`socket`, cause)
state(IO EOF cause)
throw (cause getOrElse new RuntimeException("Socket closed"))
state(cause)
throw cause match {
case IO.Error(t) t
case _ new RuntimeException("Socket closed")
}
}
override def postStop {
socket.close()
state(IO EOF None)
state(IO.EOF)
}
def readResult: IO.Iteratee[Any] = {
IO take 1 map (_.utf8String) flatMap {
case "+" IO takeUntil EOL map (msg msg.utf8String)
case "-" IO takeUntil EOL flatMap (err IO throwErr new RuntimeException(err.utf8String))
case "-" IO takeUntil EOL flatMap (err IO.Failure(new RuntimeException(err.utf8String)))
case "$"
IO takeUntil EOL map (_.utf8String.toInt) flatMap {
case -1 IO Done None
@ -221,10 +219,10 @@ object IOActorSpec {
case (Right(m), List(Some(k: String), Some(v: String))) Right(m + (k -> v))
case (Right(_), _) Left("Unexpected Response")
case (left, _) left
} fold (msg IO throwErr new RuntimeException(msg), IO Done _)
} fold (msg IO.Failure(new RuntimeException(msg)), IO Done _)
}
}
case _ IO throwErr new RuntimeException("Unexpected Response")
case _ IO.Failure(new RuntimeException("Unexpected Response"))
}
}
}
@ -331,6 +329,48 @@ class IOActorSpec extends AkkaSpec with DefaultTimeout {
system.stop(server)
}
}
"takeUntil must fail on EOF before predicate when used with repeat" in {
val CRLF = ByteString("\r\n")
val dest = new InetSocketAddress(InetAddress.getLocalHost.getHostAddress, { val s = new java.net.ServerSocket(0); try s.getLocalPort finally s.close() })
val a = system.actorOf(Props(new Actor {
val state = IO.IterateeRef.Map.async[IO.Handle]()(context.dispatcher)
override def preStart {
IOManager(context.system) listen dest
}
def receive = {
case _: IO.Listening testActor ! "Wejkipejki"
case IO.NewClient(server)
val socket = server.accept()
state(socket) flatMap (_ IO.repeat(for (input IO.takeUntil(CRLF)) yield testActor ! input.utf8String))
case IO.Read(socket, bytes)
state(socket)(IO Chunk bytes)
case IO.Closed(socket, cause)
state(socket)(IO EOF)
state -= socket
testActor ! "eof"
}
}))
expectMsg("Wejkipejki")
val s = new Socket(dest.getAddress, dest.getPort)
try {
val expectedReceive = Seq("ole", "dole", "doff", "kinke", "lane", "koff", "ole", "dole", "dinke", "dane", "ole", "dole")
val expectedSend = expectedReceive ++ Seq("doff")
val out = s.getOutputStream
out.write(expectedSend.mkString(CRLF.utf8String).getBytes("UTF-8"))
out.flush()
for (word expectedReceive) expectMsg(word)
s.close()
expectMsg("eof")
} finally {
if (!s.isClosed) s.close()
}
}
}
}

View file

@ -28,6 +28,7 @@ import scala.collection.generic.CanBuildFrom
import java.util.UUID
import java.io.{ EOFException, IOException }
import akka.actor.IOManager.Settings
import akka.actor.IO.Chunk
/**
* IO messages and iteratees.
@ -303,7 +304,7 @@ object IO {
*
* No action is required by the receiving [[akka.actor.Actor]].
*/
case class Closed(handle: Handle, cause: Option[Exception]) extends IOMessage
case class Closed(handle: Handle, cause: Input) extends IOMessage
/**
* Message from an [[akka.actor.IOManager]] that contains bytes read from
@ -334,6 +335,9 @@ object IO {
}
object Chunk {
/**
* Represents the empty Chunk
*/
val empty: Chunk = new Chunk(ByteString.empty)
}
@ -342,10 +346,11 @@ object IO {
*/
case class Chunk(bytes: ByteString) extends Input {
final override def ++(that: Input): Input = that match {
case Chunk(more) if more.isEmpty this
case c: Chunk if bytes.isEmpty c
case Chunk(more) Chunk(bytes ++ more)
case _: EOF that
case c @ Chunk(more)
if (more.isEmpty) this
else if (bytes.isEmpty) c
else Chunk(bytes ++ more)
case other other
}
}
@ -354,12 +359,17 @@ object IO {
* stream.
*
* This will cause the [[akka.actor.IO.Iteratee]] that processes it
* to terminate early. If a cause is defined it can be 'caught' by
* Iteratee.recover() in order to handle it properly.
* to terminate early.
*/
case class EOF(cause: Option[Exception]) extends Input {
final override def ++(that: Input): Input = that
}
case object EOF extends Input { final override def ++(that: Input): Input = that }
/**
* Part of an [[akka.actor.IO.Input]] stream that represents an error in the stream.
*
* This will cause the [[akka.actor.IO.Iteratee]] that processes it
* to terminate early.
*/
case class Error(cause: Throwable) extends Input { final override def ++(that: Input): Input = that }
object Iteratee {
/**
@ -395,8 +405,8 @@ object IO {
* Iteratee and the remaining Input.
*/
final def apply(input: Input): (Iteratee[A], Input) = this match {
case Cont(f, None) f(input)
case iter (iter, input)
case Next(f) f(input)
case iter (iter, input)
}
/**
@ -408,10 +418,10 @@ object IO {
* If this Iteratee is not well behaved (does not return a result on EOF)
* then a "Divergent Iteratee" Exception will be thrown.
*/
final def get: A = this(EOF(None))._1 match {
case Done(value) value
case Cont(_, None) throw new DivergentIterateeException
case Cont(_, Some(err)) throw err
final def get: A = this(EOF)._1 match {
case Done(value) value
case Next(_) throw new DivergentIterateeException
case Failure(t) throw t
}
/**
@ -422,33 +432,21 @@ object IO {
* an Input stream.
*/
final def flatMap[B](f: A Iteratee[B]): Iteratee[B] = this match {
case Done(value) f(value)
case Cont(k: Chain[_], err) Cont(k :+ f, err)
case Cont(k, err) Cont(Chain(k, f), err)
case Done(value) f(value)
case Next(k: Chain[_]) Next(k :+ f)
case Next(k) Next(Chain(k, f))
case f: Failure f
}
/**
* Applies a function to transform the result of this Iteratee.
*/
final def map[B](f: A B): Iteratee[B] = this match {
case Done(value) Done(f(value))
case Cont(k: Chain[_], err) Cont(k :+ ((a: A) Done(f(a))), err)
case Cont(k, err) Cont(Chain(k, (a: A) Done(f(a))), err)
case Done(value) Done(f(value))
case Next(k: Chain[_]) Next(k :+ ((a: A) Done(f(a))))
case Next(k) Next(Chain(k, (a: A) Done(f(a))))
case f: Failure f
}
/**
* Provides a handler for any matching errors that may have occured while
* running this Iteratee.
*
* Errors are usually raised within the Iteratee with [[akka.actor.IO]].throwErr
* or by processing an [[akka.actor.IO.EOF]] that contains an Exception.
*/
def recover[B >: A](pf: PartialFunction[Exception, B]): Iteratee[B] = this match {
case done @ Done(_) done
case Cont(_, Some(err)) if pf isDefinedAt err Done(pf(err))
case Cont(k, err) Cont((more k(more) match { case (iter, rest) (iter recover pf, rest) }), err)
}
}
/**
@ -460,14 +458,14 @@ object IO {
/**
* An [[akka.actor.IO.Iteratee]] that still requires more input to calculate
* it's result. It may also contain an optional error, which can be handled
* with 'recover()'.
*
* It is possible to recover from an error and continue processing this
* Iteratee without losing the continuation, although that has not yet
* been tested. An example use case of this is resuming a failed download.
* it's result.
*/
final case class Cont[+A](f: Input (Iteratee[A], Input), error: Option[Exception] = None) extends Iteratee[A]
final case class Next[+A](f: Input (Iteratee[A], Input)) extends Iteratee[A]
/**
* An [[akka.actor.IO.Iteratee]] that represents an erronous end state.
*/
final case class Failure(cause: Throwable) extends Iteratee[Nothing]
//FIXME general description of what an IterateeRef is and how it is used, potentially with link to docs
object IterateeRef {
@ -598,12 +596,6 @@ object IO {
def future: Future[(Iteratee[A], Input)] = _value
}
/**
* An [[akka.actor.IO.Iteratee]] that contains an Exception. The Exception
* can be handled with Iteratee.recover().
*/
final def throwErr(err: Exception): Iteratee[Nothing] = Cont(input (throwErr(err), input), Some(err))
/**
* An Iteratee that returns the ByteString prefix up until the supplied delimiter.
* The delimiter is dropped by default, but it can be returned with the result by
@ -618,13 +610,13 @@ object IO {
val endIdx = startIdx + delimiter.length
(Done(bytes take (if (inclusive) endIdx else startIdx)), Chunk(bytes drop endIdx))
} else {
(Cont(step(bytes)), Chunk.empty)
(Next(step(bytes)), Chunk.empty)
}
case eof @ EOF(None) (Done(taken), eof)
case eof @ EOF(cause) (Cont(step(taken), cause), eof)
case EOF (Failure(new EOFException("Unexpected EOF")), EOF)
case e @ Error(cause) (Failure(cause), e)
}
Cont(step(ByteString.empty))
Next(step(ByteString.empty))
}
/**
@ -635,14 +627,14 @@ object IO {
case Chunk(more)
val (found, rest) = more span p
if (rest.isEmpty)
(Cont(step(taken ++ found)), Chunk.empty)
(Next(step(taken ++ found)), Chunk.empty)
else
(Done(taken ++ found), Chunk(rest))
case eof @ EOF(None) (Done(taken), eof)
case eof @ EOF(cause) (Cont(step(taken), cause), eof)
case EOF (Failure(new EOFException("Unexpected EOF")), EOF)
case e @ Error(cause) (Failure(cause), e)
}
Cont(step(ByteString.empty))
Next(step(ByteString.empty))
}
/**
@ -655,12 +647,12 @@ object IO {
if (bytes.length >= length)
(Done(bytes.take(length)), Chunk(bytes.drop(length)))
else
(Cont(step(bytes)), Chunk.empty)
case eof @ EOF(None) (Done(taken), eof)
case eof @ EOF(cause) (Cont(step(taken), cause), eof)
(Next(step(bytes)), Chunk.empty)
case EOF (Failure(new EOFException("Unexpected EOF")), EOF)
case e @ Error(cause) (Failure(cause), e)
}
Cont(step(ByteString.empty))
Next(step(ByteString.empty))
}
/**
@ -670,14 +662,14 @@ object IO {
def step(left: Int)(input: Input): (Iteratee[Unit], Input) = input match {
case Chunk(more)
if (left > more.length)
(Cont(step(left - more.length)), Chunk.empty)
(Next(step(left - more.length)), Chunk.empty)
else
(Done(()), Chunk(more drop left))
case eof @ EOF(None) (Done(()), eof)
case eof @ EOF(cause) (Cont(step(left), cause), eof)
case EOF (Failure(new EOFException("Unexpected EOF")), EOF)
case e @ Error(cause) (Failure(cause), e)
}
Cont(step(length))
Next(step(length))
}
/**
@ -687,22 +679,22 @@ object IO {
def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match {
case Chunk(more)
val bytes = taken ++ more
(Cont(step(bytes)), Chunk.empty)
case eof @ EOF(None) (Done(taken), eof)
case eof @ EOF(cause) (Cont(step(taken), cause), eof)
(Next(step(bytes)), Chunk.empty)
case EOF (Done(taken), EOF)
case e @ Error(cause) (Failure(cause), e)
}
Cont(step(ByteString.empty))
Next(step(ByteString.empty))
}
/**
* An Iteratee that returns any input it receives
*/
val takeAny: Iteratee[ByteString] = Cont {
val takeAny: Iteratee[ByteString] = Next {
case Chunk(bytes) if bytes.nonEmpty (Done(bytes), Chunk.empty)
case Chunk(bytes) (takeAny, Chunk.empty)
case eof @ EOF(None) (Done(ByteString.empty), eof)
case eof @ EOF(cause) (Cont(more (Done(ByteString.empty), more), cause), eof)
case EOF (Done(ByteString.empty), EOF)
case e @ Error(cause) (Failure(cause), e)
}
/**
@ -727,20 +719,18 @@ object IO {
if (bytes.length >= length)
(Done(bytes.take(length)), Chunk(bytes))
else
(Cont(step(bytes)), Chunk.empty)
case eof @ EOF(None) (Done(taken), eof)
case eof @ EOF(cause) (Cont(step(taken), cause), eof)
(Next(step(bytes)), Chunk.empty)
case EOF (Done(taken), EOF)
case e @ Error(cause) (Failure(cause), e)
}
Cont(step(ByteString.empty))
Next(step(ByteString.empty))
}
/**
* An Iteratee that continually repeats an Iteratee.
*
* FIXME TODO: Should terminate on EOF
*/
def repeat(iter: Iteratee[Unit]): Iteratee[Unit] = iter flatMap (_ repeat(iter))
def repeat[T](iter: Iteratee[T]): Iteratee[T] = iter flatMap (_ repeat(iter))
/**
* An Iteratee that applies an Iteratee to each element of a Traversable
@ -782,13 +772,11 @@ object IO {
} else result match {
case (Done(value), rest)
queueOut.head(value) match {
//case Cont(Chain(f, q)) run(f(rest), q ++ tail) <- can cause big slowdown, need to test if needed
case Cont(f, None) run(f(rest), queueOut.tail, queueIn)
case iter run((iter, rest), queueOut.tail, queueIn)
case Next(f) run(f(rest), queueOut.tail, queueIn)
case iter run((iter, rest), queueOut.tail, queueIn)
}
case (Cont(f, None), rest)
(Cont(new Chain(f, queueOut, queueIn)), rest)
case _ result
case (Next(f), rest) (Next(new Chain(f, queueOut, queueIn)), rest)
case _ result
}
}
run(cur(input), queueOut, queueIn).asInstanceOf[(Iteratee[A], Input)]
@ -815,7 +803,7 @@ object IO {
* An IOManager does not need to be manually stopped when not in use as it will
* automatically enter an idle state when it has no channels to manage.
*/
final class IOManager private (system: ExtendedActorSystem) extends Extension { //FIXME how about taking an ActorContext
final class IOManager private (system: ExtendedActorSystem) extends Extension { //FIXME how about taking an ActorNextext
val settings: Settings = {
val c = system.settings.config.getConfig("akka.io")
Settings(
@ -962,7 +950,7 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi
private var fastSelect = false
/** unique message that is sent to ourself to initiate the next select */
private object Select
private case object Select
/** This method should be called after receiving any message */
private def run() {
@ -1080,16 +1068,16 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi
case IO.Close(handle: IO.WriteHandle)
//If we still have pending writes, add to set of closing handles
if (writes get handle exists (_.isEmpty == false)) closing += handle
else cleanup(handle, None)
else cleanup(handle, IO.EOF)
run()
case IO.Close(handle)
cleanup(handle, None)
cleanup(handle, IO.EOF)
run()
}
override def postStop {
channels.keys foreach (handle cleanup(handle, None))
channels.keys foreach (handle cleanup(handle, IO.EOF))
selector.close
}
@ -1102,11 +1090,11 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi
if (key.isWritable) key.channel match { case channel: WriteChannel try write(handle.asWritable, channel) catch { case e: IOException } } // ignore, let it fail on read to ensure nothing left in read buffer.
} catch {
case e @ (_: ClassCastException | _: CancelledKeyException | _: IOException | _: ActorInitializationException)
cleanup(handle, Some(e.asInstanceOf[Exception])) //Scala patmat is broken
cleanup(handle, IO.Error(e)) //Scala patmat is broken
}
}
private def cleanup(handle: IO.Handle, cause: Option[Exception]) {
private def cleanup(handle: IO.Handle, cause: IO.Input) {
closing -= handle
handle match {
case server: IO.ServerHandle accepted -= server
@ -1138,7 +1126,7 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi
removeOps(socket, OP_CONNECT)
socket.owner ! IO.Connected(socket, channel.socket.getRemoteSocketAddress())
} else {
cleanup(socket, Some(new IllegalStateException("Channel for socket handle [%s] didn't finish connect" format socket)))
cleanup(socket, IO.Error(new IllegalStateException("Channel for socket handle [%s] didn't finish connect" format socket)))
}
}
@ -1167,7 +1155,7 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi
buffer.clear
val readLen = channel read buffer
if (readLen == -1) {
cleanup(handle, Some(new EOFException("Elvis has left the building")))
cleanup(handle, IO.EOF)
} else if (readLen > 0) {
buffer.flip
handle.owner ! IO.Read(handle, ByteString(buffer))
@ -1179,11 +1167,8 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi
val queue = writes(handle)
queue write channel
if (queue.isEmpty) {
if (closing(handle)) {
cleanup(handle, None)
} else {
removeOps(handle, OP_WRITE)
}
if (closing(handle)) cleanup(handle, IO.EOF)
else removeOps(handle, OP_WRITE)
}
}
}

View file

@ -30,7 +30,7 @@ class HttpServer(port: Int) extends Actor {
state(socket)(IO Chunk bytes)
case IO.Closed(socket, cause)
state(socket)(IO EOF None)
state(socket)(IO EOF)
state -= socket
}