fold async emits zero given an empty stream #21562

This commit is contained in:
Vsevolod Belousov 2016-10-17 08:02:54 +01:00 committed by Johan Andrén
parent 0b9a5026a0
commit 0e60020c58
3 changed files with 132 additions and 52 deletions

View file

@ -3,22 +3,20 @@
*/
package akka.stream.scaladsl
import scala.util.control.NoStackTrace
import scala.concurrent.{ Await, Future }
import scala.concurrent.duration._
import akka.NotUsed
import akka.stream.ActorMaterializer
import akka.stream.ActorAttributes.supervisionStrategy
import akka.stream.Supervision.{ restartingDecider, resumingDecider }
import akka.stream.ActorMaterializer
import akka.stream.Supervision.{restartingDecider, resumingDecider}
import akka.stream.impl.ReactiveStreamsCompliance
import akka.testkit.{ AkkaSpec, TestLatch }
import akka.stream.testkit._, Utils._
import akka.stream.testkit.Utils._
import akka.stream.testkit._
import akka.testkit.TestLatch
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import scala.concurrent.duration._
import scala.concurrent.{Await, Future}
import scala.util.control.NoStackTrace
class FlowFoldAsyncSpec extends StreamSpec {
implicit val materializer = ActorMaterializer()
implicit def ec = materializer.executionContext
@ -261,6 +259,22 @@ class FlowFoldAsyncSpec extends StreamSpec {
upstream.expectCancellation()
}
"complete future and return zero given an empty stream" in assertAllStagesStopped {
val futureValue =
Source.fromIterator[Int](() Iterator.empty)
.runFoldAsync(0)((acc, elem) Future.successful(acc + elem))
Await.result(futureValue, remainingOrDefault) should be(0)
}
"complete future and return zero + item given a stream of one item" in assertAllStagesStopped {
val futureValue =
Source.single(100)
.runFoldAsync(5)((acc, elem) Future.successful(acc + elem))
Await.result(futureValue, remainingOrDefault) should be(105)
}
}
// Keep

View file

@ -56,6 +56,14 @@ class FlowFoldSpec extends StreamSpec {
the[Exception] thrownBy Await.result(future, 3.seconds) should be(error)
}
"complete future and return zero given an empty stream" in assertAllStagesStopped {
val futureValue =
Source.fromIterator[Int](() Iterator.empty)
.runFold(0)(_ + _)
Await.result(futureValue, 3.seconds) should be(0)
}
}
}

View file

