diff --git a/akka-bench-jmh/src/main/scala/akka/stream/FlatMapConcatBenchmark.scala b/akka-bench-jmh/src/main/scala/akka/stream/FlatMapConcatBenchmark.scala new file mode 100644 index 0000000000..0259f122b0 --- /dev/null +++ b/akka-bench-jmh/src/main/scala/akka/stream/FlatMapConcatBenchmark.scala @@ -0,0 +1,130 @@ +/** + * Copyright (C) 2018 Lightbend Inc. + */ + +package akka.stream + +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit + +import scala.concurrent.Await +import scala.concurrent.duration._ + +import akka.NotUsed +import akka.actor.ActorSystem +import akka.remote.artery.BenchTestSource +import akka.remote.artery.LatchSink +import akka.stream.impl.PhasedFusingActorMaterializer +import akka.stream.impl.StreamSupervisor +import akka.stream.scaladsl._ +import akka.testkit.TestProbe +import com.typesafe.config.ConfigFactory +import org.openjdk.jmh.annotations._ +import akka.stream.impl.fusing.GraphStages + +object FlatMapConcatBenchmark { + final val OperationsPerInvocation = 100000 +} + +@State(Scope.Benchmark) +@OutputTimeUnit(TimeUnit.SECONDS) +@BenchmarkMode(Array(Mode.Throughput)) +class FlatMapConcatBenchmark { + import FlatMapConcatBenchmark._ + + private val config = ConfigFactory.parseString( + """ + akka.actor.default-dispatcher { + executor = "fork-join-executor" + fork-join-executor { + parallelism-factor = 1 + } + } + """ + ) + + private implicit val system: ActorSystem = ActorSystem("FlatMapConcatBenchmark", config) + + var materializer: ActorMaterializer = _ + + var testSource: Source[java.lang.Integer, NotUsed] = _ + + @Setup + def setup(): Unit = { + val settings = ActorMaterializerSettings(system) + materializer = ActorMaterializer(settings) + + testSource = Source.fromGraph(new BenchTestSource(OperationsPerInvocation)) + } + + @TearDown + def shutdown(): Unit = { + Await.result(system.terminate(), 5.seconds) + } + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def sourceDotSingle(): Unit = { + val latch = new CountDownLatch(1) + + testSource + .flatMapConcat(Source.single) + .runWith(new LatchSink(OperationsPerInvocation, latch))(materializer) + + awaitLatch(latch) + } + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def internalSingleSource(): Unit = { + val latch = new CountDownLatch(1) + + testSource + .flatMapConcat(elem ⇒ new GraphStages.SingleSource(elem)) + .runWith(new LatchSink(OperationsPerInvocation, latch))(materializer) + + awaitLatch(latch) + } + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def oneElementList(): Unit = { + val latch = new CountDownLatch(1) + + testSource + .flatMapConcat(n ⇒ Source(n :: Nil)) + .runWith(new LatchSink(OperationsPerInvocation, latch))(materializer) + + awaitLatch(latch) + } + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def mapBaseline(): Unit = { + val latch = new CountDownLatch(1) + + testSource + .map(elem ⇒ elem) + .runWith(new LatchSink(OperationsPerInvocation, latch))(materializer) + + awaitLatch(latch) + } + + private def awaitLatch(latch: CountDownLatch): Unit = { + if (!latch.await(30, TimeUnit.SECONDS)) { + dumpMaterializer() + throw new RuntimeException("Latch didn't complete in time") + } + } + + private def dumpMaterializer(): Unit = { + materializer match { + case impl: PhasedFusingActorMaterializer ⇒ + val probe = TestProbe()(system) + impl.supervisor.tell(StreamSupervisor.GetChildren, probe.ref) + val children = probe.expectMsgType[StreamSupervisor.Children].children + children.foreach(_ ! StreamSupervisor.PrintDebugDump) + } + } + +} diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFlattenMergeSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFlattenMergeSpec.scala index 4d71b84851..d3f858c4ac 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFlattenMergeSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFlattenMergeSpec.scala @@ -12,10 +12,15 @@ import akka.stream.testkit.scaladsl.StreamTestKit._ import scala.concurrent._ import scala.concurrent.duration._ + +import akka.stream.impl.TraversalBuilder +import akka.stream.impl.fusing.GraphStages +import akka.stream.impl.fusing.GraphStages.SingleSource import akka.stream.testkit.{ StreamSpec, TestPublisher } import org.scalatest.exceptions.TestFailedException import akka.stream.testkit.scaladsl.TestSink import akka.testkit.TestLatch +import akka.util.OptionVal class FlowFlattenMergeSpec extends StreamSpec { implicit val materializer = ActorMaterializer() @@ -210,5 +215,85 @@ class FlowFlattenMergeSpec extends StreamSpec { attributes.indexOf(Attributes.Name("inner")) < attributes.indexOf(Attributes.Name("outer")) should be(true) } + "work with optimized Source.single" in assertAllStagesStopped { + Source(0 to 3) + .flatMapConcat(Source.single) + .runWith(toSeq) + .futureValue should ===(0 to 3) + } + + "work with optimized Source.single when slow demand" in assertAllStagesStopped { + val probe = Source(0 to 4) + .flatMapConcat(Source.single) + .runWith(TestSink.probe) + + probe.request(3) + probe.expectNext(0) + probe.expectNext(1) + probe.expectNext(2) + probe.expectNoMessage(100.millis) + + probe.request(10) + probe.expectNext(3) + probe.expectNext(4) + probe.expectComplete() + } + + "work with mix of Source.single and other sources when slow demand" in assertAllStagesStopped { + val sources: Source[Source[Int, NotUsed], NotUsed] = Source(List( + Source.single(0), + Source.single(1), + Source(2 to 4), + Source.single(5), + Source(6 to 6), + Source.single(7), + Source(8 to 10), + Source.single(11) + )) + + val probe = + sources + .flatMapConcat(identity) + .runWith(TestSink.probe) + + probe.request(3) + probe.expectNext(0) + probe.expectNext(1) + probe.expectNext(2) + probe.expectNoMessage(100.millis) + + probe.request(1) + probe.expectNext(3) + probe.expectNoMessage(100.millis) + + probe.request(1) + probe.expectNext(4) + probe.expectNoMessage(100.millis) + + probe.request(3) + probe.expectNext(5) + probe.expectNext(6) + probe.expectNext(7) + probe.expectNoMessage(100.millis) + + probe.request(10) + probe.expectNext(8) + probe.expectNext(9) + probe.expectNext(10) + probe.expectNext(11) + probe.expectComplete() + } + + "find Source.single via TraversalBuilder" in assertAllStagesStopped { + TraversalBuilder.getSingleSource(Source.single("a")).get.elem should ===("a") + TraversalBuilder.getSingleSource(Source(List("a", "b"))) should be(OptionVal.None) + + val singleSourceA = new SingleSource("a") + TraversalBuilder.getSingleSource(singleSourceA) should be(OptionVal.Some(singleSourceA)) + + TraversalBuilder.getSingleSource(Source.single("c").async) should be(OptionVal.None) + TraversalBuilder.getSingleSource(Source.single("d").mapMaterializedValue(_ ⇒ "Mat")) should be(OptionVal.None) + } + } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/TraversalBuilder.scala b/akka-stream/src/main/scala/akka/stream/impl/TraversalBuilder.scala index e82e3ecff6..c31d140d13 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/TraversalBuilder.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/TraversalBuilder.scala @@ -10,10 +10,12 @@ import akka.stream.impl.StreamLayout.AtomicModule import akka.stream.impl.TraversalBuilder.{ AnyFunction1, AnyFunction2 } import akka.stream.scaladsl.Keep import akka.util.OptionVal - import scala.language.existentials import scala.collection.immutable.Map.Map1 +import akka.stream.impl.fusing.GraphStageModule +import akka.stream.impl.fusing.GraphStages.SingleSource + /** * INTERNAL API * @@ -334,6 +336,37 @@ import scala.collection.immutable.Map.Map1 } slot } + + /** + * Try to find `SingleSource` or wrapped such. This is used as a + * performance optimization in FlattenMerge and possibly other places. + */ + def getSingleSource[A >: Null](graph: Graph[SourceShape[A], _]): OptionVal[SingleSource[A]] = { + graph match { + case single: SingleSource[A] @unchecked ⇒ OptionVal.Some(single) + case _ ⇒ + graph.traversalBuilder match { + case l: LinearTraversalBuilder ⇒ + l.pendingBuilder match { + case OptionVal.Some(a: AtomicTraversalBuilder) ⇒ + a.module match { + case m: GraphStageModule[_, _] ⇒ + m.stage match { + case single: SingleSource[A] @unchecked ⇒ + // It would be != EmptyTraversal if mapMaterializedValue was used and then we can't optimize. + if ((l.traversalSoFar eq EmptyTraversal) && !l.attributes.isAsync) + OptionVal.Some(single) + else OptionVal.None + case _ ⇒ OptionVal.None + } + case _ ⇒ OptionVal.None + } + case _ ⇒ OptionVal.None + } + case _ ⇒ OptionVal.None + } + } + } } /** diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala index 8b756fd0c8..9e2f2cc4fc 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/StreamOfStreams.scala @@ -16,15 +16,18 @@ import akka.stream.impl.SubscriptionTimeoutException import akka.stream.stage._ import akka.stream.scaladsl._ import akka.stream.actor.ActorSubscriberMessage - +import akka.util.OptionVal import scala.collection.immutable import scala.concurrent.duration.FiniteDuration import scala.util.control.NonFatal import scala.annotation.tailrec -import akka.stream.impl.{ Buffer ⇒ BufferImpl } +import akka.stream.impl.{ Buffer ⇒ BufferImpl } import scala.collection.JavaConverters._ +import akka.stream.impl.TraversalBuilder +import akka.stream.impl.fusing.GraphStages.SingleSource + /** * INTERNAL API */ @@ -37,17 +40,25 @@ import scala.collection.JavaConverters._ override def createLogic(enclosingAttributes: Attributes) = new GraphStageLogic(shape) { var sources = Set.empty[SubSinkInlet[T]] - def activeSources = sources.size + var pendingSingleSources = 0 + def activeSources = sources.size + pendingSingleSources - var q: BufferImpl[SubSinkInlet[T]] = _ + // To be able to optimize for SingleSource without materializing them the queue may hold either + // SubSinkInlet[T] or SingleSource + var queue: BufferImpl[AnyRef] = _ - override def preStart(): Unit = q = BufferImpl(breadth, materializer) + override def preStart(): Unit = queue = BufferImpl(breadth, materializer) def pushOut(): Unit = { - val src = q.dequeue() - push(out, src.grab()) - if (!src.isClosed) src.pull() - else removeSource(src) + queue.dequeue() match { + case src: SubSinkInlet[T] @unchecked ⇒ + push(out, src.grab()) + if (!src.isClosed) src.pull() + else removeSource(src) + case single: SingleSource[T] @unchecked ⇒ + push(out, single.elem) + removeSource(single) + } } setHandler(in, new InHandler { @@ -68,31 +79,48 @@ import scala.collection.JavaConverters._ val outHandler = new OutHandler { // could be unavailable due to async input having been executed before this notification - override def onPull(): Unit = if (q.nonEmpty && isAvailable(out)) pushOut() + override def onPull(): Unit = if (queue.nonEmpty && isAvailable(out)) pushOut() } def addSource(source: Graph[SourceShape[T], M]): Unit = { - val sinkIn = new SubSinkInlet[T]("FlattenMergeSink") - sinkIn.setHandler(new InHandler { - override def onPush(): Unit = { - if (isAvailable(out)) { - push(out, sinkIn.grab()) - sinkIn.pull() + // If it's a SingleSource or wrapped such we can push the element directly instead of materializing it. + // Have to use AnyRef because of OptionVal null value. + TraversalBuilder.getSingleSource(source.asInstanceOf[Graph[SourceShape[AnyRef], M]]) match { + case OptionVal.Some(single) ⇒ + if (isAvailable(out) && queue.isEmpty) { + push(out, single.elem.asInstanceOf[T]) } else { - q.enqueue(sinkIn) + queue.enqueue(single) + pendingSingleSources += 1 } - } - override def onUpstreamFinish(): Unit = if (!sinkIn.isAvailable) removeSource(sinkIn) - }) - sinkIn.pull() - sources += sinkIn - val graph = Source.fromGraph(source).to(sinkIn.sink) - interpreter.subFusingMaterializer.materialize(graph, defaultAttributes = enclosingAttributes) + case _ ⇒ + val sinkIn = new SubSinkInlet[T]("FlattenMergeSink") + sinkIn.setHandler(new InHandler { + override def onPush(): Unit = { + if (isAvailable(out)) { + push(out, sinkIn.grab()) + sinkIn.pull() + } else { + queue.enqueue(sinkIn) + } + } + override def onUpstreamFinish(): Unit = if (!sinkIn.isAvailable) removeSource(sinkIn) + }) + sinkIn.pull() + sources += sinkIn + val graph = Source.fromGraph(source).to(sinkIn.sink) + interpreter.subFusingMaterializer.materialize(graph, defaultAttributes = enclosingAttributes) + } } - def removeSource(src: SubSinkInlet[T]): Unit = { + def removeSource(src: AnyRef): Unit = { val pullSuppressed = activeSources == breadth - sources -= src + src match { + case sub: SubSinkInlet[T] @unchecked ⇒ + sources -= sub + case _: SingleSource[_] ⇒ + pendingSingleSources -= 1 + } if (pullSuppressed) tryPull(in) if (activeSources == 0 && isClosed(in)) completeStage() } diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala index 48d994a245..a6b079233a 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala @@ -17,16 +17,17 @@ import akka.stream.{ Outlet, SourceShape, _ } import akka.util.ConstantFun import akka.{ Done, NotUsed } import org.reactivestreams.{ Publisher, Subscriber } - import scala.annotation.tailrec import scala.annotation.unchecked.uncheckedVariance import scala.collection.immutable import scala.concurrent.duration.FiniteDuration import scala.concurrent.{ Future, Promise } -import akka.stream.stage.GraphStageWithMaterializedValue +import akka.stream.stage.GraphStageWithMaterializedValue import scala.compat.java8.FutureConverters._ +import akka.stream.impl.fusing.GraphStageModule + /** * A `Source` is a set of stream processing steps that has one open output. It can comprise * any number of internal sources and transformations that are wired together, or it can be