diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanAsyncSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanAsyncSpec.scala index 30c89a4493..a858830a01 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanAsyncSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanAsyncSpec.scala @@ -11,10 +11,12 @@ import akka.stream.testkit.TestSubscriber.Probe import akka.stream.testkit.Utils.TE import akka.stream.testkit._ import akka.stream.testkit.scaladsl._ +import akka.stream.{ ActorAttributes, ActorMaterializer, Supervision } import scala.collection.immutable -import scala.concurrent.Future +import scala.concurrent.{ Future, Promise } import scala.concurrent.duration._ +import scala.util.{ Failure, Success } class FlowScanAsyncSpec extends StreamSpec { @@ -48,6 +50,21 @@ class FlowScanAsyncSpec extends StreamSpec { sub.expectComplete() } + "complete after stream has been consumed and pending futures resolved" in { + val (pub, sub) = + TestSource.probe[Int] + .via(Flow[Int].scanAsync(0)((acc, in) ⇒ Future.successful(acc + in))) + .toMat(TestSink.probe)(Keep.both) + .run() + + pub.sendNext(1) + sub.request(10) + sub.expectNext(0) + sub.expectNext(1) + pub.sendComplete() + sub.expectComplete() + } + "fail after zero-element has been consumed" in { val (pub, sub) = TestSource.probe[Int] @@ -118,6 +135,28 @@ class FlowScanAsyncSpec extends StreamSpec { .expectNext(1, 1) .expectComplete() } + + "skip error values and handle stage completion after future get resolved" in { + val promises = Promise[Int].success(1) :: Promise[Int] :: Nil + val (pub, sub) = whenEventualFuture(promises, 0, decider = Supervision.restartingDecider) + pub.sendNext(0) + sub.expectNext(0, 1) + pub.sendNext(1) + promises(1).complete(Failure(TE("bang"))) + pub.sendComplete() + sub.expectComplete() + } + + "skip error values and handle stage completion before future get resolved" in { + val promises = Promise[Int].success(1) :: Promise[Int] :: Nil + val (pub, sub) = whenEventualFuture(promises, 0, decider = Supervision.restartingDecider) + pub.sendNext(0) + sub.expectNext(0, 1) + pub.sendNext(1) + pub.sendComplete() + promises(1).complete(Failure(TE("bang"))) + sub.expectComplete() + } } "with the resuming decider" should { @@ -134,6 +173,28 @@ class FlowScanAsyncSpec extends StreamSpec { .expectNext(1, 2) .expectComplete() } + + "skip error values and handle stage completion after future get resolved" in { + val promises = Promise[Int].success(1) :: Promise[Int] :: Nil + val (pub, sub) = whenEventualFuture(promises, 0, decider = Supervision.resumingDecider) + pub.sendNext(0) + sub.expectNext(0, 1) + pub.sendNext(1) + promises(1).complete(Failure(TE("bang"))) + pub.sendComplete() + sub.expectComplete() + } + + "skip error values and handle stage completion before future get resolved" in { + val promises = Promise[Int].success(1) :: Promise[Int] :: Nil + val (pub, sub) = whenEventualFuture(promises, 0, decider = Supervision.resumingDecider) + pub.sendNext(0) + sub.expectNext(0, 1) + pub.sendNext(1) + pub.sendComplete() + promises(1).complete(Failure(TE("bang"))) + sub.expectComplete() + } } "with the stopping decider" should { @@ -178,6 +239,27 @@ class FlowScanAsyncSpec extends StreamSpec { .expectNext(zero) } + def whenEventualFuture( + promises: immutable.Seq[Promise[Int]], + zero: Int, + decider: Supervision.Decider = Supervision.stoppingDecider + ): (TestPublisher.Probe[Int], TestSubscriber.Probe[Int]) = { + require(promises.nonEmpty, "must be at least one promise") + val promiseScanFlow = Flow[Int].scanAsync(zero) { (accumulator: Int, next: Int) ⇒ + promises(next).future + } + + val (pub, sub) = TestSource.probe[Int] + .via(promiseScanFlow) + .withAttributes(ActorAttributes.supervisionStrategy(decider)) + .toMat(TestSink.probe)(Keep.both) + .run() + + sub.request(promises.size + 1) + + (pub, sub) + } + def whenFailedFuture( elements: immutable.Seq[Int], zero: Int, diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 061dfd4799..382249a8fb 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -427,7 +427,7 @@ private[stream] object Collect { new GraphStageLogic(shape) with InHandler with OutHandler { self ⇒ private var current: Out = zero - private var eventualCurrent: Future[Out] = Future.successful(current) + private var elementHandled: Boolean = false private def ec = ExecutionContexts.sameThreadExecutionContext @@ -438,6 +438,7 @@ private[stream] object Collect { throw new IllegalStateException("No push should happen before zero value has been consumed") override def onPull(): Unit = { + elementHandled = true push(out, current) setHandlers(in, out, self) } @@ -452,21 +453,22 @@ private[stream] object Collect { private def onRestart(t: Throwable): Unit = { current = zero + elementHandled = false } private def safePull(): Unit = { - if (!hasBeenPulled(in)) { - tryPull(in) + if (isClosed(in)) { + completeStage() + } else if (isAvailable(out)) { + if (!hasBeenPulled(in)) { + tryPull(in) + } } } private def pushAndPullOrFinish(update: Out): Unit = { push(out, update) - if (isClosed(in)) { - completeStage() - } else if (isAvailable(out)) { - safePull() - } + safePull() } private def doSupervision(t: Throwable): Unit = { @@ -477,12 +479,14 @@ private[stream] object Collect { onRestart(t) safePull() } + elementHandled = true } private val futureCB = getAsyncCallback[Try[Out]] { case Success(next) if next != null ⇒ current = next pushAndPullOrFinish(next) + elementHandled = true case Success(null) ⇒ doSupervision(ReactiveStreamsCompliance.elementMustNotBeNullException) case Failure(t) ⇒ doSupervision(t) }.invoke _ @@ -493,7 +497,9 @@ private[stream] object Collect { def onPush(): Unit = { try { - eventualCurrent = f(current, grab(in)) + elementHandled = false + + val eventualCurrent = f(current, grab(in)) eventualCurrent.value match { case Some(result) ⇒ futureCB(result) @@ -507,21 +513,17 @@ private[stream] object Collect { case Supervision.Resume ⇒ () } tryPull(in) + elementHandled = true } } override def onUpstreamFinish(): Unit = { - if (current == zero) { - eventualCurrent.value match { - case Some(Success(`zero`)) ⇒ - // #24036 upstream completed without emitting anything but after zero was emitted downstream - completeStage() - case _ ⇒ // in all other cases we will get a complete when the future completes - } + if (elementHandled) { + completeStage() } } - override val toString: String = s"ScanAsync.Logic(completed=${eventualCurrent.isCompleted})" + override val toString: String = s"ScanAsync.Logic(completed=$elementHandled)" } }