Merge pull request #20325 from akka/wip-convert-scan-to-graphstage-√

Converts the Scan-operation from PushPullStage to GraphStage
This commit is contained in:
Patrik Nordwall 2016-04-20 15:37:16 +02:00
commit dfbcd948bf
7 changed files with 54 additions and 80 deletions

View file

@ -310,9 +310,7 @@ trait GraphInterpreterSpecKit extends AkkaSpec {
abstract class OneBoundedSetup[T](_ops: GraphStageWithMaterializedValue[Shape, Any]*) extends Builder { abstract class OneBoundedSetup[T](_ops: GraphStageWithMaterializedValue[Shape, Any]*) extends Builder {
val ops = _ops.toArray val ops = _ops.toArray
def this(op: Seq[Stage[_, _]], dummy: Int = 42) = { def this(op: Seq[Stage[_, _]], dummy: Int = 42) = this(op.map(_.toGS): _*)
this(op.map(_.toGS): _*)
}
val upstream = new UpstreamOneBoundedProbe[T] val upstream = new UpstreamOneBoundedProbe[T]
val downstream = new DownstreamOneBoundedPortProbe[T] val downstream = new DownstreamOneBoundedPortProbe[T]

View file

@ -288,49 +288,6 @@ class InterpreterSupervisionSpec extends AkkaSpec with GraphInterpreterSpecKit {
} }
} }
"resume when Scan throws" in new OneBoundedSetup[Int](Seq(
Scan(1, (acc: Int, x: Int) if (x == 10) throw TE else acc + x, resumingDecider))) {
downstream.requestOne()
lastEvents() should be(Set(OnNext(1)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(2)
lastEvents() should be(Set(OnNext(3)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(10) // boom
lastEvents() should be(Set(RequestOne))
upstream.onNext(4)
lastEvents() should be(Set(OnNext(7))) // 1 + 2 + 4
}
"restart when Scan throws" in new OneBoundedSetup[Int](Seq(
Scan(1, (acc: Int, x: Int) if (x == 10) throw TE else acc + x, restartingDecider))) {
downstream.requestOne()
lastEvents() should be(Set(OnNext(1)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(2)
lastEvents() should be(Set(OnNext(3)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(10) // boom
lastEvents() should be(Set(RequestOne))
upstream.onNext(4)
lastEvents() should be(Set(OnNext(1))) // starts over again
downstream.requestOne()
lastEvents() should be(Set(OnNext(5)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(20)
lastEvents() should be(Set(OnNext(25))) // 1 + 4 + 20
}
"fail when Expand `seed` throws" in new OneBoundedSetup[Int]( "fail when Expand `seed` throws" in new OneBoundedSetup[Int](
new Expand((in: Int) if (in == 2) throw TE else Iterator(in) ++ Iterator.continually(-math.abs(in)))) { new Expand((in: Int) if (in == 2) throw TE else Iterator(in) ++ Iterator.continually(-math.abs(in)))) {

View file

@ -37,8 +37,7 @@ class FlowScanSpec extends AkkaSpec {
} }
"Scan empty" in assertAllStagesStopped { "Scan empty" in assertAllStagesStopped {
val v = Vector.empty[Int] scan(Source.empty[Int]) should be(Vector.empty[Int].scan(0)(_ + _))
scan(Source(v)) should be(v.scan(0)(_ + _))
} }
"emit values promptly" in { "emit values promptly" in {
@ -46,7 +45,7 @@ class FlowScanSpec extends AkkaSpec {
Await.result(f, 1.second) should ===(Seq(0, 1)) Await.result(f, 1.second) should ===(Seq(0, 1))
} }
"fail properly" in { "restart properly" in {
import ActorAttributes._ import ActorAttributes._
val scan = Flow[Int].scan(0) { (old, current) val scan = Flow[Int].scan(0) { (old, current)
require(current > 0) require(current > 0)
@ -55,5 +54,15 @@ class FlowScanSpec extends AkkaSpec {
Source(List(1, 3, -1, 5, 7)).via(scan).runWith(TestSink.probe) Source(List(1, 3, -1, 5, 7)).via(scan).runWith(TestSink.probe)
.toStrict(1.second) should ===(Seq(0, 1, 4, 0, 5, 12)) .toStrict(1.second) should ===(Seq(0, 1, 4, 0, 5, 12))
} }
"resume properly" in {
import ActorAttributes._
val scan = Flow[Int].scan(0) { (old, current)
require(current > 0)
old + current
}.withAttributes(supervisionStrategy(Supervision.resumingDecider))
Source(List(1, 3, -1, 5, 7)).via(scan).runWith(TestSink.probe)
.toStrict(1.second) should ===(Seq(0, 1, 4, 9, 16))
}
} }
} }

View file

@ -30,7 +30,7 @@ class FlowThrottleSpec extends AkkaSpec {
} }
"accept very high rates" in Utils.assertAllStagesStopped { "accept very high rates" in Utils.assertAllStagesStopped {
Source(1 to 5).throttle(1, 1 nanos, 0, ThrottleMode.Shaping) Source(1 to 5).throttle(1, 1.nanos, 0, ThrottleMode.Shaping)
.runWith(TestSink.probe[Int]) .runWith(TestSink.probe[Int])
.request(5) .request(5)
.expectNext(1, 2, 3, 4, 5) .expectNext(1, 2, 3, 4, 5)

View file

@ -165,10 +165,6 @@ private[stream] object Stages {
override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Sliding(n, step) override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Sliding(n, step)
} }
final case class Scan[In, Out](zero: Out, f: (Out, In) Out, attributes: Attributes = scan) extends SymbolicStage[In, Out] {
override def create(attr: Attributes): Stage[In, Out] = fusing.Scan(zero, f, supervision(attr))
}
final case class Fold[In, Out](zero: Out, f: (Out, In) Out, attributes: Attributes = fold) extends SymbolicStage[In, Out] { final case class Fold[In, Out](zero: Out, f: (Out, In) Out, attributes: Attributes = fold) extends SymbolicStage[In, Out] {
override def create(attr: Attributes): Stage[In, Out] = fusing.Fold(zero, f, supervision(attr)) override def create(attr: Attributes): Stage[In, Out] = fusing.Fold(zero, f, supervision(attr))
} }

View file

@ -294,33 +294,47 @@ private[akka] final case class Drop[T](count: Long) extends SimpleLinearGraphSta
/** /**
* INTERNAL API * INTERNAL API
*/ */
private[akka] final case class Scan[In, Out](zero: Out, f: (Out, In) Out, decider: Supervision.Decider) extends PushPullStage[In, Out] { private[akka] final case class Scan[In, Out](zero: Out, f: (Out, In) Out) extends GraphStage[FlowShape[In, Out]] {
private var aggregator = zero override val shape = FlowShape[In, Out](Inlet("Scan.in"), Outlet("Scan.out"))
private var pushedZero = false
override def onPush(elem: In, ctx: Context[Out]): SyncDirective = { override def initialAttributes: Attributes = DefaultAttributes.scan
if (pushedZero) { override def toString: String = "Scan"
aggregator = f(aggregator, elem)
ctx.push(aggregator) override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
} else { new GraphStageLogic(shape) with InHandler with OutHandler {
aggregator = f(zero, elem) self
ctx.push(zero)
private var aggregator = zero
private lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
import Supervision.{ Stop, Resume, Restart }
import shape.{ in, out }
// Initial behavior makes sure that the zero gets flushed if upstream is empty
setHandler(out, new OutHandler {
override def onPull(): Unit = {
push(out, aggregator)
setHandlers(in, out, self)
}
})
setHandler(in, totallyIgnorantInput)
override def onPull(): Unit = pull(in)
override def onPush(): Unit = {
try {
aggregator = f(aggregator, grab(in))
push(out, aggregator)
} catch {
case NonFatal(ex) decider(ex) match {
case Resume if (!hasBeenPulled(in)) pull(in)
case Stop failStage(ex)
case Restart
aggregator = zero
push(out, aggregator)
}
}
}
} }
}
override def onPull(ctx: Context[Out]): SyncDirective =
if (!pushedZero) {
pushedZero = true
if (ctx.isFinishing) ctx.pushAndFinish(aggregator) else ctx.push(aggregator)
} else ctx.pull()
override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective =
if (pushedZero) ctx.finish()
else ctx.absorbTermination()
override def decide(t: Throwable): Supervision.Directive = decider(t)
override def restart(): Scan[In, Out] = copy()
} }
/** /**

View file

@ -768,7 +768,7 @@ trait FlowOps[+Out, +Mat] {
* *
* '''Cancels when''' downstream cancels * '''Cancels when''' downstream cancels
*/ */
def scan[T](zero: T)(f: (T, Out) T): Repr[T] = andThen(Scan(zero, f)) def scan[T](zero: T)(f: (T, Out) T): Repr[T] = via(Scan(zero, f))
/** /**
* Similar to `scan` but only emits its result when the upstream completes, * Similar to `scan` but only emits its result when the upstream completes,