Merge pull request #18473 from drewhk/wip-graphstage-improvements-drewhk

=str: GraphStage improvements
This commit is contained in:
drewhk 2015-09-17 12:57:57 +02:00
commit 7a85508a5d
6 changed files with 112 additions and 147 deletions

View file

@ -3,17 +3,14 @@
*/
package akka.stream.impl.fusing
import java.util.concurrent.TimeoutException
import akka.actor.{ ActorSystem, Cancellable, Scheduler }
import akka.stream._
import akka.stream.scaladsl._
import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }
import akka.stream.testkit.{ TestPublisher, AkkaSpec, TestSubscriber }
import akka.stream.testkit.AkkaSpec
import akka.stream.testkit.Utils._
import scala.concurrent.Await
import scala.concurrent.duration._
import akka.stream.testkit.Utils._
class ActorGraphInterpreterSpec extends AkkaSpec {
implicit val mat = ActorMaterializer()
@ -29,6 +26,19 @@ class ActorGraphInterpreterSpec extends AkkaSpec {
}
"be able to reuse a simple identity graph stage" in assertAllStagesStopped {
val identity = new GraphStages.Identity[Int]
Await.result(
Source(1 to 100)
.via(identity)
.via(identity)
.via(identity)
.grouped(200)
.runWith(Sink.head),
3.seconds) should ===(1 to 100)
}
"be able to interpret a simple bidi stage" in assertAllStagesStopped {
val identityBidi = new GraphStage[BidiShape[Int, Int, Int, Int]] {
val in1 = Inlet[Int]("in1")
@ -70,6 +80,52 @@ class ActorGraphInterpreterSpec extends AkkaSpec {
}
"be able to interpret and resuse a simple bidi stage" in assertAllStagesStopped {
val identityBidi = new GraphStage[BidiShape[Int, Int, Int, Int]] {
val in1 = Inlet[Int]("in1")
val in2 = Inlet[Int]("in2")
val out1 = Outlet[Int]("out1")
val out2 = Outlet[Int]("out2")
val shape = BidiShape(in1, out1, in2, out2)
override def createLogic: GraphStageLogic = new GraphStageLogic {
setHandler(in1, new InHandler {
override def onPush(): Unit = push(out1, grab(in1))
override def onUpstreamFinish(): Unit = complete(out1)
})
setHandler(in2, new InHandler {
override def onPush(): Unit = push(out2, grab(in2))
override def onUpstreamFinish(): Unit = complete(out2)
})
setHandler(out1, new OutHandler {
override def onPull(): Unit = pull(in1)
override def onDownstreamFinish(): Unit = cancel(in1)
})
setHandler(out2, new OutHandler {
override def onPull(): Unit = pull(in2)
override def onDownstreamFinish(): Unit = cancel(in2)
})
}
override def toString = "IdentityBidi"
}
val identityBidiF = BidiFlow.wrap(identityBidi)
val identity = (identityBidiF atop identityBidiF atop identityBidiF).join(Flow[Int].map { x x })
Await.result(
Source(1 to 10).via(identity).grouped(100).runWith(Sink.head),
3.seconds) should ===(1 to 10)
}
"be able to interpret a rotated identity bidi stage" in assertAllStagesStopped {
// This is a "rotated" identity BidiStage, as it loops back upstream elements
// to its upstream, and loops back downstream elementd to its downstream.
@ -84,21 +140,25 @@ class ActorGraphInterpreterSpec extends AkkaSpec {
override def createLogic: GraphStageLogic = new GraphStageLogic {
setHandler(in1, new InHandler {
override def onPush(): Unit = push(out2, grab(in1))
override def onUpstreamFinish(): Unit = complete(out2)
})
setHandler(in2, new InHandler {
override def onPush(): Unit = push(out1, grab(in2))
override def onUpstreamFinish(): Unit = complete(out1)
})
setHandler(out1, new OutHandler {
override def onPull(): Unit = pull(in2)
override def onDownstreamFinish(): Unit = cancel(in2)
})
setHandler(out2, new OutHandler {
override def onPull(): Unit = pull(in1)
override def onDownstreamFinish(): Unit = cancel(in1)
})
}
@ -124,113 +184,5 @@ class ActorGraphInterpreterSpec extends AkkaSpec {
Await.result(f2, 3.seconds) should ===(1 to 10)
}
"be able to implement a timeout bidiStage" in {
class IdleTimeout[I, O](
val system: ActorSystem,
val timeout: FiniteDuration) extends GraphStage[BidiShape[I, I, O, O]] {
val in1 = Inlet[I]("in1")
val in2 = Inlet[O]("in2")
val out1 = Outlet[I]("out1")
val out2 = Outlet[O]("out2")
val shape = BidiShape(in1, out1, in2, out2)
override def toString = "IdleTimeout"
override def createLogic: GraphStageLogic = new GraphStageLogic {
private var timerCancellable: Option[Cancellable] = None
private var nextDeadline: Deadline = Deadline.now + timeout
setHandler(in1, new InHandler {
override def onPush(): Unit = {
onActivity()
push(out1, grab(in1))
}
override def onUpstreamFinish(): Unit = complete(out1)
})
setHandler(in2, new InHandler {
override def onPush(): Unit = {
onActivity()
push(out2, grab(in2))
}
override def onUpstreamFinish(): Unit = complete(out2)
})
setHandler(out1, new OutHandler {
override def onPull(): Unit = pull(in1)
override def onDownstreamFinish(): Unit = cancel(in1)
})
setHandler(out2, new OutHandler {
override def onPull(): Unit = pull(in2)
override def onDownstreamFinish(): Unit = cancel(in2)
})
private def onActivity(): Unit = nextDeadline = Deadline.now + timeout
private def onTimerTick(): Unit =
if (nextDeadline.isOverdue())
failStage(new TimeoutException(s"No reads or writes happened in $timeout."))
override def preStart(): Unit = {
super.preStart()
val checkPeriod = timeout / 8
val callback = getAsyncCallback[Unit]((_) onTimerTick())
import system.dispatcher
timerCancellable = Some(system.scheduler.schedule(timeout, checkPeriod)(callback.invoke(())))
}
override def postStop(): Unit = {
super.postStop()
timerCancellable.foreach(_.cancel())
}
}
}
val upWrite = TestPublisher.probe[String]()
val upRead = TestSubscriber.probe[Int]()
val downWrite = TestPublisher.probe[Int]()
val downRead = TestSubscriber.probe[String]()
FlowGraph.closed() { implicit b
import FlowGraph.Implicits._
val timeoutStage = b.add(new IdleTimeout[String, Int](system, 2.seconds))
Source(upWrite) ~> timeoutStage.in1; timeoutStage.out1 ~> Sink(downRead)
Sink(upRead) <~ timeoutStage.out2; timeoutStage.in2 <~ Source(downWrite)
}.run()
// Request enough for the whole test
upRead.request(100)
downRead.request(100)
upWrite.sendNext("DATA1")
downRead.expectNext("DATA1")
Thread.sleep(1500)
downWrite.sendNext(1)
upRead.expectNext(1)
Thread.sleep(1500)
upWrite.sendNext("DATA2")
downRead.expectNext("DATA2")
Thread.sleep(1000)
downWrite.sendNext(2)
upRead.expectNext(2)
upRead.expectNoMsg(500.millis)
val error1 = upRead.expectError()
val error2 = downRead.expectError()
error1.isInstanceOf[TimeoutException] should be(true)
error1.getMessage should be("No reads or writes happened in 2 seconds.")
error2 should ===(error1)
upWrite.expectCancellation()
downWrite.expectCancellation()
}
}
}

View file

@ -7,7 +7,7 @@ import akka.actor._
import akka.event.Logging
import akka.stream._
import akka.stream.impl.ReactiveStreamsCompliance._
import akka.stream.impl.StreamLayout.Module
import akka.stream.impl.StreamLayout.{ CopiedModule, Module }
import akka.stream.impl.fusing.GraphInterpreter.{ DownstreamBoundaryStageLogic, UpstreamBoundaryStageLogic, GraphAssembly }
import akka.stream.impl.{ ActorPublisher, ReactiveStreamsCompliance }
import akka.stream.stage.{ GraphStageLogic, InHandler, OutHandler }
@ -22,9 +22,13 @@ private[stream] case class GraphModule(assembly: GraphAssembly, shape: Shape, at
override def subModules: Set[Module] = Set.empty
override def withAttributes(newAttr: Attributes): Module = copy(attributes = newAttr)
override def carbonCopy: Module = copy()
override final def carbonCopy: Module = {
val newShape = shape.deepCopy()
replaceShape(newShape)
}
override def replaceShape(s: Shape): Module = ???
override final def replaceShape(newShape: Shape): Module =
CopiedModule(newShape, attributes, copyOf = this)
}
/**
@ -50,7 +54,7 @@ private[stream] object ActorGraphInterpreter {
final class BoundarySubscription(val parent: ActorRef, val id: Int) extends Subscription {
override def request(elements: Long): Unit = parent ! RequestMore(id, elements)
override def cancel(): Unit = parent ! Cancel(id)
override def toString = "BoundarySubscription" + System.identityHashCode(this)
override def toString = s"BoundarySubscription[$parent, $id]"
}
final class BoundarySubscriber(val parent: ActorRef, id: Int) extends Subscriber[Any] {
@ -306,6 +310,25 @@ private[stream] class ActorGraphInterpreter(assembly: GraphAssembly, shape: Shap
}
override def receive: Receive = {
// Cases that are most likely on the hot path, in decreasing order of frequency
case OnNext(id: Int, e: Any)
if (GraphInterpreter.Debug) println(s" onNext $e id=$id")
inputs(id).onNext(e)
runBatch()
case RequestMore(id: Int, demand: Long)
if (GraphInterpreter.Debug) println(s" request $demand id=$id")
outputs(id).requestMore(demand)
runBatch()
case Resume
resumeScheduled = false
if (interpreter.isSuspended) runBatch()
case AsyncInput(logic, event, handler)
if (GraphInterpreter.Debug) println(s"ASYNC $event")
if (!interpreter.isStageCompleted(logic.stageId))
handler(event)
runBatch()
// Initialization and completion messages
case OnError(id: Int, cause: Throwable)
if (GraphInterpreter.Debug) println(s" onError id=$id")
inputs(id).onError(cause)
@ -314,17 +337,8 @@ private[stream] class ActorGraphInterpreter(assembly: GraphAssembly, shape: Shap
if (GraphInterpreter.Debug) println(s" onComplete id=$id")
inputs(id).onComplete()
runBatch()
case OnNext(id: Int, e: Any)
if (GraphInterpreter.Debug) println(s" onNext $e id=$id")
inputs(id).onNext(e)
runBatch()
case OnSubscribe(id: Int, subscription: Subscription)
inputs(id).onSubscribe(subscription)
case RequestMore(id: Int, demand: Long)
if (GraphInterpreter.Debug) println(s" request $demand id=$id")
outputs(id).requestMore(demand)
runBatch()
case Cancel(id: Int)
if (GraphInterpreter.Debug) println(s" cancel id=$id")
outputs(id).cancel()
@ -333,16 +347,6 @@ private[stream] class ActorGraphInterpreter(assembly: GraphAssembly, shape: Shap
outputs(id).subscribePending()
case ExposedPublisher(id, publisher)
outputs(id).exposedPublisher(publisher)
case AsyncInput(_, event, handler)
if (GraphInterpreter.Debug) println(s"ASYNC $event")
handler(event)
runBatch()
case Resume
resumeScheduled = false
if (interpreter.isSuspended) runBatch()
}
override protected[akka] def aroundReceive(receive: Actor.Receive, msg: Any): Unit = {

View file

@ -255,6 +255,7 @@ private[stream] final class GraphInterpreter(
def init(): Unit = {
var i = 0
while (i < logics.length) {
logics(i).stageId = i
logics(i).preStart()
i += 1
}
@ -373,7 +374,7 @@ private[stream] final class GraphInterpreter(
def isConnectionCompleted(connection: Int): Boolean = connectionStates(connection).isInstanceOf[CompletedState]
// Returns true if the given stage is alredy completed
private def isStageCompleted(stageId: Int): Boolean = stageId != Boundary && shutdownCounter(stageId) == 0
def isStageCompleted(stageId: Int): Boolean = stageId != Boundary && shutdownCounter(stageId) == 0
private def isPushInFlight(connection: Int): Boolean =
!inAvailable(connection) &&

View file

@ -15,7 +15,7 @@ object GraphStages {
val in = Inlet[T]("in")
val out = Outlet[T]("out")
val shape = FlowShape(in, out)
override val shape = FlowShape(in, out)
override def createLogic: GraphStageLogic = new GraphStageLogic {
setHandler(in, new InHandler {
@ -36,7 +36,7 @@ object GraphStages {
class Detacher[T] extends GraphStage[FlowShape[T, T]] {
val in = Inlet[T]("in")
val out = Outlet[T]("out")
val shape = FlowShape(in, out)
override val shape = FlowShape(in, out)
override def createLogic: GraphStageLogic = new GraphStageLogic {
var initialized = false
@ -70,7 +70,7 @@ object GraphStages {
class Broadcast[T](private val outCount: Int) extends GraphStage[UniformFanOutShape[T, T]] {
val in = Inlet[T]("in")
val out = Vector.fill(outCount)(Outlet[T]("out"))
val shape = UniformFanOutShape(in, out: _*)
override val shape = UniformFanOutShape(in, out: _*)
override def createLogic: GraphStageLogic = new GraphStageLogic {
private var pending = outCount
@ -101,7 +101,7 @@ object GraphStages {
val in0 = Inlet[A]("in0")
val in1 = Inlet[B]("in1")
val out = Outlet[(A, B)]("out")
val shape = new FanInShape2[A, B, (A, B)](in0, in1, out)
override val shape = new FanInShape2[A, B, (A, B)](in0, in1, out)
override def createLogic: GraphStageLogic = new GraphStageLogic {
var pending = 2
@ -130,7 +130,7 @@ object GraphStages {
class Merge[T](private val inCount: Int) extends GraphStage[UniformFanInShape[T, T]] {
val in = Vector.fill(inCount)(Inlet[T]("in"))
val out = Outlet[T]("out")
val shape = UniformFanInShape(out, in: _*)
override val shape = UniformFanInShape(out, in: _*)
override def createLogic: GraphStageLogic = new GraphStageLogic {
private var initialized = false
@ -187,7 +187,7 @@ object GraphStages {
class Balance[T](private val outCount: Int) extends GraphStage[UniformFanOutShape[T, T]] {
val in = Inlet[T]("in")
val out = Vector.fill(outCount)(Outlet[T]("out"))
val shape = UniformFanOutShape[T, T](in, out: _*)
override val shape = UniformFanOutShape[T, T](in, out: _*)
override def createLogic: GraphStageLogic = new GraphStageLogic {
private val pendingQueue = Array.ofDim[Outlet[T]](outCount)

View file

@ -44,7 +44,7 @@ final class Source[+Out, +Mat](private[stream] override val module: Module)
new Source(
module
.fuse(flowCopy, shape.outlet, flowCopy.shape.inlets.head, combine)
.replaceShape(SourceShape(flowCopy.shape.outlets.head))) // FIXME why is not .wrap() needed here?
.replaceShape(SourceShape(flowCopy.shape.outlets.head)))
}
}

View file

@ -43,9 +43,12 @@ abstract class GraphStage[S <: Shape] extends Graph[S, Unit] {
* This method throws an [[UnsupportedOperationException]] by default. The subclass can override this method
* and provide a correct implementation that creates an exact copy of the stage with the provided new attributes.
*/
override def withAttributes(attr: Attributes): Graph[S, Unit] =
throw new UnsupportedOperationException(
"withAttributes not supported by default by FlexiMerge, subclass may override and implement it")
final override def withAttributes(attr: Attributes): Graph[S, Unit] = new Graph[S, Unit] {
override def shape = GraphStage.this.shape
override private[stream] def module = GraphStage.this.module.withAttributes(attr)
override def withAttributes(attr: Attributes) = GraphStage.this.withAttributes(attr)
}
}
/**
@ -64,6 +67,11 @@ abstract class GraphStage[S <: Shape] extends Graph[S, Unit] {
abstract class GraphStageLogic {
import GraphInterpreter._
/**
* INTERNAL API
*/
private[stream] var stageId: Int = Int.MinValue
/**
* INTERNAL API
*/