diff --git a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala index d35bb93861..158c300cf3 100644 --- a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala +++ b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala @@ -21,6 +21,8 @@ import akka.testkit.TestActor.AutoPilot import akka.util.JavaDurationConverters import akka.util.ccompat._ +import scala.reflect.ClassTag + /** * Provides factory methods for various Publishers. */ @@ -30,7 +32,7 @@ object TestPublisher { trait PublisherEvent extends DeadLetterSuppression with NoSerializationVerificationNeeded final case class Subscribe(subscription: Subscription) extends PublisherEvent - final case class CancelSubscription(subscription: Subscription) extends PublisherEvent + final case class CancelSubscription(subscription: Subscription, cause: Throwable) extends PublisherEvent final case class RequestMore(subscription: Subscription, elements: Long) extends PublisherEvent final object SubscriptionDone extends NoSerializationVerificationNeeded @@ -258,6 +260,24 @@ object TestPublisher { subscription.expectCancellation() this } + + def expectCancellationWithCause(expectedCause: Throwable): Self = { + val cause = subscription.expectCancellation() + assert(cause == expectedCause, s"Expected cancellation cause to be $expectedCause but was $cause") + this + } + def expectCancellationWithCause[E <: Throwable: ClassTag](): E = subscription.expectCancellation() match { + case e: E => e + case cause => + throw new AssertionError( + s"Expected cancellation cause to be of type ${scala.reflect.classTag[E]} but was ${cause.getClass}: $cause") + } + + /** + * Java API + */ + def expectCancellationWithCause[E <: Throwable](causeClass: Class[E]): E = + expectCancellationWithCause()(ClassTag(causeClass)) } } @@ -799,6 +819,15 @@ object TestSubscriber { this } + def cancel(cause: Throwable): Self = subscription match { + case s: SubscriptionWithCancelException => + s.cancel(cause) + this + case _ => + throw new IllegalStateException( + "Tried to cancel with cause but upstream subscription doesn't support cancellation with cause") + } + /** * Request and expect a stream element. */ @@ -834,19 +863,20 @@ private[testkit] object StreamTestKit { } final case class PublisherProbeSubscription[I](subscriber: Subscriber[_ >: I], publisherProbe: TestProbe) - extends Subscription { + extends Subscription + with SubscriptionWithCancelException { def request(elements: Long): Unit = publisherProbe.ref ! RequestMore(this, elements) - def cancel(): Unit = publisherProbe.ref ! CancelSubscription(this) + def cancel(cause: Throwable): Unit = publisherProbe.ref ! CancelSubscription(this, cause) def expectRequest(n: Long): Unit = publisherProbe.expectMsg(RequestMore(this, n)) def expectRequest(): Long = publisherProbe.expectMsgPF(hint = "expecting request() signal") { case RequestMore(sub, n) if sub eq this => n } - def expectCancellation(): Unit = publisherProbe.fishForMessage(hint = "Expecting cancellation") { - case CancelSubscription(sub) if sub eq this => true - case RequestMore(sub, _) if sub eq this => false - } + def expectCancellation(): Throwable = + publisherProbe.fishForSpecificMessage[Throwable](hint = "Expecting cancellation") { + case CancelSubscription(sub, cause) if sub eq this => cause + } def sendNext(element: I): Unit = subscriber.onNext(element) def sendComplete(): Unit = subscriber.onComplete()