=str 19834 Migrating PushStages to GraphStage

Collect, DropWhile, LimitWeighted
This commit is contained in:
Alexander Golubev 2016-03-09 20:46:42 -05:00
parent b3444fa1b0
commit 951afec88e
7 changed files with 144 additions and 64 deletions

View file

@ -59,38 +59,89 @@ private[akka] final case class TakeWhile[T](p: T ⇒ Boolean, decider: Supervisi
/**
* INTERNAL API
*/
private[akka] final case class DropWhile[T](p: T Boolean, decider: Supervision.Decider) extends PushStage[T, T] {
var taking = false
private[stream] final case class DropWhile[T](p: T Boolean) extends GraphStage[FlowShape[T, T]] {
val in = Inlet[T]("DropWhile.in")
val out = Outlet[T]("DropWhile.out")
override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.dropWhile
override def onPush(elem: T, ctx: Context[T]): SyncDirective =
if (taking || !p(elem)) {
taking = true
ctx.push(elem)
} else {
ctx.pull()
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
override def onPush(): Unit = {
val elem = grab(in)
withSupervision(() p(elem)) match {
case Some(flag) if flag pull(in)
case Some(flag) if !flag
push(out, elem)
setHandler(in, rest)
case None // do nothing
}
}
override def decide(t: Throwable): Supervision.Directive = decider(t)
def rest = new InHandler {
def onPush() = push(out, grab(in))
}
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
override def onPull(): Unit = pull(in)
setHandlers(in, out, this)
}
override def toString = "DropWhile"
}
private[akka] object Collect {
/**
* INTERNAL API
*/
abstract private[stream] class SupervisedGraphStageLogic(inheritedAttributes: Attributes, shape: Shape) extends GraphStageLogic(shape) {
private lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
def withSupervision[T](f: () T): Option[T] =
try { Some(f()) } catch {
case NonFatal(ex)
decider(ex) match {
case Supervision.Stop onStop(ex)
case Supervision.Resume onResume(ex)
case Supervision.Restart onRestart(ex)
}
None
}
def onResume(t: Throwable): Unit
def onStop(t: Throwable): Unit = failStage(t)
def onRestart(t: Throwable): Unit = onResume(t)
}
private[stream] object Collect {
// Cached function that can be used with PartialFunction.applyOrElse to ensure that A) the guard is only applied once,
// and the caller can check the returned value with Collect.notApplied to query whether the PF was applied or not.
// Prior art: https://github.com/scala/scala/blob/v2.11.4/src/library/scala/collection/immutable/List.scala#L458
final val NotApplied: Any Any = _ Collect.NotApplied
}
private[akka] final case class Collect[In, Out](pf: PartialFunction[In, Out], decider: Supervision.Decider) extends PushStage[In, Out] {
/**
* INTERNAL API
*/
private[stream] final case class Collect[In, Out](pf: PartialFunction[In, Out]) extends GraphStage[FlowShape[In, Out]] {
val in = Inlet[In]("Collect.in")
val out = Outlet[Out]("Collect.out")
override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.collect
import Collect.NotApplied
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
import Collect.NotApplied
val wrappedPf = () pf.applyOrElse(grab(in), NotApplied)
override def onPush(elem: In, ctx: Context[Out]): SyncDirective =
pf.applyOrElse(elem, NotApplied) match {
case NotApplied ctx.pull()
case result: Out @unchecked ctx.push(result)
override def onPush(): Unit = withSupervision(wrappedPf) match {
case Some(result) result match {
case NotApplied pull(in)
case result: Out @unchecked push(out, result)
}
case None //do nothing
}
override def decide(t: Throwable): Supervision.Directive = decider(t)
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
override def onPull(): Unit = pull(in)
setHandlers(in, out, this)
}
override def toString = "Collect"
}
/**
@ -312,15 +363,33 @@ private[akka] final case class Grouped[T](n: Int) extends PushPullStage[T, immut
/**
* INTERNAL API
*/
private[stream] final case class LimitWeighted[T](n: Long, costFn: T Long) extends GraphStage[FlowShape[T, T]] {
val in = Inlet[T]("LimitWeighted.in")
val out = Outlet[T]("LimitWeighted.out")
override val shape = FlowShape(in, out)
override def initialAttributes: Attributes = DefaultAttributes.limitWeighted
private[akka] final case class LimitWeighted[T](n: Long, costFn: T Long) extends PushStage[T, T] {
private var left = n
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
private var left = n
override def onPush(elem: T, ctx: Context[T]): SyncDirective = {
left -= costFn(elem)
if (left >= 0) ctx.push(elem)
else ctx.fail(new StreamLimitReachedException(n))
override def onPush(): Unit = {
val elem = grab(in)
withSupervision(() costFn(elem)) match {
case Some(wight)
left -= wight
if (left >= 0) push(out, elem) else failStage(new StreamLimitReachedException(n))
case None //do nothing
}
}
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
override def onRestart(t: Throwable): Unit = {
left = n
if (!hasBeenPulled(in)) pull(in)
}
override def onPull(): Unit = pull(in)
setHandlers(in, out, this)
}
override def toString = "LimitWeighted"
}
/**