diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala index 4e02890aae..2f8ba7542d 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala @@ -327,10 +327,10 @@ class InterpreterSpec extends InterpreterSpecKit { "work with expand-expand" in new TestSetup(Seq( Expand( (in: Int) ⇒ in, - (agg: Int) ⇒ (agg, agg)), + (agg: Int) ⇒ (agg, agg + 1)), Expand( (in: Int) ⇒ in, - (agg: Int) ⇒ (agg, agg)))) { + (agg: Int) ⇒ (agg, agg + 1)))) { lastEvents() should be(Set(RequestOne)) @@ -341,22 +341,24 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(0))) downstream.requestOne() - lastEvents() should be(Set(OnNext(0))) + lastEvents() should be(Set(OnNext(1))) - upstream.onNext(1) + upstream.onNext(10) lastEvents() should be(Set.empty) downstream.requestOne() - lastEvents() should be(Set(RequestOne, OnNext(0))) // One zero is still in the pipeline + lastEvents() should be(Set(RequestOne, OnNext(2))) // One element is still in the pipeline downstream.requestOne() - lastEvents() should be(Set(OnNext(1))) + lastEvents() should be(Set(OnNext(10))) downstream.requestOne() - lastEvents() should be(Set(OnNext(1))) + lastEvents() should be(Set(OnNext(11))) upstream.onComplete() - lastEvents() should be(Set(OnComplete)) + downstream.requestOne() + // This is correct! If you don't believe, run the interpreter with Debug on + lastEvents() should be(Set(OnComplete, OnNext(12))) } "implement conflate-expand" in new TestSetup(Seq( diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowExpandSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowExpandSpec.scala index c9d6837dad..3eb7bf64b8 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowExpandSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowExpandSpec.scala @@ -66,6 +66,28 @@ class FlowExpandSpec extends AkkaSpec { sub.cancel() } + "do not drop last element" in { + val publisher = StreamTestKit.PublisherProbe[Int]() + val subscriber = StreamTestKit.SubscriberProbe[Int]() + + // Simply repeat the last element as an extrapolation step + Source(publisher).expand(seed = i ⇒ i)(extrapolate = i ⇒ (i, i)).runWith(Sink(subscriber)) + + val autoPublisher = new StreamTestKit.AutoPublisher(publisher) + val sub = subscriber.expectSubscription() + + autoPublisher.sendNext(1) + sub.request(1) + subscriber.expectNext(1) + + autoPublisher.sendNext(2) + autoPublisher.sendComplete() + + sub.request(1) + subscriber.expectNext(2) + subscriber.expectComplete() + } + "work on a variable rate chain" in { val future = Source(1 to 100) .map { i ⇒ if (ThreadLocalRandom.current().nextBoolean()) Thread.sleep(10); i } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Interpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Interpreter.scala index f10fc03b65..4bbb7482ee 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Interpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Interpreter.scala @@ -36,7 +36,7 @@ private[akka] abstract class BoundaryStage extends AbstractStage[Any, Any, Direc * INTERNAL API */ private[akka] object OneBoundedInterpreter { - final val PhantomDirective = null + final val Debug = false /** * INTERNAL API @@ -142,7 +142,7 @@ private[akka] class OneBoundedInterpreter(ops: Seq[Stage[_, _]], val forkLimit: type UntypedOp = AbstractStage[Any, Any, Directive, Directive, Context[Any]] require(ops.nonEmpty, "OneBoundedInterpreter cannot be created without at least one Op") - private val pipeline: Array[UntypedOp] = ops.map(_.asInstanceOf[UntypedOp])(breakOut) + private final val pipeline: Array[UntypedOp] = ops.map(_.asInstanceOf[UntypedOp])(breakOut) /** * This table is used to accelerate demand propagation upstream. All ops that implement PushStage are guaranteed @@ -152,10 +152,10 @@ private[akka] class OneBoundedInterpreter(ops: Seq[Stage[_, _]], val forkLimit: * This table maintains the positions where execution should jump from a current position when a pull event is to * be executed. */ - private val jumpBacks: Array[Int] = calculateJumpBacks + private final val jumpBacks: Array[Int] = calculateJumpBacks - private val Upstream = 0 - private val Downstream = pipeline.length - 1 + private final val Upstream = 0 + private final val Downstream = pipeline.length - 1 // Var to hold the current element if pushing. The only reason why this var is needed is to avoid allocations and // make it possible for the Pushing state to be an object @@ -210,20 +210,20 @@ private[akka] class OneBoundedInterpreter(ops: Seq[Stage[_, _]], val forkLimit: currentOp.allowedToPush = false elementInFlight = elem state = Pushing - PhantomDirective + null } override def pull(): UpstreamDirective = { if (currentOp.holding) throw new IllegalStateException("Cannot pull while holding, only pushAndPull") currentOp.allowedToPush = !currentOp.isInstanceOf[DetachedStage[_, _]] state = Pulling - PhantomDirective + null } override def finish(): FreeDirective = { fork(Completing) state = Cancelling - PhantomDirective + null } def isFinishing: Boolean = currentOp.terminationPending @@ -244,7 +244,7 @@ private[akka] class OneBoundedInterpreter(ops: Seq[Stage[_, _]], val forkLimit: override def fail(cause: Throwable): FreeDirective = { fork(Failing(cause)) state = Cancelling - PhantomDirective + null } override def hold(): FreeDirective = { @@ -260,7 +260,7 @@ private[akka] class OneBoundedInterpreter(ops: Seq[Stage[_, _]], val forkLimit: currentOp.holding = false fork(Pushing, elem) state = Pulling - PhantomDirective + null } override def absorbTermination(): TerminationDirective = { @@ -271,32 +271,32 @@ private[akka] class OneBoundedInterpreter(ops: Seq[Stage[_, _]], val forkLimit: override def exit(): FreeDirective = { elementInFlight = null activeOpIndex = -1 - PhantomDirective + null } } - private object Pushing extends State { + private final val Pushing: State = new State { override def advance(): Unit = activeOpIndex += 1 override def run(): Unit = currentOp.onPush(elementInFlight, ctx = this) } - private object PushFinish extends State { + private final val PushFinish: State = new State { override def advance(): Unit = activeOpIndex += 1 override def run(): Unit = currentOp.onPush(elementInFlight, ctx = this) override def pushAndFinish(elem: Any): DownstreamDirective = { elementInFlight = elem state = PushFinish - PhantomDirective + null } override def finish(): FreeDirective = { state = Completing - PhantomDirective + null } } - private object Pulling extends State { + private final val Pulling: State = new State { override def advance(): Unit = { elementInFlight = null activeOpIndex = jumpBacks(activeOpIndex) @@ -310,7 +310,7 @@ private[akka] class OneBoundedInterpreter(ops: Seq[Stage[_, _]], val forkLimit: } } - private object Completing extends State { + private final val Completing: State = new State { override def advance(): Unit = { elementInFlight = null pipeline(activeOpIndex) = Finished.asInstanceOf[UntypedOp] @@ -324,7 +324,7 @@ private[akka] class OneBoundedInterpreter(ops: Seq[Stage[_, _]], val forkLimit: override def finish(): FreeDirective = { state = Completing - PhantomDirective + null } override def absorbTermination(): TerminationDirective = { @@ -333,11 +333,11 @@ private[akka] class OneBoundedInterpreter(ops: Seq[Stage[_, _]], val forkLimit: // FIXME: This state is potentially corrupted by the jumpBackTable (not updated when jumping over) if (currentOp.allowedToPush) currentOp.onPull(ctx = Pulling) else exit() - PhantomDirective + null } } - private object Cancelling extends State { + private final val Cancelling: State = new State { override def advance(): Unit = { elementInFlight = null pipeline(activeOpIndex) = Finished.asInstanceOf[UntypedOp] @@ -351,7 +351,7 @@ private[akka] class OneBoundedInterpreter(ops: Seq[Stage[_, _]], val forkLimit: override def finish(): FreeDirective = { state = Cancelling - PhantomDirective + null } } @@ -369,15 +369,31 @@ private[akka] class OneBoundedInterpreter(ops: Seq[Stage[_, _]], val forkLimit: currentOp.holding = false if (currentOp.allowedToPush) currentOp.onPull(ctx = Pulling) else exit() - PhantomDirective + null } } private def inside: Boolean = activeOpIndex > -1 && activeOpIndex < pipeline.length + private def printDebug(): Unit = { + val padding = " " * activeOpIndex + val icon: String = state match { + case Pushing | PushFinish ⇒ padding + s"---> $elementInFlight" + case Pulling ⇒ + (" " * jumpBacks(activeOpIndex)) + + "<---" + + ("----" * (activeOpIndex - jumpBacks(activeOpIndex) - 1)) + case Completing ⇒ padding + "---|" + case Cancelling ⇒ padding + "|---" + case Failing(e) ⇒ padding + s"---X ${e.getMessage}" + } + println(icon) + } + @tailrec private def execute(): Unit = { while (inside) { try { + if (Debug) printDebug() state.progress() } catch { case NonFatal(e) if lastOpFailing != activeOpIndex ⇒ @@ -451,42 +467,42 @@ private[akka] class OneBoundedInterpreter(ops: Seq[Stage[_, _]], val forkLimit: activeOpIndex = entryPoint super.push(elem) execute() - PhantomDirective + null } override def pull(): UpstreamDirective = { activeOpIndex = entryPoint super.pull() execute() - PhantomDirective + null } override def finish(): FreeDirective = { activeOpIndex = entryPoint super.finish() execute() - PhantomDirective + null } override def fail(cause: Throwable): FreeDirective = { activeOpIndex = entryPoint super.fail(cause) execute() - PhantomDirective + null } override def hold(): FreeDirective = { activeOpIndex = entryPoint super.hold() execute() - PhantomDirective + null } override def pushAndPull(elem: Any): FreeDirective = { activeOpIndex = entryPoint super.pushAndPull(elem) execute() - PhantomDirective + null } } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index b1f9b0d030..83e2a4ecf1 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -267,25 +267,39 @@ private[akka] final case class Conflate[In, Out](seed: In ⇒ Out, aggregate: (O * INTERNAL API */ private[akka] final case class Expand[In, Out, Seed](seed: In ⇒ Seed, extrapolate: Seed ⇒ (Out, Seed)) extends DetachedStage[In, Out] { - private var s: Any = null + private var s: Seed = _ + private var started: Boolean = false + private var expanded: Boolean = false override def onPush(elem: In, ctx: DetachedContext[Out]): UpstreamDirective = { s = seed(elem) + started = true + expanded = false if (ctx.isHolding) { - val (emit, newS) = extrapolate(s.asInstanceOf[Seed]) + val (emit, newS) = extrapolate(s) s = newS + expanded = true ctx.pushAndPull(emit) } else ctx.hold() } override def onPull(ctx: DetachedContext[Out]): DownstreamDirective = { - if (s == null) ctx.hold() + if (ctx.isFinishing) { + if (!started) ctx.finish() + else ctx.pushAndFinish(extrapolate(s)._1) + } else if (!started) ctx.hold() else { - val (emit, newS) = extrapolate(s.asInstanceOf[Seed]) + val (emit, newS) = extrapolate(s) s = newS + expanded = true if (ctx.isHolding) ctx.pushAndPull(emit) else ctx.push(emit) } } + + override def onUpstreamFinish(ctx: DetachedContext[Out]): TerminationDirective = { + if (expanded) ctx.finish() + else ctx.absorbTermination() + } }