Optimize flatMapConcat for single element source, #25241 (#25242)

* Optimize flatMapConcat for single element source, #25241

* Grab the SourceSingle via TraversalBuilder

* Also handle the case when there is no demand

* don't match when mapMaterializedValue and async
This commit is contained in:
Patrik Nordwall 2018-07-11 18:19:40 +02:00 committed by GitHub
parent 97490eb30c
commit d76b27ba3e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 306 additions and 29 deletions

View file

@ -0,0 +1,130 @@
/**
* Copyright (C) 2018 Lightbend Inc. <https://www.lightbend.com>
*/
package akka.stream
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import scala.concurrent.Await
import scala.concurrent.duration._
import akka.NotUsed
import akka.actor.ActorSystem
import akka.remote.artery.BenchTestSource
import akka.remote.artery.LatchSink
import akka.stream.impl.PhasedFusingActorMaterializer
import akka.stream.impl.StreamSupervisor
import akka.stream.scaladsl._
import akka.testkit.TestProbe
import com.typesafe.config.ConfigFactory
import org.openjdk.jmh.annotations._
import akka.stream.impl.fusing.GraphStages
object FlatMapConcatBenchmark {
final val OperationsPerInvocation = 100000
}
@State(Scope.Benchmark)
@OutputTimeUnit(TimeUnit.SECONDS)
@BenchmarkMode(Array(Mode.Throughput))
class FlatMapConcatBenchmark {
import FlatMapConcatBenchmark._
private val config = ConfigFactory.parseString(
"""
akka.actor.default-dispatcher {
executor = "fork-join-executor"
fork-join-executor {
parallelism-factor = 1
}
}
"""
)
private implicit val system: ActorSystem = ActorSystem("FlatMapConcatBenchmark", config)
var materializer: ActorMaterializer = _
var testSource: Source[java.lang.Integer, NotUsed] = _
@Setup
def setup(): Unit = {
val settings = ActorMaterializerSettings(system)
materializer = ActorMaterializer(settings)
testSource = Source.fromGraph(new BenchTestSource(OperationsPerInvocation))
}
@TearDown
def shutdown(): Unit = {
Await.result(system.terminate(), 5.seconds)
}
@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def sourceDotSingle(): Unit = {
val latch = new CountDownLatch(1)
testSource
.flatMapConcat(Source.single)
.runWith(new LatchSink(OperationsPerInvocation, latch))(materializer)
awaitLatch(latch)
}
@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def internalSingleSource(): Unit = {
val latch = new CountDownLatch(1)
testSource
.flatMapConcat(elem new GraphStages.SingleSource(elem))
.runWith(new LatchSink(OperationsPerInvocation, latch))(materializer)
awaitLatch(latch)
}
@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def oneElementList(): Unit = {
val latch = new CountDownLatch(1)
testSource
.flatMapConcat(n Source(n :: Nil))
.runWith(new LatchSink(OperationsPerInvocation, latch))(materializer)
awaitLatch(latch)
}
@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def mapBaseline(): Unit = {
val latch = new CountDownLatch(1)
testSource
.map(elem elem)
.runWith(new LatchSink(OperationsPerInvocation, latch))(materializer)
awaitLatch(latch)
}
private def awaitLatch(latch: CountDownLatch): Unit = {
if (!latch.await(30, TimeUnit.SECONDS)) {
dumpMaterializer()
throw new RuntimeException("Latch didn't complete in time")
}
}
private def dumpMaterializer(): Unit = {
materializer match {
case impl: PhasedFusingActorMaterializer
val probe = TestProbe()(system)
impl.supervisor.tell(StreamSupervisor.GetChildren, probe.ref)
val children = probe.expectMsgType[StreamSupervisor.Children].children
children.foreach(_ ! StreamSupervisor.PrintDebugDump)
}
}
}

View file

