diff --git a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala index bf512ba4f4..756c1e836d 100644 --- a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala +++ b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala @@ -315,6 +315,12 @@ object TestSubscriber { def receiveWhile[T](max: Duration = Duration.Undefined, idle: Duration = Duration.Inf, messages: Int = Int.MaxValue)(f: PartialFunction[SubscriberEvent, T]): immutable.Seq[T] = probe.receiveWhile(max, idle, messages)(f.asInstanceOf[PartialFunction[AnyRef, T]]) + def receiveWithin(max: FiniteDuration, messages: Int = Int.MaxValue): immutable.Seq[I] = + probe.receiveWhile(max, max, messages) { + case OnNext(i) ⇒ Some(i.asInstanceOf[I]) + case _ ⇒ None + }.flatten + def within[T](max: FiniteDuration)(f: ⇒ T): T = probe.within(0.seconds, max)(f) def onSubscribe(subscription: Subscription): Unit = probe.ref ! OnSubscribe(subscription) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala index b13d91efea..c39eb36068 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala @@ -15,9 +15,10 @@ import akka.stream.testkit._ import akka.stream.testkit.Utils._ import akka.testkit.{ EventFilter, TestProbe } import com.typesafe.config.ConfigFactory - import scala.concurrent.duration._ import scala.util.control.NoStackTrace +import akka.stream.testkit.scaladsl.TestSink +import akka.stream.testkit.scaladsl.TestSource class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug.receive=off\nakka.loglevel=INFO")) { @@ -222,8 +223,7 @@ class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug } "allow early finish" in assertAllStagesStopped { - val p = TestPublisher.manualProbe[Int]() - val p2 = Source(p). + val (p1, p2) = TestSource.probe[Int]. transform(() ⇒ new PushStage[Int, Int] { var s = "" override def onPush(element: Int, ctx: Context[Int]) = { @@ -233,18 +233,14 @@ class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug else ctx.push(element) } - }). - runWith(Sink.publisher) - val proc = p.expectSubscription - val c = TestSubscriber.manualProbe[Int]() - p2.subscribe(c) - val s = c.expectSubscription() - s.request(10) - proc.sendNext(1) - proc.sendNext(2) - c.expectNext(1) - c.expectComplete() - proc.expectCancellation() + }) + .toMat(TestSink.probe[Int])(Keep.both).run + p2.request(10) + p1.sendNext(1) + .sendNext(2) + p2.expectNext(1) + .expectComplete() + p1.expectCancellation() } "report error when exception is thrown" in assertAllStagesStopped { @@ -261,16 +257,13 @@ class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug } } }). - runWith(Sink.publisher) - val subscriber = TestSubscriber.manualProbe[Int]() - p2.subscribe(subscriber) - val subscription = subscriber.expectSubscription() + runWith(TestSink.probe[Int]) EventFilter[IllegalArgumentException]("two not allowed") intercept { - subscription.request(100) - subscriber.expectNext(1) - subscriber.expectNext(1) - subscriber.expectError().getMessage should be("two not allowed") - subscriber.expectNoMsg(200.millis) + p2.request(100) + .expectNext(1) + .expectNext(1) + .expectError().getMessage should be("two not allowed") + p2.expectNoMsg(200.millis) } } @@ -288,65 +281,56 @@ class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug } }). filter(elem ⇒ elem != 1). // it's undefined if element 1 got through before the error or not - runWith(Sink.publisher) - val subscriber = TestSubscriber.manualProbe[Int]() - p2.subscribe(subscriber) - val subscription = subscriber.expectSubscription() + runWith(TestSink.probe[Int]) EventFilter[IllegalArgumentException]("two not allowed") intercept { - subscription.request(100) - subscriber.expectNext(100) - subscriber.expectNext(101) - subscriber.expectComplete() - subscriber.expectNoMsg(200.millis) + p2.request(100) + .expectNext(100) + .expectNext(101) + .expectComplete() + .expectNoMsg(200.millis) } } "support cancel as expected" in assertAllStagesStopped { - val p = Source(List(1, 2, 3)).runWith(Sink.publisher) - val p2 = Source(p). + val p = Source(1 to 100).runWith(Sink.publisher) + val received = Source(p). transform(() ⇒ new StatefulStage[Int, Int] { override def initial = new State { override def onPush(elem: Int, ctx: Context[Int]) = emit(Iterator(elem, elem), ctx) } - }). - runWith(Sink.publisher) - val subscriber = TestSubscriber.manualProbe[Int]() - p2.subscribe(subscriber) - val subscription = subscriber.expectSubscription() - subscription.request(2) - subscriber.expectNext(1) - subscription.cancel() - subscriber.expectNext(1) - subscriber.expectNoMsg(500.millis) - subscription.request(2) - subscriber.expectNoMsg(200.millis) + }) + .runWith(TestSink.probe[Int]()) + .request(1000) + .expectNext(1) + .cancel() + .receiveWithin(1.second) + received.size should be < 200 + received.foldLeft((true, 1)) { + case ((flag, last), next) ⇒ (flag && (last == next || last == next - 1), next) + }._1 should be(true) } "support producing elements from empty inputs" in assertAllStagesStopped { val p = Source(List.empty[Int]).runWith(Sink.publisher) - val p2 = Source(p). + Source(p). transform(() ⇒ new StatefulStage[Int, Int] { override def initial = new State { override def onPush(elem: Int, ctx: Context[Int]) = ctx.pull() } override def onUpstreamFinish(ctx: Context[Int]) = terminationEmit(Iterator(1, 2, 3), ctx) - }). - runWith(Sink.publisher) - val subscriber = TestSubscriber.manualProbe[Int]() - p2.subscribe(subscriber) - val subscription = subscriber.expectSubscription() - subscription.request(4) - subscriber.expectNext(1) - subscriber.expectNext(2) - subscriber.expectNext(3) - subscriber.expectComplete() + }) + .runWith(TestSink.probe[Int]) + .request(4) + .expectNext(1) + .expectNext(2) + .expectNext(3) + .expectComplete() } "support converting onComplete into onError" in { - val subscriber = TestSubscriber.manualProbe[Int]() Source(List(5, 1, 2, 3)).transform(() ⇒ new PushStage[Int, Int] { var expectedNumberOfElements: Option[Int] = None var count = 0 @@ -365,15 +349,12 @@ class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug throw new RuntimeException(s"Expected $expected, got $count") with NoStackTrace case _ ⇒ ctx.finish() } - }).to(Sink(subscriber)).run() - - val subscription = subscriber.expectSubscription() - subscription.request(10) - - subscriber.expectNext(1) - subscriber.expectNext(2) - subscriber.expectNext(3) - subscriber.expectError().getMessage should be("Expected 5, got 3") + }).runWith(TestSink.probe[Int]) + .request(10) + .expectNext(1) + .expectNext(2) + .expectNext(3) + .expectError().getMessage should be("Expected 5, got 3") } "be safe to reuse" in { @@ -387,17 +368,15 @@ class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug } }) - val s1 = TestSubscriber.manualProbe[Int]() - flow.to(Sink(s1)).run() - s1.expectSubscription().request(3) - s1.expectNext(1, 2, 3) - s1.expectComplete() + flow.runWith(TestSink.probe[Int]) + .request(3) + .expectNext(1, 2, 3) + .expectComplete() - val s2 = TestSubscriber.manualProbe[Int]() - flow.to(Sink(s2)).run() - s2.expectSubscription().request(3) - s2.expectNext(1, 2, 3) - s2.expectComplete() + flow.runWith(TestSink.probe[Int]) + .request(3) + .expectNext(1, 2, 3) + .expectComplete() } "handle early cancelation" in assertAllStagesStopped {