#2355 - Fixing issue with takeUntil and friends on EOF plus generalizing repeat

This commit is contained in:
Viktor Klang 2012-08-13 13:43:06 +02:00
parent e1a6e23d7f
commit 4f20f10884
3 changed files with 143 additions and 120 deletions

View file

@ -13,8 +13,8 @@ import scala.concurrent.util.duration._
import scala.util.continuations._
import akka.testkit._
import akka.dispatch.MessageDispatcher
import java.net.{ SocketAddress }
import akka.pattern.ask
import java.net.{ Socket, InetSocketAddress, InetAddress, SocketAddress }
object IOActorSpec {
@ -58,24 +58,22 @@ object IOActorSpec {
case bytes: ByteString
val source = sender
socket write bytes
state flatMap { _
IO take bytes.length map (source ! _) recover {
case e source ! Status.Failure(e)
}
}
state flatMap { _ IO take bytes.length map (source ! _) }
case IO.Read(`socket`, bytes)
state(IO Chunk bytes)
case IO.Closed(`socket`, cause)
state(IO EOF cause)
throw (cause getOrElse new RuntimeException("Socket closed"))
state(cause)
throw cause match {
case IO.Error(e) e
case _ new RuntimeException("Socket closed")
}
}
override def postStop {
socket.close()
state(IO EOF None)
state(IO EOF)
}
}
@ -180,29 +178,29 @@ object IOActorSpec {
val source = sender
socket write cmd.bytes
state flatMap { _
readResult map (source !) recover {
case e source ! Status.Failure(e)
}
readResult map (source !)
}
case IO.Read(`socket`, bytes)
state(IO Chunk bytes)
case IO.Closed(`socket`, cause)
state(IO EOF cause)
throw (cause getOrElse new RuntimeException("Socket closed"))
state(cause)
throw cause match {
case IO.Error(t) t
case _ new RuntimeException("Socket closed")
}
}
override def postStop {
socket.close()
state(IO EOF None)
state(IO.EOF)
}
def readResult: IO.Iteratee[Any] = {
IO take 1 map (_.utf8String) flatMap {
case "+" IO takeUntil EOL map (msg msg.utf8String)
case "-" IO takeUntil EOL flatMap (err IO throwErr new RuntimeException(err.utf8String))
case "-" IO takeUntil EOL flatMap (err IO.Failure(new RuntimeException(err.utf8String)))
case "$"
IO takeUntil EOL map (_.utf8String.toInt) flatMap {
case -1 IO Done None
@ -221,10 +219,10 @@ object IOActorSpec {
case (Right(m), List(Some(k: String), Some(v: String))) Right(m + (k -> v))
case (Right(_), _) Left("Unexpected Response")
case (left, _) left
} fold (msg IO throwErr new RuntimeException(msg), IO Done _)
} fold (msg IO.Failure(new RuntimeException(msg)), IO Done _)
}
}
case _ IO throwErr new RuntimeException("Unexpected Response")
case _ IO.Failure(new RuntimeException("Unexpected Response"))
}
}
}
@ -331,6 +329,46 @@ class IOActorSpec extends AkkaSpec with DefaultTimeout {
system.stop(server)
}
}
"takeUntil must fail on EOF before predicate when used with repeat" in {
val CRLF = ByteString("\r\n")
val dest = new InetSocketAddress(InetAddress.getLocalHost.getHostAddress, { val s = new java.net.ServerSocket(0); try s.getLocalPort finally s.close() })
val a = system.actorOf(Props(new Actor {
val state = IO.IterateeRef.Map.async[IO.Handle]()(context.dispatcher)
override def preStart {
IOManager(context.system) listen dest
}
def receive = {
case _: IO.Listening testActor ! "Wejkipejki"
case IO.NewClient(server)
val socket = server.accept()
state(socket) flatMap (_ IO.repeat(for (input IO.takeUntil(CRLF)) yield testActor ! input.utf8String))
case IO.Read(socket, bytes)
state(socket)(IO Chunk bytes)
case IO.Closed(socket, cause)
state(socket)(IO EOF)
state -= socket
}
}))
expectMsg("Wejkipejki")
val s = new Socket(dest.getAddress, dest.getPort)
try {
val expected = Seq("ole", "dole", "doff", "kinke", "lane", "koff", "ole", "dole", "dinke", "dane", "ole", "dole", "doff")
val out = s.getOutputStream
out.write(expected.mkString("", CRLF.utf8String, CRLF.utf8String).getBytes("UTF-8"))
out.flush()
for (word expected) expectMsg(word)
s.close()
expectNoMsg(500.millis)
} finally {
if (!s.isClosed) s.close()
}
}
}
}

