diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SinkForeachParallelSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SinkForeachParallelSpec.scala new file mode 100644 index 0000000000..66fbb7dcfd --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SinkForeachParallelSpec.scala @@ -0,0 +1,127 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.scaladsl + +import akka.stream.ActorFlowMaterializer +import akka.stream.ActorOperationAttributes._ +import akka.stream.Supervision._ +import akka.stream.testkit.Utils._ +import akka.stream.testkit.AkkaSpec +import akka.testkit.{ TestLatch, TestProbe } + +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.util.control.NoStackTrace + +class SinkForeachParallelSpec extends AkkaSpec { + + implicit val mat = ActorFlowMaterializer() + + "A ForeachParallel" must { + "produce elements in the order they are ready" in assertAllStagesStopped { + implicit val ec = system.dispatcher + + val probe = TestProbe() + val latch = (1 to 4).map(_ -> TestLatch(1)).toMap + val p = Source(1 to 4).runWith(Sink.foreachParallel(4)((n: Int) ⇒ { + Await.ready(latch(n), 5.seconds) + probe.ref ! n + })) + latch(2).countDown() + probe.expectMsg(2) + latch(4).countDown() + probe.expectMsg(4) + latch(3).countDown() + probe.expectMsg(3) + + assert(!p.isCompleted) + + latch(1).countDown() + probe.expectMsg(1) + + Await.result(p, 4.seconds) + assert(p.isCompleted) + } + + "not run more functions in parallel then specified" in { + implicit val ec = system.dispatcher + + val probe = TestProbe() + val latch = (1 to 5).map(_ -> TestLatch()).toMap + + val p = Source(1 to 5).runWith(Sink.foreachParallel(4)((n: Int) ⇒ { + probe.ref ! n + Await.ready(latch(n), 5.seconds) + })) + probe.expectMsgAllOf(1, 2, 3, 4) + probe.expectNoMsg(200.millis) + + assert(!p.isCompleted) + + for (i ← 1 to 4) latch(i).countDown() + + latch(5).countDown() + probe.expectMsg(5) + + Await.result(p, 5.seconds) + assert(p.isCompleted) + + } + + "resume after function failure" in assertAllStagesStopped { + implicit val ec = system.dispatcher + + val probe = TestProbe() + val latch = TestLatch(1) + + val p = Source(1 to 5).runWith(Sink.foreachParallel(4)((n: Int) ⇒ { + if (n == 3) throw new RuntimeException("err1") with NoStackTrace + else { + probe.ref ! n + Await.ready(latch, 10.seconds) + } + }).withAttributes(supervisionStrategy(resumingDecider))) + p.onFailure { case e ⇒ assert(e.getMessage.equals("err1")); Unit } + + latch.countDown() + probe.expectMsgAllOf(1, 2, 4) + + Await.result(p, 5.seconds) + assert(p.isCompleted) + } + + "finish after function thrown exception" in assertAllStagesStopped { + val probe = TestProbe() + val latch = TestLatch(1) + + implicit val ec = system.dispatcher + val p = Source(1 to 5).runWith(Sink.foreachParallel(3)((n: Int) ⇒ { + if (n == 3) throw new RuntimeException("err2") with NoStackTrace + else { + probe.ref ! n + Await.ready(latch, 10.seconds) + } + }).withAttributes(supervisionStrategy(stoppingDecider))) + p.onFailure { case e ⇒ assert(e.getMessage.equals("err2")); Unit } + p.onSuccess { case _ ⇒ fail() } + + latch.countDown() + probe.expectMsgAllOf(1, 2) + + Await.ready(p, 1.seconds) + + assert(p.isCompleted) + } + + "handle empty source" in assertAllStagesStopped { + implicit val ec = system.dispatcher + + val p = Source(List.empty[Int]).runWith(Sink.foreachParallel(3)(a ⇒ ())) + + Await.result(p, 200.seconds) + } + + } + +} diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Sink.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Sink.scala index 69e2e66817..e6a556463b 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Sink.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Sink.scala @@ -9,7 +9,7 @@ import akka.stream.impl.StreamLayout import akka.stream.{ javadsl, scaladsl, _ } import org.reactivestreams.{ Publisher, Subscriber } -import scala.concurrent.Future +import scala.concurrent.{ ExecutionContext, Future } import scala.util.Try /** Java API */ @@ -65,6 +65,20 @@ object Sink { */ def foreach[T](f: function.Procedure[T]): Sink[T, Future[Unit]] = new Sink(scaladsl.Sink.foreach(f.apply)) + + /** + * A `Sink` that will invoke the given procedure for each received element in parallel. The sink is materialized + * into a [[scala.concurrent.Future]]. + * + * If `f` throws an exception and the supervision decision is + * [[akka.stream.Supervision.Stop]] the `Future` will be completed with failure. + * + * If `f` throws an exception and the supervision decision is + * [[akka.stream.Supervision.Resume]] or [[akka.stream.Supervision.Restart]] the + * element is dropped and the stream continues. + */ + def foreachParallel[T](parallel: Int)(f: function.Procedure[T])(ec: ExecutionContext): Sink[T, Future[Unit]] = + new Sink(scaladsl.Sink.foreachParallel(parallel)(f.apply)(ec)) /** * A `Sink` that materializes into a [[org.reactivestreams.Publisher]] diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Sink.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Sink.scala index fac3b86d81..399cdfd805 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Sink.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Sink.scala @@ -6,10 +6,9 @@ package akka.stream.scaladsl import akka.stream.javadsl import akka.actor.{ ActorRef, Props } import akka.stream._ -import akka.stream.impl.Stages.DefaultAttributes +import akka.stream.impl.Stages.{ MapAsyncUnordered, DefaultAttributes } import akka.stream.impl.StreamLayout.Module import akka.stream.impl._ -import akka.stream.impl.Stages.DefaultAttributes import akka.stream.stage.Context import akka.stream.stage.PushStage import akka.stream.stage.SyncDirective @@ -18,7 +17,7 @@ import akka.stream.OperationAttributes._ import akka.stream.stage.{ TerminationDirective, Directive, Context, PushStage } import org.reactivestreams.{ Publisher, Subscriber } -import scala.concurrent.{ Future, Promise } +import scala.concurrent.{ ExecutionContext, Future, Promise } import scala.util.{ Failure, Success, Try } /** @@ -140,6 +139,24 @@ object Sink extends SinkApply { Flow[T].transformMaterializing(newForeachStage).to(Sink.ignore).named("foreachSink") } + /** + * A `Sink` that will invoke the given function to each of the elements + * as they pass in. The sink is materialized into a [[scala.concurrent.Future]] + * + * If `f` throws an exception and the supervision decision is + * [[akka.stream.Supervision.Stop]] the `Future` will be completed with failure. + * + * If `f` throws an exception and the supervision decision is + * [[akka.stream.Supervision.Resume]] or [[akka.stream.Supervision.Restart]] the + * element is dropped and the stream continues. + * + * @see [[#mapAsyncUnordered]] + */ + def foreachParallel[T](parallelism: Int)(f: T ⇒ Unit)(implicit ec: ExecutionContext): Sink[T, Future[Unit]] = + Flow[T].andThen( + MapAsyncUnordered(parallelism, + { out: T ⇒ Future(f(out)) }.asInstanceOf[Any ⇒ Future[Unit]])).toMat(Sink.ignore)(Keep.right) + /** * A `Sink` that will invoke the given function for every received element, giving it its previous * output (or the given `zero` value) and the element as input. @@ -193,10 +210,12 @@ object Sink extends SinkApply { def newOnCompleteStage(): PushStage[T, Unit] = { new PushStage[T, Unit] { override def onPush(elem: T, ctx: Context[Unit]): SyncDirective = ctx.pull() + override def onUpstreamFailure(cause: Throwable, ctx: Context[Unit]): TerminationDirective = { callback(Failure(cause)) ctx.fail(cause) } + override def onUpstreamFinish(ctx: Context[Unit]): TerminationDirective = { callback(Success[Unit](())) ctx.finish()