diff --git a/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java b/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java index c3ad09a65b..846d85013a 100644 --- a/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java +++ b/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java @@ -533,6 +533,61 @@ public class FlowTest extends StreamTest { assertEquals("ABC", result); } + @Test + public void mustBeAbleToUseBatch() throws Exception { + final JavaTestKit probe = new JavaTestKit(system); + final List input = Arrays.asList("A", "B", "C"); + final Flow flow = Flow.of(String.class).batch(3L, new Function() { + @Override + public String apply(String s) throws Exception { + return s; + } + }, new Function2() { + @Override + public String apply(String aggr, String in) throws Exception { + return aggr + in; + } + }); + Future future = Source.from(input).via(flow).runFold("", new Function2() { + @Override + public String apply(String aggr, String in) throws Exception { + return aggr + in; + } + }, materializer); + String result = Await.result(future, probe.dilated(FiniteDuration.create(3, TimeUnit.SECONDS))); + assertEquals("ABC", result); + } + + @Test + public void mustBeAbleToUseBatchWeighted() throws Exception { + final JavaTestKit probe = new JavaTestKit(system); + final List input = Arrays.asList("A", "B", "C"); + final Flow flow = Flow.of(String.class).batchWeighted(3L, new Function() { + @Override + public Object apply(String s) throws Exception { + return 1L; + } + }, new Function() { + @Override + public String apply(String s) throws Exception { + return s; + } + }, new Function2() { + @Override + public String apply(String aggr, String in) throws Exception { + return aggr + in; + } + }); + Future future = Source.from(input).via(flow).runFold("", new Function2() { + @Override + public String apply(String aggr, String in) throws Exception { + return aggr + in; + } + }, materializer); + String result = Await.result(future, probe.dilated(FiniteDuration.create(3, TimeUnit.SECONDS))); + assertEquals("ABC", result); + } + @Test public void mustBeAbleToUseExpand() throws Exception { final JavaTestKit probe = new JavaTestKit(system); diff --git a/akka-stream/src/main/scala/akka/stream/impl/ConstantFun.scala b/akka-stream/src/main/scala/akka/stream/impl/ConstantFun.scala index c06ca93c87..b60c360bad 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ConstantFun.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ConstantFun.scala @@ -20,4 +20,8 @@ private[akka] object ConstantFun { def javaIdentityFunction[T]: JFun[T, T] = JavaIdentityFunction.asInstanceOf[JFun[T, T]] def scalaIdentityFunction[T]: T ⇒ T = conforms + + def returnZero[T](t: T): Long = 0L + + def returnOne[T](t: T): Long = 1L } diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index 94592b6f67..e33ee8481c 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -48,6 +48,8 @@ private[stream] object Stages { val intersperse = name("intersperse") val buffer = name("buffer") val conflate = name("conflate") + val batch = name("batch") + val batchWeighted = name("batchWeighted") val expand = name("expand") val mapConcat = name("mapConcat") val detacher = name("detacher") diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 3817a5abe8..ca239ad4ca 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -478,27 +478,29 @@ private[akka] final case class Conflate[In, Out](seed: In ⇒ Out, aggregate: (O override def restart(): Conflate[In, Out] = copy() } -private[akka] sealed abstract class AbstractBatch[In, Out](max: Long, costFn: In ⇒ Long, seed: In ⇒ Out, - aggregate: (Out, In) ⇒ Out, val in: Inlet[In], - val out: Outlet[Out]) extends GraphStage[FlowShape[In, Out]] { +private[akka] final case class Batch[In, Out](max: Long, costFn: In ⇒ Long, seed: In ⇒ Out, aggregate: (Out, In) ⇒ Out) + extends GraphStage[FlowShape[In, Out]] { + + val in = Inlet[In]("Batch.in") + val out = Outlet[Out]("Batch.out") override val shape: FlowShape[In, Out] = FlowShape.of(in, out) override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { - private var agg: Any = null + private var agg: Out = null.asInstanceOf[Out] private var left: Long = max - private var pending: Any = null + private var pending: In = null.asInstanceOf[In] private def flush(): Unit = { - push(out, agg.asInstanceOf[Out]) + push(out, agg) left = max if (pending != null) { - val elem = pending.asInstanceOf[In] + val elem = pending agg = seed(elem) left -= costFn(elem) - pending = null + pending = null.asInstanceOf[In] } else { - agg = null + agg = null.asInstanceOf[Out] } } @@ -516,7 +518,7 @@ private[akka] sealed abstract class AbstractBatch[In, Out](max: Long, costFn: In pending = elem } else { left -= cost - agg = aggregate(agg.asInstanceOf[Out], elem) + agg = aggregate(agg, elem) } if (isAvailable(out)) flush() @@ -531,20 +533,16 @@ private[akka] sealed abstract class AbstractBatch[In, Out](max: Long, costFn: In setHandler(out, new OutHandler { override def onPull(): Unit = { - //if upstream finished, we still might emit up to 2 more elements (whatever agg is + possibly a pending heavy element) - if (isClosed(in)) { - if (agg == null) completeStage() + if (agg == null) { + if (isClosed(in)) completeStage() + else if (!hasBeenPulled(in)) pull(in) + } else if (isClosed(in)) { + push(out, agg) + if (pending == null) completeStage() else { - push(out, agg.asInstanceOf[Out]) - if (pending == null) completeStage() - else { - agg = seed(pending.asInstanceOf[In]) - pending = null - } + agg = seed(pending) + pending = null.asInstanceOf[In] } - } else if (agg == null) { - if (!hasBeenPulled(in)) - pull(in) } else { flush() if (!hasBeenPulled(in)) pull(in) @@ -554,101 +552,6 @@ private[akka] sealed abstract class AbstractBatch[In, Out](max: Long, costFn: In } } -private[akka] final case class BatchWeighted[I, O](max: Long, costFn: I ⇒ Long, seed: I ⇒ O, aggregate: (O, I) ⇒ O) - extends AbstractBatch(max, costFn, seed, aggregate, Inlet[I]("BatchWeighted.in"), Outlet[O]("BatchWeighted.out")) { - override def initialAttributes = Attributes.name("BatchWeighted") -} - -private[akka] final case class Batch[I, O](max: Long, seed: I ⇒ O, aggregate: (O, I) ⇒ O) - extends AbstractBatch(max, { _: I ⇒ 1L }, seed, aggregate, Inlet[I]("Batch.in"), Outlet[O]("Batch.out")) { - override def initialAttributes = Attributes.name("Batch") -} - -//private[akka] final case class Conflate[I, O](seed: I ⇒ O, aggregate: (O, I) ⇒ O) -// extends AbstractBatch(Long.MaxValue, { _: I ⇒ 0L }, seed, aggregate, Inlet[I]("Conflate.in"), Outlet[O]("Conflate.out")) { -// override def initialAttributes = Attributes.name("Conflate") -//} - -/** - * INTERNAL API - */ -private[akka] final case class AggregateWeighted[In, Out](max: Long, costFn: In ⇒ Long, seed: In ⇒ Out, - aggregate: (Out, In) ⇒ Out, - decider: Supervision.Decider) extends DetachedStage[In, Out] { - private var agg: Any = null - private var left: Long = max - private var pending: Any = null - - private[this] def flush(ctx: DetachedContext[Out]) = { - val result = agg.asInstanceOf[Out] - agg = null - left = max - if (pending != null) { - val elem = pending.asInstanceOf[In] - agg = seed(elem) - left -= costFn(elem) - pending = null - } - ctx.pushAndPull(result) - } - - override def onPush(elem: In, ctx: DetachedContext[Out]): UpstreamDirective = { - val cost = costFn(elem) - if (agg == null) { - left -= cost - agg = seed(elem) - } else if (left <= 0 || left - cost < 0) { - pending = elem - } else { - left -= cost - agg = aggregate(agg.asInstanceOf[Out], elem) - } - - if (!ctx.isHoldingDownstream && pending == null) ctx.pull() - else if (!ctx.isHoldingDownstream) ctx.holdUpstream() - else flush(ctx) - } - - override def onPull(ctx: DetachedContext[Out]): DownstreamDirective = { - //if ctx.isFinishing, we still might emit up to 2 more elements (whatever agg is + possibly a pending heavy element) - if (ctx.isFinishing) { - //agg != null since we already checked it in onUpstreamFinish - val result = agg.asInstanceOf[Out] - if (pending == null) ctx.pushAndFinish(result) - else { - val elem = pending.asInstanceOf[In] - agg = seed(elem) - pending = null - ctx.push(result) - } - } else if (ctx.isHoldingBoth) flush(ctx) - else if (agg == null) ctx.holdDownstream() - else { - val result = agg.asInstanceOf[Out] - left = max - if (pending != null) { - val elem = pending.asInstanceOf[In] - agg = seed(elem) - left -= costFn(elem) - pending = null - } else { - agg = null - } - if (ctx.isHoldingUpstream) ctx.pushAndPull(result) - else ctx.push(result) - } - } - - override def onUpstreamFinish(ctx: DetachedContext[Out]): TerminationDirective = { - if (agg == null) ctx.finish() - else ctx.absorbTermination() - } - - override def decide(t: Throwable): Supervision.Directive = decider(t) - - override def restart(): AggregateWeighted[In, Out] = copy() -} - /** * INTERNAL API */ diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala index 0f5deb8e14..136511f26b 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala @@ -834,7 +834,7 @@ final class Flow[-In, +Out, +Mat](delegate: scaladsl.Flow[In, Out, Mat]) extends * @param seed Provides the first state for a batched value using the first unconsumed element as a start * @param aggregate Takes the currently batched value and the current pending element to produce a new aggregate */ - def batch[S](max: Long, seed: function.Function[Out, S])(aggregate: function.Function2[S, Out, S]): javadsl.Flow[In, S, Mat] = + def batch[S](max: Long, seed: function.Function[Out, S], aggregate: function.Function2[S, Out, S]): javadsl.Flow[In, S, Mat] = new Flow(delegate.batch(max, seed.apply)(aggregate.apply)) /** @@ -865,7 +865,7 @@ final class Flow[-In, +Out, +Mat](delegate: scaladsl.Flow[In, Out, Mat]) extends * @param seed Provides the first state for a batched value using the first unconsumed element as a start * @param aggregate Takes the currently batched value and the current pending element to produce a new batch */ - def batchWeighted[S](max: Long, costFn: function.Function[Out, Long], seed: function.Function[Out, S])(aggregate: function.Function2[S, Out, S]): javadsl.Flow[In, S, Mat] = + def batchWeighted[S](max: Long, costFn: function.Function[Out, Long], seed: function.Function[Out, S], aggregate: function.Function2[S, Out, S]): javadsl.Flow[In, S, Mat] = new Flow(delegate.batchWeighted(max, costFn.apply, seed.apply)(aggregate.apply)) /** diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala index e0ada6f680..2aa378591e 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala @@ -1263,7 +1263,7 @@ final class Source[+Out, +Mat](delegate: scaladsl.Source[Out, Mat]) extends Grap * @param seed Provides the first state for a batched value using the first unconsumed element as a start * @param aggregate Takes the currently batched value and the current pending element to produce a new aggregate */ - def batch[S](max: Long, seed: function.Function[Out, S])(aggregate: function.Function2[S, Out, S]): javadsl.Source[S, Mat] = + def batch[S](max: Long, seed: function.Function[Out, S],aggregate: function.Function2[S, Out, S]): javadsl.Source[S, Mat] = new Source(delegate.batch(max, seed.apply)(aggregate.apply)) /** @@ -1294,7 +1294,7 @@ final class Source[+Out, +Mat](delegate: scaladsl.Source[Out, Mat]) extends Grap * @param seed Provides the first state for a batched value using the first unconsumed element as a start * @param aggregate Takes the currently batched value and the current pending element to produce a new batch */ - def batchWeighted[S](max: Long, costFn: function.Function[Out, Long], seed: function.Function[Out, S])(aggregate: function.Function2[S, Out, S]): javadsl.Source[S, Mat] = + def batchWeighted[S](max: Long, costFn: function.Function[Out, Long], seed: function.Function[Out, S],aggregate: function.Function2[S, Out, S]): javadsl.Source[S, Mat] = new Source(delegate.batchWeighted(max, costFn.apply, seed.apply)(aggregate.apply)) /** diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala b/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala index c58410c6a8..7198d01dc0 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala @@ -676,7 +676,7 @@ class SubFlow[-In, +Out, +Mat](delegate: scaladsl.SubFlow[Out, Mat, scaladsl.Flo * @param seed Provides the first state for a batched value using the first unconsumed element as a start * @param aggregate Takes the currently batched value and the current pending element to produce a new aggregate */ - def batch[S](max: Long, seed: function.Function[Out, S])(aggregate: function.Function2[S, Out, S]): SubFlow[In, S, Mat] = + def batch[S](max: Long, seed: function.Function[Out, S], aggregate: function.Function2[S, Out, S]): SubFlow[In, S, Mat] = new SubFlow(delegate.batch(max, seed.apply)(aggregate.apply)) /** @@ -707,7 +707,7 @@ class SubFlow[-In, +Out, +Mat](delegate: scaladsl.SubFlow[Out, Mat, scaladsl.Flo * @param seed Provides the first state for a batched value using the first unconsumed element as a start * @param aggregate Takes the currently batched value and the current pending element to produce a new batch */ - def batchWeighted[S](max: Long, costFn: function.Function[Out, Long], seed: function.Function[Out, S])(aggregate: function.Function2[S, Out, S]): SubFlow[In, S, Mat] = + def batchWeighted[S](max: Long, costFn: function.Function[Out, Long], seed: function.Function[Out, S], aggregate: function.Function2[S, Out, S]): SubFlow[In, S, Mat] = new SubFlow(delegate.batchWeighted(max, costFn.apply, seed.apply)(aggregate.apply)) /** diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala b/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala index 970408133c..08bec7646a 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala @@ -672,7 +672,7 @@ class SubSource[+Out, +Mat](delegate: scaladsl.SubFlow[Out, Mat, scaladsl.Source * @param seed Provides the first state for a batched value using the first unconsumed element as a start * @param aggregate Takes the currently batched value and the current pending element to produce a new aggregate */ - def batch[S](max: Long, seed: function.Function[Out, S])(aggregate: function.Function2[S, Out, S]): SubSource[S, Mat] = + def batch[S](max: Long, seed: function.Function[Out, S], aggregate: function.Function2[S, Out, S]): SubSource[S, Mat] = new SubSource(delegate.batch(max, seed.apply)(aggregate.apply)) /** @@ -703,7 +703,7 @@ class SubSource[+Out, +Mat](delegate: scaladsl.SubFlow[Out, Mat, scaladsl.Source * @param seed Provides the first state for a batched value using the first unconsumed element as a start * @param aggregate Takes the currently batched value and the current pending element to produce a new batch */ - def batchWeighted[S](max: Long, costFn: function.Function[Out, Long], seed: function.Function[Out, S])(aggregate: function.Function2[S, Out, S]): SubSource[S, Mat] = + def batchWeighted[S](max: Long, costFn: function.Function[Out, Long], seed: function.Function[Out, S], aggregate: function.Function2[S, Out, S]): SubSource[S, Mat] = new SubSource(delegate.batchWeighted(max, costFn.apply, seed.apply)(aggregate.apply)) /** diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index 9c8e93ece1..f3bbf6cdc6 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -906,7 +906,8 @@ trait FlowOps[+Out, +Mat] { * * See also [[FlowOps.limit]], [[FlowOps.limitWeighted]] [[FlowOps.batch]] [[FlowOps.batchWeighted]] */ - def conflate[S](seed: Out ⇒ S)(aggregate: (S, Out) ⇒ S): Repr[S] = andThen(Conflate(seed, aggregate)) + def conflate[S](seed: Out ⇒ S)(aggregate: (S, Out) ⇒ S): Repr[S] = //andThen(Conflate(seed, aggregate)) + via(Batch(1L, ConstantFun.returnZero[Out], seed, aggregate).withAttributes(DefaultAttributes.conflate)) /** * Allows a faster upstream to progress independently of a slower subscriber by aggregating elements into batches @@ -931,7 +932,7 @@ trait FlowOps[+Out, +Mat] { * @param aggregate Takes the currently batched value and the current pending element to produce a new aggregate */ def batch[S](max: Long, seed: Out ⇒ S)(aggregate: (S, Out) ⇒ S): Repr[S] = - via(Batch(max, seed, aggregate)) + via(Batch(max, ConstantFun.returnOne[Out], seed, aggregate).withAttributes(DefaultAttributes.batch)) /** * Allows a faster upstream to progress independently of a slower subscriber by aggregating elements into batches @@ -962,7 +963,7 @@ trait FlowOps[+Out, +Mat] { * @param aggregate Takes the currently batched value and the current pending element to produce a new batch */ def batchWeighted[S](max: Long, costFn: Out ⇒ Long, seed: Out ⇒ S)(aggregate: (S, Out) ⇒ S): Repr[S] = - via(BatchWeighted(max, costFn, seed, aggregate)) + via(Batch(max, costFn, seed, aggregate).withAttributes(DefaultAttributes.batchWeighted)) /** * Allows a faster downstream to progress independently of a slower publisher by extrapolating elements from an older