diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala index afb98addaf..5ba0fc8496 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala @@ -3,6 +3,7 @@ */ package akka.stream.impl +import akka.stream.stage.GraphStageLogic.{ EagerTerminateOutput, EagerTerminateInput } import akka.stream.testkit.AkkaSpec import akka.stream._ import akka.stream.Fusing.aggressive @@ -80,8 +81,60 @@ class GraphStageLogicSpec extends AkkaSpec with GraphInterpreterSpecKit with Con } } + final case class ReadNEmitN(n: Int) extends GraphStage[FlowShape[Int, Int]] { + override val shape = FlowShape(Inlet[Int]("readN.in"), Outlet[Int]("readN.out")) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) { + setHandler(shape.in, EagerTerminateInput) + setHandler(shape.out, EagerTerminateOutput) + override def preStart(): Unit = readN(shape.in, n)(e ⇒ emitMultiple(shape.out, e.iterator, () ⇒ completeStage()), (_) ⇒ ()) + } + } + + final case class ReadNEmitRestOnComplete(n: Int) extends GraphStage[FlowShape[Int, Int]] { + override val shape = FlowShape(Inlet[Int]("readN.in"), Outlet[Int]("readN.out")) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) { + setHandler(shape.in, EagerTerminateInput) + setHandler(shape.out, EagerTerminateOutput) + override def preStart(): Unit = + readN(shape.in, n)( + _ ⇒ failStage(new IllegalStateException("Shouldn't happen!")), + e ⇒ emitMultiple(shape.out, e.iterator, () ⇒ completeStage())) + } + } + "A GraphStageLogic" must { + "read N and emit N before completing" in assertAllStagesStopped { + Source(1 to 10).via(ReadNEmitN(2)).runWith(TestSink.probe) + .request(10) + .expectNext(1, 2) + .expectComplete() + } + + "read N should not emit if upstream completes before N is sent" in assertAllStagesStopped { + Source(1 to 5).via(ReadNEmitN(6)).runWith(TestSink.probe) + .request(10) + .expectComplete() + } + + "read N should not emit if upstream fails before N is sent" in assertAllStagesStopped { + val error = new IllegalArgumentException("Don't argue like that!") + Source(1 to 5).map(x ⇒ if (x > 3) throw error else x).via(ReadNEmitN(6)).runWith(TestSink.probe) + .request(10) + .expectError(error) + } + + "read N should provide elements read if onComplete happens before N elements have been seen" in assertAllStagesStopped { + Source(1 to 5).via(ReadNEmitRestOnComplete(6)).runWith(TestSink.probe) + .request(10) + .expectNext(1, 2, 3, 4, 5) + .expectComplete() + } + "emit all things before completing" in assertAllStagesStopped { Source.empty.via(emit1234.named("testStage")).runWith(TestSink.probe) diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index bc375372fe..cf9efda929 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -514,37 +514,36 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * Read a number of elements from the given inlet and continue with the given function, * suspending execution if necessary. This action replaces the [[InHandler]] * for the given inlet if suspension is needed and reinstalls the current - * handler upon receiving the last `onPush()` signal (before invoking the `andThen` function). + * handler upon receiving the last `onPush()` signal. + * + * If upstream closes before N elements have been read, + * the `onClose` function is invoked with the elements which were read. */ final protected def readN[T](in: Inlet[T], n: Int)(andThen: Seq[T] ⇒ Unit, onClose: Seq[T] ⇒ Unit): Unit = + //FIXME `onClose` is a poor name for `onComplete` rename this at the earliest possible opportunity if (n < 0) throw new IllegalArgumentException("cannot read negative number of elements") else if (n == 0) andThen(Nil) else { - val result = new ArrayBuffer[T](n) + val result = new Array[AnyRef](n).asInstanceOf[Array[T]] var pos = 0 - def realAndThen = (elem: T) ⇒ { - result(pos) = elem - pos += 1 - if (pos == n) andThen(result) - } - def realOnClose = () ⇒ onClose(result.take(pos)) - if (isAvailable(in)) { - val elem = grab(in) - result(0) = elem - if (n == 1) { - andThen(result) - } else { - pos = 1 - requireNotReading(in) - pull(in) - setHandler(in, new Reading(in, n - 1, getHandler(in))(realAndThen, realOnClose)) - } - } else { + if (isAvailable(in)) { //If we already have data available, then shortcircuit and read the first + result(pos) = grab(in) + pos += 1 + } + + if (n != pos) { // If we aren't already done requireNotReading(in) if (!hasBeenPulled(in)) pull(in) - setHandler(in, new Reading(in, n, getHandler(in))(realAndThen, realOnClose)) - } + setHandler(in, new Reading(in, n - pos, getHandler(in))( + (elem: T) ⇒ { + result(pos) = elem + pos += 1 + if (pos == n) andThen(result) + }, + () ⇒ onClose(result.take(pos))) + ) + } else andThen(result) } /** @@ -554,6 +553,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * handler upon receiving the last `onPush()` signal (before invoking the `andThen` function). */ final protected def readN[T](in: Inlet[T], n: Int, andThen: Procedure[java.util.List[T]], onClose: Procedure[java.util.List[T]]): Unit = { + //FIXME `onClose` is a poor name for `onComplete` rename this at the earliest possible opportunity import collection.JavaConverters._ readN(in, n)(seq ⇒ andThen(seq.asJava), seq ⇒ onClose(seq.asJava)) } @@ -604,22 +604,25 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: throw new IllegalStateException("already reading on inlet " + in) /** - * Caution: for n==1 andThen is called after resetting the handler, for - * other values it is called without resetting the handler. + * Caution: for n == 1 andThen is called after resetting the handler, for + * other values it is called without resetting the handler. n MUST be positive. */ - private class Reading[T](in: Inlet[T], private var n: Int, val previous: InHandler)(andThen: T ⇒ Unit, onClose: () ⇒ Unit) extends InHandler { + private final class Reading[T](in: Inlet[T], private var n: Int, val previous: InHandler)(andThen: T ⇒ Unit, onComplete: () ⇒ Unit) extends InHandler { + require(n > 0, "number of elements to read must be positive!") + override def onPush(): Unit = { val elem = grab(in) - if (n == 1) setHandler(in, previous) - else { - n -= 1 - pull(in) - } + n -= 1 + + if (n > 0) pull(in) + else setHandler(in, previous) + andThen(elem) } + override def onUpstreamFinish(): Unit = { setHandler(in, previous) - onClose() + onComplete() previous.onUpstreamFinish() } override def onUpstreamFailure(ex: Throwable): Unit = {