diff --git a/akka-actor-tests/src/test/scala/akka/actor/IOActor.scala b/akka-actor-tests/src/test/scala/akka/actor/IOActor.scala index 9d2af4ec7f..4ad197c1c3 100644 --- a/akka-actor-tests/src/test/scala/akka/actor/IOActor.scala +++ b/akka-actor-tests/src/test/scala/akka/actor/IOActor.scala @@ -13,8 +13,8 @@ import scala.concurrent.util.duration._ import scala.util.continuations._ import akka.testkit._ import akka.dispatch.MessageDispatcher -import java.net.{ SocketAddress } import akka.pattern.ask +import java.net.{ Socket, InetSocketAddress, InetAddress, SocketAddress } object IOActorSpec { @@ -58,24 +58,22 @@ object IOActorSpec { case bytes: ByteString ⇒ val source = sender socket write bytes - state flatMap { _ ⇒ - IO take bytes.length map (source ! _) recover { - case e ⇒ source ! Status.Failure(e) - } - } + state flatMap { _ ⇒ IO take bytes.length map (source ! _) } case IO.Read(`socket`, bytes) ⇒ state(IO Chunk bytes) case IO.Closed(`socket`, cause) ⇒ - state(IO EOF cause) - throw (cause getOrElse new RuntimeException("Socket closed")) - + state(cause) + throw cause match { + case IO.Error(e) ⇒ e + case _ ⇒ new RuntimeException("Socket closed") + } } override def postStop { socket.close() - state(IO EOF None) + state(IO EOF) } } @@ -180,29 +178,29 @@ object IOActorSpec { val source = sender socket write cmd.bytes state flatMap { _ ⇒ - readResult map (source !) recover { - case e ⇒ source ! Status.Failure(e) - } + readResult map (source !) } case IO.Read(`socket`, bytes) ⇒ state(IO Chunk bytes) case IO.Closed(`socket`, cause) ⇒ - state(IO EOF cause) - throw (cause getOrElse new RuntimeException("Socket closed")) - + state(cause) + throw cause match { + case IO.Error(t) ⇒ t + case _ ⇒ new RuntimeException("Socket closed") + } } override def postStop { socket.close() - state(IO EOF None) + state(IO.EOF) } def readResult: IO.Iteratee[Any] = { IO take 1 map (_.utf8String) flatMap { case "+" ⇒ IO takeUntil EOL map (msg ⇒ msg.utf8String) - case "-" ⇒ IO takeUntil EOL flatMap (err ⇒ IO throwErr new RuntimeException(err.utf8String)) + case "-" ⇒ IO takeUntil EOL flatMap (err ⇒ IO.Failure(new RuntimeException(err.utf8String))) case "$" ⇒ IO takeUntil EOL map (_.utf8String.toInt) flatMap { case -1 ⇒ IO Done None @@ -221,10 +219,10 @@ object IOActorSpec { case (Right(m), List(Some(k: String), Some(v: String))) ⇒ Right(m + (k -> v)) case (Right(_), _) ⇒ Left("Unexpected Response") case (left, _) ⇒ left - } fold (msg ⇒ IO throwErr new RuntimeException(msg), IO Done _) + } fold (msg ⇒ IO.Failure(new RuntimeException(msg)), IO Done _) } } - case _ ⇒ IO throwErr new RuntimeException("Unexpected Response") + case _ ⇒ IO.Failure(new RuntimeException("Unexpected Response")) } } } @@ -331,6 +329,46 @@ class IOActorSpec extends AkkaSpec with DefaultTimeout { system.stop(server) } } + + "takeUntil must fail on EOF before predicate when used with repeat" in { + val CRLF = ByteString("\r\n") + val dest = new InetSocketAddress(InetAddress.getLocalHost.getHostAddress, { val s = new java.net.ServerSocket(0); try s.getLocalPort finally s.close() }) + + val a = system.actorOf(Props(new Actor { + val state = IO.IterateeRef.Map.async[IO.Handle]()(context.dispatcher) + + override def preStart { + IOManager(context.system) listen dest + } + + def receive = { + case _: IO.Listening ⇒ testActor ! "Wejkipejki" + case IO.NewClient(server) ⇒ + val socket = server.accept() + state(socket) flatMap (_ ⇒ IO.repeat(for (input ← IO.takeUntil(CRLF)) yield testActor ! input.utf8String)) + + case IO.Read(socket, bytes) ⇒ + state(socket)(IO Chunk bytes) + + case IO.Closed(socket, cause) ⇒ + state(socket)(IO EOF) + state -= socket + } + })) + expectMsg("Wejkipejki") + val s = new Socket(dest.getAddress, dest.getPort) + try { + val expected = Seq("ole", "dole", "doff", "kinke", "lane", "koff", "ole", "dole", "dinke", "dane", "ole", "dole", "doff") + val out = s.getOutputStream + out.write(expected.mkString("", CRLF.utf8String, CRLF.utf8String).getBytes("UTF-8")) + out.flush() + for (word ← expected) expectMsg(word) + s.close() + expectNoMsg(500.millis) + } finally { + if (!s.isClosed) s.close() + } + } } } diff --git a/akka-actor/src/main/scala/akka/actor/IO.scala b/akka-actor/src/main/scala/akka/actor/IO.scala index db9198591e..8c104e9a18 100644 --- a/akka-actor/src/main/scala/akka/actor/IO.scala +++ b/akka-actor/src/main/scala/akka/actor/IO.scala @@ -28,6 +28,7 @@ import scala.collection.generic.CanBuildFrom import java.util.UUID import java.io.{ EOFException, IOException } import akka.actor.IOManager.Settings +import akka.actor.IO.Chunk /** * IO messages and iteratees. @@ -303,7 +304,7 @@ object IO { * * No action is required by the receiving [[akka.actor.Actor]]. */ - case class Closed(handle: Handle, cause: Option[Exception]) extends IOMessage + case class Closed(handle: Handle, cause: Input) extends IOMessage /** * Message from an [[akka.actor.IOManager]] that contains bytes read from @@ -334,6 +335,9 @@ object IO { } object Chunk { + /** + * Represents the empty Chunk + */ val empty: Chunk = new Chunk(ByteString.empty) } @@ -342,10 +346,11 @@ object IO { */ case class Chunk(bytes: ByteString) extends Input { final override def ++(that: Input): Input = that match { - case Chunk(more) if more.isEmpty ⇒ this - case c: Chunk if bytes.isEmpty ⇒ c - case Chunk(more) ⇒ Chunk(bytes ++ more) - case _: EOF ⇒ that + case c @ Chunk(more) ⇒ + if (more.isEmpty) this + else if (bytes.isEmpty) c + else Chunk(bytes ++ more) + case other ⇒ other } } @@ -354,12 +359,17 @@ object IO { * stream. * * This will cause the [[akka.actor.IO.Iteratee]] that processes it - * to terminate early. If a cause is defined it can be 'caught' by - * Iteratee.recover() in order to handle it properly. + * to terminate early. */ - case class EOF(cause: Option[Exception]) extends Input { - final override def ++(that: Input): Input = that - } + case object EOF extends Input { final override def ++(that: Input): Input = that } + + /** + * Part of an [[akka.actor.IO.Input]] stream that represents an error in the stream. + * + * This will cause the [[akka.actor.IO.Iteratee]] that processes it + * to terminate early. + */ + case class Error(cause: Throwable) extends Input { final override def ++(that: Input): Input = that } object Iteratee { /** @@ -395,8 +405,8 @@ object IO { * Iteratee and the remaining Input. */ final def apply(input: Input): (Iteratee[A], Input) = this match { - case Cont(f, None) ⇒ f(input) - case iter ⇒ (iter, input) + case Next(f) ⇒ f(input) + case iter ⇒ (iter, input) } /** @@ -408,10 +418,10 @@ object IO { * If this Iteratee is not well behaved (does not return a result on EOF) * then a "Divergent Iteratee" Exception will be thrown. */ - final def get: A = this(EOF(None))._1 match { - case Done(value) ⇒ value - case Cont(_, None) ⇒ throw new DivergentIterateeException - case Cont(_, Some(err)) ⇒ throw err + final def get: A = this(EOF)._1 match { + case Done(value) ⇒ value + case Next(_) ⇒ throw new DivergentIterateeException + case Failure(t) ⇒ throw t } /** @@ -422,33 +432,21 @@ object IO { * an Input stream. */ final def flatMap[B](f: A ⇒ Iteratee[B]): Iteratee[B] = this match { - case Done(value) ⇒ f(value) - case Cont(k: Chain[_], err) ⇒ Cont(k :+ f, err) - case Cont(k, err) ⇒ Cont(Chain(k, f), err) + case Done(value) ⇒ f(value) + case Next(k: Chain[_]) ⇒ Next(k :+ f) + case Next(k) ⇒ Next(Chain(k, f)) + case f: Failure ⇒ f } /** * Applies a function to transform the result of this Iteratee. */ final def map[B](f: A ⇒ B): Iteratee[B] = this match { - case Done(value) ⇒ Done(f(value)) - case Cont(k: Chain[_], err) ⇒ Cont(k :+ ((a: A) ⇒ Done(f(a))), err) - case Cont(k, err) ⇒ Cont(Chain(k, (a: A) ⇒ Done(f(a))), err) + case Done(value) ⇒ Done(f(value)) + case Next(k: Chain[_]) ⇒ Next(k :+ ((a: A) ⇒ Done(f(a)))) + case Next(k) ⇒ Next(Chain(k, (a: A) ⇒ Done(f(a)))) + case f: Failure ⇒ f } - - /** - * Provides a handler for any matching errors that may have occured while - * running this Iteratee. - * - * Errors are usually raised within the Iteratee with [[akka.actor.IO]].throwErr - * or by processing an [[akka.actor.IO.EOF]] that contains an Exception. - */ - def recover[B >: A](pf: PartialFunction[Exception, B]): Iteratee[B] = this match { - case done @ Done(_) ⇒ done - case Cont(_, Some(err)) if pf isDefinedAt err ⇒ Done(pf(err)) - case Cont(k, err) ⇒ Cont((more ⇒ k(more) match { case (iter, rest) ⇒ (iter recover pf, rest) }), err) - } - } /** @@ -460,14 +458,14 @@ object IO { /** * An [[akka.actor.IO.Iteratee]] that still requires more input to calculate - * it's result. It may also contain an optional error, which can be handled - * with 'recover()'. - * - * It is possible to recover from an error and continue processing this - * Iteratee without losing the continuation, although that has not yet - * been tested. An example use case of this is resuming a failed download. + * it's result. */ - final case class Cont[+A](f: Input ⇒ (Iteratee[A], Input), error: Option[Exception] = None) extends Iteratee[A] + final case class Next[+A](f: Input ⇒ (Iteratee[A], Input)) extends Iteratee[A] + + /** + * An [[akka.actor.IO.Iteratee]] that represents an erronous end state. + */ + final case class Failure(cause: Throwable) extends Iteratee[Nothing] //FIXME general description of what an IterateeRef is and how it is used, potentially with link to docs object IterateeRef { @@ -598,12 +596,6 @@ object IO { def future: Future[(Iteratee[A], Input)] = _value } - /** - * An [[akka.actor.IO.Iteratee]] that contains an Exception. The Exception - * can be handled with Iteratee.recover(). - */ - final def throwErr(err: Exception): Iteratee[Nothing] = Cont(input ⇒ (throwErr(err), input), Some(err)) - /** * An Iteratee that returns the ByteString prefix up until the supplied delimiter. * The delimiter is dropped by default, but it can be returned with the result by @@ -618,13 +610,13 @@ object IO { val endIdx = startIdx + delimiter.length (Done(bytes take (if (inclusive) endIdx else startIdx)), Chunk(bytes drop endIdx)) } else { - (Cont(step(bytes)), Chunk.empty) + (Next(step(bytes)), Chunk.empty) } - case eof @ EOF(None) ⇒ (Done(taken), eof) - case eof @ EOF(cause) ⇒ (Cont(step(taken), cause), eof) + case EOF ⇒ (Failure(new EOFException("Unexpected EOF")), EOF) + case e @ Error(cause) ⇒ (Failure(cause), e) } - Cont(step(ByteString.empty)) + Next(step(ByteString.empty)) } /** @@ -635,14 +627,14 @@ object IO { case Chunk(more) ⇒ val (found, rest) = more span p if (rest.isEmpty) - (Cont(step(taken ++ found)), Chunk.empty) + (Next(step(taken ++ found)), Chunk.empty) else (Done(taken ++ found), Chunk(rest)) - case eof @ EOF(None) ⇒ (Done(taken), eof) - case eof @ EOF(cause) ⇒ (Cont(step(taken), cause), eof) + case EOF ⇒ (Failure(new EOFException("Unexpected EOF")), EOF) + case e @ Error(cause) ⇒ (Failure(cause), e) } - Cont(step(ByteString.empty)) + Next(step(ByteString.empty)) } /** @@ -655,12 +647,12 @@ object IO { if (bytes.length >= length) (Done(bytes.take(length)), Chunk(bytes.drop(length))) else - (Cont(step(bytes)), Chunk.empty) - case eof @ EOF(None) ⇒ (Done(taken), eof) - case eof @ EOF(cause) ⇒ (Cont(step(taken), cause), eof) + (Next(step(bytes)), Chunk.empty) + case EOF ⇒ (Failure(new EOFException("Unexpected EOF")), EOF) + case e @ Error(cause) ⇒ (Failure(cause), e) } - Cont(step(ByteString.empty)) + Next(step(ByteString.empty)) } /** @@ -670,14 +662,14 @@ object IO { def step(left: Int)(input: Input): (Iteratee[Unit], Input) = input match { case Chunk(more) ⇒ if (left > more.length) - (Cont(step(left - more.length)), Chunk.empty) + (Next(step(left - more.length)), Chunk.empty) else (Done(()), Chunk(more drop left)) - case eof @ EOF(None) ⇒ (Done(()), eof) - case eof @ EOF(cause) ⇒ (Cont(step(left), cause), eof) + case EOF ⇒ (Failure(new EOFException("Unexpected EOF")), EOF) + case e @ Error(cause) ⇒ (Failure(cause), e) } - Cont(step(length)) + Next(step(length)) } /** @@ -687,22 +679,22 @@ object IO { def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match { case Chunk(more) ⇒ val bytes = taken ++ more - (Cont(step(bytes)), Chunk.empty) - case eof @ EOF(None) ⇒ (Done(taken), eof) - case eof @ EOF(cause) ⇒ (Cont(step(taken), cause), eof) + (Next(step(bytes)), Chunk.empty) + case EOF ⇒ (Done(taken), EOF) + case e @ Error(cause) ⇒ (Failure(cause), e) } - Cont(step(ByteString.empty)) + Next(step(ByteString.empty)) } /** * An Iteratee that returns any input it receives */ - val takeAny: Iteratee[ByteString] = Cont { + val takeAny: Iteratee[ByteString] = Next { case Chunk(bytes) if bytes.nonEmpty ⇒ (Done(bytes), Chunk.empty) case Chunk(bytes) ⇒ (takeAny, Chunk.empty) - case eof @ EOF(None) ⇒ (Done(ByteString.empty), eof) - case eof @ EOF(cause) ⇒ (Cont(more ⇒ (Done(ByteString.empty), more), cause), eof) + case EOF ⇒ (Done(ByteString.empty), EOF) + case e @ Error(cause) ⇒ (Failure(cause), e) } /** @@ -727,20 +719,18 @@ object IO { if (bytes.length >= length) (Done(bytes.take(length)), Chunk(bytes)) else - (Cont(step(bytes)), Chunk.empty) - case eof @ EOF(None) ⇒ (Done(taken), eof) - case eof @ EOF(cause) ⇒ (Cont(step(taken), cause), eof) + (Next(step(bytes)), Chunk.empty) + case EOF ⇒ (Done(taken), EOF) + case e @ Error(cause) ⇒ (Failure(cause), e) } - Cont(step(ByteString.empty)) + Next(step(ByteString.empty)) } /** * An Iteratee that continually repeats an Iteratee. - * - * FIXME TODO: Should terminate on EOF */ - def repeat(iter: Iteratee[Unit]): Iteratee[Unit] = iter flatMap (_ ⇒ repeat(iter)) + def repeat[T](iter: Iteratee[T]): Iteratee[T] = iter flatMap (_ ⇒ repeat(iter)) /** * An Iteratee that applies an Iteratee to each element of a Traversable @@ -782,13 +772,11 @@ object IO { } else result match { case (Done(value), rest) ⇒ queueOut.head(value) match { - //case Cont(Chain(f, q)) ⇒ run(f(rest), q ++ tail) <- can cause big slowdown, need to test if needed - case Cont(f, None) ⇒ run(f(rest), queueOut.tail, queueIn) - case iter ⇒ run((iter, rest), queueOut.tail, queueIn) + case Next(f) ⇒ run(f(rest), queueOut.tail, queueIn) + case iter ⇒ run((iter, rest), queueOut.tail, queueIn) } - case (Cont(f, None), rest) ⇒ - (Cont(new Chain(f, queueOut, queueIn)), rest) - case _ ⇒ result + case (Next(f), rest) ⇒ (Next(new Chain(f, queueOut, queueIn)), rest) + case _ ⇒ result } } run(cur(input), queueOut, queueIn).asInstanceOf[(Iteratee[A], Input)] @@ -815,7 +803,7 @@ object IO { * An IOManager does not need to be manually stopped when not in use as it will * automatically enter an idle state when it has no channels to manage. */ -final class IOManager private (system: ExtendedActorSystem) extends Extension { //FIXME how about taking an ActorContext +final class IOManager private (system: ExtendedActorSystem) extends Extension { //FIXME how about taking an ActorNextext val settings: Settings = { val c = system.settings.config.getConfig("akka.io") Settings( @@ -962,7 +950,7 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi private var fastSelect = false /** unique message that is sent to ourself to initiate the next select */ - private object Select + private case object Select /** This method should be called after receiving any message */ private def run() { @@ -1080,16 +1068,16 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi case IO.Close(handle: IO.WriteHandle) ⇒ //If we still have pending writes, add to set of closing handles if (writes get handle exists (_.isEmpty == false)) closing += handle - else cleanup(handle, None) + else cleanup(handle, IO.EOF) run() case IO.Close(handle) ⇒ - cleanup(handle, None) + cleanup(handle, IO.EOF) run() } override def postStop { - channels.keys foreach (handle ⇒ cleanup(handle, None)) + channels.keys foreach (handle ⇒ cleanup(handle, IO.EOF)) selector.close } @@ -1102,11 +1090,11 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi if (key.isWritable) key.channel match { case channel: WriteChannel ⇒ try write(handle.asWritable, channel) catch { case e: IOException ⇒ } } // ignore, let it fail on read to ensure nothing left in read buffer. } catch { case e @ (_: ClassCastException | _: CancelledKeyException | _: IOException | _: ActorInitializationException) ⇒ - cleanup(handle, Some(e.asInstanceOf[Exception])) //Scala patmat is broken + cleanup(handle, IO.Error(e)) //Scala patmat is broken } } - private def cleanup(handle: IO.Handle, cause: Option[Exception]) { + private def cleanup(handle: IO.Handle, cause: IO.Input) { closing -= handle handle match { case server: IO.ServerHandle ⇒ accepted -= server @@ -1138,7 +1126,7 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi removeOps(socket, OP_CONNECT) socket.owner ! IO.Connected(socket, channel.socket.getRemoteSocketAddress()) } else { - cleanup(socket, Some(new IllegalStateException("Channel for socket handle [%s] didn't finish connect" format socket))) + cleanup(socket, IO.Error(new IllegalStateException("Channel for socket handle [%s] didn't finish connect" format socket))) } } @@ -1167,7 +1155,7 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi buffer.clear val readLen = channel read buffer if (readLen == -1) { - cleanup(handle, Some(new EOFException("Elvis has left the building"))) + cleanup(handle, IO.EOF) } else if (readLen > 0) { buffer.flip handle.owner ! IO.Read(handle, ByteString(buffer)) @@ -1179,11 +1167,8 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi val queue = writes(handle) queue write channel if (queue.isEmpty) { - if (closing(handle)) { - cleanup(handle, None) - } else { - removeOps(handle, OP_WRITE) - } + if (closing(handle)) cleanup(handle, IO.EOF) + else removeOps(handle, OP_WRITE) } } } diff --git a/akka-docs/scala/code/docs/io/HTTPServer.scala b/akka-docs/scala/code/docs/io/HTTPServer.scala index 5c63eac3c2..172a305408 100644 --- a/akka-docs/scala/code/docs/io/HTTPServer.scala +++ b/akka-docs/scala/code/docs/io/HTTPServer.scala @@ -30,7 +30,7 @@ class HttpServer(port: Int) extends Actor { state(socket)(IO Chunk bytes) case IO.Closed(socket, cause) ⇒ - state(socket)(IO EOF None) + state(socket)(IO EOF) state -= socket }