=str 19834 Migrating PushStages to GraphStage

Collect, DropWhile, LimitWeighted
This commit is contained in:
Alexander Golubev 2016-03-09 20:46:42 -05:00
parent b3444fa1b0
commit 951afec88e
7 changed files with 144 additions and 64 deletions

View file

@ -302,27 +302,6 @@ class InterpreterSupervisionSpec extends AkkaSpec with GraphInterpreterSpecKit {
lastEvents() should be(Set(OnNext(3)))
}
"restart when Collect throws" in {
// TODO can't get type inference to work with `pf` inlined
val pf: PartialFunction[Int, Int] =
{ case x: Int if (x == 0) throw TE else x }
new OneBoundedSetup[Int](Seq(
Collect(pf, restartingDecider))) {
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(2)
lastEvents() should be(Set(OnNext(2)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(0) // boom
lastEvents() should be(Set(RequestOne))
upstream.onNext(3)
lastEvents() should be(Set(OnNext(3)))
}
}
"resume when Scan throws" in new OneBoundedSetup[Int](Seq(
Scan(1, (acc: Int, x: Int) if (x == 10) throw TE else acc + x, resumingDecider))) {
downstream.requestOne()

View file

@ -3,15 +3,24 @@
*/
package akka.stream.scaladsl
import akka.stream.ActorAttributes._
import akka.stream.Supervision._
import akka.stream.impl.ConstantFun
import akka.stream.testkit.Utils.TE
import akka.stream.testkit.scaladsl.TestSink
import scala.concurrent.forkjoin.ThreadLocalRandom.{ current random }
import akka.stream.ActorMaterializerSettings
import akka.stream.{ ActorMaterializer, ActorMaterializerSettings }
import akka.testkit.AkkaSpec
import akka.stream.testkit.ScriptedTest
import akka.stream.testkit.{ TestSubscriber, ScriptedTest }
import scala.util.control.NoStackTrace
class FlowCollectSpec extends AkkaSpec with ScriptedTest {
val settings = ActorMaterializerSettings(system)
implicit val materializer = ActorMaterializer(settings)
"A Collect" must {
@ -23,6 +32,19 @@ class FlowCollectSpec extends AkkaSpec with ScriptedTest {
TestConfig.RandomTestRange foreach (_ runScript(script, settings)(_.collect { case x if x % 2 == 0 (x * x).toString }))
}
"restart when Collect throws" in {
val pf: PartialFunction[Int, Int] =
{ case x: Int if (x == 2) throw TE("") else x }
Source(1 to 3).collect(pf).withAttributes(supervisionStrategy(restartingDecider))
.runWith(TestSink.probe[Int])
.request(1)
.expectNext(1)
.request(1)
.expectNext(3)
.request(1)
.expectComplete()
}
}
}

View file

@ -35,13 +35,24 @@ class FlowDropWhileSpec extends AkkaSpec {
}
"continue if error" in assertAllStagesStopped {
val testException = new Exception("test") with NoStackTrace
Source(1 to 4).dropWhile(a if (a < 3) true else throw testException).withAttributes(supervisionStrategy(resumingDecider))
Source(1 to 4).dropWhile(a if (a < 3) true else throw TE("")).withAttributes(supervisionStrategy(resumingDecider))
.runWith(TestSink.probe[Int])
.request(1)
.expectComplete()
}
"restart with strategy" in assertAllStagesStopped {
Source(1 to 4).dropWhile {
case 1 | 3 true
case 4 false
case 2 throw TE("")
}.withAttributes(supervisionStrategy(restartingDecider))
.runWith(TestSink.probe[Int])
.request(1)
.expectNext(4)
.expectComplete()
}
}
}

View file

@ -155,10 +155,6 @@ private[stream] object Stages {
override def create(attr: Attributes): Stage[T, T] = fusing.Filter(p, supervision(attr))
}
final case class Collect[In, Out](pf: PartialFunction[In, Out], attributes: Attributes = collect) extends SymbolicStage[In, Out] {
override def create(attr: Attributes): Stage[In, Out] = fusing.Collect(pf, supervision(attr))
}
final case class Recover[In, Out >: In](pf: PartialFunction[Throwable, Out], attributes: Attributes = recover) extends SymbolicStage[In, Out] {
override def create(attr: Attributes): Stage[In, Out] = fusing.Recover(pf)
}
@ -168,10 +164,6 @@ private[stream] object Stages {
override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Grouped(n)
}
final case class LimitWeighted[T](max: Long, weightFn: T Long, attributes: Attributes = limitWeighted) extends SymbolicStage[T, T] {
override def create(attr: Attributes): Stage[T, T] = fusing.LimitWeighted(max, weightFn)
}
final case class Sliding[T](n: Int, step: Int, attributes: Attributes = sliding) extends SymbolicStage[T, immutable.Seq[T]] {
require(n > 0, "n must be greater than 0")
require(step > 0, "step must be greater than 0")
@ -183,10 +175,6 @@ private[stream] object Stages {
override def create(attr: Attributes): Stage[T, T] = fusing.TakeWhile(p, supervision(attr))
}
final case class DropWhile[T](p: T Boolean, attributes: Attributes = dropWhile) extends SymbolicStage[T, T] {
override def create(attr: Attributes): Stage[T, T] = fusing.DropWhile(p, supervision(attr))
}
final case class Scan[In, Out](zero: Out, f: (Out, In) Out, attributes: Attributes = scan) extends SymbolicStage[In, Out] {
override def create(attr: Attributes): Stage[In, Out] = fusing.Scan(zero, f, supervision(attr))
}

View file

@ -59,38 +59,89 @@ private[akka] final case class TakeWhile[T](p: T ⇒ Boolean, decider: Supervisi
/**
* INTERNAL API
*/
private[akka] final case class DropWhile[T](p: T Boolean, decider: Supervision.Decider) extends PushStage[T, T] {
var taking = false
private[stream] final case class DropWhile[T](p: T Boolean) extends GraphStage[FlowShape[T, T]] {
val in = Inlet[T]("DropWhile.in")
val out = Outlet[T]("DropWhile.out")
override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.dropWhile
override def onPush(elem: T, ctx: Context[T]): SyncDirective =
if (taking || !p(elem)) {
taking = true
ctx.push(elem)
} else {
ctx.pull()
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
override def onPush(): Unit = {
val elem = grab(in)
withSupervision(() p(elem)) match {
case Some(flag) if flag pull(in)
case Some(flag) if !flag
push(out, elem)
setHandler(in, rest)
case None // do nothing
}
}
override def decide(t: Throwable): Supervision.Directive = decider(t)
def rest = new InHandler {
def onPush() = push(out, grab(in))
}
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
override def onPull(): Unit = pull(in)
setHandlers(in, out, this)
}
override def toString = "DropWhile"
}
private[akka] object Collect {
/**
* INTERNAL API
*/
abstract private[stream] class SupervisedGraphStageLogic(inheritedAttributes: Attributes, shape: Shape) extends GraphStageLogic(shape) {
private lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
def withSupervision[T](f: () T): Option[T] =
try { Some(f()) } catch {
case NonFatal(ex)
decider(ex) match {
case Supervision.Stop onStop(ex)
case Supervision.Resume onResume(ex)
case Supervision.Restart onRestart(ex)
}
None
}
def onResume(t: Throwable): Unit
def onStop(t: Throwable): Unit = failStage(t)
def onRestart(t: Throwable): Unit = onResume(t)
}
private[stream] object Collect {
// Cached function that can be used with PartialFunction.applyOrElse to ensure that A) the guard is only applied once,
// and the caller can check the returned value with Collect.notApplied to query whether the PF was applied or not.
// Prior art: https://github.com/scala/scala/blob/v2.11.4/src/library/scala/collection/immutable/List.scala#L458
final val NotApplied: Any Any = _ Collect.NotApplied
}
private[akka] final case class Collect[In, Out](pf: PartialFunction[In, Out], decider: Supervision.Decider) extends PushStage[In, Out] {
/**
* INTERNAL API
*/
private[stream] final case class Collect[In, Out](pf: PartialFunction[In, Out]) extends GraphStage[FlowShape[In, Out]] {
val in = Inlet[In]("Collect.in")
val out = Outlet[Out]("Collect.out")
override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.collect
import Collect.NotApplied
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
import Collect.NotApplied
val wrappedPf = () pf.applyOrElse(grab(in), NotApplied)
override def onPush(elem: In, ctx: Context[Out]): SyncDirective =
pf.applyOrElse(elem, NotApplied) match {
case NotApplied ctx.pull()
case result: Out @unchecked ctx.push(result)
override def onPush(): Unit = withSupervision(wrappedPf) match {
case Some(result) result match {
case NotApplied pull(in)
case result: Out @unchecked push(out, result)
}
case None //do nothing
}
override def decide(t: Throwable): Supervision.Directive = decider(t)
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
override def onPull(): Unit = pull(in)
setHandlers(in, out, this)
}
override def toString = "Collect"
}
/**
@ -312,15 +363,33 @@ private[akka] final case class Grouped[T](n: Int) extends PushPullStage[T, immut
/**
* INTERNAL API
*/
private[stream] final case class LimitWeighted[T](n: Long, costFn: T Long) extends GraphStage[FlowShape[T, T]] {
val in = Inlet[T]("LimitWeighted.in")
val out = Outlet[T]("LimitWeighted.out")
override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.limitWeighted
private[akka] final case class LimitWeighted[T](n: Long, costFn: T Long) extends PushStage[T, T] {
private var left = n
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
private var left = n
override def onPush(elem: T, ctx: Context[T]): SyncDirective = {
left -= costFn(elem)
if (left >= 0) ctx.push(elem)
else ctx.fail(new StreamLimitReachedException(n))
override def onPush(): Unit = {
val elem = grab(in)
withSupervision(() costFn(elem)) match {
case Some(wight)
left -= wight
if (left >= 0) push(out, elem) else failStage(new StreamLimitReachedException(n))
case None //do nothing
}
}
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
override def onRestart(t: Throwable): Unit = {
left = n
if (!hasBeenPulled(in)) pull(in)
}
override def onPull(): Unit = pull(in)
setHandlers(in, out, this)
}
override def toString = "LimitWeighted"
}
/**

View file

@ -627,7 +627,7 @@ trait FlowOps[+Out, +Mat] {
*
* '''Cancels when''' downstream cancels
*/
def dropWhile(p: Out Boolean): Repr[Out] = andThen(DropWhile(p))
def dropWhile(p: Out Boolean): Repr[Out] = via(DropWhile(p))
/**
* Transform this stream by applying the given partial function to each of the elements
@ -642,7 +642,7 @@ trait FlowOps[+Out, +Mat] {
*
* '''Cancels when''' downstream cancels
*/
def collect[T](pf: PartialFunction[Out, T]): Repr[T] = andThen(Collect(pf))
def collect[T](pf: PartialFunction[Out, T]): Repr[T] = via(Collect(pf))
/**
* Chunk up this stream into groups of the given size, with the last group
@ -705,7 +705,7 @@ trait FlowOps[+Out, +Mat] {
*
* See also [[FlowOps.take]], [[FlowOps.takeWithin]], [[FlowOps.takeWhile]]
*/
def limitWeighted[T](max: Long)(costFn: Out Long): Repr[Out] = andThen(LimitWeighted(max, costFn))
def limitWeighted[T](max: Long)(costFn: Out Long): Repr[Out] = via(LimitWeighted(max, costFn))
/**
* Apply a sliding window over the stream and return the windows as groups of elements, with the last group

View file

@ -754,7 +754,18 @@ object MiMa extends AutoPlugin {
// #20028 Simplify TickSource cancellation
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.fusing.GraphStages$TickSource$TickSourceCancellable"),
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.fusing.GraphStages$TickSource$")
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.fusing.GraphStages$TickSource$"),
// #19834 replacing PushStages usages with GraphStages
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$LimitWeighted"),
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$Collect$"),
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$DropWhile"),
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$LimitWeighted$"),
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$Collect"),
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$DropWhile$"),
FilterAnyProblemStartingWith("akka.stream.impl.fusing.Collect"),
FilterAnyProblemStartingWith("akka.stream.impl.fusing.DropWhile"),
FilterAnyProblemStartingWith("akka.stream.impl.fusing.LimitWeighted")
)
)
}