diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala new file mode 100644 index 0000000000..2ef8eaa305 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala @@ -0,0 +1,236 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.impl.fusing + +import java.util.concurrent.TimeoutException + +import akka.actor.{ ActorSystem, Cancellable, Scheduler } +import akka.stream._ +import akka.stream.scaladsl._ +import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler } +import akka.stream.testkit.{ TestPublisher, AkkaSpec, TestSubscriber } + +import scala.concurrent.Await +import scala.concurrent.duration._ +import akka.stream.testkit.Utils._ + +class ActorGraphInterpreterSpec extends AkkaSpec { + implicit val mat = ActorMaterializer() + + "ActorGraphInterpreter" must { + + "be able to interpret a simple identity graph stage" in assertAllStagesStopped { + val identity = new GraphStages.Identity[Int] + + Await.result( + Source(1 to 100).via(identity).grouped(200).runWith(Sink.head), + 3.seconds) should ===(1 to 100) + + } + + "be able to interpret a simple bidi stage" in assertAllStagesStopped { + val identityBidi = new GraphStage[BidiShape[Int, Int, Int, Int]] { + val in1 = Inlet[Int]("in1") + val in2 = Inlet[Int]("in2") + val out1 = Outlet[Int]("out1") + val out2 = Outlet[Int]("out2") + val shape = BidiShape(in1, out1, in2, out2) + + override def createLogic: GraphStageLogic = new GraphStageLogic { + setHandler(in1, new InHandler { + override def onPush(): Unit = push(out1, grab(in1)) + override def onUpstreamFinish(): Unit = complete(out1) + }) + + setHandler(in2, new InHandler { + override def onPush(): Unit = push(out2, grab(in2)) + override def onUpstreamFinish(): Unit = complete(out2) + }) + + setHandler(out1, new OutHandler { + override def onPull(): Unit = pull(in1) + override def onDownstreamFinish(): Unit = cancel(in1) + }) + + setHandler(out2, new OutHandler { + override def onPull(): Unit = pull(in2) + override def onDownstreamFinish(): Unit = cancel(in2) + }) + } + + override def toString = "IdentityBidi" + } + + val identity = BidiFlow.wrap(identityBidi).join(Flow[Int].map { x ⇒ x }) + + Await.result( + Source(1 to 10).via(identity).grouped(100).runWith(Sink.head), + 3.seconds) should ===(1 to 10) + + } + + "be able to interpret a rotated identity bidi stage" in assertAllStagesStopped { + // This is a "rotated" identity BidiStage, as it loops back upstream elements + // to its upstream, and loops back downstream elementd to its downstream. + + val rotatedBidi = new GraphStage[BidiShape[Int, Int, Int, Int]] { + val in1 = Inlet[Int]("in1") + val in2 = Inlet[Int]("in2") + val out1 = Outlet[Int]("out1") + val out2 = Outlet[Int]("out2") + val shape = BidiShape(in1, out1, in2, out2) + + override def createLogic: GraphStageLogic = new GraphStageLogic { + setHandler(in1, new InHandler { + override def onPush(): Unit = push(out2, grab(in1)) + override def onUpstreamFinish(): Unit = complete(out2) + }) + + setHandler(in2, new InHandler { + override def onPush(): Unit = push(out1, grab(in2)) + override def onUpstreamFinish(): Unit = complete(out1) + }) + + setHandler(out1, new OutHandler { + override def onPull(): Unit = pull(in2) + override def onDownstreamFinish(): Unit = cancel(in2) + }) + + setHandler(out2, new OutHandler { + override def onPull(): Unit = pull(in1) + override def onDownstreamFinish(): Unit = cancel(in1) + }) + } + + override def toString = "IdentityBidi" + } + + val takeAll = Flow[Int].grouped(200).toMat(Sink.head)(Keep.right) + + val (f1, f2) = FlowGraph.closed(takeAll, takeAll)(Keep.both) { implicit b ⇒ + (out1, out2) ⇒ + import FlowGraph.Implicits._ + val bidi = b.add(rotatedBidi) + + Source(1 to 10) ~> bidi.in1 + out2 <~ bidi.out2 + + bidi.in2 <~ Source(1 to 100) + bidi.out1 ~> out1 + }.run() + + Await.result(f1, 3.seconds) should ===(1 to 100) + Await.result(f2, 3.seconds) should ===(1 to 10) + } + + "be able to implement a timeout bidiStage" in { + class IdleTimeout[I, O]( + val system: ActorSystem, + val timeout: FiniteDuration) extends GraphStage[BidiShape[I, I, O, O]] { + val in1 = Inlet[I]("in1") + val in2 = Inlet[O]("in2") + val out1 = Outlet[I]("out1") + val out2 = Outlet[O]("out2") + val shape = BidiShape(in1, out1, in2, out2) + + override def toString = "IdleTimeout" + + override def createLogic: GraphStageLogic = new GraphStageLogic { + private var timerCancellable: Option[Cancellable] = None + private var nextDeadline: Deadline = Deadline.now + timeout + + setHandler(in1, new InHandler { + override def onPush(): Unit = { + onActivity() + push(out1, grab(in1)) + } + override def onUpstreamFinish(): Unit = complete(out1) + }) + + setHandler(in2, new InHandler { + override def onPush(): Unit = { + onActivity() + push(out2, grab(in2)) + } + override def onUpstreamFinish(): Unit = complete(out2) + }) + + setHandler(out1, new OutHandler { + override def onPull(): Unit = pull(in1) + override def onDownstreamFinish(): Unit = cancel(in1) + }) + + setHandler(out2, new OutHandler { + override def onPull(): Unit = pull(in2) + override def onDownstreamFinish(): Unit = cancel(in2) + }) + + private def onActivity(): Unit = nextDeadline = Deadline.now + timeout + + private def onTimerTick(): Unit = + if (nextDeadline.isOverdue()) + failStage(new TimeoutException(s"No reads or writes happened in $timeout.")) + + override def preStart(): Unit = { + super.preStart() + val checkPeriod = timeout / 8 + val callback = getAsyncCallback[Unit]((_) ⇒ onTimerTick()) + import system.dispatcher + timerCancellable = Some(system.scheduler.schedule(timeout, checkPeriod)(callback.invoke(()))) + } + + override def postStop(): Unit = { + super.postStop() + timerCancellable.foreach(_.cancel()) + } + } + } + + val upWrite = TestPublisher.probe[String]() + val upRead = TestSubscriber.probe[Int]() + + val downWrite = TestPublisher.probe[Int]() + val downRead = TestSubscriber.probe[String]() + + FlowGraph.closed() { implicit b ⇒ + import FlowGraph.Implicits._ + val timeoutStage = b.add(new IdleTimeout[String, Int](system, 2.seconds)) + Source(upWrite) ~> timeoutStage.in1; timeoutStage.out1 ~> Sink(downRead) + Sink(upRead) <~ timeoutStage.out2; timeoutStage.in2 <~ Source(downWrite) + }.run() + + // Request enough for the whole test + upRead.request(100) + downRead.request(100) + + upWrite.sendNext("DATA1") + downRead.expectNext("DATA1") + Thread.sleep(1500) + + downWrite.sendNext(1) + upRead.expectNext(1) + Thread.sleep(1500) + + upWrite.sendNext("DATA2") + downRead.expectNext("DATA2") + Thread.sleep(1000) + + downWrite.sendNext(2) + upRead.expectNext(2) + + upRead.expectNoMsg(500.millis) + val error1 = upRead.expectError() + val error2 = downRead.expectError() + + error1.isInstanceOf[TimeoutException] should be(true) + error1.getMessage should be("No reads or writes happened in 2 seconds.") + error2 should ===(error1) + + upWrite.expectCancellation() + downWrite.expectCancellation() + } + + } + +} diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala new file mode 100644 index 0000000000..698e7a9c6f --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala @@ -0,0 +1,457 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.impl.fusing + +import akka.stream._ +import akka.stream.impl.fusing.GraphInterpreterSpec.TestSetup +import akka.stream.stage.{ InHandler, OutHandler, GraphStage, GraphStageLogic } +import akka.stream.testkit.AkkaSpec +import GraphInterpreter._ + +import scala.collection.immutable + +class GraphInterpreterSpec extends AkkaSpec { + import GraphInterpreterSpec._ + import GraphStages._ + + "GraphInterpreter" must { + + // Reusable components + val identity = new Identity[Int] + val detacher = new Detacher[Int] + val zip = new Zip[Int, String] + val bcast = new Broadcast[Int](2) + val merge = new Merge[Int](2) + val balance = new Balance[Int](2) + + "implement identity" in new TestSetup { + val source = UpstreamProbe[Int]("source") + val sink = DownstreamProbe[Int]("sink") + + builder(identity) + .connect(source, identity.in) + .connect(identity.out, sink) + .init() + + lastEvents() should ===(Set.empty) + + sink.requestOne() + lastEvents() should ===(Set(RequestOne(source))) + + source.onNext(1) + lastEvents() should ===(Set(OnNext(sink, 1))) + } + + "implement chained identity" in new TestSetup { + val source = new UpstreamProbe[Int]("source") + val sink = new DownstreamProbe[Int]("sink") + + // Constructing an assembly by hand and resolving ambiguities + val assembly = GraphAssembly( + stages = Array(identity, identity), + ins = Array(identity.in, identity.in, null), + inOwners = Array(0, 1, -1), + outs = Array(null, identity.out, identity.out), + outOwners = Array(-1, 0, 1)) + + manualInit(assembly) + interpreter.attachDownstreamBoundary(2, sink) + interpreter.attachUpstreamBoundary(0, source) + interpreter.init() + + lastEvents() should ===(Set.empty) + + sink.requestOne() + lastEvents() should ===(Set(RequestOne(source))) + + source.onNext(1) + lastEvents() should ===(Set(OnNext(sink, 1))) + } + + "implement detacher stage" in new TestSetup { + val source = UpstreamProbe[Int]("source") + val sink = DownstreamProbe[Int]("sink") + + builder(detacher) + .connect(source, detacher.in) + .connect(detacher.out, sink) + .init() + + lastEvents() should ===(Set.empty) + + sink.requestOne() + lastEvents() should ===(Set(RequestOne(source))) + + source.onNext(1) + lastEvents() should ===(Set(OnNext(sink, 1), RequestOne(source))) + + // Source waits + source.onNext(2) + lastEvents() should ===(Set.empty) + + // "pushAndPull" + sink.requestOne() + lastEvents() should ===(Set(OnNext(sink, 2), RequestOne(source))) + + // Sink waits + sink.requestOne() + lastEvents() should ===(Set.empty) + + // "pushAndPull" + source.onNext(3) + lastEvents() should ===(Set(OnNext(sink, 3), RequestOne(source))) + } + + "implement Zip" in new TestSetup { + val source1 = new UpstreamProbe[Int]("source1") + val source2 = new UpstreamProbe[String]("source2") + val sink = new DownstreamProbe[(Int, String)]("sink") + + builder(zip) + .connect(source1, zip.in0) + .connect(source2, zip.in1) + .connect(zip.out, sink) + .init() + + lastEvents() should ===(Set.empty) + + sink.requestOne() + lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2))) + + source1.onNext(42) + lastEvents() should ===(Set.empty) + + source2.onNext("Meaning of life") + lastEvents() should ===(Set(OnNext(sink, (42, "Meaning of life")))) + + sink.requestOne() + lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2))) + } + + "implement Broadcast" in new TestSetup { + val source = new UpstreamProbe[Int]("source") + val sink1 = new DownstreamProbe[Int]("sink1") + val sink2 = new DownstreamProbe[Int]("sink2") + + builder(bcast) + .connect(source, bcast.in) + .connect(bcast.out(0), sink1) + .connect(bcast.out(1), sink2) + .init() + + lastEvents() should ===(Set.empty) + + sink1.requestOne() + lastEvents() should ===(Set.empty) + + sink2.requestOne() + lastEvents() should ===(Set(RequestOne(source))) + + source.onNext(1) + lastEvents() should ===(Set(OnNext(sink1, 1), OnNext(sink2, 1))) + + } + + "implement broadcast-zip" in new TestSetup { + val source = new UpstreamProbe[Int]("source") + val sink = new DownstreamProbe[(Int, Int)]("sink") + val zip = new Zip[Int, Int] + + builder(zip, bcast) + .connect(source, bcast.in) + .connect(bcast.out(0), zip.in0) + .connect(bcast.out(1), zip.in1) + .connect(zip.out, sink) + .init() + + lastEvents() should ===(Set.empty) + + sink.requestOne() + lastEvents() should ===(Set(RequestOne(source))) + + source.onNext(1) + lastEvents() should ===(Set(OnNext(sink, (1, 1)))) + + sink.requestOne() + lastEvents() should ===(Set(RequestOne(source))) + + source.onNext(2) + lastEvents() should ===(Set(OnNext(sink, (2, 2)))) + + } + + "implement zip-broadcast" in new TestSetup { + val source1 = new UpstreamProbe[Int]("source1") + val source2 = new UpstreamProbe[Int]("source2") + val sink1 = new DownstreamProbe[(Int, Int)]("sink") + val sink2 = new DownstreamProbe[(Int, Int)]("sink2") + val zip = new Zip[Int, Int] + val bcast = new Broadcast[(Int, Int)](2) + + builder(bcast, zip) + .connect(source1, zip.in0) + .connect(source2, zip.in1) + .connect(zip.out, bcast.in) + .connect(bcast.out(0), sink1) + .connect(bcast.out(1), sink2) + .init() + + lastEvents() should ===(Set.empty) + + sink1.requestOne() + lastEvents() should ===(Set.empty) + + sink2.requestOne() + lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2))) + + source1.onNext(1) + lastEvents() should ===(Set.empty) + + source2.onNext(2) + lastEvents() should ===(Set(OnNext(sink1, (1, 2)), OnNext(sink2, (1, 2)))) + + } + + "implement merge" in new TestSetup { + val source1 = new UpstreamProbe[Int]("source1") + val source2 = new UpstreamProbe[Int]("source2") + val sink = new DownstreamProbe[Int]("sink") + + builder(merge) + .connect(source1, merge.in(0)) + .connect(source2, merge.in(1)) + .connect(merge.out, sink) + .init() + + lastEvents() should ===(Set.empty) + + sink.requestOne() + lastEvents() should ===(Set(RequestOne(source1), RequestOne(source2))) + + source1.onNext(1) + lastEvents() should ===(Set(OnNext(sink, 1), RequestOne(source1))) + + source2.onNext(2) + lastEvents() should ===(Set.empty) + + sink.requestOne() + lastEvents() should ===(Set(OnNext(sink, 2), RequestOne(source2))) + + sink.requestOne() + lastEvents() should ===(Set.empty) + + source2.onNext(3) + lastEvents() should ===(Set(OnNext(sink, 3), RequestOne(source2))) + + sink.requestOne() + lastEvents() should ===(Set.empty) + + source1.onNext(4) + lastEvents() should ===(Set(OnNext(sink, 4), RequestOne(source1))) + + } + + "implement balance" in new TestSetup { + val source = new UpstreamProbe[Int]("source") + val sink1 = new DownstreamProbe[Int]("sink1") + val sink2 = new DownstreamProbe[Int]("sink2") + + builder(balance) + .connect(source, balance.in) + .connect(balance.out(0), sink1) + .connect(balance.out(1), sink2) + .init() + + lastEvents() should ===(Set.empty) + + sink1.requestOne() + lastEvents() should ===(Set(RequestOne(source))) + + sink2.requestOne() + lastEvents() should ===(Set.empty) + + source.onNext(1) + lastEvents() should ===(Set(OnNext(sink1, 1), RequestOne(source))) + + source.onNext(2) + lastEvents() should ===(Set(OnNext(sink2, 2))) + } + + "implement bidi-stage" in pending + + "implement non-divergent cycle" in new TestSetup { + val source = new UpstreamProbe[Int]("source") + val sink = new DownstreamProbe[Int]("sink") + + builder(merge, balance) + .connect(source, merge.in(0)) + .connect(merge.out, balance.in) + .connect(balance.out(0), sink) + .connect(balance.out(1), merge.in(1)) + .init() + + lastEvents() should ===(Set.empty) + + sink.requestOne() + lastEvents() should ===(Set(RequestOne(source))) + + source.onNext(1) + lastEvents() should ===(Set(OnNext(sink, 1), RequestOne(source))) + + // Token enters merge-balance cycle and gets stuck + source.onNext(2) + lastEvents() should ===(Set(RequestOne(source))) + + // Unstuck it + sink.requestOne() + lastEvents() should ===(Set(OnNext(sink, 2))) + + } + + "implement divergent cycle" in new TestSetup { + val source = new UpstreamProbe[Int]("source") + val sink = new DownstreamProbe[Int]("sink") + + builder(detacher, balance, merge) + .connect(source, merge.in(0)) + .connect(merge.out, balance.in) + .connect(balance.out(0), sink) + .connect(balance.out(1), detacher.in) + .connect(detacher.out, merge.in(1)) + .init() + + lastEvents() should ===(Set.empty) + + sink.requestOne() + lastEvents() should ===(Set(RequestOne(source))) + + source.onNext(1) + lastEvents() should ===(Set(OnNext(sink, 1), RequestOne(source))) + + // Token enters merge-balance cycle and spins until event limit + // Without the limit this would spin forever (where forever = Int.MaxValue iterations) + source.onNext(2, eventLimit = 1000) + lastEvents() should ===(Set(RequestOne(source))) + + // The cycle is still alive and kicking, just suspended due to the event limit + interpreter.isSuspended should be(true) + + // Do to the fairness properties of both the interpreter event queue and the balance stage + // the element will eventually leave the cycle and reaches the sink. + // This should not hang even though we do not have an event limit set + sink.requestOne() + lastEvents() should ===(Set(OnNext(sink, 2))) + + // The cycle is now empty + interpreter.isSuspended should be(false) + } + } + +} + +object GraphInterpreterSpec { + + sealed trait TestEvent { + def source: GraphStageLogic + } + + case class OnComplete(source: GraphStageLogic) extends TestEvent + case class Cancel(source: GraphStageLogic) extends TestEvent + case class OnError(source: GraphStageLogic, cause: Throwable) extends TestEvent + case class OnNext(source: GraphStageLogic, elem: Any) extends TestEvent + case class RequestOne(source: GraphStageLogic) extends TestEvent + case class RequestAnother(source: GraphStageLogic) extends TestEvent + + abstract class TestSetup { + private var lastEvent: Set[TestEvent] = Set.empty + private var _interpreter: GraphInterpreter = _ + protected def interpreter: GraphInterpreter = _interpreter + + class AssemblyBuilder(stages: Seq[GraphStage[_ <: Shape]]) { + var upstreams = Vector.empty[(UpstreamBoundaryStageLogic[_], Inlet[_])] + var downstreams = Vector.empty[(Outlet[_], DownstreamBoundaryStageLogic[_])] + var connections = Vector.empty[(Outlet[_], Inlet[_])] + + def connect[T](upstream: UpstreamBoundaryStageLogic[T], in: Inlet[T]): AssemblyBuilder = { + upstreams :+= upstream -> in + this + } + + def connect[T](out: Outlet[T], downstream: DownstreamBoundaryStageLogic[T]): AssemblyBuilder = { + downstreams :+= out -> downstream + this + } + + def connect[T](out: Outlet[T], in: Inlet[T]): AssemblyBuilder = { + connections :+= out -> in + this + } + + def init(): Unit = { + val ins = upstreams.map(_._2) ++ connections.map(_._2) + val outs = connections.map(_._1) ++ downstreams.map(_._1) + val inOwners = ins.map { in ⇒ stages.indexWhere(_.shape.inlets.contains(in)) } + val outOwners = outs.map { out ⇒ stages.indexWhere(_.shape.outlets.contains(out)) } + + val assembly = GraphAssembly( + stages.toArray, + (ins ++ Vector.fill(downstreams.size)(null)).toArray, + (inOwners ++ Vector.fill(downstreams.size)(-1)).toArray, + (Vector.fill(upstreams.size)(null) ++ outs).toArray, + (Vector.fill(upstreams.size)(-1) ++ outOwners).toArray) + + _interpreter = new GraphInterpreter(assembly, (_, _, _) ⇒ ()) + + for ((upstream, i) ← upstreams.zipWithIndex) { + _interpreter.attachUpstreamBoundary(i, upstream._1) + } + + for ((downstream, i) ← downstreams.zipWithIndex) { + _interpreter.attachDownstreamBoundary(i + upstreams.size + connections.size, downstream._2) + } + + _interpreter.init() + } + } + + def manualInit(assembly: GraphAssembly): Unit = _interpreter = new GraphInterpreter(assembly, (_, _, _) ⇒ ()) + + def builder(stages: GraphStage[_ <: Shape]*): AssemblyBuilder = new AssemblyBuilder(stages.toSeq) + + def lastEvents(): Set[TestEvent] = { + val result = lastEvent + lastEvent = Set.empty + result + } + + case class UpstreamProbe[T](override val toString: String) extends UpstreamBoundaryStageLogic[T] { + val out = Outlet[T]("out") + + setHandler(out, new OutHandler { + override def onPull(): Unit = lastEvent += RequestOne(UpstreamProbe.this) + }) + + def onNext(elem: T, eventLimit: Int = Int.MaxValue): Unit = { + if (GraphInterpreter.Debug) println(s"----- NEXT: $this $elem") + push(out, elem) + interpreter.execute(eventLimit) + } + } + + case class DownstreamProbe[T](override val toString: String) extends DownstreamBoundaryStageLogic[T] { + val in = Inlet[T]("in") + + setHandler(in, new InHandler { + override def onPush(): Unit = lastEvent += OnNext(DownstreamProbe.this, grab(in)) + }) + + def requestOne(eventLimit: Int = Int.MaxValue): Unit = { + if (GraphInterpreter.Debug) println(s"----- REQ $this") + pull(in) + interpreter.execute(eventLimit) + } + } + + } +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala index 44989599f2..ff6fa75336 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala @@ -13,7 +13,7 @@ import akka.stream.impl.GenJunctions.ZipWithModule import akka.stream.impl.GenJunctions.UnzipWithModule import akka.stream.impl.Junctions._ import akka.stream.impl.StreamLayout.Module -import akka.stream.impl.fusing.ActorInterpreter +import akka.stream.impl.fusing.{ ActorGraphInterpreter, GraphModule, ActorInterpreter } import akka.stream.impl.io.SslTlsCipherActor import akka.stream._ import akka.stream.io.SslTls.TlsModule @@ -112,6 +112,20 @@ private[akka] case class ActorMaterializerImpl(val system: ActorSystem, assignPort(tls.plainIn, FanIn.SubInput[Any](impl, SslTlsCipherActor.UserIn)) assignPort(tls.cipherIn, FanIn.SubInput[Any](impl, SslTlsCipherActor.TransportIn)) + case graph: GraphModule ⇒ + val calculatedSettings = effectiveSettings(effectiveAttributes) + val props = ActorGraphInterpreter.props(graph.assembly, graph.shape, calculatedSettings) + val impl = actorOf(props, stageName(effectiveAttributes), calculatedSettings.dispatcher) + for ((inlet, i) ← graph.shape.inlets.iterator.zipWithIndex) { + val subscriber = new ActorGraphInterpreter.BoundarySubscriber(impl, i) + assignPort(inlet, subscriber) + } + for ((outlet, i) ← graph.shape.outlets.iterator.zipWithIndex) { + val publisher = new ActorPublisher[Any](impl) { override val wakeUpMsg = ActorGraphInterpreter.SubscribePending(i) } + impl ! ActorGraphInterpreter.ExposedPublisher(i, publisher) + assignPort(outlet, publisher) + } + case junction: JunctionModule ⇒ materializeJunction(junction, effectiveAttributes, effectiveSettings(effectiveAttributes)) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala new file mode 100644 index 0000000000..5c995789c2 --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala @@ -0,0 +1,390 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.impl.fusing + +import akka.actor._ +import akka.event.Logging +import akka.stream._ +import akka.stream.impl.ReactiveStreamsCompliance._ +import akka.stream.impl.StreamLayout.Module +import akka.stream.impl.fusing.GraphInterpreter.{ DownstreamBoundaryStageLogic, UpstreamBoundaryStageLogic, GraphAssembly } +import akka.stream.impl.{ ActorPublisher, ReactiveStreamsCompliance } +import akka.stream.stage.{ GraphStageLogic, InHandler, OutHandler } +import org.reactivestreams.{ Subscriber, Subscription } + +import scala.util.control.NonFatal + +/** + * INTERNAL API + */ +private[stream] case class GraphModule(assembly: GraphAssembly, shape: Shape, attributes: Attributes) extends Module { + override def subModules: Set[Module] = Set.empty + override def withAttributes(newAttr: Attributes): Module = copy(attributes = newAttr) + + override def carbonCopy: Module = copy() + + override def replaceShape(s: Shape): Module = ??? +} + +/** + * INTERNAL API + */ +private[stream] object ActorGraphInterpreter { + trait BoundaryEvent extends DeadLetterSuppression with NoSerializationVerificationNeeded + + final case class OnError(id: Int, cause: Throwable) extends BoundaryEvent + final case class OnComplete(id: Int) extends BoundaryEvent + final case class OnNext(id: Int, e: Any) extends BoundaryEvent + final case class OnSubscribe(id: Int, subscription: Subscription) extends BoundaryEvent + + final case class RequestMore(id: Int, demand: Long) extends BoundaryEvent + final case class Cancel(id: Int) extends BoundaryEvent + final case class SubscribePending(id: Int) extends BoundaryEvent + final case class ExposedPublisher(id: Int, publisher: ActorPublisher[Any]) extends BoundaryEvent + + final case class AsyncInput(logic: GraphStageLogic, evt: Any, handler: (Any) ⇒ Unit) extends BoundaryEvent + + case object Resume extends BoundaryEvent + + final class BoundarySubscription(val parent: ActorRef, val id: Int) extends Subscription { + override def request(elements: Long): Unit = parent ! RequestMore(id, elements) + override def cancel(): Unit = parent ! Cancel(id) + override def toString = "BoundarySubscription" + System.identityHashCode(this) + } + + final class BoundarySubscriber(val parent: ActorRef, id: Int) extends Subscriber[Any] { + override def onError(cause: Throwable): Unit = { + ReactiveStreamsCompliance.requireNonNullException(cause) + parent ! OnError(id, cause) + } + override def onComplete(): Unit = parent ! OnComplete(id) + override def onNext(element: Any): Unit = { + ReactiveStreamsCompliance.requireNonNullElement(element) + parent ! OnNext(id, element) + } + override def onSubscribe(subscription: Subscription): Unit = { + ReactiveStreamsCompliance.requireNonNullSubscription(subscription) + parent ! OnSubscribe(id, subscription) + } + } + + def props(assembly: GraphAssembly, shape: Shape, settings: ActorMaterializerSettings): Props = + Props(new ActorGraphInterpreter(assembly, shape, settings)).withDeploy(Deploy.local) + + class BatchingActorInputBoundary(size: Int, id: Int) extends UpstreamBoundaryStageLogic[Any] { + require(size > 0, "buffer size cannot be zero") + require((size & (size - 1)) == 0, "buffer size must be a power of two") + + private var upstream: Subscription = _ + private val inputBuffer = Array.ofDim[AnyRef](size) + private var inputBufferElements = 0 + private var nextInputElementCursor = 0 + private var upstreamCompleted = false + private var downstreamCanceled = false + private val IndexMask = size - 1 + + private def requestBatchSize = math.max(1, inputBuffer.length / 2) + private var batchRemaining = requestBatchSize + + val out: Outlet[Any] = Outlet[Any]("UpstreamBoundary" + id) + + private def dequeue(): Any = { + val elem = inputBuffer(nextInputElementCursor) + require(elem ne null, "Internal queue must never contain a null") + inputBuffer(nextInputElementCursor) = null + + batchRemaining -= 1 + if (batchRemaining == 0 && !upstreamCompleted) { + tryRequest(upstream, requestBatchSize) + batchRemaining = requestBatchSize + } + + inputBufferElements -= 1 + nextInputElementCursor = (nextInputElementCursor + 1) & IndexMask + elem + } + private def clear(): Unit = { + java.util.Arrays.fill(inputBuffer, 0, inputBuffer.length, null) + inputBufferElements = 0 + } + + def cancel(): Unit = { + if (!upstreamCompleted) { + upstreamCompleted = true + if (upstream ne null) tryCancel(upstream) + clear() + } + } + + def onNext(elem: Any): Unit = { + if (!upstreamCompleted) { + if (inputBufferElements == size) throw new IllegalStateException("Input buffer overrun") + inputBuffer((nextInputElementCursor + inputBufferElements) & IndexMask) = elem.asInstanceOf[AnyRef] + inputBufferElements += 1 + if (isAvailable(out)) push(out, dequeue()) + } + } + + def onError(e: Throwable): Unit = + if (!upstreamCompleted) { + upstreamCompleted = true + clear() + fail(out, e) + } + + // Call this when an error happens that does not come from the usual onError channel + // (exceptions while calling RS interfaces, abrupt termination etc) + def onInternalError(e: Throwable): Unit = { + if (!(upstreamCompleted || downstreamCanceled) && (upstream ne null)) { + upstream.cancel() + } + onError(e) + } + + def onComplete(): Unit = + if (!upstreamCompleted) { + upstreamCompleted = true + if (inputBufferElements == 0) complete(out) + } + + def onSubscribe(subscription: Subscription): Unit = { + require(subscription != null, "Subscription cannot be null") + if (upstreamCompleted) + tryCancel(subscription) + else if (downstreamCanceled) { + upstreamCompleted = true + tryCancel(subscription) + } else { + upstream = subscription + // Prefetch + tryRequest(upstream, inputBuffer.length) + } + } + + setHandler(out, new OutHandler { + override def onPull(): Unit = { + if (inputBufferElements > 1) push(out, dequeue()) + else if (inputBufferElements == 1) { + if (upstreamCompleted) { + push(out, dequeue()) + complete(out) + } else push(out, dequeue()) + } else if (upstreamCompleted) { + complete(out) + } + } + + override def onDownstreamFinish(): Unit = cancel() + }) + + } + + class ActorOutputBoundary(actor: ActorRef, id: Int) extends DownstreamBoundaryStageLogic[Any] { + val in: Inlet[Any] = Inlet[Any]("UpstreamBoundary" + id) + + private var exposedPublisher: ActorPublisher[Any] = _ + + private var subscriber: Subscriber[Any] = _ + private var downstreamDemand: Long = 0L + // This flag is only used if complete/fail is called externally since this op turns into a Finished one inside the + // interpreter (i.e. inside this op this flag has no effects since if it is completed the op will not be invoked) + private var downstreamCompleted = false + // this is true while we “hold the ball”; while “false” incoming demand will just be queued up + private var upstreamWaiting = true + // when upstream failed before we got the exposed publisher + private var upstreamFailed: Option[Throwable] = None + + private def onNext(elem: Any): Unit = { + downstreamDemand -= 1 + tryOnNext(subscriber, elem) + } + + private def complete(): Unit = { + if (!downstreamCompleted) { + downstreamCompleted = true + if (exposedPublisher ne null) exposedPublisher.shutdown(None) + if (subscriber ne null) tryOnComplete(subscriber) + } + } + + def fail(e: Throwable): Unit = { + if (!downstreamCompleted) { + downstreamCompleted = true + if (exposedPublisher ne null) exposedPublisher.shutdown(Some(e)) + if ((subscriber ne null) && !e.isInstanceOf[SpecViolation]) tryOnError(subscriber, e) + } else if (exposedPublisher == null && upstreamFailed.isEmpty) { + // fail called before the exposed publisher arrived, we must store it and fail when we're first able to + upstreamFailed = Some(e) + } + } + + setHandler(in, new InHandler { + override def onPush(): Unit = { + onNext(grab(in)) + if (downstreamCompleted) cancel(in) + else if (downstreamDemand > 0) pull(in) + } + + override def onUpstreamFinish(): Unit = complete() + + override def onUpstreamFailure(cause: Throwable): Unit = fail(cause) + }) + + def subscribePending(): Unit = + exposedPublisher.takePendingSubscribers() foreach { sub ⇒ + if (subscriber eq null) { + subscriber = sub + tryOnSubscribe(subscriber, new BoundarySubscription(actor, id)) + } else + rejectAdditionalSubscriber(subscriber, s"${Logging.simpleName(this)}") + } + + def exposedPublisher(publisher: ActorPublisher[Any]): Unit = { + upstreamFailed match { + case _: Some[_] ⇒ + publisher.shutdown(upstreamFailed) + case _ ⇒ + exposedPublisher = publisher + } + } + + def requestMore(elements: Long): Unit = { + if (elements < 1) { + cancel(in) + fail(ReactiveStreamsCompliance.numberOfElementsInRequestMustBePositiveException) + } else { + downstreamDemand += elements + if (downstreamDemand < 0) + downstreamDemand = Long.MaxValue // Long overflow, Reactive Streams Spec 3:17: effectively unbounded + if (!hasBeenPulled(in)) pull(in) + } + } + + def cancel(): Unit = { + downstreamCompleted = true + subscriber = null + exposedPublisher.shutdown(Some(new ActorPublisher.NormalShutdownException)) + cancel(in) + } + + } + +} + +/** + * INTERNAL API + */ +private[stream] class ActorGraphInterpreter(assembly: GraphAssembly, shape: Shape, settings: ActorMaterializerSettings) extends Actor { + import ActorGraphInterpreter._ + + val interpreter = new GraphInterpreter(assembly, (logic, event, handler) ⇒ self ! AsyncInput(logic, event, handler)) + val inputs = Array.tabulate(shape.inlets.size)(new BatchingActorInputBoundary(settings.maxInputBufferSize, _)) + val outputs = Array.tabulate(shape.outlets.size)(new ActorOutputBoundary(self, _)) + // Limits the number of events processed by the interpreter before scheduling a self-message for fairness with other + // actors. + // TODO: Better heuristic here + val eventLimit = settings.maxInputBufferSize * assembly.stages.length * 4 // Roughly 4 events per element transfer + // Limits the number of events processed by the interpreter on an abort event. + // TODO: Better heuristic here + val abortLimit = eventLimit * 2 + var resumeScheduled = false + + override def preStart(): Unit = { + var i = 0 + while (i < inputs.length) { + interpreter.attachUpstreamBoundary(i, inputs(i)) + i += 1 + } + val offset = assembly.connectionCount - outputs.length + i = 0 + while (i < outputs.length) { + interpreter.attachDownstreamBoundary(i + offset, outputs(i)) + i += 1 + } + interpreter.init() + } + + override def receive: Receive = { + case OnError(id: Int, cause: Throwable) ⇒ + if (GraphInterpreter.Debug) println(s" onError id=$id") + inputs(id).onError(cause) + runBatch() + case OnComplete(id: Int) ⇒ + if (GraphInterpreter.Debug) println(s" onComplete id=$id") + inputs(id).onComplete() + runBatch() + case OnNext(id: Int, e: Any) ⇒ + if (GraphInterpreter.Debug) println(s" onNext $e id=$id") + inputs(id).onNext(e) + runBatch() + case OnSubscribe(id: Int, subscription: Subscription) ⇒ + inputs(id).onSubscribe(subscription) + + case RequestMore(id: Int, demand: Long) ⇒ + if (GraphInterpreter.Debug) println(s" request $demand id=$id") + outputs(id).requestMore(demand) + runBatch() + case Cancel(id: Int) ⇒ + if (GraphInterpreter.Debug) println(s" cancel id=$id") + outputs(id).cancel() + runBatch() + case SubscribePending(id: Int) ⇒ + outputs(id).subscribePending() + case ExposedPublisher(id, publisher) ⇒ + outputs(id).exposedPublisher(publisher) + + case AsyncInput(_, event, handler) ⇒ + if (GraphInterpreter.Debug) println(s"ASYNC $event") + handler(event) + runBatch() + + case Resume ⇒ + resumeScheduled = false + if (interpreter.isSuspended) runBatch() + + } + + override protected[akka] def aroundReceive(receive: Actor.Receive, msg: Any): Unit = { + super.aroundReceive(receive, msg) + + } + + private def runBatch(): Unit = { + try { + interpreter.execute(eventLimit) + if (interpreter.isCompleted) context.stop(self) + else if (interpreter.isSuspended && !resumeScheduled) { + resumeScheduled = true + self ! Resume + } + } catch { + case NonFatal(e) ⇒ + context.stop(self) + tryAbort(e) + } + } + + /** + * Attempts to abort execution, by first propagating the reason given until either + * - the interpreter successfully finishes + * - the event limit is reached + * - a new error is encountered + */ + private def tryAbort(ex: Throwable): Unit = { + // This should handle termination while interpreter is running. If the upstream have been closed already this + // call has no effect and therefore do the right thing: nothing. + try { + inputs.foreach(_.onInternalError(ex)) + interpreter.execute(abortLimit) + interpreter.finish() + } // Will only have an effect if the above call to the interpreter failed to emit a proper failure to the downstream + // otherwise this will have no effect + finally { + outputs.foreach(_.fail(ex)) + inputs.foreach(_.cancel()) + } + } + + override def postStop(): Unit = tryAbort(AbruptTerminationException(self)) +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala new file mode 100644 index 0000000000..edcd0532eb --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala @@ -0,0 +1,434 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.impl.fusing + +import akka.stream.stage.{ OutHandler, InHandler, GraphStage, GraphStageLogic } +import akka.stream.{ Shape, Inlet, Outlet } + +/** + * INTERNAL API + * + * (See the class for the documentation of the internals) + */ +private[stream] object GraphInterpreter { + /** + * Compile time constant, enable it for debug logging to the console. + */ + final val Debug = false + + /** + * Marker object that indicates that a port holds no element since it was already grabbed. The port is still pullable, + * but there is no more element to grab. + */ + case object Empty + + sealed trait ConnectionState + sealed trait CompletedState extends ConnectionState + case object Pushable extends ConnectionState + case object Completed extends CompletedState + final case class PushCompleted(element: Any) extends ConnectionState + case object Cancelled extends CompletedState + final case class Failed(ex: Throwable) extends CompletedState + + val NoEvent = -1 + val Boundary = -1 + + abstract class UpstreamBoundaryStageLogic[T] extends GraphStageLogic { + def out: Outlet[T] + } + + abstract class DownstreamBoundaryStageLogic[T] extends GraphStageLogic { + def in: Inlet[T] + } + + /** + * INTERNAL API + * + * A GraphAssembly represents a small stream processing graph to be executed by the interpreter. Instances of this + * class **must not** be mutated after construction. + * + * The arrays [[ins]] and [[outs]] correspond to the notion of a *connection* in the [[GraphInterpreter]]. Each slot + * *i* contains the input and output port corresponding to connection *i*. Slots where the graph is not closed (i.e. + * ports are exposed to the external world) are marked with *null* values. For example if an input port *p* is + * exposed, then outs(p) will contain a *null*. + * + * The arrays [[inOwners]] and [[outOwners]] are lookup tables from a connection id (the index of the slot) + * to a slot in the [[stages]] array, indicating which stage is the owner of the given input or output port. + * Slots which would correspond to non-existent stages (where the corresponding port is null since it represents + * the currently unknown external context) contain the value [[GraphInterpreter#Boundary]]. + * + * The current assumption by the infrastructure is that the layout of these arrays looks like this: + * + * +---------------------------------------+-----------------+ + * inOwners: | index to stages array | Boundary (-1) | + * +----------------+----------------------+-----------------+ + * ins: | exposed inputs | internal connections | nulls | + * +----------------+----------------------+-----------------+ + * outs: | nulls | internal connections | exposed outputs | + * +----------------+----------------------+-----------------+ + * outOwners: | Boundary (-1) | index to stages array | + * +----------------+----------------------------------------+ + * + * In addition, it is also assumed by the infrastructure that the order of exposed inputs and outputs in the + * corresponding segments of these arrays matches the exact same order of the ports in the [[Shape]]. + * + */ + final case class GraphAssembly(stages: Array[GraphStage[_]], + ins: Array[Inlet[_]], + inOwners: Array[Int], + outs: Array[Outlet[_]], + outOwners: Array[Int]) { + + val connectionCount: Int = ins.length + + /** + * Takes an interpreter and returns three arrays required by the interpreter containing the input, output port + * handlers and the stage logic instances. + */ + def materialize(interpreter: GraphInterpreter): (Array[InHandler], Array[OutHandler], Array[GraphStageLogic]) = { + val logics = Array.ofDim[GraphStageLogic](stages.length) + for (i ← stages.indices) { + logics(i) = stages(i).createLogic + logics(i).interpreter = interpreter + } + + val inHandlers = Array.ofDim[InHandler](connectionCount) + val outHandlers = Array.ofDim[OutHandler](connectionCount) + + for (i ← 0 until connectionCount) { + if (ins(i) ne null) { + inHandlers(i) = logics(inOwners(i)).inHandlers(ins(i)) + logics(inOwners(i)).inToConn += ins(i) -> i + } + if (outs(i) ne null) { + outHandlers(i) = logics(outOwners(i)).outHandlers(outs(i)) + logics(outOwners(i)).outToConn += outs(i) -> i + } + } + + (inHandlers, outHandlers, logics) + } + + override def toString: String = + "GraphAssembly(" + + stages.mkString("[", ",", "]") + ", " + + ins.mkString("[", ",", "]") + ", " + + inOwners.mkString("[", ",", "]") + ", " + + outs.mkString("[", ",", "]") + ", " + + outOwners.mkString("[", ",", "]") + + ")" + } +} + +/** + * INERNAL API + * + * From an external viewpoint, the GraphInterpreter takes an assembly of graph processing stages encoded as a + * [[GraphInterpreter#GraphAssembly]] object and provides facilities to execute and interact with this assembly. + * The lifecylce of the Interpreter is roughly the following: + * - Boundary logics are attached via [[attachDownstreamBoundary()]] and [[attachUpstreamBoundary()]] + * - [[init()]] is called + * - [[execute()]] is called whenever there is need for execution, providing an upper limit on the processed events + * - [[finish()]] is called before the interpreter is disposed, preferably after [[isCompleted]] returned true, although + * in abort cases this is not strictly necessary + * + * The [[execute()]] method of the interpreter accepts an upper bound on the events it will process. After this limit + * is reached or there are no more pending events to be processed, the call returns. It is possible to inspect + * if there are unprocessed events left via the [[isSuspended]] method. [[isCompleted]] returns true once all stages + * reported completion inside the interpreter. + * + * The internal architecture of the interpreter is based on the usage of arrays and optimized for reducing allocations + * on the hot paths. + * + * One of the basic abstractions inside the interpreter is the notion of *connection*. In the abstract sense a + * connection represents an output-input port pair (an analogue for a connected RS Publisher-Subscriber pair), + * while in the practical sense a connection is a number which represents slots in certain arrays. + * In particular + * - connectionStates is a mapping from a connection id to a current (or future) state of the connection + * - inAvailable is a mapping from a connection to a boolean that indicates whether the input corresponding + * to the connection is currently pullable + * - outAvailable is a mapping from a connection to a boolean that indicates whether the input corresponding + * to the connection is currently pushable + * - inHandlers is a mapping from a connection id to the [[InHandler]] instance that handles the events corresponding + * to the input port of the connection + * - outHandlers is a mapping from a connection id to the [[OutHandler]] instance that handles the events corresponding + * to the output port of the connection + * + * On top of these lookup tables there is an eventQueue, represented as a circular buffer of integers. The integers + * it contains represents connections that have pending events to be processed. The pending event itself is encoded + * in the connectionStates table. This implies that there can be only one event in flight for a given connection, which + * is true in almost all cases, except a complete-after-push which is therefore handled with a special event + * [[GraphInterpreter#PushCompleted]]. + * + * Sending an event is usually the following sequence: + * - An action is requested by a stage logic (push, pull, complete, etc.) + * - the availability of the port is set on the sender side to false (inAvailable or outAvailable) + * - the scheduled event is put in the slot of the connection in the connectionStates table + * - the id of the affected connection is enqueued + * + * Receiving an event is usually the following sequence: + * - id of connection to be processed is dequeued + * - the type of the event is determined by the object in the corresponding connectionStates slot + * - the availability of the port is set on the receiver side to be true (inAvailable or outAvailable) + * - using the inHandlers/outHandlers table the corresponding callback is called on the stage logic. + * + * Because of the FIFO construction of the queue the interpreter is fair, i.e. a pending event is always executed + * after a bounded number of other events. This property, together with suspendability means that even infinite cycles can + * be modeled, or even dissolved (if preempted and a "stealing" external even is injected; for example the non-cycle + * edge of a balance is pulled, dissolving the original cycle). + */ +private[stream] final class GraphInterpreter( + private val assembly: GraphInterpreter.GraphAssembly, + val onAsyncInput: (GraphStageLogic, Any, (Any) ⇒ Unit) ⇒ Unit) { + import GraphInterpreter._ + + // Maintains the next event (and state) of the connection. + // Technically the connection cannot be considered being in the state that is encoded here before the enqueued + // connection event has been processed. The inAvailable and outAvailable arrays usually protect access to this + // field while it is in transient state. + val connectionStates = Array.fill[Any](assembly.connectionCount)(Empty) + + // Indicates whether the input port is pullable. After pulling it becomes false + // Be aware that when inAvailable goes to false outAvailable does not become true immediately, only after + // the corresponding event in the queue has been processed + val inAvailable = Array.fill[Boolean](assembly.connectionCount)(true) + + // Indicates whether the output port is pushable. After pushing it becomes false + // Be aware that when inAvailable goes to false outAvailable does not become true immediately, only after + // the corresponding event in the queue has been processed + val outAvailable = Array.fill[Boolean](assembly.connectionCount)(false) + + // Lookup tables for the InHandler and OutHandler for a given connection ID, and a lookup table for the + // GraphStageLogic instances + val (inHandlers, outHandlers, logics) = assembly.materialize(this) + + // The number of currently running stages. Once this counter reaches zero, the interpreter is considered to be + // completed + private var runningStages = assembly.stages.length + + // Counts how many active connections a stage has. Once it reaches zero, the stage is automatically stopped. + private val shutdownCounter = Array.tabulate(assembly.stages.length) { i ⇒ + val shape = assembly.stages(i).shape.asInstanceOf[Shape] + shape.inlets.size + shape.outlets.size + } + + // An event queue implemented as a circular buffer + private val mask = 255 + private val eventQueue = Array.ofDim[Int](256) + private var queueHead: Int = 0 + private var queueTail: Int = 0 + + /** + * Assign the boundary logic to a given connection. This will serve as the interface to the external world + * (outside the interpreter) to process and inject events. + */ + def attachUpstreamBoundary(connection: Int, logic: UpstreamBoundaryStageLogic[_]): Unit = { + logic.outToConn += logic.out -> connection + logic.interpreter = this + outHandlers(connection) = logic.outHandlers.head._2 + } + + /** + * Assign the boundary logic to a given connection. This will serve as the interface to the external world + * (outside the interpreter) to process and inject events. + */ + def attachDownstreamBoundary(connection: Int, logic: DownstreamBoundaryStageLogic[_]): Unit = { + logic.inToConn += logic.in -> connection + logic.interpreter = this + inHandlers(connection) = logic.inHandlers.head._2 + } + + /** + * Returns true if there are pending unprocessed events in the event queue. + */ + def isSuspended: Boolean = queueHead != queueTail + + /** + * Returns true if there are no more running stages and pending events. + */ + def isCompleted: Boolean = runningStages == 0 && !isSuspended + + /** + * Initializes the states of all the stage logics by calling preStart() + */ + def init(): Unit = { + var i = 0 + while (i < logics.length) { + logics(i).preStart() + i += 1 + } + } + + /** + * Finalizes the state of all stages by calling postStop() (if necessary). + */ + def finish(): Unit = { + var i = 0 + while (i < logics.length) { + if (!isStageCompleted(i)) logics(i).postStop() + i += 1 + } + } + + // Debug name for a connections input part + private def inOwnerName(connection: Int): String = + if (assembly.inOwners(connection) == Boundary) "DownstreamBoundary" + else assembly.stages(assembly.inOwners(connection)).toString + + // Debug name for a connections ouput part + private def outOwnerName(connection: Int): String = + if (assembly.outOwners(connection) == Boundary) "UpstreamBoundary" + else assembly.stages(assembly.outOwners(connection)).toString + + /** + * Executes pending events until the given limit is met. If there were remaining events, isSuspended will return + * true. + */ + def execute(eventLimit: Int): Unit = { + var eventsRemaining = eventLimit + var connection = dequeue() + while (eventsRemaining > 0 && connection != NoEvent) { + processEvent(connection) + eventsRemaining -= 1 + if (eventsRemaining > 0) connection = dequeue() + } + // TODO: deadlock detection + } + + // Decodes and processes a single event for the given connection + private def processEvent(connection: Int): Unit = { + + def processElement(elem: Any): Unit = { + if (!isStageCompleted(assembly.inOwners(connection))) { + if (GraphInterpreter.Debug) println(s"PUSH ${outOwnerName(connection)} -> ${inOwnerName(connection)}, $elem") + inAvailable(connection) = true + inHandlers(connection).onPush() + } + } + + connectionStates(connection) match { + case Pushable ⇒ + if (!isStageCompleted(assembly.outOwners(connection))) { + if (GraphInterpreter.Debug) println(s"PULL ${inOwnerName(connection)} -> ${outOwnerName(connection)}") + outAvailable(connection) = true + outHandlers(connection).onPull() + } + case Completed ⇒ + val stageId = assembly.inOwners(connection) + if (!isStageCompleted(stageId)) { + if (GraphInterpreter.Debug) println(s"COMPLETE ${outOwnerName(connection)} -> ${inOwnerName(connection)}") + inAvailable(connection) = false + inHandlers(connection).onUpstreamFinish() + completeConnection(stageId) + } + case Failed(ex) ⇒ + val stageId = assembly.inOwners(connection) + if (!isStageCompleted(stageId)) { + if (GraphInterpreter.Debug) println(s"FAIL ${outOwnerName(connection)} -> ${inOwnerName(connection)}") + inAvailable(connection) = false + inHandlers(connection).onUpstreamFailure(ex) + completeConnection(stageId) + } + case Cancelled ⇒ + val stageId = assembly.outOwners(connection) + if (!isStageCompleted(stageId)) { + if (GraphInterpreter.Debug) println(s"CANCEL ${inOwnerName(connection)} -> ${outOwnerName(connection)}") + outAvailable(connection) = false + outHandlers(connection).onDownstreamFinish() + completeConnection(stageId) + } + case PushCompleted(elem) ⇒ + inAvailable(connection) = true + connectionStates(connection) = elem + processElement(elem) + enqueue(connection, Completed) + case pushedElem ⇒ processElement(pushedElem) + + } + + } + + private def dequeue(): Int = { + if (queueHead == queueTail) NoEvent + else { + val idx = queueHead & mask + val elem = eventQueue(idx) + eventQueue(idx) = NoEvent + queueHead += 1 + elem + } + } + + private def enqueue(connection: Int, event: Any): Unit = { + connectionStates(connection) = event + eventQueue(queueTail & mask) = connection + queueTail += 1 + } + + // Returns true if a connection has been completed *or if the completion event is already enqueued*. This is useful + // to prevent redundant completion events in case of concurrent invocation on both sides of the connection. + // I.e. when one side already enqueued the completion event, then the other side will not enqueue the event since + // there is noone to process it anymore. + def isConnectionCompleted(connection: Int): Boolean = connectionStates(connection).isInstanceOf[CompletedState] + + // Returns true if the given stage is alredy completed + private def isStageCompleted(stageId: Int): Boolean = stageId != Boundary && shutdownCounter(stageId) == 0 + + private def isPushInFlight(connection: Int): Boolean = + !inAvailable(connection) && + !connectionStates(connection).isInstanceOf[ConnectionState] && + connectionStates(connection) != Empty + + // Register that a connection in which the given stage participated has been completed and therefore the stage + // itself might stop, too. + private def completeConnection(stageId: Int): Unit = { + if (stageId != Boundary) { + val activeConnections = shutdownCounter(stageId) + if (activeConnections > 0) { + shutdownCounter(stageId) = activeConnections - 1 + // This was the last active connection keeping this stage alive + if (activeConnections == 1) { + runningStages -= 1 + logics(stageId).postStop() + } + } + } + } + + private[stream] def push(connection: Int, elem: Any): Unit = { + outAvailable(connection) = false + enqueue(connection, elem) + } + + private[stream] def pull(connection: Int): Unit = { + inAvailable(connection) = false + enqueue(connection, Pushable) + } + + private[stream] def complete(connection: Int): Unit = { + outAvailable(connection) = false + if (!isConnectionCompleted(connection)) { + // There is a pending push, we change the signal to be a PushCompleted (there can be only one signal in flight + // for a connection) + if (isPushInFlight(connection)) + connectionStates(connection) = PushCompleted(connectionStates(connection)) + else + enqueue(connection, Completed) + } + completeConnection(assembly.outOwners(connection)) + } + + private[stream] def fail(connection: Int, ex: Throwable): Unit = { + outAvailable(connection) = false + if (!isConnectionCompleted(connection)) enqueue(connection, Failed(ex)) + completeConnection(assembly.outOwners(connection)) + } + + private[stream] def cancel(connection: Int): Unit = { + inAvailable(connection) = false + if (!isConnectionCompleted(connection)) enqueue(connection, Cancelled) + completeConnection(assembly.inOwners(connection)) + } + +} \ No newline at end of file diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala new file mode 100644 index 0000000000..0b1aa4073c --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala @@ -0,0 +1,235 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.impl.fusing + +import akka.stream._ +import akka.stream.stage.{ OutHandler, InHandler, GraphStageLogic, GraphStage } + +/** + * INTERNAL API + */ +object GraphStages { + + class Identity[T] extends GraphStage[FlowShape[T, T]] { + val in = Inlet[T]("in") + val out = Outlet[T]("out") + + val shape = FlowShape(in, out) + + override def createLogic: GraphStageLogic = new GraphStageLogic { + setHandler(in, new InHandler { + override def onPush(): Unit = push(out, grab(in)) + override def onUpstreamFinish(): Unit = completeStage() + override def onUpstreamFailure(ex: Throwable): Unit = failStage(ex) + }) + + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + override def onDownstreamFinish(): Unit = completeStage() + }) + } + + override def toString = "Identity" + } + + class Detacher[T] extends GraphStage[FlowShape[T, T]] { + val in = Inlet[T]("in") + val out = Outlet[T]("out") + val shape = FlowShape(in, out) + + override def createLogic: GraphStageLogic = new GraphStageLogic { + var initialized = false + + setHandler(in, new InHandler { + override def onPush(): Unit = { + if (isAvailable(out)) { + push(out, grab(in)) + pull(in) + } + } + }) + + setHandler(out, new OutHandler { + override def onPull(): Unit = { + if (!initialized) { + pull(in) + initialized = true + } else if (isAvailable(in)) { + push(out, grab(in)) + if (!hasBeenPulled(in)) pull(in) + } + } + }) + + } + + override def toString = "Detacher" + } + + class Broadcast[T](private val outCount: Int) extends GraphStage[UniformFanOutShape[T, T]] { + val in = Inlet[T]("in") + val out = Vector.fill(outCount)(Outlet[T]("out")) + val shape = UniformFanOutShape(in, out: _*) + + override def createLogic: GraphStageLogic = new GraphStageLogic { + private var pending = outCount + + setHandler(in, new InHandler { + override def onPush(): Unit = { + pending = outCount + val elem = grab(in) + out.foreach(push(_, elem)) + } + }) + + val outHandler = new OutHandler { + override def onPull(): Unit = { + pending -= 1 + if (pending == 0) pull(in) + } + } + + out.foreach(setHandler(_, outHandler)) + } + + override def toString = "Broadcast" + + } + + class Zip[A, B] extends GraphStage[FanInShape2[A, B, (A, B)]] { + val in0 = Inlet[A]("in0") + val in1 = Inlet[B]("in1") + val out = Outlet[(A, B)]("out") + val shape = new FanInShape2[A, B, (A, B)](in0, in1, out) + + override def createLogic: GraphStageLogic = new GraphStageLogic { + var pending = 2 + + val inHandler = new InHandler { + override def onPush(): Unit = { + pending -= 1 + if (pending == 0) push(out, (grab(in0), grab(in1))) + } + } + + setHandler(in0, inHandler) + setHandler(in1, inHandler) + setHandler(out, new OutHandler { + override def onPull(): Unit = { + pending = 2 + pull(in0) + pull(in1) + } + }) + } + + override def toString = "Zip" + } + + class Merge[T](private val inCount: Int) extends GraphStage[UniformFanInShape[T, T]] { + val in = Vector.fill(inCount)(Inlet[T]("in")) + val out = Outlet[T]("out") + val shape = UniformFanInShape(out, in: _*) + + override def createLogic: GraphStageLogic = new GraphStageLogic { + private var initialized = false + + private val pendingQueue = Array.ofDim[Inlet[T]](inCount) + private var pendingHead: Int = 0 + private var pendingTail: Int = 0 + + private def noPending: Boolean = pendingHead == pendingTail + private def enqueue(in: Inlet[T]): Unit = { + pendingQueue(pendingTail % inCount) = in + pendingTail += 1 + } + private def dequeueAndDispatch(): Unit = { + val in = pendingQueue(pendingHead % inCount) + pendingHead += 1 + push(out, grab(in)) + pull(in) + } + + in.foreach { i ⇒ + setHandler(i, new InHandler { + override def onPush(): Unit = { + if (isAvailable(out)) { + if (noPending) { + push(out, grab(i)) + pull(i) + } else { + enqueue(i) + dequeueAndDispatch() + } + } else enqueue(i) + } + }) + } + + setHandler(out, new OutHandler { + override def onPull(): Unit = { + if (!initialized) { + initialized = true + in.foreach(pull(_)) + } else { + if (!noPending) { + dequeueAndDispatch() + } + } + } + }) + } + + override def toString = "Merge" + } + + class Balance[T](private val outCount: Int) extends GraphStage[UniformFanOutShape[T, T]] { + val in = Inlet[T]("in") + val out = Vector.fill(outCount)(Outlet[T]("out")) + val shape = UniformFanOutShape[T, T](in, out: _*) + + override def createLogic: GraphStageLogic = new GraphStageLogic { + private val pendingQueue = Array.ofDim[Outlet[T]](outCount) + private var pendingHead: Int = 0 + private var pendingTail: Int = 0 + + private def noPending: Boolean = pendingHead == pendingTail + private def enqueue(out: Outlet[T]): Unit = { + pendingQueue(pendingTail % outCount) = out + pendingTail += 1 + } + private def dequeueAndDispatch(): Unit = { + val out = pendingQueue(pendingHead % outCount) + pendingHead += 1 + push(out, grab(in)) + if (!noPending) pull(in) + } + + setHandler(in, new InHandler { + override def onPush(): Unit = dequeueAndDispatch() + }) + + out.foreach { o ⇒ + setHandler(o, new OutHandler { + override def onPull(): Unit = { + if (isAvailable(in)) { + if (noPending) { + push(o, grab(in)) + } else { + enqueue(o) + dequeueAndDispatch() + } + } else { + if (!hasBeenPulled(in)) pull(in) + enqueue(o) + } + } + }) + } + } + + override def toString = "Balance" + } + +} diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala new file mode 100644 index 0000000000..c90e886bf7 --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -0,0 +1,255 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.stage + +import akka.stream._ +import akka.stream.impl.StreamLayout.Module +import akka.stream.impl.fusing.{ GraphModule, GraphInterpreter } +import akka.stream.impl.fusing.GraphInterpreter.GraphAssembly + +/** + * A GraphStage represents a reusable graph stream processing stage. A GraphStage consists of a [[Shape]] which describes + * its input and output ports and a factory function that creates a [[GraphStageLogic]] which implements the processing + * logic that ties the ports together. + */ +abstract class GraphStage[S <: Shape] extends Graph[S, Unit] { + def shape: S + def createLogic: GraphStageLogic + + final override private[stream] lazy val module: Module = { + val connectionCount = shape.inlets.size + shape.outlets.size + val assembly = GraphAssembly( + Array(this), + Array.ofDim(connectionCount), + Array.fill(connectionCount)(-1), + Array.ofDim(connectionCount), + Array.fill(connectionCount)(-1)) + + for ((inlet, i) ← shape.inlets.iterator.zipWithIndex) { + assembly.ins(i) = inlet + assembly.inOwners(i) = 0 + } + + for ((outlet, i) ← shape.outlets.iterator.zipWithIndex) { + assembly.outs(i + shape.inlets.size) = outlet + assembly.outOwners(i + shape.inlets.size) = 0 + } + + GraphModule(assembly, shape, Attributes.none) + } + + /** + * This method throws an [[UnsupportedOperationException]] by default. The subclass can override this method + * and provide a correct implementation that creates an exact copy of the stage with the provided new attributes. + */ + override def withAttributes(attr: Attributes): Graph[S, Unit] = + throw new UnsupportedOperationException( + "withAttributes not supported by default by FlexiMerge, subclass may override and implement it") +} + +/** + * Represents the processing logic behind a [[GraphStage]]. Roughly speaking, a subclass of [[GraphStageLogic]] is a + * collection of the following parts: + * * A set of [[InHandler]] and [[OutHandler]] instances and their assignments to the [[Inlet]]s and [[Outlet]]s + * of the enclosing [[GraphStage]] + * * Possible mutable state, accessible from the [[InHandler]] and [[OutHandler]] callbacks, but not from anywhere + * else (as such access would not be thread-safe) + * * The lifecycle hooks [[preStart()]] and [[postStop()]] + * * Methods for performing stream processing actions, like pulling or pushing elements + * + * The stage logic is always stopped once all its input and output ports have been closed, i.e. it is not possible to + * keep the stage alive for further processing once it does not have any open ports. + */ +abstract class GraphStageLogic { + import GraphInterpreter._ + + /** + * INTERNAL API + */ + private[stream] var inHandlers = scala.collection.Map.empty[Inlet[_], InHandler] + /** + * INTERNAL API + */ + private[stream] var outHandlers = scala.collection.Map.empty[Outlet[_], OutHandler] + + /** + * INTERNAL API + */ + private[stream] var inToConn = scala.collection.Map.empty[Inlet[_], Int] + /** + * INTERNAL API + */ + private[stream] var outToConn = scala.collection.Map.empty[Outlet[_], Int] + + /** + * INTERNAL API + */ + private[stream] var interpreter: GraphInterpreter = _ + + /** + * Assigns callbacks for the events for an [[Inlet]] + */ + final protected def setHandler(in: Inlet[_], handler: InHandler): Unit = inHandlers += in -> handler + /** + * Assigns callbacks for the events for an [[Outlet]] + */ + final protected def setHandler(out: Outlet[_], handler: OutHandler): Unit = outHandlers += out -> handler + + private def conn[T](in: Inlet[T]): Int = inToConn(in) + private def conn[T](out: Outlet[T]): Int = outToConn(out) + + /** + * Requests an element on the given port. Calling this method twice before an element arrived will fail. + * There can only be one outstanding request at any given time. The method [[hasBeenPulled()]] can be used + * query whether pull is allowed to be called or not. + */ + final def pull[T](in: Inlet[T]): Unit = { + require(!hasBeenPulled(in), "Cannot pull port twice") + interpreter.pull(conn(in)) + } + + /** + * Requests to stop receiving events from a given input port. + */ + final def cancel[T](in: Inlet[T]): Unit = interpreter.cancel(conn(in)) + + /** + * Once the callback [[InHandler.onPush()]] for an input port has been invoked, the element that has been pushed + * can be retrieved via this method. After [[grab()]] has been called the port is considered to be empty, and further + * calls to [[grab()]] will fail until the port is pulled again and a new element is pushed as a response. + * + * The method [[isAvailable()]] can be used to query if the port has an element that can be grabbed or not. + */ + final def grab[T](in: Inlet[T]): T = { + require(isAvailable(in), "Cannot get element from already empty input port") + val connection = conn(in) + val elem = interpreter.connectionStates(connection) + interpreter.connectionStates(connection) = Empty + elem.asInstanceOf[T] + } + + /** + * Indicates whether there is already a pending pull for the given input port. If this method returns true + * then [[isAvailable()]] must return false for that same port. + */ + final def hasBeenPulled[T](in: Inlet[T]): Boolean = !interpreter.inAvailable(conn(in)) + + /** + * Indicates whether there is an element waiting at the given input port. [[grab()]] can be used to retrieve the + * element. After calling [[grab()]] this method will return false. + * + * If this method returns true then [[hasBeenPulled()]] will return false for that same port. + */ + final def isAvailable[T](in: Inlet[T]): Boolean = { + val connection = conn(in) + interpreter.inAvailable(connection) && !(interpreter.connectionStates(connection) == Empty) + } + + /** + * Emits an element through the given output port. Calling this method twice before a [[pull()]] has been arrived + * will fail. There can be only one outstanding push request at any given time. The method [[isAvailable()]] can be + * used to check if the port is ready to be pushed or not. + */ + final def push[T](out: Outlet[T], elem: T): Unit = { + require(isAvailable(out), "Cannot push port twice") + interpreter.push(conn(out), elem) + } + + /** + * Signals that there will be no more elements emitted on the given port. + */ + final def complete[T](out: Outlet[T]): Unit = interpreter.complete(conn(out)) + + /** + * Signals failure through the given port. + */ + final def fail[T](out: Outlet[T], ex: Throwable): Unit = interpreter.fail(conn(out), ex) + + /** + * Automatically invokes [[cancel()]] or [[complete()]] on all the input or output ports that have been called, + * then stops the stage, then [[postStop()]] is called. + */ + final def completeStage(): Unit = { + inToConn.valuesIterator.foreach(interpreter.cancel) + outToConn.valuesIterator.foreach(interpreter.complete) + } + + /** + * Automatically invokes [[cancel()]] or [[fail()]] on all the input or output ports that have been called, + * then stops the stage, then [[postStop()]] is called. + */ + final def failStage(ex: Throwable): Unit = { + inToConn.valuesIterator.foreach(interpreter.cancel) + outToConn.valuesIterator.foreach(interpreter.fail(_, ex)) + } + + /** + * Return true if the given output port is ready to be pushed. + */ + final def isAvailable[T](out: Outlet[T]): Boolean = interpreter.outAvailable(conn(out)) + + /** + * Obtain a callback object that can be used asynchronously to re-enter the + * current [[AsyncStage]] with an asynchronous notification. The [[invoke()]] method of the returned + * [[AsyncCallback]] is safe to be called from other threads and it will in the background thread-safely + * delegate to the passed callback function. I.e. [[invoke()]] will be called by the external world and + * the passed handler will be invoked eventually in a thread-safe way by the execution environment. + * + * This object can be cached and reused within the same [[GraphStageLogic]]. + */ + final def getAsyncCallback[T](handler: T ⇒ Unit): AsyncCallback[T] = { + new AsyncCallback[T] { + override def invoke(event: T): Unit = + interpreter.onAsyncInput(GraphStageLogic.this, event, handler.asInstanceOf[Any ⇒ Unit]) + } + } + + /** + * Invoked before any external events are processed, at the startup of the stage. + */ + def preStart(): Unit = () + + /** + * Invoked after processing of external events stopped because the stage is about to stop or fail. + */ + def postStop(): Unit = () +} + +/** + * Collection of callbacks for an input port of a [[GraphStage]] + */ +trait InHandler { + /** + * Called when the input port has a new element available. The actual element can be retrieved via the + * [[GraphStageLogic.grab()]] method. + */ + def onPush(): Unit + + /** + * Called when the input port is finished. After this callback no other callbacks will be called for this port. + */ + def onUpstreamFinish(): Unit = () + + /** + * Called when the input port has failed. After this callback no other callbacks will be called for this port. + */ + def onUpstreamFailure(ex: Throwable): Unit = () +} + +/** + * Collection of callbacks for an output port of a [[GraphStage]] + */ +trait OutHandler { + /** + * Called when the output port has received a pull, and therefore ready to emit an element, i.e. [[GraphStageLogic.push()]] + * is now allowed to be called on this port. + */ + def onPull(): Unit + + /** + * Called when the output port will no longer accept any new elements. After this callback no other callbacks will + * be called for this port. + */ + def onDownstreamFinish(): Unit = () +} \ No newline at end of file