diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala index 3d260e17f3..a81fe87e0b 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala @@ -3,14 +3,12 @@ */ package akka.stream.impl -import akka.stream.ReactiveStreamsConstants import org.reactivestreams.{ Publisher, Subscriber, Subscription, Processor } import akka.actor._ -import akka.stream.MaterializerSettings +import akka.stream.{ ReactiveStreamsConstants, MaterializerSettings, TimerTransformer } import akka.stream.actor.ActorSubscriber.OnSubscribe import akka.stream.actor.ActorSubscriberMessage.{ OnNext, OnComplete, OnError } import java.util.Arrays -import akka.stream.TimerTransformer /** * INTERNAL API @@ -162,7 +160,7 @@ private[akka] abstract class BatchingInputBuffer(val size: Int, val pump: Pump) /** * INTERNAL API */ -private[akka] class SimpleOutputs(self: ActorRef, val pump: Pump) extends DefaultOutputTransferStates { +private[akka] class SimpleOutputs(val actor: ActorRef, val pump: Pump) extends DefaultOutputTransferStates { protected var exposedPublisher: ActorPublisher[Any] = _ @@ -198,11 +196,15 @@ private[akka] class SimpleOutputs(self: ActorRef, val pump: Pump) extends Defaul def isClosed: Boolean = downstreamCompleted + protected def createSubscription(): Subscription = { + new ActorSubscription(actor, subscriber) + } + private def subscribePending(subscribers: Seq[Subscriber[Any]]): Unit = subscribers foreach { sub ⇒ if (subscriber eq null) { subscriber = sub - subscriber.onSubscribe(new ActorSubscription(self, subscriber)) + subscriber.onSubscribe(createSubscription()) } else sub.onError(new IllegalStateException(s"${getClass.getSimpleName} ${ReactiveStreamsConstants.SupportsOnlyASingleSubscriber}")) } @@ -289,8 +291,8 @@ private[akka] abstract class ActorProcessorImpl(val settings: MaterializerSettin primaryOutputs.cancel(new IllegalStateException("Processor actor terminated abruptly")) } - override def preRestart(reason: Throwable, message: Option[Any]): Unit = { - super.preRestart(reason, message) + override def postRestart(reason: Throwable): Unit = { + super.postRestart(reason) throw new IllegalStateException("This actor cannot be restarted") } diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorPublisher.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorPublisher.scala index df3efb8e57..c0de87bd5f 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorPublisher.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorPublisher.scala @@ -57,6 +57,8 @@ private[akka] class ActorPublisher[T](val impl: ActorRef, val equalityValue: Opt // the shutdown method. Subscription attempts after shutdown can be denied immediately. private val pendingSubscribers = new AtomicReference[immutable.Seq[Subscriber[_ >: T]]](Nil) + protected val wakeUpMsg: Any = SubscribePending + override def subscribe(subscriber: Subscriber[_ >: T]): Unit = { @tailrec def doSubscribe(subscriber: Subscriber[_ >: T]): Unit = { val current = pendingSubscribers.get @@ -64,7 +66,7 @@ private[akka] class ActorPublisher[T](val impl: ActorRef, val equalityValue: Opt reportSubscribeError(subscriber) else { if (pendingSubscribers.compareAndSet(current, subscriber +: current)) - impl ! SubscribePending + impl ! wakeUpMsg else doSubscribe(subscriber) // CAS retry } diff --git a/akka-stream/src/main/scala/akka/stream/impl2/ActorBasedFlowMaterializer.scala b/akka-stream/src/main/scala/akka/stream/impl2/ActorBasedFlowMaterializer.scala index cab1cc29a5..040e8a8936 100644 --- a/akka-stream/src/main/scala/akka/stream/impl2/ActorBasedFlowMaterializer.scala +++ b/akka-stream/src/main/scala/akka/stream/impl2/ActorBasedFlowMaterializer.scala @@ -4,6 +4,9 @@ package akka.stream.impl2 import java.util.concurrent.atomic.AtomicLong + +import akka.stream.actor.ActorSubscriber + import scala.annotation.tailrec import scala.collection.immutable import scala.concurrent.{ Future, Await } @@ -70,14 +73,22 @@ private[akka] object Ast { def name: String } - case object Merge extends JunctionAstNode { + // FIXME: Try to eliminate these + sealed trait FanInAstNode extends JunctionAstNode + sealed trait FanOutAstNode extends JunctionAstNode + + case object Merge extends FanInAstNode { override def name = "merge" } - case object Broadcast extends JunctionAstNode { + case object Broadcast extends FanOutAstNode { override def name = "broadcast" } + case object Zip extends FanInAstNode { + override def name = "zip" + } + } /** @@ -187,15 +198,38 @@ case class ActorBasedFlowMaterializer(override val settings: MaterializerSetting throw new IllegalStateException(s"Stream supervisor must be a local actor, was [${supervisor.getClass.getName}]") } - override def materializeJunction[In, Out](op: Ast.JunctionAstNode, inputCount: Int, outputCount: Int): (immutable.Seq[Subscriber[In]], immutable.Seq[Publisher[Out]]) = op match { - case Ast.Merge ⇒ - // FIXME real impl - require(outputCount == 1) - (Vector.fill(inputCount)(dummySubscriber[In]), List(dummyPublisher[Out])) - case Ast.Broadcast ⇒ - // FIXME real impl - require(inputCount == 1) - (List(dummySubscriber[In]), Vector.fill(outputCount)(dummyPublisher[Out])) + override def materializeJunction[In, Out](op: Ast.JunctionAstNode, inputCount: Int, outputCount: Int): (immutable.Seq[Subscriber[In]], immutable.Seq[Publisher[Out]]) = { + val flowName = createFlowName() + val actorName = s"$flowName-${op.name}" + + op match { + case fanin: Ast.FanInAstNode ⇒ + val impl = op match { + case Ast.Merge ⇒ + actorOf(Props(new FairMerge(settings, inputCount)).withDispatcher(settings.dispatcher), actorName) + case Ast.Zip ⇒ + actorOf(Props(new Zip(settings)).withDispatcher(settings.dispatcher), actorName) + } + + val publisher = new ActorPublisher[Out](impl, equalityValue = None) + impl ! ExposedPublisher(publisher.asInstanceOf[ActorPublisher[Any]]) + val subscribers = Vector.tabulate(inputCount)(FanIn.SubInput[In](impl, _)) + (subscribers, List(publisher)) + + case fanout: Ast.FanOutAstNode ⇒ + val impl = op match { + case Ast.Broadcast ⇒ + actorOf(Props(new Broadcast(settings, outputCount)).withDispatcher(settings.dispatcher), actorName) + } + + val publishers = Vector.tabulate(outputCount)(id ⇒ new ActorPublisher[Out](impl, equalityValue = None) { + override val wakeUpMsg = FanOut.SubstreamSubscribePending(id) + }) + impl ! FanOut.ExposedPublishers(publishers.asInstanceOf[immutable.Seq[ActorPublisher[Any]]]) + val subscriber = ActorSubscriber[In](impl) + (List(subscriber), publishers) + } + } // FIXME remove diff --git a/akka-stream/src/main/scala/akka/stream/impl2/FanIn.scala b/akka-stream/src/main/scala/akka/stream/impl2/FanIn.scala new file mode 100644 index 0000000000..22113ecfa9 --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/impl2/FanIn.scala @@ -0,0 +1,209 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ +package akka.stream.impl2 + +import akka.actor.{ ActorRef, ActorLogging, Actor } +import akka.stream.MaterializerSettings +import akka.stream.actor.{ ActorSubscriberMessage, ActorSubscriber } +import akka.stream.impl._ +import org.reactivestreams.{ Subscription, Subscriber } + +/** + * INTERNAL API + */ +private[akka] object FanIn { + + case class OnError(id: Int, cause: Throwable) + case class OnComplete(id: Int) + case class OnNext(id: Int, e: Any) + case class OnSubscribe(id: Int, subscription: Subscription) + + private[akka] final case class SubInput[T](impl: ActorRef, id: Int) extends Subscriber[T] { + override def onError(cause: Throwable): Unit = impl ! OnError(id, cause) + override def onComplete(): Unit = impl ! OnComplete(id) + override def onNext(element: T): Unit = impl ! OnNext(id, element) + override def onSubscribe(subscription: Subscription): Unit = impl ! OnSubscribe(id, subscription) + } + + abstract class InputBunch(inputCount: Int, bufferSize: Int, pump: Pump) { + private var cancelled = false + + private val inputs = Array.fill(inputCount)(new BatchingInputBuffer(bufferSize, pump) { + override protected def onError(e: Throwable): Unit = InputBunch.this.onError(e) + }) + + private val marked = Array.ofDim[Boolean](inputCount) + private var markCount = 0 + private val pending = Array.ofDim[Boolean](inputCount) + private var markedPending = 0 + private val completed = Array.ofDim[Boolean](inputCount) + private var markedCompleted = 0 + + private var preferredId = 0 + + def cancel(): Unit = + if (!cancelled) { + cancelled = true + inputs foreach (_.cancel()) + } + + def onError(e: Throwable): Unit + + def markInput(input: Int): Unit = { + if (!marked(input)) { + if (completed(input)) markedCompleted += 1 + if (pending(input)) markedPending += 1 + marked(input) = true + markCount += 1 + } + } + + def unmarkInput(input: Int): Unit = { + if (marked(input)) { + if (completed(input)) markedCompleted -= 1 + if (pending(input)) markedPending -= 1 + marked(input) = false + markCount -= 1 + } + } + + private def idToDequeue(): Int = { + var id = preferredId + while (!(marked(id) && pending(id))) { + id += 1 + if (id == inputCount) id = 0 + assert(id != preferredId, "Tried to dequeue without waiting for any input") + } + id + } + + def dequeue(id: Int): Any = { + val input = inputs(id) + val elem = input.dequeueInputElement() + if (!input.inputsAvailable) { + markedPending -= 1 + pending(id) = false + } + elem + } + + def dequeueAndYield(): Any = { + val id = idToDequeue() + preferredId = id + 1 + if (preferredId == inputCount) preferredId = 0 + dequeue(id) + } + def dequeueAndPrefer(preferred: Int): Any = { + val id = idToDequeue() + preferredId = preferred + dequeue(id) + } + + val AllOfMarkedInputs = new TransferState { + override def isCompleted: Boolean = markedCompleted == markCount && markedPending < markCount + override def isReady: Boolean = markedPending == markCount + } + + val AnyOfMarkedInputs = new TransferState { + override def isCompleted: Boolean = markedCompleted == markCount && markedPending == 0 + override def isReady: Boolean = markedPending > 0 + } + + // FIXME: Eliminate re-wraps + def subreceive: SubReceive = new SubReceive({ + case OnSubscribe(id, subscription) ⇒ + inputs(id).subreceive(ActorSubscriber.OnSubscribe(subscription)) + case OnNext(id, elem) ⇒ + if (marked(id) && !pending(id)) markedPending += 1 + pending(id) = true + inputs(id).subreceive(ActorSubscriberMessage.OnNext(elem)) + case OnComplete(id) ⇒ + if (marked(id) && !completed(id)) markedCompleted += 1 + completed(id) = true + inputs(id).subreceive(ActorSubscriberMessage.OnComplete) + case OnError(id, e) ⇒ onError(e) + }) + + } + +} + +/** + * INTERNAL API + */ +private[akka] abstract class FanIn(val settings: MaterializerSettings, val inputPorts: Int) extends Actor with ActorLogging with Pump { + import FanIn._ + + protected val primaryOutputs: Outputs = new SimpleOutputs(self, this) + protected val inputBunch = new InputBunch(inputPorts, settings.maxInputBufferSize, this) { + override def onError(e: Throwable): Unit = fail(e) + } + + override def pumpFinished(): Unit = { + inputBunch.cancel() + primaryOutputs.complete() + context.stop(self) + } + + override def pumpFailed(e: Throwable): Unit = fail(e) + + protected def fail(e: Throwable): Unit = { + log.error(e, "failure during processing") // FIXME: escalate to supervisor instead + inputBunch.cancel() + primaryOutputs.cancel(e) + context.stop(self) + } + + override def postStop(): Unit = { + inputBunch.cancel() + primaryOutputs.cancel(new IllegalStateException("Processor actor terminated abruptly")) + } + + override def postRestart(reason: Throwable): Unit = { + super.postRestart(reason) + throw new IllegalStateException("This actor cannot be restarted") + } + + def receive = inputBunch.subreceive orElse primaryOutputs.subreceive + +} + +/** + * INTERNAL API + */ +private[akka] class FairMerge(_settings: MaterializerSettings, _inputPorts: Int) extends FanIn(_settings, _inputPorts) { + (0 until inputPorts) foreach inputBunch.markInput + + nextPhase(TransferPhase(inputBunch.AnyOfMarkedInputs && primaryOutputs.NeedsDemand) { () ⇒ + val elem = inputBunch.dequeueAndYield() + primaryOutputs.enqueueOutputElement(elem) + }) + +} + +/** + * INTERNAL API + */ +private[akka] class UnfairMerge(_settings: MaterializerSettings, _inputPorts: Int, val preferred: Int) extends FanIn(_settings, _inputPorts) { + (0 until inputPorts) foreach inputBunch.markInput + + nextPhase(TransferPhase(inputBunch.AnyOfMarkedInputs && primaryOutputs.NeedsDemand) { () ⇒ + val elem = inputBunch.dequeueAndPrefer(preferred) + primaryOutputs.enqueueOutputElement(elem) + }) +} + +/** + * INTERNAL API + */ +private[akka] class Zip(_settings: MaterializerSettings) extends FanIn(_settings, inputPorts = 2) { + inputBunch.markInput(0) + inputBunch.markInput(1) + + nextPhase(TransferPhase(inputBunch.AllOfMarkedInputs && primaryOutputs.NeedsDemand) { () ⇒ + val elem0 = inputBunch.dequeue(0) + val elem1 = inputBunch.dequeue(1) + primaryOutputs.enqueueOutputElement((elem0, elem1)) + }) +} \ No newline at end of file diff --git a/akka-stream/src/main/scala/akka/stream/impl2/FanOut.scala b/akka-stream/src/main/scala/akka/stream/impl2/FanOut.scala new file mode 100644 index 0000000000..45ca401188 --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/impl2/FanOut.scala @@ -0,0 +1,239 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ +package akka.stream.impl2 + +import java.util.concurrent.atomic.AtomicReference + +import akka.actor.{ Actor, ActorLogging, ActorRef } +import akka.stream.MaterializerSettings +import akka.stream.impl.{ BatchingInputBuffer, Pump, SimpleOutputs, SubReceive, TransferState, _ } +import org.reactivestreams.{ Subscription, Subscriber, Publisher } + +import scala.collection.immutable + +/** + * INTERNAL API + */ +private[akka] object FanOut { + + case class SubstreamRequestMore(id: Int, demand: Long) + case class SubstreamCancel(id: Int) + case class SubstreamSubscribePending(id: Int) + + class SubstreamSubscription(val parent: ActorRef, val id: Int) extends Subscription { + override def request(elements: Long): Unit = + if (elements <= 0) throw new IllegalArgumentException("The number of requested elements must be > 0") + else parent ! SubstreamRequestMore(id, elements) + override def cancel(): Unit = parent ! SubstreamCancel(id) + override def toString = "SubstreamSubscription" + System.identityHashCode(this) + } + + class FanoutOutputs(val id: Int, _impl: ActorRef, _pump: Pump) extends SimpleOutputs(_impl, _pump) { + override def createSubscription(): Subscription = new SubstreamSubscription(actor, id) + } + + case class ExposedPublishers(publishers: immutable.Seq[ActorPublisher[Any]]) + + class OutputBunch(outputCount: Int, impl: ActorRef, pump: Pump) { + private var bunchCancelled = false + + private val outputs = Array.tabulate(outputCount)(new FanoutOutputs(_, impl, pump)) + + private val marked = Array.ofDim[Boolean](outputCount) + private var markCount = 0 + private val pending = Array.ofDim[Boolean](outputCount) + private var markedPending = 0 + private val cancelled = Array.ofDim[Boolean](outputCount) + private var markedCancelled = 0 + + private var unmarkCancelled = true + + private var preferredId = 0 + + def complete(): Unit = + if (!bunchCancelled) { + bunchCancelled = true + outputs foreach (_.complete()) + } + + def cancel(e: Throwable): Unit = + if (!bunchCancelled) { + bunchCancelled = true + outputs foreach (_.cancel(e)) + } + + def markOutput(output: Int): Unit = { + if (!marked(output)) { + if (cancelled(output)) markedCancelled += 1 + if (pending(output)) markedPending += 1 + marked(output) = true + markCount += 1 + } + } + + def unmarkOutput(output: Int): Unit = { + if (marked(output)) { + if (cancelled(output)) markedCancelled -= 1 + if (pending(output)) markedPending -= 1 + marked(output) = false + markCount -= 1 + } + } + + def unmarkCancelledOutputs(enabled: Boolean): Unit = unmarkCancelled = enabled + + private def idToEnqueue(): Int = { + var id = preferredId + while (!(marked(id) && pending(id))) { + id += 1 + if (id == outputCount) id = 0 + assert(id != preferredId, "Tried to enqueue without waiting for any demand") + } + id + } + + def enqueue(id: Int, elem: Any): Unit = { + val output = outputs(id) + output.enqueueOutputElement(elem) + if (!output.demandAvailable) { + markedPending -= 1 + pending(id) = false + } + } + + def enqueueMarked(elem: Any): Unit = { + var id = 0 + while (id < outputCount) { + if (marked(id)) enqueue(id, elem) + id += 1 + } + } + + def enqueueAndYield(elem: Any): Unit = { + val id = idToEnqueue() + preferredId = id + 1 + if (preferredId == outputCount) preferredId = 0 + enqueue(id, elem) + } + + def enqueueAndPrefer(elem: Any, preferred: Int): Unit = { + val id = idToEnqueue() + preferredId = preferred + enqueue(id, elem) + } + + val AllOfMarkedOutputs = new TransferState { + override def isCompleted: Boolean = markedCancelled > 0 + override def isReady: Boolean = markedPending == markCount + } + + val AnyOfMarkedOutputs = new TransferState { + override def isCompleted: Boolean = markedCancelled == markCount + override def isReady: Boolean = markedPending > 0 + } + + // FIXME: Eliminate re-wraps + def subreceive: SubReceive = new SubReceive({ + case ExposedPublishers(publishers) ⇒ + publishers.zip(outputs) foreach { + case (pub, output) ⇒ + output.subreceive(ExposedPublisher(pub)) + } + + case SubstreamRequestMore(id, demand) ⇒ + if (marked(id) && !pending(id)) markedPending += 1 + pending(id) = true + outputs(id).subreceive(RequestMore(null, demand)) + case SubstreamCancel(id) ⇒ + if (unmarkCancelled) { + if (marked(id)) markCount -= 1 + marked(id) = false + } + if (marked(id) && !cancelled(id)) markedCancelled += 1 + cancelled(id) = true + outputs(id).subreceive(Cancel(null)) + case SubstreamSubscribePending(id) ⇒ outputs(id).subreceive(SubscribePending) + }) + + } + +} + +/** + * INTERNAL API + */ +private[akka] abstract class FanOut(val settings: MaterializerSettings, val outputPorts: Int) extends Actor with ActorLogging with Pump { + import akka.stream.impl2.FanOut._ + + protected val outputBunch = new OutputBunch(outputPorts, self, this) + protected val primaryInputs: Inputs = new BatchingInputBuffer(settings.maxInputBufferSize, this) { + override def onError(e: Throwable): Unit = fail(e) + } + + override def pumpFinished(): Unit = { + primaryInputs.cancel() + outputBunch.complete() + context.stop(self) + } + + override def pumpFailed(e: Throwable): Unit = fail(e) + + protected def fail(e: Throwable): Unit = { + log.error(e, "failure during processing") // FIXME: escalate to supervisor instead + primaryInputs.cancel() + outputBunch.cancel(e) + context.stop(self) + } + + override def postStop(): Unit = { + primaryInputs.cancel() + outputBunch.cancel(new IllegalStateException("Processor actor terminated abruptly")) + } + + override def postRestart(reason: Throwable): Unit = { + super.postRestart(reason) + throw new IllegalStateException("This actor cannot be restarted") + } + + def receive = primaryInputs.subreceive orElse outputBunch.subreceive + +} + +/** + * INTERNAL API + */ +private[akka] class Broadcast(_settings: MaterializerSettings, _outputPorts: Int) extends FanOut(_settings, _outputPorts) { + (0 until outputPorts) foreach outputBunch.markOutput + + nextPhase(TransferPhase(primaryInputs.NeedsInput && outputBunch.AllOfMarkedOutputs) { () ⇒ + val elem = primaryInputs.dequeueInputElement() + outputBunch.enqueueMarked(elem) + }) +} + +/** + * INTERNAL API + */ +private[akka] class Balance(_settings: MaterializerSettings, _outputPorts: Int) extends FanOut(_settings, _outputPorts) { + (0 until outputPorts) foreach outputBunch.markOutput + + nextPhase(TransferPhase(primaryInputs.NeedsInput && outputBunch.AnyOfMarkedOutputs) { () ⇒ + val elem = primaryInputs.dequeueInputElement() + outputBunch.enqueueAndYield(elem) + }) +} + +/** + * INTERNAL API + */ +private[akka] class Unzip(_settings: MaterializerSettings, _outputPorts: Int) extends FanOut(_settings, _outputPorts) { + (0 until outputPorts) foreach outputBunch.markOutput + + nextPhase(TransferPhase(primaryInputs.NeedsInput && outputBunch.AllOfMarkedOutputs) { () ⇒ + val (elem0, elem1) = primaryInputs.dequeueInputElement().asInstanceOf[(Any, Any)] + outputBunch.enqueue(0, elem0) + outputBunch.enqueue(1, elem1) + }) +} + diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl2/FlowGraph.scala b/akka-stream/src/main/scala/akka/stream/scaladsl2/FlowGraph.scala index 7496cc7774..b95d616894 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl2/FlowGraph.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl2/FlowGraph.scala @@ -16,15 +16,23 @@ import akka.stream.impl2.Ast * Fan-in and fan-out vertices in the [[FlowGraph]] implements * this marker interface. */ -sealed trait Junction[T] extends FlowGraphInternal.Vertex -/** - * Fan-out vertices in the [[FlowGraph]] implements this marker interface. - */ -trait FanOutOperation[T] extends Junction[T] -/** - * Fan-in vertices in the [[FlowGraph]] implements this marker interface. - */ -trait FanInOperation[T] extends Junction[T] +sealed trait Junction[T] extends JunctionInPort[T] with JunctionOutPort[T] { + override def port: Int = FlowGraphInternal.UnlabeledPort + override def vertex: FlowGraphInternal.Vertex + type NextT = T + override def next = this +} + +sealed trait JunctionInPort[T] { + def port: Int = FlowGraphInternal.UnlabeledPort + def vertex: FlowGraphInternal.Vertex + type NextT + def next: JunctionOutPort[NextT] +} +sealed trait JunctionOutPort[T] { + def port: Int = FlowGraphInternal.UnlabeledPort + def vertex: FlowGraphInternal.Vertex +} object Merge { /** @@ -49,7 +57,13 @@ object Merge { * When building the [[FlowGraph]] you must connect one or more input flows/sources * and one output flow/sink to the `Merge` vertex. */ -final class Merge[T](override val name: Option[String]) extends FanInOperation[T] with FlowGraphInternal.NamedVertex +final class Merge[T](override val name: Option[String]) extends FlowGraphInternal.InternalVertex with Junction[T] { + override val vertex = this + override val minimumInputCount: Int = 2 + override val maximumInputCount: Int = Int.MaxValue + override val minimumOutputCount: Int = 1 + override val maximumOutputCount: Int = 1 +} object Broadcast { /** @@ -72,7 +86,59 @@ object Broadcast { * the other streams. It will not shutdown until the subscriptions for at least * two downstream subscribers have been established. */ -final class Broadcast[T](override val name: Option[String]) extends FanOutOperation[T] with FlowGraphInternal.NamedVertex +final class Broadcast[T](override val name: Option[String]) extends FlowGraphInternal.InternalVertex with Junction[T] { + override val vertex = this + override val minimumInputCount: Int = 1 + override val maximumInputCount: Int = 1 + override val minimumOutputCount: Int = 2 + override val maximumOutputCount: Int = Int.MaxValue +} + +object Zip { + /** + * Create a new anonymous `Zip` vertex with the specified input types. + * Note that a `Zip` instance can only be used at one place (one vertex) + * in the `FlowGraph`. This method creates a new instance every time it + * is called and those instances are not `equal`.* + */ + def apply[A, B]: Zip[A, B] = new Zip[A, B](None) + + /** + * Create a named `Zip` vertex with the specified input types. + * Note that a `Zip` instance can only be used at one place (one vertex) + * in the `FlowGraph`. This method creates a new instance every time it + * is called and those instances are not `equal`.* + */ + def apply[A, B](name: String): Zip[A, B] = new Zip[A, B](Some(name)) + + class Left[A, B] private[akka] (val vertex: Zip[A, B]) extends JunctionInPort[A] { + override val port = 0 + type NextT = (A, B) + override def next = vertex.out + } + class Right[A, B] private[akka] (val vertex: Zip[A, B]) extends JunctionInPort[B] { + override val port = 1 + type NextT = (A, B) + override def next = vertex.out + } + class Out[A, B] private[akka] (val vertex: Zip[A, B]) extends JunctionOutPort[(A, B)] +} + +/** + * Takes two streams and outputs an output stream formed from the two input streams + * by combining corresponding elements in pairs. If one of the two streams is + * longer than the other, its remaining elements are ignored. + */ +final class Zip[A, B](override val name: Option[String]) extends FlowGraphInternal.InternalVertex { + val left = new Zip.Left(this) + val right = new Zip.Right(this) + val out = new Zip.Out(this) + + override val minimumInputCount: Int = 2 + override val maximumInputCount: Int = 2 + override val minimumOutputCount: Int = 1 + override val maximumOutputCount: Int = 1 +} object UndefinedSink { /** @@ -95,7 +161,12 @@ object UndefinedSink { * yet by using this placeholder instead of the real [[Sink]]. Later the placeholder can * be replaced with [[FlowGraphBuilder#attachSink]]. */ -final class UndefinedSink[T](override val name: Option[String]) extends FlowGraphInternal.NamedVertex +final class UndefinedSink[T](override val name: Option[String]) extends FlowGraphInternal.InternalVertex { + override val minimumInputCount: Int = 1 + override val maximumInputCount: Int = 1 + override val minimumOutputCount: Int = 0 + override val maximumOutputCount: Int = 0 +} object UndefinedSource { /** @@ -118,13 +189,20 @@ object UndefinedSource { * yet by using this placeholder instead of the real [[Source]]. Later the placeholder can * be replaced with [[FlowGraphBuilder#attachSource]]. */ -final class UndefinedSource[T](override val name: Option[String]) extends FlowGraphInternal.NamedVertex +final class UndefinedSource[T](override val name: Option[String]) extends FlowGraphInternal.InternalVertex { + override val minimumInputCount: Int = 0 + override val maximumInputCount: Int = 0 + override val minimumOutputCount: Int = 1 + override val maximumOutputCount: Int = 1 +} /** * INTERNAL API */ private[akka] object FlowGraphInternal { + val UnlabeledPort = -1 + sealed trait Vertex case class SourceVertex(source: Source[_]) extends Vertex { override def toString = source.toString @@ -139,12 +217,17 @@ private[akka] object FlowGraphInternal { final override def hashCode: Int = super.hashCode } - sealed trait NamedVertex extends Vertex { + sealed trait InternalVertex extends Vertex { def name: Option[String] + def minimumInputCount: Int + def maximumInputCount: Int + def minimumOutputCount: Int + def maximumOutputCount: Int + final override def equals(obj: Any): Boolean = obj match { - case other: NamedVertex ⇒ + case other: InternalVertex ⇒ if (name.isDefined) (getClass == other.getClass && name == other.name) else (this eq other) case _ ⇒ false } @@ -161,7 +244,7 @@ private[akka] object FlowGraphInternal { } // flow not part of equals/hashCode - case class EdgeLabel(qualifier: Int)(val flow: ProcessorFlow[Any, Any]) { + case class EdgeLabel(qualifier: Int, inputPort: Int)(val flow: ProcessorFlow[Any, Any]) { override def toString: String = flow.toString } @@ -185,56 +268,59 @@ class FlowGraphBuilder private (graph: Graph[FlowGraphInternal.Vertex, LkDiEdge] private var cyclesAllowed = false - def addEdge[In, Out](source: Source[In], flow: ProcessorFlow[In, Out], sink: Junction[Out]): this.type = { + def addEdge[In, Out](source: Source[In], flow: ProcessorFlow[In, Out], sink: JunctionInPort[Out]): this.type = { val sourceVertex = SourceVertex(source) checkAddSourceSinkPrecondition(sourceVertex) - checkAddFanPrecondition(sink, in = true) - addGraphEdge(sourceVertex, sink, flow) + checkJunctionInPortPrecondition(sink) + addGraphEdge(sourceVertex, sink.vertex, flow, sink.port) this } - def addEdge[In, Out](source: UndefinedSource[In], flow: ProcessorFlow[In, Out], sink: Junction[Out]): this.type = { + def addEdge[In, Out](source: UndefinedSource[In], flow: ProcessorFlow[In, Out], sink: JunctionInPort[Out]): this.type = { checkAddSourceSinkPrecondition(source) - checkAddFanPrecondition(sink, in = true) - addGraphEdge(source, sink, flow) + checkJunctionInPortPrecondition(sink) + addGraphEdge(source, sink.vertex, flow, sink.port) this } - def addEdge[In, Out](source: Junction[In], flow: ProcessorFlow[In, Out], sink: Sink[Out]): this.type = { + def addEdge[In, Out](source: JunctionOutPort[In], flow: ProcessorFlow[In, Out], sink: Sink[Out]): this.type = { val sinkVertex = SinkVertex(sink) checkAddSourceSinkPrecondition(sinkVertex) - checkAddFanPrecondition(source, in = false) - addGraphEdge(source, sinkVertex, flow) + checkJunctionOutPortPrecondition(source) + // FIXME: output ports are not handled yet + addGraphEdge(source.vertex, sinkVertex, flow, UnlabeledPort) this } - def addEdge[In, Out](source: Junction[In], flow: ProcessorFlow[In, Out], sink: UndefinedSink[Out]): this.type = { + def addEdge[In, Out](source: JunctionOutPort[In], flow: ProcessorFlow[In, Out], sink: UndefinedSink[Out]): this.type = { checkAddSourceSinkPrecondition(sink) - checkAddFanPrecondition(source, in = false) - addGraphEdge(source, sink, flow) + checkJunctionOutPortPrecondition(source) + // FIXME: output ports are not handled yet + addGraphEdge(source.vertex, sink, flow, UnlabeledPort) this } - def addEdge[In, Out](source: Junction[In], flow: ProcessorFlow[In, Out], sink: Junction[Out]): this.type = { - checkAddFanPrecondition(source, in = false) - checkAddFanPrecondition(sink, in = true) - addGraphEdge(source, sink, flow) + def addEdge[In, Out](source: JunctionOutPort[In], flow: ProcessorFlow[In, Out], sink: JunctionInPort[Out]): this.type = { + checkJunctionOutPortPrecondition(source) + checkJunctionInPortPrecondition(sink) + addGraphEdge(source.vertex, sink.vertex, flow, sink.port) this } - def addEdge[In, Out](flow: FlowWithSource[In, Out], sink: Junction[Out]): this.type = { + def addEdge[In, Out](flow: FlowWithSource[In, Out], sink: JunctionInPort[Out]): this.type = { addEdge(flow.input, flow.withoutSource, sink) this } - def addEdge[In, Out](source: Junction[In], flow: FlowWithSink[In, Out]): this.type = { + def addEdge[In, Out](source: JunctionOutPort[In], flow: FlowWithSink[In, Out]): this.type = { addEdge(source, flow.withoutSink, flow.output) this } - private def addGraphEdge[In, Out](from: Vertex, to: Vertex, flow: ProcessorFlow[In, Out]): Unit = { + private def addGraphEdge[In, Out](from: Vertex, to: Vertex, flow: ProcessorFlow[In, Out], portQualifier: Int): Unit = { if (edgeQualifier == Int.MaxValue) throw new IllegalArgumentException(s"Too many edges") - val label = EdgeLabel(edgeQualifier)(flow.asInstanceOf[ProcessorFlow[Any, Any]]) + val effectivePortQualifier = if (portQualifier == UnlabeledPort) edgeQualifier else portQualifier + val label = EdgeLabel(edgeQualifier, portQualifier)(flow.asInstanceOf[ProcessorFlow[Any, Any]]) graph.addLEdge(from, to)(label) edgeQualifier += 1 } @@ -276,22 +362,30 @@ class FlowGraphBuilder private (graph: Graph[FlowGraphInternal.Vertex, LkDiEdge] private def checkAddSourceSinkPrecondition(node: Vertex): Unit = require(graph.find(node) == None, s"[$node] instance is already used in this flow graph") - private def checkAddFanPrecondition(junction: Junction[_], in: Boolean): Unit = { - junction match { - case _: FanOutOperation[_] if in ⇒ - graph.find(junction) match { - case Some(existing) if existing.incoming.nonEmpty ⇒ - throw new IllegalArgumentException(s"Fan-out [$junction] is already attached to input [${existing.incoming.head}]") - case _ ⇒ // ok - } - case _: FanInOperation[_] if !in ⇒ - graph.find(junction) match { - case Some(existing) if existing.outgoing.nonEmpty ⇒ - throw new IllegalArgumentException(s"Fan-in [$junction] is already attached to output [${existing.outgoing.head}]") - case _ ⇒ // ok - } - case _ ⇒ // ok - } + private def checkJunctionInPortPrecondition(junction: JunctionInPort[_]): Unit = { + // FIXME: Reenable checks + // junction match { + // case _: FanOutOperation[_] if in ⇒ + // graph.find(junction.vertex) match { + // case Some(existing) if existing.incoming.nonEmpty ⇒ + // throw new IllegalArgumentException(s"Fan-out [$junction] is already attached to input [${existing.incoming.head}]") + // case _ ⇒ // ok + // } + // case _ ⇒ // ok + // } + } + + private def checkJunctionOutPortPrecondition(junction: JunctionOutPort[_]): Unit = { + // FIXME: Reenable checks + // junction match { + // case _: FanInOperation[_] if !in ⇒ + // graph.find(junction.vertex) match { + // case Some(existing) if existing.outgoing.nonEmpty ⇒ + // throw new IllegalArgumentException(s"Fan-in [$junction] is already attached to output [${existing.outgoing.head}]") + // case _ ⇒ // ok + // } + // case _ ⇒ // ok + // } } /** @@ -340,12 +434,19 @@ class FlowGraphBuilder private (graph: Graph[FlowGraphInternal.Vertex, LkDiEdge] // we will be able to relax these checks graph.nodes.foreach { node ⇒ node.value match { - case merge: Merge[_] ⇒ - require(node.incoming.size == 2, "Merge must have two incoming edges: " + node.incoming) - require(node.outgoing.size == 1, "Merge must have one outgoing edge: " + node.outgoing) - case bcast: Broadcast[_] ⇒ - require(node.incoming.size == 1, "Broadcast must have one incoming edge: " + node.incoming) - require(node.outgoing.size >= 1, "Broadcast must have at least one outgoing edge: " + node.outgoing) + case v: InternalVertex ⇒ + require( + node.inDegree >= v.minimumInputCount, + s"${node.incoming}) must have at least ${v.minimumInputCount} incoming edges") + require( + node.inDegree <= v.maximumInputCount, + s"${node.incoming}) must have at most ${v.maximumInputCount} incoming edges") + require( + node.outDegree >= v.minimumOutputCount, + s"${node.outgoing.size}) must have at least ${v.minimumOutputCount} outgoing edges") + require( + node.outDegree <= v.maximumOutputCount, + s"${node.incoming}) must have at most ${v.maximumOutputCount} outgoing edges") case _ ⇒ // no check for other node types } } @@ -452,43 +553,36 @@ class FlowGraph private[akka] (private[akka] val graph: ImmutableGraph[FlowGraph memo.copy(visited = memo.visited + edge, sources = memo.sources.updated(src, f)) - case merge: Merge[_] ⇒ - // one subscriber for each incoming edge of the merge vertex - val (subscribers, publishers) = - materializer.materializeJunction[Any, Any](Ast.Merge, edge.from.inDegree, 1) - val publisher = publishers.head - val edgeSubscribers = edge.from.incoming.zip(subscribers) - val materializedSink = connectToDownstream(publisher) - memo.copy( - visited = memo.visited + edge, - downstreamSubscriber = memo.downstreamSubscriber ++ edgeSubscribers, - materializedSinks = memo.materializedSinks ++ materializedSink) - - case bcast: Broadcast[_] ⇒ + case v: InternalVertex ⇒ if (memo.upstreamPublishers.contains(edge)) { - // broadcast vertex already materialized + // vertex already materialized val materializedSink = connectToDownstream(memo.upstreamPublishers(edge)) memo.copy( visited = memo.visited + edge, materializedSinks = memo.materializedSinks ++ materializedSink) } else { - // one publisher for each outgoing edge of the broadcast vertex + + val op: Ast.JunctionAstNode = v match { + case _: Merge[_] ⇒ Ast.Merge + case _: Broadcast[_] ⇒ Ast.Broadcast + case _: Zip[_, _] ⇒ Ast.Zip + } val (subscribers, publishers) = - materializer.materializeJunction[Any, Any](Ast.Broadcast, 1, edge.from.outDegree) - val subscriber = subscribers.head - val edgePublishers = edge.from.outgoing.zip(publishers).toMap + materializer.materializeJunction[Any, Any](op, edge.from.inDegree, edge.from.outDegree) + // TODO: Check for gaps in port numbers + val edgeSubscribers = + edge.from.incoming.toSeq.sortBy(_.label.asInstanceOf[EdgeLabel].inputPort).zip(subscribers) + val edgePublishers = + edge.from.outgoing.toSeq.sortBy(_.label.asInstanceOf[EdgeLabel].inputPort).zip(publishers).toMap val publisher = edgePublishers(edge) val materializedSink = connectToDownstream(publisher) memo.copy( visited = memo.visited + edge, - downstreamSubscriber = memo.downstreamSubscriber + (edge.from.incoming.head -> subscriber), + downstreamSubscriber = memo.downstreamSubscriber ++ edgeSubscribers, upstreamPublishers = memo.upstreamPublishers ++ edgePublishers, materializedSinks = memo.materializedSinks ++ materializedSink) } - case other ⇒ - throw new IllegalArgumentException("Unknown junction operation: " + other) - } } @@ -601,20 +695,20 @@ object FlowGraphImplicits { new SourceNextStep(source, flow, builder) } - def ~>(sink: Junction[In])(implicit builder: FlowGraphBuilder): Junction[In] = { + def ~>(sink: JunctionInPort[In])(implicit builder: FlowGraphBuilder): JunctionOutPort[sink.NextT] = { builder.addEdge(source, ProcessorFlow.empty[In], sink) - sink + sink.next } } class SourceNextStep[In, Out](source: Source[In], flow: ProcessorFlow[In, Out], builder: FlowGraphBuilder) { - def ~>(sink: Junction[Out]): Junction[Out] = { + def ~>(sink: JunctionInPort[Out]): JunctionOutPort[sink.NextT] = { builder.addEdge(source, flow, sink) - sink + sink.next } } - implicit class JunctionOps[In](val junction: Junction[In]) extends AnyVal { + implicit class JunctionOps[In](val junction: JunctionOutPort[In]) extends AnyVal { def ~>[Out](flow: ProcessorFlow[In, Out])(implicit builder: FlowGraphBuilder): JunctionNextStep[In, Out] = { new JunctionNextStep(junction, flow, builder) } @@ -625,19 +719,19 @@ object FlowGraphImplicits { def ~>(sink: UndefinedSink[In])(implicit builder: FlowGraphBuilder): Unit = builder.addEdge(junction, ProcessorFlow.empty[In], sink) - def ~>(sink: Junction[In])(implicit builder: FlowGraphBuilder): Junction[In] = { + def ~>(sink: JunctionInPort[In])(implicit builder: FlowGraphBuilder): JunctionOutPort[sink.NextT] = { builder.addEdge(junction, ProcessorFlow.empty[In], sink) - sink + sink.next } def ~>(flow: FlowWithSink[In, _])(implicit builder: FlowGraphBuilder): Unit = builder.addEdge(junction, flow) } - class JunctionNextStep[In, Out](junction: Junction[In], flow: ProcessorFlow[In, Out], builder: FlowGraphBuilder) { - def ~>(sink: Junction[Out]): Junction[Out] = { + class JunctionNextStep[In, Out](junction: JunctionOutPort[In], flow: ProcessorFlow[In, Out], builder: FlowGraphBuilder) { + def ~>(sink: JunctionInPort[Out]): JunctionOutPort[sink.NextT] = { builder.addEdge(junction, flow, sink) - sink + sink.next } def ~>(sink: Sink[Out]): Unit = { @@ -650,9 +744,9 @@ object FlowGraphImplicits { } implicit class FlowWithSourceOps[In, Out](val flow: FlowWithSource[In, Out]) extends AnyVal { - def ~>(sink: Junction[Out])(implicit builder: FlowGraphBuilder): Junction[Out] = { + def ~>(sink: JunctionInPort[Out])(implicit builder: FlowGraphBuilder): JunctionOutPort[sink.NextT] = { builder.addEdge(flow, sink) - sink + sink.next } } @@ -661,17 +755,17 @@ object FlowGraphImplicits { new UndefinedSourceNextStep(source, flow, builder) } - def ~>(sink: Junction[In])(implicit builder: FlowGraphBuilder): Junction[In] = { + def ~>(sink: JunctionInPort[In])(implicit builder: FlowGraphBuilder): JunctionOutPort[sink.NextT] = { builder.addEdge(source, ProcessorFlow.empty[In], sink) - sink + sink.next } } class UndefinedSourceNextStep[In, Out](source: UndefinedSource[In], flow: ProcessorFlow[In, Out], builder: FlowGraphBuilder) { - def ~>(sink: Junction[Out]): Junction[Out] = { + def ~>(sink: JunctionInPort[Out]): JunctionOutPort[sink.NextT] = { builder.addEdge(source, flow, sink) - sink + sink.next } } diff --git a/akka-stream/src/test/scala/akka/stream/scaladsl2/FlowGraphCompileSpec.scala b/akka-stream/src/test/scala/akka/stream/scaladsl2/FlowGraphCompileSpec.scala index 8eb6d31d99..9dce1e0124 100644 --- a/akka-stream/src/test/scala/akka/stream/scaladsl2/FlowGraphCompileSpec.scala +++ b/akka-stream/src/test/scala/akka/stream/scaladsl2/FlowGraphCompileSpec.scala @@ -226,5 +226,30 @@ class FlowGraphCompileSpec extends AkkaSpec { }.run() } + "chain input and output ports" in { + FlowGraph { implicit b ⇒ + val zip = Zip[Int, String] + val out = PublisherSink[(Int, String)] + import FlowGraphImplicits._ + FlowFrom(List(1, 2, 3)) ~> zip.left ~> out + FlowFrom(List("a", "b", "c")) ~> zip.right + }.run() + } + + "distinguish between input and output ports" in { + intercept[IllegalArgumentException] { + FlowGraph { implicit b ⇒ + val zip = Zip[Int, String] + val wrongOut = PublisherSink[(Int, Int)] + import FlowGraphImplicits._ + "FlowFrom(List(1, 2, 3)) ~> zip.left ~> wrongOut" shouldNot compile + """FlowFrom(List("a", "b", "c")) ~> zip.left""" shouldNot compile + """FlowFrom(List("a", "b", "c")) ~> zip.out""" shouldNot compile + "zip.left ~> zip.right" shouldNot compile + "FlowFrom(List(1, 2, 3)) ~> zip.left ~> wrongOut" shouldNot compile + } + }.getMessage should include("empty") + } + } } diff --git a/akka-stream/src/test/scala/akka/stream/scaladsl2/GraphBroadcastSpec.scala b/akka-stream/src/test/scala/akka/stream/scaladsl2/GraphBroadcastSpec.scala new file mode 100644 index 0000000000..a61f87ef73 --- /dev/null +++ b/akka-stream/src/test/scala/akka/stream/scaladsl2/GraphBroadcastSpec.scala @@ -0,0 +1,116 @@ +package akka.stream.scaladsl2 + +import akka.stream.{ OverflowStrategy, MaterializerSettings } +import akka.stream.testkit.{ StreamTestKit, AkkaSpec } +import scala.concurrent.Await +import scala.concurrent.duration._ +import akka.stream.scaladsl2.FlowGraphImplicits._ + +class GraphBroadcastSpec extends AkkaSpec { + + val settings = MaterializerSettings(system) + .withInputBuffer(initialSize = 2, maxSize = 16) + .withFanOutBuffer(initialSize = 1, maxSize = 16) + + implicit val materializer = FlowMaterializer(settings) + + "A broadcast" must { + + "broadcast to other subscriber" in { + val c1 = StreamTestKit.SubscriberProbe[Int]() + val c2 = StreamTestKit.SubscriberProbe[Int]() + + FlowGraph { implicit b ⇒ + val bcast = Broadcast[Int]("broadcast") + FlowFrom(List(1, 2, 3)) ~> bcast + bcast ~> FlowFrom[Int].buffer(16, OverflowStrategy.backpressure) ~> SubscriberSink(c1) + bcast ~> FlowFrom[Int].buffer(16, OverflowStrategy.backpressure) ~> SubscriberSink(c2) + }.run() + + val sub1 = c1.expectSubscription() + val sub2 = c2.expectSubscription() + sub1.request(1) + sub2.request(2) + c1.expectNext(1) + c1.expectNoMsg(100.millis) + c2.expectNext(1) + c2.expectNext(2) + c2.expectNoMsg(100.millis) + sub1.request(3) + c1.expectNext(2) + c1.expectNext(3) + c1.expectComplete() + sub2.request(3) + c2.expectNext(3) + c2.expectComplete() + } + + "work with n-way broadcast" in { + val f1 = FutureSink[Seq[Int]] + val f2 = FutureSink[Seq[Int]] + val f3 = FutureSink[Seq[Int]] + val f4 = FutureSink[Seq[Int]] + val f5 = FutureSink[Seq[Int]] + + val g = FlowGraph { implicit b ⇒ + val bcast = Broadcast[Int]("broadcast") + FlowFrom(List(1, 2, 3)) ~> bcast + bcast ~> FlowFrom[Int].grouped(5) ~> f1 + bcast ~> FlowFrom[Int].grouped(5) ~> f2 + bcast ~> FlowFrom[Int].grouped(5) ~> f3 + bcast ~> FlowFrom[Int].grouped(5) ~> f4 + bcast ~> FlowFrom[Int].grouped(5) ~> f5 + }.run() + + Await.result(g.getSinkFor(f1), 3.seconds) should be(List(1, 2, 3)) + Await.result(g.getSinkFor(f2), 3.seconds) should be(List(1, 2, 3)) + Await.result(g.getSinkFor(f3), 3.seconds) should be(List(1, 2, 3)) + Await.result(g.getSinkFor(f4), 3.seconds) should be(List(1, 2, 3)) + Await.result(g.getSinkFor(f5), 3.seconds) should be(List(1, 2, 3)) + } + + "produce to other even though downstream cancels" in { + val c1 = StreamTestKit.SubscriberProbe[Int]() + val c2 = StreamTestKit.SubscriberProbe[Int]() + + FlowGraph { implicit b ⇒ + val bcast = Broadcast[Int]("broadcast") + FlowFrom(List(1, 2, 3)) ~> bcast + bcast ~> FlowFrom[Int] ~> SubscriberSink(c1) + bcast ~> FlowFrom[Int] ~> SubscriberSink(c2) + }.run() + + val sub1 = c1.expectSubscription() + sub1.cancel() + val sub2 = c2.expectSubscription() + sub2.request(3) + c2.expectNext(1) + c2.expectNext(2) + c2.expectNext(3) + c2.expectComplete() + } + + "produce to downstream even though other cancels" in { + val c1 = StreamTestKit.SubscriberProbe[Int]() + val c2 = StreamTestKit.SubscriberProbe[Int]() + + FlowGraph { implicit b ⇒ + val bcast = Broadcast[Int]("broadcast") + FlowFrom(List(1, 2, 3)) ~> bcast + bcast ~> FlowFrom[Int] ~> SubscriberSink(c1) + bcast ~> FlowFrom[Int] ~> SubscriberSink(c2) + }.run() + + val sub1 = c1.expectSubscription() + sub1.cancel() + val sub2 = c2.expectSubscription() + sub2.request(3) + c2.expectNext(1) + c2.expectNext(2) + c2.expectNext(3) + c2.expectComplete() + } + + } + +} diff --git a/akka-stream/src/test/scala/akka/stream/scaladsl2/GraphMergeSpec.scala b/akka-stream/src/test/scala/akka/stream/scaladsl2/GraphMergeSpec.scala new file mode 100644 index 0000000000..d6e0621972 --- /dev/null +++ b/akka-stream/src/test/scala/akka/stream/scaladsl2/GraphMergeSpec.scala @@ -0,0 +1,138 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ +package akka.stream.scaladsl2 + +import scala.concurrent.duration._ +import akka.stream.testkit.StreamTestKit +import akka.stream.testkit2.TwoStreamsSetup +import akka.stream.scaladsl2.FlowGraphImplicits._ + +class GraphMergeSpec extends TwoStreamsSetup { + + override type Outputs = Int + val op = Merge[Int] + override def operationUnderTestLeft = op + override def operationUnderTestRight = op + + "merge" must { + + "work in the happy case" in { + // Different input sizes (4 and 6) + val source1 = FlowFrom((0 to 3).iterator) + val source2 = FlowFrom((4 to 9).iterator) + val source3 = FlowFrom(List.empty[Int].iterator) + val probe = StreamTestKit.SubscriberProbe[Int]() + + FlowGraph { implicit b ⇒ + val m1 = Merge[Int]("m1") + val m2 = Merge[Int]("m2") + val m3 = Merge[Int]("m3") + + source1 ~> m1 ~> FlowFrom[Int].map(_ * 2) ~> m2 ~> FlowFrom[Int].map(_ / 2).map(_ + 1) ~> SubscriberSink(probe) + source2 ~> m1 + source3 ~> m2 + + }.run() + + val subscription = probe.expectSubscription() + + var collected = Set.empty[Int] + for (_ ← 1 to 10) { + subscription.request(1) + collected += probe.expectNext() + } + + collected should be(Set(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + probe.expectComplete() + } + + "work with n-way merge" in { + val source1 = FlowFrom(List(1)) + val source2 = FlowFrom(List(2)) + val source3 = FlowFrom(List(3)) + val source4 = FlowFrom(List(4)) + val source5 = FlowFrom(List(5)) + val source6 = FlowFrom(List.empty[Int]) + + val probe = StreamTestKit.SubscriberProbe[Int]() + + FlowGraph { implicit b ⇒ + val merge = Merge[Int]("merge") + + source1 ~> merge ~> FlowFrom[Int] ~> SubscriberSink(probe) + source2 ~> merge + source3 ~> merge + source4 ~> merge + source5 ~> merge + source6 ~> merge + + }.run() + + val subscription = probe.expectSubscription() + + var collected = Set.empty[Int] + for (_ ← 1 to 5) { + subscription.request(1) + collected += probe.expectNext() + } + + collected should be(Set(1, 2, 3, 4, 5)) + probe.expectComplete() + } + + commonTests() + + "work with one immediately completed and one nonempty publisher" in { + val subscriber1 = setup(completedPublisher, nonemptyPublisher((1 to 4).iterator)) + val subscription1 = subscriber1.expectSubscription() + subscription1.request(4) + subscriber1.expectNext(1) + subscriber1.expectNext(2) + subscriber1.expectNext(3) + subscriber1.expectNext(4) + subscriber1.expectComplete() + + val subscriber2 = setup(nonemptyPublisher((1 to 4).iterator), completedPublisher) + val subscription2 = subscriber2.expectSubscription() + subscription2.request(4) + subscriber2.expectNext(1) + subscriber2.expectNext(2) + subscriber2.expectNext(3) + subscriber2.expectNext(4) + subscriber2.expectComplete() + } + + "work with one delayed completed and one nonempty publisher" in { + val subscriber1 = setup(soonToCompletePublisher, nonemptyPublisher((1 to 4).iterator)) + val subscription1 = subscriber1.expectSubscription() + subscription1.request(4) + subscriber1.expectNext(1) + subscriber1.expectNext(2) + subscriber1.expectNext(3) + subscriber1.expectNext(4) + subscriber1.expectComplete() + + val subscriber2 = setup(nonemptyPublisher((1 to 4).iterator), soonToCompletePublisher) + val subscription2 = subscriber2.expectSubscription() + subscription2.request(4) + subscriber2.expectNext(1) + subscriber2.expectNext(2) + subscriber2.expectNext(3) + subscriber2.expectNext(4) + subscriber2.expectComplete() + } + + "work with one immediately failed and one nonempty publisher" in { + // This is nondeterministic, multiple scenarios can happen + pending + } + + "work with one delayed failed and one nonempty publisher" in { + // This is nondeterministic, multiple scenarios can happen + pending + } + + } + +} diff --git a/akka-stream/src/test/scala/akka/stream/scaladsl2/GraphOpsIntegrationSpec.scala b/akka-stream/src/test/scala/akka/stream/scaladsl2/GraphOpsIntegrationSpec.scala new file mode 100644 index 0000000000..7c27cfc2c3 --- /dev/null +++ b/akka-stream/src/test/scala/akka/stream/scaladsl2/GraphOpsIntegrationSpec.scala @@ -0,0 +1,88 @@ +package akka.stream.scaladsl2 + +import akka.stream.testkit.AkkaSpec + +import akka.stream.{ OverflowStrategy, MaterializerSettings } +import akka.stream.testkit.{ StreamTestKit, AkkaSpec } +import scala.concurrent.Await +import scala.concurrent.duration._ +import akka.stream.scaladsl2.FlowGraphImplicits._ + +class GraphOpsIntegrationSpec extends AkkaSpec { + + val settings = MaterializerSettings(system) + .withInputBuffer(initialSize = 2, maxSize = 16) + .withFanOutBuffer(initialSize = 1, maxSize = 16) + + implicit val materializer = FlowMaterializer(settings) + + "FlowGraphs" must { + + "support broadcast - merge layouts" in { + val resultFuture = FutureSink[Seq[Int]] + + val g = FlowGraph { implicit b ⇒ + val bcast = Broadcast[Int]("broadcast") + val merge = Merge[Int]("merge") + + FlowFrom(List(1, 2, 3)) ~> bcast + bcast ~> merge + bcast ~> FlowFrom[Int].map(_ + 3) ~> merge + merge ~> FlowFrom[Int].grouped(10) ~> resultFuture + }.run() + + Await.result(g.getSinkFor(resultFuture), 3.seconds).sorted should be(List(1, 2, 3, 4, 5, 6)) + } + + "support wikipedia Topological_sorting 2" in { + // see https://en.wikipedia.org/wiki/Topological_sorting#mediaviewer/File:Directed_acyclic_graph.png + val resultFuture2 = FutureSink[Seq[Int]] + val resultFuture9 = FutureSink[Seq[Int]] + val resultFuture10 = FutureSink[Seq[Int]] + + val g = FlowGraph { implicit b ⇒ + val b3 = Broadcast[Int]("b3") + val b7 = Broadcast[Int]("b7") + val b11 = Broadcast[Int]("b11") + val m8 = Merge[Int]("m8") + val m9 = Merge[Int]("m9") + val m10 = Merge[Int]("m10") + val m11 = Merge[Int]("m11") + val in3 = IterableSource(List(3)) + val in5 = IterableSource(List(5)) + val in7 = IterableSource(List(7)) + + // First layer + in7 ~> b7 + b7 ~> m11 + b7 ~> m8 + + in5 ~> m11 + + in3 ~> b3 + b3 ~> m8 + b3 ~> m10 + + // Second layer + m11 ~> b11 + b11 ~> FlowFrom[Int].grouped(1000) ~> resultFuture2 // Vertex 2 is omitted since it has only one in and out + b11 ~> m9 + b11 ~> m10 + + m8 ~> m9 + + // Third layer + m9 ~> FlowFrom[Int].grouped(1000) ~> resultFuture9 + m10 ~> FlowFrom[Int].grouped(1000) ~> resultFuture10 + + }.run() + + Await.result(g.getSinkFor(resultFuture2), 3.seconds).sorted should be(List(5, 7)) + Await.result(g.getSinkFor(resultFuture9), 3.seconds).sorted should be(List(3, 5, 7, 7)) + Await.result(g.getSinkFor(resultFuture10), 3.seconds).sorted should be(List(3, 5, 7)) + + } + + } + +} diff --git a/akka-stream/src/test/scala/akka/stream/scaladsl2/GraphZipSpec.scala b/akka-stream/src/test/scala/akka/stream/scaladsl2/GraphZipSpec.scala new file mode 100644 index 0000000000..1b758d884a --- /dev/null +++ b/akka-stream/src/test/scala/akka/stream/scaladsl2/GraphZipSpec.scala @@ -0,0 +1,81 @@ +package akka.stream.scaladsl2 + +import akka.stream.MaterializerSettings + +import scala.concurrent.duration._ +import akka.stream.testkit.{ AkkaSpec, StreamTestKit } +import akka.stream.testkit2.TwoStreamsSetup +import akka.stream.scaladsl2.FlowGraphImplicits._ + +class GraphZipSpec extends TwoStreamsSetup { + + override type Outputs = (Int, Int) + val op = Zip[Int, Int] + override def operationUnderTestLeft() = op.left + override def operationUnderTestRight() = op.right + + "Zip" must { + // + // "work in the happy case" in { + // val probe = StreamTestKit.SubscriberProbe[(Int, String)]() + // + // FlowGraph { implicit b ⇒ + // val zip = Zip[Int, String] + // + // FlowFrom(1 to 4) ~> zip.left + // FlowFrom(List("A", "B", "C", "D", "E", "F")) ~> zip.right + // + // zip.out ~> SubscriberSink(probe) + // }.run() + // + // val subscription = probe.expectSubscription() + // + // subscription.request(2) + // probe.expectNext((1, "A")) + // probe.expectNext((2, "B")) + // + // subscription.request(1) + // probe.expectNext((3, "C")) + // subscription.request(1) + // probe.expectNext((4, "D")) + // + // probe.expectComplete() + // } + // + // commonTests() + // + // "work with one immediately completed and one nonempty publisher" in { + // val subscriber1 = setup(completedPublisher, nonemptyPublisher((1 to 4).iterator)) + // subscriber1.expectCompletedOrSubscriptionFollowedByComplete() + // + // val subscriber2 = setup(nonemptyPublisher((1 to 4).iterator), completedPublisher) + // subscriber2.expectCompletedOrSubscriptionFollowedByComplete() + // } + // + // "work with one delayed completed and one nonempty publisher" in { + // val subscriber1 = setup(soonToCompletePublisher, nonemptyPublisher((1 to 4).iterator)) + // subscriber1.expectCompletedOrSubscriptionFollowedByComplete() + // + // val subscriber2 = setup(nonemptyPublisher((1 to 4).iterator), soonToCompletePublisher) + // subscriber2.expectCompletedOrSubscriptionFollowedByComplete() + // } + // + // "work with one immediately failed and one nonempty publisher" in { + // val subscriber1 = setup(failedPublisher, nonemptyPublisher((1 to 4).iterator)) + // subscriber1.expectErrorOrSubscriptionFollowedByError(TestException) + // + // val subscriber2 = setup(nonemptyPublisher((1 to 4).iterator), failedPublisher) + // subscriber2.expectErrorOrSubscriptionFollowedByError(TestException) + // } + // + // "work with one delayed failed and one nonempty publisher" in { + // val subscriber1 = setup(soonToFailPublisher, nonemptyPublisher((1 to 4).iterator)) + // subscriber1.expectErrorOrSubscriptionFollowedByError(TestException) + // + // val subscriber2 = setup(nonemptyPublisher((1 to 4).iterator), soonToFailPublisher) + // val subscription2 = subscriber2.expectErrorOrSubscriptionFollowedByError(TestException) + // } + + } + +} diff --git a/akka-stream/src/test/scala/akka/stream/testkit2/TwoStreamsSetup.scala b/akka-stream/src/test/scala/akka/stream/testkit2/TwoStreamsSetup.scala new file mode 100644 index 0000000000..73d0604eb9 --- /dev/null +++ b/akka-stream/src/test/scala/akka/stream/testkit2/TwoStreamsSetup.scala @@ -0,0 +1,89 @@ +package akka.stream.testkit2 + +import akka.stream.MaterializerSettings +import akka.stream.scaladsl2._ +import akka.stream.testkit.{ StreamTestKit, AkkaSpec } +import org.reactivestreams.Publisher + +import scala.util.control.NoStackTrace + +abstract class TwoStreamsSetup extends AkkaSpec { + + val settings = MaterializerSettings(system) + .withInputBuffer(initialSize = 2, maxSize = 2) + .withFanOutBuffer(initialSize = 2, maxSize = 2) + + implicit val materializer = FlowMaterializer(settings) + + case class TE(message: String) extends RuntimeException(message) with NoStackTrace + + val TestException = TE("test") + + type Outputs + + def operationUnderTestLeft(): JunctionInPort[Int] { type NextT = Outputs } + def operationUnderTestRight(): JunctionInPort[Int] { type NextT = Outputs } + + def setup(p1: Publisher[Int], p2: Publisher[Int]) = { + val subscriber = StreamTestKit.SubscriberProbe[Outputs]() + FlowGraph { implicit b ⇒ + import FlowGraphImplicits._ + val left = operationUnderTestLeft() + val right = operationUnderTestRight() + val x = FlowFrom(p1) ~> left ~> FlowFrom[Outputs] ~> SubscriberSink(subscriber) + FlowFrom(p2) ~> right + }.run() + + subscriber + } + + def failedPublisher[T]: Publisher[T] = StreamTestKit.errorPublisher[T](TestException) + + def completedPublisher[T]: Publisher[T] = StreamTestKit.emptyPublisher[T] + + def nonemptyPublisher[T](elems: Iterator[T]): Publisher[T] = FlowFrom(elems).toPublisher() + + def soonToFailPublisher[T]: Publisher[T] = StreamTestKit.lazyErrorPublisher[T](TestException) + + def soonToCompletePublisher[T]: Publisher[T] = StreamTestKit.lazyEmptyPublisher[T] + + def commonTests() = { + "work with two immediately completed publishers" in { + val subscriber = setup(completedPublisher, completedPublisher) + subscriber.expectCompletedOrSubscriptionFollowedByComplete() + } + + "work with two delayed completed publishers" in { + val subscriber = setup(soonToCompletePublisher, soonToCompletePublisher) + subscriber.expectCompletedOrSubscriptionFollowedByComplete() + } + + "work with one immediately completed and one delayed completed publisher" in { + val subscriber = setup(completedPublisher, soonToCompletePublisher) + subscriber.expectCompletedOrSubscriptionFollowedByComplete() + } + + "work with two immediately failed publishers" in { + val subscriber = setup(failedPublisher, failedPublisher) + subscriber.expectErrorOrSubscriptionFollowedByError(TestException) + } + + "work with two delayed failed publishers" in { + val subscriber = setup(soonToFailPublisher, soonToFailPublisher) + subscriber.expectErrorOrSubscriptionFollowedByError(TestException) + } + + // Warning: The two test cases below are somewhat implementation specific and might fail if the implementation + // is changed. They are here to be an early warning though. + "work with one immediately failed and one delayed failed publisher (case 1)" in { + val subscriber = setup(soonToFailPublisher, failedPublisher) + subscriber.expectErrorOrSubscriptionFollowedByError(TestException) + } + + "work with one immediately failed and one delayed failed publisher (case 2)" in { + val subscriber = setup(failedPublisher, soonToFailPublisher) + subscriber.expectErrorOrSubscriptionFollowedByError(TestException) + } + } + +}