Merge pull request #18473 from drewhk/wip-graphstage-improvements-drewhk
=str: GraphStage improvements
This commit is contained in:
commit
7a85508a5d
6 changed files with 112 additions and 147 deletions
|
|
@ -3,17 +3,14 @@
|
|||
*/
|
||||
package akka.stream.impl.fusing
|
||||
|
||||
import java.util.concurrent.TimeoutException
|
||||
|
||||
import akka.actor.{ ActorSystem, Cancellable, Scheduler }
|
||||
import akka.stream._
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }
|
||||
import akka.stream.testkit.{ TestPublisher, AkkaSpec, TestSubscriber }
|
||||
import akka.stream.testkit.AkkaSpec
|
||||
import akka.stream.testkit.Utils._
|
||||
|
||||
import scala.concurrent.Await
|
||||
import scala.concurrent.duration._
|
||||
import akka.stream.testkit.Utils._
|
||||
|
||||
class ActorGraphInterpreterSpec extends AkkaSpec {
|
||||
implicit val mat = ActorMaterializer()
|
||||
|
|
@ -29,6 +26,19 @@ class ActorGraphInterpreterSpec extends AkkaSpec {
|
|||
|
||||
}
|
||||
|
||||
"be able to reuse a simple identity graph stage" in assertAllStagesStopped {
|
||||
val identity = new GraphStages.Identity[Int]
|
||||
|
||||
Await.result(
|
||||
Source(1 to 100)
|
||||
.via(identity)
|
||||
.via(identity)
|
||||
.via(identity)
|
||||
.grouped(200)
|
||||
.runWith(Sink.head),
|
||||
3.seconds) should ===(1 to 100)
|
||||
}
|
||||
|
||||
"be able to interpret a simple bidi stage" in assertAllStagesStopped {
|
||||
val identityBidi = new GraphStage[BidiShape[Int, Int, Int, Int]] {
|
||||
val in1 = Inlet[Int]("in1")
|
||||
|
|
@ -70,6 +80,52 @@ class ActorGraphInterpreterSpec extends AkkaSpec {
|
|||
|
||||
}
|
||||
|
||||
"be able to interpret and resuse a simple bidi stage" in assertAllStagesStopped {
|
||||
val identityBidi = new GraphStage[BidiShape[Int, Int, Int, Int]] {
|
||||
val in1 = Inlet[Int]("in1")
|
||||
val in2 = Inlet[Int]("in2")
|
||||
val out1 = Outlet[Int]("out1")
|
||||
val out2 = Outlet[Int]("out2")
|
||||
val shape = BidiShape(in1, out1, in2, out2)
|
||||
|
||||
override def createLogic: GraphStageLogic = new GraphStageLogic {
|
||||
setHandler(in1, new InHandler {
|
||||
override def onPush(): Unit = push(out1, grab(in1))
|
||||
|
||||
override def onUpstreamFinish(): Unit = complete(out1)
|
||||
})
|
||||
|
||||
setHandler(in2, new InHandler {
|
||||
override def onPush(): Unit = push(out2, grab(in2))
|
||||
|
||||
override def onUpstreamFinish(): Unit = complete(out2)
|
||||
})
|
||||
|
||||
setHandler(out1, new OutHandler {
|
||||
override def onPull(): Unit = pull(in1)
|
||||
|
||||
override def onDownstreamFinish(): Unit = cancel(in1)
|
||||
})
|
||||
|
||||
setHandler(out2, new OutHandler {
|
||||
override def onPull(): Unit = pull(in2)
|
||||
|
||||
override def onDownstreamFinish(): Unit = cancel(in2)
|
||||
})
|
||||
}
|
||||
|
||||
override def toString = "IdentityBidi"
|
||||
}
|
||||
|
||||
val identityBidiF = BidiFlow.wrap(identityBidi)
|
||||
val identity = (identityBidiF atop identityBidiF atop identityBidiF).join(Flow[Int].map { x ⇒ x })
|
||||
|
||||
Await.result(
|
||||
Source(1 to 10).via(identity).grouped(100).runWith(Sink.head),
|
||||
3.seconds) should ===(1 to 10)
|
||||
|
||||
}
|
||||
|
||||
"be able to interpret a rotated identity bidi stage" in assertAllStagesStopped {
|
||||
// This is a "rotated" identity BidiStage, as it loops back upstream elements
|
||||
// to its upstream, and loops back downstream elementd to its downstream.
|
||||
|
|
@ -84,21 +140,25 @@ class ActorGraphInterpreterSpec extends AkkaSpec {
|
|||
override def createLogic: GraphStageLogic = new GraphStageLogic {
|
||||
setHandler(in1, new InHandler {
|
||||
override def onPush(): Unit = push(out2, grab(in1))
|
||||
|
||||
override def onUpstreamFinish(): Unit = complete(out2)
|
||||
})
|
||||
|
||||
setHandler(in2, new InHandler {
|
||||
override def onPush(): Unit = push(out1, grab(in2))
|
||||
|
||||
override def onUpstreamFinish(): Unit = complete(out1)
|
||||
})
|
||||
|
||||
setHandler(out1, new OutHandler {
|
||||
override def onPull(): Unit = pull(in2)
|
||||
|
||||
override def onDownstreamFinish(): Unit = cancel(in2)
|
||||
})
|
||||
|
||||
setHandler(out2, new OutHandler {
|
||||
override def onPull(): Unit = pull(in1)
|
||||
|
||||
override def onDownstreamFinish(): Unit = cancel(in1)
|
||||
})
|
||||
}
|
||||
|
|
@ -124,113 +184,5 @@ class ActorGraphInterpreterSpec extends AkkaSpec {
|
|||
Await.result(f2, 3.seconds) should ===(1 to 10)
|
||||
}
|
||||
|
||||
"be able to implement a timeout bidiStage" in {
|
||||
class IdleTimeout[I, O](
|
||||
val system: ActorSystem,
|
||||
val timeout: FiniteDuration) extends GraphStage[BidiShape[I, I, O, O]] {
|
||||
val in1 = Inlet[I]("in1")
|
||||
val in2 = Inlet[O]("in2")
|
||||
val out1 = Outlet[I]("out1")
|
||||
val out2 = Outlet[O]("out2")
|
||||
val shape = BidiShape(in1, out1, in2, out2)
|
||||
|
||||
override def toString = "IdleTimeout"
|
||||
|
||||
override def createLogic: GraphStageLogic = new GraphStageLogic {
|
||||
private var timerCancellable: Option[Cancellable] = None
|
||||
private var nextDeadline: Deadline = Deadline.now + timeout
|
||||
|
||||
setHandler(in1, new InHandler {
|
||||
override def onPush(): Unit = {
|
||||
onActivity()
|
||||
push(out1, grab(in1))
|
||||
}
|
||||
override def onUpstreamFinish(): Unit = complete(out1)
|
||||
})
|
||||
|
||||
setHandler(in2, new InHandler {
|
||||
override def onPush(): Unit = {
|
||||
onActivity()
|
||||
push(out2, grab(in2))
|
||||
}
|
||||
override def onUpstreamFinish(): Unit = complete(out2)
|
||||
})
|
||||
|
||||
setHandler(out1, new OutHandler {
|
||||
override def onPull(): Unit = pull(in1)
|
||||
override def onDownstreamFinish(): Unit = cancel(in1)
|
||||
})
|
||||
|
||||
setHandler(out2, new OutHandler {
|
||||
override def onPull(): Unit = pull(in2)
|
||||
override def onDownstreamFinish(): Unit = cancel(in2)
|
||||
})
|
||||
|
||||
private def onActivity(): Unit = nextDeadline = Deadline.now + timeout
|
||||
|
||||
private def onTimerTick(): Unit =
|
||||
if (nextDeadline.isOverdue())
|
||||
failStage(new TimeoutException(s"No reads or writes happened in $timeout."))
|
||||
|
||||
override def preStart(): Unit = {
|
||||
super.preStart()
|
||||
val checkPeriod = timeout / 8
|
||||
val callback = getAsyncCallback[Unit]((_) ⇒ onTimerTick())
|
||||
import system.dispatcher
|
||||
timerCancellable = Some(system.scheduler.schedule(timeout, checkPeriod)(callback.invoke(())))
|
||||
}
|
||||
|
||||
override def postStop(): Unit = {
|
||||
super.postStop()
|
||||
timerCancellable.foreach(_.cancel())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val upWrite = TestPublisher.probe[String]()
|
||||
val upRead = TestSubscriber.probe[Int]()
|
||||
|
||||
val downWrite = TestPublisher.probe[Int]()
|
||||
val downRead = TestSubscriber.probe[String]()
|
||||
|
||||
FlowGraph.closed() { implicit b ⇒
|
||||
import FlowGraph.Implicits._
|
||||
val timeoutStage = b.add(new IdleTimeout[String, Int](system, 2.seconds))
|
||||
Source(upWrite) ~> timeoutStage.in1; timeoutStage.out1 ~> Sink(downRead)
|
||||
Sink(upRead) <~ timeoutStage.out2; timeoutStage.in2 <~ Source(downWrite)
|
||||
}.run()
|
||||
|
||||
// Request enough for the whole test
|
||||
upRead.request(100)
|
||||
downRead.request(100)
|
||||
|
||||
upWrite.sendNext("DATA1")
|
||||
downRead.expectNext("DATA1")
|
||||
Thread.sleep(1500)
|
||||
|
||||
downWrite.sendNext(1)
|
||||
upRead.expectNext(1)
|
||||
Thread.sleep(1500)
|
||||
|
||||
upWrite.sendNext("DATA2")
|
||||
downRead.expectNext("DATA2")
|
||||
Thread.sleep(1000)
|
||||
|
||||
downWrite.sendNext(2)
|
||||
upRead.expectNext(2)
|
||||
|
||||
upRead.expectNoMsg(500.millis)
|
||||
val error1 = upRead.expectError()
|
||||
val error2 = downRead.expectError()
|
||||
|
||||
error1.isInstanceOf[TimeoutException] should be(true)
|
||||
error1.getMessage should be("No reads or writes happened in 2 seconds.")
|
||||
error2 should ===(error1)
|
||||
|
||||
upWrite.expectCancellation()
|
||||
downWrite.expectCancellation()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import akka.actor._
|
|||
import akka.event.Logging
|
||||
import akka.stream._
|
||||
import akka.stream.impl.ReactiveStreamsCompliance._
|
||||
import akka.stream.impl.StreamLayout.Module
|
||||
import akka.stream.impl.StreamLayout.{ CopiedModule, Module }
|
||||
import akka.stream.impl.fusing.GraphInterpreter.{ DownstreamBoundaryStageLogic, UpstreamBoundaryStageLogic, GraphAssembly }
|
||||
import akka.stream.impl.{ ActorPublisher, ReactiveStreamsCompliance }
|
||||
import akka.stream.stage.{ GraphStageLogic, InHandler, OutHandler }
|
||||
|
|
@ -22,9 +22,13 @@ private[stream] case class GraphModule(assembly: GraphAssembly, shape: Shape, at
|
|||
override def subModules: Set[Module] = Set.empty
|
||||
override def withAttributes(newAttr: Attributes): Module = copy(attributes = newAttr)
|
||||
|
||||
override def carbonCopy: Module = copy()
|
||||
override final def carbonCopy: Module = {
|
||||
val newShape = shape.deepCopy()
|
||||
replaceShape(newShape)
|
||||
}
|
||||
|
||||
override def replaceShape(s: Shape): Module = ???
|
||||
override final def replaceShape(newShape: Shape): Module =
|
||||
CopiedModule(newShape, attributes, copyOf = this)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -50,7 +54,7 @@ private[stream] object ActorGraphInterpreter {
|
|||
final class BoundarySubscription(val parent: ActorRef, val id: Int) extends Subscription {
|
||||
override def request(elements: Long): Unit = parent ! RequestMore(id, elements)
|
||||
override def cancel(): Unit = parent ! Cancel(id)
|
||||
override def toString = "BoundarySubscription" + System.identityHashCode(this)
|
||||
override def toString = s"BoundarySubscription[$parent, $id]"
|
||||
}
|
||||
|
||||
final class BoundarySubscriber(val parent: ActorRef, id: Int) extends Subscriber[Any] {
|
||||
|
|
@ -306,6 +310,25 @@ private[stream] class ActorGraphInterpreter(assembly: GraphAssembly, shape: Shap
|
|||
}
|
||||
|
||||
override def receive: Receive = {
|
||||
// Cases that are most likely on the hot path, in decreasing order of frequency
|
||||
case OnNext(id: Int, e: Any) ⇒
|
||||
if (GraphInterpreter.Debug) println(s" onNext $e id=$id")
|
||||
inputs(id).onNext(e)
|
||||
runBatch()
|
||||
case RequestMore(id: Int, demand: Long) ⇒
|
||||
if (GraphInterpreter.Debug) println(s" request $demand id=$id")
|
||||
outputs(id).requestMore(demand)
|
||||
runBatch()
|
||||
case Resume ⇒
|
||||
resumeScheduled = false
|
||||
if (interpreter.isSuspended) runBatch()
|
||||
case AsyncInput(logic, event, handler) ⇒
|
||||
if (GraphInterpreter.Debug) println(s"ASYNC $event")
|
||||
if (!interpreter.isStageCompleted(logic.stageId))
|
||||
handler(event)
|
||||
runBatch()
|
||||
|
||||
// Initialization and completion messages
|
||||
case OnError(id: Int, cause: Throwable) ⇒
|
||||
if (GraphInterpreter.Debug) println(s" onError id=$id")
|
||||
inputs(id).onError(cause)
|
||||
|
|
@ -314,17 +337,8 @@ private[stream] class ActorGraphInterpreter(assembly: GraphAssembly, shape: Shap
|
|||
if (GraphInterpreter.Debug) println(s" onComplete id=$id")
|
||||
inputs(id).onComplete()
|
||||
runBatch()
|
||||
case OnNext(id: Int, e: Any) ⇒
|
||||
if (GraphInterpreter.Debug) println(s" onNext $e id=$id")
|
||||
inputs(id).onNext(e)
|
||||
runBatch()
|
||||
case OnSubscribe(id: Int, subscription: Subscription) ⇒
|
||||
inputs(id).onSubscribe(subscription)
|
||||
|
||||
case RequestMore(id: Int, demand: Long) ⇒
|
||||
if (GraphInterpreter.Debug) println(s" request $demand id=$id")
|
||||
outputs(id).requestMore(demand)
|
||||
runBatch()
|
||||
case Cancel(id: Int) ⇒
|
||||
if (GraphInterpreter.Debug) println(s" cancel id=$id")
|
||||
outputs(id).cancel()
|
||||
|
|
@ -333,16 +347,6 @@ private[stream] class ActorGraphInterpreter(assembly: GraphAssembly, shape: Shap
|
|||
outputs(id).subscribePending()
|
||||
case ExposedPublisher(id, publisher) ⇒
|
||||
outputs(id).exposedPublisher(publisher)
|
||||
|
||||
case AsyncInput(_, event, handler) ⇒
|
||||
if (GraphInterpreter.Debug) println(s"ASYNC $event")
|
||||
handler(event)
|
||||
runBatch()
|
||||
|
||||
case Resume ⇒
|
||||
resumeScheduled = false
|
||||
if (interpreter.isSuspended) runBatch()
|
||||
|
||||
}
|
||||
|
||||
override protected[akka] def aroundReceive(receive: Actor.Receive, msg: Any): Unit = {
|
||||
|
|
|
|||
|
|
@ -255,6 +255,7 @@ private[stream] final class GraphInterpreter(
|
|||
def init(): Unit = {
|
||||
var i = 0
|
||||
while (i < logics.length) {
|
||||
logics(i).stageId = i
|
||||
logics(i).preStart()
|
||||
i += 1
|
||||
}
|
||||
|
|
@ -373,7 +374,7 @@ private[stream] final class GraphInterpreter(
|
|||
def isConnectionCompleted(connection: Int): Boolean = connectionStates(connection).isInstanceOf[CompletedState]
|
||||
|
||||
// Returns true if the given stage is alredy completed
|
||||
private def isStageCompleted(stageId: Int): Boolean = stageId != Boundary && shutdownCounter(stageId) == 0
|
||||
def isStageCompleted(stageId: Int): Boolean = stageId != Boundary && shutdownCounter(stageId) == 0
|
||||
|
||||
private def isPushInFlight(connection: Int): Boolean =
|
||||
!inAvailable(connection) &&
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ object GraphStages {
|
|||
val in = Inlet[T]("in")
|
||||
val out = Outlet[T]("out")
|
||||
|
||||
val shape = FlowShape(in, out)
|
||||
override val shape = FlowShape(in, out)
|
||||
|
||||
override def createLogic: GraphStageLogic = new GraphStageLogic {
|
||||
setHandler(in, new InHandler {
|
||||
|
|
@ -36,7 +36,7 @@ object GraphStages {
|
|||
class Detacher[T] extends GraphStage[FlowShape[T, T]] {
|
||||
val in = Inlet[T]("in")
|
||||
val out = Outlet[T]("out")
|
||||
val shape = FlowShape(in, out)
|
||||
override val shape = FlowShape(in, out)
|
||||
|
||||
override def createLogic: GraphStageLogic = new GraphStageLogic {
|
||||
var initialized = false
|
||||
|
|
@ -70,7 +70,7 @@ object GraphStages {
|
|||
class Broadcast[T](private val outCount: Int) extends GraphStage[UniformFanOutShape[T, T]] {
|
||||
val in = Inlet[T]("in")
|
||||
val out = Vector.fill(outCount)(Outlet[T]("out"))
|
||||
val shape = UniformFanOutShape(in, out: _*)
|
||||
override val shape = UniformFanOutShape(in, out: _*)
|
||||
|
||||
override def createLogic: GraphStageLogic = new GraphStageLogic {
|
||||
private var pending = outCount
|
||||
|
|
@ -101,7 +101,7 @@ object GraphStages {
|
|||
val in0 = Inlet[A]("in0")
|
||||
val in1 = Inlet[B]("in1")
|
||||
val out = Outlet[(A, B)]("out")
|
||||
val shape = new FanInShape2[A, B, (A, B)](in0, in1, out)
|
||||
override val shape = new FanInShape2[A, B, (A, B)](in0, in1, out)
|
||||
|
||||
override def createLogic: GraphStageLogic = new GraphStageLogic {
|
||||
var pending = 2
|
||||
|
|
@ -130,7 +130,7 @@ object GraphStages {
|
|||
class Merge[T](private val inCount: Int) extends GraphStage[UniformFanInShape[T, T]] {
|
||||
val in = Vector.fill(inCount)(Inlet[T]("in"))
|
||||
val out = Outlet[T]("out")
|
||||
val shape = UniformFanInShape(out, in: _*)
|
||||
override val shape = UniformFanInShape(out, in: _*)
|
||||
|
||||
override def createLogic: GraphStageLogic = new GraphStageLogic {
|
||||
private var initialized = false
|
||||
|
|
@ -187,7 +187,7 @@ object GraphStages {
|
|||
class Balance[T](private val outCount: Int) extends GraphStage[UniformFanOutShape[T, T]] {
|
||||
val in = Inlet[T]("in")
|
||||
val out = Vector.fill(outCount)(Outlet[T]("out"))
|
||||
val shape = UniformFanOutShape[T, T](in, out: _*)
|
||||
override val shape = UniformFanOutShape[T, T](in, out: _*)
|
||||
|
||||
override def createLogic: GraphStageLogic = new GraphStageLogic {
|
||||
private val pendingQueue = Array.ofDim[Outlet[T]](outCount)
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ final class Source[+Out, +Mat](private[stream] override val module: Module)
|
|||
new Source(
|
||||
module
|
||||
.fuse(flowCopy, shape.outlet, flowCopy.shape.inlets.head, combine)
|
||||
.replaceShape(SourceShape(flowCopy.shape.outlets.head))) // FIXME why is not .wrap() needed here?
|
||||
.replaceShape(SourceShape(flowCopy.shape.outlets.head)))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -43,9 +43,12 @@ abstract class GraphStage[S <: Shape] extends Graph[S, Unit] {
|
|||
* This method throws an [[UnsupportedOperationException]] by default. The subclass can override this method
|
||||
* and provide a correct implementation that creates an exact copy of the stage with the provided new attributes.
|
||||
*/
|
||||
override def withAttributes(attr: Attributes): Graph[S, Unit] =
|
||||
throw new UnsupportedOperationException(
|
||||
"withAttributes not supported by default by FlexiMerge, subclass may override and implement it")
|
||||
final override def withAttributes(attr: Attributes): Graph[S, Unit] = new Graph[S, Unit] {
|
||||
override def shape = GraphStage.this.shape
|
||||
override private[stream] def module = GraphStage.this.module.withAttributes(attr)
|
||||
|
||||
override def withAttributes(attr: Attributes) = GraphStage.this.withAttributes(attr)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -64,6 +67,11 @@ abstract class GraphStage[S <: Shape] extends Graph[S, Unit] {
|
|||
abstract class GraphStageLogic {
|
||||
import GraphInterpreter._
|
||||
|
||||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[stream] var stageId: Int = Int.MinValue
|
||||
|
||||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue