Introduces eager cancellation for divertTo

* and updates Partition to support this behavior
* not adding Partition.apply overload due to type inference issues, use constructor instead
This commit is contained in:
Viktor Klang (√) 2018-01-30 19:59:53 +01:00 committed by Patrik Nordwall
parent 5b2a4edd2c
commit 08b0d34a4c
9 changed files with 105 additions and 34 deletions

View file

@ -1085,6 +1085,8 @@ Each upstream element will either be diverted to the given sink, or the downstre
**completes** when upstream completes and no output is pending
**cancels** when any of the downstreams cancel
---------------------------------------------------------------
<br/>

View file

@ -89,13 +89,13 @@ class GraphPartitionSpec extends StreamSpec {
c2.expectComplete()
}
"cancel upstream when downstreams cancel" in assertAllStagesStopped {
"cancel upstream when all downstreams cancel if eagerCancel is false" in assertAllStagesStopped {
val p1 = TestPublisher.probe[Int]()
val c1 = TestSubscriber.probe[Int]()
val c2 = TestSubscriber.probe[Int]()
RunnableGraph.fromGraph(GraphDSL.create() { implicit b
val partition = b.add(Partition[Int](2, { case l if l < 6 0; case _ 1 }))
val partition = b.add(new Partition[Int](2, { case l if l < 6 0; case _ 1 }, false))
Source.fromPublisher(p1.getPublisher) ~> partition.in
partition.out(0) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c1)
partition.out(1) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c2)
@ -114,10 +114,38 @@ class GraphPartitionSpec extends StreamSpec {
p1Sub.sendNext(2)
c1.expectNext(2)
sub1.cancel()
p1Sub.sendNext(9)
c2.expectNext(9)
sub2.cancel()
p1Sub.expectCancellation()
}
"cancel upstream when any downstream cancel if eagerCancel is true" in assertAllStagesStopped {
val p1 = TestPublisher.probe[Int]()
val c1 = TestSubscriber.probe[Int]()
val c2 = TestSubscriber.probe[Int]()
RunnableGraph.fromGraph(GraphDSL.create() { implicit b
val partition = b.add(new Partition[Int](2, { case l if l < 6 0; case _ 1 }, true))
Source.fromPublisher(p1.getPublisher) ~> partition.in
partition.out(0) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c1)
partition.out(1) ~> Flow[Int].buffer(16, OverflowStrategy.backpressure) ~> Sink.fromSubscriber(c2)
ClosedShape
}).run()
val p1Sub = p1.expectSubscription()
val sub1 = c1.expectSubscription()
val sub2 = c2.expectSubscription()
sub1.request(3)
sub2.request(3)
p1Sub.sendNext(1)
p1Sub.sendNext(8)
c1.expectNext(1)
c2.expectNext(8)
sub1.cancel()
p1Sub.expectCancellation()
}
"work with merge" in assertAllStagesStopped {
val s = Sink.seq[Int]
val input = Set(5, 2, 9, 1, 1, 1, 10)
@ -190,4 +218,16 @@ class GraphPartitionSpec extends StreamSpec {
odd.expectComplete()
even.expectComplete()
}
"divertTo must cancel when any of the downstreams cancel" in assertAllStagesStopped {
val pub = TestPublisher.probe[Int]()
val odd = TestSubscriber.probe[Int]()
val even = TestSubscriber.probe[Int]()
Source.fromPublisher(pub.getPublisher).divertTo(Sink.fromSubscriber(odd), _ % 2 != 0).to(Sink.fromSubscriber(even)).run()
even.request(1)
pub.sendNext(2)
even.expectNext(2)
odd.cancel()
pub.expectCancellation()
}
}

View file

@ -1705,7 +1705,7 @@ final class Flow[-In, +Out, +Mat](delegate: scaladsl.Flow[In, Out, Mat]) extends
*
* '''Completes when''' upstream completes and no output is pending
*
* '''Cancels when''' when all downstreams cancel
* '''Cancels when''' any of the downstreams cancel
*/
def divertTo(that: Graph[SinkShape[Out], _], when: function.Predicate[Out]): javadsl.Flow[In, Out, Mat] =
new Flow(delegate.divertTo(that, when.test))

View file

@ -197,26 +197,48 @@ object Broadcast {
* '''Completes when''' upstream completes
*
* '''Cancels when'''
* when one of the downstreams cancel
* when any (eagerCancel=true) or all (eagerCancel=false) of the downstreams cancel
*/
object Partition {
/**
* Create a new `Partition` stage with the specified input type.
* Create a new `Partition` stage with the specified input type, `eagerCancel` is `false`.
*
* @param outputCount number of output ports
* @param partitioner function deciding which output each element will be targeted
*/
def create[T](outputCount: Int, partitioner: function.Function[T, Integer]): Graph[UniformFanOutShape[T, T], NotUsed] =
scaladsl.Partition(outputCount, partitioner = (t: T) partitioner.apply(t))
new scaladsl.Partition(outputCount, partitioner.apply)
/**
* Create a new `Partition` stage with the specified input type.
*
* @param outputCount number of output ports
* @param partitioner function deciding which output each element will be targeted
* @param eagerCancel this stage cancels, when any (true) or all (false) of the downstreams cancel
*/
def create[T](outputCount: Int, partitioner: function.Function[T, Integer], eagerCancel: Boolean): Graph[UniformFanOutShape[T, T], NotUsed] =
new scaladsl.Partition(outputCount, partitioner.apply, eagerCancel)
/**
* Create a new `Partition` stage with the specified input type, `eagerCancel` is `false`.
*
* @param clazz a type hint for this method
* @param outputCount number of output ports
* @param partitioner function deciding which output each element will be targeted
*/
def create[T](clazz: Class[T], outputCount: Int, partitioner: function.Function[T, Integer]): Graph[UniformFanOutShape[T, T], NotUsed] =
create(outputCount, partitioner)
new scaladsl.Partition(outputCount, partitioner.apply)
/**
* Create a new `Partition` stage with the specified input type.
*
* @param clazz a type hint for this method
* @param outputCount number of output ports
* @param partitioner function deciding which output each element will be targeted
* @param eagerCancel this stage cancels, when any (true) or all (false) of the downstreams cancel
*/
def create[T](clazz: Class[T], outputCount: Int, partitioner: function.Function[T, Integer], eagerCancel: Boolean): Graph[UniformFanOutShape[T, T], NotUsed] =
new scaladsl.Partition(outputCount, partitioner.apply, eagerCancel)
}

View file

@ -761,7 +761,7 @@ final class Source[+Out, +Mat](delegate: scaladsl.Source[Out, Mat]) extends Grap
*
* '''Completes when''' upstream completes and no output is pending
*
* '''Cancels when''' when all downstreams cancel
* '''Cancels when''' any of the downstreams cancel
*/
def divertTo(that: Graph[SinkShape[Out], _], when: function.Predicate[Out]): javadsl.Source[Out, Mat] =
new Source(delegate.divertTo(that, when.test))

View file

@ -1160,7 +1160,7 @@ class SubFlow[-In, +Out, +Mat](delegate: scaladsl.SubFlow[Out, Mat, scaladsl.Flo
*
* '''Completes when''' upstream completes and no output is pending
*
* '''Cancels when''' when all downstreams cancel
* '''Cancels when''' any of the downstreams cancel
*/
def divertTo(that: Graph[SinkShape[Out], _], when: function.Predicate[Out]): SubFlow[In, Out, Mat] =
new SubFlow(delegate.divertTo(that, when.test))

View file

@ -1152,7 +1152,7 @@ class SubSource[+Out, +Mat](delegate: scaladsl.SubFlow[Out, Mat, scaladsl.Source
*
* '''Completes when''' upstream completes and no output is pending
*
* '''Cancels when''' when all downstreams cancel
* '''Cancels when''' any of the downstreams cancel
*/
def divertTo(that: Graph[SinkShape[Out], _], when: function.Predicate[Out]): SubSource[Out, Mat] =
new SubSource(delegate.divertTo(that, when.test))

View file

@ -2308,14 +2308,14 @@ trait FlowOps[+Out, +Mat] {
*
* '''Completes when''' upstream completes and no output is pending
*
* '''Cancels when''' when all downstreams cancel
* '''Cancels when''' any of the downstreams cancel
*/
def divertTo(that: Graph[SinkShape[Out], _], when: Out Boolean): Repr[Out] = via(divertToGraph(that, when))
protected def divertToGraph[M](that: Graph[SinkShape[Out], M], when: Out Boolean): Graph[FlowShape[Out @uncheckedVariance, Out], M] =
GraphDSL.create(that) { implicit b r
import GraphDSL.Implicits._
val partition = b.add(Partition[Out](2, out if (when(out)) 1 else 0))
val partition = b.add(new Partition[Out](2, out if (when(out)) 1 else 0, true))
partition.out(1) ~> r
FlowShape(partition.in, partition.out(0))
}
@ -2335,7 +2335,6 @@ trait FlowOps[+Out, +Mat] {
* asynchronously.
*/
def async: Repr[Out]
}
/**

View file

@ -628,16 +628,19 @@ final class Broadcast[T](val outputPorts: Int, val eagerCancel: Boolean) extends
}
object Partition {
// FIXME make `PartitionOutOfBoundsException` a `final` class when possible
case class PartitionOutOfBoundsException(msg: String) extends IndexOutOfBoundsException(msg) with NoStackTrace
/**
* Create a new `Partition` stage with the specified input type.
* Create a new `Partition` stage with the specified input type. This method sets `eagerCancel` to `false`.
* To specify a different value for the `eagerCancel` parameter, then instantiate Partition using the constructor.
*
* If `eagerCancel` is true, partition cancels upstream if any of its downstreams cancel, if false, when all have cancelled.
*
* @param outputPorts number of output ports
* @param partitioner function deciding which output each element will be targeted
*/
def apply[T](outputPorts: Int, partitioner: T Int): Partition[T] = new Partition(outputPorts, partitioner)
*/ // FIXME BC add `eagerCancel: Boolean = false` parameter
def apply[T](outputPorts: Int, partitioner: T Int): Partition[T] = new Partition(outputPorts, partitioner, false)
}
/**
@ -650,14 +653,19 @@ object Partition {
*
* '''Completes when''' upstream completes and no output is pending
*
* '''Cancels when'''
* when all downstreams cancel
* '''Cancels when''' all downstreams have cancelled (eagerCancel=false) or one downstream cancels (eagerCancel=true)
*/
final class Partition[T](val outputPorts: Int, val partitioner: T Int) extends GraphStage[UniformFanOutShape[T, T]] {
final class Partition[T](val outputPorts: Int, val partitioner: T Int, val eagerCancel: Boolean) extends GraphStage[UniformFanOutShape[T, T]] {
/**
* Sets `eagerCancel` to `false`.
*/
@deprecated("Use the constructor which also specifies the `eagerCancel` parameter")
def this(outputPorts: Int, partitioner: T Int) = this(outputPorts, partitioner, false)
val in: Inlet[T] = Inlet[T]("Partition.in")
val out: Seq[Outlet[T]] = Seq.tabulate(outputPorts)(i Outlet[T]("Partition.out" + i))
val out: Seq[Outlet[T]] = Seq.tabulate(outputPorts)(i Outlet[T]("Partition.out" + i)) // FIXME BC make this immutable.IndexedSeq as type + Vector as concret impl
override val shape: UniformFanOutShape[T, T] = UniformFanOutShape[T, T](in, out: _*)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler {
@ -690,11 +698,10 @@ final class Partition[T](val outputPorts: Int, val partitioner: T ⇒ Int) exten
setHandler(in, this)
out.zipWithIndex.foreach {
out.iterator.zipWithIndex.foreach {
case (o, idx)
setHandler(o, new OutHandler {
override def onPull() = {
if (outPendingElem != null) {
val elem = outPendingElem.asInstanceOf[T]
if (idx == outPendingIdx) {
@ -711,24 +718,25 @@ final class Partition[T](val outputPorts: Int, val partitioner: T ⇒ Int) exten
pull(in)
}
override def onDownstreamFinish(): Unit = {
downstreamRunning -= 1
if (downstreamRunning == 0)
completeStage()
else if (outPendingElem != null) {
if (idx == outPendingIdx) {
outPendingElem = null
if (!hasBeenPulled(in))
pull(in)
override def onDownstreamFinish(): Unit =
if (eagerCancel) completeStage()
else {
downstreamRunning -= 1
if (downstreamRunning == 0)
completeStage()
else if (outPendingElem != null) {
if (idx == outPendingIdx) {
outPendingElem = null
if (!hasBeenPulled(in))
pull(in)
}
}
}
}
})
}
}
override def toString = s"Partition($outputPorts)"
}
object Balance {