=str 19834 Migrating PushStages to GraphStage
Collect, DropWhile, LimitWeighted
This commit is contained in:
parent
b3444fa1b0
commit
951afec88e
7 changed files with 144 additions and 64 deletions
|
|
@ -59,38 +59,89 @@ private[akka] final case class TakeWhile[T](p: T ⇒ Boolean, decider: Supervisi
|
|||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[akka] final case class DropWhile[T](p: T ⇒ Boolean, decider: Supervision.Decider) extends PushStage[T, T] {
|
||||
var taking = false
|
||||
private[stream] final case class DropWhile[T](p: T ⇒ Boolean) extends GraphStage[FlowShape[T, T]] {
|
||||
val in = Inlet[T]("DropWhile.in")
|
||||
val out = Outlet[T]("DropWhile.out")
|
||||
override val shape = FlowShape(in, out)
|
||||
override def initialAttributes: Attributes = DefaultAttributes.dropWhile
|
||||
|
||||
override def onPush(elem: T, ctx: Context[T]): SyncDirective =
|
||||
if (taking || !p(elem)) {
|
||||
taking = true
|
||||
ctx.push(elem)
|
||||
} else {
|
||||
ctx.pull()
|
||||
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
|
||||
override def onPush(): Unit = {
|
||||
val elem = grab(in)
|
||||
withSupervision(() ⇒ p(elem)) match {
|
||||
case Some(flag) if flag ⇒ pull(in)
|
||||
case Some(flag) if !flag ⇒
|
||||
push(out, elem)
|
||||
setHandler(in, rest)
|
||||
case None ⇒ // do nothing
|
||||
}
|
||||
}
|
||||
|
||||
override def decide(t: Throwable): Supervision.Directive = decider(t)
|
||||
def rest = new InHandler {
|
||||
def onPush() = push(out, grab(in))
|
||||
}
|
||||
|
||||
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
|
||||
override def onPull(): Unit = pull(in)
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
override def toString = "DropWhile"
|
||||
}
|
||||
|
||||
private[akka] object Collect {
|
||||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
abstract private[stream] class SupervisedGraphStageLogic(inheritedAttributes: Attributes, shape: Shape) extends GraphStageLogic(shape) {
|
||||
private lazy val decider = inheritedAttributes.get[SupervisionStrategy].map(_.decider).getOrElse(Supervision.stoppingDecider)
|
||||
def withSupervision[T](f: () ⇒ T): Option[T] =
|
||||
try { Some(f()) } catch {
|
||||
case NonFatal(ex) ⇒
|
||||
decider(ex) match {
|
||||
case Supervision.Stop ⇒ onStop(ex)
|
||||
case Supervision.Resume ⇒ onResume(ex)
|
||||
case Supervision.Restart ⇒ onRestart(ex)
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
def onResume(t: Throwable): Unit
|
||||
def onStop(t: Throwable): Unit = failStage(t)
|
||||
def onRestart(t: Throwable): Unit = onResume(t)
|
||||
}
|
||||
|
||||
private[stream] object Collect {
|
||||
// Cached function that can be used with PartialFunction.applyOrElse to ensure that A) the guard is only applied once,
|
||||
// and the caller can check the returned value with Collect.notApplied to query whether the PF was applied or not.
|
||||
// Prior art: https://github.com/scala/scala/blob/v2.11.4/src/library/scala/collection/immutable/List.scala#L458
|
||||
final val NotApplied: Any ⇒ Any = _ ⇒ Collect.NotApplied
|
||||
}
|
||||
|
||||
private[akka] final case class Collect[In, Out](pf: PartialFunction[In, Out], decider: Supervision.Decider) extends PushStage[In, Out] {
|
||||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[stream] final case class Collect[In, Out](pf: PartialFunction[In, Out]) extends GraphStage[FlowShape[In, Out]] {
|
||||
val in = Inlet[In]("Collect.in")
|
||||
val out = Outlet[Out]("Collect.out")
|
||||
override val shape = FlowShape(in, out)
|
||||
override def initialAttributes: Attributes = DefaultAttributes.collect
|
||||
|
||||
import Collect.NotApplied
|
||||
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
|
||||
import Collect.NotApplied
|
||||
val wrappedPf = () ⇒ pf.applyOrElse(grab(in), NotApplied)
|
||||
|
||||
override def onPush(elem: In, ctx: Context[Out]): SyncDirective =
|
||||
pf.applyOrElse(elem, NotApplied) match {
|
||||
case NotApplied ⇒ ctx.pull()
|
||||
case result: Out @unchecked ⇒ ctx.push(result)
|
||||
override def onPush(): Unit = withSupervision(wrappedPf) match {
|
||||
case Some(result) ⇒ result match {
|
||||
case NotApplied ⇒ pull(in)
|
||||
case result: Out @unchecked ⇒ push(out, result)
|
||||
}
|
||||
case None ⇒ //do nothing
|
||||
}
|
||||
|
||||
override def decide(t: Throwable): Supervision.Directive = decider(t)
|
||||
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
|
||||
override def onPull(): Unit = pull(in)
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
override def toString = "Collect"
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -312,15 +363,33 @@ private[akka] final case class Grouped[T](n: Int) extends PushPullStage[T, immut
|
|||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[stream] final case class LimitWeighted[T](n: Long, costFn: T ⇒ Long) extends GraphStage[FlowShape[T, T]] {
|
||||
val in = Inlet[T]("LimitWeighted.in")
|
||||
val out = Outlet[T]("LimitWeighted.out")
|
||||
override val shape = FlowShape(in, out)
|
||||
override def initialAttributes: Attributes = DefaultAttributes.limitWeighted
|
||||
|
||||
private[akka] final case class LimitWeighted[T](n: Long, costFn: T ⇒ Long) extends PushStage[T, T] {
|
||||
private var left = n
|
||||
def createLogic(inheritedAttributes: Attributes) = new SupervisedGraphStageLogic(inheritedAttributes, shape) with InHandler with OutHandler {
|
||||
private var left = n
|
||||
|
||||
override def onPush(elem: T, ctx: Context[T]): SyncDirective = {
|
||||
left -= costFn(elem)
|
||||
if (left >= 0) ctx.push(elem)
|
||||
else ctx.fail(new StreamLimitReachedException(n))
|
||||
override def onPush(): Unit = {
|
||||
val elem = grab(in)
|
||||
withSupervision(() ⇒ costFn(elem)) match {
|
||||
case Some(wight) ⇒
|
||||
left -= wight
|
||||
if (left >= 0) push(out, elem) else failStage(new StreamLimitReachedException(n))
|
||||
case None ⇒ //do nothing
|
||||
}
|
||||
}
|
||||
override def onResume(t: Throwable): Unit = if (!hasBeenPulled(in)) pull(in)
|
||||
override def onRestart(t: Throwable): Unit = {
|
||||
left = n
|
||||
if (!hasBeenPulled(in)) pull(in)
|
||||
}
|
||||
override def onPull(): Unit = pull(in)
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
override def toString = "LimitWeighted"
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue