diff --git a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala index 69804a0728..bb075af767 100644 --- a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala +++ b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala @@ -3,17 +3,21 @@ */ package akka.stream.testkit -import akka.actor.{ ActorSystem, DeadLetterSuppression, NoSerializationVerificationNeeded } +import akka.actor.{ ActorRef, ActorSystem, DeadLetterSuppression, NoSerializationVerificationNeeded } import akka.stream._ import akka.stream.impl._ -import akka.testkit.TestProbe +import akka.testkit.{ TestActor, TestProbe } import org.reactivestreams.{ Publisher, Subscriber, Subscription } + import scala.annotation.tailrec import scala.collection.immutable import scala.concurrent.duration._ import scala.language.existentials import java.io.StringWriter import java.io.PrintWriter +import java.util.concurrent.CountDownLatch + +import akka.testkit.TestActor.{ AutoPilot, NoAutoPilot } /** * Provides factory methods for various Publishers. @@ -27,6 +31,8 @@ object TestPublisher { final case class CancelSubscription(subscription: Subscription) extends PublisherEvent final case class RequestMore(subscription: Subscription, elements: Long) extends PublisherEvent + final object SubscriptionDone extends NoSerializationVerificationNeeded + /** * Publisher that signals complete to subscribers, after handing a void subscription. */ @@ -74,6 +80,15 @@ object TestPublisher { private val probe: TestProbe = TestProbe() + //this is a way to pause receiving message from probe until subscription is done + private val subscribed = new CountDownLatch(1) + probe.ignoreMsg { case SubscriptionDone ⇒ true } + probe.setAutoPilot(new TestActor.AutoPilot() { + override def run(sender: ActorRef, msg: Any): AutoPilot = { + if (msg == SubscriptionDone) subscribed.countDown() + this + } + }) private val self = this.asInstanceOf[Self] /** @@ -83,18 +98,26 @@ object TestPublisher { val subscription: PublisherProbeSubscription[I] = new PublisherProbeSubscription[I](subscriber, probe) probe.ref ! Subscribe(subscription) if (autoOnSubscribe) subscriber.onSubscribe(subscription) + probe.ref ! SubscriptionDone + } + + def executeAfterSubscription[T](f: ⇒ T): T = { + subscribed.await( + probe.testKitSettings.DefaultTimeout.duration.length, + probe.testKitSettings.DefaultTimeout.duration.unit) + f } /** * Expect a subscription. */ def expectSubscription(): PublisherProbeSubscription[I] = - probe.expectMsgType[Subscribe].subscription.asInstanceOf[PublisherProbeSubscription[I]] + executeAfterSubscription { probe.expectMsgType[Subscribe].subscription.asInstanceOf[PublisherProbeSubscription[I]] } /** * Expect demand from a given subscription. */ - def expectRequest(subscription: Subscription, n: Int): Self = { + def expectRequest(subscription: Subscription, n: Int): Self = executeAfterSubscription { probe.expectMsg(RequestMore(subscription, n)) self } @@ -102,7 +125,7 @@ object TestPublisher { /** * Expect no messages. */ - def expectNoMsg(): Self = { + def expectNoMsg(): Self = executeAfterSubscription { probe.expectNoMsg() self } @@ -110,7 +133,7 @@ object TestPublisher { /** * Expect no messages for a given duration. */ - def expectNoMsg(max: FiniteDuration): Self = { + def expectNoMsg(max: FiniteDuration): Self = executeAfterSubscription { probe.expectNoMsg(max) self } @@ -119,10 +142,10 @@ object TestPublisher { * Receive messages for a given duration or until one does not match a given partial function. */ def receiveWhile[T](max: Duration = Duration.Undefined, idle: Duration = Duration.Inf, messages: Int = Int.MaxValue)(f: PartialFunction[PublisherEvent, T]): immutable.Seq[T] = - probe.receiveWhile(max, idle, messages)(f.asInstanceOf[PartialFunction[AnyRef, T]]) + executeAfterSubscription { probe.receiveWhile(max, idle, messages)(f.asInstanceOf[PartialFunction[AnyRef, T]]) } def expectEventPF[T](f: PartialFunction[PublisherEvent, T]): T = - probe.expectMsgPF[T]()(f.asInstanceOf[PartialFunction[Any, T]]) + executeAfterSubscription { probe.expectMsgPF[T]()(f.asInstanceOf[PartialFunction[Any, T]]) } def getPublisher: Publisher[I] = this @@ -142,12 +165,12 @@ object TestPublisher { * } * }}} */ - def within[T](min: FiniteDuration, max: FiniteDuration)(f: ⇒ T): T = probe.within(min, max)(f) + def within[T](min: FiniteDuration, max: FiniteDuration)(f: ⇒ T): T = executeAfterSubscription { probe.within(min, max)(f) } /** * Same as calling `within(0 seconds, max)(f)`. */ - def within[T](max: FiniteDuration)(f: ⇒ T): T = probe.within(max)(f) + def within[T](max: FiniteDuration)(f: ⇒ T): T = executeAfterSubscription { probe.within(max)(f) } } /** diff --git a/akka-stream/src/main/scala/akka/stream/impl/ReactiveStreamsCompliance.scala b/akka-stream/src/main/scala/akka/stream/impl/ReactiveStreamsCompliance.scala index e3f3946579..aee943d796 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ReactiveStreamsCompliance.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ReactiveStreamsCompliance.scala @@ -109,12 +109,14 @@ import org.reactivestreams.{ Subscriber, Subscription } } final def tryRequest(subscription: Subscription, demand: Long): Unit = { + if (subscription eq null) throw new IllegalStateException("Subscription must be not null on request() call, rule 1.3") try subscription.request(demand) catch { case NonFatal(t) ⇒ throw new SignalThrewException("It is illegal to throw exceptions from request(), rule 3.16", t) } } final def tryCancel(subscription: Subscription): Unit = { + if (subscription eq null) throw new IllegalStateException("Subscription must be not null on cancel() call, rule 1.3") try subscription.cancel() catch { case NonFatal(t) ⇒ throw new SignalThrewException("It is illegal to throw exceptions from cancel(), rule 3.15", t) }