=str 19834 Migrating PushStages to GraphStage

Collect, DropWhile, LimitWeighted
This commit is contained in:
Alexander Golubev 2016-03-09 20:46:42 -05:00
parent b3444fa1b0
commit 951afec88e
7 changed files with 144 additions and 64 deletions

View file

@ -302,27 +302,6 @@ class InterpreterSupervisionSpec extends AkkaSpec with GraphInterpreterSpecKit {
lastEvents() should be(Set(OnNext(3))) lastEvents() should be(Set(OnNext(3)))
} }
"restart when Collect throws" in {
// TODO can't get type inference to work with `pf` inlined
val pf: PartialFunction[Int, Int] =
{ case x: Int if (x == 0) throw TE else x }
new OneBoundedSetup[Int](Seq(
Collect(pf, restartingDecider))) {
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(2)
lastEvents() should be(Set(OnNext(2)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(0) // boom
lastEvents() should be(Set(RequestOne))
upstream.onNext(3)
lastEvents() should be(Set(OnNext(3)))
}
}
"resume when Scan throws" in new OneBoundedSetup[Int](Seq( "resume when Scan throws" in new OneBoundedSetup[Int](Seq(
Scan(1, (acc: Int, x: Int) if (x == 10) throw TE else acc + x, resumingDecider))) { Scan(1, (acc: Int, x: Int) if (x == 10) throw TE else acc + x, resumingDecider))) {
downstream.requestOne() downstream.requestOne()

View file

@ -3,15 +3,24 @@
*/ */
package akka.stream.scaladsl package akka.stream.scaladsl
import akka.stream.ActorAttributes._
import akka.stream.Supervision._
import akka.stream.impl.ConstantFun
import akka.stream.testkit.Utils.TE
import akka.stream.testkit.scaladsl.TestSink
import scala.concurrent.forkjoin.ThreadLocalRandom.{ current random } import scala.concurrent.forkjoin.ThreadLocalRandom.{ current random }
import akka.stream.ActorMaterializerSettings import akka.stream.{ ActorMaterializer, ActorMaterializerSettings }
import akka.testkit.AkkaSpec import akka.testkit.AkkaSpec
import akka.stream.testkit.ScriptedTest import akka.stream.testkit.{ TestSubscriber, ScriptedTest }
import scala.util.control.NoStackTrace
class FlowCollectSpec extends AkkaSpec with ScriptedTest { class FlowCollectSpec extends AkkaSpec with ScriptedTest {
val settings = ActorMaterializerSettings(system) val settings = ActorMaterializerSettings(system)
implicit val materializer = ActorMaterializer(settings)
"A Collect" must { "A Collect" must {
@ -23,6 +32,19 @@ class FlowCollectSpec extends AkkaSpec with ScriptedTest {
TestConfig.RandomTestRange foreach (_ runScript(script, settings)(_.collect { case x if x % 2 == 0 (x * x).toString })) TestConfig.RandomTestRange foreach (_ runScript(script, settings)(_.collect { case x if x % 2 == 0 (x * x).toString }))
} }
"restart when Collect throws" in {
val pf: PartialFunction[Int, Int] =
{ case x: Int if (x == 2) throw TE("") else x }
Source(1 to 3).collect(pf).withAttributes(supervisionStrategy(restartingDecider))
.runWith(TestSink.probe[Int])
.request(1)
.expectNext(1)
.request(1)
.expectNext(3)
.request(1)
.expectComplete()
}
} }
} }

View file

@ -35,13 +35,24 @@ class FlowDropWhileSpec extends AkkaSpec {
} }
"continue if error" in assertAllStagesStopped { "continue if error" in assertAllStagesStopped {
val testException = new Exception("test") with NoStackTrace Source(1 to 4).dropWhile(a if (a < 3) true else throw TE("")).withAttributes(supervisionStrategy(resumingDecider))
Source(1 to 4).dropWhile(a if (a < 3) true else throw testException).withAttributes(supervisionStrategy(resumingDecider))
.runWith(TestSink.probe[Int]) .runWith(TestSink.probe[Int])
.request(1) .request(1)
.expectComplete() .expectComplete()
} }
"restart with strategy" in assertAllStagesStopped {
Source(1 to 4).dropWhile {
case 1 | 3 true
case 4 false
case 2 throw TE("")
}.withAttributes(supervisionStrategy(restartingDecider))
.runWith(TestSink.probe[Int])
.request(1)
.expectNext(4)
.expectComplete()
}
} }
} }

View file

