From 85e06b35b13fecdcb2b8b74b865f8774f0de1132 Mon Sep 17 00:00:00 2001 From: Patrik Nordwall Date: Mon, 4 Apr 2016 15:59:52 +0200 Subject: [PATCH] str: Support stash in ActorPublisher, #17037 --- .../stream/actor/ActorPublisherSpec.scala | 37 +++++++++++++++- .../akka/stream/actor/ActorPublisher.scala | 42 ++++++++++++++----- 2 files changed, 67 insertions(+), 12 deletions(-) diff --git a/akka-stream-tests/src/test/scala/akka/stream/actor/ActorPublisherSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/actor/ActorPublisherSpec.scala index e760a7ac33..d4a8f8d986 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/actor/ActorPublisherSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/actor/ActorPublisherSpec.scala @@ -10,10 +10,10 @@ import akka.stream.testkit._ import akka.stream.testkit.Utils._ import akka.testkit.TestEvent.Mute import akka.testkit.{ AkkaSpec, EventFilter, ImplicitSender, TestProbe } - import scala.annotation.tailrec import scala.concurrent.duration._ import scala.util.control.NoStackTrace +import akka.actor.Stash object ActorPublisherSpec { @@ -29,6 +29,12 @@ object ActorPublisherSpec { else p } + def testPublisherWithStashProps(probe: ActorRef, useTestDispatcher: Boolean = true): Props = { + val p = Props(new TestPublisherWithStash(probe)) + if (useTestDispatcher) p.withDispatcher("akka.test.stream-dispatcher") + else p + } + case class TotalDemand(elements: Long) case class Produce(elem: String) case class Err(reason: String) @@ -53,6 +59,19 @@ object ActorPublisherSpec { } } + class TestPublisherWithStash(probe: ActorRef) extends TestPublisher(probe) with Stash { + + override def receive = stashing + + def stashing: Receive = { + case "unstash" ⇒ + unstashAll() + context.become(super.receive) + case _ ⇒ stash() + } + + } + def senderProps: Props = Props[Sender].withDispatcher("akka.test.stream-dispatcher") class Sender extends ActorPublisher[Int] { @@ -445,6 +464,22 @@ class ActorPublisherSpec extends AkkaSpec(ActorPublisherSpec.config) with Implic expectMsgType[String] should include("my-dispatcher1") } + "handle stash" in { + val probe = TestProbe() + val ref = system.actorOf(testPublisherWithStashProps(probe.ref)) + val p = ActorPublisher[String](ref) + val s = TestSubscriber.probe[String]() + p.subscribe(s) + s.request(2) + s.request(3) + ref ! "unstash" + probe.expectMsg(TotalDemand(5)) + probe.expectMsg(TotalDemand(5)) + s.request(4) + probe.expectMsg(TotalDemand(9)) + s.cancel() + } + } } diff --git a/akka-stream/src/main/scala/akka/stream/actor/ActorPublisher.scala b/akka-stream/src/main/scala/akka/stream/actor/ActorPublisher.scala index 3efc9cfa66..939a077f55 100644 --- a/akka-stream/src/main/scala/akka/stream/actor/ActorPublisher.scala +++ b/akka-stream/src/main/scala/akka/stream/actor/ActorPublisher.scala @@ -45,7 +45,18 @@ object ActorPublisherMessage { * more elements. * @param n number of requested elements */ - final case class Request(n: Long) extends ActorPublisherMessage with NoSerializationVerificationNeeded + final case class Request(n: Long) extends ActorPublisherMessage with NoSerializationVerificationNeeded { + private var processed = false + /** + * INTERNAL API: needed for stash support + */ + private[akka] def markProcessed(): Unit = processed = true + + /** + * INTERNAL API: needed for stash support + */ + private[akka] def isProcessed(): Boolean = processed + } /** * This message is delivered to the [[ActorPublisher]] actor when the stream subscriber cancels the @@ -256,15 +267,21 @@ trait ActorPublisher[T] extends Actor { * INTERNAL API */ protected[akka] override def aroundReceive(receive: Receive, msg: Any): Unit = msg match { - case Request(n) ⇒ - if (n < 1) { - if (lifecycleState == Active) - onError(numberOfElementsInRequestMustBePositiveException) + case req @ Request(n) ⇒ + if (req.isProcessed()) { + // it's an unstashed Request, demand is already handled + super.aroundReceive(receive, req) } else { - demand += n - if (demand < 0) - demand = Long.MaxValue // Long overflow, Reactive Streams Spec 3:17: effectively unbounded - super.aroundReceive(receive, msg) + if (n < 1) { + if (lifecycleState == Active) + onError(numberOfElementsInRequestMustBePositiveException) + } else { + demand += n + if (demand < 0) + demand = Long.MaxValue // Long overflow, Reactive Streams Spec 3:17: effectively unbounded + req.markProcessed() + super.aroundReceive(receive, req) + } } case Subscribe(sub: Subscriber[_]) ⇒ @@ -293,8 +310,11 @@ trait ActorPublisher[T] extends Actor { } case Cancel ⇒ - cancelSelf() - super.aroundReceive(receive, msg) + if (lifecycleState != Canceled) { + // possible to receive again in case of stash + cancelSelf() + super.aroundReceive(receive, msg) + } case SubscriptionTimeoutExceeded ⇒ if (!scheduledSubscriptionTimeout.isCancelled) {