diff --git a/akka-docs-dev/rst/scala/code/docs/stream/StreamTestKitDocSpec.scala b/akka-docs-dev/rst/scala/code/docs/stream/StreamTestKitDocSpec.scala index f6aa372d3d..1bbaea1a4a 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/StreamTestKitDocSpec.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/StreamTestKitDocSpec.scala @@ -153,7 +153,7 @@ class StreamTestKitDocSpec extends AkkaSpec { sub.expectNextUnordered(1, 2, 3) pub.sendError(new Exception("Power surge in the linear subroutine C-47!")) - val ex = sub.expectError + val ex = sub.expectError() assert(ex.getMessage.contains("C-47")) //#test-source-and-sink } diff --git a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala index f2f37b4716..b1899bbeb6 100644 --- a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala +++ b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala @@ -3,16 +3,19 @@ */ package akka.stream.testkit -import scala.language.existentials -import akka.actor.{ NoSerializationVerificationNeeded, ActorSystem, DeadLetterSuppression } +import java.util.concurrent.TimeoutException + +import akka.actor.{ ActorSystem, DeadLetterSuppression, NoSerializationVerificationNeeded } import akka.stream._ +import akka.stream.impl.StreamLayout.Module import akka.stream.impl._ import akka.testkit.TestProbe -import akka.stream.impl.StreamLayout.Module -import akka.stream.scaladsl._ import org.reactivestreams.{ Publisher, Subscriber, Subscription } + +import scala.annotation.tailrec import scala.collection.immutable import scala.concurrent.duration._ +import scala.language.existentials /** * Provides factory methods for various Publishers. @@ -185,23 +188,43 @@ object TestSubscriber { private val probe = TestProbe() + @volatile private var _subscription: Subscription = _ + private val self = this.asInstanceOf[Self] /** - * Expect and return a Subscription. + * Expect and return a [[Subscription]]. */ - def expectSubscription(): Subscription = probe.expectMsgType[OnSubscribe].subscription + def expectSubscription(): Subscription = { + _subscription = probe.expectMsgType[OnSubscribe].subscription + _subscription + } /** - * Expect [[SubscriberEvent]]. + * Expect and return [[SubscriberEvent]] (any of: `OnSubscribe`, `OnNext`, `OnError` or `OnComplete`). */ - def expectEvent(event: SubscriberEvent): Self = { + def expectEvent(): SubscriberEvent = + probe.expectMsgType[SubscriberEvent] + + /** + * Fluent DSL + * + * Expect [[SubscriberEvent]] (any of: `OnSubscribe`, `OnNext`, `OnError` or `OnComplete`). + */ + def expectEvent(event: SubscriberEvent): Self = { // TODO it's more "signal" than event, shall we rename? -- ktoso probe.expectMsg(event) self } /** - * Expect a data element. + * Expect and return a stream element. + */ + def expectNext(): I = probe.expectMsgType[OnNext[I]].element + + /** + * Fluent DSL + * + * Expect a stream element. */ def expectNext(element: I): Self = { probe.expectMsg(OnNext(element)) @@ -209,24 +232,49 @@ object TestSubscriber { } /** - * Expect multiple data elements. + * Fluent DSL + * + * Expect multiple stream elements. */ @annotation.varargs def expectNext(e1: I, e2: I, es: I*): Self = expectNextN((e1 +: e2 +: es).map(identity)(collection.breakOut)) + /** + * Fluent DSL + * + * Expect multiple stream elements in arbitrary order. + */ @annotation.varargs def expectNextUnordered(e1: I, e2: I, es: I*): Self = expectNextUnorderedN((e1 +: e2 +: es).map(identity)(collection.breakOut)) /** - * Expect and return a data element. + * Expect and return the next `n` stream elements. */ - def expectNext(): I = probe.expectMsgType[OnNext[I]].element + def expectNextN(n: Long): immutable.Seq[I] = { + val b = immutable.Seq.newBuilder[I] + var i = 0 + while (i < n) { + val next = probe.expectMsgType[OnNext[I]] + b += next.element + i += 1 + } + b.result() + } + + /** + * Fluent DSL + * Expect the given elements to be signalled in order. + */ def expectNextN(all: immutable.Seq[I]): Self = { all.foreach(e ⇒ probe.expectMsg(OnNext(e))) self } + /** + * Fluent DSL + * Expect the given elements to be signalled in any order. + */ def expectNextUnorderedN(all: immutable.Seq[I]): Self = { @annotation.tailrec def expectOneOf(all: immutable.Seq[I]): Unit = all match { case Nil ⇒ @@ -241,6 +289,8 @@ object TestSubscriber { } /** + * Fluent DSL + * * Expect completion. */ def expectComplete(): Self = { @@ -249,6 +299,13 @@ object TestSubscriber { } /** + * Expect and return the signalled [[Throwable]]. + */ + def expectError(): Throwable = probe.expectMsgType[OnError].cause + + /** + * Fluent DSL + * * Expect given [[Throwable]]. */ def expectError(cause: Throwable): Self = { @@ -257,55 +314,157 @@ object TestSubscriber { } /** - * Expect and return a [[Throwable]]. + * Expect subscription to be followed immediatly by an error signal. + * + * By default `1` demand will be signalled in order to wake up a possibly lazy upstream. + * + * See also [[#expectSubscriptionAndError(Boolean)]] if no demand should be signalled. */ - def expectError(): Throwable = probe.expectMsgType[OnError].cause + def expectSubscriptionAndError(): Throwable = { + expectSubscriptionAndError(true) + } - def expectSubscriptionAndError(cause: Throwable): Self = { + /** + * Expect subscription to be followed immediatly by an error signal. + * + * Depending on the `signalDemand` parameter demand may be signalled immediatly after obtaining the subscription + * in order to wake up a possibly lazy upstream. You can disable this by setting the `signalDemand` parameter to `false`. + * + * See also [[#expectSubscriptionAndError()]]. + */ + def expectSubscriptionAndError(signalDemand: Boolean): Throwable = { val sub = expectSubscription() - sub.request(1) + if (signalDemand) sub.request(1) + expectError() + } + + /** + * Fluent DSL + * + * Expect subscription followed by immediate stream completion. + * + * By default `1` demand will be signalled in order to wake up a possibly lazy upstream. + * + * See also [[#expectSubscriptionAndComplete(Throwable, Boolean)]] if no demand should be signalled. + */ + def expectSubscriptionAndError(cause: Throwable): Self = + expectSubscriptionAndError(cause, true) + + /** + * Fluent DSL + * + * Expect subscription followed by immediate stream completion. + * By default `1` demand will be signalled in order to wake up a possibly lazy upstream + * + * See also [[#expectSubscriptionAndError(Throwable)]]. + */ + def expectSubscriptionAndError(cause: Throwable, signalDemand: Boolean): Self = { + val sub = expectSubscription() + if (signalDemand) sub.request(1) expectError(cause) self } - def expectSubscriptionAndError(): Throwable = { - val sub = expectSubscription() - sub.request(1) - expectError() - } + /** + * Fluent DSL + * + * Expect subscription followed by immediate stream completion. + * By default `1` demand will be signalled in order to wake up a possibly lazy upstream + * + * See also [[#expectSubscriptionAndComplete(Boolean)]] if no demand should be signalled. + */ + def expectSubscriptionAndComplete(): Self = + expectSubscriptionAndComplete(true) - def expectSubscriptionAndComplete(): Self = { + /** + * Fluent DSL + * + * Expect subscription followed by immediate stream completion. + * + * Depending on the `signalDemand` parameter demand may be signalled immediatly after obtaining the subscription + * in order to wake up a possibly lazy upstream. You can disable this by setting the `signalDemand` parameter to `false`. + * + * See also [[#expectSubscriptionAndComplete]]. + */ + def expectSubscriptionAndComplete(signalDemand: Boolean): Self = { val sub = expectSubscription() - sub.request(1) + if (signalDemand) sub.request(1) expectComplete() self } - def expectNextOrError(element: I, cause: Throwable): Either[Throwable, I] = { - probe.fishForMessage(hint = s"OnNext($element) or ${cause.getClass.getName}") { - case OnNext(n) ⇒ true - case OnError(`cause`) ⇒ true + /** + * Fluent DSL + * + * Expect given next element or error signal, returning whichever was signalled. + */ + def expectNextOrError(): Either[Throwable, I] = { + probe.fishForMessage(hint = s"OnNext(_) or error") { + case OnNext(element) ⇒ true + case OnError(cause) ⇒ true } match { case OnNext(n: I @unchecked) ⇒ Right(n) case OnError(err) ⇒ Left(err) } } - def expectNextOrComplete(element: I): Self = { - probe.fishForMessage(hint = s"OnNext($element) or OnComplete") { + /** + * Fluent DSL + * Expect given next element or error signal. + */ + def expectNextOrError(element: I, cause: Throwable): Either[Throwable, I] = { + probe.fishForMessage(hint = s"OnNext($element) or ${cause.getClass.getName}") { + case OnNext(`element`) ⇒ true + case OnError(`cause`) ⇒ true + } match { + case OnNext(n: I @unchecked) ⇒ Right(n) + case OnError(err) ⇒ Left(err) + } + } + + /** + * Expect next element or stream completion - returning whichever was signalled. + */ + def expectNextOrComplete(): Either[OnComplete.type, I] = { + probe.fishForMessage(hint = s"OnNext(_) or OnComplete") { case OnNext(n) ⇒ true case OnComplete ⇒ true + } match { + case OnComplete ⇒ Left(OnComplete) + case OnNext(n: I @unchecked) ⇒ Right(n) + } + } + + /** + * Fluent DSL + * + * Expect given next element or stream completion. + */ + def expectNextOrComplete(element: I): Self = { + probe.fishForMessage(hint = s"OnNext($element) or OnComplete") { + case OnNext(`element`) ⇒ true + case OnComplete ⇒ true } self } + /** + * Fluent DSL + * + * Same as `expectNoMsg(remaining)`, but correctly treating the timeFactor. + */ def expectNoMsg(): Self = { probe.expectNoMsg() self } - def expectNoMsg(max: FiniteDuration): Self = { - probe.expectNoMsg(max) + /** + * Fluent DSL + * + * Assert that no message is received for the specified time. + */ + def expectNoMsg(remaining: FiniteDuration): Self = { + probe.expectNoMsg(remaining) self } @@ -315,12 +474,49 @@ object TestSubscriber { def receiveWhile[T](max: Duration = Duration.Undefined, idle: Duration = Duration.Inf, messages: Int = Int.MaxValue)(f: PartialFunction[SubscriberEvent, T]): immutable.Seq[T] = probe.receiveWhile(max, idle, messages)(f.asInstanceOf[PartialFunction[AnyRef, T]]) + /** + * Drains a given number of messages + */ def receiveWithin(max: FiniteDuration, messages: Int = Int.MaxValue): immutable.Seq[I] = probe.receiveWhile(max, max, messages) { case OnNext(i) ⇒ Some(i.asInstanceOf[I]) case _ ⇒ None }.flatten + /** + * Attempt to drain the stream into a strict collection (by requesting `Long.MaxValue` elements). + * + * '''Use with caution: Be warned that this may not be a good idea if the stream is infinite or its elements are very large!''' + */ + def toStrict(atMost: FiniteDuration): immutable.Seq[I] = { + val deadline = Deadline.now + atMost + val b = immutable.Seq.newBuilder[I] + + def checkDeadline(): Unit = { + if (deadline.isOverdue()) + throw new TimeoutException(s"toStrict did not drain the stream within $atMost! Accumulated elements: ${b.result()}") + } + + @tailrec def drain(): immutable.Seq[I] = + self.expectEvent() match { + case OnError(ex) ⇒ + throw new AssertionError(s"toStrict received OnError($ex) while draining stream! Accumulated elements: ${b.result()}") + case OnComplete ⇒ + checkDeadline() + b.result() + case OnNext(i: I @unchecked) ⇒ + checkDeadline() + b += i + drain() + } + + // if no subscription was obtained yet, we expect it + if (_subscription == null) self.expectSubscription() + _subscription.request(Long.MaxValue) + + drain() + } + def within[T](max: FiniteDuration)(f: ⇒ T): T = probe.within(0.seconds, max)(f) def onSubscribe(subscription: Subscription): Unit = probe.ref ! OnSubscribe(subscription) @@ -366,9 +562,6 @@ object TestSubscriber { */ private[testkit] object StreamTestKit { - import TestPublisher._ - import TestSubscriber._ - sealed trait PublisherEvent extends DeadLetterSuppression with NoSerializationVerificationNeeded final case class Subscribe(subscription: Subscription) extends PublisherEvent final case class CancelSubscription(subscription: Subscription) extends PublisherEvent diff --git a/akka-stream-testkit/src/main/scala/akka/stream/testkit/scaladsl/TestSink.scala b/akka-stream-testkit/src/main/scala/akka/stream/testkit/scaladsl/TestSink.scala index 4f93e20247..a50faf7186 100644 --- a/akka-stream-testkit/src/main/scala/akka/stream/testkit/scaladsl/TestSink.scala +++ b/akka-stream-testkit/src/main/scala/akka/stream/testkit/scaladsl/TestSink.scala @@ -7,6 +7,7 @@ import akka.actor.ActorSystem import akka.stream.Attributes.none import akka.stream._ import akka.stream.scaladsl._ +import akka.stream.testkit.TestSubscriber.Probe import akka.stream.testkit._ /** @@ -17,6 +18,7 @@ object TestSink { /** * A Sink that materialized to a [[TestSubscriber.Probe]]. */ - def probe[T]()(implicit system: ActorSystem) = new Sink[T, TestSubscriber.Probe[T]](new StreamTestKit.ProbeSink(none, SinkShape(Inlet("ProbeSink.in")))) + def probe[T]()(implicit system: ActorSystem): Sink[T, Probe[T]] = + new Sink[T, TestSubscriber.Probe[T]](new StreamTestKit.ProbeSink(none, SinkShape(Inlet("ProbeSink.in")))) } diff --git a/akka-stream-testkit/src/test/scala/akka/stream/testkit/StreamTestKitSpec.scala b/akka-stream-testkit/src/test/scala/akka/stream/testkit/StreamTestKitSpec.scala new file mode 100644 index 0000000000..09c43bb6a8 --- /dev/null +++ b/akka-stream-testkit/src/test/scala/akka/stream/testkit/StreamTestKitSpec.scala @@ -0,0 +1,125 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.testkit + +import akka.stream._ +import akka.stream.scaladsl.Source +import akka.stream.testkit.scaladsl.TestSink + +import scala.concurrent.duration._ + +class StreamTestKitSpec extends AkkaSpec { + + implicit val mat = ActorMaterializer() + + val ex = new Exception("Boom!") + + "A TestSink Probe" must { + "#toStrict" in { + Source(1 to 4).runWith(TestSink.probe()) + .toStrict(300.millis) should ===(List(1, 2, 3, 4)) + } + + "#toStrict with failing source" in { + val msg = intercept[AssertionError] { + Source(() ⇒ new Iterator[Int] { + var i = 0 + override def hasNext: Boolean = true + override def next(): Int = { + i += 1 + i match { + case 3 ⇒ throw ex + case n ⇒ n + } + } + }).runWith(TestSink.probe()) + .toStrict(300.millis) + }.getMessage + + msg should include("Boom!") + msg should include("List(1, 2)") + } + + "#toStrict when subscription was already obtained" in { + val p = Source(1 to 4).runWith(TestSink.probe()) + p.expectSubscription() + p.toStrict(300.millis) should ===(List(1, 2, 3, 4)) + } + + "#expectNextOrError with right element" in { + Source(1 to 4).runWith(TestSink.probe()) + .request(4) + .expectNextOrError(1, ex) + } + + "#expectNextOrError with right exception" in { + Source.failed[Int](ex).runWith(TestSink.probe()) + .request(4) + .expectNextOrError(1, ex) + } + + "#expectNextOrError fail if the next element is not the expected one" in { + intercept[AssertionError] { + Source(1 to 4).runWith(TestSink.probe()) + .request(4) + .expectNextOrError(100, ex) + }.getMessage should include("OnNext(1)") + } + + "#expectError" in { + Source.failed[Int](ex).runWith(TestSink.probe()) + .request(1) + .expectError() should ===(ex) + } + + "#expectError fail if no error signalled" in { + intercept[AssertionError] { + Source(1 to 4).runWith(TestSink.probe()) + .request(1) + .expectError() + }.getMessage should include("OnNext") + } + + "#expectComplete should fail if error signalled" in { + intercept[AssertionError] { + Source.failed[Int](ex).runWith(TestSink.probe()) + .request(1) + .expectComplete() + }.getMessage should include("OnError") + } + + "#expectComplete should fail if next element signalled" in { + intercept[AssertionError] { + Source(1 to 4).runWith(TestSink.probe()) + .request(1) + .expectComplete() + }.getMessage should include("OnNext") + } + + "#expectNextOrComplete with right element" in { + Source(1 to 4).runWith(TestSink.probe()) + .request(4) + .expectNextOrComplete(1) + } + + "#expectNextOrComplete with completion" in { + Source.single(1).runWith(TestSink.probe()) + .request(4) + .expectNextOrComplete(1) + .expectNextOrComplete(1337) + } + + "#expectNextN given a number of elements" in { + Source(1 to 4).runWith(TestSink.probe()) + .request(4) + .expectNextN(4) should ===(List(1, 2, 3, 4)) + } + + "#expectNextN given specific elements" in { + Source(1 to 4).runWith(TestSink.probe()) + .request(4) + .expectNextN(4) should ===(List(1, 2, 3, 4)) + } + } +} \ No newline at end of file diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanSpec.scala index 436d947e86..dbe83b5499 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanSpec.scala @@ -3,16 +3,15 @@ */ package akka.stream.scaladsl -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.concurrent.forkjoin.ThreadLocalRandom.{ current ⇒ random } -import scala.collection.immutable -import akka.stream.ActorMaterializer -import akka.stream.ActorMaterializerSettings +import akka.stream.testkit.scaladsl.TestSink +import akka.stream.{ActorAttributes, ActorMaterializer, ActorMaterializerSettings, Supervision} import akka.stream.testkit.AkkaSpec import akka.stream.testkit.Utils._ -import akka.stream.ActorAttributes -import akka.stream.Supervision + +import scala.collection.immutable +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.concurrent.forkjoin.ThreadLocalRandom.{current => random} class FlowScanSpec extends AkkaSpec { @@ -42,8 +41,8 @@ class FlowScanSpec extends AkkaSpec { } "emit values promptly" in { - val f = Source.single(1).concat(Source.lazyEmpty).scan(0)(_ + _).grouped(2).runWith(Sink.head) - Await.result(f, 1.second) should be(Seq(0, 1)) + Source.single(1).concat(Source.lazyEmpty).scan(0)(_ + _).grouped(2).runWith(TestSink.probe()) + .toStrict(1.second) should ===(Seq(0, 1)) } "fail properly" in { @@ -52,8 +51,8 @@ class FlowScanSpec extends AkkaSpec { require(current > 0) old + current }.withAttributes(supervisionStrategy(Supervision.restartingDecider)) - val f = Source(List(1, 3, -1, 5, 7)).via(scan).grouped(1000).runWith(Sink.head) - Await.result(f, 1.second) should be(Seq(0, 1, 4, 0, 5, 12)) + Source(List(1, 3, -1, 5, 7)).via(scan).runWith(TestSink.probe()) + .toStrict(1.second) should ===(Seq(0, 1, 4, 0, 5, 12)) } } }