@ -155,10 +155,6 @@ private[stream] object Stages {
override def create(attr: Attributes): Stage[T, T] = fusing.Filter(p, supervision(attr)) override def create(attr: Attributes): Stage[T, T] = fusing.Filter(p, supervision(attr))
} }
final case class Collect[In, Out](pf: PartialFunction[In, Out], attributes: Attributes = collect) extends SymbolicStage[In, Out] {
override def create(attr: Attributes): Stage[In, Out] = fusing.Collect(pf, supervision(attr))
}
final case class Recover[In, Out >: In](pf: PartialFunction[Throwable, Out], attributes: Attributes = recover) extends SymbolicStage[In, Out] { final case class Recover[In, Out >: In](pf: PartialFunction[Throwable, Out], attributes: Attributes = recover) extends SymbolicStage[In, Out] {
override def create(attr: Attributes): Stage[In, Out] = fusing.Recover(pf) override def create(attr: Attributes): Stage[In, Out] = fusing.Recover(pf)
} }
@ -168,10 +164,6 @@ private[stream] object Stages {
override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Grouped(n) override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Grouped(n)
} }
final case class LimitWeighted[T](max: Long, weightFn: T Long, attributes: Attributes = limitWeighted) extends SymbolicStage[T, T] {
override def create(attr: Attributes): Stage[T, T] = fusing.LimitWeighted(max, weightFn)
}
final case class Sliding[T](n: Int, step: Int, attributes: Attributes = sliding) extends SymbolicStage[T, immutable.Seq[T]] { final case class Sliding[T](n: Int, step: Int, attributes: Attributes = sliding) extends SymbolicStage[T, immutable.Seq[T]] {
require(n > 0, "n must be greater than 0") require(n > 0, "n must be greater than 0")
require(step > 0, "step must be greater than 0") require(step > 0, "step must be greater than 0")
@ -183,10 +175,6 @@ private[stream] object Stages {
override def create(attr: Attributes): Stage[T, T] = fusing.TakeWhile(p, supervision(attr)) override def create(attr: Attributes): Stage[T, T] = fusing.TakeWhile(p, supervision(attr))
} }
final case class DropWhile[T](p: T Boolean, attributes: Attributes = dropWhile) extends SymbolicStage[T, T] {
override def create(attr: Attributes): Stage[T, T] = fusing.DropWhile(p, supervision(attr))
}
final case class Scan[In, Out](zero: Out, f: (Out, In) Out, attributes: Attributes = scan) extends SymbolicStage[In, Out] { final case class Scan[In, Out](zero: Out, f: (Out, In) Out, attributes: Attributes = scan) extends SymbolicStage[In, Out] {
override def create(attr: Attributes): Stage[In, Out] = fusing.Scan(zero, f, supervision(attr)) override def create(attr: Attributes): Stage[In, Out] = fusing.Scan(zero, f, supervision(attr))
} }

View file

@ -59,38 +59,89 @@ private[akka] final case class TakeWhile[T](p: T ⇒ Boolean, decider: Supervisi
/** /**
* INTERNAL API * INTERNAL API
*/ */
private[akka] final case class DropWhile[T](p: T Boolean, decider: Supervision.Decider) extends PushStage[T, T] { private[stream] final case class DropWhile[T](p: T Boolean) extends GraphStage[FlowShape[T, T]] {
var taking = false val in = Inlet[T]("DropWhile.in")
val out = Outlet[T]("DropWhile.out")
override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.dropWhile
override def onPush(elem: T, ctx: Context[T]): SyncDirective = def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
if (taking || !p(elem)) { override def onPush(): Unit = {
taking = true val elem = grab(in)
ctx.push(elem) withSupervision(() p(elem)) match {
} else { case Some(flag) if flag pull(in)
ctx.pull() case Some(flag) if !flag
push(out, elem)
setHandler(in, rest)
case None // do nothing
}
} }
override def decide(t: Throwable): Supervision.Directive = decider(t) def rest = new InHandler {
def onPush() = push(out, grab(in))
}
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
override def onPull(): Unit = pull(in)
setHandlers(in, out, this)
}
override def toString = "DropWhile"
} }
private[akka] object Collect { /**
* INTERNAL API
*/
abstract private[stream] class SupervisedGraphStageLogic(inheritedAttributes: Attributes, shape: Shape) extends GraphStageLogic(shape) {
private lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
def withSupervision[T](f: () T): Option[T] =
try { Some(f()) } catch {
case NonFatal(ex)
decider(ex) match {
case Supervision.Stop onStop(ex)
case Supervision.Resume onResume(ex)
case Supervision.Restart onRestart(ex)
}
None
}
def onResume(t: Throwable): Unit
def onStop(t: Throwable): Unit = failStage(t)
def onRestart(t: Throwable): Unit = onResume(t)
}
private[stream] object Collect {
// Cached function that can be used with PartialFunction.applyOrElse to ensure that A) the guard is only applied once, // Cached function that can be used with PartialFunction.applyOrElse to ensure that A) the guard is only applied once,
// and the caller can check the returned value with Collect.notApplied to query whether the PF was applied or not. // and the caller can check the returned value with Collect.notApplied to query whether the PF was applied or not.
// Prior art: https://github.com/scala/scala/blob/v2.11.4/src/library/scala/collection/immutable/List.scala#L458 // Prior art: https://github.com/scala/scala/blob/v2.11.4/src/library/scala/collection/immutable/List.scala#L458
final val NotApplied: Any Any = _ Collect.NotApplied final val NotApplied: Any Any = _ Collect.NotApplied
} }
private[akka] final case class Collect[In, Out](pf: PartialFunction[In, Out], decider: Supervision.Decider) extends PushStage[In, Out] { /**
* INTERNAL API
*/
private[stream] final case class Collect[In, Out](pf: PartialFunction[In, Out]) extends GraphStage[FlowShape[In, Out]] {
val in = Inlet[In]("Collect.in")
val out = Outlet[Out]("Collect.out")
override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.collect
import Collect.NotApplied def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
import Collect.NotApplied
val wrappedPf = () pf.applyOrElse(grab(in), NotApplied)
override def onPush(elem: In, ctx: Context[Out]): SyncDirective = override def onPush(): Unit = withSupervision(wrappedPf) match {
pf.applyOrElse(elem, NotApplied) match { case Some(result) result match {
case NotApplied ctx.pull() case NotApplied pull(in)
case result: Out @unchecked ctx.push(result) case result: Out @unchecked push(out, result)
}
case None //do nothing
} }
override def decide(t: Throwable): Supervision.Directive = decider(t) override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
override def onPull(): Unit = pull(in)
setHandlers(in, out, this)
}
override def toString = "Collect"
} }
/** /**
@ -312,15 +363,33 @@ private[akka] final case class Grouped[T](n: Int) extends PushPullStage[T, immut
/** /**
* INTERNAL API * INTERNAL API
*/ */
private[stream] final case class LimitWeighted[T](n: Long, costFn: T Long) extends GraphStage[FlowShape[T, T]] {
val in = Inlet[T]("LimitWeighted.in")
val out = Outlet[T]("LimitWeighted.out")
override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.limitWeighted
private[akka] final case class LimitWeighted[T](n: Long, costFn: T Long) extends PushStage[T, T] { def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
private var left = n private var left = n
override def onPush(elem: T, ctx: Context[T]): SyncDirective = { override def onPush(): Unit = {
left -= costFn(elem) val elem = grab(in)
if (left >= 0) ctx.push(elem) withSupervision(() costFn(elem)) match {
else ctx.fail(new StreamLimitReachedException(n)) case Some(wight)
left -= wight
if (left >= 0) push(out, elem) else failStage(new StreamLimitReachedException(n))
case None //do nothing
}
}
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
override def onRestart(t: Throwable): Unit = {
left = n
if (!hasBeenPulled(in)) pull(in)
}
override def onPull(): Unit = pull(in)
setHandlers(in, out, this)
} }
override def toString = "LimitWeighted"
} }
/** /**

View file

@ -627,7 +627,7 @@ trait FlowOps[+Out, +Mat] {
* *
* '''Cancels when''' downstream cancels * '''Cancels when''' downstream cancels
*/ */
def dropWhile(p: Out Boolean): Repr[Out] = andThen(DropWhile(p)) def dropWhile(p: Out Boolean): Repr[Out] = via(DropWhile(p))
/** /**
* Transform this stream by applying the given partial function to each of the elements * Transform this stream by applying the given partial function to each of the elements
@ -642,7 +642,7 @@ trait FlowOps[+Out, +Mat] {
* *
* '''Cancels when''' downstream cancels * '''Cancels when''' downstream cancels
*/ */
def collect[T](pf: PartialFunction[Out, T]): Repr[T] = andThen(Collect(pf)) def collect[T](pf: PartialFunction[Out, T]): Repr[T] = via(Collect(pf))
/** /**
* Chunk up this stream into groups of the given size, with the last group * Chunk up this stream into groups of the given size, with the last group
@ -705,7 +705,7 @@ trait FlowOps[+Out, +Mat] {
* *
* See also [[FlowOps.take]], [[FlowOps.takeWithin]], [[FlowOps.takeWhile]] * See also [[FlowOps.take]], [[FlowOps.takeWithin]], [[FlowOps.takeWhile]]
*/ */
def limitWeighted[T](max: Long)(costFn: Out Long): Repr[Out] = andThen(LimitWeighted(max, costFn)) def limitWeighted[T](max: Long)(costFn: Out Long): Repr[Out] = via(LimitWeighted(max, costFn))
/** /**
* Apply a sliding window over the stream and return the windows as groups of elements, with the last group * Apply a sliding window over the stream and return the windows as groups of elements, with the last group

View file

@ -754,7 +754,18 @@ object MiMa extends AutoPlugin {
// #20028 Simplify TickSource cancellation // #20028 Simplify TickSource cancellation
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.fusing.GraphStages$TickSource$TickSourceCancellable"), ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.fusing.GraphStages$TickSource$TickSourceCancellable"),
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.fusing.GraphStages$TickSource$") ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.fusing.GraphStages$TickSource$"),
// #19834 replacing PushStages usages with GraphStages
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$LimitWeighted"),
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$Collect$"),
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$DropWhile"),
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$LimitWeighted$"),
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$Collect"),
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.Stages$DropWhile$"),
FilterAnyProblemStartingWith("akka.stream.impl.fusing.Collect"),
FilterAnyProblemStartingWith("akka.stream.impl.fusing.DropWhile"),
FilterAnyProblemStartingWith("akka.stream.impl.fusing.LimitWeighted")
) )
) )
} }