From 6e69bc87130bb9889350d829e847ed99a0ce3ca8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Andr=C3=A9n?= Date: Tue, 9 Jul 2019 13:58:26 +0200 Subject: [PATCH] Fix stack overflow in stream converters (#27305) --- .../scala/akka/stream/scaladsl/SinkSpec.scala | 138 +------- .../akka/stream/scaladsl/SourceSpec.scala | 115 +------ .../scaladsl/StreamConvertersSpec.scala | 309 ++++++++++++++++++ .../mima-filters/2.5.x.backwards.excludes | 4 + .../main/scala/akka/stream/impl/Sinks.scala | 108 +++++- .../stream/scaladsl/StreamConverters.scala | 25 +- 6 files changed, 426 insertions(+), 273 deletions(-) create mode 100644 akka-stream-tests/src/test/scala/akka/stream/scaladsl/StreamConvertersSpec.scala diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SinkSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SinkSpec.scala index 1d821cfa80..c63990b6e5 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SinkSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SinkSpec.scala @@ -4,14 +4,7 @@ package akka.stream.scaladsl -import java.util -import java.util.function -import java.util.function.{ BiConsumer, BinaryOperator, Supplier, ToIntFunction } -import java.util.stream.Collector.Characteristics -import java.util.stream.{ Collector, Collectors } - import akka.stream._ -import akka.stream.testkit.Utils._ import akka.stream.testkit._ import akka.stream.testkit.scaladsl.TestSink import akka.testkit.DefaultTimeout @@ -19,7 +12,7 @@ import com.github.ghik.silencer.silent import org.reactivestreams.Publisher import org.scalatest.concurrent.ScalaFutures -import scala.concurrent.{ Await, Future } +import scala.concurrent.Future import scala.concurrent.duration._ @silent // tests deprecated APIs @@ -212,135 +205,6 @@ class SinkSpec extends StreamSpec with DefaultTimeout with ScalaFutures { } } - "Java collector Sink" must { - - class TestCollector( - _supplier: () => Supplier[Array[Int]], - _accumulator: () => BiConsumer[Array[Int], Int], - _combiner: () => BinaryOperator[Array[Int]], - _finisher: () => function.Function[Array[Int], Int]) - extends Collector[Int, Array[Int], Int] { - override def supplier(): Supplier[Array[Int]] = _supplier() - override def combiner(): BinaryOperator[Array[Int]] = _combiner() - override def finisher(): function.Function[Array[Int], Int] = _finisher() - override def accumulator(): BiConsumer[Array[Int], Int] = _accumulator() - override def characteristics(): util.Set[Characteristics] = util.Collections.emptySet() - } - - val intIdentity: ToIntFunction[Int] = new ToIntFunction[Int] { - override def applyAsInt(value: Int): Int = value - } - - def supplier(): Supplier[Array[Int]] = new Supplier[Array[Int]] { - override def get(): Array[Int] = new Array(1) - } - def accumulator(): BiConsumer[Array[Int], Int] = new BiConsumer[Array[Int], Int] { - override def accept(a: Array[Int], b: Int): Unit = a(0) = intIdentity.applyAsInt(b) - } - - def combiner(): BinaryOperator[Array[Int]] = new BinaryOperator[Array[Int]] { - override def apply(a: Array[Int], b: Array[Int]): Array[Int] = { - a(0) += b(0); a - } - } - def finisher(): function.Function[Array[Int], Int] = new function.Function[Array[Int], Int] { - override def apply(a: Array[Int]): Int = a(0) - } - - "work in the happy case" in { - Source(1 to 100) - .map(_.toString) - .runWith(StreamConverters.javaCollector(() => Collectors.joining(", "))) - .futureValue should ===((1 to 100).mkString(", ")) - } - - "work parallelly in the happy case" in { - Source(1 to 100) - .runWith(StreamConverters.javaCollectorParallelUnordered(4)(() => Collectors.summingInt[Int](intIdentity))) - .futureValue - .toInt should ===(5050) - } - - "be reusable" in { - val sink = StreamConverters.javaCollector[Int, Integer](() => Collectors.summingInt[Int](intIdentity)) - Source(1 to 4).runWith(sink).futureValue.toInt should ===(10) - - // Collector has state so it preserves all previous elements that went though - Source(4 to 6).runWith(sink).futureValue.toInt should ===(15) - } - - "be reusable with parallel version" in { - val sink = StreamConverters.javaCollectorParallelUnordered(4)(() => Collectors.summingInt[Int](intIdentity)) - - Source(1 to 4).runWith(sink).futureValue.toInt should ===(10) - Source(4 to 6).runWith(sink).futureValue.toInt should ===(15) - } - - "fail if getting the supplier fails" in { - def failedSupplier(): Supplier[Array[Int]] = throw TE("") - val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() => - new TestCollector(failedSupplier _, accumulator _, combiner _, finisher _))) - a[TE] shouldBe thrownBy { - Await.result(future, 300.millis) - } - } - - "fail if the supplier fails" in { - def failedSupplier(): Supplier[Array[Int]] = new Supplier[Array[Int]] { - override def get(): Array[Int] = throw TE("") - } - val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() => - new TestCollector(failedSupplier _, accumulator _, combiner _, finisher _))) - a[TE] shouldBe thrownBy { - Await.result(future, 300.millis) - } - } - - "fail if getting the accumulator fails" in { - def failedAccumulator(): BiConsumer[Array[Int], Int] = throw TE("") - - val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() => - new TestCollector(supplier _, failedAccumulator _, combiner _, finisher _))) - a[TE] shouldBe thrownBy { - Await.result(future, 300.millis) - } - } - - "fail if the accumulator fails" in { - def failedAccumulator(): BiConsumer[Array[Int], Int] = new BiConsumer[Array[Int], Int] { - override def accept(a: Array[Int], b: Int): Unit = throw TE("") - } - - val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() => - new TestCollector(supplier _, failedAccumulator _, combiner _, finisher _))) - a[TE] shouldBe thrownBy { - Await.result(future, 300.millis) - } - } - - "fail if getting the finisher fails" in { - def failedFinisher(): function.Function[Array[Int], Int] = throw TE("") - - val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() => - new TestCollector(supplier _, accumulator _, combiner _, failedFinisher _))) - a[TE] shouldBe thrownBy { - Await.result(future, 300.millis) - } - } - - "fail if the finisher fails" in { - def failedFinisher(): function.Function[Array[Int], Int] = new function.Function[Array[Int], Int] { - override def apply(a: Array[Int]): Int = throw TE("") - } - val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() => - new TestCollector(supplier _, accumulator _, combiner _, failedFinisher _))) - a[TE] shouldBe thrownBy { - Await.result(future, 300.millis) - } - } - - } - "The ignore sink" should { "fail its materialized value on abrupt materializer termination" in { diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala index c138827db1..debe5852da 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala @@ -7,8 +7,7 @@ package akka.stream.scaladsl import akka.testkit.DefaultTimeout import org.scalatest.time.{ Millis, Span } -import scala.concurrent.{ Await, Future } -import scala.concurrent.duration._ +import scala.concurrent.Future import akka.stream.testkit.Utils.TE import com.github.ghik.silencer.silent //#imports @@ -19,8 +18,6 @@ import akka.stream.testkit._ import akka.NotUsed import akka.testkit.EventFilter import scala.collection.immutable -import java.util -import java.util.stream.BaseStream import akka.stream.testkit.scaladsl.TestSink @@ -364,116 +361,6 @@ class SourceSpec extends StreamSpec with DefaultTimeout { } } - "Java Stream source" must { - import scala.compat.java8.FunctionConverters._ - import java.util.stream.{ IntStream, Stream } - - def javaStreamInts = - IntStream.iterate(1, { i: Int => - i + 1 - }.asJava) - - "work with Java collections" in { - val list = new java.util.LinkedList[Integer]() - list.add(0) - list.add(1) - list.add(2) - - StreamConverters.fromJavaStream(() => list.stream()).map(_.intValue).runWith(Sink.seq).futureValue should ===( - List(0, 1, 2)) - } - - "work with primitive streams" in { - StreamConverters - .fromJavaStream(() => IntStream.rangeClosed(1, 10)) - .map(_.intValue) - .runWith(Sink.seq) - .futureValue should ===(1 to 10) - } - - "work with an empty stream" in { - StreamConverters.fromJavaStream(() => Stream.empty[Int]()).runWith(Sink.seq).futureValue should ===(Nil) - } - - "work with an infinite stream" in { - StreamConverters.fromJavaStream(() => javaStreamInts).take(1000).runFold(0)(_ + _).futureValue should ===(500500) - } - - "work with a filtered stream" in { - StreamConverters - .fromJavaStream(() => - javaStreamInts.filter({ i: Int => - i % 2 == 0 - }.asJava)) - .take(1000) - .runFold(0)(_ + _) - .futureValue should ===(1001000) - } - - "properly report errors during iteration" in { - import akka.stream.testkit.Utils.TE - // Filtering is lazy on Java Stream - - val failyFilter: Int => Boolean = i => throw TE("failing filter") - - a[TE] must be thrownBy { - Await.result( - StreamConverters.fromJavaStream(() => javaStreamInts.filter(failyFilter.asJava)).runWith(Sink.ignore), - 3.seconds) - } - } - - "close the underlying stream when completed" in { - @volatile var closed = false - - final class EmptyStream[A] extends BaseStream[A, EmptyStream[A]] { - override def unordered(): EmptyStream[A] = this - override def sequential(): EmptyStream[A] = this - override def parallel(): EmptyStream[A] = this - override def isParallel: Boolean = false - - override def spliterator(): util.Spliterator[A] = ??? - override def onClose(closeHandler: Runnable): EmptyStream[A] = ??? - - override def iterator(): util.Iterator[A] = new util.Iterator[A] { - override def next(): A = ??? - override def hasNext: Boolean = false - } - - override def close(): Unit = closed = true - } - - Await.ready(StreamConverters.fromJavaStream(() => new EmptyStream[Unit]).runWith(Sink.ignore), 3.seconds) - - closed should ===(true) - } - - "close the underlying stream when failed" in { - @volatile var closed = false - - final class FailingStream[A] extends BaseStream[A, FailingStream[A]] { - override def unordered(): FailingStream[A] = this - override def sequential(): FailingStream[A] = this - override def parallel(): FailingStream[A] = this - override def isParallel: Boolean = false - - override def spliterator(): util.Spliterator[A] = ??? - override def onClose(closeHandler: Runnable): FailingStream[A] = ??? - - override def iterator(): util.Iterator[A] = new util.Iterator[A] { - override def next(): A = throw new RuntimeException("ouch") - override def hasNext: Boolean = true - } - - override def close(): Unit = closed = true - } - - Await.ready(StreamConverters.fromJavaStream(() => new FailingStream[Unit]).runWith(Sink.ignore), 3.seconds) - - closed should ===(true) - } - } - "Source pre-materialization" must { "materialize the source and connect it to a publisher" in { diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/StreamConvertersSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/StreamConvertersSpec.scala new file mode 100644 index 0000000000..22605f1446 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/StreamConvertersSpec.scala @@ -0,0 +1,309 @@ +/* + * Copyright (C) 2009-2019 Lightbend Inc. + */ + +package akka.stream.scaladsl + +import java.util +import java.util.function.{ BiConsumer, BinaryOperator, Supplier, ToIntFunction } +import java.util.stream.Collector.Characteristics +import java.util.stream.{ BaseStream, Collector, Collectors } + +import akka.stream.ActorMaterializer +import akka.stream.testkit.StreamSpec +import akka.stream.testkit.Utils.TE +import akka.testkit.DefaultTimeout +import org.scalatest.time.{ Millis, Span } + +import scala.concurrent.Await +import scala.concurrent.duration._ + +class StreamConvertersSpec extends StreamSpec with DefaultTimeout { + + implicit val materializer = ActorMaterializer() + implicit val config = PatienceConfig(timeout = Span(timeout.duration.toMillis, Millis)) + + "Java Stream source" must { + import scala.compat.java8.FunctionConverters._ + import java.util.stream.{ IntStream, Stream } + + def javaStreamInts = + IntStream.iterate(1, { i: Int => + i + 1 + }.asJava) + + "work with Java collections" in { + val list = new java.util.LinkedList[Integer]() + list.add(0) + list.add(1) + list.add(2) + + StreamConverters.fromJavaStream(() => list.stream()).map(_.intValue).runWith(Sink.seq).futureValue should ===( + List(0, 1, 2)) + } + + "work with primitive streams" in { + StreamConverters + .fromJavaStream(() => IntStream.rangeClosed(1, 10)) + .map(_.intValue) + .runWith(Sink.seq) + .futureValue should ===(1 to 10) + } + + "work with an empty stream" in { + StreamConverters.fromJavaStream(() => Stream.empty[Int]()).runWith(Sink.seq).futureValue should ===(Nil) + } + + "work with an infinite stream" in { + StreamConverters.fromJavaStream(() => javaStreamInts).take(1000).runFold(0)(_ + _).futureValue should ===(500500) + } + + "work with a filtered stream" in { + StreamConverters + .fromJavaStream(() => + javaStreamInts.filter({ i: Int => + i % 2 == 0 + }.asJava)) + .take(1000) + .runFold(0)(_ + _) + .futureValue should ===(1001000) + } + + "properly report errors during iteration" in { + import akka.stream.testkit.Utils.TE + // Filtering is lazy on Java Stream + + val failyFilter: Int => Boolean = _ => throw TE("failing filter") + + a[TE] must be thrownBy { + Await.result( + StreamConverters.fromJavaStream(() => javaStreamInts.filter(failyFilter.asJava)).runWith(Sink.ignore), + 3.seconds) + } + } + + "close the underlying stream when completed" in { + @volatile var closed = false + + final class EmptyStream[A] extends BaseStream[A, EmptyStream[A]] { + override def unordered(): EmptyStream[A] = this + override def sequential(): EmptyStream[A] = this + override def parallel(): EmptyStream[A] = this + override def isParallel: Boolean = false + + override def spliterator(): util.Spliterator[A] = ??? + override def onClose(closeHandler: Runnable): EmptyStream[A] = ??? + + override def iterator(): util.Iterator[A] = new util.Iterator[A] { + override def next(): A = ??? + override def hasNext: Boolean = false + } + + override def close(): Unit = closed = true + } + + Await.ready(StreamConverters.fromJavaStream(() => new EmptyStream[Unit]).runWith(Sink.ignore), 3.seconds) + + closed should ===(true) + } + + "close the underlying stream when failed" in { + @volatile var closed = false + + final class FailingStream[A] extends BaseStream[A, FailingStream[A]] { + override def unordered(): FailingStream[A] = this + override def sequential(): FailingStream[A] = this + override def parallel(): FailingStream[A] = this + override def isParallel: Boolean = false + + override def spliterator(): util.Spliterator[A] = ??? + override def onClose(closeHandler: Runnable): FailingStream[A] = ??? + + override def iterator(): util.Iterator[A] = new util.Iterator[A] { + override def next(): A = throw new TE("ouch") + override def hasNext: Boolean = true + } + + override def close(): Unit = closed = true + } + + Await.ready(StreamConverters.fromJavaStream(() => new FailingStream[Unit]).runWith(Sink.ignore), 3.seconds) + + closed should ===(true) + } + + // Repeater for #24304 + "not throw stack overflow with a large source" in { + Source + .repeat(Integer.valueOf(1)) + .take(100000) + .runWith(StreamConverters.javaCollector[Integer, Integer] { () => + Collectors.summingInt(new ToIntFunction[Integer] { + def applyAsInt(value: Integer): Int = value + }) + }) + .futureValue + } + + "not share collector across materializations" in { + val stream = Source + .repeat(1) + .take(10) + .toMat(StreamConverters.javaCollector[Int, Integer] { () => + Collectors.summingInt(new ToIntFunction[Int] { + def applyAsInt(value: Int): Int = value + }) + })(Keep.right) + stream.run().futureValue should ===(Integer.valueOf(10)) + stream.run().futureValue should ===(Integer.valueOf(10)) + } + + } + + "Java collector Sink" must { + + class TestCollector( + _supplier: () => Supplier[Array[Int]], + _accumulator: () => BiConsumer[Array[Int], Int], + _combiner: () => BinaryOperator[Array[Int]], + _finisher: () => java.util.function.Function[Array[Int], Int]) + extends Collector[Int, Array[Int], Int] { + override def supplier(): Supplier[Array[Int]] = _supplier() + override def combiner(): BinaryOperator[Array[Int]] = _combiner() + override def finisher(): java.util.function.Function[Array[Int], Int] = _finisher() + override def accumulator(): BiConsumer[Array[Int], Int] = _accumulator() + override def characteristics(): util.Set[Characteristics] = util.Collections.emptySet() + } + + val intIdentity: ToIntFunction[Int] = new ToIntFunction[Int] { + override def applyAsInt(value: Int): Int = value + } + + def supplier(): Supplier[Array[Int]] = new Supplier[Array[Int]] { + override def get(): Array[Int] = new Array(1) + } + def accumulator(): BiConsumer[Array[Int], Int] = new BiConsumer[Array[Int], Int] { + override def accept(a: Array[Int], b: Int): Unit = a(0) = intIdentity.applyAsInt(b) + } + + def combiner(): BinaryOperator[Array[Int]] = new BinaryOperator[Array[Int]] { + override def apply(a: Array[Int], b: Array[Int]): Array[Int] = { + a(0) += b(0); a + } + } + def finisher(): java.util.function.Function[Array[Int], Int] = new java.util.function.Function[Array[Int], Int] { + override def apply(a: Array[Int]): Int = a(0) + } + + "work in the happy case" in { + Source(1 to 100) + .map(_.toString) + .runWith(StreamConverters.javaCollector(() => Collectors.joining(", "))) + .futureValue should ===((1 to 100).mkString(", ")) + } + + "work with an empty source" in { + Source + .empty[Int] + .map(_.toString) + .runWith(StreamConverters.javaCollector(() => Collectors.joining(", "))) + .futureValue should ===("") + } + + "work parallelly in the happy case" in { + Source(1 to 100) + .runWith(StreamConverters.javaCollectorParallelUnordered(4)(() => Collectors.summingInt[Int](intIdentity))) + .futureValue + .toInt should ===(5050) + } + + "work parallelly with an empty source" in { + Source + .empty[Int] + .map(_.toString) + .runWith(StreamConverters.javaCollectorParallelUnordered(4)(() => Collectors.joining(", "))) + .futureValue should ===("") + } + + "be reusable" in { + val sink = StreamConverters.javaCollector[Int, Integer](() => Collectors.summingInt[Int](intIdentity)) + Source(1 to 4).runWith(sink).futureValue.toInt should ===(10) + + // Collector has state so it preserves all previous elements that went though + Source(4 to 6).runWith(sink).futureValue.toInt should ===(15) + } + + "be reusable with parallel version" in { + val sink = StreamConverters.javaCollectorParallelUnordered(4)(() => Collectors.summingInt[Int](intIdentity)) + + Source(1 to 4).runWith(sink).futureValue.toInt should ===(10) + Source(4 to 6).runWith(sink).futureValue.toInt should ===(15) + } + + "fail if getting the supplier fails" in { + def failedSupplier(): Supplier[Array[Int]] = throw TE("") + val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() => + new TestCollector(failedSupplier _, accumulator _, combiner _, finisher _))) + a[TE] shouldBe thrownBy { + Await.result(future, 300.millis) + } + } + + "fail if the supplier fails" in { + def failedSupplier(): Supplier[Array[Int]] = new Supplier[Array[Int]] { + override def get(): Array[Int] = throw TE("") + } + val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() => + new TestCollector(failedSupplier _, accumulator _, combiner _, finisher _))) + a[TE] shouldBe thrownBy { + Await.result(future, 300.millis) + } + } + + "fail if getting the accumulator fails" in { + def failedAccumulator(): BiConsumer[Array[Int], Int] = throw TE("") + + val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() => + new TestCollector(supplier _, failedAccumulator _, combiner _, finisher _))) + a[TE] shouldBe thrownBy { + Await.result(future, 300.millis) + } + } + + "fail if the accumulator fails" in { + def failedAccumulator(): BiConsumer[Array[Int], Int] = new BiConsumer[Array[Int], Int] { + override def accept(a: Array[Int], b: Int): Unit = throw TE("") + } + + val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() => + new TestCollector(supplier _, failedAccumulator _, combiner _, finisher _))) + a[TE] shouldBe thrownBy { + Await.result(future, 300.millis) + } + } + + "fail if getting the finisher fails" in { + def failedFinisher(): java.util.function.Function[Array[Int], Int] = throw TE("") + + val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() => + new TestCollector(supplier _, accumulator _, combiner _, failedFinisher _))) + a[TE] shouldBe thrownBy { + Await.result(future, 300.millis) + } + } + + "fail if the finisher fails" in { + def failedFinisher(): java.util.function.Function[Array[Int], Int] = + new java.util.function.Function[Array[Int], Int] { + override def apply(a: Array[Int]): Int = throw TE("") + } + val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() => + new TestCollector(supplier _, accumulator _, combiner _, failedFinisher _))) + a[TE] shouldBe thrownBy { + Await.result(future, 300.millis) + } + } + + } + +} diff --git a/akka-stream/src/main/mima-filters/2.5.x.backwards.excludes b/akka-stream/src/main/mima-filters/2.5.x.backwards.excludes index f153eaf41e..bcdfaa91ce 100644 --- a/akka-stream/src/main/mima-filters/2.5.x.backwards.excludes +++ b/akka-stream/src/main/mima-filters/2.5.x.backwards.excludes @@ -140,3 +140,7 @@ ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.InputStreamPubl # #19980 subscription timeouts for streams ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.stream.impl.ActorProcessorImpl.subTimeoutHandling") ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.FanoutOutputs.this") + +# #24304 stack overflow in StreamConverters +ProblemFilters.exclude[IncompatibleTemplateDefProblem]("akka.stream.impl.CollectorState") +ProblemFilters.exclude[IncompatibleTemplateDefProblem]("akka.stream.impl.ReducerState") \ No newline at end of file diff --git a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala index fdca1434f9..d6a9f292d7 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala @@ -4,6 +4,8 @@ package akka.stream.impl +import java.util.function.BinaryOperator + import scala.annotation.unchecked.uncheckedVariance import scala.collection.immutable import scala.collection.mutable @@ -13,7 +15,6 @@ import scala.util.Failure import scala.util.Success import scala.util.Try import scala.util.control.NonFatal - import akka.NotUsed import akka.actor.ActorRef import akka.actor.Props @@ -428,21 +429,81 @@ import org.reactivestreams.Subscriber } } +/** + * INTERNAL API + * + * Helper class to be able to express collection as a fold using mutable data without + * accidentally sharing state between materializations + */ +@InternalApi private[akka] trait CollectorState[T, R] { + def accumulated(): Any + def update(elem: T): CollectorState[T, R] + def finish(): R +} + /** * INTERNAL API * * Helper class to be able to express collection as a fold using mutable data */ -@InternalApi private[akka] final class CollectorState[T, R](val collector: java.util.stream.Collector[T, Any, R]) { - lazy val accumulated = collector.supplier().get() - private lazy val accumulator = collector.accumulator() +@InternalApi private[akka] final class FirstCollectorState[T, R]( + collectorFactory: () => java.util.stream.Collector[T, Any, R]) + extends CollectorState[T, R] { - def update(elem: T): CollectorState[T, R] = { + override def update(elem: T): CollectorState[T, R] = { + // on first update, return a new mutable collector to ensure not + // sharing collector between streams + val collector = collectorFactory() + val accumulator = collector.accumulator() + val accumulated = collector.supplier().get() + accumulator.accept(accumulated, elem) + new MutableCollectorState(collector, accumulator, accumulated) + } + + override def accumulated(): Any = { + // only called if it is asked about accumulated before accepting a first element + val collector = collectorFactory() + collector.supplier().get() + } + + override def finish(): R = { + // only called if completed without elements + val collector = collectorFactory() + collector.finisher().apply(collector.supplier().get()) + } +} + +/** + * INTERNAL API + * + * Helper class to be able to express collection as a fold using mutable data + */ +@InternalApi private[akka] final class MutableCollectorState[T, R]( + collector: java.util.stream.Collector[T, Any, R], + accumulator: java.util.function.BiConsumer[Any, T], + val accumulated: Any) + extends CollectorState[T, R] { + + override def update(elem: T): CollectorState[T, R] = { accumulator.accept(accumulated, elem) this } - def finish(): R = collector.finisher().apply(accumulated) + override def finish(): R = { + // only called if completed without elements + collector.finisher().apply(accumulated) + } +} + +/** + * INTERNAL API + * + * Helper class to be able to express reduce as a fold for parallel collector without + * accidentally sharing state between materializations + */ +@InternalApi private[akka] trait ReducerState[T, R] { + def update(batch: Any): ReducerState[T, R] + def finish(): R } /** @@ -450,13 +511,38 @@ import org.reactivestreams.Subscriber * * Helper class to be able to express reduce as a fold for parallel collector */ -@InternalApi private[akka] final class ReducerState[T, R](val collector: java.util.stream.Collector[T, Any, R]) { - private var reduced: Any = null.asInstanceOf[Any] - private lazy val combiner = collector.combiner() +@InternalApi private[akka] final class FirstReducerState[T, R]( + collectorFactory: () => java.util.stream.Collector[T, Any, R]) + extends ReducerState[T, R] { def update(batch: Any): ReducerState[T, R] = { - if (reduced == null) reduced = batch - else reduced = combiner(reduced, batch) + // on first update, return a new mutable collector to ensure not + // sharing collector between streams + val collector = collectorFactory() + val combiner = collector.combiner() + new MutableReducerState(collector, combiner, batch) + } + + def finish(): R = { + // only called if completed without elements + val collector = collectorFactory() + collector.finisher().apply(null) + } +} + +/** + * INTERNAL API + * + * Helper class to be able to express reduce as a fold for parallel collector + */ +@InternalApi private[akka] final class MutableReducerState[T, R]( + val collector: java.util.stream.Collector[T, Any, R], + val combiner: BinaryOperator[Any], + var reduced: Any) + extends ReducerState[T, R] { + + def update(batch: Any): ReducerState[T, R] = { + reduced = combiner(reduced, batch) this } diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/StreamConverters.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/StreamConverters.scala index a20ac7917f..5091eeb3e7 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/StreamConverters.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/StreamConverters.scala @@ -111,11 +111,14 @@ object StreamConverters { */ def javaCollector[T, R](collectorFactory: () => java.util.stream.Collector[T, _ <: Any, R]): Sink[T, Future[R]] = Flow[T] - .fold(() => new CollectorState[T, R](collectorFactory().asInstanceOf[Collector[T, Any, R]])) { - (state, elem) => () => - state().update(elem) + .fold { + new FirstCollectorState[T, R](collectorFactory.asInstanceOf[() => java.util.stream.Collector[T, Any, R]]): CollectorState[ + T, + R] + } { (state, elem) => + state.update(elem) } - .map(state => state().finish()) + .map(state => state.finish()) .toMat(Sink.head)(Keep.right) .withAttributes(DefaultAttributes.javaCollector) @@ -136,14 +139,14 @@ object StreamConverters { Sink .fromGraph(GraphDSL.create(Sink.head[R]) { implicit b => sink => import GraphDSL.Implicits._ - val collector = collectorFactory().asInstanceOf[Collector[T, Any, R]] + val factory = collectorFactory.asInstanceOf[() => Collector[T, Any, R]] val balance = b.add(Balance[T](parallelism)) - val merge = b.add(Merge[() => CollectorState[T, R]](parallelism)) + val merge = b.add(Merge[CollectorState[T, R]](parallelism)) for (i <- 0 until parallelism) { val worker = Flow[T] - .fold(() => new CollectorState(collector)) { (state, elem) => () => - state().update(elem) + .fold(new FirstCollectorState(factory): CollectorState[T, R]) { (state, elem) => + state.update(elem) } .async @@ -151,10 +154,10 @@ object StreamConverters { } merge.out - .fold(() => new ReducerState(collector)) { (state, elem) => () => - state().update(elem().accumulated) + .fold(new FirstReducerState(factory): ReducerState[T, R]) { (state, elem) => + state.update(elem.accumulated()) } - .map(state => state().finish()) ~> sink.in + .map(state => state.finish()) ~> sink.in SinkShape(balance.in) })