diff --git a/bench-jmh/src/main/scala/org/apache/pekko/stream/CollectBenchmark.scala b/bench-jmh/src/main/scala/org/apache/pekko/stream/CollectBenchmark.scala new file mode 100644 index 0000000000..6d1e627042 --- /dev/null +++ b/bench-jmh/src/main/scala/org/apache/pekko/stream/CollectBenchmark.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * license agreements; and to You under the Apache License, version 2.0: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * This file is part of the Apache Pekko project, which was derived from Akka. + */ + +/* + * Copyright (C) 2009-2022 Lightbend Inc. + */ + +package org.apache.pekko.stream + +import com.typesafe.config.ConfigFactory +import org.apache.pekko +import org.apache.pekko.stream.ActorAttributes.SupervisionStrategy +import org.apache.pekko.stream.Attributes.SourceLocation +import org.apache.pekko.stream.impl.Stages.DefaultAttributes +import org.apache.pekko.stream.impl.fusing.Collect +import org.apache.pekko.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler } +import org.openjdk.jmh.annotations._ +import pekko.actor.ActorSystem +import pekko.stream.scaladsl._ + +import java.util.concurrent.TimeUnit +import scala.annotation.nowarn +import scala.concurrent._ +import scala.concurrent.duration._ +import scala.util.control.NonFatal + +object CollectBenchmark { + final val OperationsPerInvocation = 10000000 +} + +@State(Scope.Benchmark) +@OutputTimeUnit(TimeUnit.SECONDS) +@BenchmarkMode(Array(Mode.Throughput)) +@nowarn("msg=deprecated") +class CollectBenchmark { + import CollectBenchmark._ + + private val config = ConfigFactory.parseString(""" + pekko.actor.default-dispatcher { + executor = "fork-join-executor" + fork-join-executor { + parallelism-factor = 1 + } + } + """) + + private implicit val system: ActorSystem = ActorSystem("CollectBenchmark", config) + + @TearDown + def shutdown(): Unit = { + Await.result(system.terminate(), 5.seconds) + } + + private val newCollect = Source + .repeat(1) + .via(new Collect({ case elem => elem })) + .take(OperationsPerInvocation) + .toMat(Sink.ignore)(Keep.right) + + private val oldCollect = Source + .repeat(1) + .via(new SimpleCollect({ case elem => elem })) + .take(OperationsPerInvocation) + .toMat(Sink.ignore)(Keep.right) + + private class SimpleCollect[In, Out](pf: PartialFunction[In, Out]) + extends GraphStage[FlowShape[In, Out]] { + val in = Inlet[In]("SimpleCollect.in") + val out = Outlet[Out]("SimpleCollect.out") + override val shape = FlowShape(in, out) + + override def initialAttributes: Attributes = DefaultAttributes.collect and SourceLocation.forLambda(pf) + + def createLogic(inheritedAttributes: Attributes) = + new GraphStageLogic(shape) with InHandler with OutHandler { + private lazy val decider = inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider + import Collect.NotApplied + + override def onPush(): Unit = + try { + pf.applyOrElse(grab(in), NotApplied) match { + case NotApplied => pull(in) + case result: Out @unchecked => push(out, result) + case _ => throw new RuntimeException() + } + } catch { + case NonFatal(ex) => + decider(ex) match { + case Supervision.Stop => failStage(ex) + case Supervision.Resume => if (!hasBeenPulled(in)) pull(in) + case Supervision.Restart => if (!hasBeenPulled(in)) pull(in) + } + } + + override def onPull(): Unit = pull(in) + + setHandlers(in, out, this) + } + + override def toString = "SimpleCollect" + } + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def benchOldCollect(): Unit = + Await.result(oldCollect.run(), Duration.Inf) + + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def benchNewCollect(): Unit = + Await.result(newCollect.run(), Duration.Inf) + +} diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/CollectWhile.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/CollectWhile.scala index 11617e28ec..90833f7b41 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/CollectWhile.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/CollectWhile.scala @@ -45,16 +45,26 @@ private[pekko] final class CollectWhile[In, Out](pf: PartialFunction[In, Out]) e override final def onPush(): Unit = try { - pf.applyOrElse(grab(in), NotApplied) match { - case NotApplied => completeStage() - case result: Out @unchecked => push(out, result) - case _ => throw new RuntimeException() // won't happen, compiler exhaustiveness check pleaser + // 1. `applyOrElse` is faster than (`pf.isDefinedAt` and then `pf.apply`) + // 2. using reference comparing here instead of pattern matching can generate less and quicker bytecode, + // eg: just a simple `IF_ACMPNE`, and you can find the same trick in `Collect` operator. + // If you interest, you can check the associated PR for this change and the + // current implementation of `scala.collection.IterableOnceOps.collectFirst`. + val result = pf.applyOrElse(grab(in), NotApplied) + if (result.asInstanceOf[AnyRef] eq NotApplied) { + completeStage() + } else { + push(out, result.asInstanceOf[Out]) } } catch { case NonFatal(ex) => decider(ex) match { case Supervision.Stop => failStage(ex) - case _ => pull(in) + case _ => + // The !hasBeenPulled(in) check is not required here since it + // isn't possible to do an additional pull(in) due to the nature + // of how collect works + pull(in) } } diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala index 8fe84fcc70..4f2a8031a2 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala @@ -258,10 +258,16 @@ private[stream] object Collect { override def onPush(): Unit = try { - pf.applyOrElse(grab(in), NotApplied) match { - case NotApplied => pull(in) - case result: Out @unchecked => push(out, result) - case _ => throw new RuntimeException() // won't happen, compiler exhaustiveness check pleaser + val result = pf.applyOrElse(grab(in), NotApplied) + // 1. `applyOrElse` is faster than (`pf.isDefinedAt` and then `pf.apply`) + // 2. using reference comparing here instead of pattern matching can generate less and quicker bytecode, + // eg: just a simple `IF_ACMPNE`, and you can find the same trick in `CollectWhile` operator. + // If you interest, you can check the associated PR for this change and the + // current implementation of `scala.collection.IterableOnceOps.collectFirst`. + if (result.asInstanceOf[AnyRef] eq Collect.NotApplied) { + pull(in) + } else { + push(out, result.asInstanceOf[Out]) } } catch { case NonFatal(ex) =>