@ -12,10 +12,15 @@ import akka.stream.testkit.scaladsl.StreamTestKit._
import scala.concurrent._ import scala.concurrent._
import scala.concurrent.duration._ import scala.concurrent.duration._
import akka.stream.impl.TraversalBuilder
import akka.stream.impl.fusing.GraphStages
import akka.stream.impl.fusing.GraphStages.SingleSource
import akka.stream.testkit.{ StreamSpec, TestPublisher } import akka.stream.testkit.{ StreamSpec, TestPublisher }
import org.scalatest.exceptions.TestFailedException import org.scalatest.exceptions.TestFailedException
import akka.stream.testkit.scaladsl.TestSink import akka.stream.testkit.scaladsl.TestSink
import akka.testkit.TestLatch import akka.testkit.TestLatch
import akka.util.OptionVal
class FlowFlattenMergeSpec extends StreamSpec { class FlowFlattenMergeSpec extends StreamSpec {
implicit val materializer = ActorMaterializer() implicit val materializer = ActorMaterializer()
@ -210,5 +215,85 @@ class FlowFlattenMergeSpec extends StreamSpec {
attributes.indexOf(Attributes.Name("inner")) < attributes.indexOf(Attributes.Name("outer")) should be(true) attributes.indexOf(Attributes.Name("inner")) < attributes.indexOf(Attributes.Name("outer")) should be(true)
} }
"work with optimized Source.single" in assertAllStagesStopped {
Source(0 to 3)
.flatMapConcat(Source.single)
.runWith(toSeq)
.futureValue should ===(0 to 3)
}
"work with optimized Source.single when slow demand" in assertAllStagesStopped {
val probe = Source(0 to 4)
.flatMapConcat(Source.single)
.runWith(TestSink.probe)
probe.request(3)
probe.expectNext(0)
probe.expectNext(1)
probe.expectNext(2)
probe.expectNoMessage(100.millis)
probe.request(10)
probe.expectNext(3)
probe.expectNext(4)
probe.expectComplete()
}
"work with mix of Source.single and other sources when slow demand" in assertAllStagesStopped {
val sources: Source[Source[Int, NotUsed], NotUsed] = Source(List(
Source.single(0),
Source.single(1),
Source(2 to 4),
Source.single(5),
Source(6 to 6),
Source.single(7),
Source(8 to 10),
Source.single(11)
))
val probe =
sources
.flatMapConcat(identity)
.runWith(TestSink.probe)
probe.request(3)
probe.expectNext(0)
probe.expectNext(1)
probe.expectNext(2)
probe.expectNoMessage(100.millis)
probe.request(1)
probe.expectNext(3)
probe.expectNoMessage(100.millis)
probe.request(1)
probe.expectNext(4)
probe.expectNoMessage(100.millis)
probe.request(3)
probe.expectNext(5)
probe.expectNext(6)
probe.expectNext(7)
probe.expectNoMessage(100.millis)
probe.request(10)
probe.expectNext(8)
probe.expectNext(9)
probe.expectNext(10)
probe.expectNext(11)
probe.expectComplete()
}
"find Source.single via TraversalBuilder" in assertAllStagesStopped {
TraversalBuilder.getSingleSource(Source.single("a")).get.elem should ===("a")
TraversalBuilder.getSingleSource(Source(List("a", "b"))) should be(OptionVal.None)
val singleSourceA = new SingleSource("a")
TraversalBuilder.getSingleSource(singleSourceA) should be(OptionVal.Some(singleSourceA))
TraversalBuilder.getSingleSource(Source.single("c").async) should be(OptionVal.None)
TraversalBuilder.getSingleSource(Source.single("d").mapMaterializedValue(_ "Mat")) should be(OptionVal.None)
}
} }
} }

View file

