diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala index 6363c34c37..275a5d907a 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala @@ -6,7 +6,8 @@ package akka.stream.impl.fusing import akka.testkit.AkkaSpec import akka.util.ByteString import akka.stream.stage._ -import akka.stream.Supervision +import akka.stream.{ Attributes, Supervision } +import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { import Supervision.stoppingDecider @@ -34,7 +35,7 @@ class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { } "work with ops that need extra pull for complete" in { - val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(NaiveTake(1).toGS)).iterator + val itr = new IteratorInterpreter[Int, Int]((1 to 10).iterator, Seq(NaiveTake(1))).iterator itr.toSeq should be(Seq(1)) } @@ -47,29 +48,9 @@ class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { a[NoSuchElementException] should be thrownBy { itr.next() } } - "throw exceptions when chain fails" in { - val stage = new PushStage[Int, Int] { - override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = { - if (elem == 2) ctx.fail(new ArithmeticException()) - else ctx.push(elem) - } - } - val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(stage.toGS)).iterator - - itr.next() should be(1) - itr.hasNext should be(true) - a[ArithmeticException] should be thrownBy { itr.next() } - itr.hasNext should be(false) - } - "throw exceptions when op in chain throws" in { - val stage = new PushStage[Int, Int] { - override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = { - if (elem == 2) throw new ArithmeticException() - else ctx.push(elem) - } - } - val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq(stage.toGS)).iterator + val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq( + Map((n: Int) ⇒ if (n == 2) throw new ArithmeticException() else n, stoppingDecider).toGS)).iterator itr.next() should be(1) itr.hasNext should be(true) @@ -90,7 +71,7 @@ class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { def newItr(threshold: Int) = new IteratorInterpreter[ByteString, ByteString](testBytes.iterator, Seq( - ByteStringBatcher(threshold).toGS)).iterator + ByteStringBatcher(threshold))).iterator val itr1 = newItr(20) itr1.next() should be(ByteString(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) @@ -110,7 +91,7 @@ class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { val itr4 = new IteratorInterpreter[ByteString, ByteString](Iterator.empty, Seq( - ByteStringBatcher(10).toGS)).iterator + ByteStringBatcher(10))).iterator itr4.hasNext should be(false) } @@ -118,48 +99,65 @@ class IteratorInterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { } // This op needs an extra pull round to finish - case class NaiveTake[T](count: Int) extends PushPullStage[T, T] { - private var left: Int = count + case class NaiveTake[T](count: Int) extends SimpleLinearGraphStage[T] { - override def onPush(elem: T, ctx: Context[T]): SyncDirective = { - left -= 1 - ctx.push(elem) - } + override def createLogic(attributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with InHandler with OutHandler { + private var left: Int = count - override def onPull(ctx: Context[T]): SyncDirective = { - if (left == 0) ctx.finish() - else ctx.pull() - } + override def onPush(): Unit = { + left -= 1 + push(out, grab(in)) + } + + override def onPull(): Unit = { + if (left == 0) completeStage() + else pull(in) + } + + setHandlers(in, out, this) + } + + override def toString = "NaiveTake" } - case class ByteStringBatcher(threshold: Int, compact: Boolean = true) extends PushPullStage[ByteString, ByteString] { + case class ByteStringBatcher(threshold: Int, compact: Boolean = true) extends SimpleLinearGraphStage[ByteString] { require(threshold > 0, "Threshold must be positive") - private var buf = ByteString.empty - private var passthrough = false + override def createLogic(attributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with InHandler with OutHandler { + private var buf: ByteString = ByteString.empty + private var passthrough: Boolean = false - override def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = { - if (passthrough) ctx.push(elem) - else { - buf = buf ++ elem - if (buf.size >= threshold) { - val batch = if (compact) buf.compact else buf - passthrough = true - buf = ByteString.empty - ctx.push(batch) - } else ctx.pull() + override def onPush(): Unit = { + val elem = grab(in) + if (passthrough) push(out, elem) + else { + buf = buf ++ elem + if (buf.size >= threshold) { + val batch = if (compact) buf.compact else buf + passthrough = true + buf = ByteString.empty + push(out, batch) + } else pull(in) + } + } + + override def onPull(): Unit = { + if (isClosed(in)) { + push(out, buf) + completeStage() + } else pull(in) + } + + override def onUpstreamFinish(): Unit = { + if (passthrough || buf.isEmpty) completeStage() + else if (isAvailable(out)) onPull() + } + + setHandlers(in, out, this) } - } - override def onPull(ctx: Context[ByteString]): SyncDirective = { - if (ctx.isFinishing) ctx.pushAndFinish(buf) - else ctx.pull() - } - - override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = { - if (passthrough || buf.isEmpty) ctx.finish() - else ctx.absorbTermination() - } + override def toString = "ByteStringBatcher" } - }