From f80d97fc9be837c06c6f574ae08e789cd4012da3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Endre=20S=C3=A1ndor=20Varga?= Date: Fri, 14 Nov 2014 12:27:08 +0100 Subject: [PATCH] !str: IteratorInterpreter and ByteStringBatcher also Fixing TerminationDirective return types --- .../impl/fusing/InterpreterSpecKit.scala | 6 +- .../impl/fusing/IteratorInterpreterSpec.scala | 155 ++++++++++++++++++ .../stream/impl/fusing/ActorInterpreter.scala | 6 +- .../akka/stream/impl/fusing/Interpreter.scala | 12 +- .../impl/fusing/IteratorInterpreter.scala | 83 ++++++++++ .../scala/akka/stream/impl/fusing/Ops.scala | 8 +- 6 files changed, 254 insertions(+), 16 deletions(-) create mode 100644 akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala create mode 100644 akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpecKit.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpecKit.scala index 09028e5248..40fcb8c72f 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpecKit.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpecKit.scala @@ -47,7 +47,7 @@ trait InterpreterSpecKit extends AkkaSpec { class UpstreamProbe extends BoundaryOp { - override def onDownstreamFinish(ctxt: BoundaryContext): Directive = { + override def onDownstreamFinish(ctxt: BoundaryContext): TerminationDirective = { lastEvent += Cancel ctxt.exit() } @@ -72,12 +72,12 @@ trait InterpreterSpecKit extends AkkaSpec { ctxt.exit() } - override def onUpstreamFinish(ctxt: BoundaryContext): Directive = { + override def onUpstreamFinish(ctxt: BoundaryContext): TerminationDirective = { lastEvent += OnComplete ctxt.exit() } - override def onFailure(cause: Throwable, ctxt: BoundaryContext): Directive = { + override def onFailure(cause: Throwable, ctxt: BoundaryContext): TerminationDirective = { lastEvent += OnError(cause) ctxt.exit() } diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala new file mode 100644 index 0000000000..f69b71cd29 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala @@ -0,0 +1,155 @@ +package akka.stream.impl.fusing + +import akka.stream.testkit.AkkaSpec +import akka.util.ByteString + +import scala.collection.immutable + +class IteratorInterpreterSpec extends AkkaSpec { + + "IteratorInterpreter" must { + + "work in the happy case" in { + val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq( + Map((x: Int) ⇒ x + 1))).iterator + + itr.toSeq should be(2 to 11) + } + + "hasNext should not affect elements" in { + val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq( + Map((x: Int) ⇒ x))).iterator + + itr.hasNext should be(true) + itr.hasNext should be(true) + itr.hasNext should be(true) + itr.hasNext should be(true) + itr.hasNext should be(true) + + itr.toSeq should be(1 to 10) + } + + "work with ops that need extra pull for complete" in { + val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(NaiveTake(1))).iterator + + itr.toSeq should be(Seq(1)) + } + + "throw exceptions on empty iterator" in { + val itr = new IteratorInterpreter[Int, Int](List(1).iterator, Seq( + Map((x: Int) ⇒ x))).iterator + + itr.next() should be(1) + a[NoSuchElementException] should be thrownBy { itr.next() } + } + + "throw exceptions when chain fails" in { + val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq( + new TransitivePullOp[Int, Int] { + override def onPush(elem: Int, ctxt: Context[Int]): Directive = { + if (elem == 2) ctxt.fail(new ArithmeticException()) + else ctxt.push(elem) + } + })).iterator + + itr.next() should be(1) + a[ArithmeticException] should be thrownBy { itr.next() } + } + + "throw exceptions when op in chain throws" in { + val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq( + new TransitivePullOp[Int, Int] { + override def onPush(elem: Int, ctxt: Context[Int]): Directive = { + if (elem == 2) throw new ArithmeticException() + else ctxt.push(elem) + } + })).iterator + + itr.next() should be(1) + a[ArithmeticException] should be thrownBy { itr.next() } + } + + "work with an empty iterator" in { + val itr = new IteratorInterpreter[Int, Int](Iterator.empty, Seq( + Map((x: Int) ⇒ x + 1))).iterator + + itr.hasNext should be(false) + a[NoSuchElementException] should be thrownBy { itr.next() } + } + + "able to implement a ByteStringBatcher" in { + val testBytes = (1 to 10).map(ByteString(_)) + + def newItr(threshold: Int) = + new IteratorInterpreter[ByteString, ByteString](testBytes.iterator, Seq(ByteStringBatcher(threshold))).iterator + + val itr1 = newItr(20) + itr1.next() should be(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + itr1.hasNext should be(false) + + val itr2 = newItr(10) + itr2.next() should be(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + itr2.hasNext should be(false) + + val itr3 = newItr(5) + itr3.next() should be(ByteString(1, 2, 3, 4, 5)) + (6 to 10) foreach { i ⇒ + itr3.hasNext should be(true) + itr3.next() should be(ByteString(i)) + } + itr3.hasNext should be(false) + + val itr4 = + new IteratorInterpreter[ByteString, ByteString](Iterator.empty, Seq(ByteStringBatcher(10))).iterator + + itr4.hasNext should be(false) + } + + } + + // This op needs an extra pull round to finish + case class NaiveTake[T](count: Int) extends DeterministicOp[T, T] { + private var left: Int = count + + override def onPush(elem: T, ctxt: Context[T]): Directive = { + left -= 1 + ctxt.push(elem) + } + + override def onPull(ctxt: Context[T]): Directive = { + if (left == 0) ctxt.finish() + else ctxt.pull() + } + } + + case class ByteStringBatcher(threshold: Int, compact: Boolean = true) extends DeterministicOp[ByteString, ByteString] { + require(threshold > 0, "Threshold must be positive") + + private var buf = ByteString.empty + private var passthrough = false + + override def onPush(elem: ByteString, ctxt: Context[ByteString]): Directive = { + if (passthrough) ctxt.push(elem) + else { + buf = buf ++ elem + if (buf.size >= threshold) { + val batch = if (compact) buf.compact else buf + passthrough = true + buf = ByteString.empty + ctxt.push(batch) + } else ctxt.pull() + } + } + + override def onPull(ctxt: Context[ByteString]): Directive = { + if (isFinishing) ctxt.pushAndFinish(buf) + else ctxt.pull() + } + + override def onUpstreamFinish(ctxt: Context[ByteString]): TerminationDirective = { + if (passthrough || buf.isEmpty) ctxt.finish() + else ctxt.absorbTermination() + } + } + +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala index 9cfef375cd..3e8eed3d38 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala @@ -76,7 +76,7 @@ private[akka] class BatchingActorInputBoundary(val size: Int) extends BoundaryOp } } - override def onDownstreamFinish(ctxt: BoundaryContext): Directive = { + override def onDownstreamFinish(ctxt: BoundaryContext): TerminationDirective = { cancel() ctxt.exit() } @@ -190,12 +190,12 @@ private[akka] class ActorOutputBoundary(val actor: ActorRef) extends BoundaryOp override def onPull(ctxt: BoundaryContext): Directive = throw new UnsupportedOperationException("BUG: Cannot pull the downstream boundary") - override def onUpstreamFinish(ctxt: BoundaryContext): Directive = { + override def onUpstreamFinish(ctxt: BoundaryContext): TerminationDirective = { complete() ctxt.finish() } - override def onFailure(cause: Throwable, ctxt: BoundaryContext): Directive = { + override def onFailure(cause: Throwable, ctxt: BoundaryContext): TerminationDirective = { fail(cause) ctxt.fail(cause) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Interpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Interpreter.scala index 7ba9f3f573..51b7765dd8 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Interpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Interpreter.scala @@ -20,9 +20,9 @@ trait Op[In, Out, PushD <: Directive, PullD <: Directive, Ctxt <: Context[Out]] def isFinishing: Boolean = terminationPending def onPush(elem: In, ctxt: Ctxt): PushD def onPull(ctxt: Ctxt): PullD - def onUpstreamFinish(ctxt: Ctxt): Directive = ctxt.finish() - def onDownstreamFinish(ctxt: Ctxt): Directive = ctxt.finish() - def onFailure(cause: Throwable, ctxt: Ctxt): Directive = ctxt.fail(cause) + def onUpstreamFinish(ctxt: Ctxt): TerminationDirective = ctxt.finish() + def onDownstreamFinish(ctxt: Ctxt): TerminationDirective = ctxt.finish() + def onFailure(cause: Throwable, ctxt: Ctxt): TerminationDirective = ctxt.fail(cause) } trait DeterministicOp[In, Out] extends Op[In, Out, Directive, Directive, Context[Out]] @@ -73,9 +73,9 @@ object OneBoundedInterpreter { private[akka] object Finished extends BoundaryOp { override def onPush(elem: Any, ctxt: BoundaryContext): UpstreamDirective = ctxt.finish() override def onPull(ctxt: BoundaryContext): DownstreamDirective = ctxt.finish() - override def onUpstreamFinish(ctxt: BoundaryContext): Directive = ctxt.exit() - override def onDownstreamFinish(ctxt: BoundaryContext): Directive = ctxt.exit() - override def onFailure(cause: Throwable, ctxt: BoundaryContext): Directive = ctxt.exit() + override def onUpstreamFinish(ctxt: BoundaryContext): TerminationDirective = ctxt.exit() + override def onDownstreamFinish(ctxt: BoundaryContext): TerminationDirective = ctxt.exit() + override def onFailure(cause: Throwable, ctxt: BoundaryContext): TerminationDirective = ctxt.exit() } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala new file mode 100644 index 0000000000..480afec88a --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala @@ -0,0 +1,83 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ +package akka.stream.impl.fusing + +object IteratorInterpreter { + case class IteratorUpstream[T](input: Iterator[T]) extends DeterministicOp[T, T] { + private var hasNext = input.hasNext + + override def onPush(elem: T, ctxt: Context[T]): Directive = + throw new UnsupportedOperationException("IteratorUpstream operates as a source, it cannot be pushed") + + override def onPull(ctxt: Context[T]): Directive = { + if (!hasNext) ctxt.finish() + else { + val elem = input.next() + hasNext = input.hasNext + if (!hasNext) ctxt.pushAndFinish(elem) + else ctxt.push(elem) + } + + } + } + + case class IteratorDownstream[T]() extends BoundaryOp with Iterator[T] { + private var done = false + private var nextElem: T = _ + private var needsPull = true + private var lastError: Throwable = null + + override def onPush(elem: Any, ctxt: BoundaryContext): Directive = { + nextElem = elem.asInstanceOf[T] + needsPull = false + ctxt.exit() + } + + override def onPull(ctxt: BoundaryContext): Directive = + throw new UnsupportedOperationException("IteratorDownstream operates as a sink, it cannot be pulled") + + override def onUpstreamFinish(ctxt: BoundaryContext): TerminationDirective = { + done = true + ctxt.finish() + } + + override def onFailure(cause: Throwable, ctxt: BoundaryContext): TerminationDirective = { + done = true + lastError = cause + ctxt.finish() + } + + private def pullIfNeeded(): Unit = { + if (needsPull) { + enter().pull() // will eventually result in a finish, or an onPush which exits + } + } + + override def hasNext: Boolean = { + if (!done) pullIfNeeded() + !(done && needsPull) + } + + override def next(): T = { + if (!hasNext) { + if (lastError != null) throw lastError + else Iterator.empty.next() + } + needsPull = true + nextElem + } + + } +} + +class IteratorInterpreter[I, O](val input: Iterator[I], val ops: Seq[DeterministicOp[_, _]]) { + import akka.stream.impl.fusing.IteratorInterpreter._ + + private val upstream = IteratorUpstream(input) + private val downstream = IteratorDownstream[O]() + private val interpreter = new OneBoundedInterpreter(upstream +: ops.asInstanceOf[Seq[Op[_, _, _, _, _]]] :+ downstream) + interpreter.init() + + def iterator: Iterator[O] = downstream +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 92386c6f16..f4a4b5d9c7 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -81,7 +81,7 @@ private[akka] case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends if (isFinishing) ctxt.pushAndFinish(aggregator) else ctxt.pull() - override def onUpstreamFinish(ctxt: Context[Out]): Directive = ctxt.absorbTermination() + override def onUpstreamFinish(ctxt: Context[Out]): TerminationDirective = ctxt.absorbTermination() } /** @@ -103,7 +103,7 @@ private[akka] case class Grouped[T](n: Int) extends DeterministicOp[T, immutable if (isFinishing) ctxt.pushAndFinish(buf) else ctxt.pull() - override def onUpstreamFinish(ctxt: Context[immutable.Seq[T]]): Directive = + override def onUpstreamFinish(ctxt: Context[immutable.Seq[T]]): TerminationDirective = if (buf.isEmpty) ctxt.finish() else ctxt.absorbTermination() } @@ -130,7 +130,7 @@ private[akka] case class Buffer[T](size: Int, overflowStrategy: OverflowStrategy else ctxt.push(buffer.dequeue().asInstanceOf[T]) } - override def onUpstreamFinish(ctxt: DetachedContext[T]): Directive = + override def onUpstreamFinish(ctxt: DetachedContext[T]): TerminationDirective = if (buffer.isEmpty) ctxt.finish() else ctxt.absorbTermination() @@ -208,7 +208,7 @@ private[akka] case class Conflate[In, Out](seed: In ⇒ Out, aggregate: (Out, In } } - override def onUpstreamFinish(ctxt: DetachedContext[Out]): Directive = ctxt.absorbTermination() + override def onUpstreamFinish(ctxt: DetachedContext[Out]): TerminationDirective = ctxt.absorbTermination() } /**