Fix stack overflow in stream converters (#27305)

This commit is contained in:
Johan Andrén 2019-07-09 13:58:26 +02:00 committed by GitHub
parent 0037998bfb
commit 6e69bc8713
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 426 additions and 273 deletions

View file

@ -4,14 +4,7 @@
package akka.stream.scaladsl package akka.stream.scaladsl
import java.util
import java.util.function
import java.util.function.{ BiConsumer, BinaryOperator, Supplier, ToIntFunction }
import java.util.stream.Collector.Characteristics
import java.util.stream.{ Collector, Collectors }
import akka.stream._ import akka.stream._
import akka.stream.testkit.Utils._
import akka.stream.testkit._ import akka.stream.testkit._
import akka.stream.testkit.scaladsl.TestSink import akka.stream.testkit.scaladsl.TestSink
import akka.testkit.DefaultTimeout import akka.testkit.DefaultTimeout
@ -19,7 +12,7 @@ import com.github.ghik.silencer.silent
import org.reactivestreams.Publisher import org.reactivestreams.Publisher
import org.scalatest.concurrent.ScalaFutures import org.scalatest.concurrent.ScalaFutures
import scala.concurrent.{ Await, Future } import scala.concurrent.Future
import scala.concurrent.duration._ import scala.concurrent.duration._
@silent // tests deprecated APIs @silent // tests deprecated APIs
@ -212,135 +205,6 @@ class SinkSpec extends StreamSpec with DefaultTimeout with ScalaFutures {
} }
} }
"Java collector Sink" must {
class TestCollector(
_supplier: () => Supplier[Array[Int]],
_accumulator: () => BiConsumer[Array[Int], Int],
_combiner: () => BinaryOperator[Array[Int]],
_finisher: () => function.Function[Array[Int], Int])
extends Collector[Int, Array[Int], Int] {
override def supplier(): Supplier[Array[Int]] = _supplier()
override def combiner(): BinaryOperator[Array[Int]] = _combiner()
override def finisher(): function.Function[Array[Int], Int] = _finisher()
override def accumulator(): BiConsumer[Array[Int], Int] = _accumulator()
override def characteristics(): util.Set[Characteristics] = util.Collections.emptySet()
}
val intIdentity: ToIntFunction[Int] = new ToIntFunction[Int] {
override def applyAsInt(value: Int): Int = value
}
def supplier(): Supplier[Array[Int]] = new Supplier[Array[Int]] {
override def get(): Array[Int] = new Array(1)
}
def accumulator(): BiConsumer[Array[Int], Int] = new BiConsumer[Array[Int], Int] {
override def accept(a: Array[Int], b: Int): Unit = a(0) = intIdentity.applyAsInt(b)
}
def combiner(): BinaryOperator[Array[Int]] = new BinaryOperator[Array[Int]] {
override def apply(a: Array[Int], b: Array[Int]): Array[Int] = {
a(0) += b(0); a
}
}
def finisher(): function.Function[Array[Int], Int] = new function.Function[Array[Int], Int] {
override def apply(a: Array[Int]): Int = a(0)
}
"work in the happy case" in {
Source(1 to 100)
.map(_.toString)
.runWith(StreamConverters.javaCollector(() => Collectors.joining(", ")))
.futureValue should ===((1 to 100).mkString(", "))
}
"work parallelly in the happy case" in {
Source(1 to 100)
.runWith(StreamConverters.javaCollectorParallelUnordered(4)(() => Collectors.summingInt[Int](intIdentity)))
.futureValue
.toInt should ===(5050)
}
"be reusable" in {
val sink = StreamConverters.javaCollector[Int, Integer](() => Collectors.summingInt[Int](intIdentity))
Source(1 to 4).runWith(sink).futureValue.toInt should ===(10)
// Collector has state so it preserves all previous elements that went though
Source(4 to 6).runWith(sink).futureValue.toInt should ===(15)
}
"be reusable with parallel version" in {
val sink = StreamConverters.javaCollectorParallelUnordered(4)(() => Collectors.summingInt[Int](intIdentity))
Source(1 to 4).runWith(sink).futureValue.toInt should ===(10)
Source(4 to 6).runWith(sink).futureValue.toInt should ===(15)
}
"fail if getting the supplier fails" in {
def failedSupplier(): Supplier[Array[Int]] = throw TE("")
val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() =>
new TestCollector(failedSupplier _, accumulator _, combiner _, finisher _)))
a[TE] shouldBe thrownBy {
Await.result(future, 300.millis)
}
}
"fail if the supplier fails" in {
def failedSupplier(): Supplier[Array[Int]] = new Supplier[Array[Int]] {
override def get(): Array[Int] = throw TE("")
}
val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() =>
new TestCollector(failedSupplier _, accumulator _, combiner _, finisher _)))
a[TE] shouldBe thrownBy {
Await.result(future, 300.millis)
}
}
"fail if getting the accumulator fails" in {
def failedAccumulator(): BiConsumer[Array[Int], Int] = throw TE("")
val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() =>
new TestCollector(supplier _, failedAccumulator _, combiner _, finisher _)))
a[TE] shouldBe thrownBy {
Await.result(future, 300.millis)
}
}
"fail if the accumulator fails" in {
def failedAccumulator(): BiConsumer[Array[Int], Int] = new BiConsumer[Array[Int], Int] {
override def accept(a: Array[Int], b: Int): Unit = throw TE("")
}
val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() =>
new TestCollector(supplier _, failedAccumulator _, combiner _, finisher _)))
a[TE] shouldBe thrownBy {
Await.result(future, 300.millis)
}
}
"fail if getting the finisher fails" in {
def failedFinisher(): function.Function[Array[Int], Int] = throw TE("")
val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() =>
new TestCollector(supplier _, accumulator _, combiner _, failedFinisher _)))
a[TE] shouldBe thrownBy {
Await.result(future, 300.millis)
}
}
"fail if the finisher fails" in {
def failedFinisher(): function.Function[Array[Int], Int] = new function.Function[Array[Int], Int] {
override def apply(a: Array[Int]): Int = throw TE("")
}
val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() =>
new TestCollector(supplier _, accumulator _, combiner _, failedFinisher _)))
a[TE] shouldBe thrownBy {
Await.result(future, 300.millis)
}
}
}
"The ignore sink" should { "The ignore sink" should {
"fail its materialized value on abrupt materializer termination" in { "fail its materialized value on abrupt materializer termination" in {

View file

@ -7,8 +7,7 @@ package akka.stream.scaladsl
import akka.testkit.DefaultTimeout import akka.testkit.DefaultTimeout
import org.scalatest.time.{ Millis, Span } import org.scalatest.time.{ Millis, Span }
import scala.concurrent.{ Await, Future } import scala.concurrent.Future
import scala.concurrent.duration._
import akka.stream.testkit.Utils.TE import akka.stream.testkit.Utils.TE
import com.github.ghik.silencer.silent import com.github.ghik.silencer.silent
//#imports //#imports
@ -19,8 +18,6 @@ import akka.stream.testkit._
import akka.NotUsed import akka.NotUsed
import akka.testkit.EventFilter import akka.testkit.EventFilter
import scala.collection.immutable import scala.collection.immutable
import java.util
import java.util.stream.BaseStream
import akka.stream.testkit.scaladsl.TestSink import akka.stream.testkit.scaladsl.TestSink
@ -364,116 +361,6 @@ class SourceSpec extends StreamSpec with DefaultTimeout {
} }
} }
"Java Stream source" must {
import scala.compat.java8.FunctionConverters._
import java.util.stream.{ IntStream, Stream }
def javaStreamInts =
IntStream.iterate(1, { i: Int =>
i + 1
}.asJava)
"work with Java collections" in {
val list = new java.util.LinkedList[Integer]()
list.add(0)
list.add(1)
list.add(2)
StreamConverters.fromJavaStream(() => list.stream()).map(_.intValue).runWith(Sink.seq).futureValue should ===(
List(0, 1, 2))
}
"work with primitive streams" in {
StreamConverters
.fromJavaStream(() => IntStream.rangeClosed(1, 10))
.map(_.intValue)
.runWith(Sink.seq)
.futureValue should ===(1 to 10)
}
"work with an empty stream" in {
StreamConverters.fromJavaStream(() => Stream.empty[Int]()).runWith(Sink.seq).futureValue should ===(Nil)
}
"work with an infinite stream" in {
StreamConverters.fromJavaStream(() => javaStreamInts).take(1000).runFold(0)(_ + _).futureValue should ===(500500)
}
"work with a filtered stream" in {
StreamConverters
.fromJavaStream(() =>
javaStreamInts.filter({ i: Int =>
i % 2 == 0
}.asJava))
.take(1000)
.runFold(0)(_ + _)
.futureValue should ===(1001000)
}
"properly report errors during iteration" in {
import akka.stream.testkit.Utils.TE
// Filtering is lazy on Java Stream
val failyFilter: Int => Boolean = i => throw TE("failing filter")
a[TE] must be thrownBy {
Await.result(
StreamConverters.fromJavaStream(() => javaStreamInts.filter(failyFilter.asJava)).runWith(Sink.ignore),
3.seconds)
}
}
"close the underlying stream when completed" in {
@volatile var closed = false
final class EmptyStream[A] extends BaseStream[A, EmptyStream[A]] {
override def unordered(): EmptyStream[A] = this
override def sequential(): EmptyStream[A] = this
override def parallel(): EmptyStream[A] = this
override def isParallel: Boolean = false
override def spliterator(): util.Spliterator[A] = ???
override def onClose(closeHandler: Runnable): EmptyStream[A] = ???
override def iterator(): util.Iterator[A] = new util.Iterator[A] {
override def next(): A = ???
override def hasNext: Boolean = false
}
override def close(): Unit = closed = true
}
Await.ready(StreamConverters.fromJavaStream(() => new EmptyStream[Unit]).runWith(Sink.ignore), 3.seconds)
closed should ===(true)
}
"close the underlying stream when failed" in {
@volatile var closed = false
final class FailingStream[A] extends BaseStream[A, FailingStream[A]] {
override def unordered(): FailingStream[A] = this
override def sequential(): FailingStream[A] = this
override def parallel(): FailingStream[A] = this
override def isParallel: Boolean = false
override def spliterator(): util.Spliterator[A] = ???
override def onClose(closeHandler: Runnable): FailingStream[A] = ???
override def iterator(): util.Iterator[A] = new util.Iterator[A] {
override def next(): A = throw new RuntimeException("ouch")
override def hasNext: Boolean = true
}
override def close(): Unit = closed = true
}
Await.ready(StreamConverters.fromJavaStream(() => new FailingStream[Unit]).runWith(Sink.ignore), 3.seconds)
closed should ===(true)
}
}
"Source pre-materialization" must { "Source pre-materialization" must {
"materialize the source and connect it to a publisher" in { "materialize the source and connect it to a publisher" in {

View file

@ -0,0 +1,309 @@
/*
* Copyright (C) 2009-2019 Lightbend Inc. <https://www.lightbend.com>
*/
package akka.stream.scaladsl
import java.util
import java.util.function.{ BiConsumer, BinaryOperator, Supplier, ToIntFunction }
import java.util.stream.Collector.Characteristics
import java.util.stream.{ BaseStream, Collector, Collectors }
import akka.stream.ActorMaterializer
import akka.stream.testkit.StreamSpec
import akka.stream.testkit.Utils.TE
import akka.testkit.DefaultTimeout
import org.scalatest.time.{ Millis, Span }
import scala.concurrent.Await
import scala.concurrent.duration._
class StreamConvertersSpec extends StreamSpec with DefaultTimeout {
implicit val materializer = ActorMaterializer()
implicit val config = PatienceConfig(timeout = Span(timeout.duration.toMillis, Millis))
"Java Stream source" must {
import scala.compat.java8.FunctionConverters._
import java.util.stream.{ IntStream, Stream }
def javaStreamInts =
IntStream.iterate(1, { i: Int =>
i + 1
}.asJava)
"work with Java collections" in {
val list = new java.util.LinkedList[Integer]()
list.add(0)
list.add(1)
list.add(2)
StreamConverters.fromJavaStream(() => list.stream()).map(_.intValue).runWith(Sink.seq).futureValue should ===(
List(0, 1, 2))
}
"work with primitive streams" in {
StreamConverters
.fromJavaStream(() => IntStream.rangeClosed(1, 10))
.map(_.intValue)
.runWith(Sink.seq)
.futureValue should ===(1 to 10)
}
"work with an empty stream" in {
StreamConverters.fromJavaStream(() => Stream.empty[Int]()).runWith(Sink.seq).futureValue should ===(Nil)
}
"work with an infinite stream" in {
StreamConverters.fromJavaStream(() => javaStreamInts).take(1000).runFold(0)(_ + _).futureValue should ===(500500)
}
"work with a filtered stream" in {
StreamConverters
.fromJavaStream(() =>
javaStreamInts.filter({ i: Int =>
i % 2 == 0
}.asJava))
.take(1000)
.runFold(0)(_ + _)
.futureValue should ===(1001000)
}
"properly report errors during iteration" in {
import akka.stream.testkit.Utils.TE
// Filtering is lazy on Java Stream
val failyFilter: Int => Boolean = _ => throw TE("failing filter")
a[TE] must be thrownBy {
Await.result(
StreamConverters.fromJavaStream(() => javaStreamInts.filter(failyFilter.asJava)).runWith(Sink.ignore),
3.seconds)
}
}
"close the underlying stream when completed" in {
@volatile var closed = false
final class EmptyStream[A] extends BaseStream[A, EmptyStream[A]] {
override def unordered(): EmptyStream[A] = this
override def sequential(): EmptyStream[A] = this
override def parallel(): EmptyStream[A] = this
override def isParallel: Boolean = false
override def spliterator(): util.Spliterator[A] = ???
override def onClose(closeHandler: Runnable): EmptyStream[A] = ???
override def iterator(): util.Iterator[A] = new util.Iterator[A] {
override def next(): A = ???
override def hasNext: Boolean = false
}
override def close(): Unit = closed = true
}
Await.ready(StreamConverters.fromJavaStream(() => new EmptyStream[Unit]).runWith(Sink.ignore), 3.seconds)
closed should ===(true)
}
"close the underlying stream when failed" in {
@volatile var closed = false
final class FailingStream[A] extends BaseStream[A, FailingStream[A]] {
override def unordered(): FailingStream[A] = this
override def sequential(): FailingStream[A] = this
override def parallel(): FailingStream[A] = this
override def isParallel: Boolean = false
override def spliterator(): util.Spliterator[A] = ???
override def onClose(closeHandler: Runnable): FailingStream[A] = ???
override def iterator(): util.Iterator[A] = new util.Iterator[A] {
override def next(): A = throw new TE("ouch")
override def hasNext: Boolean = true
}
override def close(): Unit = closed = true
}
Await.ready(StreamConverters.fromJavaStream(() => new FailingStream[Unit]).runWith(Sink.ignore), 3.seconds)
closed should ===(true)
}
// Repeater for #24304
"not throw stack overflow with a large source" in {
Source
.repeat(Integer.valueOf(1))
.take(100000)
.runWith(StreamConverters.javaCollector[Integer, Integer] { () =>
Collectors.summingInt(new ToIntFunction[Integer] {
def applyAsInt(value: Integer): Int = value
})
})
.futureValue
}
"not share collector across materializations" in {
val stream = Source
.repeat(1)
.take(10)
.toMat(StreamConverters.javaCollector[Int, Integer] { () =>
Collectors.summingInt(new ToIntFunction[Int] {
def applyAsInt(value: Int): Int = value
})
})(Keep.right)
stream.run().futureValue should ===(Integer.valueOf(10))
stream.run().futureValue should ===(Integer.valueOf(10))
}
}
"Java collector Sink" must {
class TestCollector(
_supplier: () => Supplier[Array[Int]],
_accumulator: () => BiConsumer[Array[Int], Int],
_combiner: () => BinaryOperator[Array[Int]],
_finisher: () => java.util.function.Function[Array[Int], Int])
extends Collector[Int, Array[Int], Int] {
override def supplier(): Supplier[Array[Int]] = _supplier()
override def combiner(): BinaryOperator[Array[Int]] = _combiner()
override def finisher(): java.util.function.Function[Array[Int], Int] = _finisher()
override def accumulator(): BiConsumer[Array[Int], Int] = _accumulator()
override def characteristics(): util.Set[Characteristics] = util.Collections.emptySet()
}
val intIdentity: ToIntFunction[Int] = new ToIntFunction[Int] {
override def applyAsInt(value: Int): Int = value
}
def supplier(): Supplier[Array[Int]] = new Supplier[Array[Int]] {
override def get(): Array[Int] = new Array(1)
}
def accumulator(): BiConsumer[Array[Int], Int] = new BiConsumer[Array[Int], Int] {
override def accept(a: Array[Int], b: Int): Unit = a(0) = intIdentity.applyAsInt(b)
}
def combiner(): BinaryOperator[Array[Int]] = new BinaryOperator[Array[Int]] {
override def apply(a: Array[Int], b: Array[Int]): Array[Int] = {
a(0) += b(0); a
}
}
def finisher(): java.util.function.Function[Array[Int], Int] = new java.util.function.Function[Array[Int], Int] {
override def apply(a: Array[Int]): Int = a(0)
}
"work in the happy case" in {
Source(1 to 100)
.map(_.toString)
.runWith(StreamConverters.javaCollector(() => Collectors.joining(", ")))
.futureValue should ===((1 to 100).mkString(", "))
}
"work with an empty source" in {
Source
.empty[Int]
.map(_.toString)
.runWith(StreamConverters.javaCollector(() => Collectors.joining(", ")))
.futureValue should ===("")
}
"work parallelly in the happy case" in {
Source(1 to 100)
.runWith(StreamConverters.javaCollectorParallelUnordered(4)(() => Collectors.summingInt[Int](intIdentity)))
.futureValue
.toInt should ===(5050)
}
"work parallelly with an empty source" in {
Source
.empty[Int]
.map(_.toString)
.runWith(StreamConverters.javaCollectorParallelUnordered(4)(() => Collectors.joining(", ")))
.futureValue should ===("")
}
"be reusable" in {
val sink = StreamConverters.javaCollector[Int, Integer](() => Collectors.summingInt[Int](intIdentity))
Source(1 to 4).runWith(sink).futureValue.toInt should ===(10)
// Collector has state so it preserves all previous elements that went though
Source(4 to 6).runWith(sink).futureValue.toInt should ===(15)
}
"be reusable with parallel version" in {
val sink = StreamConverters.javaCollectorParallelUnordered(4)(() => Collectors.summingInt[Int](intIdentity))
Source(1 to 4).runWith(sink).futureValue.toInt should ===(10)
Source(4 to 6).runWith(sink).futureValue.toInt should ===(15)
}
"fail if getting the supplier fails" in {
def failedSupplier(): Supplier[Array[Int]] = throw TE("")
val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() =>
new TestCollector(failedSupplier _, accumulator _, combiner _, finisher _)))
a[TE] shouldBe thrownBy {
Await.result(future, 300.millis)
}
}
"fail if the supplier fails" in {
def failedSupplier(): Supplier[Array[Int]] = new Supplier[Array[Int]] {
override def get(): Array[Int] = throw TE("")
}
val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() =>
new TestCollector(failedSupplier _, accumulator _, combiner _, finisher _)))
a[TE] shouldBe thrownBy {
Await.result(future, 300.millis)
}
}
"fail if getting the accumulator fails" in {
def failedAccumulator(): BiConsumer[Array[Int], Int] = throw TE("")
val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() =>
new TestCollector(supplier _, failedAccumulator _, combiner _, finisher _)))
a[TE] shouldBe thrownBy {
Await.result(future, 300.millis)
}
}
"fail if the accumulator fails" in {
def failedAccumulator(): BiConsumer[Array[Int], Int] = new BiConsumer[Array[Int], Int] {
override def accept(a: Array[Int], b: Int): Unit = throw TE("")
}
val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() =>
new TestCollector(supplier _, failedAccumulator _, combiner _, finisher _)))
a[TE] shouldBe thrownBy {
Await.result(future, 300.millis)
}
}
"fail if getting the finisher fails" in {
def failedFinisher(): java.util.function.Function[Array[Int], Int] = throw TE("")
val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() =>
new TestCollector(supplier _, accumulator _, combiner _, failedFinisher _)))
a[TE] shouldBe thrownBy {
Await.result(future, 300.millis)
}
}
"fail if the finisher fails" in {
def failedFinisher(): java.util.function.Function[Array[Int], Int] =
new java.util.function.Function[Array[Int], Int] {
override def apply(a: Array[Int]): Int = throw TE("")
}
val future = Source(1 to 100).runWith(StreamConverters.javaCollector(() =>
new TestCollector(supplier _, accumulator _, combiner _, failedFinisher _)))
a[TE] shouldBe thrownBy {
Await.result(future, 300.millis)
}
}
}
}