View file

@ -28,6 +28,7 @@ import scala.collection.generic.CanBuildFrom
import java.util.UUID
import java.io.{ EOFException, IOException }
import akka.actor.IOManager.Settings
import akka.actor.IO.Chunk
/**
* IO messages and iteratees.
@ -303,7 +304,7 @@ object IO {
*
* No action is required by the receiving [[akka.actor.Actor]].
*/
case class Closed(handle: Handle, cause: Option[Exception]) extends IOMessage
case class Closed(handle: Handle, cause: Input) extends IOMessage
/**
* Message from an [[akka.actor.IOManager]] that contains bytes read from
@ -334,6 +335,9 @@ object IO {
}
object Chunk {
/**
* Represents the empty Chunk
*/
val empty: Chunk = new Chunk(ByteString.empty)
}
@ -342,10 +346,11 @@ object IO {
*/
case class Chunk(bytes: ByteString) extends Input {
final override def ++(that: Input): Input = that match {
case Chunk(more) if more.isEmpty this
case c: Chunk if bytes.isEmpty c
case Chunk(more) Chunk(bytes ++ more)
case _: EOF that
case c @ Chunk(more)
if (more.isEmpty) this
else if (bytes.isEmpty) c
else Chunk(bytes ++ more)
case other other
}
}
@ -354,12 +359,17 @@ object IO {
* stream.
*
* This will cause the [[akka.actor.IO.Iteratee]] that processes it
* to terminate early. If a cause is defined it can be 'caught' by
* Iteratee.recover() in order to handle it properly.
* to terminate early.
*/
case class EOF(cause: Option[Exception]) extends Input {
final override def ++(that: Input): Input = that
}
case object EOF extends Input { final override def ++(that: Input): Input = that }
/**
* Part of an [[akka.actor.IO.Input]] stream that represents an error in the stream.
*
* This will cause the [[akka.actor.IO.Iteratee]] that processes it
* to terminate early.
*/
case class Error(cause: Throwable) extends Input { final override def ++(that: Input): Input = that }
object Iteratee {
/**
@ -395,7 +405,7 @@ object IO {
* Iteratee and the remaining Input.
*/
final def apply(input: Input): (Iteratee[A], Input) = this match {
case Cont(f, None) f(input)
case Next(f) f(input)
case iter (iter, input)
}
@ -408,10 +418,10 @@ object IO {
* If this Iteratee is not well behaved (does not return a result on EOF)
* then a "Divergent Iteratee" Exception will be thrown.
*/
final def get: A = this(EOF(None))._1 match {
final def get: A = this(EOF)._1 match {
case Done(value) value
case Cont(_, None) throw new DivergentIterateeException
case Cont(_, Some(err)) throw err
case Next(_) throw new DivergentIterateeException
case Failure(t) throw t
}
/**
@ -423,8 +433,9 @@ object IO {
*/
final def flatMap[B](f: A Iteratee[B]): Iteratee[B] = this match {
case Done(value) f(value)
case Cont(k: Chain[_], err) Cont(k :+ f, err)
case Cont(k, err) Cont(Chain(k, f), err)
case Next(k: Chain[_]) Next(k :+ f)
case Next(k) Next(Chain(k, f))
case f: Failure f
}
/**
@ -432,23 +443,10 @@ object IO {
*/
final def map[B](f: A B): Iteratee[B] = this match {
case Done(value) Done(f(value))
case Cont(k: Chain[_], err) Cont(k :+ ((a: A) Done(f(a))), err)
case Cont(k, err) Cont(Chain(k, (a: A) Done(f(a))), err)
case Next(k: Chain[_]) Next(k :+ ((a: A) Done(f(a))))
case Next(k) Next(Chain(k, (a: A) Done(f(a))))
case f: Failure f
}
/**
* Provides a handler for any matching errors that may have occured while
* running this Iteratee.
*
* Errors are usually raised within the Iteratee with [[akka.actor.IO]].throwErr
* or by processing an [[akka.actor.IO.EOF]] that contains an Exception.
*/
def recover[B >: A](pf: PartialFunction[Exception, B]): Iteratee[B] = this match {
case done @ Done(_) done
case Cont(_, Some(err)) if pf isDefinedAt err Done(pf(err))
case Cont(k, err) Cont((more k(more) match { case (iter, rest) (iter recover pf, rest) }), err)
}
}
/**
@ -460,14 +458,14 @@ object IO {
/**
* An [[akka.actor.IO.Iteratee]] that still requires more input to calculate
* it's result. It may also contain an optional error, which can be handled
* with 'recover()'.
*
* It is possible to recover from an error and continue processing this
* Iteratee without losing the continuation, although that has not yet
* been tested. An example use case of this is resuming a failed download.
* it's result.
*/
final case class Cont[+A](f: Input (Iteratee[A], Input), error: Option[Exception] = None) extends Iteratee[A]
final case class Next[+A](f: Input (Iteratee[A], Input)) extends Iteratee[A]
/**
* An [[akka.actor.IO.Iteratee]] that represents an erronous end state.
*/
final case class Failure(cause: Throwable) extends Iteratee[Nothing]
//FIXME general description of what an IterateeRef is and how it is used, potentially with link to docs
object IterateeRef {
@ -598,12 +596,6 @@ object IO {
def future: Future[(Iteratee[A], Input)] = _value
}
/**
* An [[akka.actor.IO.Iteratee]] that contains an Exception. The Exception
* can be handled with Iteratee.recover().
*/
final def throwErr(err: Exception): Iteratee[Nothing] = Cont(input (throwErr(err), input), Some(err))
/**
* An Iteratee that returns the ByteString prefix up until the supplied delimiter.
* The delimiter is dropped by default, but it can be returned with the result by
@ -618,13 +610,13 @@ object IO {
val endIdx = startIdx + delimiter.length
(Done(bytes take (if (inclusive) endIdx else startIdx)), Chunk(bytes drop endIdx))
} else {
(Cont(step(bytes)), Chunk.empty)
(Next(step(bytes)), Chunk.empty)
}
case eof @ EOF(None) (Done(taken), eof)
case eof @ EOF(cause) (Cont(step(taken), cause), eof)
case EOF (Failure(new EOFException("Unexpected EOF")), EOF)
case e @ Error(cause) (Failure(cause), e)
}
Cont(step(ByteString.empty))
Next(step(ByteString.empty))
}
/**
@ -635,14 +627,14 @@ object IO {
case Chunk(more)
val (found, rest) = more span p
if (rest.isEmpty)
(Cont(step(taken ++ found)), Chunk.empty)
(Next(step(taken ++ found)), Chunk.empty)
else
(Done(taken ++ found), Chunk(rest))
case eof @ EOF(None) (Done(taken), eof)
case eof @ EOF(cause) (Cont(step(taken), cause), eof)
case EOF (Failure(new EOFException("Unexpected EOF")), EOF)
case e @ Error(cause) (Failure(cause), e)
}
Cont(step(ByteString.empty))
Next(step(ByteString.empty))
}
/**
@ -655,12 +647,12 @@ object IO {
if (bytes.length >= length)
(Done(bytes.take(length)), Chunk(bytes.drop(length)))
else
(Cont(step(bytes)), Chunk.empty)
case eof @ EOF(None) (Done(taken), eof)
case eof @ EOF(cause) (Cont(step(taken), cause), eof)
(Next(step(bytes)), Chunk.empty)
case EOF (Failure(new EOFException("Unexpected EOF")), EOF)
case e @ Error(cause) (Failure(cause), e)
}
Cont(step(ByteString.empty))
Next(step(ByteString.empty))
}
/**
@ -670,14 +662,14 @@ object IO {
def step(left: Int)(input: Input): (Iteratee[Unit], Input) = input match {
case Chunk(more)
if (left > more.length)
(Cont(step(left - more.length)), Chunk.empty)
(Next(step(left - more.length)), Chunk.empty)
else
(Done(()), Chunk(more drop left))
case eof @ EOF(None) (Done(()), eof)
case eof @ EOF(cause) (Cont(step(left), cause), eof)
case EOF (Failure(new EOFException("Unexpected EOF")), EOF)
case e @ Error(cause) (Failure(cause), e)
}
Cont(step(length))
Next(step(length))
}
/**
@ -687,22 +679,22 @@ object IO {
def step(taken: ByteString)(input: Input): (Iteratee[ByteString], Input) = input match {
case Chunk(more)
val bytes = taken ++ more
(Cont(step(bytes)), Chunk.empty)
case eof @ EOF(None) (Done(taken), eof)
case eof @ EOF(cause) (Cont(step(taken), cause), eof)
(Next(step(bytes)), Chunk.empty)
case EOF (Done(taken), EOF)
case e @ Error(cause) (Failure(cause), e)
}
Cont(step(ByteString.empty))
Next(step(ByteString.empty))
}
/**
* An Iteratee that returns any input it receives
*/
val takeAny: Iteratee[ByteString] = Cont {
val takeAny: Iteratee[ByteString] = Next {
case Chunk(bytes) if bytes.nonEmpty (Done(bytes), Chunk.empty)
case Chunk(bytes) (takeAny, Chunk.empty)
case eof @ EOF(None) (Done(ByteString.empty), eof)
case eof @ EOF(cause) (Cont(more (Done(ByteString.empty), more), cause), eof)
case EOF (Done(ByteString.empty), EOF)
case e @ Error(cause) (Failure(cause), e)
}
/**
@ -727,20 +719,18 @@ object IO {
if (bytes.length >= length)
(Done(bytes.take(length)), Chunk(bytes))
else
(Cont(step(bytes)), Chunk.empty)
case eof @ EOF(None) (Done(taken), eof)
case eof @ EOF(cause) (Cont(step(taken), cause), eof)
(Next(step(bytes)), Chunk.empty)
case EOF (Done(taken), EOF)
case e @ Error(cause) (Failure(cause), e)
}
Cont(step(ByteString.empty))
Next(step(ByteString.empty))
}
/**
* An Iteratee that continually repeats an Iteratee.
*
* FIXME TODO: Should terminate on EOF
*/
def repeat(iter: Iteratee[Unit]): Iteratee[Unit] = iter flatMap (_ repeat(iter))
def repeat[T](iter: Iteratee[T]): Iteratee[T] = iter flatMap (_ repeat(iter))
/**
* An Iteratee that applies an Iteratee to each element of a Traversable
@ -782,12 +772,10 @@ object IO {
} else result match {
case (Done(value), rest)
queueOut.head(value) match {
//case Cont(Chain(f, q)) run(f(rest), q ++ tail) <- can cause big slowdown, need to test if needed
case Cont(f, None) run(f(rest), queueOut.tail, queueIn)
case Next(f) run(f(rest), queueOut.tail, queueIn)
case iter run((iter, rest), queueOut.tail, queueIn)
}
case (Cont(f, None), rest)
(Cont(new Chain(f, queueOut, queueIn)), rest)
case (Next(f), rest) (Next(new Chain(f, queueOut, queueIn)), rest)
case _ result
}
}
@ -815,7 +803,7 @@ object IO {
* An IOManager does not need to be manually stopped when not in use as it will
* automatically enter an idle state when it has no channels to manage.
*/
final class IOManager private (system: ExtendedActorSystem) extends Extension { //FIXME how about taking an ActorContext
final class IOManager private (system: ExtendedActorSystem) extends Extension { //FIXME how about taking an ActorNextext
val settings: Settings = {
val c = system.settings.config.getConfig("akka.io")
Settings(
@ -962,7 +950,7 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi
private var fastSelect = false
/** unique message that is sent to ourself to initiate the next select */
private object Select
private case object Select
/** This method should be called after receiving any message */
private def run() {
@ -1080,16 +1068,16 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi
case IO.Close(handle: IO.WriteHandle)
//If we still have pending writes, add to set of closing handles
if (writes get handle exists (_.isEmpty == false)) closing += handle
else cleanup(handle, None)
else cleanup(handle, IO.EOF)
run()
case IO.Close(handle)
cleanup(handle, None)
cleanup(handle, IO.EOF)
run()
}
override def postStop {
channels.keys foreach (handle cleanup(handle, None))
channels.keys foreach (handle cleanup(handle, IO.EOF))
selector.close
}
@ -1102,11 +1090,11 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi
if (key.isWritable) key.channel match { case channel: WriteChannel try write(handle.asWritable, channel) catch { case e: IOException } } // ignore, let it fail on read to ensure nothing left in read buffer.
} catch {
case e @ (_: ClassCastException | _: CancelledKeyException | _: IOException | _: ActorInitializationException)
cleanup(handle, Some(e.asInstanceOf[Exception])) //Scala patmat is broken
cleanup(handle, IO.Error(e)) //Scala patmat is broken
}
}
private def cleanup(handle: IO.Handle, cause: Option[Exception]) {
private def cleanup(handle: IO.Handle, cause: IO.Input) {
closing -= handle
handle match {
case server: IO.ServerHandle accepted -= server
@ -1138,7 +1126,7 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi
removeOps(socket, OP_CONNECT)
socket.owner ! IO.Connected(socket, channel.socket.getRemoteSocketAddress())
} else {
cleanup(socket, Some(new IllegalStateException("Channel for socket handle [%s] didn't finish connect" format socket)))
cleanup(socket, IO.Error(new IllegalStateException("Channel for socket handle [%s] didn't finish connect" format socket)))
}
}
@ -1167,7 +1155,7 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi
buffer.clear
val readLen = channel read buffer
if (readLen == -1) {
cleanup(handle, Some(new EOFException("Elvis has left the building")))
cleanup(handle, IO.EOF)
} else if (readLen > 0) {
buffer.flip
handle.owner ! IO.Read(handle, ByteString(buffer))
@ -1179,11 +1167,8 @@ final class IOManagerActor(val settings: Settings) extends Actor with ActorLoggi
val queue = writes(handle)
queue write channel
if (queue.isEmpty) {
if (closing(handle)) {
cleanup(handle, None)
} else {
removeOps(handle, OP_WRITE)
}
if (closing(handle)) cleanup(handle, IO.EOF)
else removeOps(handle, OP_WRITE)
}
}
}

View file

@ -30,7 +30,7 @@ class HttpServer(port: Int) extends Actor {
state(socket)(IO Chunk bytes)
case IO.Closed(socket, cause)
state(socket)(IO EOF None)
state(socket)(IO EOF)
state -= socket
}