diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala index 2ef8eaa305..4b0c607a3d 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala @@ -3,17 +3,14 @@ */ package akka.stream.impl.fusing -import java.util.concurrent.TimeoutException - -import akka.actor.{ ActorSystem, Cancellable, Scheduler } import akka.stream._ import akka.stream.scaladsl._ import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler } -import akka.stream.testkit.{ TestPublisher, AkkaSpec, TestSubscriber } +import akka.stream.testkit.AkkaSpec +import akka.stream.testkit.Utils._ import scala.concurrent.Await import scala.concurrent.duration._ -import akka.stream.testkit.Utils._ class ActorGraphInterpreterSpec extends AkkaSpec { implicit val mat = ActorMaterializer() @@ -29,6 +26,19 @@ class ActorGraphInterpreterSpec extends AkkaSpec { } + "be able to reuse a simple identity graph stage" in assertAllStagesStopped { + val identity = new GraphStages.Identity[Int] + + Await.result( + Source(1 to 100) + .via(identity) + .via(identity) + .via(identity) + .grouped(200) + .runWith(Sink.head), + 3.seconds) should ===(1 to 100) + } + "be able to interpret a simple bidi stage" in assertAllStagesStopped { val identityBidi = new GraphStage[BidiShape[Int, Int, Int, Int]] { val in1 = Inlet[Int]("in1") @@ -70,6 +80,52 @@ class ActorGraphInterpreterSpec extends AkkaSpec { } + "be able to interpret and resuse a simple bidi stage" in assertAllStagesStopped { + val identityBidi = new GraphStage[BidiShape[Int, Int, Int, Int]] { + val in1 = Inlet[Int]("in1") + val in2 = Inlet[Int]("in2") + val out1 = Outlet[Int]("out1") + val out2 = Outlet[Int]("out2") + val shape = BidiShape(in1, out1, in2, out2) + + override def createLogic: GraphStageLogic = new GraphStageLogic { + setHandler(in1, new InHandler { + override def onPush(): Unit = push(out1, grab(in1)) + + override def onUpstreamFinish(): Unit = complete(out1) + }) + + setHandler(in2, new InHandler { + override def onPush(): Unit = push(out2, grab(in2)) + + override def onUpstreamFinish(): Unit = complete(out2) + }) + + setHandler(out1, new OutHandler { + override def onPull(): Unit = pull(in1) + + override def onDownstreamFinish(): Unit = cancel(in1) + }) + + setHandler(out2, new OutHandler { + override def onPull(): Unit = pull(in2) + + override def onDownstreamFinish(): Unit = cancel(in2) + }) + } + + override def toString = "IdentityBidi" + } + + val identityBidiF = BidiFlow.wrap(identityBidi) + val identity = (identityBidiF atop identityBidiF atop identityBidiF).join(Flow[Int].map { x ⇒ x }) + + Await.result( + Source(1 to 10).via(identity).grouped(100).runWith(Sink.head), + 3.seconds) should ===(1 to 10) + + } + "be able to interpret a rotated identity bidi stage" in assertAllStagesStopped { // This is a "rotated" identity BidiStage, as it loops back upstream elements // to its upstream, and loops back downstream elementd to its downstream. @@ -84,21 +140,25 @@ class ActorGraphInterpreterSpec extends AkkaSpec { override def createLogic: GraphStageLogic = new GraphStageLogic { setHandler(in1, new InHandler { override def onPush(): Unit = push(out2, grab(in1)) + override def onUpstreamFinish(): Unit = complete(out2) }) setHandler(in2, new InHandler { override def onPush(): Unit = push(out1, grab(in2)) + override def onUpstreamFinish(): Unit = complete(out1) }) setHandler(out1, new OutHandler { override def onPull(): Unit = pull(in2) + override def onDownstreamFinish(): Unit = cancel(in2) }) setHandler(out2, new OutHandler { override def onPull(): Unit = pull(in1) + override def onDownstreamFinish(): Unit = cancel(in1) }) } @@ -124,113 +184,5 @@ class ActorGraphInterpreterSpec extends AkkaSpec { Await.result(f2, 3.seconds) should ===(1 to 10) } - "be able to implement a timeout bidiStage" in { - class IdleTimeout[I, O]( - val system: ActorSystem, - val timeout: FiniteDuration) extends GraphStage[BidiShape[I, I, O, O]] { - val in1 = Inlet[I]("in1") - val in2 = Inlet[O]("in2") - val out1 = Outlet[I]("out1") - val out2 = Outlet[O]("out2") - val shape = BidiShape(in1, out1, in2, out2) - - override def toString = "IdleTimeout" - - override def createLogic: GraphStageLogic = new GraphStageLogic { - private var timerCancellable: Option[Cancellable] = None - private var nextDeadline: Deadline = Deadline.now + timeout - - setHandler(in1, new InHandler { - override def onPush(): Unit = { - onActivity() - push(out1, grab(in1)) - } - override def onUpstreamFinish(): Unit = complete(out1) - }) - - setHandler(in2, new InHandler { - override def onPush(): Unit = { - onActivity() - push(out2, grab(in2)) - } - override def onUpstreamFinish(): Unit = complete(out2) - }) - - setHandler(out1, new OutHandler { - override def onPull(): Unit = pull(in1) - override def onDownstreamFinish(): Unit = cancel(in1) - }) - - setHandler(out2, new OutHandler { - override def onPull(): Unit = pull(in2) - override def onDownstreamFinish(): Unit = cancel(in2) - }) - - private def onActivity(): Unit = nextDeadline = Deadline.now + timeout - - private def onTimerTick(): Unit = - if (nextDeadline.isOverdue()) - failStage(new TimeoutException(s"No reads or writes happened in $timeout.")) - - override def preStart(): Unit = { - super.preStart() - val checkPeriod = timeout / 8 - val callback = getAsyncCallback[Unit]((_) ⇒ onTimerTick()) - import system.dispatcher - timerCancellable = Some(system.scheduler.schedule(timeout, checkPeriod)(callback.invoke(()))) - } - - override def postStop(): Unit = { - super.postStop() - timerCancellable.foreach(_.cancel()) - } - } - } - - val upWrite = TestPublisher.probe[String]() - val upRead = TestSubscriber.probe[Int]() - - val downWrite = TestPublisher.probe[Int]() - val downRead = TestSubscriber.probe[String]() - - FlowGraph.closed() { implicit b ⇒ - import FlowGraph.Implicits._ - val timeoutStage = b.add(new IdleTimeout[String, Int](system, 2.seconds)) - Source(upWrite) ~> timeoutStage.in1; timeoutStage.out1 ~> Sink(downRead) - Sink(upRead) <~ timeoutStage.out2; timeoutStage.in2 <~ Source(downWrite) - }.run() - - // Request enough for the whole test - upRead.request(100) - downRead.request(100) - - upWrite.sendNext("DATA1") - downRead.expectNext("DATA1") - Thread.sleep(1500) - - downWrite.sendNext(1) - upRead.expectNext(1) - Thread.sleep(1500) - - upWrite.sendNext("DATA2") - downRead.expectNext("DATA2") - Thread.sleep(1000) - - downWrite.sendNext(2) - upRead.expectNext(2) - - upRead.expectNoMsg(500.millis) - val error1 = upRead.expectError() - val error2 = downRead.expectError() - - error1.isInstanceOf[TimeoutException] should be(true) - error1.getMessage should be("No reads or writes happened in 2 seconds.") - error2 should ===(error1) - - upWrite.expectCancellation() - downWrite.expectCancellation() - } - } - } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala index 5c995789c2..39985f0841 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala @@ -7,7 +7,7 @@ import akka.actor._ import akka.event.Logging import akka.stream._ import akka.stream.impl.ReactiveStreamsCompliance._ -import akka.stream.impl.StreamLayout.Module +import akka.stream.impl.StreamLayout.{ CopiedModule, Module } import akka.stream.impl.fusing.GraphInterpreter.{ DownstreamBoundaryStageLogic, UpstreamBoundaryStageLogic, GraphAssembly } import akka.stream.impl.{ ActorPublisher, ReactiveStreamsCompliance } import akka.stream.stage.{ GraphStageLogic, InHandler, OutHandler } @@ -22,9 +22,13 @@ private[stream] case class GraphModule(assembly: GraphAssembly, shape: Shape, at override def subModules: Set[Module] = Set.empty override def withAttributes(newAttr: Attributes): Module = copy(attributes = newAttr) - override def carbonCopy: Module = copy() + override final def carbonCopy: Module = { + val newShape = shape.deepCopy() + replaceShape(newShape) + } - override def replaceShape(s: Shape): Module = ??? + override final def replaceShape(newShape: Shape): Module = + CopiedModule(newShape, attributes, copyOf = this) } /** @@ -50,7 +54,7 @@ private[stream] object ActorGraphInterpreter { final class BoundarySubscription(val parent: ActorRef, val id: Int) extends Subscription { override def request(elements: Long): Unit = parent ! RequestMore(id, elements) override def cancel(): Unit = parent ! Cancel(id) - override def toString = "BoundarySubscription" + System.identityHashCode(this) + override def toString = s"BoundarySubscription[$parent, $id]" } final class BoundarySubscriber(val parent: ActorRef, id: Int) extends Subscriber[Any] { @@ -306,6 +310,25 @@ private[stream] class ActorGraphInterpreter(assembly: GraphAssembly, shape: Shap } override def receive: Receive = { + // Cases that are most likely on the hot path, in decreasing order of frequency + case OnNext(id: Int, e: Any) ⇒ + if (GraphInterpreter.Debug) println(s" onNext $e id=$id") + inputs(id).onNext(e) + runBatch() + case RequestMore(id: Int, demand: Long) ⇒ + if (GraphInterpreter.Debug) println(s" request $demand id=$id") + outputs(id).requestMore(demand) + runBatch() + case Resume ⇒ + resumeScheduled = false + if (interpreter.isSuspended) runBatch() + case AsyncInput(logic, event, handler) ⇒ + if (GraphInterpreter.Debug) println(s"ASYNC $event") + if (!interpreter.isStageCompleted(logic.stageId)) + handler(event) + runBatch() + + // Initialization and completion messages case OnError(id: Int, cause: Throwable) ⇒ if (GraphInterpreter.Debug) println(s" onError id=$id") inputs(id).onError(cause) @@ -314,17 +337,8 @@ private[stream] class ActorGraphInterpreter(assembly: GraphAssembly, shape: Shap if (GraphInterpreter.Debug) println(s" onComplete id=$id") inputs(id).onComplete() runBatch() - case OnNext(id: Int, e: Any) ⇒ - if (GraphInterpreter.Debug) println(s" onNext $e id=$id") - inputs(id).onNext(e) - runBatch() case OnSubscribe(id: Int, subscription: Subscription) ⇒ inputs(id).onSubscribe(subscription) - - case RequestMore(id: Int, demand: Long) ⇒ - if (GraphInterpreter.Debug) println(s" request $demand id=$id") - outputs(id).requestMore(demand) - runBatch() case Cancel(id: Int) ⇒ if (GraphInterpreter.Debug) println(s" cancel id=$id") outputs(id).cancel() @@ -333,16 +347,6 @@ private[stream] class ActorGraphInterpreter(assembly: GraphAssembly, shape: Shap outputs(id).subscribePending() case ExposedPublisher(id, publisher) ⇒ outputs(id).exposedPublisher(publisher) - - case AsyncInput(_, event, handler) ⇒ - if (GraphInterpreter.Debug) println(s"ASYNC $event") - handler(event) - runBatch() - - case Resume ⇒ - resumeScheduled = false - if (interpreter.isSuspended) runBatch() - } override protected[akka] def aroundReceive(receive: Actor.Receive, msg: Any): Unit = { diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala index edcd0532eb..615fa6ea03 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala @@ -255,6 +255,7 @@ private[stream] final class GraphInterpreter( def init(): Unit = { var i = 0 while (i < logics.length) { + logics(i).stageId = i logics(i).preStart() i += 1 } @@ -373,7 +374,7 @@ private[stream] final class GraphInterpreter( def isConnectionCompleted(connection: Int): Boolean = connectionStates(connection).isInstanceOf[CompletedState] // Returns true if the given stage is alredy completed - private def isStageCompleted(stageId: Int): Boolean = stageId != Boundary && shutdownCounter(stageId) == 0 + def isStageCompleted(stageId: Int): Boolean = stageId != Boundary && shutdownCounter(stageId) == 0 private def isPushInFlight(connection: Int): Boolean = !inAvailable(connection) && diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala index 0b1aa4073c..c9c3f8ebdc 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala @@ -15,7 +15,7 @@ object GraphStages { val in = Inlet[T]("in") val out = Outlet[T]("out") - val shape = FlowShape(in, out) + override val shape = FlowShape(in, out) override def createLogic: GraphStageLogic = new GraphStageLogic { setHandler(in, new InHandler { @@ -36,7 +36,7 @@ object GraphStages { class Detacher[T] extends GraphStage[FlowShape[T, T]] { val in = Inlet[T]("in") val out = Outlet[T]("out") - val shape = FlowShape(in, out) + override val shape = FlowShape(in, out) override def createLogic: GraphStageLogic = new GraphStageLogic { var initialized = false @@ -70,7 +70,7 @@ object GraphStages { class Broadcast[T](private val outCount: Int) extends GraphStage[UniformFanOutShape[T, T]] { val in = Inlet[T]("in") val out = Vector.fill(outCount)(Outlet[T]("out")) - val shape = UniformFanOutShape(in, out: _*) + override val shape = UniformFanOutShape(in, out: _*) override def createLogic: GraphStageLogic = new GraphStageLogic { private var pending = outCount @@ -101,7 +101,7 @@ object GraphStages { val in0 = Inlet[A]("in0") val in1 = Inlet[B]("in1") val out = Outlet[(A, B)]("out") - val shape = new FanInShape2[A, B, (A, B)](in0, in1, out) + override val shape = new FanInShape2[A, B, (A, B)](in0, in1, out) override def createLogic: GraphStageLogic = new GraphStageLogic { var pending = 2 @@ -130,7 +130,7 @@ object GraphStages { class Merge[T](private val inCount: Int) extends GraphStage[UniformFanInShape[T, T]] { val in = Vector.fill(inCount)(Inlet[T]("in")) val out = Outlet[T]("out") - val shape = UniformFanInShape(out, in: _*) + override val shape = UniformFanInShape(out, in: _*) override def createLogic: GraphStageLogic = new GraphStageLogic { private var initialized = false @@ -187,7 +187,7 @@ object GraphStages { class Balance[T](private val outCount: Int) extends GraphStage[UniformFanOutShape[T, T]] { val in = Inlet[T]("in") val out = Vector.fill(outCount)(Outlet[T]("out")) - val shape = UniformFanOutShape[T, T](in, out: _*) + override val shape = UniformFanOutShape[T, T](in, out: _*) override def createLogic: GraphStageLogic = new GraphStageLogic { private val pendingQueue = Array.ofDim[Outlet[T]](outCount) diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala index a1c5e83e5c..f3951a5f1b 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala @@ -44,7 +44,7 @@ final class Source[+Out, +Mat](private[stream] override val module: Module) new Source( module .fuse(flowCopy, shape.outlet, flowCopy.shape.inlets.head, combine) - .replaceShape(SourceShape(flowCopy.shape.outlets.head))) // FIXME why is not .wrap() needed here? + .replaceShape(SourceShape(flowCopy.shape.outlets.head))) } } diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index c90e886bf7..f007bdb302 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -43,9 +43,12 @@ abstract class GraphStage[S <: Shape] extends Graph[S, Unit] { * This method throws an [[UnsupportedOperationException]] by default. The subclass can override this method * and provide a correct implementation that creates an exact copy of the stage with the provided new attributes. */ - override def withAttributes(attr: Attributes): Graph[S, Unit] = - throw new UnsupportedOperationException( - "withAttributes not supported by default by FlexiMerge, subclass may override and implement it") + final override def withAttributes(attr: Attributes): Graph[S, Unit] = new Graph[S, Unit] { + override def shape = GraphStage.this.shape + override private[stream] def module = GraphStage.this.module.withAttributes(attr) + + override def withAttributes(attr: Attributes) = GraphStage.this.withAttributes(attr) + } } /** @@ -64,6 +67,11 @@ abstract class GraphStage[S <: Shape] extends Graph[S, Unit] { abstract class GraphStageLogic { import GraphInterpreter._ + /** + * INTERNAL API + */ + private[stream] var stageId: Int = Int.MinValue + /** * INTERNAL API */