!str: IteratorInterpreter and ByteStringBatcher

also Fixing TerminationDirective return types
This commit is contained in:
Endre Sándor Varga 2014-11-14 12:27:08 +01:00
parent bb07c20547
commit f80d97fc9b
6 changed files with 254 additions and 16 deletions

View file

@ -47,7 +47,7 @@ trait InterpreterSpecKit extends AkkaSpec {
class UpstreamProbe extends BoundaryOp {
override def onDownstreamFinish(ctxt: BoundaryContext): Directive = {
override def onDownstreamFinish(ctxt: BoundaryContext): TerminationDirective = {
lastEvent += Cancel
ctxt.exit()
}
@ -72,12 +72,12 @@ trait InterpreterSpecKit extends AkkaSpec {
ctxt.exit()
}
override def onUpstreamFinish(ctxt: BoundaryContext): Directive = {
override def onUpstreamFinish(ctxt: BoundaryContext): TerminationDirective = {
lastEvent += OnComplete
ctxt.exit()
}
override def onFailure(cause: Throwable, ctxt: BoundaryContext): Directive = {
override def onFailure(cause: Throwable, ctxt: BoundaryContext): TerminationDirective = {
lastEvent += OnError(cause)
ctxt.exit()
}

View file

@ -0,0 +1,155 @@
package akka.stream.impl.fusing
import akka.stream.testkit.AkkaSpec
import akka.util.ByteString
import scala.collection.immutable
class IteratorInterpreterSpec extends AkkaSpec {
"IteratorInterpreter" must {
"work in the happy case" in {
val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(
Map((x: Int) x + 1))).iterator
itr.toSeq should be(2 to 11)
}
"hasNext should not affect elements" in {
val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(
Map((x: Int) x))).iterator
itr.hasNext should be(true)
itr.hasNext should be(true)
itr.hasNext should be(true)
itr.hasNext should be(true)
itr.hasNext should be(true)
itr.toSeq should be(1 to 10)
}
"work with ops that need extra pull for complete" in {
val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(NaiveTake(1))).iterator
itr.toSeq should be(Seq(1))
}
"throw exceptions on empty iterator" in {
val itr = new IteratorInterpreter[Int, Int](List(1).iterator, Seq(
Map((x: Int) x))).iterator
itr.next() should be(1)
a[NoSuchElementException] should be thrownBy { itr.next() }
}
"throw exceptions when chain fails" in {
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(
new TransitivePullOp[Int, Int] {
override def onPush(elem: Int, ctxt: Context[Int]): Directive = {
if (elem == 2) ctxt.fail(new ArithmeticException())
else ctxt.push(elem)
}
})).iterator
itr.next() should be(1)
a[ArithmeticException] should be thrownBy { itr.next() }
}
"throw exceptions when op in chain throws" in {
val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(
new TransitivePullOp[Int, Int] {
override def onPush(elem: Int, ctxt: Context[Int]): Directive = {
if (elem == 2) throw new ArithmeticException()
else ctxt.push(elem)
}
})).iterator
itr.next() should be(1)
a[ArithmeticException] should be thrownBy { itr.next() }
}
"work with an empty iterator" in {
val itr = new IteratorInterpreter[Int, Int](Iterator.empty, Seq(
Map((x: Int) x + 1))).iterator
itr.hasNext should be(false)
a[NoSuchElementException] should be thrownBy { itr.next() }
}
"able to implement a ByteStringBatcher" in {
val testBytes = (1 to 10).map(ByteString(_))
def newItr(threshold: Int) =
new IteratorInterpreter[ByteString, ByteString](testBytes.iterator, Seq(ByteStringBatcher(threshold))).iterator
val itr1 = newItr(20)
itr1.next() should be(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
itr1.hasNext should be(false)
val itr2 = newItr(10)
itr2.next() should be(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
itr2.hasNext should be(false)
val itr3 = newItr(5)
itr3.next() should be(ByteString(1, 2, 3, 4, 5))
(6 to 10) foreach { i
itr3.hasNext should be(true)
itr3.next() should be(ByteString(i))
}
itr3.hasNext should be(false)
val itr4 =
new IteratorInterpreter[ByteString, ByteString](Iterator.empty, Seq(ByteStringBatcher(10))).iterator
itr4.hasNext should be(false)
}
}
// This op needs an extra pull round to finish
case class NaiveTake[T](count: Int) extends DeterministicOp[T, T] {
private var left: Int = count
override def onPush(elem: T, ctxt: Context[T]): Directive = {
left -= 1
ctxt.push(elem)
}
override def onPull(ctxt: Context[T]): Directive = {
if (left == 0) ctxt.finish()
else ctxt.pull()
}
}
case class ByteStringBatcher(threshold: Int, compact: Boolean = true) extends DeterministicOp[ByteString, ByteString] {
require(threshold > 0, "Threshold must be positive")
private var buf = ByteString.empty
private var passthrough = false
override def onPush(elem: ByteString, ctxt: Context[ByteString]): Directive = {
if (passthrough) ctxt.push(elem)
else {
buf = buf ++ elem
if (buf.size >= threshold) {
val batch = if (compact) buf.compact else buf
passthrough = true
buf = ByteString.empty
ctxt.push(batch)
} else ctxt.pull()
}
}
override def onPull(ctxt: Context[ByteString]): Directive = {
if (isFinishing) ctxt.pushAndFinish(buf)
else ctxt.pull()
}
override def onUpstreamFinish(ctxt: Context[ByteString]): TerminationDirective = {
if (passthrough || buf.isEmpty) ctxt.finish()
else ctxt.absorbTermination()
}
}
}

View file

@ -76,7 +76,7 @@ private[akka] class BatchingActorInputBoundary(val size: Int) extends BoundaryOp
}
}
override def onDownstreamFinish(ctxt: BoundaryContext): Directive = {
override def onDownstreamFinish(ctxt: BoundaryContext): TerminationDirective = {
cancel()
ctxt.exit()
}
@ -190,12 +190,12 @@ private[akka] class ActorOutputBoundary(val actor: ActorRef) extends BoundaryOp
override def onPull(ctxt: BoundaryContext): Directive =
throw new UnsupportedOperationException("BUG: Cannot pull the downstream boundary")
override def onUpstreamFinish(ctxt: BoundaryContext): Directive = {
override def onUpstreamFinish(ctxt: BoundaryContext): TerminationDirective = {
complete()
ctxt.finish()
}
override def onFailure(cause: Throwable, ctxt: BoundaryContext): Directive = {
override def onFailure(cause: Throwable, ctxt: BoundaryContext): TerminationDirective = {
fail(cause)
ctxt.fail(cause)
}

View file

@ -20,9 +20,9 @@ trait Op[In, Out, PushD <: Directive, PullD <: Directive, Ctxt <: Context[Out]]
def isFinishing: Boolean = terminationPending
def onPush(elem: In, ctxt: Ctxt): PushD
def onPull(ctxt: Ctxt): PullD
def onUpstreamFinish(ctxt: Ctxt): Directive = ctxt.finish()
def onDownstreamFinish(ctxt: Ctxt): Directive = ctxt.finish()
def onFailure(cause: Throwable, ctxt: Ctxt): Directive = ctxt.fail(cause)
def onUpstreamFinish(ctxt: Ctxt): TerminationDirective = ctxt.finish()
def onDownstreamFinish(ctxt: Ctxt): TerminationDirective = ctxt.finish()
def onFailure(cause: Throwable, ctxt: Ctxt): TerminationDirective = ctxt.fail(cause)
}
trait DeterministicOp[In, Out] extends Op[In, Out, Directive, Directive, Context[Out]]
@ -73,9 +73,9 @@ object OneBoundedInterpreter {
private[akka] object Finished extends BoundaryOp {
override def onPush(elem: Any, ctxt: BoundaryContext): UpstreamDirective = ctxt.finish()
override def onPull(ctxt: BoundaryContext): DownstreamDirective = ctxt.finish()
override def onUpstreamFinish(ctxt: BoundaryContext): Directive = ctxt.exit()
override def onDownstreamFinish(ctxt: BoundaryContext): Directive = ctxt.exit()
override def onFailure(cause: Throwable, ctxt: BoundaryContext): Directive = ctxt.exit()
override def onUpstreamFinish(ctxt: BoundaryContext): TerminationDirective = ctxt.exit()
override def onDownstreamFinish(ctxt: BoundaryContext): TerminationDirective = ctxt.exit()
override def onFailure(cause: Throwable, ctxt: BoundaryContext): TerminationDirective = ctxt.exit()
}
}

View file

@ -0,0 +1,83 @@
/**
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.stream.impl.fusing
object IteratorInterpreter {
case class IteratorUpstream[T](input: Iterator[T]) extends DeterministicOp[T, T] {
private var hasNext = input.hasNext
override def onPush(elem: T, ctxt: Context[T]): Directive =
throw new UnsupportedOperationException("IteratorUpstream operates as a source, it cannot be pushed")
override def onPull(ctxt: Context[T]): Directive = {
if (!hasNext) ctxt.finish()
else {
val elem = input.next()
hasNext = input.hasNext
if (!hasNext) ctxt.pushAndFinish(elem)
else ctxt.push(elem)
}
}
}
case class IteratorDownstream[T]() extends BoundaryOp with Iterator[T] {
private var done = false
private var nextElem: T = _
private var needsPull = true
private var lastError: Throwable = null
override def onPush(elem: Any, ctxt: BoundaryContext): Directive = {
nextElem = elem.asInstanceOf[T]
needsPull = false
ctxt.exit()
}
override def onPull(ctxt: BoundaryContext): Directive =
throw new UnsupportedOperationException("IteratorDownstream operates as a sink, it cannot be pulled")
override def onUpstreamFinish(ctxt: BoundaryContext): TerminationDirective = {
done = true
ctxt.finish()
}
override def onFailure(cause: Throwable, ctxt: BoundaryContext): TerminationDirective = {
done = true
lastError = cause
ctxt.finish()
}
private def pullIfNeeded(): Unit = {
if (needsPull) {
enter().pull() // will eventually result in a finish, or an onPush which exits
}
}
override def hasNext: Boolean = {
if (!done) pullIfNeeded()
!(done && needsPull)
}
override def next(): T = {
if (!hasNext) {
if (lastError != null) throw lastError
else Iterator.empty.next()
}
needsPull = true
nextElem
}
}
}
class IteratorInterpreter[I, O](val input: Iterator[I], val ops: Seq[DeterministicOp[_, _]]) {
import akka.stream.impl.fusing.IteratorInterpreter._
private val upstream = IteratorUpstream(input)
private val downstream = IteratorDownstream[O]()
private val interpreter = new OneBoundedInterpreter(upstream +: ops.asInstanceOf[Seq[Op[_, _, _, _, _]]] :+ downstream)
interpreter.init()
def iterator: Iterator[O] = downstream
}

View file

@ -81,7 +81,7 @@ private[akka] case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends
if (isFinishing) ctxt.pushAndFinish(aggregator)
else ctxt.pull()
override def onUpstreamFinish(ctxt: Context[Out]): Directive = ctxt.absorbTermination()
override def onUpstreamFinish(ctxt: Context[Out]): TerminationDirective = ctxt.absorbTermination()
}
/**
@ -103,7 +103,7 @@ private[akka] case class Grouped[T](n: Int) extends DeterministicOp[T, immutable
if (isFinishing) ctxt.pushAndFinish(buf)
else ctxt.pull()
override def onUpstreamFinish(ctxt: Context[immutable.Seq[T]]): Directive =
override def onUpstreamFinish(ctxt: Context[immutable.Seq[T]]): TerminationDirective =
if (buf.isEmpty) ctxt.finish()
else ctxt.absorbTermination()
}
@ -130,7 +130,7 @@ private[akka] case class Buffer[T](size: Int, overflowStrategy: OverflowStrategy
else ctxt.push(buffer.dequeue().asInstanceOf[T])
}
override def onUpstreamFinish(ctxt: DetachedContext[T]): Directive =
override def onUpstreamFinish(ctxt: DetachedContext[T]): TerminationDirective =
if (buffer.isEmpty) ctxt.finish()
else ctxt.absorbTermination()
@ -208,7 +208,7 @@ private[akka] case class Conflate[In, Out](seed: In ⇒ Out, aggregate: (Out, In
}
}
override def onUpstreamFinish(ctxt: DetachedContext[Out]): Directive = ctxt.absorbTermination()
override def onUpstreamFinish(ctxt: DetachedContext[Out]): TerminationDirective = ctxt.absorbTermination()
}
/**