View file

@ -140,3 +140,7 @@ ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.InputStreamPubl
# #19980 subscription timeouts for streams # #19980 subscription timeouts for streams
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.stream.impl.ActorProcessorImpl.subTimeoutHandling") ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.stream.impl.ActorProcessorImpl.subTimeoutHandling")
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.FanoutOutputs.this") ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.FanoutOutputs.this")
# #24304 stack overflow in StreamConverters
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("akka.stream.impl.CollectorState")
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("akka.stream.impl.ReducerState")

View file

@ -4,6 +4,8 @@
package akka.stream.impl package akka.stream.impl
import java.util.function.BinaryOperator
import scala.annotation.unchecked.uncheckedVariance import scala.annotation.unchecked.uncheckedVariance
import scala.collection.immutable import scala.collection.immutable
import scala.collection.mutable import scala.collection.mutable
@ -13,7 +15,6 @@ import scala.util.Failure
import scala.util.Success import scala.util.Success
import scala.util.Try import scala.util.Try
import scala.util.control.NonFatal import scala.util.control.NonFatal
import akka.NotUsed import akka.NotUsed
import akka.actor.ActorRef import akka.actor.ActorRef
import akka.actor.Props import akka.actor.Props
@ -428,21 +429,81 @@ import org.reactivestreams.Subscriber
} }
} }
/**
* INTERNAL API
*
* Helper class to be able to express collection as a fold using mutable data without
* accidentally sharing state between materializations
*/
@InternalApi private[akka] trait CollectorState[T, R] {
def accumulated(): Any
def update(elem: T): CollectorState[T, R]
def finish(): R
}
/** /**
* INTERNAL API * INTERNAL API
* *
* Helper class to be able to express collection as a fold using mutable data * Helper class to be able to express collection as a fold using mutable data
*/ */
@InternalApi private[akka] final class CollectorState[T, R](val collector: java.util.stream.Collector[T, Any, R]) { @InternalApi private[akka] final class FirstCollectorState[T, R](
lazy val accumulated = collector.supplier().get() collectorFactory: () => java.util.stream.Collector[T, Any, R])
private lazy val accumulator = collector.accumulator() extends CollectorState[T, R] {
def update(elem: T): CollectorState[T, R] = { override def update(elem: T): CollectorState[T, R] = {
// on first update, return a new mutable collector to ensure not
// sharing collector between streams
val collector = collectorFactory()
val accumulator = collector.accumulator()
val accumulated = collector.supplier().get()
accumulator.accept(accumulated, elem)
new MutableCollectorState(collector, accumulator, accumulated)
}
override def accumulated(): Any = {
// only called if it is asked about accumulated before accepting a first element
val collector = collectorFactory()
collector.supplier().get()
}
override def finish(): R = {
// only called if completed without elements
val collector = collectorFactory()
collector.finisher().apply(collector.supplier().get())
}
}
/**
* INTERNAL API
*
* Helper class to be able to express collection as a fold using mutable data
*/
@InternalApi private[akka] final class MutableCollectorState[T, R](
collector: java.util.stream.Collector[T, Any, R],
accumulator: java.util.function.BiConsumer[Any, T],
val accumulated: Any)
extends CollectorState[T, R] {
override def update(elem: T): CollectorState[T, R] = {
accumulator.accept(accumulated, elem) accumulator.accept(accumulated, elem)
this this
} }
def finish(): R = collector.finisher().apply(accumulated) override def finish(): R = {
// only called if completed without elements
collector.finisher().apply(accumulated)
}
}
/**
* INTERNAL API
*
* Helper class to be able to express reduce as a fold for parallel collector without
* accidentally sharing state between materializations
*/
@InternalApi private[akka] trait ReducerState[T, R] {
def update(batch: Any): ReducerState[T, R]
def finish(): R
} }
/** /**
@ -450,13 +511,38 @@ import org.reactivestreams.Subscriber
* *
* Helper class to be able to express reduce as a fold for parallel collector * Helper class to be able to express reduce as a fold for parallel collector
*/ */
@InternalApi private[akka] final class ReducerState[T, R](val collector: java.util.stream.Collector[T, Any, R]) { @InternalApi private[akka] final class FirstReducerState[T, R](
private var reduced: Any = null.asInstanceOf[Any] collectorFactory: () => java.util.stream.Collector[T, Any, R])
private lazy val combiner = collector.combiner() extends ReducerState[T, R] {
def update(batch: Any): ReducerState[T, R] = { def update(batch: Any): ReducerState[T, R] = {
if (reduced == null) reduced = batch // on first update, return a new mutable collector to ensure not
else reduced = combiner(reduced, batch) // sharing collector between streams
val collector = collectorFactory()
val combiner = collector.combiner()
new MutableReducerState(collector, combiner, batch)
}
def finish(): R = {
// only called if completed without elements
val collector = collectorFactory()
collector.finisher().apply(null)
}
}
/**
* INTERNAL API
*
* Helper class to be able to express reduce as a fold for parallel collector
*/
@InternalApi private[akka] final class MutableReducerState[T, R](
val collector: java.util.stream.Collector[T, Any, R],
val combiner: BinaryOperator[Any],
var reduced: Any)
extends ReducerState[T, R] {
def update(batch: Any): ReducerState[T, R] = {
reduced = combiner(reduced, batch)
this this
} }

View file

@ -111,11 +111,14 @@ object StreamConverters {
*/ */
def javaCollector[T, R](collectorFactory: () => java.util.stream.Collector[T, _ <: Any, R]): Sink[T, Future[R]] = def javaCollector[T, R](collectorFactory: () => java.util.stream.Collector[T, _ <: Any, R]): Sink[T, Future[R]] =
Flow[T] Flow[T]
.fold(() => new CollectorState[T, R](collectorFactory().asInstanceOf[Collector[T, Any, R]])) { .fold {
(state, elem) => () => new FirstCollectorState[T, R](collectorFactory.asInstanceOf[() => java.util.stream.Collector[T, Any, R]]): CollectorState[
state().update(elem) T,
R]
} { (state, elem) =>
state.update(elem)
} }
.map(state => state().finish()) .map(state => state.finish())
.toMat(Sink.head)(Keep.right) .toMat(Sink.head)(Keep.right)
.withAttributes(DefaultAttributes.javaCollector) .withAttributes(DefaultAttributes.javaCollector)
@ -136,14 +139,14 @@ object StreamConverters {
Sink Sink
.fromGraph(GraphDSL.create(Sink.head[R]) { implicit b => sink => .fromGraph(GraphDSL.create(Sink.head[R]) { implicit b => sink =>
import GraphDSL.Implicits._ import GraphDSL.Implicits._
val collector = collectorFactory().asInstanceOf[Collector[T, Any, R]] val factory = collectorFactory.asInstanceOf[() => Collector[T, Any, R]]
val balance = b.add(Balance[T](parallelism)) val balance = b.add(Balance[T](parallelism))
val merge = b.add(Merge[() => CollectorState[T, R]](parallelism)) val merge = b.add(Merge[CollectorState[T, R]](parallelism))
for (i <- 0 until parallelism) { for (i <- 0 until parallelism) {
val worker = Flow[T] val worker = Flow[T]
.fold(() => new CollectorState(collector)) { (state, elem) => () => .fold(new FirstCollectorState(factory): CollectorState[T, R]) { (state, elem) =>
state().update(elem) state.update(elem)
} }
.async .async
@ -151,10 +154,10 @@ object StreamConverters {
} }
merge.out merge.out
.fold(() => new ReducerState(collector)) { (state, elem) => () => .fold(new FirstReducerState(factory): ReducerState[T, R]) { (state, elem) =>
state().update(elem().accumulated) state.update(elem.accumulated())
} }
.map(state => state().finish()) ~> sink.in .map(state => state.finish()) ~> sink.in
SinkShape(balance.in) SinkShape(balance.in)
}) })