@ -30,6 +30,7 @@ final case class Map[In, Out](f: In ⇒ Out) extends GraphStage[FlowShape[In, Ou
val in = Inlet[In]("Map.in")
val out = Outlet[Out]("Map.out")
override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.map
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
@ -131,6 +132,7 @@ final case class DropWhile[T](p: T ⇒ Boolean) extends GraphStage[FlowShape[T,
val in = Inlet[T]("DropWhile.in")
val out = Outlet[T]("DropWhile.out")
override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.dropWhile
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
@ -150,9 +152,12 @@ final case class DropWhile[T](p: T ⇒ Boolean) extends GraphStage[FlowShape[T,
}
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
override def onPull(): Unit = pull(in)
setHandlers(in, out, this)
}
override def toString = "DropWhile"
}
@ -163,7 +168,9 @@ abstract class SupervisedGraphStageLogic(inheritedAttributes: Attributes, shape:
private lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
def withSupervision[T](f: () T): Option[T] =
try { Some(f()) } catch {
try {
Some(f())
} catch {
case NonFatal(ex)
decider(ex) match {
case Supervision.Stop onStop(ex)
@ -174,7 +181,9 @@ abstract class SupervisedGraphStageLogic(inheritedAttributes: Attributes, shape:
}
def onResume(t: Throwable): Unit
def onStop(t: Throwable): Unit = failStage(t)
def onRestart(t: Throwable): Unit = onResume(t)
}
@ -192,10 +201,13 @@ final case class Collect[In, Out](pf: PartialFunction[In, Out]) extends GraphSta
val in = Inlet[In]("Collect.in")
val out = Outlet[Out]("Collect.out")
override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.collect
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
import Collect.NotApplied
val wrappedPf = () pf.applyOrElse(grab(in), NotApplied)
override def onPush(): Unit = withSupervision(wrappedPf) match {
@ -207,9 +219,12 @@ final case class Collect[In, Out](pf: PartialFunction[In, Out]) extends GraphSta
}
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
override def onPull(): Unit = pull(in)
setHandlers(in, out, this)
}
override def toString = "Collect"
}
@ -224,6 +239,7 @@ final case class Recover[T](pf: PartialFunction[Throwable, T]) extends GraphStag
override protected val initialAttributes: Attributes = DefaultAttributes.recover
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
import Collect.NotApplied
var recovered: Option[T] = None
@ -320,6 +336,7 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
override val shape = FlowShape[In, Out](Inlet("Scan.in"), Outlet("Scan.out"))
override def initialAttributes: Attributes = DefaultAttributes.scan
override def toString: String = "Scan"
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
@ -342,6 +359,7 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
setHandler(in, new InHandler {
override def onPush(): Unit = ()
override def onUpstreamFinish(): Unit = setHandler(out, new OutHandler {
override def onPull(): Unit = {
push(out, aggregator)
@ -351,6 +369,7 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
})
override def onPull(): Unit = pull(in)
override def onPush(): Unit = {
try {
aggregator = f(aggregator, grab(in))
@ -426,6 +445,7 @@ final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
* INTERNAL API
*/
final class FoldAsync[In, Out](zero: Out, f: (Out, In) Future[Out]) extends GraphStage[FlowShape[In, Out]] {
import akka.dispatch.ExecutionContexts
val in = Inlet[In]("FoldAsync.in")
@ -436,30 +456,29 @@ final class FoldAsync[In, Out](zero: Out, f: (Out, In) ⇒ Future[Out]) extends
override val initialAttributes = DefaultAttributes.foldAsync
def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {
val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
private var aggregator: Out = zero
private var aggregating: Future[Out] = Future.successful(aggregator)
private var aggregator: Out = zero
private var aggregating: Future[Out] = Future.successful(aggregator)
private def onRestart(t: Throwable): Unit = {
aggregator = zero
}
private def onRestart(t: Throwable): Unit = {
aggregator = zero
}
private def ec = ExecutionContexts.sameThreadExecutionContext
private def ec = ExecutionContexts.sameThreadExecutionContext
private val futureCB = getAsyncCallback[Try[Out]]((result: Try[Out]) {
result match {
case Success(update) if update != null {
private val futureCB = getAsyncCallback[Try[Out]] {
case Success(update) if update != null
aggregator = update
if (isClosed(in)) {
push(out, update)
completeStage()
} else if (isAvailable(out) && !hasBeenPulled(in)) tryPull(in)
}
case other {
case other
val ex = other match {
case Failure(t) t
case Success(s) if s == null
@ -476,42 +495,45 @@ final class FoldAsync[In, Out](zero: Out, f: (Out, In) ⇒ Future[Out]) extends
completeStage()
} else if (isAvailable(out) && !hasBeenPulled(in)) tryPull(in)
}
}.invoke _
def onPush(): Unit = {
try {
aggregating = f(aggregator, grab(in))
handleAggregatingValue()
} catch {
case NonFatal(ex) decider(ex) match {
case Supervision.Stop failStage(ex)
case supervision {
supervision match {
case Supervision.Restart onRestart(ex)
case _ () // just ignore on Resume
}
tryPull(in)
}
}
}
}
}).invoke _
def onPush(): Unit = {
try {
aggregating = f(aggregator, grab(in))
override def onUpstreamFinish(): Unit = {
handleAggregatingValue()
}
def onPull(): Unit = if (!hasBeenPulled(in)) tryPull(in)
private def handleAggregatingValue(): Unit = {
aggregating.value match {
case Some(result) futureCB(result) // already completed
case _ aggregating.onComplete(futureCB)(ec)
}
} catch {
case NonFatal(ex) decider(ex) match {
case Supervision.Stop failStage(ex)
case supervision {
supervision match {
case Supervision.Restart onRestart(ex)
case _ () // just ignore on Resume
}
tryPull(in)
}
}
}
setHandlers(in, out, this)
override def toString =
s"FoldAsync.Logic(completed=${aggregating.isCompleted})"
}
override def onUpstreamFinish(): Unit = {}
def onPull(): Unit = if (!hasBeenPulled(in)) tryPull(in)
setHandlers(in, out, this)
override def toString =
s"FoldAsync.Logic(completed=${aggregating.isCompleted})"
}
}
/**
@ -619,6 +641,7 @@ final case class LimitWeighted[T](val n: Long, val costFn: T ⇒ Long) extends G
val in = Inlet[T]("LimitWeighted.in")
val out = Outlet[T]("LimitWeighted.out")
override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.limitWeighted
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
@ -633,14 +656,19 @@ final case class LimitWeighted[T](val n: Long, val costFn: T ⇒ Long) extends G
case None //do nothing
}
}
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
override def onRestart(t: Throwable): Unit = {
left = n
if (!hasBeenPulled(in)) pull(in)
}
override def onPull(): Unit = pull(in)
setHandlers(in, out, this)
}
override def toString = "LimitWeighted"
}
@ -892,6 +920,7 @@ final class Expand[In, Out](val extrapolate: In ⇒ Iterator[Out]) extends Graph
private val out = Outlet[Out]("expand.out")
override def initialAttributes = DefaultAttributes.expand
override val shape = FlowShape(in, out)
override def createLogic(attr: Attributes) = new GraphStageLogic(shape) with InHandler with OutHandler {
@ -941,6 +970,7 @@ final class Expand[In, Out](val extrapolate: In ⇒ Iterator[Out]) extends Graph
* INTERNAL API
*/
private[akka] object MapAsync {
final class Holder[T](var elem: Try[T], val cb: AsyncCallback[Holder[T]]) extends (Try[T] Unit) {
def setElem(t: Try[T]): Unit =
elem = t match {
@ -953,6 +983,7 @@ private[akka] object MapAsync {
cb.invoke(this)
}
}
val NotYetThere = Failure(new Exception)
}
@ -968,6 +999,7 @@ final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out])
private val out = Outlet[Out]("MapAsync.out")
override def initialAttributes = DefaultAttributes.mapAsync
override val shape = FlowShape(in, out)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
@ -984,6 +1016,7 @@ final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out])
case _ if (isAvailable(out)) pushOne()
}
}
val futureCB = getAsyncCallback[Holder[Out]](holderCompleted)
private[this] def todo = buffer.used
@ -1023,6 +1056,7 @@ final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out])
}
if (todo < parallelism && !hasBeenPulled(in)) tryPull(in)
}
override def onUpstreamFinish(): Unit = if (todo == 0) completeStage()
override def onPull(): Unit = pushOne()
@ -1041,6 +1075,7 @@ final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[O
private val out = Outlet[Out]("MapAsyncUnordered.out")
override def initialAttributes = DefaultAttributes.mapAsyncUnordered
override val shape = FlowShape(in, out)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
@ -1052,6 +1087,7 @@ final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[O
private var inFlight = 0
private var buffer: BufferImpl[Out] = _
private[this] def todo = inFlight + buffer.used
override def preStart(): Unit = buffer = BufferImpl(parallelism, materializer)
@ -1074,6 +1110,7 @@ final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[O
else if (!hasBeenPulled(in)) tryPull(in)
}
}
private val futureCB = getAsyncCallback(futureCompleted)
private val invokeFutureCB: Try[Out] Unit = futureCB.invoke
@ -1119,6 +1156,7 @@ final case class Log[T](
// TODO more optimisations can be done here - prepare logOnPush function etc
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with OutHandler with InHandler {
import Log._
private var logLevels: LogLevels = _
@ -1221,9 +1259,13 @@ private[akka] object Log {
* INTERNAL API
*/
private[stream] object TimerKeys {
case object TakeWithinTimerKey
case object DropWithinTimerKey
case object GroupedWithinTimerKey
}
final class GroupedWithin[T](val n: Int, val d: FiniteDuration) extends GraphStage[FlowShape[T, immutable.Seq[T]]] {
@ -1232,7 +1274,9 @@ final class GroupedWithin[T](val n: Int, val d: FiniteDuration) extends GraphSta
val in = Inlet[T]("in")
val out = Outlet[immutable.Seq[T]]("out")
override def initialAttributes = DefaultAttributes.groupedWithin
val shape = FlowShape(in, out)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler {
@ -1304,7 +1348,9 @@ final class GroupedWithin[T](val n: Int, val d: FiniteDuration) extends GraphSta
final class Delay[T](val d: FiniteDuration, val strategy: DelayOverflowStrategy) extends SimpleLinearGraphStage[T] {
private[this] def timerName = "DelayedTimer"
override def initialAttributes: Attributes = DefaultAttributes.delay
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler {
val size =
inheritedAttributes.get[InputBuffer] match {
@ -1418,6 +1464,7 @@ final class TakeWithin[T](val timeout: FiniteDuration) extends SimpleLinearGraph
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler {
def onPush(): Unit = push(out, grab(in))
def onPull(): Unit = pull(in)
setHandler(in, this)
@ -1461,7 +1508,8 @@ final class DropWithin[T](val timeout: FiniteDuration) extends SimpleLinearGraph
final class Reduce[T](val f: (T, T) T) extends SimpleLinearGraphStage[T] {
override def initialAttributes: Attributes = DefaultAttributes.reduce
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { self
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
self
override def toString = s"Reduce.Logic(aggregator=$aggregator)"
var aggregator: T = _
@ -1505,6 +1553,7 @@ private[stream] object RecoverWith {
final class RecoverWith[T, M](val maximumRetries: Int, val pf: PartialFunction[Throwable, Graph[SourceShape[T], M]]) extends SimpleLinearGraphStage[T] {
require(maximumRetries >= -1, "number of retries must be non-negative or equal to -1")
override def initialAttributes = DefaultAttributes.recoverWith
override def createLogic(attr: Attributes) = new GraphStageLogic(shape) {
@ -1512,6 +1561,7 @@ final class RecoverWith[T, M](val maximumRetries: Int, val pf: PartialFunction[T
setHandler(in, new InHandler {
override def onPush(): Unit = push(out, grab(in))
override def onUpstreamFailure(ex: Throwable) = onFailure(ex)
})
@ -1531,12 +1581,15 @@ final class RecoverWith[T, M](val maximumRetries: Int, val pf: PartialFunction[T
sinkIn.setHandler(new InHandler {
override def onPush(): Unit = push(out, sinkIn.grab())
override def onUpstreamFinish(): Unit = completeStage()
override def onUpstreamFailure(ex: Throwable) = onFailure(ex)
})
val outHandler = new OutHandler {
override def onPull(): Unit = sinkIn.pull()
override def onDownstreamFinish(): Unit = sinkIn.cancel()
}
@ -1556,13 +1609,16 @@ final class StatefulMapConcat[In, Out](val f: () ⇒ In ⇒ immutable.Iterable[O
val in = Inlet[In]("StatefulMapConcat.in")
val out = Outlet[Out]("StatefulMapConcat.out")
override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.statefulMapConcat
def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) with InHandler with OutHandler {
lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
var currentIterator: Iterator[Out] = _
var plainFun = f()
def hasNext = if (currentIterator != null) currentIterator.hasNext else false
setHandlers(in, out, this)
def pushPull(): Unit =
@ -1590,6 +1646,7 @@ final class StatefulMapConcat[In, Out](val f: () ⇒ In ⇒ immutable.Iterable[O
}
override def onUpstreamFinish(): Unit = onFinish()
override def onPull(): Unit = pushPull()
private def restartState(): Unit = {
@ -1597,6 +1654,7 @@ final class StatefulMapConcat[In, Out](val f: () ⇒ In ⇒ immutable.Iterable[O
currentIterator = null
}
}
override def toString = "StatefulMapConcat"
}