fold async emits zero given an empty stream #21562

This commit is contained in:
Vsevolod Belousov 2016-10-17 08:02:54 +01:00 committed by Johan Andrén
parent 0b9a5026a0
commit 0e60020c58
3 changed files with 132 additions and 52 deletions

View file

@ -3,22 +3,20 @@
*/ */
package akka.stream.scaladsl package akka.stream.scaladsl
import scala.util.control.NoStackTrace
import scala.concurrent.{ Await, Future }
import scala.concurrent.duration._
import akka.NotUsed import akka.NotUsed
import akka.stream.ActorMaterializer
import akka.stream.ActorAttributes.supervisionStrategy import akka.stream.ActorAttributes.supervisionStrategy
import akka.stream.Supervision.{ restartingDecider, resumingDecider } import akka.stream.ActorMaterializer
import akka.stream.Supervision.{restartingDecider, resumingDecider}
import akka.stream.impl.ReactiveStreamsCompliance import akka.stream.impl.ReactiveStreamsCompliance
import akka.stream.testkit.Utils._
import akka.testkit.{ AkkaSpec, TestLatch } import akka.stream.testkit._
import akka.stream.testkit._, Utils._ import akka.testkit.TestLatch
import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.concurrent.PatienceConfiguration.Timeout
import scala.concurrent.duration._
import scala.concurrent.{Await, Future}
import scala.util.control.NoStackTrace
class FlowFoldAsyncSpec extends StreamSpec { class FlowFoldAsyncSpec extends StreamSpec {
implicit val materializer = ActorMaterializer() implicit val materializer = ActorMaterializer()
implicit def ec = materializer.executionContext implicit def ec = materializer.executionContext
@ -261,6 +259,22 @@ class FlowFoldAsyncSpec extends StreamSpec {
upstream.expectCancellation() upstream.expectCancellation()
} }
"complete future and return zero given an empty stream" in assertAllStagesStopped {
val futureValue =
Source.fromIterator[Int](() Iterator.empty)
.runFoldAsync(0)((acc, elem) Future.successful(acc + elem))
Await.result(futureValue, remainingOrDefault) should be(0)
}
"complete future and return zero + item given a stream of one item" in assertAllStagesStopped {
val futureValue =
Source.single(100)
.runFoldAsync(5)((acc, elem) Future.successful(acc + elem))
Await.result(futureValue, remainingOrDefault) should be(105)
}
} }
// Keep // Keep

View file

@ -56,6 +56,14 @@ class FlowFoldSpec extends StreamSpec {
the[Exception] thrownBy Await.result(future, 3.seconds) should be(error) the[Exception] thrownBy Await.result(future, 3.seconds) should be(error)
} }
"complete future and return zero given an empty stream" in assertAllStagesStopped {
val futureValue =
Source.fromIterator[Int](() Iterator.empty)
.runFold(0)(_ + _)
Await.result(futureValue, 3.seconds) should be(0)
}
} }
} }

View file