@ -10,10 +10,12 @@ import akka.stream.impl.StreamLayout.AtomicModule
import akka.stream.impl.TraversalBuilder.{ AnyFunction1, AnyFunction2 } import akka.stream.impl.TraversalBuilder.{ AnyFunction1, AnyFunction2 }
import akka.stream.scaladsl.Keep import akka.stream.scaladsl.Keep
import akka.util.OptionVal import akka.util.OptionVal
import scala.language.existentials import scala.language.existentials
import scala.collection.immutable.Map.Map1 import scala.collection.immutable.Map.Map1
import akka.stream.impl.fusing.GraphStageModule
import akka.stream.impl.fusing.GraphStages.SingleSource
/** /**
* INTERNAL API * INTERNAL API
* *
@ -334,6 +336,37 @@ import scala.collection.immutable.Map.Map1
} }
slot slot
} }
/**
* Try to find `SingleSource` or wrapped such. This is used as a
* performance optimization in FlattenMerge and possibly other places.
*/
def getSingleSource[A >: Null](graph: Graph[SourceShape[A], _]): OptionVal[SingleSource[A]] = {
graph match {
case single: SingleSource[A] @unchecked OptionVal.Some(single)
case _
graph.traversalBuilder match {
case l: LinearTraversalBuilder
l.pendingBuilder match {
case OptionVal.Some(a: AtomicTraversalBuilder)
a.module match {
case m: GraphStageModule[_, _]
m.stage match {
case single: SingleSource[A] @unchecked
// It would be != EmptyTraversal if mapMaterializedValue was used and then we can't optimize.
if ((l.traversalSoFar eq EmptyTraversal) && !l.attributes.isAsync)
OptionVal.Some(single)
else OptionVal.None
case _ OptionVal.None
}
case _ OptionVal.None
}
case _ OptionVal.None
}
case _ OptionVal.None
}
}
}
} }
/** /**

View file

@ -16,15 +16,18 @@ import akka.stream.impl.SubscriptionTimeoutException
import akka.stream.stage._ import akka.stream.stage._
import akka.stream.scaladsl._ import akka.stream.scaladsl._
import akka.stream.actor.ActorSubscriberMessage import akka.stream.actor.ActorSubscriberMessage
import akka.util.OptionVal
import scala.collection.immutable import scala.collection.immutable
import scala.concurrent.duration.FiniteDuration import scala.concurrent.duration.FiniteDuration
import scala.util.control.NonFatal import scala.util.control.NonFatal
import scala.annotation.tailrec import scala.annotation.tailrec
import akka.stream.impl.{ Buffer BufferImpl }
import akka.stream.impl.{ Buffer BufferImpl }
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import akka.stream.impl.TraversalBuilder
import akka.stream.impl.fusing.GraphStages.SingleSource
/** /**
* INTERNAL API * INTERNAL API
*/ */
@ -37,17 +40,25 @@ import scala.collection.JavaConverters._
override def createLogic(enclosingAttributes: Attributes) = new GraphStageLogic(shape) { override def createLogic(enclosingAttributes: Attributes) = new GraphStageLogic(shape) {
var sources = Set.empty[SubSinkInlet[T]] var sources = Set.empty[SubSinkInlet[T]]
def activeSources = sources.size var pendingSingleSources = 0
def activeSources = sources.size + pendingSingleSources
var q: BufferImpl[SubSinkInlet[T]] = _ // To be able to optimize for SingleSource without materializing them the queue may hold either
// SubSinkInlet[T] or SingleSource
var queue: BufferImpl[AnyRef] = _
override def preStart(): Unit = q = BufferImpl(breadth, materializer) override def preStart(): Unit = queue = BufferImpl(breadth, materializer)
def pushOut(): Unit = { def pushOut(): Unit = {
val src = q.dequeue() queue.dequeue() match {
case src: SubSinkInlet[T] @unchecked
push(out, src.grab()) push(out, src.grab())
if (!src.isClosed) src.pull() if (!src.isClosed) src.pull()
else removeSource(src) else removeSource(src)
case single: SingleSource[T] @unchecked
push(out, single.elem)
removeSource(single)
}
} }
setHandler(in, new InHandler { setHandler(in, new InHandler {
@ -68,10 +79,21 @@ import scala.collection.JavaConverters._
val outHandler = new OutHandler { val outHandler = new OutHandler {
// could be unavailable due to async input having been executed before this notification // could be unavailable due to async input having been executed before this notification
override def onPull(): Unit = if (q.nonEmpty && isAvailable(out)) pushOut() override def onPull(): Unit = if (queue.nonEmpty && isAvailable(out)) pushOut()
} }
def addSource(source: Graph[SourceShape[T], M]): Unit = { def addSource(source: Graph[SourceShape[T], M]): Unit = {
// If it's a SingleSource or wrapped such we can push the element directly instead of materializing it.
// Have to use AnyRef because of OptionVal null value.
TraversalBuilder.getSingleSource(source.asInstanceOf[Graph[SourceShape[AnyRef], M]]) match {
case OptionVal.Some(single)
if (isAvailable(out) && queue.isEmpty) {
push(out, single.elem.asInstanceOf[T])
} else {
queue.enqueue(single)
pendingSingleSources += 1
}
case _
val sinkIn = new SubSinkInlet[T]("FlattenMergeSink") val sinkIn = new SubSinkInlet[T]("FlattenMergeSink")
sinkIn.setHandler(new InHandler { sinkIn.setHandler(new InHandler {
override def onPush(): Unit = { override def onPush(): Unit = {
@ -79,7 +101,7 @@ import scala.collection.JavaConverters._
push(out, sinkIn.grab()) push(out, sinkIn.grab())
sinkIn.pull() sinkIn.pull()
} else { } else {
q.enqueue(sinkIn) queue.enqueue(sinkIn)
} }
} }
override def onUpstreamFinish(): Unit = if (!sinkIn.isAvailable) removeSource(sinkIn) override def onUpstreamFinish(): Unit = if (!sinkIn.isAvailable) removeSource(sinkIn)
@ -89,10 +111,16 @@ import scala.collection.JavaConverters._
val graph = Source.fromGraph(source).to(sinkIn.sink) val graph = Source.fromGraph(source).to(sinkIn.sink)
interpreter.subFusingMaterializer.materialize(graph, defaultAttributes = enclosingAttributes) interpreter.subFusingMaterializer.materialize(graph, defaultAttributes = enclosingAttributes)
} }
}
def removeSource(src: SubSinkInlet[T]): Unit = { def removeSource(src: AnyRef): Unit = {
val pullSuppressed = activeSources == breadth val pullSuppressed = activeSources == breadth
sources -= src src match {
case sub: SubSinkInlet[T] @unchecked
sources -= sub
case _: SingleSource[_]
pendingSingleSources -= 1
}
if (pullSuppressed) tryPull(in) if (pullSuppressed) tryPull(in)
if (activeSources == 0 && isClosed(in)) completeStage() if (activeSources == 0 && isClosed(in)) completeStage()
} }

View file

@ -17,16 +17,17 @@ import akka.stream.{ Outlet, SourceShape, _ }
import akka.util.ConstantFun import akka.util.ConstantFun
import akka.{ Done, NotUsed } import akka.{ Done, NotUsed }
import org.reactivestreams.{ Publisher, Subscriber } import org.reactivestreams.{ Publisher, Subscriber }
import scala.annotation.tailrec import scala.annotation.tailrec
import scala.annotation.unchecked.uncheckedVariance import scala.annotation.unchecked.uncheckedVariance
import scala.collection.immutable import scala.collection.immutable
import scala.concurrent.duration.FiniteDuration import scala.concurrent.duration.FiniteDuration
import scala.concurrent.{ Future, Promise } import scala.concurrent.{ Future, Promise }
import akka.stream.stage.GraphStageWithMaterializedValue
import akka.stream.stage.GraphStageWithMaterializedValue
import scala.compat.java8.FutureConverters._ import scala.compat.java8.FutureConverters._
import akka.stream.impl.fusing.GraphStageModule
/** /**
* A `Source` is a set of stream processing steps that has one open output. It can comprise * A `Source` is a set of stream processing steps that has one open output. It can comprise
* any number of internal sources and transformations that are wired together, or it can be * any number of internal sources and transformations that are wired together, or it can be