From 4bf10b77dc2bfb7e59bfcee68e33862e0f882d03 Mon Sep 17 00:00:00 2001 From: Yoshitaka Fujii Date: Mon, 15 Apr 2019 13:11:43 +0200 Subject: [PATCH] [#24220] Supports Supervision in Partition stage extract pullIfAnyOutIsAvailable --- .../stream/scaladsl/GraphPartitionSpec.scala | 75 +++++++++++++++++++ .../scala/akka/stream/javadsl/Graph.scala | 2 + .../scala/akka/stream/scaladsl/Graph.scala | 52 +++++++++---- 3 files changed, 115 insertions(+), 14 deletions(-) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPartitionSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPartitionSpec.scala index 54b28ceb19..0c1147e422 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPartitionSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPartitionSpec.scala @@ -10,6 +10,10 @@ import akka.stream.testkit.scaladsl.StreamTestKit._ import scala.concurrent.Await import scala.concurrent.duration._ +import akka.stream.ActorAttributes +import akka.stream.Supervision +import akka.stream.testkit.Utils.TE + class GraphPartitionSpec extends StreamSpec { val settings = ActorMaterializerSettings(system).withInputBuffer(initialSize = 2, maxSize = 16) @@ -220,6 +224,77 @@ class GraphPartitionSpec extends StreamSpec { "partitioner must return an index in the range [0,1]. returned: [-1] for input [java.lang.Integer].")) } + "partition to three subscribers, with Resume supervision" in assertAllStagesStopped { + + val (s1, s2, s3) = RunnableGraph + .fromGraph(GraphDSL.create(Sink.seq[Int], Sink.seq[Int], Sink.seq[Int])(Tuple3.apply) { + implicit b => (sink1, sink2, sink3) => + val partition = b.add(Partition[Int](3, { + case g if g > 3 => 0 + case l if l < 3 => 1 + case e if e == 3 => throw TE("Resume") + })) + Source(List(1, 2, 3, 4, 5)) ~> partition.in + partition.out(0) ~> sink1.in + partition.out(1) ~> sink2.in + partition.out(2) ~> sink3.in + ClosedShape + }) + .withAttributes(ActorAttributes.supervisionStrategy(_ => Supervision.Resume)) + .run() + + s1.futureValue.toSet should ===(Set(4, 5)) + s2.futureValue.toSet should ===(Set(1, 2)) + s3.futureValue.toSet should ===(Set()) + } + + "partition to three subscribers, with Restart supervision" in assertAllStagesStopped { + val (s1, s2, s3) = RunnableGraph + .fromGraph(GraphDSL.create(Sink.seq[Int], Sink.seq[Int], Sink.seq[Int])(Tuple3.apply) { + implicit b => (sink1, sink2, sink3) => + val partition = b.add(Partition[Int](3, { + case g if g > 3 => 0 + case l if l < 3 => 1 + case e if e == 3 => throw TE("Restart") + })) + Source(List(1, 2, 3, 4, 5)) ~> partition.in + partition.out(0) ~> sink1.in + partition.out(1) ~> sink2.in + partition.out(2) ~> sink3.in + ClosedShape + }) + .withAttributes(ActorAttributes.supervisionStrategy(_ => Supervision.Restart)) + .run() + + s1.futureValue.toSet should ===(Set(4, 5)) + s2.futureValue.toSet should ===(Set(1, 2)) + s3.futureValue.toSet should ===(Set()) + } + + "support supervision for PartitionOutOfBoundsException" in assertAllStagesStopped { + + val (s1, s2, s3) = RunnableGraph + .fromGraph(GraphDSL.create(Sink.seq[Int], Sink.seq[Int], Sink.seq[Int])(Tuple3.apply) { + implicit b => (sink1, sink2, sink3) => + val partition = b.add(Partition[Int](3, { + case g if g > 3 => 0 + case l if l < 3 => 1 + case e if e == 3 => -1 // out of bounds + })) + Source(List(1, 2, 3, 4, 5)) ~> partition.in + partition.out(0) ~> sink1.in + partition.out(1) ~> sink2.in + partition.out(2) ~> sink3.in + ClosedShape + }) + .withAttributes(ActorAttributes.supervisionStrategy(_ => Supervision.Resume)) + .run() + + s1.futureValue.toSet should ===(Set(4, 5)) + s2.futureValue.toSet should ===(Set(1, 2)) + s3.futureValue.toSet should ===(Set()) + } + } "divertTo must send matching elements to the sink" in assertAllStagesStopped { diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala index c65126bdae..a8083f41e3 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Graph.scala @@ -215,6 +215,8 @@ object Broadcast { * Fan-out the stream to several streams. emitting an incoming upstream element to one downstream consumer according * to the partitioner function applied to the element * + * Adheres to the [[ActorAttributes.SupervisionStrategy]] attribute. + * * '''Emits when''' all of the outputs stops backpressuring and there is an input element available * * '''Backpressures when''' one of the outputs backpressure diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala index afd9d1b3f2..e3f712bf9b 100755 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala @@ -15,13 +15,14 @@ import akka.stream.impl.fusing.GraphStages import akka.stream.scaladsl.Partition.PartitionOutOfBoundsException import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler } import akka.util.ConstantFun - import scala.annotation.tailrec import scala.annotation.unchecked.uncheckedVariance import scala.collection.{ immutable, mutable } import scala.concurrent.Promise import scala.util.control.{ NoStackTrace, NonFatal } +import akka.stream.ActorAttributes.SupervisionStrategy + /** * INTERNAL API * @@ -770,6 +771,8 @@ object Partition { * Fan-out the stream to several streams. emitting an incoming upstream element to one downstream consumer according * to the partitioner function applied to the element * + * Adheres to the [[ActorAttributes.SupervisionStrategy]] attribute. + * * '''Emits when''' emits when an element is available from the input and the chosen output has demand * * '''Backpressures when''' the currently chosen output back-pressures @@ -793,27 +796,48 @@ final class Partition[T](val outputPorts: Int, val partitioner: T => Int, val ea override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler { + lazy val decider = inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider private var outPendingElem: Any = null private var outPendingIdx: Int = _ private var downstreamRunning = outputPorts def onPush() = { val elem = grab(in) - val idx = partitioner(elem) - if (idx < 0 || idx >= outputPorts) { - failStage(PartitionOutOfBoundsException( - s"partitioner must return an index in the range [0,${outputPorts - 1}]. returned: [$idx] for input [${elem.getClass.getName}].")) - } else if (!isClosed(out(idx))) { - if (isAvailable(out(idx))) { - push(out(idx), elem) - if (out.exists(isAvailable(_))) - pull(in) - } else { - outPendingElem = elem - outPendingIdx = idx + val idx = + try { + val i = partitioner(elem) + if (i < 0 || i >= outputPorts) + throw PartitionOutOfBoundsException( + s"partitioner must return an index in the range [0,${outputPorts - 1}]. returned: [$i] for input " + + s"[${elem.getClass.getName}].") + else i + } catch { + case NonFatal(ex) => + decider(ex) match { + case Supervision.Stop => failStage(ex) + case Supervision.Restart => pull(in) + case Supervision.Resume => pull(in) + } + Int.MinValue } - } else if (out.exists(isAvailable(_))) + if (idx != Int.MinValue) { + if (!isClosed(out(idx))) { + if (isAvailable(out(idx))) { + push(out(idx), elem) + pullIfAnyOutIsAvailable() + } else { + outPendingElem = elem + outPendingIdx = idx + } + + } else + pullIfAnyOutIsAvailable() + } + } + + private def pullIfAnyOutIsAvailable(): Unit = { + if (out.exists(isAvailable(_))) pull(in) }