=str #19591 updated 'partition to three subscribers' to use Sink.seq and added a new test to cover a case where partitioner return an index out of bounds
This commit is contained in:
parent
3d9ea4415f
commit
a4d96cd4e2
2 changed files with 39 additions and 27 deletions
|
|
@ -7,10 +7,11 @@ import akka.stream.testkit._
|
|||
import akka.stream.testkit.scaladsl.TestSink
|
||||
import akka.stream.{ OverflowStrategy, ActorMaterializer, ActorMaterializerSettings, ClosedShape }
|
||||
import akka.stream.testkit.Utils._
|
||||
import org.scalatest.concurrent.ScalaFutures
|
||||
import scala.concurrent.Await
|
||||
import scala.concurrent.duration._
|
||||
|
||||
class GraphPartitionSpec extends AkkaSpec {
|
||||
class GraphPartitionSpec extends AkkaSpec with ScalaFutures {
|
||||
|
||||
val settings = ActorMaterializerSettings(system)
|
||||
.withInputBuffer(initialSize = 2, maxSize = 16)
|
||||
|
|
@ -21,34 +22,25 @@ class GraphPartitionSpec extends AkkaSpec {
|
|||
import GraphDSL.Implicits._
|
||||
|
||||
"partition to three subscribers" in assertAllStagesStopped {
|
||||
val c1 = TestSubscriber.probe[Int]()
|
||||
val c2 = TestSubscriber.probe[Int]()
|
||||
val c3 = TestSubscriber.probe[Int]()
|
||||
|
||||
RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||
val partition = b.add(Partition[Int](3, {
|
||||
case g if (g > 3) ⇒ 0
|
||||
case l if (l < 3) ⇒ 1
|
||||
case e if (e == 3) ⇒ 2
|
||||
}))
|
||||
Source(List(1, 2, 3, 4, 5)) ~> partition.in
|
||||
partition.out(0) ~> Sink.fromSubscriber(c1)
|
||||
partition.out(1) ~> Sink.fromSubscriber(c2)
|
||||
partition.out(2) ~> Sink.fromSubscriber(c3)
|
||||
ClosedShape
|
||||
val (s1, s2, s3) = RunnableGraph.fromGraph(GraphDSL.create(Sink.seq[Int], Sink.seq[Int], Sink.seq[Int])(Tuple3.apply) { implicit b ⇒
|
||||
(sink1, sink2, sink3) ⇒
|
||||
val partition = b.add(Partition[Int](3, {
|
||||
case g if (g > 3) ⇒ 0
|
||||
case l if (l < 3) ⇒ 1
|
||||
case e if (e == 3) ⇒ 2
|
||||
}))
|
||||
Source(List(1, 2, 3, 4, 5)) ~> partition.in
|
||||
partition.out(0) ~> sink1.in
|
||||
partition.out(1) ~> sink2.in
|
||||
partition.out(2) ~> sink3.in
|
||||
ClosedShape
|
||||
}).run()
|
||||
|
||||
c2.request(2)
|
||||
c1.request(2)
|
||||
c3.request(1)
|
||||
c2.expectNext(1)
|
||||
c2.expectNext(2)
|
||||
c3.expectNext(3)
|
||||
c1.expectNext(4)
|
||||
c1.expectNext(5)
|
||||
c1.expectComplete()
|
||||
c2.expectComplete()
|
||||
c3.expectComplete()
|
||||
s1.futureValue.toSet should ===(Set(4, 5))
|
||||
s2.futureValue.toSet should ===(Set(1, 2))
|
||||
s3.futureValue.toSet should ===(Set(3))
|
||||
|
||||
}
|
||||
|
||||
"complete stage after upstream completes" in assertAllStagesStopped {
|
||||
|
|
@ -171,5 +163,21 @@ class GraphPartitionSpec extends AkkaSpec {
|
|||
c2.expectComplete()
|
||||
}
|
||||
|
||||
"must fail stage if partitioner outcome is out of bound" in assertAllStagesStopped {
|
||||
|
||||
val c1 = TestSubscriber.probe[Int]()
|
||||
|
||||
RunnableGraph.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||
val partition = b.add(Partition[Int](2, { case l if l < 0 ⇒ -1; case _ ⇒ 0 }))
|
||||
Source(List(-3)) ~> partition.in
|
||||
partition.out(0) ~> Sink.fromSubscriber(c1)
|
||||
partition.out(1) ~> Sink.ignore
|
||||
ClosedShape
|
||||
}).run()
|
||||
|
||||
c1.request(1)
|
||||
c1.expectError(Partition.PartitionOutOfBoundsException("partitioner must return an index in the range [0,1]. returned: [-1] for input [java.lang.Integer]."))
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,10 +10,12 @@ import akka.stream.impl.fusing.GraphStages
|
|||
import akka.stream.impl.fusing.GraphStages.MaterializedValueSource
|
||||
import akka.stream.impl.Stages.{ DefaultAttributes, StageModule, SymbolicStage }
|
||||
import akka.stream.impl.StreamLayout._
|
||||
import akka.stream.scaladsl.Partition.PartitionOutOfBoundsException
|
||||
import akka.stream.stage.{ OutHandler, InHandler, GraphStageLogic, GraphStage }
|
||||
import scala.annotation.unchecked.uncheckedVariance
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.immutable
|
||||
import scala.util.control.NoStackTrace
|
||||
|
||||
object Merge {
|
||||
/**
|
||||
|
|
@ -469,6 +471,8 @@ final class Broadcast[T](private val outputPorts: Int, eagerCancel: Boolean) ext
|
|||
|
||||
object Partition {
|
||||
|
||||
case class PartitionOutOfBoundsException(msg:String) extends RuntimeException(msg) with NoStackTrace
|
||||
|
||||
/**
|
||||
* Create a new `Partition` stage with the specified input type.
|
||||
*
|
||||
|
|
@ -508,7 +512,7 @@ final class Partition[T](outputPorts: Int, partitioner: T ⇒ Int) extends Graph
|
|||
val elem = grab(in)
|
||||
val idx = partitioner(elem)
|
||||
if (idx < 0 || idx >= outputPorts)
|
||||
failStage(new IndexOutOfBoundsException(s"partitioner must return an index in the range [0,${outputPorts - 1}]. returned: [$idx] for input [$elem]."))
|
||||
failStage(PartitionOutOfBoundsException(s"partitioner must return an index in the range [0,${outputPorts - 1}]. returned: [$idx] for input [${elem.getClass.getName}]."))
|
||||
else if (!isClosed(out(idx))) {
|
||||
if (isAvailable(out(idx))) {
|
||||
push(out(idx), elem)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue