diff --git a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala index 67cc08575b..39eb3c3a11 100644 --- a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala +++ b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala @@ -3,8 +3,6 @@ */ package akka.stream.testkit -import java.util.concurrent.TimeoutException - import akka.actor.{ ActorSystem, DeadLetterSuppression, NoSerializationVerificationNeeded } import akka.stream._ import akka.stream.impl.StreamLayout.Module @@ -24,6 +22,11 @@ object TestPublisher { import StreamTestKit._ + trait PublisherEvent extends DeadLetterSuppression with NoSerializationVerificationNeeded + final case class Subscribe(subscription: Subscription) extends PublisherEvent + final case class CancelSubscription(subscription: Subscription) extends PublisherEvent + final case class RequestMore(subscription: Subscription, elements: Long) extends PublisherEvent + /** * Publisher that signals complete to subscribers, after handing a void subscription. */ @@ -118,6 +121,9 @@ object TestPublisher { def receiveWhile[T](max: Duration = Duration.Undefined, idle: Duration = Duration.Inf, messages: Int = Int.MaxValue)(f: PartialFunction[PublisherEvent, T]): immutable.Seq[T] = probe.receiveWhile(max, idle, messages)(f.asInstanceOf[PartialFunction[AnyRef, T]]) + def expectEventPF[T](f: PartialFunction[PublisherEvent, T]): T = + probe.expectMsgPF[T](probe.remaining)(f.asInstanceOf[PartialFunction[Any, T]]) + def getPublisher: Publisher[I] = this } @@ -170,7 +176,11 @@ object TestPublisher { object TestSubscriber { - import StreamTestKit._ + trait SubscriberEvent extends DeadLetterSuppression with NoSerializationVerificationNeeded + final case class OnSubscribe(subscription: Subscription) extends SubscriberEvent + final case class OnNext[I](element: I) extends SubscriberEvent + final case object OnComplete extends SubscriberEvent + final case class OnError(cause: Throwable) extends SubscriberEvent /** * Probe that implements [[org.reactivestreams.Subscriber]] interface. @@ -477,6 +487,17 @@ object TestSubscriber { self } + def expectNextPF[T](f: PartialFunction[Any, T]): T = { + expectEventPF { + case OnNext(n) ⇒ + assert(f.isDefinedAt(n)) + f(n) + } + } + + def expectEventPF[T](f: PartialFunction[SubscriberEvent, T]): T = + probe.expectMsgPF[T](probe.remaining)(f.asInstanceOf[PartialFunction[Any, T]]) + /** * Receive messages for a given duration or until one does not match a given partial function. */ @@ -564,17 +585,7 @@ object TestSubscriber { * INTERNAL API */ private[testkit] object StreamTestKit { - - sealed trait PublisherEvent extends DeadLetterSuppression with NoSerializationVerificationNeeded - final case class Subscribe(subscription: Subscription) extends PublisherEvent - final case class CancelSubscription(subscription: Subscription) extends PublisherEvent - final case class RequestMore(subscription: Subscription, elements: Long) extends PublisherEvent - - sealed trait SubscriberEvent extends DeadLetterSuppression with NoSerializationVerificationNeeded - final case class OnSubscribe(subscription: Subscription) extends SubscriberEvent - final case class OnNext[I](element: I) extends SubscriberEvent - final case object OnComplete extends SubscriberEvent - final case class OnError(cause: Throwable) extends SubscriberEvent + import TestPublisher._ final case class CompletedSubscription[T](subscriber: Subscriber[T]) extends Subscription { override def request(elements: Long): Unit = subscriber.onComplete() diff --git a/akka-stream-testkit/src/test/scala/akka/stream/testkit/ScriptedTest.scala b/akka-stream-testkit/src/test/scala/akka/stream/testkit/ScriptedTest.scala index 54fe63dc9f..f17d820072 100644 --- a/akka-stream-testkit/src/test/scala/akka/stream/testkit/ScriptedTest.scala +++ b/akka-stream-testkit/src/test/scala/akka/stream/testkit/ScriptedTest.scala @@ -4,16 +4,16 @@ package akka.stream.testkit import akka.actor.ActorSystem -import akka.stream.ActorMaterializerSettings -import akka.stream.scaladsl.{ Sink, Source, Flow } -import akka.stream.testkit._ -import akka.stream.testkit.StreamTestKit._ +import akka.stream.testkit.TestPublisher._ +import akka.stream.testkit.TestSubscriber._ +import akka.stream.{ ActorMaterializer, ActorMaterializerSettings } +import akka.stream.scaladsl.{ Flow, Sink, Source } import org.reactivestreams.Publisher import org.scalatest.Matchers + import scala.annotation.tailrec import scala.concurrent.duration._ import scala.concurrent.forkjoin.ThreadLocalRandom -import akka.stream.ActorMaterializer trait ScriptedTest extends Matchers { diff --git a/akka-stream-testkit/src/test/scala/akka/stream/testkit/TestPublisherSubscriberSpec.scala b/akka-stream-testkit/src/test/scala/akka/stream/testkit/TestPublisherSubscriberSpec.scala new file mode 100644 index 0000000000..97670e8ee6 --- /dev/null +++ b/akka-stream-testkit/src/test/scala/akka/stream/testkit/TestPublisherSubscriberSpec.scala @@ -0,0 +1,62 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.testkit + +import akka.stream.scaladsl.{ Sink, Source } +import akka.stream.testkit.TestPublisher._ +import akka.stream.testkit.TestSubscriber._ +import akka.stream.testkit.Utils._ +import akka.stream.{ ActorMaterializer, ActorMaterializerSettings } +import org.reactivestreams.Subscription + +class TestPublisherSubscriberSpec extends AkkaSpec { + + val settings = ActorMaterializerSettings(system) + .withInputBuffer(initialSize = 2, maxSize = 2) + + implicit val materializer = ActorMaterializer(settings) + + "TestPublisher and TestSubscriber" must { + + "have all events accessible from manual probes" in assertAllStagesStopped { + val upstream = TestPublisher.manualProbe[Int]() + val downstream = TestSubscriber.manualProbe[Int]() + Source(upstream).runWith(Sink.publisher)(materializer).subscribe(downstream) + + val upstreamSubscription = upstream.expectSubscription() + val downstreamSubscription: Subscription = downstream.expectEventPF { case OnSubscribe(sub) ⇒ sub } + + upstreamSubscription.sendNext(1) + downstreamSubscription.request(1) + upstream.expectEventPF { case RequestMore(_, e) ⇒ e } should ===(1) + downstream.expectEventPF { case OnNext(e) ⇒ e } should ===(1) + + upstreamSubscription.sendNext(1) + downstreamSubscription.request(1) + downstream.expectNextPF[Int] { case e: Int ⇒ e } should ===(1) + + upstreamSubscription.sendComplete() + downstream.expectEventPF { + case c @ OnComplete ⇒ + case _ ⇒ fail() + } + } + + "handle gracefully partial function that is not suitable" in assertAllStagesStopped { + val upstream = TestPublisher.manualProbe[Int]() + val downstream = TestSubscriber.manualProbe[Int]() + Source(upstream).runWith(Sink.publisher)(materializer).subscribe(downstream) + val upstreamSubscription = upstream.expectSubscription() + val downstreamSubscription: Subscription = downstream.expectEventPF { case OnSubscribe(sub) ⇒ sub } + + upstreamSubscription.sendNext(1) + downstreamSubscription.request(1) + an[AssertionError] should be thrownBy upstream.expectEventPF { case Subscribe(e) ⇒ e } + an[AssertionError] should be thrownBy downstream.expectNextPF[String] { case e: String ⇒ e } + + upstreamSubscription.sendComplete() + } + + } +}