fold async emits zero given an empty stream #21562
This commit is contained in:
parent
0b9a5026a0
commit
0e60020c58
3 changed files with 132 additions and 52 deletions
|
|
@ -30,6 +30,7 @@ final case class Map[In, Out](f: In ⇒ Out) extends GraphStage[FlowShape[In, Ou
|
|||
val in = Inlet[In]("Map.in")
|
||||
val out = Outlet[Out]("Map.out")
|
||||
override val shape = FlowShape(in, out)
|
||||
|
||||
override def initialAttributes: Attributes = DefaultAttributes.map
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
|
||||
|
|
@ -131,6 +132,7 @@ final case class DropWhile[T](p: T ⇒ Boolean) extends GraphStage[FlowShape[T,
|
|||
val in = Inlet[T]("DropWhile.in")
|
||||
val out = Outlet[T]("DropWhile.out")
|
||||
override val shape = FlowShape(in, out)
|
||||
|
||||
override def initialAttributes: Attributes = DefaultAttributes.dropWhile
|
||||
|
||||
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
|
||||
|
|
@ -150,9 +152,12 @@ final case class DropWhile[T](p: T ⇒ Boolean) extends GraphStage[FlowShape[T,
|
|||
}
|
||||
|
||||
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
|
||||
|
||||
override def onPull(): Unit = pull(in)
|
||||
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
|
||||
override def toString = "DropWhile"
|
||||
}
|
||||
|
||||
|
|
@ -163,7 +168,9 @@ abstract class SupervisedGraphStageLogic(inheritedAttributes: Attributes, shape:
|
|||
private lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
|
||||
|
||||
def withSupervision[T](f: () ⇒ T): Option[T] =
|
||||
try { Some(f()) } catch {
|
||||
try {
|
||||
Some(f())
|
||||
} catch {
|
||||
case NonFatal(ex) ⇒
|
||||
decider(ex) match {
|
||||
case Supervision.Stop ⇒ onStop(ex)
|
||||
|
|
@ -174,7 +181,9 @@ abstract class SupervisedGraphStageLogic(inheritedAttributes: Attributes, shape:
|
|||
}
|
||||
|
||||
def onResume(t: Throwable): Unit
|
||||
|
||||
def onStop(t: Throwable): Unit = failStage(t)
|
||||
|
||||
def onRestart(t: Throwable): Unit = onResume(t)
|
||||
}
|
||||
|
||||
|
|
@ -192,10 +201,13 @@ final case class Collect[In, Out](pf: PartialFunction[In, Out]) extends GraphSta
|
|||
val in = Inlet[In]("Collect.in")
|
||||
val out = Outlet[Out]("Collect.out")
|
||||
override val shape = FlowShape(in, out)
|
||||
|
||||
override def initialAttributes: Attributes = DefaultAttributes.collect
|
||||
|
||||
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
|
||||
|
||||
import Collect.NotApplied
|
||||
|
||||
val wrappedPf = () ⇒ pf.applyOrElse(grab(in), NotApplied)
|
||||
|
||||
override def onPush(): Unit = withSupervision(wrappedPf) match {
|
||||
|
|
@ -207,9 +219,12 @@ final case class Collect[In, Out](pf: PartialFunction[In, Out]) extends GraphSta
|
|||
}
|
||||
|
||||
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
|
||||
|
||||
override def onPull(): Unit = pull(in)
|
||||
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
|
||||
override def toString = "Collect"
|
||||
}
|
||||
|
||||
|
|
@ -224,6 +239,7 @@ final case class Recover[T](pf: PartialFunction[Throwable, T]) extends GraphStag
|
|||
override protected val initialAttributes: Attributes = DefaultAttributes.recover
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
|
||||
import Collect.NotApplied
|
||||
|
||||
var recovered: Option[T] = None
|
||||
|
|
@ -320,6 +336,7 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
|
|||
override val shape = FlowShape[In, Out](Inlet("Scan.in"), Outlet("Scan.out"))
|
||||
|
||||
override def initialAttributes: Attributes = DefaultAttributes.scan
|
||||
|
||||
override def toString: String = "Scan"
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
|
||||
|
|
@ -342,6 +359,7 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
|
|||
|
||||
setHandler(in, new InHandler {
|
||||
override def onPush(): Unit = ()
|
||||
|
||||
override def onUpstreamFinish(): Unit = setHandler(out, new OutHandler {
|
||||
override def onPull(): Unit = {
|
||||
push(out, aggregator)
|
||||
|
|
@ -351,6 +369,7 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
|
|||
})
|
||||
|
||||
override def onPull(): Unit = pull(in)
|
||||
|
||||
override def onPush(): Unit = {
|
||||
try {
|
||||
aggregator = f(aggregator, grab(in))
|
||||
|
|
@ -426,6 +445,7 @@ final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
|
|||
* INTERNAL API
|
||||
*/
|
||||
final class FoldAsync[In, Out](zero: Out, f: (Out, In) ⇒ Future[Out]) extends GraphStage[FlowShape[In, Out]] {
|
||||
|
||||
import akka.dispatch.ExecutionContexts
|
||||
|
||||
val in = Inlet[In]("FoldAsync.in")
|
||||
|
|
@ -436,30 +456,29 @@ final class FoldAsync[In, Out](zero: Out, f: (Out, In) ⇒ Future[Out]) extends
|
|||
|
||||
override val initialAttributes = DefaultAttributes.foldAsync
|
||||
|
||||
def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
|
||||
def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
|
||||
new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
|
||||
|
||||
private var aggregator: Out = zero
|
||||
private var aggregating: Future[Out] = Future.successful(aggregator)
|
||||
private var aggregator: Out = zero
|
||||
private var aggregating: Future[Out] = Future.successful(aggregator)
|
||||
|
||||
private def onRestart(t: Throwable): Unit = {
|
||||
aggregator = zero
|
||||
}
|
||||
private def onRestart(t: Throwable): Unit = {
|
||||
aggregator = zero
|
||||
}
|
||||
|
||||
private def ec = ExecutionContexts.sameThreadExecutionContext
|
||||
private def ec = ExecutionContexts.sameThreadExecutionContext
|
||||
|
||||
private val futureCB = getAsyncCallback[Try[Out]]((result: Try[Out]) ⇒ {
|
||||
result match {
|
||||
case Success(update) if update != null ⇒ {
|
||||
private val futureCB = getAsyncCallback[Try[Out]] {
|
||||
case Success(update) if update != null ⇒
|
||||
aggregator = update
|
||||
|
||||
if (isClosed(in)) {
|
||||
push(out, update)
|
||||
completeStage()
|
||||
} else if (isAvailable(out) && !hasBeenPulled(in)) tryPull(in)
|
||||
}
|
||||
|
||||
case other ⇒ {
|
||||
case other ⇒
|
||||
val ex = other match {
|
||||
case Failure(t) ⇒ t
|
||||
case Success(s) if s == null ⇒
|
||||
|
|
@ -476,42 +495,45 @@ final class FoldAsync[In, Out](zero: Out, f: (Out, In) ⇒ Future[Out]) extends
|
|||
completeStage()
|
||||
} else if (isAvailable(out) && !hasBeenPulled(in)) tryPull(in)
|
||||
}
|
||||
}.invoke _
|
||||
|
||||
def onPush(): Unit = {
|
||||
try {
|
||||
aggregating = f(aggregator, grab(in))
|
||||
handleAggregatingValue()
|
||||
} catch {
|
||||
case NonFatal(ex) ⇒ decider(ex) match {
|
||||
case Supervision.Stop ⇒ failStage(ex)
|
||||
case supervision ⇒ {
|
||||
supervision match {
|
||||
case Supervision.Restart ⇒ onRestart(ex)
|
||||
case _ ⇒ () // just ignore on Resume
|
||||
}
|
||||
|
||||
tryPull(in)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}).invoke _
|
||||
|
||||
def onPush(): Unit = {
|
||||
try {
|
||||
aggregating = f(aggregator, grab(in))
|
||||
override def onUpstreamFinish(): Unit = {
|
||||
handleAggregatingValue()
|
||||
}
|
||||
|
||||
def onPull(): Unit = if (!hasBeenPulled(in)) tryPull(in)
|
||||
|
||||
private def handleAggregatingValue(): Unit = {
|
||||
aggregating.value match {
|
||||
case Some(result) ⇒ futureCB(result) // already completed
|
||||
case _ ⇒ aggregating.onComplete(futureCB)(ec)
|
||||
}
|
||||
} catch {
|
||||
case NonFatal(ex) ⇒ decider(ex) match {
|
||||
case Supervision.Stop ⇒ failStage(ex)
|
||||
case supervision ⇒ {
|
||||
supervision match {
|
||||
case Supervision.Restart ⇒ onRestart(ex)
|
||||
case _ ⇒ () // just ignore on Resume
|
||||
}
|
||||
|
||||
tryPull(in)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setHandlers(in, out, this)
|
||||
|
||||
override def toString =
|
||||
s"FoldAsync.Logic(completed=${aggregating.isCompleted})"
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(): Unit = {}
|
||||
|
||||
def onPull(): Unit = if (!hasBeenPulled(in)) tryPull(in)
|
||||
|
||||
setHandlers(in, out, this)
|
||||
|
||||
override def toString =
|
||||
s"FoldAsync.Logic(completed=${aggregating.isCompleted})"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -619,6 +641,7 @@ final case class LimitWeighted[T](val n: Long, val costFn: T ⇒ Long) extends G
|
|||
val in = Inlet[T]("LimitWeighted.in")
|
||||
val out = Outlet[T]("LimitWeighted.out")
|
||||
override val shape = FlowShape(in, out)
|
||||
|
||||
override def initialAttributes: Attributes = DefaultAttributes.limitWeighted
|
||||
|
||||
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
|
||||
|
|
@ -633,14 +656,19 @@ final case class LimitWeighted[T](val n: Long, val costFn: T ⇒ Long) extends G
|
|||
case None ⇒ //do nothing
|
||||
}
|
||||
}
|
||||
|
||||
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
|
||||
|
||||
override def onRestart(t: Throwable): Unit = {
|
||||
left = n
|
||||
if (!hasBeenPulled(in)) pull(in)
|
||||
}
|
||||
|
||||
override def onPull(): Unit = pull(in)
|
||||
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
|
||||
override def toString = "LimitWeighted"
|
||||
}
|
||||
|
||||
|
|
@ -892,6 +920,7 @@ final class Expand[In, Out](val extrapolate: In ⇒ Iterator[Out]) extends Graph
|
|||
private val out = Outlet[Out]("expand.out")
|
||||
|
||||
override def initialAttributes = DefaultAttributes.expand
|
||||
|
||||
override val shape = FlowShape(in, out)
|
||||
|
||||
override def createLogic(attr: Attributes) = new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
|
|
@ -941,6 +970,7 @@ final class Expand[In, Out](val extrapolate: In ⇒ Iterator[Out]) extends Graph
|
|||
* INTERNAL API
|
||||
*/
|
||||
private[akka] object MapAsync {
|
||||
|
||||
final class Holder[T](var elem: Try[T], val cb: AsyncCallback[Holder[T]]) extends (Try[T] ⇒ Unit) {
|
||||
def setElem(t: Try[T]): Unit =
|
||||
elem = t match {
|
||||
|
|
@ -953,6 +983,7 @@ private[akka] object MapAsync {
|
|||
cb.invoke(this)
|
||||
}
|
||||
}
|
||||
|
||||
val NotYetThere = Failure(new Exception)
|
||||
}
|
||||
|
||||
|
|
@ -968,6 +999,7 @@ final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out])
|
|||
private val out = Outlet[Out]("MapAsync.out")
|
||||
|
||||
override def initialAttributes = DefaultAttributes.mapAsync
|
||||
|
||||
override val shape = FlowShape(in, out)
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
|
||||
|
|
@ -984,6 +1016,7 @@ final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out])
|
|||
case _ ⇒ if (isAvailable(out)) pushOne()
|
||||
}
|
||||
}
|
||||
|
||||
val futureCB = getAsyncCallback[Holder[Out]](holderCompleted)
|
||||
|
||||
private[this] def todo = buffer.used
|
||||
|
|
@ -1023,6 +1056,7 @@ final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out])
|
|||
}
|
||||
if (todo < parallelism && !hasBeenPulled(in)) tryPull(in)
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(): Unit = if (todo == 0) completeStage()
|
||||
|
||||
override def onPull(): Unit = pushOne()
|
||||
|
|
@ -1041,6 +1075,7 @@ final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[O
|
|||
private val out = Outlet[Out]("MapAsyncUnordered.out")
|
||||
|
||||
override def initialAttributes = DefaultAttributes.mapAsyncUnordered
|
||||
|
||||
override val shape = FlowShape(in, out)
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
|
||||
|
|
@ -1052,6 +1087,7 @@ final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[O
|
|||
|
||||
private var inFlight = 0
|
||||
private var buffer: BufferImpl[Out] = _
|
||||
|
||||
private[this] def todo = inFlight + buffer.used
|
||||
|
||||
override def preStart(): Unit = buffer = BufferImpl(parallelism, materializer)
|
||||
|
|
@ -1074,6 +1110,7 @@ final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[O
|
|||
else if (!hasBeenPulled(in)) tryPull(in)
|
||||
}
|
||||
}
|
||||
|
||||
private val futureCB = getAsyncCallback(futureCompleted)
|
||||
private val invokeFutureCB: Try[Out] ⇒ Unit = futureCB.invoke
|
||||
|
||||
|
|
@ -1119,6 +1156,7 @@ final case class Log[T](
|
|||
// TODO more optimisations can be done here - prepare logOnPush function etc
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
|
||||
new GraphStageLogic(shape) with OutHandler with InHandler {
|
||||
|
||||
import Log._
|
||||
|
||||
private var logLevels: LogLevels = _
|
||||
|
|
@ -1221,9 +1259,13 @@ private[akka] object Log {
|
|||
* INTERNAL API
|
||||
*/
|
||||
private[stream] object TimerKeys {
|
||||
|
||||
case object TakeWithinTimerKey
|
||||
|
||||
case object DropWithinTimerKey
|
||||
|
||||
case object GroupedWithinTimerKey
|
||||
|
||||
}
|
||||
|
||||
final class GroupedWithin[T](val n: Int, val d: FiniteDuration) extends GraphStage[FlowShape[T, immutable.Seq[T]]] {
|
||||
|
|
@ -1232,7 +1274,9 @@ final class GroupedWithin[T](val n: Int, val d: FiniteDuration) extends GraphSta
|
|||
|
||||
val in = Inlet[T]("in")
|
||||
val out = Outlet[immutable.Seq[T]]("out")
|
||||
|
||||
override def initialAttributes = DefaultAttributes.groupedWithin
|
||||
|
||||
val shape = FlowShape(in, out)
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler {
|
||||
|
|
@ -1304,7 +1348,9 @@ final class GroupedWithin[T](val n: Int, val d: FiniteDuration) extends GraphSta
|
|||
|
||||
final class Delay[T](val d: FiniteDuration, val strategy: DelayOverflowStrategy) extends SimpleLinearGraphStage[T] {
|
||||
private[this] def timerName = "DelayedTimer"
|
||||
|
||||
override def initialAttributes: Attributes = DefaultAttributes.delay
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler {
|
||||
val size =
|
||||
inheritedAttributes.get[InputBuffer] match {
|
||||
|
|
@ -1418,6 +1464,7 @@ final class TakeWithin[T](val timeout: FiniteDuration) extends SimpleLinearGraph
|
|||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler {
|
||||
def onPush(): Unit = push(out, grab(in))
|
||||
|
||||
def onPull(): Unit = pull(in)
|
||||
|
||||
setHandler(in, this)
|
||||
|
|
@ -1461,7 +1508,8 @@ final class DropWithin[T](val timeout: FiniteDuration) extends SimpleLinearGraph
|
|||
final class Reduce[T](val f: (T, T) ⇒ T) extends SimpleLinearGraphStage[T] {
|
||||
override def initialAttributes: Attributes = DefaultAttributes.reduce
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { self ⇒
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
self ⇒
|
||||
override def toString = s"Reduce.Logic(aggregator=$aggregator)"
|
||||
|
||||
var aggregator: T = _
|
||||
|
|
@ -1505,6 +1553,7 @@ private[stream] object RecoverWith {
|
|||
|
||||
final class RecoverWith[T, M](val maximumRetries: Int, val pf: PartialFunction[Throwable, Graph[SourceShape[T], M]]) extends SimpleLinearGraphStage[T] {
|
||||
require(maximumRetries >= -1, "number of retries must be non-negative or equal to -1")
|
||||
|
||||
override def initialAttributes = DefaultAttributes.recoverWith
|
||||
|
||||
override def createLogic(attr: Attributes) = new GraphStageLogic(shape) {
|
||||
|
|
@ -1512,6 +1561,7 @@ final class RecoverWith[T, M](val maximumRetries: Int, val pf: PartialFunction[T
|
|||
|
||||
setHandler(in, new InHandler {
|
||||
override def onPush(): Unit = push(out, grab(in))
|
||||
|
||||
override def onUpstreamFailure(ex: Throwable) = onFailure(ex)
|
||||
})
|
||||
|
||||
|
|
@ -1531,12 +1581,15 @@ final class RecoverWith[T, M](val maximumRetries: Int, val pf: PartialFunction[T
|
|||
|
||||
sinkIn.setHandler(new InHandler {
|
||||
override def onPush(): Unit = push(out, sinkIn.grab())
|
||||
|
||||
override def onUpstreamFinish(): Unit = completeStage()
|
||||
|
||||
override def onUpstreamFailure(ex: Throwable) = onFailure(ex)
|
||||
})
|
||||
|
||||
val outHandler = new OutHandler {
|
||||
override def onPull(): Unit = sinkIn.pull()
|
||||
|
||||
override def onDownstreamFinish(): Unit = sinkIn.cancel()
|
||||
}
|
||||
|
||||
|
|
@ -1556,13 +1609,16 @@ final class StatefulMapConcat[In, Out](val f: () ⇒ In ⇒ immutable.Iterable[O
|
|||
val in = Inlet[In]("StatefulMapConcat.in")
|
||||
val out = Outlet[Out]("StatefulMapConcat.out")
|
||||
override val shape = FlowShape(in, out)
|
||||
|
||||
override def initialAttributes: Attributes = DefaultAttributes.statefulMapConcat
|
||||
|
||||
def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
|
||||
var currentIterator: Iterator[Out] = _
|
||||
var plainFun = f()
|
||||
|
||||
def hasNext = if (currentIterator != null) currentIterator.hasNext else false
|
||||
|
||||
setHandlers(in, out, this)
|
||||
|
||||
def pushPull(): Unit =
|
||||
|
|
@ -1590,6 +1646,7 @@ final class StatefulMapConcat[In, Out](val f: () ⇒ In ⇒ immutable.Iterable[O
|
|||
}
|
||||
|
||||
override def onUpstreamFinish(): Unit = onFinish()
|
||||
|
||||
override def onPull(): Unit = pushPull()
|
||||
|
||||
private def restartState(): Unit = {
|
||||
|
|
@ -1597,6 +1654,7 @@ final class StatefulMapConcat[In, Out](val f: () ⇒ In ⇒ immutable.Iterable[O
|
|||
currentIterator = null
|
||||
}
|
||||
}
|
||||
|
||||
override def toString = "StatefulMapConcat"
|
||||
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue