Converts the Scan-operation from PushPullStage to GraphStage

This commit is contained in:
Viktor Klang 2016-04-14 17:33:19 +02:00
parent a1423b6e7d
commit 455805cda9
7 changed files with 54 additions and 80 deletions

View file

@ -310,9 +310,7 @@ trait GraphInterpreterSpecKit extends AkkaSpec {
abstract class OneBoundedSetup[T](_ops: GraphStageWithMaterializedValue[Shape, Any]*) extends Builder {
val ops = _ops.toArray
def this(op: Seq[Stage[_, _]], dummy: Int = 42) = {
this(op.map(_.toGS): _*)
}
def this(op: Seq[Stage[_, _]], dummy: Int = 42) = this(op.map(_.toGS): _*)
val upstream = new UpstreamOneBoundedProbe[T]
val downstream = new DownstreamOneBoundedPortProbe[T]

View file

@ -288,49 +288,6 @@ class InterpreterSupervisionSpec extends AkkaSpec with GraphInterpreterSpecKit {
}
}
"resume when Scan throws" in new OneBoundedSetup[Int](Seq(
Scan(1, (acc: Int, x: Int) if (x == 10) throw TE else acc + x, resumingDecider))) {
downstream.requestOne()
lastEvents() should be(Set(OnNext(1)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(2)
lastEvents() should be(Set(OnNext(3)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(10) // boom
lastEvents() should be(Set(RequestOne))
upstream.onNext(4)
lastEvents() should be(Set(OnNext(7))) // 1 + 2 + 4
}
"restart when Scan throws" in new OneBoundedSetup[Int](Seq(
Scan(1, (acc: Int, x: Int) if (x == 10) throw TE else acc + x, restartingDecider))) {
downstream.requestOne()
lastEvents() should be(Set(OnNext(1)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(2)
lastEvents() should be(Set(OnNext(3)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(10) // boom
lastEvents() should be(Set(RequestOne))
upstream.onNext(4)
lastEvents() should be(Set(OnNext(1))) // starts over again
downstream.requestOne()
lastEvents() should be(Set(OnNext(5)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(20)
lastEvents() should be(Set(OnNext(25))) // 1 + 4 + 20
}
"fail when Expand `seed` throws" in new OneBoundedSetup[Int](
new Expand((in: Int) if (in == 2) throw TE else Iterator(in) ++ Iterator.continually(-math.abs(in)))) {

View file

@ -37,8 +37,7 @@ class FlowScanSpec extends AkkaSpec {
}
"Scan empty" in assertAllStagesStopped {
val v = Vector.empty[Int]
scan(Source(v)) should be(v.scan(0)(_ + _))
scan(Source.empty[Int]) should be(Vector.empty[Int].scan(0)(_ + _))
}
"emit values promptly" in {
@ -46,7 +45,7 @@ class FlowScanSpec extends AkkaSpec {
Await.result(f, 1.second) should ===(Seq(0, 1))
}
"fail properly" in {
"restart properly" in {
import ActorAttributes._
val scan = Flow[Int].scan(0) { (old, current)
require(current > 0)
@ -55,5 +54,15 @@ class FlowScanSpec extends AkkaSpec {
Source(List(1, 3, -1, 5, 7)).via(scan).runWith(TestSink.probe)
.toStrict(1.second) should ===(Seq(0, 1, 4, 0, 5, 12))
}
"resume properly" in {
import ActorAttributes._
val scan = Flow[Int].scan(0) { (old, current)
require(current > 0)
old + current
}.withAttributes(supervisionStrategy(Supervision.resumingDecider))
Source(List(1, 3, -1, 5, 7)).via(scan).runWith(TestSink.probe)
.toStrict(1.second) should ===(Seq(0, 1, 4, 9, 16))
}
}
}

View file

@ -30,7 +30,7 @@ class FlowThrottleSpec extends AkkaSpec {
}
"accept very high rates" in Utils.assertAllStagesStopped {
Source(1 to 5).throttle(1, 1 nanos, 0, ThrottleMode.Shaping)
Source(1 to 5).throttle(1, 1.nanos, 0, ThrottleMode.Shaping)
.runWith(TestSink.probe[Int])
.request(5)
.expectNext(1, 2, 3, 4, 5)

View file

@ -169,10 +169,6 @@ private[stream] object Stages {
override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Sliding(n, step)
}
final case class Scan[In, Out](zero: Out, f: (Out, In) Out, attributes: Attributes = scan) extends SymbolicStage[In, Out] {
override def create(attr: Attributes): Stage[In, Out] = fusing.Scan(zero, f, supervision(attr))
}
final case class Fold[In, Out](zero: Out, f: (Out, In) Out, attributes: Attributes = fold) extends SymbolicStage[In, Out] {
override def create(attr: Attributes): Stage[In, Out] = fusing.Fold(zero, f, supervision(attr))
}

View file

@ -275,33 +275,47 @@ private[akka] final case class Drop[T](count: Long) extends SimpleLinearGraphSta
/**
* INTERNAL API
*/
private[akka] final case class Scan[In, Out](zero: Out, f: (Out, In) Out, decider: Supervision.Decider) extends PushPullStage[In, Out] {
private[akka] final case class Scan[In, Out](zero: Out, f: (Out, In) Out) extends GraphStage[FlowShape[In, Out]] {
override val shape = FlowShape[In, Out](Inlet("Scan.in"), Outlet("Scan.out"))
override def initialAttributes: Attributes = DefaultAttributes.scan
override def toString: String = "Scan"
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {
self
private var aggregator = zero
private var pushedZero = false
private lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
override def onPush(elem: In, ctx: Context[Out]): SyncDirective = {
if (pushedZero) {
aggregator = f(aggregator, elem)
ctx.push(aggregator)
} else {
aggregator = f(zero, elem)
ctx.push(zero)
import Supervision.{ Stop, Resume, Restart }
import shape.{ in, out }
// Initial behavior makes sure that the zero gets flushed if upstream is empty
setHandler(out, new OutHandler {
override def onPull(): Unit = {
push(out, aggregator)
setHandlers(in, out, self)
}
})
setHandler(in, totallyIgnorantInput)
override def onPull(): Unit = pull(in)
override def onPush(): Unit = {
try {
aggregator = f(aggregator, grab(in))
push(out, aggregator)
} catch {
case NonFatal(ex) decider(ex) match {
case Resume if (!hasBeenPulled(in)) pull(in)
case Stop failStage(ex)
case Restart
aggregator = zero
push(out, aggregator)
}
}
}
}
override def onPull(ctx: Context[Out]): SyncDirective =
if (!pushedZero) {
pushedZero = true
if (ctx.isFinishing) ctx.pushAndFinish(aggregator) else ctx.push(aggregator)
} else ctx.pull()
override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective =
if (pushedZero) ctx.finish()
else ctx.absorbTermination()
override def decide(t: Throwable): Supervision.Directive = decider(t)
override def restart(): Scan[In, Out] = copy()
}
/**

View file

@ -745,7 +745,7 @@ trait FlowOps[+Out, +Mat] {
*
* '''Cancels when''' downstream cancels
*/
def scan[T](zero: T)(f: (T, Out) T): Repr[T] = andThen(Scan(zero, f))
def scan[T](zero: T)(f: (T, Out) T): Repr[T] = via(Scan(zero, f))
/**
* Similar to `scan` but only emits its result when the upstream completes,