FoldAsync op for Flow #18603

This commit is contained in:
Cédric Chantepie 2016-08-24 21:02:32 +02:00 committed by Johan Andrén
parent 9630feb6cc
commit efc87af58a
16 changed files with 519 additions and 620 deletions

View file

@ -161,6 +161,7 @@ final case class DropWhile[T](p: T ⇒ Boolean) extends GraphStage[FlowShape[T,
*/
abstract class SupervisedGraphStageLogic(inheritedAttributes: Attributes, shape: Shape) extends GraphStageLogic(shape) {
private lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
def withSupervision[T](f: () T): Option[T] =
try { Some(f()) } catch {
case NonFatal(ex)
@ -376,6 +377,8 @@ final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
val out = Outlet[Out]("Fold.out")
override val shape: FlowShape[In, Out] = FlowShape(in, out)
override def toString: String = "Fold"
override val initialAttributes = DefaultAttributes.fold
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
@ -419,6 +422,98 @@ final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphSta
}
}
/**
* INTERNAL API
*/
final class FoldAsync[In, Out](zero: Out, f: (Out, In) Future[Out]) extends GraphStage[FlowShape[In, Out]] {
import akka.dispatch.ExecutionContexts
val in = Inlet[In]("FoldAsync.in")
val out = Outlet[Out]("FoldAsync.out")
val shape = FlowShape.of(in, out)
override def toString: String = "FoldAsync"
override val initialAttributes = DefaultAttributes.foldAsync
def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
private var aggregator: Out = zero
private var aggregating: Future[Out] = Future.successful(aggregator)
private def onRestart(t: Throwable): Unit = {
aggregator = zero
}
private def ec = ExecutionContexts.sameThreadExecutionContext
private val futureCB = getAsyncCallback[Try[Out]]((result: Try[Out]) {
result match {
case Success(update) if update != null {
aggregator = update
if (isClosed(in)) {
push(out, update)
completeStage()
} else if (isAvailable(out) && !hasBeenPulled(in)) tryPull(in)
}
case other {
val ex = other match {
case Failure(t) t
case Success(s) if s == null
ReactiveStreamsCompliance.elementMustNotBeNullException
}
val supervision = decider(ex)
if (supervision == Supervision.Stop) failStage(ex)
else {
if (supervision == Supervision.Restart) onRestart(ex)
if (isClosed(in)) {
push(out, aggregator)
completeStage()
} else if (isAvailable(out) && !hasBeenPulled(in)) tryPull(in)
}
}
}
}).invoke _
def onPush(): Unit = {
try {
aggregating = f(aggregator, grab(in))
aggregating.value match {
case Some(result) futureCB(result) // already completed
case _ aggregating.onComplete(futureCB)(ec)
}
} catch {
case NonFatal(ex) decider(ex) match {
case Supervision.Stop failStage(ex)
case supervision {
supervision match {
case Supervision.Restart onRestart(ex)
case _ () // just ignore on Resume
}
tryPull(in)
}
}
}
}
override def onUpstreamFinish(): Unit = {}
def onPull(): Unit = if (!hasBeenPulled(in)) tryPull(in)
setHandlers(in, out, this)
override def toString =
s"FoldAsync.Logic(completed=${aggregating.isCompleted})"
}
}
/**
* INTERNAL API
*/
@ -954,8 +1049,8 @@ final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[O
val decider =
inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
var inFlight = 0
var buffer: BufferImpl[Out] = _
private var inFlight = 0
private var buffer: BufferImpl[Out] = _
private[this] def todo = inFlight + buffer.used
override def preStart(): Unit = buffer = BufferImpl(parallelism, materializer)
@ -993,6 +1088,7 @@ final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[O
}
if (todo < parallelism) tryPull(in)
}
override def onUpstreamFinish(): Unit = {
if (todo == 0) completeStage()
}
@ -1000,6 +1096,7 @@ final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[O
override def onPull(): Unit = {
if (!buffer.isEmpty) push(out, buffer.dequeue())
else if (isClosed(in) && todo == 0) completeStage()
if (todo < parallelism && !hasBeenPulled(in)) tryPull(in)
}