diff --git a/akka-docs/src/main/paradox/stream/operators/MergeSequence.md b/akka-docs/src/main/paradox/stream/operators/MergeSequence.md new file mode 100644 index 0000000000..a7f649c224 --- /dev/null +++ b/akka-docs/src/main/paradox/stream/operators/MergeSequence.md @@ -0,0 +1,46 @@ +# MergeSequence + +Merge a linear sequence partitioned across multiple sources. + +@ref[Fan-in operators](index.md#fan-in-operators) + +## Signature + +@apidoc[MergeSequence] + +## Description + +Merge a linear sequence partitioned across multiple sources. Each element from upstream must have a defined index, +starting from 0. There must be no gaps in the sequence, nor may there be any duplicates. Each upstream source must be +ordered by the sequence. + +## Example + +`MergeSequence` is most useful when used in combination with `Partition`, to merge the partitioned stream back into +a single stream, while maintaining the order of the original elements. `zipWithIndex` can be used before partitioning +the stream to generate the index. + +The example below shows partitioning a stream of messages into one stream for elements that must be processed by a +given processing flow, and another stream for elements for which no processing will be done, and then merges them +back together so that the messages can be acknowledged in order. + +Scala +: @@snip [MergeSequenceDocExample.scala](/akka-docs/src/test/scala/docs/stream/operators/MergeSequenceDocExample.scala) { #merge-sequence } + +Java +: @@snip [MergeSequenceDocExample.java](/akka-docs/src/test/java/jdocs/stream/operators/MergeSequenceDocExample.java) { #import #merge-sequence } + +## Reactive Streams semantics + +@@@div { .callout } + +**emits** when one of the upstreams has the next expected element in the sequence available. + +**backpressures** when downstream backpressures + +**completes** when all upstreams complete + +**cancels** downstream cancels + +@@@ + diff --git a/akka-docs/src/main/paradox/stream/operators/index.md b/akka-docs/src/main/paradox/stream/operators/index.md index 3253428348..4731863220 100644 --- a/akka-docs/src/main/paradox/stream/operators/index.md +++ b/akka-docs/src/main/paradox/stream/operators/index.md @@ -263,6 +263,7 @@ the inputs in different ways. | |Operator|Description| |--|--|--| +| |@ref[MergeSequence](MergeSequence.md)|Merge a linear sequence partitioned across multiple sources.| |Source/Flow|@ref[concat](Source-or-Flow/concat.md)|After completion of the original upstream the elements of the given source will be emitted.| |Source/Flow|@ref[interleave](Source-or-Flow/interleave.md)|Emits a specifiable number of elements from the original source, then from the provided source and repeats.| |Source/Flow|@ref[merge](Source-or-Flow/merge.md)|Merge multiple sources.| @@ -487,6 +488,7 @@ For more background see the @ref[Error Handling in Streams](../stream-error.md) * [mergeLatest](Source-or-Flow/mergeLatest.md) * [mergePreferred](Source-or-Flow/mergePreferred.md) * [mergePrioritized](Source-or-Flow/mergePrioritized.md) +* [MergeSequence](MergeSequence.md) * [mergeSorted](Source-or-Flow/mergeSorted.md) * [monitor](Source-or-Flow/monitor.md) * [never](Source/never.md) diff --git a/akka-docs/src/main/paradox/stream/stream-graphs.md b/akka-docs/src/main/paradox/stream/stream-graphs.md index 2a97db9fdc..f3db7327e7 100644 --- a/akka-docs/src/main/paradox/stream/stream-graphs.md +++ b/akka-docs/src/main/paradox/stream/stream-graphs.md @@ -47,6 +47,7 @@ Akka Streams currently provide these junctions (for a detailed list see the @ref * @scala[`MergePreferred[In]`]@java[`MergePreferred`] – like `Merge` but if elements are available on `preferred` port, it picks from it, otherwise randomly from `others` * @scala[`MergePrioritized[In]`]@java[`MergePrioritized`] – like `Merge` but if elements are available on all input ports, it picks from them randomly based on their `priority` * @scala[`MergeLatest[In]`]@java[`MergeLatest`] – *(N inputs, 1 output)* emits `List[In]`, when i-th input stream emits element, then i-th element in emitted list is updated + * @scala[`MergeSequence[In]`]@java[`MergeSequence`] – *(N inputs, 1 output)* emits `List[In]`, where the input streams must represent a partitioned sequence that must be merged back together in order * @scala[`ZipWith[A,B,...,Out]`]@java[`ZipWith`] – *(N inputs, 1 output)* which takes a function of N inputs that given a value for each input emits 1 output element * @scala[`Zip[A,B]`]@java[`Zip`] – *(2 inputs, 1 output)* is a `ZipWith` specialised to zipping input streams of `A` and `B` into a @scala[`(A,B)`]@java[`Pair(A,B)`] tuple stream * @scala[`Concat[A]`]@java[`Concat`] – *(2 inputs, 1 output)* concatenates two streams (first consume one, then the second one) diff --git a/akka-docs/src/test/java/jdocs/stream/operators/MergeSequenceDocExample.java b/akka-docs/src/test/java/jdocs/stream/operators/MergeSequenceDocExample.java new file mode 100644 index 0000000000..daf0129b31 --- /dev/null +++ b/akka-docs/src/test/java/jdocs/stream/operators/MergeSequenceDocExample.java @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2020 Lightbend Inc. + */ + +package jdocs.stream.operators; + +import akka.actor.ActorSystem; + +// #import +import akka.NotUsed; +import akka.japi.Pair; +import akka.stream.ClosedShape; +import akka.stream.UniformFanInShape; +import akka.stream.UniformFanOutShape; +import akka.stream.javadsl.Flow; +import akka.stream.javadsl.GraphDSL; +import akka.stream.javadsl.MergeSequence; +import akka.stream.javadsl.Partition; +import akka.stream.javadsl.RunnableGraph; +import akka.stream.javadsl.Sink; +import akka.stream.javadsl.Source; +// #import + +public class MergeSequenceDocExample { + + private final ActorSystem system = ActorSystem.create("MergeSequenceDocExample"); + + interface Message {} + + boolean shouldProcess(Message message) { + return true; + } + + Source createSubscription() { + return null; + } + + Flow, Pair, NotUsed> createMessageProcessor() { + return null; + } + + Sink createMessageAcknowledger() { + return null; + } + + void mergeSequenceExample() { + + // #merge-sequence + + Source subscription = createSubscription(); + Flow, Pair, NotUsed> messageProcessor = + createMessageProcessor(); + Sink messageAcknowledger = createMessageAcknowledger(); + + RunnableGraph.fromGraph( + GraphDSL.create( + builder -> { + // Partitions stream into messages that should or should not be processed + UniformFanOutShape, Pair> partition = + builder.add( + Partition.create(2, element -> shouldProcess(element.first()) ? 0 : 1)); + // Merges stream by the index produced by zipWithIndex + UniformFanInShape, Pair> merge = + builder.add(MergeSequence.create(2, Pair::second)); + + builder.from(builder.add(subscription.zipWithIndex())).viaFanOut(partition); + // First goes through message processor + builder.from(partition.out(0)).via(builder.add(messageProcessor)).viaFanIn(merge); + // Second partition bypasses message processor + builder.from(partition.out(1)).viaFanIn(merge); + + // Unwrap message index pairs and send to acknowledger + builder + .from(merge.out()) + .to( + builder.add( + Flow.>create() + .map(Pair::first) + .to(messageAcknowledger))); + + return ClosedShape.getInstance(); + })) + .run(system); + + // #merge-sequence + } +} diff --git a/akka-docs/src/test/scala/docs/stream/operators/MergeSequenceDocExample.scala b/akka-docs/src/test/scala/docs/stream/operators/MergeSequenceDocExample.scala new file mode 100644 index 0000000000..12b3754d2e --- /dev/null +++ b/akka-docs/src/test/scala/docs/stream/operators/MergeSequenceDocExample.scala @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2020 Lightbend Inc. + */ + +package docs.stream.operators + +import akka.actor.ActorSystem + +object MergeSequenceDocExample { + + implicit val system: ActorSystem = ??? + + // #merge-sequence + import akka.NotUsed + import akka.stream.ClosedShape + import akka.stream.scaladsl.{ Flow, GraphDSL, MergeSequence, Partition, RunnableGraph, Sink, Source } + + val subscription: Source[Message, NotUsed] = createSubscription() + val messageProcessor: Flow[(Message, Long), (Message, Long), NotUsed] = + createMessageProcessor() + val messageAcknowledger: Sink[Message, NotUsed] = createMessageAcknowledger() + + RunnableGraph + .fromGraph(GraphDSL.create() { implicit builder => + import GraphDSL.Implicits._ + // Partitions stream into messages that should or should not be processed + val partition = builder.add(Partition[(Message, Long)](2, { + case (message, _) if shouldProcess(message) => 0 + case _ => 1 + })) + // Merges stream by the index produced by zipWithIndex + val merge = builder.add(MergeSequence[(Message, Long)](2)(_._2)) + + subscription.zipWithIndex ~> partition.in + // First goes through message processor + partition.out(0) ~> messageProcessor ~> merge + // Second partition bypasses message processor + partition.out(1) ~> merge + merge.out.map(_._1) ~> messageAcknowledger + ClosedShape + }) + .run() + + // #merge-sequence + + def shouldProcess(message: Message): Boolean = true + trait Message + def createSubscription(): Source[Message, NotUsed] = ??? + def createMessageProcessor(): Flow[(Message, Long), (Message, Long), NotUsed] = ??? + def createMessageAcknowledger(): Sink[Message, NotUsed] = ??? + +} diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSequenceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSequenceSpec.scala new file mode 100644 index 0000000000..ffdb6c7552 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSequenceSpec.scala @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2020 Lightbend Inc. + */ + +package akka.stream.scaladsl + +import akka.NotUsed +import akka.stream._ +import akka.stream.testkit._ +import akka.stream.testkit.scaladsl.StreamTestKit._ + +import scala.collection.immutable +import scala.concurrent.Await +import scala.concurrent.duration._ + +class GraphMergeSequenceSpec extends TwoStreamsSetup { + + override type Outputs = Int + + override def fixture(b: GraphDSL.Builder[_]): Fixture = new Fixture { + val merge = b.add(MergeSequence[Outputs](2)(i => i)) + + override def left: Inlet[Outputs] = merge.in(0) + override def right: Inlet[Outputs] = merge.in(1) + override def out: Outlet[Outputs] = merge.out + + } + + private def merge(seqs: immutable.Seq[Long]*): immutable.Seq[Long] = + mergeSources(seqs.map(Source(_)): _*) + + private def mergeSources(sources: Source[Long, NotUsed]*): immutable.Seq[Long] = { + val future = Source + .fromGraph(GraphDSL.create() { implicit builder => + import GraphDSL.Implicits._ + val merge = builder.add(new MergeSequence[Long](sources.size)(identity)) + sources.foreach { source => + source ~> merge + } + + SourceShape(merge.out) + }) + .runWith(Sink.seq) + Await.result(future, 3.seconds) + } + + "merge sequence" must { + + "merge interleaved streams" in assertAllStagesStopped { + (merge(List(0L, 4L, 8L, 9L, 11L), List(1L, 3L, 6L, 10L, 13L), List(2L, 5L, 7L, 12L)) should contain) + .theSameElementsInOrderAs(0L to 13L) + } + + "merge non interleaved streams" in assertAllStagesStopped { + (merge(List(5L, 6L, 7L, 8L, 9L), List(0L, 1L, 2L, 3L, 4L), List(10L, 11L, 12L, 13L)) should contain) + .theSameElementsInOrderAs(0L to 13L) + } + + "fail on duplicate sequence numbers" in assertAllStagesStopped { + an[IllegalStateException] should be thrownBy merge(List(0L, 1L, 2L), List(2L)) + } + + "fail on missing sequence numbers" in assertAllStagesStopped { + an[IllegalStateException] should be thrownBy merge( + List(0L, 4L, 8L, 9L, 11L), + List(1L, 3L, 10L, 13L), + List(2L, 5L, 7L, 12L)) + } + + "fail on missing sequence numbers if some streams have completed" in assertAllStagesStopped { + an[IllegalStateException] should be thrownBy merge( + List(0L, 4L, 8L, 9L, 11L), + List(1L, 3L, 6L, 10L, 13L, 15L), + List(2L, 5L, 7L, 12L)) + } + + "fail on sequence regression in a single stream" in assertAllStagesStopped { + an[IllegalStateException] should be thrownBy merge( + List(0L, 4L, 8L, 7L, 9L, 11L), + List(1L, 3L, 6L, 10L, 13L), + List(2L, 5L, 7L, 12L)) + } + + commonTests() + } + +} diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala index a45990f27b..b123a543b2 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala @@ -523,6 +523,70 @@ object Concat { } +/** + * Takes multiple streams whose elements in aggregate have a defined linear + * sequence with difference 1, starting at 0, and outputs a single stream + * containing these elements, in order. That is, given a set of input streams + * with combined elements *ek*: + * + * *e0*, *e1*, *e2*, ..., *en* + * + * This will output a stream ordered by *k*. + * + * The elements in the input streams must already be sorted according to the + * sequence. The input streams do not need to be linear, but the aggregate + * stream must be linear, no element *k* may be skipped or duplicated, either + * of these conditions will cause the stream to fail. + * + * The typical use case for this is to merge a partitioned stream back + * together while maintaining order. This can be achieved by first using + * `zipWithIndex` on the input stream, then partitioning using a + * [[Partition]] fanout, and then maintaining the index through the processing + * of each partition before bringing together with this stage. + * + * '''Emits when''' one of the upstreams has the next expected element in the + * sequence available. + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' all upstreams complete + * + * '''Cancels when''' downstream cancels + */ +object MergeSequence { + + /** + * Create a new anonymous `MergeSequence` operator with two input ports. + * + * @param extractSequence The function to extract the sequence from an element. + */ + def create[T](extractSequence: function.Function[T, Long]): Graph[UniformFanInShape[T, T], NotUsed] = + scaladsl.MergeSequence[T]()(extractSequence.apply) + + /** + * Create a new anonymous `MergeSequence` operator. + * + * @param inputCount The number of input streams. + * @param extractSequence The function to extract the sequence from an element. + */ + def create[T](inputCount: Int, extractSequence: function.Function[T, Long]): Graph[UniformFanInShape[T, T], NotUsed] = + scaladsl.MergeSequence[T](inputCount)(extractSequence.apply) + + /** + * Create a new anonymous `Concat` operator with the specified input types. + * + * @param clazz a type hint for this method + * @param inputCount The number of input streams. + * @param extractSequence The function to extract the sequence from an element. + */ + def create[T]( + @unused clazz: Class[T], + inputCount: Int, + extractSequence: function.Function[T, Long]): Graph[UniformFanInShape[T, T], NotUsed] = + create(inputCount, extractSequence) + +} + object GraphDSL extends GraphCreate { /** diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala index 232db8a9f4..bb4967c78e 100755 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala @@ -1395,6 +1395,137 @@ private[stream] final class OrElse[T] extends GraphStage[UniformFanInShape[T, T] } +object MergeSequence { + + private case class Pushed[T](in: Inlet[T], sequence: Long, elem: T) + + private implicit def ordering[T]: Ordering[Pushed[T]] = Ordering.by[Pushed[T], Long](_.sequence).reverse + + /** @see [[MergeSequence]] **/ + def apply[T](inputPorts: Int = 2)(extractSequence: T => Long): Graph[UniformFanInShape[T, T], NotUsed] = + GraphStages.withDetachedInputs(new MergeSequence[T](inputPorts)(extractSequence)) +} + +/** + * Takes multiple streams whose elements in aggregate have a defined linear + * sequence with difference 1, starting at 0, and outputs a single stream + * containing these elements, in order. That is, given a set of input streams + * with combined elements *ek*: + * + * *e0*, *e1*, *e2*, ..., *en* + * + * This will output a stream ordered by *k*. + * + * The elements in the input streams must already be sorted according to the + * sequence. The input streams do not need to be linear, but the aggregate + * stream must be linear, no element *k* may be skipped or duplicated, either + * of these conditions will cause the stream to fail. + * + * The typical use case for this is to merge a partitioned stream back + * together while maintaining order. This can be achieved by first using + * `zipWithIndex` on the input stream, then partitioning using a + * [[Partition]] fanout, and then maintaining the index through the processing + * of each partition before bringing together with this stage. + * + * '''Emits when''' one of the upstreams has the next expected element in the + * sequence available. + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' all upstreams complete + * + * '''Cancels when''' downstream cancels + */ +final class MergeSequence[T](val inputPorts: Int)(extractSequence: T => Long) + extends GraphStage[UniformFanInShape[T, T]] { + require(inputPorts > 1, "A MergeSequence must have more than 1 input ports") + private val in: IndexedSeq[Inlet[T]] = Vector.tabulate(inputPorts)(i => Inlet[T]("MergeSequence.in" + i)) + private val out: Outlet[T] = Outlet("MergeSequence.out") + override val shape: UniformFanInShape[T, T] = UniformFanInShape(out, in: _*) + + import MergeSequence._ + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with OutHandler { + private var nextSequence = 0L + private val available = mutable.PriorityQueue.empty[Pushed[T]] + private var complete = 0 + + setHandler(out, this) + + in.zipWithIndex.foreach { + case (inPort, idx) => + setHandler( + inPort, + new InHandler { + override def onPush(): Unit = { + val elem = grab(inPort) + val sequence = extractSequence(elem) + if (sequence < nextSequence) { + failStage( + new IllegalStateException(s"Sequence regression from $nextSequence to $sequence on port $idx")) + } else if (sequence == nextSequence && isAvailable(out)) { + push(out, elem) + tryPull(inPort) + nextSequence += 1 + } else { + available.enqueue(Pushed(inPort, sequence, elem)) + detectMissedSequence() + } + } + + override def onUpstreamFinish(): Unit = { + complete += 1 + if (complete == inputPorts && available.isEmpty) { + completeStage() + } else { + detectMissedSequence() + } + } + }) + } + + def onPull(): Unit = + if (available.nonEmpty && available.head.sequence == nextSequence) { + val pushed = available.dequeue() + push(out, pushed.elem) + if (complete == inputPorts && available.isEmpty) { + completeStage() + } else { + if (available.nonEmpty && available.head.sequence == nextSequence) { + failStage( + new IllegalStateException( + s"Duplicate sequence $nextSequence on ports ${pushed.in} and ${available.head.in}")) + } + tryPull(pushed.in) + nextSequence += 1 + } + } else { + detectMissedSequence() + } + + private def detectMissedSequence(): Unit = + // Cheap to calculate, but doesn't give the right answer, because there might be input ports + // that are both complete and still have one last buffered element. + if (isAvailable(out) && available.size + complete >= inputPorts) { + // So in the event that this was true we count the number of ports that we have elements buffered for that + // are not yet closed, and add that to the complete ones, to see if we're in a dead lock. + if (available.count(pushed => !isClosed(pushed.in)) + complete == inputPorts) { + failStage( + new IllegalStateException( + s"Expected sequence $nextSequence, but all input ports have pushed or are complete, " + + "but none have pushed the next sequence number. Pushed sequences: " + + available.toVector.map(p => s"${p.in}: ${p.sequence}").mkString(", "))) + } + } + + override def preStart(): Unit = + in.foreach(pull) + } + + override def toString: String = s"MergeSequence($inputPorts)" +} + object GraphDSL extends GraphApply { /** diff --git a/project/StreamOperatorsIndexGenerator.scala b/project/StreamOperatorsIndexGenerator.scala index b84d72f2fe..bcd4904f90 100644 --- a/project/StreamOperatorsIndexGenerator.scala +++ b/project/StreamOperatorsIndexGenerator.scala @@ -171,6 +171,7 @@ object StreamOperatorsIndexGenerator extends AutoPlugin { .map(method => (element, method)) } ++ List( (noElement, "Partition"), + (noElement, "MergeSequence"), (noElement, "Broadcast"), (noElement, "Balance"), (noElement, "Unzip"),