Converts the Scan-operation from PushPullStage to GraphStage
This commit is contained in:
parent
a1423b6e7d
commit
455805cda9
7 changed files with 54 additions and 80 deletions
|
|
@ -310,9 +310,7 @@ trait GraphInterpreterSpecKit extends AkkaSpec {
|
|||
abstract class OneBoundedSetup[T](_ops: GraphStageWithMaterializedValue[Shape, Any]*) extends Builder {
|
||||
val ops = _ops.toArray
|
||||
|
||||
def this(op: Seq[Stage[_, _]], dummy: Int = 42) = {
|
||||
this(op.map(_.toGS): _*)
|
||||
}
|
||||
def this(op: Seq[Stage[_, _]], dummy: Int = 42) = this(op.map(_.toGS): _*)
|
||||
|
||||
val upstream = new UpstreamOneBoundedProbe[T]
|
||||
val downstream = new DownstreamOneBoundedPortProbe[T]
|
||||
|
|
|
|||
|
|
@ -288,49 +288,6 @@ class InterpreterSupervisionSpec extends AkkaSpec with GraphInterpreterSpecKit {
|
|||
}
|
||||
}
|
||||
|
||||
"resume when Scan throws" in new OneBoundedSetup[Int](Seq(
|
||||
Scan(1, (acc: Int, x: Int) ⇒ if (x == 10) throw TE else acc + x, resumingDecider))) {
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(OnNext(1)))
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(RequestOne))
|
||||
upstream.onNext(2)
|
||||
lastEvents() should be(Set(OnNext(3)))
|
||||
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(RequestOne))
|
||||
upstream.onNext(10) // boom
|
||||
lastEvents() should be(Set(RequestOne))
|
||||
|
||||
upstream.onNext(4)
|
||||
lastEvents() should be(Set(OnNext(7))) // 1 + 2 + 4
|
||||
}
|
||||
|
||||
"restart when Scan throws" in new OneBoundedSetup[Int](Seq(
|
||||
Scan(1, (acc: Int, x: Int) ⇒ if (x == 10) throw TE else acc + x, restartingDecider))) {
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(OnNext(1)))
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(RequestOne))
|
||||
upstream.onNext(2)
|
||||
lastEvents() should be(Set(OnNext(3)))
|
||||
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(RequestOne))
|
||||
upstream.onNext(10) // boom
|
||||
lastEvents() should be(Set(RequestOne))
|
||||
|
||||
upstream.onNext(4)
|
||||
lastEvents() should be(Set(OnNext(1))) // starts over again
|
||||
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(OnNext(5)))
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(RequestOne))
|
||||
upstream.onNext(20)
|
||||
lastEvents() should be(Set(OnNext(25))) // 1 + 4 + 20
|
||||
}
|
||||
|
||||
"fail when Expand `seed` throws" in new OneBoundedSetup[Int](
|
||||
new Expand((in: Int) ⇒ if (in == 2) throw TE else Iterator(in) ++ Iterator.continually(-math.abs(in)))) {
|
||||
|
||||
|
|
|
|||
|
|
@ -37,8 +37,7 @@ class FlowScanSpec extends AkkaSpec {
|
|||
}
|
||||
|
||||
"Scan empty" in assertAllStagesStopped {
|
||||
val v = Vector.empty[Int]
|
||||
scan(Source(v)) should be(v.scan(0)(_ + _))
|
||||
scan(Source.empty[Int]) should be(Vector.empty[Int].scan(0)(_ + _))
|
||||
}
|
||||
|
||||
"emit values promptly" in {
|
||||
|
|
@ -46,7 +45,7 @@ class FlowScanSpec extends AkkaSpec {
|
|||
Await.result(f, 1.second) should ===(Seq(0, 1))
|
||||
}
|
||||
|
||||
"fail properly" in {
|
||||
"restart properly" in {
|
||||
import ActorAttributes._
|
||||
val scan = Flow[Int].scan(0) { (old, current) ⇒
|
||||
require(current > 0)
|
||||
|
|
@ -55,5 +54,15 @@ class FlowScanSpec extends AkkaSpec {
|
|||
Source(List(1, 3, -1, 5, 7)).via(scan).runWith(TestSink.probe)
|
||||
.toStrict(1.second) should ===(Seq(0, 1, 4, 0, 5, 12))
|
||||
}
|
||||
|
||||
"resume properly" in {
|
||||
import ActorAttributes._
|
||||
val scan = Flow[Int].scan(0) { (old, current) ⇒
|
||||
require(current > 0)
|
||||
old + current
|
||||
}.withAttributes(supervisionStrategy(Supervision.resumingDecider))
|
||||
Source(List(1, 3, -1, 5, 7)).via(scan).runWith(TestSink.probe)
|
||||
.toStrict(1.second) should ===(Seq(0, 1, 4, 9, 16))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class FlowThrottleSpec extends AkkaSpec {
|
|||
}
|
||||
|
||||
"accept very high rates" in Utils.assertAllStagesStopped {
|
||||
Source(1 to 5).throttle(1, 1 nanos, 0, ThrottleMode.Shaping)
|
||||
Source(1 to 5).throttle(1, 1.nanos, 0, ThrottleMode.Shaping)
|
||||
.runWith(TestSink.probe[Int])
|
||||
.request(5)
|
||||
.expectNext(1, 2, 3, 4, 5)
|
||||
|
|
|
|||
|
|
@ -169,10 +169,6 @@ private[stream] object Stages {
|
|||
override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Sliding(n, step)
|
||||
}
|
||||
|
||||
final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out, attributes: Attributes = scan) extends SymbolicStage[In, Out] {
|
||||
override def create(attr: Attributes): Stage[In, Out] = fusing.Scan(zero, f, supervision(attr))
|
||||
}
|
||||
|
||||
final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out, attributes: Attributes = fold) extends SymbolicStage[In, Out] {
|
||||
override def create(attr: Attributes): Stage[In, Out] = fusing.Fold(zero, f, supervision(attr))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -275,33 +275,47 @@ private[akka] final case class Drop[T](count: Long) extends SimpleLinearGraphSta
|
|||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[akka] final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out, decider: Supervision.Decider) extends PushPullStage[In, Out] {
|
||||
private[akka] final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends GraphStage[FlowShape[In, Out]] {
|
||||
override val shape = FlowShape[In, Out](Inlet("Scan.in"), Outlet("Scan.out"))
|
||||
|
||||
override def initialAttributes: Attributes = DefaultAttributes.scan
|
||||
override def toString: String = "Scan"
|
||||
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
|
||||
new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
self ⇒
|
||||
|
||||
private var aggregator = zero
|
||||
private var pushedZero = false
|
||||
private lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
|
||||
|
||||
override def onPush(elem: In, ctx: Context[Out]): SyncDirective = {
|
||||
if (pushedZero) {
|
||||
aggregator = f(aggregator, elem)
|
||||
ctx.push(aggregator)
|
||||
} else {
|
||||
aggregator = f(zero, elem)
|
||||
ctx.push(zero)
|
||||
import Supervision.{ Stop, Resume, Restart }
|
||||
import shape.{ in, out }
|
||||
|
||||
// Initial behavior makes sure that the zero gets flushed if upstream is empty
|
||||
setHandler(out, new OutHandler {
|
||||
override def onPull(): Unit = {
|
||||
push(out, aggregator)
|
||||
setHandlers(in, out, self)
|
||||
}
|
||||
})
|
||||
setHandler(in, totallyIgnorantInput)
|
||||
|
||||
override def onPull(): Unit = pull(in)
|
||||
override def onPush(): Unit = {
|
||||
try {
|
||||
aggregator = f(aggregator, grab(in))
|
||||
push(out, aggregator)
|
||||
} catch {
|
||||
case NonFatal(ex) ⇒ decider(ex) match {
|
||||
case Resume ⇒ if (!hasBeenPulled(in)) pull(in)
|
||||
case Stop ⇒ failStage(ex)
|
||||
case Restart ⇒
|
||||
aggregator = zero
|
||||
push(out, aggregator)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def onPull(ctx: Context[Out]): SyncDirective =
|
||||
if (!pushedZero) {
|
||||
pushedZero = true
|
||||
if (ctx.isFinishing) ctx.pushAndFinish(aggregator) else ctx.push(aggregator)
|
||||
} else ctx.pull()
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective =
|
||||
if (pushedZero) ctx.finish()
|
||||
else ctx.absorbTermination()
|
||||
|
||||
override def decide(t: Throwable): Supervision.Directive = decider(t)
|
||||
|
||||
override def restart(): Scan[In, Out] = copy()
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -745,7 +745,7 @@ trait FlowOps[+Out, +Mat] {
|
|||
*
|
||||
* '''Cancels when''' downstream cancels
|
||||
*/
|
||||
def scan[T](zero: T)(f: (T, Out) ⇒ T): Repr[T] = andThen(Scan(zero, f))
|
||||
def scan[T](zero: T)(f: (T, Out) ⇒ T): Repr[T] = via(Scan(zero, f))
|
||||
|
||||
/**
|
||||
* Similar to `scan` but only emits its result when the upstream completes,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue