diff --git a/akka-bench-jmh/src/main/scala/akka/remote/artery/LatchSink.scala b/akka-bench-jmh/src/main/scala/akka/remote/artery/LatchSink.scala index 5953b352bc..4ac234e9cd 100644 --- a/akka-bench-jmh/src/main/scala/akka/remote/artery/LatchSink.scala +++ b/akka-bench-jmh/src/main/scala/akka/remote/artery/LatchSink.scala @@ -24,6 +24,11 @@ class LatchSink(countDownAfter: Int, latch: CountDownLatch) extends GraphStage[S override def preStart(): Unit = pull(in) + override def onUpstreamFailure(ex: Throwable): Unit = { + println(ex.getMessage) + ex.printStackTrace() + } + override def onPush(): Unit = { n += 1 if (n == countDownAfter) diff --git a/akka-bench-jmh/src/main/scala/akka/stream/PartitionHubBenchmark.scala b/akka-bench-jmh/src/main/scala/akka/stream/PartitionHubBenchmark.scala new file mode 100644 index 0000000000..f8c1c7ead6 --- /dev/null +++ b/akka-bench-jmh/src/main/scala/akka/stream/PartitionHubBenchmark.scala @@ -0,0 +1,128 @@ +/** + * Copyright (C) 2014-2017 Lightbend Inc. + */ + +package akka.stream + +import java.util.concurrent.TimeUnit +import akka.NotUsed +import akka.actor.ActorSystem +import akka.stream.scaladsl._ +import com.typesafe.config.ConfigFactory +import org.openjdk.jmh.annotations._ +import java.util.concurrent.Semaphore +import scala.util.Success +import akka.stream.impl.fusing.GraphStages +import org.reactivestreams._ +import scala.concurrent.Await +import scala.concurrent.duration._ +import akka.remote.artery.BenchTestSource +import java.util.concurrent.CountDownLatch +import akka.remote.artery.LatchSink +import akka.stream.impl.PhasedFusingActorMaterializer +import akka.testkit.TestProbe +import akka.stream.impl.StreamSupervisor +import akka.stream.scaladsl.PartitionHub +import akka.remote.artery.FixedSizePartitionHub + +object PartitionHubBenchmark { + final val OperationsPerInvocation = 100000 +} + +@State(Scope.Benchmark) +@OutputTimeUnit(TimeUnit.SECONDS) +@BenchmarkMode(Array(Mode.Throughput)) +class PartitionHubBenchmark { + import PartitionHubBenchmark._ + + val config = ConfigFactory.parseString( + """ + akka.actor.default-dispatcher { + executor = "fork-join-executor" + fork-join-executor { + parallelism-factor = 1 + } + } + """ + ) + + implicit val system = ActorSystem("PartitionHubBenchmark", config) + + var materializer: ActorMaterializer = _ + + @Param(Array("2", "5", "10", "20", "30")) + var NumberOfStreams = 0 + + @Param(Array("256")) + var BufferSize = 0 + + var testSource: Source[java.lang.Integer, NotUsed] = _ + + @Setup + def setup(): Unit = { + val settings = ActorMaterializerSettings(system) + materializer = ActorMaterializer(settings) + + testSource = Source.fromGraph(new BenchTestSource(OperationsPerInvocation)) + } + + @TearDown + def shutdown(): Unit = { + Await.result(system.terminate(), 5.seconds) + } + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def partition(): Unit = { + val N = OperationsPerInvocation + val latch = new CountDownLatch(NumberOfStreams) + + val source = testSource + .runWith(PartitionHub.sink[java.lang.Integer]( + (size, elem) => elem.intValue % NumberOfStreams, + startAfterNrOfConsumers = NumberOfStreams, bufferSize = BufferSize + ))(materializer) + + for (_ <- 0 until NumberOfStreams) + source.runWith(new LatchSink(N / NumberOfStreams, latch))(materializer) + + if (!latch.await(30, TimeUnit.SECONDS)) { + dumpMaterializer() + throw new RuntimeException("Latch didn't complete in time") + } + } + + // @Benchmark + // @OperationsPerInvocation(OperationsPerInvocation) + def arteryLanes(): Unit = { + val N = OperationsPerInvocation + val latch = new CountDownLatch(NumberOfStreams) + + val source = testSource + .runWith( + Sink.fromGraph(new FixedSizePartitionHub( + _.intValue % NumberOfStreams, + lanes = NumberOfStreams, bufferSize = BufferSize + )) + )(materializer) + + for (_ <- 0 until NumberOfStreams) + source.runWith(new LatchSink(N / NumberOfStreams, latch))(materializer) + + if (!latch.await(30, TimeUnit.SECONDS)) { + dumpMaterializer() + throw new RuntimeException("Latch didn't complete in time") + } + } + + private def dumpMaterializer(): Unit = { + materializer match { + case impl: PhasedFusingActorMaterializer ⇒ + val probe = TestProbe()(system) + impl.supervisor.tell(StreamSupervisor.GetChildren, probe.ref) + val children = probe.expectMsgType[StreamSupervisor.Children].children + children.foreach(_ ! StreamSupervisor.PrintDebugDump) + } + } + +} diff --git a/akka-docs/src/main/paradox/scala/stream/stream-dynamic.md b/akka-docs/src/main/paradox/scala/stream/stream-dynamic.md index 026399ff94..f408514ccd 100644 --- a/akka-docs/src/main/paradox/scala/stream/stream-dynamic.md +++ b/akka-docs/src/main/paradox/scala/stream/stream-dynamic.md @@ -74,7 +74,7 @@ before any materialization takes place. @@@ -## Dynamic fan-in and fan-out with MergeHub and BroadcastHub +## Dynamic fan-in and fan-out with MergeHub, BroadcastHub and PartitionHub There are many cases when consumers or producers of a certain service (represented as a Sink, Source, or possibly Flow) are dynamic and not known in advance. The Graph DSL does not allow to represent this, all connections of the graph @@ -169,4 +169,71 @@ Scala : @@snip [HubsDocSpec.scala]($code$/scala/docs/stream/HubsDocSpec.scala) { #pub-sub-4 } Java -: @@snip [HubDocTest.java]($code$/java/jdocs/stream/HubDocTest.java) { #pub-sub-4 } \ No newline at end of file +: @@snip [HubDocTest.java]($code$/java/jdocs/stream/HubDocTest.java) { #pub-sub-4 } + +### Using the PartitionHub + +**This is a @ref:[may change](../common/may-change.md) feature*** + +A `PartitionHub` can be used to route elements from a common producer to a dynamic set of consumers. +The selection of consumer is done with a function. Each element can be routed to only one consumer. + +The rate of the producer will be automatically adapted to the slowest consumer. In this case, the hub is a `Sink` +to which the single producer must be attached first. Consumers can only be attached once the `Sink` has +been materialized (i.e. the producer has been started). One example of using the `PartitionHub`: + +Scala +: @@snip [HubsDocSpec.scala]($code$/scala/docs/stream/HubsDocSpec.scala) { #partition-hub } + +Java +: @@snip [HubDocTest.java]($code$/java/jdocs/stream/HubDocTest.java) { #partition-hub } + +The `partitioner` function takes two parameters; the first is the number of active consumers and the second +is the stream element. The function should return the index of the selected consumer for the given element, +i.e. `int` greater than or equal to 0 and less than number of consumers. + +The resulting `Source` can be materialized any number of times, each materialization effectively attaching +a new consumer. If there are no consumers attached to this hub then it will not drop any elements but instead +backpressure the upstream producer until consumers arrive. This behavior can be tweaked by using the combinators +`.buffer` for example with a drop strategy, or just attaching a consumer that drops all messages. If there +are no other consumers, this will ensure that the producer is kept drained (dropping all elements) and once a new +consumer arrives and messages are routed to the new consumer it will adaptively slow down, ensuring no more messages +are dropped. + +It is possible to define how many initial consumers that are required before it starts emitting any messages +to the attached consumers. While not enough consumers have been attached messages are buffered and when the +buffer is full the upstream producer is backpressured. No messages are dropped. + +The above example illustrate a stateless partition function. For more advanced stateful routing the @java[`ofStateful`] +@scala[`statefulSink`] can be used. Here is an example of a stateful round-robin function: + +Scala +: @@snip [HubsDocSpec.scala]($code$/scala/docs/stream/HubsDocSpec.scala) { #partition-hub-stateful } + +Java +: @@snip [HubDocTest.java]($code$/java/jdocs/stream/HubDocTest.java) { #partition-hub-stateful } + +Note that it is a factory of a function to to be able to hold stateful variables that are +unique for each materialization. @java[In this example the `partitioner` function is implemented as a class to +be able to hold the mutable variable. A new instance of `RoundRobin` is created for each materialization of the hub.] + +@@@ div { .group-java } +@@snip [HubDocTest.java]($code$/java/jdocs/stream/HubDocTest.java) { #partition-hub-stateful-function } +@@@ + +The function takes two parameters; the first is information about active consumers, including an array of +consumer identifiers and the second is the stream element. The function should return the selected consumer +identifier for the given element. The function will never be called when there are no active consumers, i.e. +there is always at least one element in the array of identifiers. + +Another interesting type of routing is to prefer routing to the fastest consumers. The `ConsumerInfo` +has an accessor `queueSize` that is approximate number of buffered elements for a consumer. +Larger value than other consumers could be an indication of that the consumer is slow. +Note that this is a moving target since the elements are consumed concurrently. Here is an example of +a hub that routes to the consumer with least buffered elements: + +Scala +: @@snip [HubsDocSpec.scala]($code$/scala/docs/stream/HubsDocSpec.scala) { #partition-hub-fastest } + +Java +: @@snip [HubDocTest.java]($code$/java/jdocs/stream/HubDocTest.java) { #partition-hub-fastest } diff --git a/akka-docs/src/test/java/jdocs/stream/HubDocTest.java b/akka-docs/src/test/java/jdocs/stream/HubDocTest.java index 9e6ea667b6..63746772d2 100644 --- a/akka-docs/src/test/java/jdocs/stream/HubDocTest.java +++ b/akka-docs/src/test/java/jdocs/stream/HubDocTest.java @@ -11,17 +11,25 @@ import akka.japi.Pair; import akka.stream.ActorMaterializer; import akka.stream.KillSwitches; import akka.stream.Materializer; +import akka.stream.ThrottleMode; import akka.stream.UniqueKillSwitch; import akka.stream.javadsl.*; +import akka.stream.javadsl.PartitionHub.ConsumerInfo; + import jdocs.AbstractJavaTest; import akka.testkit.javadsl.TestKit; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; +import scala.concurrent.duration.Duration; import scala.concurrent.duration.FiniteDuration; +import java.util.List; import java.util.concurrent.CompletionStage; import java.util.concurrent.TimeUnit; +import java.util.function.BiFunction; +import java.util.function.Supplier; +import java.util.function.ToLongBiFunction; public class HubDocTest extends AbstractJavaTest { @@ -137,4 +145,136 @@ public class HubDocTest extends AbstractJavaTest { killSwitch.shutdown(); //#pub-sub-4 } + + @Test + public void dynamicPartition() { + // Used to be able to clean up the running stream + ActorMaterializer materializer = ActorMaterializer.create(system); + + //#partition-hub + // A simple producer that publishes a new "message-n" every second + Source producer = Source.tick( + FiniteDuration.create(1, TimeUnit.SECONDS), + FiniteDuration.create(1, TimeUnit.SECONDS), + "message" + ).zipWith(Source.range(0, 100), (a, b) -> a + "-" + b); + + // Attach a PartitionHub Sink to the producer. This will materialize to a + // corresponding Source. + // (We need to use toMat and Keep.right since by default the materialized + // value to the left is used) + RunnableGraph> runnableGraph = + producer.toMat(PartitionHub.of( + String.class, + (size, elem) -> Math.abs(elem.hashCode()) % size, + 2, 256), Keep.right()); + + // By running/materializing the producer, we get back a Source, which + // gives us access to the elements published by the producer. + Source fromProducer = runnableGraph.run(materializer); + + // Print out messages from the producer in two independent consumers + fromProducer.runForeach(msg -> System.out.println("consumer1: " + msg), materializer); + fromProducer.runForeach(msg -> System.out.println("consumer2: " + msg), materializer); + //#partition-hub + + // Cleanup + materializer.shutdown(); + } + + //#partition-hub-stateful-function + // Using a class since variable must otherwise be final. + // New instance is created for each materialization of the PartitionHub. + static class RoundRobin implements ToLongBiFunction { + + private long i = -1; + + @Override + public long applyAsLong(ConsumerInfo info, T elem) { + i++; + return info.consumerIdByIdx((int) (i % info.size())); + } + } + //#partition-hub-stateful-function + + @Test + public void dynamicStatefulPartition() { + // Used to be able to clean up the running stream + ActorMaterializer materializer = ActorMaterializer.create(system); + + //#partition-hub-stateful + // A simple producer that publishes a new "message-n" every second + Source producer = Source.tick( + FiniteDuration.create(1, TimeUnit.SECONDS), + FiniteDuration.create(1, TimeUnit.SECONDS), + "message" + ).zipWith(Source.range(0, 100), (a, b) -> a + "-" + b); + + // Attach a PartitionHub Sink to the producer. This will materialize to a + // corresponding Source. + // (We need to use toMat and Keep.right since by default the materialized + // value to the left is used) + RunnableGraph> runnableGraph = + producer.toMat( + PartitionHub.ofStateful( + String.class, + () -> new RoundRobin(), + 2, + 256), + Keep.right()); + + // By running/materializing the producer, we get back a Source, which + // gives us access to the elements published by the producer. + Source fromProducer = runnableGraph.run(materializer); + + // Print out messages from the producer in two independent consumers + fromProducer.runForeach(msg -> System.out.println("consumer1: " + msg), materializer); + fromProducer.runForeach(msg -> System.out.println("consumer2: " + msg), materializer); + //#partition-hub-stateful + + // Cleanup + materializer.shutdown(); + } + + @Test + public void dynamicFastestPartition() { + // Used to be able to clean up the running stream + ActorMaterializer materializer = ActorMaterializer.create(system); + + //#partition-hub-fastest + Source producer = Source.range(0, 100); + + // ConsumerInfo.queueSize is the approximate number of buffered elements for a consumer. + // Note that this is a moving target since the elements are consumed concurrently. + RunnableGraph> runnableGraph = + producer.toMat( + PartitionHub.ofStateful( + Integer.class, + () -> (info, elem) -> { + final List ids = info.getConsumerIds(); + int minValue = info.queueSize(0); + long fastest = info.consumerIdByIdx(0); + for (int i = 1; i < ids.size(); i++) { + int value = info.queueSize(i); + if (value < minValue) { + minValue = value; + fastest = info.consumerIdByIdx(i); + } + } + return fastest; + }, + 2, + 8), + Keep.right()); + + Source fromProducer = runnableGraph.run(materializer); + + fromProducer.runForeach(msg -> System.out.println("consumer1: " + msg), materializer); + fromProducer.throttle(10, Duration.create(100, TimeUnit.MILLISECONDS), 10, ThrottleMode.shaping()) + .runForeach(msg -> System.out.println("consumer2: " + msg), materializer); + //#partition-hub-fastest + + // Cleanup + materializer.shutdown(); + } } diff --git a/akka-docs/src/test/scala/docs/stream/HubsDocSpec.scala b/akka-docs/src/test/scala/docs/stream/HubsDocSpec.scala index 2d234c66bc..892676e68c 100644 --- a/akka-docs/src/test/scala/docs/stream/HubsDocSpec.scala +++ b/akka-docs/src/test/scala/docs/stream/HubsDocSpec.scala @@ -10,6 +10,7 @@ import akka.testkit.AkkaSpec import docs.CompileOnlySpec import scala.concurrent.duration._ +import akka.stream.ThrottleMode class HubsDocSpec extends AkkaSpec with CompileOnlySpec { implicit val materializer = ActorMaterializer() @@ -104,6 +105,86 @@ class HubsDocSpec extends AkkaSpec with CompileOnlySpec { //#pub-sub-4 } + "demonstrate creating a dynamic partition hub" in compileOnlySpec { + //#partition-hub + // A simple producer that publishes a new "message-" every second + val producer = Source.tick(1.second, 1.second, "message") + .zipWith(Source(1 to 100))((a, b) => s"$a-$b") + + // Attach a PartitionHub Sink to the producer. This will materialize to a + // corresponding Source. + // (We need to use toMat and Keep.right since by default the materialized + // value to the left is used) + val runnableGraph: RunnableGraph[Source[String, NotUsed]] = + producer.toMat(PartitionHub.sink( + (size, elem) => math.abs(elem.hashCode) % size, + startAfterNrOfConsumers = 2, bufferSize = 256))(Keep.right) + + // By running/materializing the producer, we get back a Source, which + // gives us access to the elements published by the producer. + val fromProducer: Source[String, NotUsed] = runnableGraph.run() + + // Print out messages from the producer in two independent consumers + fromProducer.runForeach(msg => println("consumer1: " + msg)) + fromProducer.runForeach(msg => println("consumer2: " + msg)) + //#partition-hub + } + + "demonstrate creating a dynamic stateful partition hub" in compileOnlySpec { + //#partition-hub-stateful + // A simple producer that publishes a new "message-" every second + val producer = Source.tick(1.second, 1.second, "message") + .zipWith(Source(1 to 100))((a, b) => s"$a-$b") + + // New instance of the partitioner function and its state is created + // for each materialization of the PartitionHub. + def roundRobin(): (PartitionHub.ConsumerInfo, String) ⇒ Long = { + var i = -1L + + (info, elem) => { + i += 1 + info.consumerIdByIdx((i % info.size).toInt) + } + } + + // Attach a PartitionHub Sink to the producer. This will materialize to a + // corresponding Source. + // (We need to use toMat and Keep.right since by default the materialized + // value to the left is used) + val runnableGraph: RunnableGraph[Source[String, NotUsed]] = + producer.toMat(PartitionHub.statefulSink( + () => roundRobin(), + startAfterNrOfConsumers = 2, bufferSize = 256))(Keep.right) + + // By running/materializing the producer, we get back a Source, which + // gives us access to the elements published by the producer. + val fromProducer: Source[String, NotUsed] = runnableGraph.run() + + // Print out messages from the producer in two independent consumers + fromProducer.runForeach(msg => println("consumer1: " + msg)) + fromProducer.runForeach(msg => println("consumer2: " + msg)) + //#partition-hub-stateful + } + + "demonstrate creating a dynamic partition hub routing to fastest consumer" in compileOnlySpec { + //#partition-hub-fastest + val producer = Source(0 until 100) + + // ConsumerInfo.queueSize is the approximate number of buffered elements for a consumer. + // Note that this is a moving target since the elements are consumed concurrently. + val runnableGraph: RunnableGraph[Source[Int, NotUsed]] = + producer.toMat(PartitionHub.statefulSink( + () => (info, elem) ⇒ info.consumerIds.minBy(id ⇒ info.queueSize(id)), + startAfterNrOfConsumers = 2, bufferSize = 16))(Keep.right) + + val fromProducer: Source[Int, NotUsed] = runnableGraph.run() + + fromProducer.runForeach(msg => println("consumer1: " + msg)) + fromProducer.throttle(10, 100.millis, 10, ThrottleMode.Shaping) + .runForeach(msg => println("consumer2: " + msg)) + //#partition-hub-fastest + } + } } diff --git a/akka-remote/src/main/mima-filters/2.5.3.backwards.excludes b/akka-remote/src/main/mima-filters/2.5.3.backwards.excludes new file mode 100644 index 0000000000..c82dbcace6 --- /dev/null +++ b/akka-remote/src/main/mima-filters/2.5.3.backwards.excludes @@ -0,0 +1,2 @@ +#21880 PartitionHub in Artery +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.remote.artery.ArterySettings#Advanced.InboundBroadcastHubBufferSize") diff --git a/akka-remote/src/main/scala/akka/remote/artery/ArterySettings.scala b/akka-remote/src/main/scala/akka/remote/artery/ArterySettings.scala index 3314156524..5ff4b6dd3a 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/ArterySettings.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/ArterySettings.scala @@ -142,7 +142,7 @@ private[akka] final class ArterySettings private (config: Config) { .requiring(_ >= 32 * 1024, "maximum-frame-size must be greater than or equal to 32 KiB") final val BufferPoolSize: Int = getInt("buffer-pool-size") .requiring(_ > 0, "buffer-pool-size must be greater than 0") - final val InboundBroadcastHubBufferSize = BufferPoolSize / 2 + final val InboundHubBufferSize = BufferPoolSize / 2 final val MaximumLargeFrameSize: Int = math.min(getBytes("maximum-large-frame-size"), Int.MaxValue).toInt .requiring(_ >= 32 * 1024, "maximum-large-frame-size must be greater than or equal to 32 KiB") final val LargeBufferPoolSize: Int = getInt("large-buffer-pool-size") diff --git a/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala b/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala index 3631212bb6..51e2d3af27 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala @@ -729,34 +729,29 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R } else { val hubKillSwitch = KillSwitches.shared("hubKillSwitch") - val source: Source[(OptionVal[InternalActorRef], InboundEnvelope), (ResourceLifecycle, InboundCompressionAccess)] = + val source: Source[InboundEnvelope, (ResourceLifecycle, InboundCompressionAccess)] = aeronSource(ordinaryStreamId, envelopeBufferPool) .via(hubKillSwitch.flow) .viaMat(inboundFlow(settings, _inboundCompressions))(Keep.both) - .map(env ⇒ (env.recipient, env)) - - val (resourceLife, compressionAccess, broadcastHub) = - source - .toMat(BroadcastHub.sink(bufferSize = settings.Advanced.InboundBroadcastHubBufferSize))({ case ((a, b), c) ⇒ (a, b, c) }) - .run()(materializer) // select lane based on destination, to preserve message order - def shouldUseLane(recipient: OptionVal[ActorRef], targetLane: Int): Boolean = - recipient match { - case OptionVal.Some(r) ⇒ math.abs(r.path.uid) % inboundLanes == targetLane - case OptionVal.None ⇒ 0 == targetLane + val partitioner: InboundEnvelope ⇒ Int = env ⇒ { + env.recipient match { + case OptionVal.Some(r) ⇒ math.abs(r.path.uid) % inboundLanes + case OptionVal.None ⇒ 0 } + } + + val (resourceLife, compressionAccess, hub) = + source + .toMat(Sink.fromGraph(new FixedSizePartitionHub[InboundEnvelope](partitioner, inboundLanes, + settings.Advanced.InboundHubBufferSize)))({ case ((a, b), c) ⇒ (a, b, c) }) + .run()(materializer) val lane = inboundSink(envelopeBufferPool) val completedValues: Vector[Future[Done]] = - (0 until inboundLanes).map { laneId ⇒ - broadcastHub - // TODO replace filter with "PartitionHub" when that is implemented - // must use a tuple here because envelope is pooled and must only be read in the selected lane - // otherwise, the lane that actually processes it might have already released it. - .collect { case (recipient, env) if shouldUseLane(recipient, laneId) ⇒ env } - .toMat(lane)(Keep.right) - .run()(materializer) + (0 until inboundLanes).map { _ ⇒ + hub.toMat(lane)(Keep.right).run()(materializer) }(collection.breakOut) import system.dispatcher diff --git a/akka-remote/src/main/scala/akka/remote/artery/FixedSizePartitionHub.scala b/akka-remote/src/main/scala/akka/remote/artery/FixedSizePartitionHub.scala new file mode 100644 index 0000000000..7426b58db4 --- /dev/null +++ b/akka-remote/src/main/scala/akka/remote/artery/FixedSizePartitionHub.scala @@ -0,0 +1,73 @@ +/** + * Copyright (C) 2017 Lightbend Inc. + */ +package akka.remote.artery + +import akka.annotation.InternalApi +import akka.stream.scaladsl.PartitionHub +import org.agrona.concurrent.OneToOneConcurrentArrayQueue +import java.util.concurrent.atomic.AtomicInteger +import org.agrona.concurrent.ManyToManyConcurrentArrayQueue + +/** + * INTERNAL API + */ +@InternalApi private[akka] class FixedSizePartitionHub[T]( + partitioner: T ⇒ Int, + lanes: Int, + bufferSize: Int) extends PartitionHub[T](() ⇒ (info, elem) ⇒ info.consumerIdByIdx(partitioner(elem)), lanes, bufferSize - 1) { + // -1 because of the Completed token + + override def createQueue(): PartitionHub.Internal.PartitionQueue = + new FixedSizePartitionQueue(lanes, bufferSize) + +} + +/** + * INTERNAL API + */ +@InternalApi private[akka] class FixedSizePartitionQueue(lanes: Int, capacity: Int) extends PartitionHub.Internal.PartitionQueue { + + private val queues = { + val arr = new Array[OneToOneConcurrentArrayQueue[AnyRef]](lanes) + var i = 0 + while (i < arr.length) { + arr(i) = new OneToOneConcurrentArrayQueue(capacity) + i += 1 + } + arr + } + + override def init(id: Long): Unit = () + + override def totalSize: Int = { + var sum = 0 + var i = 0 + while (i < lanes) { + sum += queues(i).size + i += 1 + } + sum + } + + override def size(id: Long): Int = + queues(id.toInt).size + + override def isEmpty(id: Long): Boolean = + queues(id.toInt).isEmpty + + override def nonEmpty(id: Long): Boolean = + !isEmpty(id) + + override def offer(id: Long, elem: Any): Unit = { + if (!queues(id.toInt).offer(elem.asInstanceOf[AnyRef])) + throw new IllegalStateException(s"queue is full, id [$id]") + } + + override def poll(id: Long): AnyRef = + queues(id.toInt).poll() + + override def remove(id: Long): Unit = + queues(id.toInt).clear() + +} diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/HubSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/HubSpec.scala index dece30419a..9897e3a2a4 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/HubSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/HubSpec.scala @@ -327,7 +327,7 @@ class HubSpec extends StreamSpec { downstream2.expectError(TE("Failed")) } - "properly singal completion to consumers arriving after producer finished" in assertAllStagesStopped { + "properly signal completion to consumers arriving after producer finished" in assertAllStagesStopped { val source = Source.empty[Int].runWith(BroadcastHub.sink(8)) // Wait enough so the Hub gets the completion. This is racy, but this is fine because both // cases should work in the end @@ -354,7 +354,7 @@ class HubSpec extends StreamSpec { sink2Probe.expectComplete() } - "properly singal error to consumers arriving after producer finished" in assertAllStagesStopped { + "properly signal error to consumers arriving after producer finished" in assertAllStagesStopped { val source = Source.failed(TE("Fail!")).runWith(BroadcastHub.sink(8)) // Wait enough so the Hub gets the completion. This is racy, but this is fine because both // cases should work in the end @@ -367,4 +367,241 @@ class HubSpec extends StreamSpec { } + "PartitionHub" must { + + "work in the happy case with one stream" in assertAllStagesStopped { + val source = Source(1 to 10).runWith(PartitionHub.sink((size, elem) ⇒ 0, startAfterNrOfConsumers = 0, bufferSize = 8)) + source.runWith(Sink.seq).futureValue should ===(1 to 10) + } + + "work in the happy case with two streams" in assertAllStagesStopped { + val source = Source(0 until 10).runWith(PartitionHub.sink((size, elem) ⇒ elem % size, startAfterNrOfConsumers = 2, bufferSize = 8)) + val result1 = source.runWith(Sink.seq) + // it should not start publishing until startAfterNrOfConsumers = 2 + Thread.sleep(20) + val result2 = source.runWith(Sink.seq) + result1.futureValue should ===(0 to 8 by 2) + result2.futureValue should ===(1 to 9 by 2) + } + + "be able to use as round-robin router" in assertAllStagesStopped { + val source = Source(0 until 10).runWith(PartitionHub.statefulSink(() ⇒ { + var n = 0L + + (info, elem) ⇒ { + n += 1 + info.consumerIdByIdx((n % info.size).toInt) + } + }, startAfterNrOfConsumers = 2, bufferSize = 8)) + val result1 = source.runWith(Sink.seq) + val result2 = source.runWith(Sink.seq) + result1.futureValue should ===(1 to 9 by 2) + result2.futureValue should ===(0 to 8 by 2) + } + + "be able to use as sticky session router" in assertAllStagesStopped { + val source = Source(List("usr-1", "usr-2", "usr-1", "usr-3")).runWith(PartitionHub.statefulSink(() ⇒ { + var sessions = Map.empty[String, Long] + var n = 0L + + (info, elem) ⇒ { + sessions.get(elem) match { + case Some(id) if info.consumerIds.exists(_ == id) ⇒ id + case _ ⇒ + n += 1 + val id = info.consumerIdByIdx((n % info.size).toInt) + sessions = sessions.updated(elem, id) + id + } + } + }, startAfterNrOfConsumers = 2, bufferSize = 8)) + val result1 = source.runWith(Sink.seq) + val result2 = source.runWith(Sink.seq) + result1.futureValue should ===(List("usr-2")) + result2.futureValue should ===(List("usr-1", "usr-1", "usr-3")) + } + + "be able to use as fastest consumer router" in assertAllStagesStopped { + val source = Source(0 until 1000).runWith(PartitionHub.statefulSink( + () ⇒ (info, elem) ⇒ info.consumerIds.toVector.minBy(id ⇒ info.queueSize(id)), + startAfterNrOfConsumers = 2, bufferSize = 4)) + val result1 = source.runWith(Sink.seq) + val result2 = source.throttle(10, 100.millis, 10, ThrottleMode.Shaping).runWith(Sink.seq) + + result1.futureValue.size should be > (result2.futureValue.size) + } + + "route evenly" in assertAllStagesStopped { + val (testSource, hub) = TestSource.probe[Int].toMat( + PartitionHub.sink((size, elem) ⇒ elem % size, startAfterNrOfConsumers = 2, bufferSize = 8))(Keep.both).run() + val probe0 = hub.runWith(TestSink.probe[Int]) + val probe1 = hub.runWith(TestSink.probe[Int]) + probe0.request(3) + probe1.request(10) + testSource.sendNext(0) + probe0.expectNext(0) + testSource.sendNext(1) + probe1.expectNext(1) + + testSource.sendNext(2) + testSource.sendNext(3) + testSource.sendNext(4) + probe0.expectNext(2) + probe1.expectNext(3) + probe0.expectNext(4) + + // probe1 has not requested more + testSource.sendNext(5) + testSource.sendNext(6) + testSource.sendNext(7) + probe1.expectNext(5) + probe1.expectNext(7) + probe0.expectNoMsg(10.millis) + probe0.request(10) + probe0.expectNext(6) + + testSource.sendComplete() + probe0.expectComplete() + probe1.expectComplete() + } + + "route unevenly" in assertAllStagesStopped { + val (testSource, hub) = TestSource.probe[Int].toMat( + PartitionHub.sink((size, elem) ⇒ (elem % 3) % 2, startAfterNrOfConsumers = 2, bufferSize = 8))(Keep.both).run() + val probe0 = hub.runWith(TestSink.probe[Int]) + val probe1 = hub.runWith(TestSink.probe[Int]) + + // (_ % 3) % 2 + // 0 => 0 + // 1 => 1 + // 2 => 0 + // 3 => 0 + // 4 => 1 + + probe0.request(10) + probe1.request(10) + testSource.sendNext(0) + probe0.expectNext(0) + testSource.sendNext(1) + probe1.expectNext(1) + testSource.sendNext(2) + probe0.expectNext(2) + testSource.sendNext(3) + probe0.expectNext(3) + testSource.sendNext(4) + probe1.expectNext(4) + + testSource.sendComplete() + probe0.expectComplete() + probe1.expectComplete() + } + + "backpressure" in assertAllStagesStopped { + val (testSource, hub) = TestSource.probe[Int].toMat( + PartitionHub.sink((size, elem) ⇒ 0, startAfterNrOfConsumers = 2, bufferSize = 4))(Keep.both).run() + val probe0 = hub.runWith(TestSink.probe[Int]) + val probe1 = hub.runWith(TestSink.probe[Int]) + probe0.request(10) + probe1.request(10) + testSource.sendNext(0) + probe0.expectNext(0) + testSource.sendNext(1) + probe0.expectNext(1) + testSource.sendNext(2) + probe0.expectNext(2) + testSource.sendNext(3) + probe0.expectNext(3) + testSource.sendNext(4) + probe0.expectNext(4) + + testSource.sendComplete() + probe0.expectComplete() + probe1.expectComplete() + } + + "ensure that from two different speed consumers the slower controls the rate" in assertAllStagesStopped { + val (firstElem, source) = Source.maybe[Int].concat(Source(1 until 20)).toMat( + PartitionHub.sink((size, elem) ⇒ elem % size, startAfterNrOfConsumers = 2, bufferSize = 1))(Keep.both).run() + + val f1 = source.throttle(1, 10.millis, 1, ThrottleMode.shaping).runWith(Sink.seq) + // Second cannot be overwhelmed since the first one throttles the overall rate, and second allows a higher rate + val f2 = source.throttle(10, 10.millis, 8, ThrottleMode.enforcing).runWith(Sink.seq) + + // Ensure subscription of Sinks. This is racy but there is no event we can hook into here. + Thread.sleep(100) + firstElem.success(Some(0)) + f1.futureValue should ===(0 to 18 by 2) + f2.futureValue should ===(1 to 19 by 2) + + } + + "properly signal error to consumers" in assertAllStagesStopped { + val upstream = TestPublisher.probe[Int]() + val source = Source.fromPublisher(upstream).runWith( + PartitionHub.sink((size, elem) ⇒ elem % size, startAfterNrOfConsumers = 2, bufferSize = 8)) + + val downstream1 = TestSubscriber.probe[Int]() + source.runWith(Sink.fromSubscriber(downstream1)) + val downstream2 = TestSubscriber.probe[Int]() + source.runWith(Sink.fromSubscriber(downstream2)) + + downstream1.request(4) + downstream2.request(8) + + (0 until 16) foreach (upstream.sendNext(_)) + + downstream1.expectNext(0, 2, 4, 6) + downstream2.expectNext(1, 3, 5, 7, 9, 11, 13, 15) + + downstream1.expectNoMsg(100.millis) + downstream2.expectNoMsg(100.millis) + + upstream.sendError(TE("Failed")) + + downstream1.expectError(TE("Failed")) + downstream2.expectError(TE("Failed")) + } + + "properly signal completion to consumers arriving after producer finished" in assertAllStagesStopped { + val source = Source.empty[Int].runWith(PartitionHub.sink((size, elem) ⇒ elem % size, startAfterNrOfConsumers = 0)) + // Wait enough so the Hub gets the completion. This is racy, but this is fine because both + // cases should work in the end + Thread.sleep(10) + + source.runWith(Sink.seq).futureValue should ===(Nil) + } + + "remember completion for materialisations after completion" in { + + val (sourceProbe, source) = TestSource.probe[Unit].toMat( + PartitionHub.sink((size, elem) ⇒ 0, startAfterNrOfConsumers = 0))(Keep.both).run() + val sinkProbe = source.runWith(TestSink.probe[Unit]) + + sourceProbe.sendComplete() + + sinkProbe.request(1) + sinkProbe.expectComplete() + + // Materialize a second time. There was a race here, where we managed to enqueue our Source registration just + // immediately before the Hub shut down. + val sink2Probe = source.runWith(TestSink.probe[Unit]) + + sink2Probe.request(1) + sink2Probe.expectComplete() + } + + "properly signal error to consumers arriving after producer finished" in assertAllStagesStopped { + val source = Source.failed[Int](TE("Fail!")).runWith( + PartitionHub.sink((size, elem) ⇒ 0, startAfterNrOfConsumers = 0)) + // Wait enough so the Hub gets the failure. This is racy, but this is fine because both + // cases should work in the end + Thread.sleep(10) + + a[TE] shouldBe thrownBy { + Await.result(source.runWith(Sink.seq), 3.seconds) + } + } + + } + } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala index cfb812754c..dd279a3da5 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala @@ -635,7 +635,7 @@ import scala.util.control.NonFatal interpreter.connections.foreach { connection ⇒ builder .append(" ") - .append(connection.toString) + .append(if (connection == null) "null" else connection.toString) .append(",\n") } builder.setLength(builder.length - 2) diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala index 0216e90cc1..442219591a 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala @@ -641,7 +641,7 @@ import scala.util.control.NonFatal } val logicIndexes = logics.zipWithIndex.map { case (stage, idx) ⇒ stage → idx }.toMap - for (connection ← connections) { + for (connection ← connections if connection != null) { val inName = "N" + logicIndexes(connection.inOwner) val outName = "N" + logicIndexes(connection.outOwner) diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Hub.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Hub.scala index 12be1fae17..c183381156 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Hub.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Hub.scala @@ -4,6 +4,10 @@ package akka.stream.javadsl import akka.NotUsed +import java.util.function.{ BiFunction, Supplier, ToLongBiFunction } + +import akka.annotation.DoNotInherit +import akka.annotation.ApiMayChange /** * A MergeHub is a special streaming hub that is able to collect streamed elements from a dynamic set of @@ -91,3 +95,129 @@ object BroadcastHub { def of[T](clazz: Class[T]): Sink[T, Source[T, NotUsed]] = of(clazz, 256) } + +/** + * A `PartitionHub` is a special streaming hub that is able to route streamed elements to a dynamic set of consumers. + * It consists of two parts, a [[Sink]] and a [[Source]]. The [[Sink]] e elements from a producer to the + * actually live consumers it has. The selection of consumer is done with a function. Each element can be routed to + * only one consumer. Once the producer has been materialized, the [[Sink]] it feeds into returns a + * materialized value which is the corresponding [[Source]]. This [[Source]] can be materialized an arbitrary number + * of times, where each of the new materializations will receive their elements from the original [[Sink]]. + */ +object PartitionHub { + + /** + * Creates a [[Sink]] that receives elements from its upstream producer and routes them to a dynamic set + * of consumers. After the [[Sink]] returned by this method is materialized, it returns a [[Source]] as materialized + * value. This [[Source]] can be materialized an arbitrary number of times and each materialization will receive the + * elements from the original [[Sink]]. + * + * Every new materialization of the [[Sink]] results in a new, independent hub, which materializes to its own + * [[Source]] for consuming the [[Sink]] of that materialization. + * + * If the original [[Sink]] is failed, then the failure is immediately propagated to all of its materialized + * [[Source]]s (possibly jumping over already buffered elements). If the original [[Sink]] is completed, then + * all corresponding [[Source]]s are completed. Both failure and normal completion is "remembered" and later + * materializations of the [[Source]] will see the same (failure or completion) state. [[Source]]s that are + * cancelled are simply removed from the dynamic set of consumers. + * + * This `statefulSink` should be used when there is a need to keep mutable state in the partition function, + * e.g. for implemening round-robin or sticky session kind of routing. If state is not needed the [[#of]] can + * be more convenient to use. + * + * @param partitioner Function that decides where to route an element. It is a factory of a function to + * to be able to hold stateful variables that are unique for each materialization. The function + * takes two parameters; the first is information about active consumers, including an array of consumer + * identifiers and the second is the stream element. The function should return the selected consumer + * identifier for the given element. The function will never be called when there are no active consumers, + * i.e. there is always at least one element in the array of identifiers. + * @param startAfterNrOfConsumers Elements are buffered until this number of consumers have been connected. + * This is only used initially when the stage is starting up, i.e. it is not honored when consumers have + * been removed (canceled). + * @param bufferSize Total number of elements that can be buffered. If this buffer is full, the producer + * is backpressured. + */ + @ApiMayChange def ofStateful[T](clazz: Class[T], partitioner: Supplier[ToLongBiFunction[ConsumerInfo, T]], + startAfterNrOfConsumers: Int, bufferSize: Int): Sink[T, Source[T, NotUsed]] = { + val p: () ⇒ (akka.stream.scaladsl.PartitionHub.ConsumerInfo, T) ⇒ Long = () ⇒ { + val f = partitioner.get() + (info, elem) ⇒ f.applyAsLong(info, elem) + } + akka.stream.scaladsl.PartitionHub.statefulSink[T](p, startAfterNrOfConsumers, bufferSize) + .mapMaterializedValue(_.asJava) + .asJava + } + + @ApiMayChange def ofStateful[T](clazz: Class[T], partitioner: Supplier[ToLongBiFunction[ConsumerInfo, T]], + startAfterNrOfConsumers: Int): Sink[T, Source[T, NotUsed]] = + ofStateful(clazz, partitioner, startAfterNrOfConsumers, akka.stream.scaladsl.PartitionHub.defaultBufferSize) + + /** + * Creates a [[Sink]] that receives elements from its upstream producer and routes them to a dynamic set + * of consumers. After the [[Sink]] returned by this method is materialized, it returns a [[Source]] as materialized + * value. This [[Source]] can be materialized an arbitrary number of times and each materialization will receive the + * elements from the original [[Sink]]. + * + * Every new materialization of the [[Sink]] results in a new, independent hub, which materializes to its own + * [[Source]] for consuming the [[Sink]] of that materialization. + * + * If the original [[Sink]] is failed, then the failure is immediately propagated to all of its materialized + * [[Source]]s (possibly jumping over already buffered elements). If the original [[Sink]] is completed, then + * all corresponding [[Source]]s are completed. Both failure and normal completion is "remembered" and later + * materializations of the [[Source]] will see the same (failure or completion) state. [[Source]]s that are + * cancelled are simply removed from the dynamic set of consumers. + * + * This `sink` should be used when the routing function is stateless, e.g. based on a hashed value of the + * elements. Otherwise the [[#ofStateful]] can be used to implement more advanced routing logic. + * + * @param partitioner Function that decides where to route an element. The function takes two parameters; + * the first is the number of active consumers and the second is the stream element. The function should + * return the index of the selected consumer for the given element, i.e. int greater than or equal to 0 + * and less than number of consumers. E.g. `(size, elem) -> Math.abs(elem.hashCode()) % size`. + * @param startAfterNrOfConsumers Elements are buffered until this number of consumers have been connected. + * This is only used initially when the stage is starting up, i.e. it is not honored when consumers have + * been removed (canceled). + * @param bufferSize Total number of elements that can be buffered. If this buffer is full, the producer + * is backpressured. + */ + @ApiMayChange def of[T](clazz: Class[T], partitioner: BiFunction[Integer, T, Integer], startAfterNrOfConsumers: Int, + bufferSize: Int): Sink[T, Source[T, NotUsed]] = + akka.stream.scaladsl.PartitionHub.sink[T]( + (size, elem) ⇒ partitioner.apply(size, elem), + startAfterNrOfConsumers, bufferSize) + .mapMaterializedValue(_.asJava) + .asJava + + @ApiMayChange def of[T](clazz: Class[T], partitioner: BiFunction[Integer, T, Integer], startAfterNrOfConsumers: Int): Sink[T, Source[T, NotUsed]] = + of(clazz, partitioner, startAfterNrOfConsumers, akka.stream.scaladsl.PartitionHub.defaultBufferSize) + + @DoNotInherit @ApiMayChange trait ConsumerInfo { + + /** + * Sequence of all identifiers of current consumers. + * + * Use this method only if you need to enumerate consumer existing ids. + * When selecting a specific consumerId by its index, prefer using the dedicated [[#consumerIdByIdx]] method instead, + * which is optimised for this use case. + */ + def getConsumerIds: java.util.List[Long] + + /** Obtain consumer identifier by index */ + def consumerIdByIdx(idx: Int): Long + + /** + * Approximate number of buffered elements for a consumer. + * Larger value than other consumers could be an indication of + * that the consumer is slow. + * + * Note that this is a moving target since the elements are + * consumed concurrently. + */ + def queueSize(consumerId: Long): Int + + /** + * Number of attached consumers. + */ + def size: Int + } +} diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Hub.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Hub.scala index e14b444f9a..0d0f13a480 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Hub.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Hub.scala @@ -3,6 +3,7 @@ */ package akka.stream.scaladsl +import java.util import java.util.concurrent.atomic.{ AtomicLong, AtomicReference } import akka.NotUsed @@ -13,6 +14,17 @@ import akka.stream.stage._ import scala.annotation.tailrec import scala.concurrent.{ Future, Promise } import scala.util.{ Failure, Success, Try } +import java.util.Arrays +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicReferenceArray + +import scala.collection.immutable +import scala.collection.mutable.LongMap +import scala.collection.immutable.Queue +import akka.annotation.InternalApi +import akka.annotation.DoNotInherit +import akka.annotation.ApiMayChange /** * A MergeHub is a special streaming hub that is able to collect streamed elements from a dynamic set of @@ -107,8 +119,7 @@ private[akka] class MergeHub[T](perProducerBufferSize: Int) extends GraphStageWi private[this] val demands = scala.collection.mutable.LongMap.empty[InputState] private[this] val wakeupCallback = getAsyncCallback[NotUsed]((_) ⇒ // We are only allowed to dequeue if we are not backpressured. See comment in tryProcessNext() for details. - if (isAvailable(out)) tryProcessNext(firstAttempt = true) - ) + if (isAvailable(out)) tryProcessNext(firstAttempt = true)) setHandler(out, this) @@ -291,7 +302,7 @@ object BroadcastHub { * Creates a [[Sink]] that receives elements from its upstream producer and broadcasts them to a dynamic set * of consumers. After the [[Sink]] returned by this method is materialized, it returns a [[Source]] as materialized * value. This [[Source]] can be materialized an arbitrary number of times and each materialization will receive the - * broadcast elements form the ofiginal [[Sink]]. + * broadcast elements from the original [[Sink]]. * * Every new materialization of the [[Sink]] results in a new, independent hub, which materializes to its own * [[Source]] for consuming the [[Sink]] of that materialization. @@ -693,3 +704,554 @@ private[akka] class BroadcastHub[T](bufferSize: Int) extends GraphStageWithMater (logic, Source.fromGraph(source)) } } + +/** + * A `PartitionHub` is a special streaming hub that is able to route streamed elements to a dynamic set of consumers. + * It consists of two parts, a [[Sink]] and a [[Source]]. The [[Sink]] e elements from a producer to the + * actually live consumers it has. The selection of consumer is done with a function. Each element can be routed to + * only one consumer. Once the producer has been materialized, the [[Sink]] it feeds into returns a + * materialized value which is the corresponding [[Source]]. This [[Source]] can be materialized an arbitrary number + * of times, where each of the new materializations will receive their elements from the original [[Sink]]. + */ +object PartitionHub { + + /** + * INTERNAL API + */ + @InternalApi private[akka] val defaultBufferSize = 256 + + /** + * Creates a [[Sink]] that receives elements from its upstream producer and routes them to a dynamic set + * of consumers. After the [[Sink]] returned by this method is materialized, it returns a [[Source]] as materialized + * value. This [[Source]] can be materialized an arbitrary number of times and each materialization will receive the + * elements from the original [[Sink]]. + * + * Every new materialization of the [[Sink]] results in a new, independent hub, which materializes to its own + * [[Source]] for consuming the [[Sink]] of that materialization. + * + * If the original [[Sink]] is failed, then the failure is immediately propagated to all of its materialized + * [[Source]]s (possibly jumping over already buffered elements). If the original [[Sink]] is completed, then + * all corresponding [[Source]]s are completed. Both failure and normal completion is "remembered" and later + * materializations of the [[Source]] will see the same (failure or completion) state. [[Source]]s that are + * cancelled are simply removed from the dynamic set of consumers. + * + * This `statefulSink` should be used when there is a need to keep mutable state in the partition function, + * e.g. for implemening round-robin or sticky session kind of routing. If state is not needed the [[#sink]] can + * be more convenient to use. + * + * @param partitioner Function that decides where to route an element. It is a factory of a function to + * to be able to hold stateful variables that are unique for each materialization. The function + * takes two parameters; the first is information about active consumers, including an array of consumer + * identifiers and the second is the stream element. The function should return the selected consumer + * identifier for the given element. The function will never be called when there are no active consumers, + * i.e. there is always at least one element in the array of identifiers. + * @param startAfterNrOfConsumers Elements are buffered until this number of consumers have been connected. + * This is only used initially when the stage is starting up, i.e. it is not honored when consumers have + * been removed (canceled). + * @param bufferSize Total number of elements that can be buffered. If this buffer is full, the producer + * is backpressured. + */ + @ApiMayChange def statefulSink[T](partitioner: () ⇒ (ConsumerInfo, T) ⇒ Long, startAfterNrOfConsumers: Int, + bufferSize: Int = defaultBufferSize): Sink[T, Source[T, NotUsed]] = + Sink.fromGraph(new PartitionHub[T](partitioner, startAfterNrOfConsumers, bufferSize)) + + /** + * Creates a [[Sink]] that receives elements from its upstream producer and routes them to a dynamic set + * of consumers. After the [[Sink]] returned by this method is materialized, it returns a [[Source]] as materialized + * value. This [[Source]] can be materialized an arbitrary number of times and each materialization will receive the + * elements from the original [[Sink]]. + * + * Every new materialization of the [[Sink]] results in a new, independent hub, which materializes to its own + * [[Source]] for consuming the [[Sink]] of that materialization. + * + * If the original [[Sink]] is failed, then the failure is immediately propagated to all of its materialized + * [[Source]]s (possibly jumping over already buffered elements). If the original [[Sink]] is completed, then + * all corresponding [[Source]]s are completed. Both failure and normal completion is "remembered" and later + * materializations of the [[Source]] will see the same (failure or completion) state. [[Source]]s that are + * cancelled are simply removed from the dynamic set of consumers. + * + * This `sink` should be used when the routing function is stateless, e.g. based on a hashed value of the + * elements. Otherwise the [[#statefulSink]] can be used to implement more advanced routing logic. + * + * @param partitioner Function that decides where to route an element. The function takes two parameters; + * the first is the number of active consumers and the second is the stream element. The function should + * return the index of the selected consumer for the given element, i.e. int greater than or equal to 0 + * and less than number of consumers. E.g. `(size, elem) => math.abs(elem.hashCode) % size`. + * @param startAfterNrOfConsumers Elements are buffered until this number of consumers have been connected. + * This is only used initially when the stage is starting up, i.e. it is not honored when consumers have + * been removed (canceled). + * @param bufferSize Total number of elements that can be buffered. If this buffer is full, the producer + * is backpressured. + */ + @ApiMayChange + def sink[T](partitioner: (Int, T) ⇒ Int, startAfterNrOfConsumers: Int, + bufferSize: Int = defaultBufferSize): Sink[T, Source[T, NotUsed]] = + statefulSink(() ⇒ (info, elem) ⇒ info.consumerIdByIdx(partitioner(info.size, elem)), startAfterNrOfConsumers, bufferSize) + + @DoNotInherit @ApiMayChange trait ConsumerInfo extends akka.stream.javadsl.PartitionHub.ConsumerInfo { + + /** + * Sequence of all identifiers of current consumers. + * + * Use this method only if you need to enumerate consumer existing ids. + * When selecting a specific consumerId by its index, prefer using the dedicated [[#consumerIdByIdx]] method instead, + * which is optimised for this use case. + */ + def consumerIds: immutable.IndexedSeq[Long] + + /** Obtain consumer identifier by index */ + def consumerIdByIdx(idx: Int): Long + + /** + * Approximate number of buffered elements for a consumer. + * Larger value than other consumers could be an indication of + * that the consumer is slow. + * + * Note that this is a moving target since the elements are + * consumed concurrently. + */ + def queueSize(consumerId: Long): Int + + /** + * Number of attached consumers. + */ + def size: Int + + } + + /** + * INTERNAL API + */ + @InternalApi private[akka] object Internal { + sealed trait ConsumerEvent + case object Wakeup extends ConsumerEvent + final case class HubCompleted(failure: Option[Throwable]) extends ConsumerEvent + case object Initialize extends ConsumerEvent + + sealed trait HubEvent + case object RegistrationPending extends HubEvent + final case class UnRegister(id: Long) extends HubEvent + final case class NeedWakeup(consumer: Consumer) extends HubEvent + final case class Consumer(id: Long, callback: AsyncCallback[ConsumerEvent]) + case object TryPull extends HubEvent + + case object Completed + + sealed trait HubState + final case class Open(callbackFuture: Future[AsyncCallback[HubEvent]], registrations: List[Consumer]) extends HubState + final case class Closed(failure: Option[Throwable]) extends HubState + + // The reason for the two implementations here is that the common case (as I see it) is to have a few (< 100) + // consumers over the lifetime of the hub but we must of course also support more. + // FixedQueues is more efficient than ConcurrentHashMap so we use that for the first 128 consumers. + private val FixedQueues = 128 + + // Need the queue to be pluggable to be able to use a more performant (less general) + // queue in Artery + trait PartitionQueue { + def init(id: Long): Unit + def totalSize: Int + def size(id: Long): Int + def isEmpty(id: Long): Boolean + def nonEmpty(id: Long): Boolean + def offer(id: Long, elem: Any): Unit + def poll(id: Long): AnyRef + def remove(id: Long): Unit + } + + object ConsumerQueue { + val empty = ConsumerQueue(Queue.empty, 0) + } + + final case class ConsumerQueue(queue: Queue[Any], size: Int) { + def enqueue(elem: Any): ConsumerQueue = + new ConsumerQueue(queue.enqueue(elem), size + 1) + + def isEmpty: Boolean = size == 0 + + def head: Any = queue.head + + def tail: ConsumerQueue = + new ConsumerQueue(queue.tail, size - 1) + } + + class PartitionQueueImpl extends PartitionQueue { + private val queues1 = new AtomicReferenceArray[ConsumerQueue](FixedQueues) + private val queues2 = new ConcurrentHashMap[Long, ConsumerQueue] + private val _totalSize = new AtomicInteger + + override def init(id: Long): Unit = { + if (id < FixedQueues) + queues1.set(id.toInt, ConsumerQueue.empty) + else + queues2.put(id, ConsumerQueue.empty) + } + + override def totalSize: Int = _totalSize.get + + def size(id: Long): Int = { + val queue = + if (id < FixedQueues) queues1.get(id.toInt) + else queues2.get(id) + if (queue eq null) + throw new IllegalArgumentException(s"Invalid stream identifier: $id") + queue.size + } + + override def isEmpty(id: Long): Boolean = { + val queue = + if (id < FixedQueues) queues1.get(id.toInt) + else queues2.get(id) + if (queue eq null) + throw new IllegalArgumentException(s"Invalid stream identifier: $id") + queue.isEmpty + } + + override def nonEmpty(id: Long): Boolean = !isEmpty(id) + + override def offer(id: Long, elem: Any): Unit = { + @tailrec def offer1(): Unit = { + val i = id.toInt + val queue = queues1.get(i) + if (queue eq null) + throw new IllegalArgumentException(s"Invalid stream identifier: $id") + if (queues1.compareAndSet(i, queue, queue.enqueue(elem))) + _totalSize.incrementAndGet() + else + offer1() // CAS failed, retry + } + + @tailrec def offer2(): Unit = { + val queue = queues2.get(id) + if (queue eq null) + throw new IllegalArgumentException(s"Invalid stream identifier: $id") + if (queues2.replace(id, queue, queue.enqueue(elem))) { + _totalSize.incrementAndGet() + } else + offer2() // CAS failed, retry + } + + if (id < FixedQueues) offer1() else offer2() + } + + override def poll(id: Long): AnyRef = { + @tailrec def poll1(): AnyRef = { + val i = id.toInt + val queue = queues1.get(i) + if ((queue eq null) || queue.isEmpty) null + else if (queues1.compareAndSet(i, queue, queue.tail)) { + _totalSize.decrementAndGet() + queue.head.asInstanceOf[AnyRef] + } else + poll1() // CAS failed, try again + } + + @tailrec def poll2(): AnyRef = { + val queue = queues2.get(id) + if ((queue eq null) || queue.isEmpty) null + else if (queues2.replace(id, queue, queue.tail)) { + _totalSize.decrementAndGet() + queue.head.asInstanceOf[AnyRef] + } else + poll2() // CAS failed, try again + } + + if (id < FixedQueues) poll1() else poll2() + } + + override def remove(id: Long): Unit = { + (if (id < FixedQueues) queues1.getAndSet(id.toInt, null) + else queues2.remove(id)) match { + case null ⇒ + case queue ⇒ _totalSize.addAndGet(-queue.size) + } + } + + } + } +} + +/** + * INTERNAL API + */ +@InternalApi private[akka] class PartitionHub[T]( + partitioner: () ⇒ (PartitionHub.ConsumerInfo, T) ⇒ Long, + startAfterNrOfConsumers: Int, bufferSize: Int) + extends GraphStageWithMaterializedValue[SinkShape[T], Source[T, NotUsed]] { + import PartitionHub.Internal._ + import PartitionHub.ConsumerInfo + + val in: Inlet[T] = Inlet("PartitionHub.in") + override val shape: SinkShape[T] = SinkShape(in) + + // Need the queue to be pluggable to be able to use a more performant (less general) + // queue in Artery + def createQueue(): PartitionQueue = new PartitionQueueImpl + + private class PartitionSinkLogic(_shape: Shape) + extends GraphStageLogic(_shape) with InHandler { + + // Half of buffer size, rounded up + private val DemandThreshold = (bufferSize / 2) + (bufferSize % 2) + + private val materializedPartitioner = partitioner() + + private val callbackPromise: Promise[AsyncCallback[HubEvent]] = Promise() + private val noRegistrationsState = Open(callbackPromise.future, Nil) + val state = new AtomicReference[HubState](noRegistrationsState) + private var initialized = false + + private val queue = createQueue() + private var pending = Vector.empty[T] + private var consumerInfo: ConsumerInfoImpl = new ConsumerInfoImpl(Vector.empty) + private val needWakeup: LongMap[Consumer] = LongMap.empty + + private var callbackCount = 0L + + private final class ConsumerInfoImpl(val consumers: Vector[Consumer]) + extends ConsumerInfo { info ⇒ + + override def queueSize(consumerId: Long): Int = + queue.size(consumerId) + + override def size: Int = consumers.size + + override def consumerIds: immutable.IndexedSeq[Long] = + consumers.map(_.id) + + override def consumerIdByIdx(idx: Int): Long = + consumers(idx).id + + override def getConsumerIds: java.util.List[Long] = + new util.AbstractList[Long] { + override def get(idx: Int): Long = info.consumerIdByIdx(idx) + override def size(): Int = info.size + } + } + + override def preStart(): Unit = { + setKeepGoing(true) + callbackPromise.success(getAsyncCallback[HubEvent](onEvent)) + if (startAfterNrOfConsumers == 0) + pull(in) + } + + override def onPush(): Unit = { + publish(grab(in)) + if (!isFull) pull(in) + } + + private def isFull: Boolean = { + (queue.totalSize + pending.size) >= bufferSize + } + + private def publish(elem: T): Unit = { + if (!initialized || consumerInfo.consumers.isEmpty) { + // will be published when first consumers are registered + pending :+= elem + } else { + val id = materializedPartitioner(consumerInfo, elem) + queue.offer(id, elem) + wakeup(id) + } + } + + private def wakeup(id: Long): Unit = { + needWakeup.get(id) match { + case None ⇒ // ignore + case Some(consumer) ⇒ + needWakeup -= id + consumer.callback.invoke(Wakeup) + } + } + + override def onUpstreamFinish(): Unit = { + if (consumerInfo.consumers.isEmpty) + completeStage() + else { + consumerInfo.consumers.foreach(c ⇒ complete(c.id)) + } + } + + private def complete(id: Long): Unit = { + queue.offer(id, Completed) + wakeup(id) + } + + private def tryPull(): Unit = { + if (initialized && !isClosed(in) && !hasBeenPulled(in) && !isFull) + pull(in) + } + + private def onEvent(ev: HubEvent): Unit = { + callbackCount += 1 + ev match { + case NeedWakeup(consumer) ⇒ + // Also check if the consumer is now unblocked since we published an element since it went asleep. + if (queue.nonEmpty(consumer.id)) + consumer.callback.invoke(Wakeup) + else { + needWakeup.update(consumer.id, consumer) + tryPull() + } + + case TryPull ⇒ + tryPull() + + case RegistrationPending ⇒ + state.getAndSet(noRegistrationsState).asInstanceOf[Open].registrations foreach { consumer ⇒ + val newConsumers = (consumerInfo.consumers :+ consumer).sortBy(_.id) + consumerInfo = new ConsumerInfoImpl(newConsumers) + queue.init(consumer.id) + if (newConsumers.size >= startAfterNrOfConsumers) { + initialized = true + } + + consumer.callback.invoke(Initialize) + + if (initialized && pending.nonEmpty) { + pending.foreach(publish) + pending = Vector.empty[T] + } + + tryPull() + } + + case UnRegister(id) ⇒ + val newConsumers = consumerInfo.consumers.filterNot(_.id == id) + consumerInfo = new ConsumerInfoImpl(newConsumers) + queue.remove(id) + if (newConsumers.isEmpty) { + if (isClosed(in)) completeStage() + } else + tryPull() + } + } + + override def onUpstreamFailure(ex: Throwable): Unit = { + val failMessage = HubCompleted(Some(ex)) + + // Notify pending consumers and set tombstone + state.getAndSet(Closed(Some(ex))).asInstanceOf[Open].registrations foreach { consumer ⇒ + consumer.callback.invoke(failMessage) + } + + // Notify registered consumers + consumerInfo.consumers.foreach { consumer ⇒ + consumer.callback.invoke(failMessage) + } + failStage(ex) + } + + override def postStop(): Unit = { + // Notify pending consumers and set tombstone + + @tailrec def tryClose(): Unit = state.get() match { + case Closed(_) ⇒ // Already closed, ignore + case open: Open ⇒ + if (state.compareAndSet(open, Closed(None))) { + val completedMessage = HubCompleted(None) + open.registrations foreach { consumer ⇒ + consumer.callback.invoke(completedMessage) + } + } else tryClose() + } + + tryClose() + } + + // Consumer API + def poll(id: Long, hubCallback: AsyncCallback[HubEvent]): AnyRef = { + // try pull via async callback when half full + // this is racy with other threads doing poll but doesn't matter + if (queue.totalSize == DemandThreshold) + hubCallback.invoke(TryPull) + + queue.poll(id) + } + + setHandler(in, this) + } + + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Source[T, NotUsed]) = { + val idCounter = new AtomicLong + + val logic = new PartitionSinkLogic(shape) + + val source = new GraphStage[SourceShape[T]] { + val out: Outlet[T] = Outlet("PartitionHub.out") + override val shape: SourceShape[T] = SourceShape(out) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with OutHandler { + private val id = idCounter.getAndIncrement() + private var hubCallback: AsyncCallback[HubEvent] = _ + private val callback = getAsyncCallback(onCommand) + private val consumer = Consumer(id, callback) + + private var callbackCount = 0L + + override def preStart(): Unit = { + val onHubReady: Try[AsyncCallback[HubEvent]] ⇒ Unit = { + case Success(callback) ⇒ + hubCallback = callback + callback.invoke(RegistrationPending) + if (isAvailable(out)) onPull() + case Failure(ex) ⇒ + failStage(ex) + } + + @tailrec def register(): Unit = { + logic.state.get() match { + case Closed(Some(ex)) ⇒ failStage(ex) + case Closed(None) ⇒ completeStage() + case previousState @ Open(callbackFuture, registrations) ⇒ + val newRegistrations = consumer :: registrations + if (logic.state.compareAndSet(previousState, Open(callbackFuture, newRegistrations))) { + callbackFuture.onComplete(getAsyncCallback(onHubReady).invoke)(materializer.executionContext) + } else register() + } + } + + register() + + } + + override def onPull(): Unit = { + if (hubCallback ne null) { + val elem = logic.poll(id, hubCallback) + + elem match { + case null ⇒ + hubCallback.invoke(NeedWakeup(consumer)) + case Completed ⇒ + completeStage() + case _ ⇒ + push(out, elem.asInstanceOf[T]) + } + } + } + + override def postStop(): Unit = { + if (hubCallback ne null) + hubCallback.invoke(UnRegister(id)) + } + + private def onCommand(cmd: ConsumerEvent): Unit = { + callbackCount += 1 + cmd match { + case HubCompleted(Some(ex)) ⇒ failStage(ex) + case HubCompleted(None) ⇒ completeStage() + case Wakeup ⇒ + if (isAvailable(out)) onPull() + case Initialize ⇒ + if (isAvailable(out) && (hubCallback ne null)) onPull() + } + } + + setHandler(out, this) + } + } + + (logic, Source.fromGraph(source)) + } +}