From 0e60020c58ddc1bb19c6fcd7bb839d7901b496e6 Mon Sep 17 00:00:00 2001 From: Vsevolod Belousov Date: Mon, 17 Oct 2016 08:02:54 +0100 Subject: [PATCH] fold async emits zero given an empty stream #21562 --- .../stream/scaladsl/FlowFoldAsyncSpec.scala | 36 +++-- .../akka/stream/scaladsl/FlowFoldSpec.scala | 8 + .../scala/akka/stream/impl/fusing/Ops.scala | 140 +++++++++++++----- 3 files changed, 132 insertions(+), 52 deletions(-) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFoldAsyncSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFoldAsyncSpec.scala index 411c692257..545a6ec8b6 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFoldAsyncSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFoldAsyncSpec.scala @@ -3,22 +3,20 @@ */ package akka.stream.scaladsl -import scala.util.control.NoStackTrace - -import scala.concurrent.{ Await, Future } -import scala.concurrent.duration._ - import akka.NotUsed -import akka.stream.ActorMaterializer import akka.stream.ActorAttributes.supervisionStrategy -import akka.stream.Supervision.{ restartingDecider, resumingDecider } +import akka.stream.ActorMaterializer +import akka.stream.Supervision.{restartingDecider, resumingDecider} import akka.stream.impl.ReactiveStreamsCompliance - -import akka.testkit.{ AkkaSpec, TestLatch } -import akka.stream.testkit._, Utils._ - +import akka.stream.testkit.Utils._ +import akka.stream.testkit._ +import akka.testkit.TestLatch import org.scalatest.concurrent.PatienceConfiguration.Timeout +import scala.concurrent.duration._ +import scala.concurrent.{Await, Future} +import scala.util.control.NoStackTrace + class FlowFoldAsyncSpec extends StreamSpec { implicit val materializer = ActorMaterializer() implicit def ec = materializer.executionContext @@ -261,6 +259,22 @@ class FlowFoldAsyncSpec extends StreamSpec { upstream.expectCancellation() } + + "complete future and return zero given an empty stream" in assertAllStagesStopped { + val futureValue = + Source.fromIterator[Int](() ⇒ Iterator.empty) + .runFoldAsync(0)((acc, elem) ⇒ Future.successful(acc + elem)) + + Await.result(futureValue, remainingOrDefault) should be(0) + } + + "complete future and return zero + item given a stream of one item" in assertAllStagesStopped { + val futureValue = + Source.single(100) + .runFoldAsync(5)((acc, elem) ⇒ Future.successful(acc + elem)) + + Await.result(futureValue, remainingOrDefault) should be(105) + } } // Keep diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFoldSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFoldSpec.scala index a71eda4145..4c0bac1783 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFoldSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFoldSpec.scala @@ -56,6 +56,14 @@ class FlowFoldSpec extends StreamSpec { the[Exception] thrownBy Await.result(future, 3.seconds) should be(error) } + "complete future and return zero given an empty stream" in assertAllStagesStopped { + val futureValue = + Source.fromIterator[Int](() ⇒ Iterator.empty) + .runFold(0)(_ + _) + + Await.result(futureValue, 3.seconds) should be(0) + } + } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 57d63a9ee6..7c92df8113 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -30,6 +30,7 @@ final case class Map[In, Out](f: In ⇒ Out) extends GraphStage[FlowShape[In, Ou val in = Inlet[In]("Map.in") val out = Outlet[Out]("Map.out") override val shape = FlowShape(in, out) + override def initialAttributes: Attributes = DefaultAttributes.map override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = @@ -131,6 +132,7 @@ final case class DropWhile[T](p: T ⇒ Boolean) extends GraphStage[FlowShape[T, val in = Inlet[T]("DropWhile.in") val out = Outlet[T]("DropWhile.out") override val shape = FlowShape(in, out) + override def initialAttributes: Attributes = DefaultAttributes.dropWhile def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler { @@ -150,9 +152,12 @@ final case class DropWhile[T](p: T ⇒ Boolean) extends GraphStage[FlowShape[T, } override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in) + override def onPull(): Unit = pull(in) + setHandlers(in, out, this) } + override def toString = "DropWhile" } @@ -163,7 +168,9 @@ abstract class SupervisedGraphStageLogic(inheritedAttributes: Attributes, shape: private lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider) def withSupervision[T](f: () ⇒ T): Option[T] = - try { Some(f()) } catch { + try { + Some(f()) + } catch { case NonFatal(ex) ⇒ decider(ex) match { case Supervision.Stop ⇒ onStop(ex) @@ -174,7 +181,9 @@ abstract class SupervisedGraphStageLogic(inheritedAttributes: Attributes, shape: } def onResume(t: Throwable): Unit + def onStop(t: Throwable): Unit = failStage(t) + def onRestart(t: Throwable): Unit = onResume(t) } @@ -192,10 +201,13 @@ final case class Collect[In, Out](pf: PartialFunction[In, Out]) extends GraphSta val in = Inlet[In]("Collect.in") val out = Outlet[Out]("Collect.out") override val shape = FlowShape(in, out) + override def initialAttributes: Attributes = DefaultAttributes.collect def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler { + import Collect.NotApplied + val wrappedPf = () ⇒ pf.applyOrElse(grab(in), NotApplied) override def onPush(): Unit = withSupervision(wrappedPf) match { @@ -207,9 +219,12 @@ final case class Collect[In, Out](pf: PartialFunction[In, Out]) extends GraphSta } override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in) + override def onPull(): Unit = pull(in) + setHandlers(in, out, this) } + override def toString = "Collect" } @@ -224,6 +239,7 @@ final case class Recover[T](pf: PartialFunction[Throwable, T]) extends GraphStag override protected val initialAttributes: Attributes = DefaultAttributes.recover override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { + import Collect.NotApplied var recovered: Option[T] = None @@ -320,6 +336,7 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta override val shape = FlowShape[In, Out](Inlet("Scan.in"), Outlet("Scan.out")) override def initialAttributes: Attributes = DefaultAttributes.scan + override def toString: String = "Scan" override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = @@ -342,6 +359,7 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta setHandler(in, new InHandler { override def onPush(): Unit = () + override def onUpstreamFinish(): Unit = setHandler(out, new OutHandler { override def onPull(): Unit = { push(out, aggregator) @@ -351,6 +369,7 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta }) override def onPull(): Unit = pull(in) + override def onPush(): Unit = { try { aggregator = f(aggregator, grab(in)) @@ -426,6 +445,7 @@ final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta * INTERNAL API */ final class FoldAsync[In, Out](zero: Out, f: (Out, In) ⇒ Future[Out]) extends GraphStage[FlowShape[In, Out]] { + import akka.dispatch.ExecutionContexts val in = Inlet[In]("FoldAsync.in") @@ -436,30 +456,29 @@ final class FoldAsync[In, Out](zero: Out, f: (Out, In) ⇒ Future[Out]) extends override val initialAttributes = DefaultAttributes.foldAsync - def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { - val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider) + def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with InHandler with OutHandler { + val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider) - private var aggregator: Out = zero - private var aggregating: Future[Out] = Future.successful(aggregator) + private var aggregator: Out = zero + private var aggregating: Future[Out] = Future.successful(aggregator) - private def onRestart(t: Throwable): Unit = { - aggregator = zero - } + private def onRestart(t: Throwable): Unit = { + aggregator = zero + } - private def ec = ExecutionContexts.sameThreadExecutionContext + private def ec = ExecutionContexts.sameThreadExecutionContext - private val futureCB = getAsyncCallback[Try[Out]]((result: Try[Out]) ⇒ { - result match { - case Success(update) if update != null ⇒ { + private val futureCB = getAsyncCallback[Try[Out]] { + case Success(update) if update != null ⇒ aggregator = update if (isClosed(in)) { push(out, update) completeStage() } else if (isAvailable(out) && !hasBeenPulled(in)) tryPull(in) - } - case other ⇒ { + case other ⇒ val ex = other match { case Failure(t) ⇒ t case Success(s) if s == null ⇒ @@ -476,42 +495,45 @@ final class FoldAsync[In, Out](zero: Out, f: (Out, In) ⇒ Future[Out]) extends completeStage() } else if (isAvailable(out) && !hasBeenPulled(in)) tryPull(in) } + }.invoke _ + + def onPush(): Unit = { + try { + aggregating = f(aggregator, grab(in)) + handleAggregatingValue() + } catch { + case NonFatal(ex) ⇒ decider(ex) match { + case Supervision.Stop ⇒ failStage(ex) + case supervision ⇒ { + supervision match { + case Supervision.Restart ⇒ onRestart(ex) + case _ ⇒ () // just ignore on Resume + } + + tryPull(in) + } + } } } - }).invoke _ - def onPush(): Unit = { - try { - aggregating = f(aggregator, grab(in)) + override def onUpstreamFinish(): Unit = { + handleAggregatingValue() + } + def onPull(): Unit = if (!hasBeenPulled(in)) tryPull(in) + + private def handleAggregatingValue(): Unit = { aggregating.value match { case Some(result) ⇒ futureCB(result) // already completed case _ ⇒ aggregating.onComplete(futureCB)(ec) } - } catch { - case NonFatal(ex) ⇒ decider(ex) match { - case Supervision.Stop ⇒ failStage(ex) - case supervision ⇒ { - supervision match { - case Supervision.Restart ⇒ onRestart(ex) - case _ ⇒ () // just ignore on Resume - } - - tryPull(in) - } - } } + + setHandlers(in, out, this) + + override def toString = + s"FoldAsync.Logic(completed=${aggregating.isCompleted})" } - - override def onUpstreamFinish(): Unit = {} - - def onPull(): Unit = if (!hasBeenPulled(in)) tryPull(in) - - setHandlers(in, out, this) - - override def toString = - s"FoldAsync.Logic(completed=${aggregating.isCompleted})" - } } /** @@ -619,6 +641,7 @@ final case class LimitWeighted[T](val n: Long, val costFn: T ⇒ Long) extends G val in = Inlet[T]("LimitWeighted.in") val out = Outlet[T]("LimitWeighted.out") override val shape = FlowShape(in, out) + override def initialAttributes: Attributes = DefaultAttributes.limitWeighted def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler { @@ -633,14 +656,19 @@ final case class LimitWeighted[T](val n: Long, val costFn: T ⇒ Long) extends G case None ⇒ //do nothing } } + override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in) + override def onRestart(t: Throwable): Unit = { left = n if (!hasBeenPulled(in)) pull(in) } + override def onPull(): Unit = pull(in) + setHandlers(in, out, this) } + override def toString = "LimitWeighted" } @@ -892,6 +920,7 @@ final class Expand[In, Out](val extrapolate: In ⇒ Iterator[Out]) extends Graph private val out = Outlet[Out]("expand.out") override def initialAttributes = DefaultAttributes.expand + override val shape = FlowShape(in, out) override def createLogic(attr: Attributes) = new GraphStageLogic(shape) with InHandler with OutHandler { @@ -941,6 +970,7 @@ final class Expand[In, Out](val extrapolate: In ⇒ Iterator[Out]) extends Graph * INTERNAL API */ private[akka] object MapAsync { + final class Holder[T](var elem: Try[T], val cb: AsyncCallback[Holder[T]]) extends (Try[T] ⇒ Unit) { def setElem(t: Try[T]): Unit = elem = t match { @@ -953,6 +983,7 @@ private[akka] object MapAsync { cb.invoke(this) } } + val NotYetThere = Failure(new Exception) } @@ -968,6 +999,7 @@ final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out]) private val out = Outlet[Out]("MapAsync.out") override def initialAttributes = DefaultAttributes.mapAsync + override val shape = FlowShape(in, out) override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = @@ -984,6 +1016,7 @@ final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out]) case _ ⇒ if (isAvailable(out)) pushOne() } } + val futureCB = getAsyncCallback[Holder[Out]](holderCompleted) private[this] def todo = buffer.used @@ -1023,6 +1056,7 @@ final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out]) } if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) } + override def onUpstreamFinish(): Unit = if (todo == 0) completeStage() override def onPull(): Unit = pushOne() @@ -1041,6 +1075,7 @@ final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[O private val out = Outlet[Out]("MapAsyncUnordered.out") override def initialAttributes = DefaultAttributes.mapAsyncUnordered + override val shape = FlowShape(in, out) override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = @@ -1052,6 +1087,7 @@ final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[O private var inFlight = 0 private var buffer: BufferImpl[Out] = _ + private[this] def todo = inFlight + buffer.used override def preStart(): Unit = buffer = BufferImpl(parallelism, materializer) @@ -1074,6 +1110,7 @@ final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[O else if (!hasBeenPulled(in)) tryPull(in) } } + private val futureCB = getAsyncCallback(futureCompleted) private val invokeFutureCB: Try[Out] ⇒ Unit = futureCB.invoke @@ -1119,6 +1156,7 @@ final case class Log[T]( // TODO more optimisations can be done here - prepare logOnPush function etc override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with OutHandler with InHandler { + import Log._ private var logLevels: LogLevels = _ @@ -1221,9 +1259,13 @@ private[akka] object Log { * INTERNAL API */ private[stream] object TimerKeys { + case object TakeWithinTimerKey + case object DropWithinTimerKey + case object GroupedWithinTimerKey + } final class GroupedWithin[T](val n: Int, val d: FiniteDuration) extends GraphStage[FlowShape[T, immutable.Seq[T]]] { @@ -1232,7 +1274,9 @@ final class GroupedWithin[T](val n: Int, val d: FiniteDuration) extends GraphSta val in = Inlet[T]("in") val out = Outlet[immutable.Seq[T]]("out") + override def initialAttributes = DefaultAttributes.groupedWithin + val shape = FlowShape(in, out) override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler { @@ -1304,7 +1348,9 @@ final class GroupedWithin[T](val n: Int, val d: FiniteDuration) extends GraphSta final class Delay[T](val d: FiniteDuration, val strategy: DelayOverflowStrategy) extends SimpleLinearGraphStage[T] { private[this] def timerName = "DelayedTimer" + override def initialAttributes: Attributes = DefaultAttributes.delay + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler { val size = inheritedAttributes.get[InputBuffer] match { @@ -1418,6 +1464,7 @@ final class TakeWithin[T](val timeout: FiniteDuration) extends SimpleLinearGraph override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler { def onPush(): Unit = push(out, grab(in)) + def onPull(): Unit = pull(in) setHandler(in, this) @@ -1461,7 +1508,8 @@ final class DropWithin[T](val timeout: FiniteDuration) extends SimpleLinearGraph final class Reduce[T](val f: (T, T) ⇒ T) extends SimpleLinearGraphStage[T] { override def initialAttributes: Attributes = DefaultAttributes.reduce - override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { self ⇒ + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { + self ⇒ override def toString = s"Reduce.Logic(aggregator=$aggregator)" var aggregator: T = _ @@ -1505,6 +1553,7 @@ private[stream] object RecoverWith { final class RecoverWith[T, M](val maximumRetries: Int, val pf: PartialFunction[Throwable, Graph[SourceShape[T], M]]) extends SimpleLinearGraphStage[T] { require(maximumRetries >= -1, "number of retries must be non-negative or equal to -1") + override def initialAttributes = DefaultAttributes.recoverWith override def createLogic(attr: Attributes) = new GraphStageLogic(shape) { @@ -1512,6 +1561,7 @@ final class RecoverWith[T, M](val maximumRetries: Int, val pf: PartialFunction[T setHandler(in, new InHandler { override def onPush(): Unit = push(out, grab(in)) + override def onUpstreamFailure(ex: Throwable) = onFailure(ex) }) @@ -1531,12 +1581,15 @@ final class RecoverWith[T, M](val maximumRetries: Int, val pf: PartialFunction[T sinkIn.setHandler(new InHandler { override def onPush(): Unit = push(out, sinkIn.grab()) + override def onUpstreamFinish(): Unit = completeStage() + override def onUpstreamFailure(ex: Throwable) = onFailure(ex) }) val outHandler = new OutHandler { override def onPull(): Unit = sinkIn.pull() + override def onDownstreamFinish(): Unit = sinkIn.cancel() } @@ -1556,13 +1609,16 @@ final class StatefulMapConcat[In, Out](val f: () ⇒ In ⇒ immutable.Iterable[O val in = Inlet[In]("StatefulMapConcat.in") val out = Outlet[Out]("StatefulMapConcat.out") override val shape = FlowShape(in, out) + override def initialAttributes: Attributes = DefaultAttributes.statefulMapConcat def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) with InHandler with OutHandler { lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider) var currentIterator: Iterator[Out] = _ var plainFun = f() + def hasNext = if (currentIterator != null) currentIterator.hasNext else false + setHandlers(in, out, this) def pushPull(): Unit = @@ -1590,6 +1646,7 @@ final class StatefulMapConcat[In, Out](val f: () ⇒ In ⇒ immutable.Iterable[O } override def onUpstreamFinish(): Unit = onFinish() + override def onPull(): Unit = pushPull() private def restartState(): Unit = { @@ -1597,6 +1654,7 @@ final class StatefulMapConcat[In, Out](val f: () ⇒ In ⇒ immutable.Iterable[O currentIterator = null } } + override def toString = "StatefulMapConcat" }