@ -30,6 +30,7 @@ final case class Map[In, Out](f: In ⇒ Out) extends GraphStage[FlowShape[In, Ou
val in = Inlet[In]("Map.in") val in = Inlet[In]("Map.in")
val out = Outlet[Out]("Map.out") val out = Outlet[Out]("Map.out")
override val shape = FlowShape(in, out) override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.map override def initialAttributes: Attributes = DefaultAttributes.map
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
@ -131,6 +132,7 @@ final case class DropWhile[T](p: T ⇒ Boolean) extends GraphStage[FlowShape[T,
val in = Inlet[T]("DropWhile.in") val in = Inlet[T]("DropWhile.in")
val out = Outlet[T]("DropWhile.out") val out = Outlet[T]("DropWhile.out")
override val shape = FlowShape(in, out) override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.dropWhile override def initialAttributes: Attributes = DefaultAttributes.dropWhile
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler { def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
@ -150,9 +152,12 @@ final case class DropWhile[T](p: T ⇒ Boolean) extends GraphStage[FlowShape[T,
} }
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in) override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
override def onPull(): Unit = pull(in) override def onPull(): Unit = pull(in)
setHandlers(in, out, this) setHandlers(in, out, this)
} }
override def toString = "DropWhile" override def toString = "DropWhile"
} }
@ -163,7 +168,9 @@ abstract class SupervisedGraphStageLogic(inheritedAttributes: Attributes, shape:
private lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider) private lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
def withSupervision[T](f: () T): Option[T] = def withSupervision[T](f: () T): Option[T] =
try { Some(f()) } catch { try {
Some(f())
} catch {
case NonFatal(ex) case NonFatal(ex)
decider(ex) match { decider(ex) match {
case Supervision.Stop onStop(ex) case Supervision.Stop onStop(ex)
@ -174,7 +181,9 @@ abstract class SupervisedGraphStageLogic(inheritedAttributes: Attributes, shape:
} }
def onResume(t: Throwable): Unit def onResume(t: Throwable): Unit
def onStop(t: Throwable): Unit = failStage(t) def onStop(t: Throwable): Unit = failStage(t)
def onRestart(t: Throwable): Unit = onResume(t) def onRestart(t: Throwable): Unit = onResume(t)
} }
@ -192,10 +201,13 @@ final case class Collect[In, Out](pf: PartialFunction[In, Out]) extends GraphSta
val in = Inlet[In]("Collect.in") val in = Inlet[In]("Collect.in")
val out = Outlet[Out]("Collect.out") val out = Outlet[Out]("Collect.out")
override val shape = FlowShape(in, out) override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.collect override def initialAttributes: Attributes = DefaultAttributes.collect
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler { def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
import Collect.NotApplied import Collect.NotApplied
val wrappedPf = () pf.applyOrElse(grab(in), NotApplied) val wrappedPf = () pf.applyOrElse(grab(in), NotApplied)
override def onPush(): Unit = withSupervision(wrappedPf) match { override def onPush(): Unit = withSupervision(wrappedPf) match {
@ -207,9 +219,12 @@ final case class Collect[In, Out](pf: PartialFunction[In, Out]) extends GraphSta
} }
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in) override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
override def onPull(): Unit = pull(in) override def onPull(): Unit = pull(in)
setHandlers(in, out, this) setHandlers(in, out, this)
} }
override def toString = "Collect" override def toString = "Collect"
} }
@ -224,6 +239,7 @@ final case class Recover[T](pf: PartialFunction[Throwable, T]) extends GraphStag
override protected val initialAttributes: Attributes = DefaultAttributes.recover override protected val initialAttributes: Attributes = DefaultAttributes.recover
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
import Collect.NotApplied import Collect.NotApplied
var recovered: Option[T] = None var recovered: Option[T] = None
@ -320,6 +336,7 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
override val shape = FlowShape[In, Out](Inlet("Scan.in"), Outlet("Scan.out")) override val shape = FlowShape[In, Out](Inlet("Scan.in"), Outlet("Scan.out"))
override def initialAttributes: Attributes = DefaultAttributes.scan override def initialAttributes: Attributes = DefaultAttributes.scan
override def toString: String = "Scan" override def toString: String = "Scan"
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
@ -342,6 +359,7 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
setHandler(in, new InHandler { setHandler(in, new InHandler {
override def onPush(): Unit = () override def onPush(): Unit = ()
override def onUpstreamFinish(): Unit = setHandler(out, new OutHandler { override def onUpstreamFinish(): Unit = setHandler(out, new OutHandler {
override def onPull(): Unit = { override def onPull(): Unit = {
push(out, aggregator) push(out, aggregator)
@ -351,6 +369,7 @@ final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
}) })
override def onPull(): Unit = pull(in) override def onPull(): Unit = pull(in)
override def onPush(): Unit = { override def onPush(): Unit = {
try { try {
aggregator = f(aggregator, grab(in)) aggregator = f(aggregator, grab(in))
@ -426,6 +445,7 @@ final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
* INTERNAL API * INTERNAL API
*/ */
final class FoldAsync[In, Out](zero: Out, f: (Out, In) Future[Out]) extends GraphStage[FlowShape[In, Out]] { final class FoldAsync[In, Out](zero: Out, f: (Out, In) Future[Out]) extends GraphStage[FlowShape[In, Out]] {
import akka.dispatch.ExecutionContexts import akka.dispatch.ExecutionContexts
val in = Inlet[In]("FoldAsync.in") val in = Inlet[In]("FoldAsync.in")
@ -436,30 +456,29 @@ final class FoldAsync[In, Out](zero: Out, f: (Out, In) ⇒ Future[Out]) extends
override val initialAttributes = DefaultAttributes.foldAsync override val initialAttributes = DefaultAttributes.foldAsync
def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider) new GraphStageLogic(shape) with InHandler with OutHandler {
val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
private var aggregator: Out = zero private var aggregator: Out = zero
private var aggregating: Future[Out] = Future.successful(aggregator) private var aggregating: Future[Out] = Future.successful(aggregator)
private def onRestart(t: Throwable): Unit = { private def onRestart(t: Throwable): Unit = {
aggregator = zero aggregator = zero
} }
private def ec = ExecutionContexts.sameThreadExecutionContext private def ec = ExecutionContexts.sameThreadExecutionContext
private val futureCB = getAsyncCallback[Try[Out]]((result: Try[Out]) { private val futureCB = getAsyncCallback[Try[Out]] {
result match { case Success(update) if update != null
case Success(update) if update != null {
aggregator = update aggregator = update
if (isClosed(in)) { if (isClosed(in)) {
push(out, update) push(out, update)
completeStage() completeStage()
} else if (isAvailable(out) && !hasBeenPulled(in)) tryPull(in) } else if (isAvailable(out) && !hasBeenPulled(in)) tryPull(in)
}
case other { case other
val ex = other match { val ex = other match {
case Failure(t) t case Failure(t) t
case Success(s) if s == null case Success(s) if s == null
@ -476,42 +495,45 @@ final class FoldAsync[In, Out](zero: Out, f: (Out, In) ⇒ Future[Out]) extends
completeStage() completeStage()
} else if (isAvailable(out) && !hasBeenPulled(in)) tryPull(in) } else if (isAvailable(out) && !hasBeenPulled(in)) tryPull(in)
} }
}.invoke _
def onPush(): Unit = {
try {
aggregating = f(aggregator, grab(in))
handleAggregatingValue()
} catch {
case NonFatal(ex) decider(ex) match {
case Supervision.Stop failStage(ex)
case supervision {
supervision match {
case Supervision.Restart onRestart(ex)
case _ () // just ignore on Resume
}
tryPull(in)
}
}
} }
} }
}).invoke _
def onPush(): Unit = { override def onUpstreamFinish(): Unit = {
try { handleAggregatingValue()
aggregating = f(aggregator, grab(in)) }
def onPull(): Unit = if (!hasBeenPulled(in)) tryPull(in)
private def handleAggregatingValue(): Unit = {
aggregating.value match { aggregating.value match {
case Some(result) futureCB(result) // already completed case Some(result) futureCB(result) // already completed
case _ aggregating.onComplete(futureCB)(ec) case _ aggregating.onComplete(futureCB)(ec)
} }
} catch {
case NonFatal(ex) decider(ex) match {
case Supervision.Stop failStage(ex)
case supervision {
supervision match {
case Supervision.Restart onRestart(ex)
case _ () // just ignore on Resume
}
tryPull(in)
}
}
} }
setHandlers(in, out, this)
override def toString =
s"FoldAsync.Logic(completed=${aggregating.isCompleted})"
} }
override def onUpstreamFinish(): Unit = {}
def onPull(): Unit = if (!hasBeenPulled(in)) tryPull(in)
setHandlers(in, out, this)
override def toString =
s"FoldAsync.Logic(completed=${aggregating.isCompleted})"
}
} }
/** /**
@ -619,6 +641,7 @@ final case class LimitWeighted[T](val n: Long, val costFn: T ⇒ Long) extends G
val in = Inlet[T]("LimitWeighted.in") val in = Inlet[T]("LimitWeighted.in")
val out = Outlet[T]("LimitWeighted.out") val out = Outlet[T]("LimitWeighted.out")
override val shape = FlowShape(in, out) override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.limitWeighted override def initialAttributes: Attributes = DefaultAttributes.limitWeighted
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler { def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
@ -633,14 +656,19 @@ final case class LimitWeighted[T](val n: Long, val costFn: T ⇒ Long) extends G
case None //do nothing case None //do nothing
} }
} }
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in) override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
override def onRestart(t: Throwable): Unit = { override def onRestart(t: Throwable): Unit = {
left = n left = n
if (!hasBeenPulled(in)) pull(in) if (!hasBeenPulled(in)) pull(in)
} }
override def onPull(): Unit = pull(in) override def onPull(): Unit = pull(in)
setHandlers(in, out, this) setHandlers(in, out, this)
} }
override def toString = "LimitWeighted" override def toString = "LimitWeighted"
} }
@ -892,6 +920,7 @@ final class Expand[In, Out](val extrapolate: In ⇒ Iterator[Out]) extends Graph
private val out = Outlet[Out]("expand.out") private val out = Outlet[Out]("expand.out")
override def initialAttributes = DefaultAttributes.expand override def initialAttributes = DefaultAttributes.expand
override val shape = FlowShape(in, out) override val shape = FlowShape(in, out)
override def createLogic(attr: Attributes) = new GraphStageLogic(shape) with InHandler with OutHandler { override def createLogic(attr: Attributes) = new GraphStageLogic(shape) with InHandler with OutHandler {
@ -941,6 +970,7 @@ final class Expand[In, Out](val extrapolate: In ⇒ Iterator[Out]) extends Graph
* INTERNAL API * INTERNAL API
*/ */
private[akka] object MapAsync { private[akka] object MapAsync {
final class Holder[T](var elem: Try[T], val cb: AsyncCallback[Holder[T]]) extends (Try[T] Unit) { final class Holder[T](var elem: Try[T], val cb: AsyncCallback[Holder[T]]) extends (Try[T] Unit) {
def setElem(t: Try[T]): Unit = def setElem(t: Try[T]): Unit =
elem = t match { elem = t match {
@ -953,6 +983,7 @@ private[akka] object MapAsync {
cb.invoke(this) cb.invoke(this)
} }
} }
val NotYetThere = Failure(new Exception) val NotYetThere = Failure(new Exception)
} }
@ -968,6 +999,7 @@ final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out])
private val out = Outlet[Out]("MapAsync.out") private val out = Outlet[Out]("MapAsync.out")
override def initialAttributes = DefaultAttributes.mapAsync override def initialAttributes = DefaultAttributes.mapAsync
override val shape = FlowShape(in, out) override val shape = FlowShape(in, out)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
@ -984,6 +1016,7 @@ final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out])
case _ if (isAvailable(out)) pushOne() case _ if (isAvailable(out)) pushOne()
} }
} }
val futureCB = getAsyncCallback[Holder[Out]](holderCompleted) val futureCB = getAsyncCallback[Holder[Out]](holderCompleted)
private[this] def todo = buffer.used private[this] def todo = buffer.used
@ -1023,6 +1056,7 @@ final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out])
} }
if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) if (todo < parallelism && !hasBeenPulled(in)) tryPull(in)
} }
override def onUpstreamFinish(): Unit = if (todo == 0) completeStage() override def onUpstreamFinish(): Unit = if (todo == 0) completeStage()
override def onPull(): Unit = pushOne() override def onPull(): Unit = pushOne()
@ -1041,6 +1075,7 @@ final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[O
private val out = Outlet[Out]("MapAsyncUnordered.out") private val out = Outlet[Out]("MapAsyncUnordered.out")
override def initialAttributes = DefaultAttributes.mapAsyncUnordered override def initialAttributes = DefaultAttributes.mapAsyncUnordered
override val shape = FlowShape(in, out) override val shape = FlowShape(in, out)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
@ -1052,6 +1087,7 @@ final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[O
private var inFlight = 0 private var inFlight = 0
private var buffer: BufferImpl[Out] = _ private var buffer: BufferImpl[Out] = _
private[this] def todo = inFlight + buffer.used private[this] def todo = inFlight + buffer.used
override def preStart(): Unit = buffer = BufferImpl(parallelism, materializer) override def preStart(): Unit = buffer = BufferImpl(parallelism, materializer)
@ -1074,6 +1110,7 @@ final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[O
else if (!hasBeenPulled(in)) tryPull(in) else if (!hasBeenPulled(in)) tryPull(in)
} }
} }
private val futureCB = getAsyncCallback(futureCompleted) private val futureCB = getAsyncCallback(futureCompleted)
private val invokeFutureCB: Try[Out] Unit = futureCB.invoke private val invokeFutureCB: Try[Out] Unit = futureCB.invoke
@ -1119,6 +1156,7 @@ final case class Log[T](
// TODO more optimisations can be done here - prepare logOnPush function etc // TODO more optimisations can be done here - prepare logOnPush function etc
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with OutHandler with InHandler { new GraphStageLogic(shape) with OutHandler with InHandler {
import Log._ import Log._
private var logLevels: LogLevels = _ private var logLevels: LogLevels = _
@ -1221,9 +1259,13 @@ private[akka] object Log {
* INTERNAL API * INTERNAL API
*/ */
private[stream] object TimerKeys { private[stream] object TimerKeys {
case object TakeWithinTimerKey case object TakeWithinTimerKey
case object DropWithinTimerKey case object DropWithinTimerKey
case object GroupedWithinTimerKey case object GroupedWithinTimerKey
} }
final class GroupedWithin[T](val n: Int, val d: FiniteDuration) extends GraphStage[FlowShape[T, immutable.Seq[T]]] { final class GroupedWithin[T](val n: Int, val d: FiniteDuration) extends GraphStage[FlowShape[T, immutable.Seq[T]]] {
@ -1232,7 +1274,9 @@ final class GroupedWithin[T](val n: Int, val d: FiniteDuration) extends GraphSta
val in = Inlet[T]("in") val in = Inlet[T]("in")
val out = Outlet[immutable.Seq[T]]("out") val out = Outlet[immutable.Seq[T]]("out")
override def initialAttributes = DefaultAttributes.groupedWithin override def initialAttributes = DefaultAttributes.groupedWithin
val shape = FlowShape(in, out) val shape = FlowShape(in, out)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler { override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler {
@ -1304,7 +1348,9 @@ final class GroupedWithin[T](val n: Int, val d: FiniteDuration) extends GraphSta
final class Delay[T](val d: FiniteDuration, val strategy: DelayOverflowStrategy) extends SimpleLinearGraphStage[T] { final class Delay[T](val d: FiniteDuration, val strategy: DelayOverflowStrategy) extends SimpleLinearGraphStage[T] {
private[this] def timerName = "DelayedTimer" private[this] def timerName = "DelayedTimer"
override def initialAttributes: Attributes = DefaultAttributes.delay override def initialAttributes: Attributes = DefaultAttributes.delay
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler { override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler {
val size = val size =
inheritedAttributes.get[InputBuffer] match { inheritedAttributes.get[InputBuffer] match {
@ -1418,6 +1464,7 @@ final class TakeWithin[T](val timeout: FiniteDuration) extends SimpleLinearGraph
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler { override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) with InHandler with OutHandler {
def onPush(): Unit = push(out, grab(in)) def onPush(): Unit = push(out, grab(in))
def onPull(): Unit = pull(in) def onPull(): Unit = pull(in)
setHandler(in, this) setHandler(in, this)
@ -1461,7 +1508,8 @@ final class DropWithin[T](val timeout: FiniteDuration) extends SimpleLinearGraph
final class Reduce[T](val f: (T, T) T) extends SimpleLinearGraphStage[T] { final class Reduce[T](val f: (T, T) T) extends SimpleLinearGraphStage[T] {
override def initialAttributes: Attributes = DefaultAttributes.reduce override def initialAttributes: Attributes = DefaultAttributes.reduce
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { self override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
self
override def toString = s"Reduce.Logic(aggregator=$aggregator)" override def toString = s"Reduce.Logic(aggregator=$aggregator)"
var aggregator: T = _ var aggregator: T = _
@ -1505,6 +1553,7 @@ private[stream] object RecoverWith {
final class RecoverWith[T, M](val maximumRetries: Int, val pf: PartialFunction[Throwable, Graph[SourceShape[T], M]]) extends SimpleLinearGraphStage[T] { final class RecoverWith[T, M](val maximumRetries: Int, val pf: PartialFunction[Throwable, Graph[SourceShape[T], M]]) extends SimpleLinearGraphStage[T] {
require(maximumRetries >= -1, "number of retries must be non-negative or equal to -1") require(maximumRetries >= -1, "number of retries must be non-negative or equal to -1")
override def initialAttributes = DefaultAttributes.recoverWith override def initialAttributes = DefaultAttributes.recoverWith
override def createLogic(attr: Attributes) = new GraphStageLogic(shape) { override def createLogic(attr: Attributes) = new GraphStageLogic(shape) {
@ -1512,6 +1561,7 @@ final class RecoverWith[T, M](val maximumRetries: Int, val pf: PartialFunction[T
setHandler(in, new InHandler { setHandler(in, new InHandler {
override def onPush(): Unit = push(out, grab(in)) override def onPush(): Unit = push(out, grab(in))
override def onUpstreamFailure(ex: Throwable) = onFailure(ex) override def onUpstreamFailure(ex: Throwable) = onFailure(ex)
}) })
@ -1531,12 +1581,15 @@ final class RecoverWith[T, M](val maximumRetries: Int, val pf: PartialFunction[T
sinkIn.setHandler(new InHandler { sinkIn.setHandler(new InHandler {
override def onPush(): Unit = push(out, sinkIn.grab()) override def onPush(): Unit = push(out, sinkIn.grab())
override def onUpstreamFinish(): Unit = completeStage() override def onUpstreamFinish(): Unit = completeStage()
override def onUpstreamFailure(ex: Throwable) = onFailure(ex) override def onUpstreamFailure(ex: Throwable) = onFailure(ex)
}) })
val outHandler = new OutHandler { val outHandler = new OutHandler {
override def onPull(): Unit = sinkIn.pull() override def onPull(): Unit = sinkIn.pull()
override def onDownstreamFinish(): Unit = sinkIn.cancel() override def onDownstreamFinish(): Unit = sinkIn.cancel()
} }
@ -1556,13 +1609,16 @@ final class StatefulMapConcat[In, Out](val f: () ⇒ In ⇒ immutable.Iterable[O
val in = Inlet[In]("StatefulMapConcat.in") val in = Inlet[In]("StatefulMapConcat.in")
val out = Outlet[Out]("StatefulMapConcat.out") val out = Outlet[Out]("StatefulMapConcat.out")
override val shape = FlowShape(in, out) override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.statefulMapConcat override def initialAttributes: Attributes = DefaultAttributes.statefulMapConcat
def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) with InHandler with OutHandler { def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) with InHandler with OutHandler {
lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider) lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
var currentIterator: Iterator[Out] = _ var currentIterator: Iterator[Out] = _
var plainFun = f() var plainFun = f()
def hasNext = if (currentIterator != null) currentIterator.hasNext else false def hasNext = if (currentIterator != null) currentIterator.hasNext else false
setHandlers(in, out, this) setHandlers(in, out, this)
def pushPull(): Unit = def pushPull(): Unit =
@ -1590,6 +1646,7 @@ final class StatefulMapConcat[In, Out](val f: () ⇒ In ⇒ immutable.Iterable[O
} }
override def onUpstreamFinish(): Unit = onFinish() override def onUpstreamFinish(): Unit = onFinish()
override def onPull(): Unit = pushPull() override def onPull(): Unit = pushPull()
private def restartState(): Unit = { private def restartState(): Unit = {
@ -1597,6 +1654,7 @@ final class StatefulMapConcat[In, Out](val f: () ⇒ In ⇒ immutable.Iterable[O
currentIterator = null currentIterator = null
} }
} }
override def toString = "StatefulMapConcat" override def toString = "StatefulMapConcat"
} }