From c39dd6506edfaa858c7de5f5c561c95e426c8b5c Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Thu, 9 Jan 2020 16:16:17 +0100 Subject: [PATCH] stream: filter out elements without demand This will also mean that completion will not be blocked by elements that will later be filtered out. One particular use case of that would be a kind of partitioning use case, where you put several streams behind a broadcast and each consumer will filter out elements not handled there. In that case, the broadcast can get head-of-line blocked when one of the consumers currently has no demand but also wouldn't have to handle any elements because they would all be filtered out. --- .../stream/impl/fusing/InterpreterSpec.scala | 27 ++++++++-------- .../akka/stream/scaladsl/FlowFilterSpec.scala | 28 ++++++++++++++++ .../scala/akka/stream/impl/fusing/Ops.scala | 32 ++++++++++++++----- 3 files changed, 66 insertions(+), 21 deletions(-) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala index 5c605f5577..581d0e5d60 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala @@ -84,10 +84,10 @@ class InterpreterSpec extends StreamSpec with GraphInterpreterSpecKit { Doubler(), Filter((x: Int) => x != 0)) { - lastEvents() should be(Set.empty) + lastEvents() should be(Set(RequestOne)) downstream.requestOne() - lastEvents() should be(Set(RequestOne)) + lastEvents() should be(Set.empty) upstream.onNext(0) lastEvents() should be(Set(RequestOne)) @@ -96,10 +96,10 @@ class InterpreterSpec extends StreamSpec with GraphInterpreterSpecKit { lastEvents() should be(Set(OnNext(1))) downstream.requestOne() - lastEvents() should be(Set(OnNext(1))) + lastEvents() should be(Set(OnNext(1), RequestOne)) downstream.requestOne() - lastEvents() should be(Set(RequestOne)) + lastEvents() should be(Set.empty) upstream.onComplete() lastEvents() should be(Set(OnComplete)) @@ -109,22 +109,22 @@ class InterpreterSpec extends StreamSpec with GraphInterpreterSpecKit { Filter((x: Int) => x != 0), Doubler()) { - lastEvents() should be(Set.empty) + lastEvents() should be(Set(RequestOne)) downstream.requestOne() - lastEvents() should be(Set(RequestOne)) + lastEvents() should be(Set.empty) upstream.onNext(0) lastEvents() should be(Set(RequestOne)) upstream.onNext(1) - lastEvents() should be(Set(OnNext(1))) + lastEvents() should be(Set(OnNext(1), RequestOne)) downstream.requestOne() lastEvents() should be(Set(OnNext(1))) downstream.requestOne() - lastEvents() should be(Set(RequestOne)) + lastEvents() should be(Set.empty) downstream.cancel() lastEvents() should be(Set(Cancel(SubscriptionWithCancelException.NoMoreElementsNeeded))) @@ -152,22 +152,23 @@ class InterpreterSpec extends StreamSpec with GraphInterpreterSpecKit { takeTwo, Map((x: Int) => x + 1)) { - lastEvents() should be(Set.empty) + lastEvents() should be(Set(RequestOne)) downstream.requestOne() - lastEvents() should be(Set(RequestOne)) + lastEvents() should be(Set.empty) upstream.onNext(0) lastEvents() should be(Set(RequestOne)) upstream.onNext(1) - lastEvents() should be(Set(OnNext(2))) + lastEvents() should be(Set(OnNext(2), RequestOne)) downstream.requestOne() - lastEvents() should be(Set(RequestOne)) + lastEvents() should be(Set.empty) upstream.onNext(2) - lastEvents() should be(Set(Cancel(SubscriptionWithCancelException.StageWasCompleted), OnComplete, OnNext(3))) + lastEvents() should be( + Set(RequestOne, Cancel(SubscriptionWithCancelException.StageWasCompleted), OnComplete, OnNext(3))) } "implement fold" in new OneBoundedSetup[Int](Fold(0, (agg: Int, x: Int) => agg + x)) { diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFilterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFilterSpec.scala index 1b89633e51..a5d4408499 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFilterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowFilterSpec.scala @@ -12,7 +12,9 @@ import akka.stream.Supervision._ import akka.stream.testkit._ import akka.stream.testkit.scaladsl.StreamTestKit._ import akka.stream.testkit.scaladsl.TestSink +import akka.stream.testkit.scaladsl.TestSource +import scala.concurrent.duration._ import scala.util.control.NoStackTrace class FlowFilterSpec extends StreamSpec(""" @@ -60,6 +62,32 @@ class FlowFilterSpec extends StreamSpec(""" .expectComplete() } + "filter out elements without demand" in assertAllStagesStopped { + val (inProbe, outProbe) = + TestSource + .probe[Int] + .filter(_ > 1000) + .toMat(TestSink.probe[Int])(Keep.both) + .addAttributes(Attributes.inputBuffer(1, 1)) + .run() + + outProbe.ensureSubscription() + // none of those should fail even without demand + inProbe.sendNext(1).sendNext(2).sendNext(3).sendNext(4).sendNext(5).sendNext(1001).sendNext(1002) + + // now the buffer should be full (1 internal buffer, 1 async buffer at source probe) + + inProbe.expectNoMessage(100.millis).pending shouldBe 0L + + inProbe.sendComplete() // to test later that completion is buffered as well + + outProbe.requestNext(1001) + outProbe.requestNext(1002) + outProbe.expectComplete() + } + "complete without demand if remaining elements are filtered out" in assertAllStagesStopped { + Source(1 to 1000).filter(_ > 1000).runWith(TestSink.probe[Int]).ensureSubscription().expectComplete() + } } "A FilterNot" must { diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 27b943c97e..355ccfa37a 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -79,14 +79,19 @@ import com.github.ghik.silencer.silent new GraphStageLogic(shape) with OutHandler with InHandler { def decider = inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider - override def onPush(): Unit = { + private var buffer: OptionVal[T] = OptionVal.none + + override def preStart(): Unit = pull(in) + override def onPush(): Unit = try { val elem = grab(in) - if (p(elem)) { - push(out, elem) - } else { - pull(in) - } + if (p(elem)) + if (isAvailable(out)) { + push(out, elem) + pull(in) + } else + buffer = OptionVal.Some(elem) + else pull(in) } catch { case NonFatal(ex) => decider(ex) match { @@ -94,9 +99,20 @@ import com.github.ghik.silencer.silent case _ => pull(in) } } - } - override def onPull(): Unit = pull(in) + override def onPull(): Unit = + buffer match { + case OptionVal.Some(value) => + push(out, value) + buffer = OptionVal.none + if (!isClosed(in)) pull(in) + else completeStage() + case _ => // already pulled + } + + override def onUpstreamFinish(): Unit = + if (buffer.isEmpty) super.onUpstreamFinish() + // else onPull will complete setHandlers(in, out, this) }