diff --git a/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java b/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java index 17aef36bfe..80aa817f82 100644 --- a/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java +++ b/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java @@ -484,6 +484,21 @@ public class SourceTest extends StreamTest { assertEquals(result.size(), 10000); for (Integer i: result) assertEquals(i, (Integer) 42); } + + @Test + public void mustBeAbleToUseQueue() throws Exception { + final Pair, CompletionStage>> x = + Flow.of(String.class).runWith( + Source.queue(2, OverflowStrategy.fail()), + Sink.seq(), materializer); + final SourceQueueWithComplete source = x.first(); + final CompletionStage> result = x.second(); + source.offer("hello"); + source.offer("world"); + source.complete(); + assertEquals(result.toCompletableFuture().get(3, TimeUnit.SECONDS), + Arrays.asList("hello", "world")); + } @Test public void mustBeAbleToUseActorRefSource() throws Exception { diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/QueueSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/QueueSourceSpec.scala index 1a4b613160..ce78c46ed6 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/QueueSourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/QueueSourceSpec.scala @@ -8,11 +8,16 @@ import akka.pattern.pipe import akka.stream._ import akka.stream.impl.QueueSource import akka.stream.testkit.Utils._ -import akka.stream.testkit._ import akka.testkit.{ AkkaSpec, TestProbe } import scala.concurrent.duration._ import scala.concurrent._ import akka.Done +import org.scalatest.concurrent.ScalaFutures +import akka.testkit.AkkaSpec +import akka.stream.testkit.TestSubscriber +import akka.stream.testkit.TestSourceStage +import akka.stream.testkit.GraphStageMessages +import akka.stream.testkit.scaladsl.TestSink class QueueSourceSpec extends AkkaSpec { implicit val materializer = ActorMaterializer() @@ -20,8 +25,7 @@ class QueueSourceSpec extends AkkaSpec { val pause = 300.millis def assertSuccess(f: Future[QueueOfferResult]): Unit = { - f pipeTo testActor - expectMsg(QueueOfferResult.Enqueued) + f.futureValue should ===(QueueOfferResult.Enqueued) } "A QueueSource" must { @@ -43,6 +47,25 @@ class QueueSourceSpec extends AkkaSpec { expectMsg(Done) } + "be reusable" in { + val source = Source.queue(0, OverflowStrategy.backpressure) + val q1 = source.to(Sink.ignore).run() + q1.complete() + q1.watchCompletion().futureValue should ===(Done) + val q2 = source.to(Sink.ignore).run() + q2.watchCompletion().value should ===(None) + } + + "reject elements when back-pressuring with maxBuffer=0" in { + val (source, probe) = Source.queue[Int](0, OverflowStrategy.backpressure).toMat(TestSink.probe)(Keep.both).run() + val f = source.offer(42) + val ex = source.offer(43).failed.futureValue + ex shouldBe a[IllegalStateException] + ex.getMessage should include("have to wait") + probe.requestNext() should ===(42) + f.futureValue should ===(QueueOfferResult.Enqueued) + } + "buffer when needed" in { val s = TestSubscriber.manualProbe[Int]() val queue = Source.queue(100, OverflowStrategy.dropHead).to(Sink.fromSubscriber(s)).run() @@ -121,24 +144,23 @@ class QueueSourceSpec extends AkkaSpec { } "fail offer future if user does not wait in backpressure mode" in assertAllStagesStopped { - val s = TestSubscriber.manualProbe[Int]() - val queue = Source.queue(5, OverflowStrategy.backpressure).to(Sink.fromSubscriber(s)).run() - val sub = s.expectSubscription + val (queue, probe) = Source.queue[Int](5, OverflowStrategy.backpressure).toMat(TestSink.probe)(Keep.both).run() for (i ← 1 to 5) assertSuccess(queue.offer(i)) queue.offer(6).pipeTo(testActor) - expectNoMsg(pause) - val future = queue.offer(7) - future.onFailure { case e ⇒ e.isInstanceOf[IllegalStateException] should ===(true) } - future.onSuccess { case _ ⇒ fail() } - Await.ready(future, pause) + queue.offer(7).pipeTo(testActor) + expectMsgType[Status.Failure].cause shouldBe an[IllegalStateException] - sub.request(1) - s.expectNext(1) + probe.requestNext(1) expectMsg(QueueOfferResult.Enqueued) - sub.cancel() + queue.complete() + + probe + .request(6) + .expectNext(2, 3, 4, 5, 6) + .expectComplete() } "complete watching future with failure if stream failed" in assertAllStagesStopped { @@ -215,6 +237,112 @@ class QueueSourceSpec extends AkkaSpec { sourceQueue2.watchCompletion().isCompleted should ===(false) } + "complete the stream" when { + + "buffer is empty" in { + val (source, probe) = Source.queue[Int](1, OverflowStrategy.fail).toMat(TestSink.probe)(Keep.both).run() + source.complete() + source.watchCompletion().futureValue should ===(Done) + probe + .ensureSubscription() + .expectComplete() + } + + "buffer is full" in { + val (source, probe) = Source.queue[Int](1, OverflowStrategy.fail).toMat(TestSink.probe)(Keep.both).run() + source.offer(1) + source.complete() + probe + .requestNext(1) + .expectComplete() + source.watchCompletion().futureValue should ===(Done) + } + + "buffer is full and element is pending" in { + val (source, probe) = Source.queue[Int](1, OverflowStrategy.backpressure).toMat(TestSink.probe)(Keep.both).run() + source.offer(1) + source.offer(2) + source.complete() + probe + .requestNext(1) + .requestNext(2) + .expectComplete() + source.watchCompletion().futureValue should ===(Done) + } + + "no buffer is used" in { + val (source, probe) = Source.queue[Int](0, OverflowStrategy.fail).toMat(TestSink.probe)(Keep.both).run() + source.complete() + source.watchCompletion().futureValue should ===(Done) + probe + .ensureSubscription() + .expectComplete() + } + + "no buffer is used and element is pending" in { + val (source, probe) = Source.queue[Int](0, OverflowStrategy.fail).toMat(TestSink.probe)(Keep.both).run() + source.offer(1) + source.complete() + probe + .requestNext(1) + .expectComplete() + source.watchCompletion().futureValue should ===(Done) + } + } + + "fail the stream" when { + val ex = new Exception("BUH") + + "buffer is empty" in { + val (source, probe) = Source.queue[Int](1, OverflowStrategy.fail).toMat(TestSink.probe)(Keep.both).run() + source.fail(ex) + source.watchCompletion().failed.futureValue should ===(ex) + probe + .ensureSubscription() + .expectError(ex) + } + + "buffer is full" in { + val (source, probe) = Source.queue[Int](1, OverflowStrategy.fail).toMat(TestSink.probe)(Keep.both).run() + source.offer(1) + source.fail(ex) + source.watchCompletion().failed.futureValue should ===(ex) + probe + .ensureSubscription() + .expectError(ex) + } + + "buffer is full and element is pending" in { + val (source, probe) = Source.queue[Int](1, OverflowStrategy.backpressure).toMat(TestSink.probe)(Keep.both).run() + source.offer(1) + source.offer(2) + source.fail(ex) + source.watchCompletion().failed.futureValue should ===(ex) + probe + .ensureSubscription() + .expectError(ex) + } + + "no buffer is used" in { + val (source, probe) = Source.queue[Int](0, OverflowStrategy.fail).toMat(TestSink.probe)(Keep.both).run() + source.fail(ex) + source.watchCompletion().failed.futureValue should ===(ex) + probe + .ensureSubscription() + .expectError(ex) + } + + "no buffer is used and element is pending" in { + val (source, probe) = Source.queue[Int](0, OverflowStrategy.fail).toMat(TestSink.probe)(Keep.both).run() + source.offer(1) + source.fail(ex) + source.watchCompletion().failed.futureValue should ===(ex) + probe + .ensureSubscription() + .expectError(ex) + } + } + } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/Sources.scala b/akka-stream/src/main/scala/akka/stream/impl/Sources.scala index f11b8c494c..6be339a6d8 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Sources.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Sources.scala @@ -7,7 +7,7 @@ import akka.stream.OverflowStrategies._ import akka.stream._ import akka.stream.stage._ import scala.concurrent.{ Future, Promise } -import akka.stream.scaladsl.SourceQueue +import akka.stream.scaladsl.SourceQueueWithComplete import akka.Done import java.util.concurrent.CompletionStage import scala.compat.java8.FutureConverters._ @@ -15,124 +15,173 @@ import scala.compat.java8.FutureConverters._ /** * INTERNAL API */ -final private[stream] class QueueSource[T](maxBuffer: Int, overflowStrategy: OverflowStrategy) extends GraphStageWithMaterializedValue[SourceShape[T], SourceQueue[T]] { - type Offered = Promise[QueueOfferResult] +private[stream] object QueueSource { + sealed trait Input[+T] + final case class Offer[+T](elem: T, promise: Promise[QueueOfferResult]) extends Input[T] + case object Completion extends Input[Nothing] + final case class Failure(ex: Throwable) extends Input[Nothing] +} + +/** + * INTERNAL API + */ +final private[stream] class QueueSource[T](maxBuffer: Int, overflowStrategy: OverflowStrategy) extends GraphStageWithMaterializedValue[SourceShape[T], SourceQueueWithComplete[T]] { + import QueueSource._ val out = Outlet[T]("queueSource.out") override val shape: SourceShape[T] = SourceShape.of(out) override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { val completion = Promise[Done] - val stageLogic = new GraphStageLogic(shape) with CallbackWrapper[(T, Offered)] { + val stageLogic = new GraphStageLogic(shape) with CallbackWrapper[Input[T]] with OutHandler { var buffer: Buffer[T] = _ - var pendingOffer: Option[(T, Offered)] = None - var pulled = false + var pendingOffer: Option[Offer[T]] = None + var terminating = false override def preStart(): Unit = { if (maxBuffer > 0) buffer = Buffer(maxBuffer, materializer) initCallback(callback.invoke) } override def postStop(): Unit = stopCallback { - case (elem, promise) ⇒ promise.failure(new IllegalStateException("Stream is terminated. SourceQueue is detached")) + case Offer(elem, promise) ⇒ promise.failure(new IllegalStateException("Stream is terminated. SourceQueue is detached")) + case _ ⇒ // ignore } - private def enqueueAndSuccess(elem: T, promise: Offered): Unit = { - buffer.enqueue(elem) - promise.success(QueueOfferResult.Enqueued) + private def enqueueAndSuccess(offer: Offer[T]): Unit = { + buffer.enqueue(offer.elem) + offer.promise.success(QueueOfferResult.Enqueued) } - private def bufferElem(elem: T, promise: Offered): Unit = { + private def bufferElem(offer: Offer[T]): Unit = { if (!buffer.isFull) { - enqueueAndSuccess(elem, promise) + enqueueAndSuccess(offer) } else overflowStrategy match { case DropHead ⇒ buffer.dropHead() - enqueueAndSuccess(elem, promise) + enqueueAndSuccess(offer) case DropTail ⇒ buffer.dropTail() - enqueueAndSuccess(elem, promise) + enqueueAndSuccess(offer) case DropBuffer ⇒ buffer.clear() - enqueueAndSuccess(elem, promise) + enqueueAndSuccess(offer) case DropNew ⇒ - promise.success(QueueOfferResult.Dropped) + offer.promise.success(QueueOfferResult.Dropped) case Fail ⇒ val bufferOverflowException = new BufferOverflowException(s"Buffer overflow (max capacity was: $maxBuffer)!") - promise.success(QueueOfferResult.Failure(bufferOverflowException)) + offer.promise.success(QueueOfferResult.Failure(bufferOverflowException)) completion.failure(bufferOverflowException) failStage(bufferOverflowException) case Backpressure ⇒ pendingOffer match { case Some(_) ⇒ - promise.failure(new IllegalStateException("You have to wait for previous offer to be resolved to send another request")) + offer.promise.failure(new IllegalStateException("You have to wait for previous offer to be resolved to send another request")) case None ⇒ - pendingOffer = Some((elem, promise)) + pendingOffer = Some(offer) } } } - private val callback: AsyncCallback[(T, Offered)] = getAsyncCallback(tuple ⇒ { - val (elem, promise) = tuple + private val callback: AsyncCallback[Input[T]] = getAsyncCallback { - if (maxBuffer != 0) { - bufferElem(elem, promise) - if (pulled) { - push(out, buffer.dequeue()) - pulled = false + case offer @ Offer(elem, promise) ⇒ + if (maxBuffer != 0) { + bufferElem(offer) + if (isAvailable(out)) push(out, buffer.dequeue()) + } else if (isAvailable(out)) { + push(out, elem) + promise.success(QueueOfferResult.Enqueued) + } else if (pendingOffer.isEmpty) + pendingOffer = Some(offer) + else overflowStrategy match { + case DropHead | DropBuffer ⇒ + pendingOffer.get.promise.success(QueueOfferResult.Dropped) + pendingOffer = Some(offer) + case DropTail | DropNew ⇒ + promise.success(QueueOfferResult.Dropped) + case Fail ⇒ + val bufferOverflowException = new BufferOverflowException(s"Buffer overflow (max capacity was: $maxBuffer)!") + promise.success(QueueOfferResult.Failure(bufferOverflowException)) + completion.failure(bufferOverflowException) + failStage(bufferOverflowException) + case Backpressure ⇒ + promise.failure(new IllegalStateException("You have to wait for previous offer to be resolved to send another request")) } - } else if (pulled) { - push(out, elem) - pulled = false - promise.success(QueueOfferResult.Enqueued) - } else pendingOffer = Some(tuple) - }) - setHandler(out, new OutHandler { - override def onDownstreamFinish(): Unit = { + case Completion ⇒ + if (maxBuffer != 0 && buffer.nonEmpty || pendingOffer.nonEmpty) terminating = true + else { + completion.success(Done) + completeStage() + } + + case Failure(ex) ⇒ + completion.failure(ex) + failStage(ex) + } + + setHandler(out, this) + + override def onDownstreamFinish(): Unit = { + pendingOffer match { + case Some(Offer(elem, promise)) ⇒ + promise.success(QueueOfferResult.QueueClosed) + pendingOffer = None + case None ⇒ // do nothing + } + completion.success(Done) + completeStage() + } + + override def onPull(): Unit = { + if (maxBuffer == 0) { pendingOffer match { - case Some((elem, promise)) ⇒ - promise.success(QueueOfferResult.QueueClosed) + case Some(Offer(elem, promise)) ⇒ + push(out, elem) + promise.success(QueueOfferResult.Enqueued) pendingOffer = None - case None ⇒ // do nothing + if (terminating) { + completion.success(Done) + completeStage() + } + case None ⇒ + } + } else if (buffer.nonEmpty) { + push(out, buffer.dequeue()) + pendingOffer match { + case Some(offer) ⇒ + enqueueAndSuccess(offer) + pendingOffer = None + case None ⇒ //do nothing + } + if (terminating && buffer.isEmpty) { + completion.success(Done) + completeStage() } - completion.success(Done) - completeStage() } - - override def onPull(): Unit = { - if (maxBuffer == 0) - pendingOffer match { - case Some((elem, promise)) ⇒ - push(out, elem) - promise.success(QueueOfferResult.Enqueued) - pendingOffer = None - case None ⇒ pulled = true - } - else if (!buffer.isEmpty) { - push(out, buffer.dequeue()) - pendingOffer match { - case Some((elem, promise)) ⇒ - enqueueAndSuccess(elem, promise) - pendingOffer = None - case None ⇒ //do nothing - } - } else pulled = true - } - }) + } } - (stageLogic, new SourceQueue[T] { + (stageLogic, new SourceQueueWithComplete[T] { override def watchCompletion() = completion.future override def offer(element: T): Future[QueueOfferResult] = { - val p = Promise[QueueOfferResult]() - stageLogic.invoke((element, p)) + val p = Promise[QueueOfferResult] + stageLogic.invoke(Offer(element, p)) p.future } + override def complete(): Unit = { + stageLogic.invoke(Completion) + } + override def fail(ex: Throwable): Unit = { + stageLogic.invoke(Failure(ex)) + } }) } } -private[akka] final class SourceQueueAdapter[T](delegate: SourceQueue[T]) extends akka.stream.javadsl.SourceQueue[T] { +private[akka] final class SourceQueueAdapter[T](delegate: SourceQueueWithComplete[T]) extends akka.stream.javadsl.SourceQueueWithComplete[T] { def offer(elem: T): CompletionStage[QueueOfferResult] = delegate.offer(elem).toJava def watchCompletion(): CompletionStage[Done] = delegate.watchCompletion().toJava + def complete(): Unit = delegate.complete() + def fail(ex: Throwable): Unit = delegate.fail(ex) } diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Queue.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Queue.scala index 5215a67e72..22853d38ed 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Queue.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Queue.scala @@ -32,6 +32,23 @@ trait SourceQueue[T] { def watchCompletion(): CompletionStage[Done] } +/** + * This trait adds completion support to [[SourceQueue]]. + */ +trait SourceQueueWithComplete[T] extends SourceQueue[T] { + /** + * Complete the stream normally. Use `watchCompletion` to be notified of this + * operation’s success. + */ + def complete(): Unit + + /** + * Complete the stream with a failure. Use `watchCompletion` to be notified of this + * operation’s success. + */ + def fail(ex: Throwable): Unit +} + /** * Trait allows to have the queue as a sink for some stream. * "SinkQueue" pulls data from stream with backpressure mechanism. diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala index 4bb92f3451..09fe8ebd7b 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala @@ -24,6 +24,7 @@ import java.util.concurrent.CompletionStage import java.util.concurrent.CompletableFuture import scala.compat.java8.FutureConverters._ import akka.stream.impl.SourceQueueAdapter +import akka.stream.scaladsl.SourceQueueWithComplete /** Java API */ object Source { @@ -304,7 +305,7 @@ object Source { * @param bufferSize size of buffer in element count * @param overflowStrategy Strategy that is used when incoming elements cannot fit inside the buffer */ - def queue[T](bufferSize: Int, overflowStrategy: OverflowStrategy): Source[T, SourceQueue[T]] = + def queue[T](bufferSize: Int, overflowStrategy: OverflowStrategy): Source[T, SourceQueueWithComplete[T]] = new Source(scaladsl.Source.queue[T](bufferSize, overflowStrategy).mapMaterializedValue(new SourceQueueAdapter(_))) } diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Queue.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Queue.scala index 9a39ce3795..9a7fb20483 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Queue.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Queue.scala @@ -31,6 +31,23 @@ trait SourceQueue[T] { def watchCompletion(): Future[Done] } +/** + * This trait adds completion support to [[SourceQueue]]. + */ +trait SourceQueueWithComplete[T] extends SourceQueue[T] { + /** + * Complete the stream normally. Use `watchCompletion` to be notified of this + * operation’s success. + */ + def complete(): Unit + + /** + * Complete the stream with a failure. Use `watchCompletion` to be notified of this + * operation’s success. + */ + def fail(ex: Throwable): Unit +} + /** * Trait allows to have the queue as a sink for some stream. * "SinkQueue" pulls data from stream with backpressure mechanism. diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala index ec08270095..7740e27fd6 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala @@ -437,7 +437,7 @@ object Source { * @param bufferSize size of buffer in element count * @param overflowStrategy Strategy that is used when incoming elements cannot fit inside the buffer */ - def queue[T](bufferSize: Int, overflowStrategy: OverflowStrategy): Source[T, SourceQueue[T]] = + def queue[T](bufferSize: Int, overflowStrategy: OverflowStrategy): Source[T, SourceQueueWithComplete[T]] = Source.fromGraph(new QueueSource(bufferSize, overflowStrategy).withAttributes(DefaultAttributes.queueSource)) } diff --git a/project/MiMa.scala b/project/MiMa.scala index a91720c308..edd56b44e6 100644 --- a/project/MiMa.scala +++ b/project/MiMa.scala @@ -728,7 +728,10 @@ object MiMa extends AutoPlugin { // #15947 catch mailbox creation failures ProblemFilters.exclude[DirectMissingMethodProblem]("akka.actor.RepointableActorRef.point"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.actor.dungeon.Dispatch.initWithFailure") + ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.actor.dungeon.Dispatch.initWithFailure"), + + // #19877 Source.queue termination support + ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.stream.impl.SourceQueueAdapter.this") ) ) }