From 5f75af70596c9296665f1d92aa3c3af3ef8f3a91 Mon Sep 17 00:00:00 2001 From: Patrik Nordwall Date: Thu, 30 Oct 2014 09:13:25 +0100 Subject: [PATCH] +str #15205 Add FlexiRoute junction --- .../akka/stream/testkit/StreamTestKit.scala | 4 + .../stream/scaladsl/GraphFlexiRouteSpec.scala | 427 ++++++++++++++++++ .../impl/ActorBasedFlowMaterializer.scala | 13 +- .../main/scala/akka/stream/impl/FanOut.scala | 69 ++- .../akka/stream/impl/FlexiRouteImpl.scala | 145 ++++++ .../akka/stream/scaladsl/FlexiMerge.scala | 14 +- .../akka/stream/scaladsl/FlexiRoute.scala | 269 +++++++++++ 7 files changed, 923 insertions(+), 18 deletions(-) create mode 100644 akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlexiRouteSpec.scala create mode 100644 akka-stream/src/main/scala/akka/stream/impl/FlexiRouteImpl.scala create mode 100644 akka-stream/src/main/scala/akka/stream/scaladsl/FlexiRoute.scala diff --git a/akka-stream-testkit/src/test/scala/akka/stream/testkit/StreamTestKit.scala b/akka-stream-testkit/src/test/scala/akka/stream/testkit/StreamTestKit.scala index 41a1ad6b9e..c76924f2da 100644 --- a/akka-stream-testkit/src/test/scala/akka/stream/testkit/StreamTestKit.scala +++ b/akka-stream-testkit/src/test/scala/akka/stream/testkit/StreamTestKit.scala @@ -54,6 +54,10 @@ object StreamTestKit { pendingRequests -= 1 subscription.sendNext(elem) } + + def sendComplete(): Unit = subscription.sendComplete() + + def sendError(cause: Exception): Unit = subscription.sendError(cause) } sealed trait SubscriberEvent diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlexiRouteSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlexiRouteSpec.scala new file mode 100644 index 0000000000..f9868db865 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphFlexiRouteSpec.scala @@ -0,0 +1,427 @@ +package akka.stream.scaladsl + +import scala.concurrent.duration._ +import scala.util.control.NoStackTrace +import FlowGraphImplicits._ +import akka.stream.FlowMaterializer +import akka.stream.testkit.AkkaSpec +import akka.stream.testkit.StreamTestKit.AutoPublisher +import akka.stream.testkit.StreamTestKit.OnNext +import akka.stream.testkit.StreamTestKit.PublisherProbe +import akka.stream.testkit.StreamTestKit.SubscriberProbe +import akka.actor.ActorSystem + +object GraphRouteSpec { + + /** + * This is fair in that sense that after enqueueing to an output it yields to other output if + * they are have requested elements. Or in other words, if all outputs have demand available at the same + * time then in finite steps all elements are enqueued to them. + */ + class Fair[T] extends FlexiRoute[T]("fairRoute") { + import FlexiRoute._ + val out1 = createOutputPort[T]() + val out2 = createOutputPort[T]() + + override def createRouteLogic: RouteLogic[T] = new RouteLogic[T] { + override def outputHandles(outputCount: Int) = Vector(out1, out2) + + val emitToAnyWithDemand = State[T](DemandFromAny(out1, out2)) { (ctx, preferredOutput, element) ⇒ + ctx.emit(preferredOutput, element) + SameState + } + + // initally, wait for demand from all + override def initialState = State[T](DemandFromAll(out1, out2)) { (ctx, preferredOutput, element) ⇒ + ctx.emit(preferredOutput, element) + emitToAnyWithDemand + } + } + } + + /** + * It never skips an output while cycling but waits on it instead (closed outputs are skipped though). + * The fair route above is a non-strict round-robin (skips currently unavailable outputs). + */ + class StrictRoundRobin[T] extends FlexiRoute[T]("roundRobinRoute") { + import FlexiRoute._ + val out1 = createOutputPort[T]() + val out2 = createOutputPort[T]() + + override def createRouteLogic = new RouteLogic[T] { + + override def outputHandles(outputCount: Int) = Vector(out1, out2) + + val toOutput1: State[T] = State[T](DemandFrom(out1)) { (ctx, _, element) ⇒ + ctx.emit(out1, element) + toOutput2 + } + + val toOutput2 = State[T](DemandFrom(out2)) { (ctx, _, element) ⇒ + ctx.emit(out2, element) + toOutput1 + } + + override def initialState = toOutput1 + } + } + + class Unzip[A, B] extends FlexiRoute[(A, B)]("unzip") { + import FlexiRoute._ + val outA = createOutputPort[A]() + val outB = createOutputPort[B]() + + override def createRouteLogic() = new RouteLogic[(A, B)] { + var lastInA: Option[A] = None + var lastInB: Option[B] = None + + override def outputHandles(outputCount: Int) = { + require(outputCount == 2, s"Unzip must have two connected outputs, was $outputCount") + Vector(outA, outB) + } + + override def initialState = State[Any](DemandFromAll(outA, outB)) { (ctx, _, element) ⇒ + val (a, b) = element + ctx.emit(outA, a) + ctx.emit(outB, b) + SameState + } + + override def initialCompletionHandling = eagerClose + } + } + + class TestRoute extends FlexiRoute[String]("testRoute") { + import FlexiRoute._ + val output1 = createOutputPort[String]() + val output2 = createOutputPort[String]() + val output3 = createOutputPort[String]() + + def createRouteLogic: RouteLogic[String] = new RouteLogic[String] { + val handles = Vector(output1, output2, output3) + override def outputHandles(outputCount: Int) = handles + + override def initialState = State[String](DemandFromAny(handles)) { + (ctx, preferred, element) ⇒ + if (element == "err") + ctx.error(new RuntimeException("err") with NoStackTrace) + else if (element == "err-output1") + ctx.error(output1, new RuntimeException("err-1") with NoStackTrace) + else if (element == "complete") + ctx.complete() + else + ctx.emit(preferred, "onInput: " + element) + + SameState + } + + override def initialCompletionHandling = CompletionHandling( + onComplete = { ctx ⇒ + handles.foreach { output ⇒ + if (ctx.isDemandAvailable(output)) + ctx.emit(output, "onComplete") + } + }, + onError = { (ctx, cause) ⇒ + cause match { + case _: IllegalArgumentException ⇒ // swallow + case _ ⇒ + handles.foreach { output ⇒ + if (ctx.isDemandAvailable(output)) + ctx.emit(output, "onError") + } + } + }, + onCancel = { (ctx, cancelledOutput) ⇒ + handles.foreach { output ⇒ + if (output != cancelledOutput && ctx.isDemandAvailable(output)) + ctx.emit(output, "onCancel: " + cancelledOutput.portIndex) + } + SameState + }) + } + } + + class TestFixture(implicit val system: ActorSystem, implicit val materializer: FlowMaterializer) { + val publisher = PublisherProbe[String] + val s1 = SubscriberProbe[String] + val s2 = SubscriberProbe[String] + FlowGraph { implicit b ⇒ + val route = new TestRoute + Source(publisher) ~> route.in + route.output1 ~> Sink(s1) + route.output2 ~> Sink(s2) + }.run() + + val autoPublisher = new AutoPublisher(publisher) + autoPublisher.sendNext("a") + autoPublisher.sendNext("b") + + val sub1 = s1.expectSubscription() + val sub2 = s2.expectSubscription() + } + +} + +@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner]) +class GraphRouteSpec extends AkkaSpec { + import GraphRouteSpec._ + + implicit val materializer = FlowMaterializer() + + val in = Source(List("a", "b", "c", "d", "e")) + + val out1 = Sink.publisher[String] + val out2 = Sink.publisher[String] + + "FlexiRoute" must { + + "build simple fair route" in { + val m = FlowGraph { implicit b ⇒ + val route = new Fair[String] + in ~> route.in + route.out1 ~> out1 + route.out2 ~> out2 + }.run() + + val s1 = SubscriberProbe[String] + val p1 = m.get(out1) + p1.subscribe(s1) + val sub1 = s1.expectSubscription() + val s2 = SubscriberProbe[String] + val p2 = m.get(out2) + p2.subscribe(s2) + val sub2 = s2.expectSubscription() + + sub1.request(10) + sub2.request(10) + + s1.expectNext("a") + s2.expectNext("b") + s1.expectNext("c") + s2.expectNext("d") + s1.expectNext("e") + + s1.expectComplete() + s2.expectComplete() + } + + "build simple round-robin route" in { + val m = FlowGraph { implicit b ⇒ + val route = new StrictRoundRobin[String] + in ~> route.in + route.out1 ~> out1 + route.out2 ~> out2 + }.run() + + val s1 = SubscriberProbe[String] + val p1 = m.get(out1) + p1.subscribe(s1) + val sub1 = s1.expectSubscription() + val s2 = SubscriberProbe[String] + val p2 = m.get(out2) + p2.subscribe(s2) + val sub2 = s2.expectSubscription() + + sub1.request(10) + sub2.request(10) + + s1.expectNext("a") + s2.expectNext("b") + s1.expectNext("c") + s2.expectNext("d") + s1.expectNext("e") + + s1.expectComplete() + s2.expectComplete() + } + + "build simple unzip route" in { + val outA = Sink.publisher[Int] + val outB = Sink.publisher[String] + + val m = FlowGraph { implicit b ⇒ + val route = new Unzip[Int, String] + Source(List(1 -> "A", 2 -> "B", 3 -> "C", 4 -> "D")) ~> route.in + route.outA ~> outA + route.outB ~> outB + }.run() + + val s1 = SubscriberProbe[Int] + val p1 = m.get(outA) + p1.subscribe(s1) + val sub1 = s1.expectSubscription() + val s2 = SubscriberProbe[String] + val p2 = m.get(outB) + p2.subscribe(s2) + val sub2 = s2.expectSubscription() + + sub1.request(3) + sub2.request(4) + + s1.expectNext(1) + s2.expectNext("A") + s1.expectNext(2) + s2.expectNext("B") + s1.expectNext(3) + s2.expectNext("C") + sub1.cancel() + + s2.expectComplete() + } + + "support complete of downstreams and cancel of upstream" in { + val fixture = new TestFixture + import fixture._ + + autoPublisher.sendNext("complete") + + sub1.request(1) + s1.expectNext("onInput: a") + sub2.request(2) + s2.expectNext("onInput: b") + + s1.expectComplete() + s2.expectComplete() + } + + "support error of outputs" in { + val fixture = new TestFixture + import fixture._ + + autoPublisher.sendNext("err") + + sub1.request(1) + s1.expectNext("onInput: a") + sub2.request(2) + s2.expectNext("onInput: b") + + s1.expectError().getMessage should be("err") + s2.expectError().getMessage should be("err") + autoPublisher.subscription.expectCancellation() + } + + "support error of a specific output" in { + val fixture = new TestFixture + import fixture._ + + sub1.request(1) + s1.expectNext("onInput: a") + sub2.request(1) + s2.expectNext("onInput: b") + + sub1.request(5) + sub2.request(5) + autoPublisher.sendNext("err-output1") + autoPublisher.sendNext("c") + + s2.expectNext("onInput: c") + s1.expectError().getMessage should be("err-1") + + autoPublisher.sendComplete() + s2.expectNext("onComplete") + s2.expectComplete() + } + + "handle cancel from output" in { + val fixture = new TestFixture + import fixture._ + + sub1.request(1) + s1.expectNext("onInput: a") + sub2.request(1) + s2.expectNext("onInput: b") + + sub1.request(2) + sub2.request(2) + sub1.cancel() + + s2.expectNext("onCancel: 0") + s1.expectNoMsg(200.millis) + + autoPublisher.sendNext("c") + s2.expectNext("onInput: c") + + autoPublisher.sendComplete() + s2.expectComplete() + } + + "handle complete from upstream input" in { + val fixture = new TestFixture + import fixture._ + + sub1.request(1) + s1.expectNext("onInput: a") + sub2.request(1) + s2.expectNext("onInput: b") + + sub1.request(2) + sub2.request(2) + autoPublisher.sendComplete() + + s1.expectNext("onComplete") + s2.expectNext("onComplete") + + s1.expectComplete() + s2.expectComplete() + } + + "handle error from upstream input" in { + val fixture = new TestFixture + import fixture._ + + sub1.request(1) + s1.expectNext("onInput: a") + sub2.request(1) + s2.expectNext("onInput: b") + + sub1.request(2) + sub2.request(2) + autoPublisher.sendError(new RuntimeException("test err") with NoStackTrace) + + s1.expectNext("onError") + s2.expectNext("onError") + + s1.expectError().getMessage should be("test err") + s2.expectError().getMessage should be("test err") + } + + "cancel upstream input when all outputs cancelled" in { + val fixture = new TestFixture + import fixture._ + + sub1.request(1) + s1.expectNext("onInput: a") + sub2.request(1) + s2.expectNext("onInput: b") + + sub1.request(2) + sub2.request(2) + sub1.cancel() + + s2.expectNext("onCancel: 0") + sub2.cancel() + + autoPublisher.subscription.expectCancellation() + } + + "cancel upstream input when all outputs completed" in { + val fixture = new TestFixture + import fixture._ + + sub1.request(1) + s1.expectNext("onInput: a") + sub2.request(1) + s2.expectNext("onInput: b") + + sub1.request(2) + sub2.request(2) + autoPublisher.sendNext("complete") + s1.expectComplete() + s2.expectComplete() + autoPublisher.subscription.expectCancellation() + } + + } +} + diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala index 1fc693feb5..0830199446 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala @@ -114,7 +114,11 @@ private[akka] object Ast { } case class FlexiMergeNode(merger: FlexiMerge[Any]) extends FanInAstNode { - override def name = merger.name.getOrElse("") + override def name = merger.name.getOrElse("flexMerge") + } + + case class RouteNode(route: FlexiRoute[Any]) extends FanOutAstNode { + override def name = route.name.getOrElse("route") } } @@ -232,7 +236,7 @@ case class ActorBasedFlowMaterializer(override val settings: MaterializerSetting op match { case fanin: Ast.FanInAstNode ⇒ - val impl = op match { + val impl = fanin match { case Ast.Merge ⇒ actorOf(FairMerge.props(settings, inputCount).withDispatcher(settings.dispatcher), actorName) case Ast.MergePreferred ⇒ @@ -252,13 +256,16 @@ case class ActorBasedFlowMaterializer(override val settings: MaterializerSetting (subscribers, List(publisher)) case fanout: Ast.FanOutAstNode ⇒ - val impl = op match { + val impl = fanout match { case Ast.Broadcast ⇒ actorOf(Broadcast.props(settings, outputCount).withDispatcher(settings.dispatcher), actorName) case Ast.Balance(waitForAllDownstreams) ⇒ actorOf(Balance.props(settings, outputCount, waitForAllDownstreams).withDispatcher(settings.dispatcher), actorName) case Ast.Unzip ⇒ actorOf(Unzip.props(settings).withDispatcher(settings.dispatcher), actorName) + case Ast.RouteNode(route) ⇒ + actorOf(FlexiRouteImpl.props(settings, outputCount, route.createRouteLogic()). + withDispatcher(settings.dispatcher), actorName) } val publishers = Vector.tabulate(outputCount)(id ⇒ new ActorPublisher[Out](impl) { diff --git a/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala b/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala index 6924d12e4c..a8dd5cb71b 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala @@ -46,21 +46,51 @@ private[akka] object FanOut { private var markedPending = 0 private val cancelled = Array.ofDim[Boolean](outputCount) private var markedCancelled = 0 + private val completed = Array.ofDim[Boolean](outputCount) + private val errored = Array.ofDim[Boolean](outputCount) private var unmarkCancelled = true private var preferredId = 0 + def isPending(output: Int): Boolean = pending(output) + + def isCompleted(output: Int): Boolean = completed(output) + + def isCancelled(output: Int): Boolean = cancelled(output) + def complete(): Unit = if (!bunchCancelled) { bunchCancelled = true - outputs foreach (_.complete()) + var i = 0 + while (i < outputs.length) { + complete(i) + i += 1 + } + } + + def complete(output: Int) = + if (!completed(output) && !errored(output) && !cancelled(output)) { + outputs(output).complete() + completed(output) = true + unmarkOutput(output) } def cancel(e: Throwable): Unit = if (!bunchCancelled) { bunchCancelled = true - outputs foreach (_.cancel(e)) + var i = 0 + while (i < outputs.length) { + error(i, e) + i += 1 + } + } + + def error(output: Int, e: Throwable): Unit = + if (!errored(output) && !cancelled(output) && !completed(output)) { + outputs(output).cancel(e) + errored(output) = true + unmarkOutput(output) } def markOutput(output: Int): Unit = { @@ -81,9 +111,25 @@ private[akka] object FanOut { } } + def markAllOutputs(): Unit = { + var i = 0 + while (i < outputCount) { + markOutput(i) + i += 1 + } + } + + def unmarkAllOutputs(): Unit = { + var i = 0 + while (i < outputCount) { + unmarkOutput(i) + i += 1 + } + } + def unmarkCancelledOutputs(enabled: Boolean): Unit = unmarkCancelled = enabled - private def idToEnqueue(): Int = { + def idToEnqueue(): Int = { var id = preferredId while (!(marked(id) && pending(id))) { id += 1 @@ -110,10 +156,15 @@ private[akka] object FanOut { } } - def enqueueAndYield(elem: Any): Unit = { + def idToEnqueueAndYield(): Int = { val id = idToEnqueue() preferredId = id + 1 if (preferredId == outputCount) preferredId = 0 + id + } + + def enqueueAndYield(elem: Any): Unit = { + val id = idToEnqueueAndYield() enqueue(id, elem) } @@ -123,6 +174,8 @@ private[akka] object FanOut { enqueue(id, elem) } + def onCancel(output: Int): Unit = () + /** * Will only transfer an element when all marked outputs * have demand, and will complete as soon as any of the marked @@ -161,6 +214,7 @@ private[akka] object FanOut { } if (marked(id) && !cancelled(id)) markedCancelled += 1 cancelled(id) = true + onCancel(id) outputs(id).subreceive(Cancel(null)) case SubstreamSubscribePending(id) ⇒ outputs(id).subreceive(SubscribePending) @@ -222,7 +276,7 @@ private[akka] object Broadcast { * INTERNAL API */ private[akka] class Broadcast(_settings: MaterializerSettings, _outputPorts: Int) extends FanOut(_settings, _outputPorts) { - (0 until outputPorts) foreach outputBunch.markOutput + outputBunch.markAllOutputs() nextPhase(TransferPhase(primaryInputs.NeedsInput && outputBunch.AllOfMarkedOutputs) { () ⇒ val elem = primaryInputs.dequeueInputElement() @@ -242,7 +296,7 @@ private[akka] object Balance { * INTERNAL API */ private[akka] class Balance(_settings: MaterializerSettings, _outputPorts: Int, waitForAllDownstreams: Boolean) extends FanOut(_settings, _outputPorts) { - (0 until outputPorts) foreach outputBunch.markOutput + outputBunch.markAllOutputs() val runningPhase = TransferPhase(primaryInputs.NeedsInput && outputBunch.AnyOfMarkedOutputs) { () ⇒ val elem = primaryInputs.dequeueInputElement() @@ -269,7 +323,7 @@ private[akka] object Unzip { * INTERNAL API */ private[akka] class Unzip(_settings: MaterializerSettings) extends FanOut(_settings, outputPorts = 2) { - (0 until outputPorts) foreach outputBunch.markOutput + outputBunch.markAllOutputs() nextPhase(TransferPhase(primaryInputs.NeedsInput && outputBunch.AllOfMarkedOutputs) { () ⇒ primaryInputs.dequeueInputElement() match { @@ -288,4 +342,3 @@ private[akka] class Unzip(_settings: MaterializerSettings) extends FanOut(_setti } }) } - diff --git a/akka-stream/src/main/scala/akka/stream/impl/FlexiRouteImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/FlexiRouteImpl.scala new file mode 100644 index 0000000000..3746159b56 --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/impl/FlexiRouteImpl.scala @@ -0,0 +1,145 @@ +/** + * Copyright (C) 2014 Typesafe Inc. + */ +package akka.stream.impl + +import scala.collection.breakOut +import akka.actor.Props +import akka.stream.scaladsl.FlexiRoute +import akka.stream.MaterializerSettings +import akka.stream.impl.FanOut.OutputBunch + +/** + * INTERNAL API + */ +private[akka] object FlexiRouteImpl { + def props(settings: MaterializerSettings, outputCount: Int, routeLogic: FlexiRoute.RouteLogic[Any]): Props = + Props(new FlexiRouteImpl(settings, outputCount, routeLogic)) +} + +/** + * INTERNAL API + */ +private[akka] class FlexiRouteImpl(_settings: MaterializerSettings, + outputCount: Int, + routeLogic: FlexiRoute.RouteLogic[Any]) + extends FanOut(_settings, outputCount) { + + import FlexiRoute._ + + val outputMapping: Map[Int, OutputHandle] = + routeLogic.outputHandles(outputCount).take(outputCount).zipWithIndex.map(_.swap)(breakOut) + + private type StateT = routeLogic.State[Any] + private type CompletionT = routeLogic.CompletionHandling + + private var behavior: StateT = _ + private var completion: CompletionT = _ + + override protected val outputBunch = new OutputBunch(outputPorts, self, this) { + override def onCancel(output: Int): Unit = + changeBehavior(completion.onCancel(ctx, outputMapping(output))) + } + + override protected val primaryInputs: Inputs = new BatchingInputBuffer(settings.maxInputBufferSize, this) { + override def onError(e: Throwable): Unit = { + completion.onError(ctx, e) + fail(e) + } + + override def onComplete(): Unit = { + completion.onComplete(ctx) + super.onComplete() + } + } + + private val ctx: routeLogic.RouteLogicContext[Any] = new routeLogic.RouteLogicContext[Any] { + override def isDemandAvailable(output: OutputHandle): Boolean = + (output.portIndex < outputCount) && outputBunch.isPending(output.portIndex) + + override def emit(output: OutputHandle, elem: Any): Unit = { + require(outputBunch.isPending(output.portIndex), s"emit to [$output] not allowed when no demand available") + outputBunch.enqueue(output.portIndex, elem) + } + + override def complete(): Unit = { + primaryInputs.cancel() + outputBunch.complete() + context.stop(self) + } + + override def complete(output: OutputHandle): Unit = + outputBunch.complete(output.portIndex) + + override def error(cause: Throwable): Unit = fail(cause) + + override def error(output: OutputHandle, cause: Throwable): Unit = + outputBunch.error(output.portIndex, cause) + + override def changeCompletionHandling(newCompletion: CompletionT): Unit = + FlexiRouteImpl.this.changeCompletionHandling(newCompletion) + + } + + private def markOutputs(outputs: Array[OutputHandle]): Unit = { + outputBunch.unmarkAllOutputs() + var i = 0 + while (i < outputs.length) { + val id = outputs(i).portIndex + if (outputMapping.contains(id) && !outputBunch.isCancelled(id) && !outputBunch.isCompleted(id)) + outputBunch.markOutput(id) + i += 1 + } + } + + private def precondition: TransferState = { + behavior.condition match { + case _: DemandFrom | _: DemandFromAny ⇒ primaryInputs.NeedsInput && outputBunch.AnyOfMarkedOutputs + case _: DemandFromAll ⇒ primaryInputs.NeedsInput && outputBunch.AllOfMarkedOutputs + } + } + + private def changeCompletionHandling(newCompletion: CompletionT): Unit = + completion = newCompletion.asInstanceOf[CompletionT] + + private def changeBehavior[A](newBehavior: routeLogic.State[A]): Unit = + if (newBehavior != routeLogic.SameState && (newBehavior ne behavior)) { + behavior = newBehavior.asInstanceOf[StateT] + behavior.condition match { + case any: DemandFromAny ⇒ + markOutputs(any.outputs.toArray) + case all: DemandFromAll ⇒ + markOutputs(all.outputs.toArray) + case DemandFrom(output) ⇒ + require(outputMapping.contains(output.portIndex), s"Unknown output handle $output") + require(!outputBunch.isCancelled(output.portIndex), s"Demand not allowed from cancelled $output") + require(!outputBunch.isCompleted(output.portIndex), s"Demand not allowed from completed $output") + outputBunch.unmarkAllOutputs() + outputBunch.markOutput(output.portIndex) + } + } + + changeBehavior(routeLogic.initialState) + changeCompletionHandling(routeLogic.initialCompletionHandling) + + nextPhase(TransferPhase(precondition) { () ⇒ + val elem = primaryInputs.dequeueInputElement() + behavior.condition match { + case any: DemandFromAny ⇒ + val id = outputBunch.idToEnqueueAndYield() + val outputHandle = outputMapping(id) + changeBehavior(behavior.onInput(ctx, outputHandle, elem)) + + case DemandFrom(outputHandle) ⇒ + changeBehavior(behavior.onInput(ctx, outputHandle, elem)) + + case all: DemandFromAll ⇒ + val id = outputBunch.idToEnqueueAndYield() + val outputHandle = outputMapping(id) + changeBehavior(behavior.onInput(ctx, outputHandle, elem)) + + } + + }) + +} diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/FlexiMerge.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/FlexiMerge.scala index adf366d316..ed895ccf2e 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/FlexiMerge.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/FlexiMerge.scala @@ -34,8 +34,8 @@ object FlexiMerge { sealed trait ReadCondition /** - * Read condition for the [[MergeLogic#State]] that will is - * fulfilled when there are elements for one specfic upstream + * Read condition for the [[MergeLogic#State]] that will be + * fulfilled when there are elements for one specific upstream * input. * * It is not allowed to use a handle that has been cancelled or @@ -48,7 +48,7 @@ object FlexiMerge { def apply(inputs: immutable.Seq[InputHandle]): ReadAny = new ReadAny(inputs: _*) } /** - * Read condition for the [[MergeLogic#State]] that will is + * Read condition for the [[MergeLogic#State]] that will be * fulfilled when there are elements for any of the given upstream * inputs. * @@ -69,7 +69,7 @@ object FlexiMerge { def initialCompletionHandling: CompletionHandling = defaultCompletionHandling /** - * Context that is passed to the methods of [[State]] and [[CompletionHandling]]. + * Context that is passed to the functions of [[State]] and [[CompletionHandling]]. * The context provides means for performing side effects, such as emitting elements * downstream. */ @@ -87,7 +87,7 @@ object FlexiMerge { def emit(elem: Out): Unit /** - * Complete this stream succesfully. Upstream subscriptions will be cancelled. + * Complete this stream successfully. Upstream subscriptions will be cancelled. */ def complete(): Unit @@ -136,12 +136,12 @@ object FlexiMerge { /** * How to handle completion or error from upstream input. * - * The `onComplete` function is called when an upstream input was completed sucessfully. + * The `onComplete` function is called when an upstream input was completed successfully. * It returns next behavior or [[#SameState]] to keep current behavior. * A completion can be propagated downstream with [[MergeLogicContext#complete]], * or it can be swallowed to continue with remaining inputs. * - * The `onError` function is called when an upstream input was completed sucessfully. + * The `onError` function is called when an upstream input was completed with failure. * It returns next behavior or [[#SameState]] to keep current behavior. * An error can be propagated downstream with [[MergeLogicContext#error]], * or it can be swallowed to continue with remaining inputs. diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/FlexiRoute.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/FlexiRoute.scala new file mode 100644 index 0000000000..9e968286f2 --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/FlexiRoute.scala @@ -0,0 +1,269 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ +package akka.stream.scaladsl + +import scala.collection.immutable +import akka.stream.impl.Ast + +object FlexiRoute { + + /** + * @see [[OutputPort]] + */ + sealed trait OutputHandle { + private[akka] def portIndex: Int + } + + /** + * An `OutputPort` can be connected to a [[Sink]] with the [[FlowGraphBuilder]]. + * The `OutputPort` is also an [[OutputHandle]] which you use to define to which + * downstream output to emit an element. + */ + class OutputPort[In, Out] private[akka] (override private[akka] val port: Int, parent: FlexiRoute[In]) + extends JunctionOutPort[Out] with OutputHandle { + + override private[akka] def vertex = parent.vertex + + override private[akka] def portIndex: Int = port + + override def toString: String = s"OutputPort($port)" + } + + sealed trait DemandCondition + + /** + * Demand condition for the [[RouteLogic#State]] that will be + * fulfilled when there are requests for elements from one specific downstream + * output. + * + * It is not allowed to use a handle that has been cancelled or + * has been completed. `IllegalArgumentException` is thrown if + * that is not obeyed. + */ + final case class DemandFrom(output: OutputHandle) extends DemandCondition + + object DemandFromAny { + def apply(outputs: immutable.Seq[OutputHandle]): DemandFromAny = new DemandFromAny(outputs: _*) + } + /** + * Demand condition for the [[RouteLogic#State]] that will be + * fulfilled when there are requests for elements from any of the given downstream + * outputs. + * + * Cancelled and completed inputs are not used, i.e. it is allowed + * to specify them in the list of `outputs`. + */ + final case class DemandFromAny(outputs: OutputHandle*) extends DemandCondition + + object DemandFromAll { + def apply(outputs: immutable.Seq[OutputHandle]): DemandFromAll = new DemandFromAll(outputs: _*) + } + /** + * Demand condition for the [[RouteLogic#State]] that will be + * fulfilled when there are requests for elements from all of the given downstream + * outputs. + * + * Cancelled and completed inputs are not used, i.e. it is allowed + * to specify them in the list of `outputs`. + */ + final case class DemandFromAll(outputs: OutputHandle*) extends DemandCondition + + /** + * The possibly stateful logic that reads from the input and enables emitting to downstream + * via the defined [[State]]. Handles completion, error and cancel via the defined + * [[CompletionHandling]]. + * + * Concrete instance is supposed to be created by implementing [[FlexiRoute#createRouteLogic]]. + */ + abstract class RouteLogic[In] { + def outputHandles(outputCount: Int): immutable.IndexedSeq[OutputHandle] + def initialState: State[_] + def initialCompletionHandling: CompletionHandling = defaultCompletionHandling + + /** + * Context that is passed to the functions of [[State]] and [[CompletionHandling]]. + * The context provides means for performing side effects, such as emitting elements + * downstream. + */ + trait RouteLogicContext[Out] { + /** + * @return `true` if at least one element has been requested by the given downstream (output). + */ + def isDemandAvailable(output: OutputHandle): Boolean + + /** + * Emit one element downstream. It is only allowed to `emit` when + * [[#isDemandAvailable]] is `true` for the given `output`, otherwise + * `IllegalArgumentException` is thrown. + */ + def emit(output: OutputHandle, elem: Out): Unit + + /** + * Complete the given downstream successfully. + */ + def complete(output: OutputHandle): Unit + + /** + * Complete all downstreams successfully and cancel upstream. + */ + def complete(): Unit + + /** + * Complete the given downstream with failure. + */ + def error(output: OutputHandle, cause: Throwable): Unit + + /** + * Complete all downstreams with failure and cancel upstream. + */ + def error(cause: Throwable): Unit + + /** + * Replace current [[CompletionHandling]]. + */ + def changeCompletionHandling(completion: CompletionHandling): Unit + } + + /** + * Definition of which outputs that must have requested elements and how to act + * on the read elements. When an element has been read [[#onInput]] is called and + * then it is ensured that the specified downstream outputs have requested at least + * one element, i.e. it is allowed to emit at least one element downstream with + * [[RouteLogicContext#emit]]. + * + * The `onInput` function is called when an `element` was read from upstream. + * The function returns next behavior or [[#SameState]] to keep current behavior. + */ + sealed case class State[Out](val condition: DemandCondition)( + val onInput: (RouteLogicContext[Out], OutputHandle, In) ⇒ State[_]) + + /** + * Return this from [[State]] `onInput` to use same state for next element. + */ + def SameState[In]: State[In] = sameStateInstance.asInstanceOf[State[In]] + + private val sameStateInstance = new State[Any](DemandFromAny(Nil))((_, _, _) ⇒ + throw new UnsupportedOperationException("SameState.onInput should not be called")) { + + // unique instance, don't use case class + override def equals(other: Any): Boolean = super.equals(other) + override def hashCode: Int = super.hashCode + override def toString: String = "SameState" + } + + /** + * How to handle completion or error from upstream input and how to + * handle cancel from downstream output. + * + * The `onComplete` function is called the upstream input was completed successfully. + * It returns next behavior or [[#SameState]] to keep current behavior. + * + * The `onError` function is called when the upstream input was completed with failure. + * It returns next behavior or [[#SameState]] to keep current behavior. + * + * The `onCancel` function is called when a downstream output cancels. + * It returns next behavior or [[#SameState]] to keep current behavior. + */ + sealed case class CompletionHandling( + onComplete: RouteLogicContext[Any] ⇒ Unit, + onError: (RouteLogicContext[Any], Throwable) ⇒ Unit, + onCancel: (RouteLogicContext[Any], OutputHandle) ⇒ State[_]) + + /** + * When an output cancels it continues with remaining outputs. + * Error or completion from upstream are immediately propagated. + */ + val defaultCompletionHandling: CompletionHandling = CompletionHandling( + onComplete = _ ⇒ (), + onError = (ctx, cause) ⇒ (), + onCancel = (ctx, _) ⇒ SameState) + + /** + * Completes as soon as any output cancels. + * Error or completion from upstream are immediately propagated. + */ + val eagerClose: CompletionHandling = CompletionHandling( + onComplete = _ ⇒ (), + onError = (ctx, cause) ⇒ (), + onCancel = (ctx, _) ⇒ { ctx.complete(); SameState }) + + } + +} + +/** + * Base class for implementing custom route junctions. + * Such a junction always has one [[#in]] port and one or more output ports. + * The output ports are to be defined in the concrete subclass and are created with + * [[#createOutputPort]]. + * + * The concrete subclass must implement [[#createRouteLogic]] to define the [[FlexiRoute#RouteLogic]] + * that will be used when reading input elements and emitting output elements. + * The [[FlexiRoute#RouteLogic]] instance may be stateful, but the ``FlexiRoute`` instance + * must not hold mutable state, since it may be shared across several materialized ``FlowGraph`` + * instances. + * + * Note that a `FlexiRoute` with a specific name can only be used at one place (one vertex) + * in the `FlowGraph`. If the `name` is not specified the `FlexiRoute` instance can only + * be used at one place (one vertex) in the `FlowGraph`. + * + * @param name optional name of the junction in the [[FlowGraph]], + */ +abstract class FlexiRoute[In](val name: Option[String]) { + import FlexiRoute._ + + def this(name: String) = this(Some(name)) + def this() = this(None) + + private var outputCount = 0 + + // hide the internal vertex things from subclass, and make it possible to create new instance + private class RouteVertex(vertexName: Option[String]) extends FlowGraphInternal.InternalVertex { + override def minimumInputCount = 1 + override def maximumInputCount = 1 + override def minimumOutputCount = 2 + override def maximumOutputCount = outputCount + + override private[akka] val astNode = Ast.RouteNode(FlexiRoute.this.asInstanceOf[FlexiRoute[Any]]) + override def name = vertexName + + final override private[scaladsl] def newInstance() = new RouteVertex(None) + } + + private[scaladsl] val vertex: FlowGraphInternal.InternalVertex = new RouteVertex(name) + + /** + * Input port of the `FlexiRoute` junction. A [[Source]] can be connected to this output + * with the [[FlowGraphBuilder]]. + */ + val in: JunctionInPort[In] = new JunctionInPort[In] { + override type NextT = Nothing + override private[akka] def next = NoNext + override private[akka] def vertex = FlexiRoute.this.vertex + } + + /** + * Concrete subclass is supposed to define one or more output ports and + * they are created by calling this method. Each [[FlexiRoute.OutputPort]] can be + * connected to a [[Sink]] with the [[FlowGraphBuilder]]. + * The `OutputPort` is also an [[FlexiRoute.OutputHandle]] which you use to define to which + * downstream output to emit an element. + */ + protected final def createOutputPort[T](): OutputPort[In, T] = { + val port = outputCount + outputCount += 1 + new OutputPort(port, parent = this) + } + + /** + * Create the stateful logic that will be used when reading input elements + * and emitting output elements. Create a new instance every time. + */ + def createRouteLogic(): RouteLogic[In] + + override def toString = name match { + case Some(n) ⇒ n + case None ⇒ getClass.getSimpleName + "@" + Integer.toHexString(super.hashCode()) + } +}