From a10bba9c84b7e14bc3fc13283ac976f2d927bb41 Mon Sep 17 00:00:00 2001 From: Alexander Golubev Date: Tue, 26 Jan 2016 11:47:30 -0500 Subject: [PATCH] =str #19291 add TestGraphStages for Sink and Source --- .../akka/stream/testkit/TestGraphStage.scala | 82 +++++++++++++++++++ .../akka/stream/io/InputStreamSinkSpec.scala | 67 +++++---------- .../stream/io/OutputStreamSourceSpec.scala | 42 ++-------- .../stream/scaladsl/QueueSourceSpec.scala | 36 +------- .../main/scala/akka/stream/impl/Sinks.scala | 2 +- .../main/scala/akka/stream/impl/Sources.scala | 2 +- .../stream/impl/io/InputStreamSinkStage.scala | 4 +- .../impl/io/OutputStreamSourceStage.scala | 4 +- 8 files changed, 118 insertions(+), 121 deletions(-) create mode 100644 akka-stream-testkit/src/main/scala/akka/stream/testkit/TestGraphStage.scala diff --git a/akka-stream-testkit/src/main/scala/akka/stream/testkit/TestGraphStage.scala b/akka-stream-testkit/src/main/scala/akka/stream/testkit/TestGraphStage.scala new file mode 100644 index 0000000000..e809752844 --- /dev/null +++ b/akka-stream-testkit/src/main/scala/akka/stream/testkit/TestGraphStage.scala @@ -0,0 +1,82 @@ +package akka.stream.testkit + +import akka.actor.NoSerializationVerificationNeeded +import akka.stream.scaladsl.Source +import akka.stream.stage.{ OutHandler, GraphStageWithMaterializedValue, InHandler } +import akka.stream._ +import akka.testkit.TestProbe + +object GraphStageMessages { + case object Push extends NoSerializationVerificationNeeded + case object UpstreamFinish extends NoSerializationVerificationNeeded + case class Failure(ex: Throwable) extends NoSerializationVerificationNeeded + + case object Pull extends NoSerializationVerificationNeeded + case object DownstreamFinish extends NoSerializationVerificationNeeded +} + +object TestSinkStage { + def apply[T, M](stageUnderTest: GraphStageWithMaterializedValue[SinkShape[T], M], + probe: TestProbe) = new TestSinkStage(stageUnderTest, probe) +} + +private[testkit] class TestSinkStage[T, M](stageUnderTest: GraphStageWithMaterializedValue[SinkShape[T], M], + probe: TestProbe) + extends GraphStageWithMaterializedValue[SinkShape[T], M] { + + val in = Inlet[T]("testSinkStage.in") + override val shape: SinkShape[T] = SinkShape.of(in) + + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { + stageUnderTest.shape.in.id = in.id + val (logic, mat) = stageUnderTest.createLogicAndMaterializedValue(inheritedAttributes) + + val inHandler = logic.handlers(in.id).asInstanceOf[InHandler] + logic.handlers(in.id) = new InHandler { + override def onPush(): Unit = { + probe.ref ! GraphStageMessages.Push + inHandler.onPush() + } + override def onUpstreamFinish(): Unit = { + probe.ref ! GraphStageMessages.UpstreamFinish + inHandler.onUpstreamFinish() + } + override def onUpstreamFailure(ex: Throwable): Unit = { + probe.ref ! GraphStageMessages.Failure(ex) + inHandler.onUpstreamFailure(ex) + } + } + (logic, mat) + } +} + +object TestSourceStage { + def apply[T, M](stageUnderTest: GraphStageWithMaterializedValue[SourceShape[T], M], + probe: TestProbe) = Source.fromGraph(new TestSourceStage(stageUnderTest, probe)) +} + +private[testkit] class TestSourceStage[T, M](stageUnderTest: GraphStageWithMaterializedValue[SourceShape[T], M], + probe: TestProbe) + extends GraphStageWithMaterializedValue[SourceShape[T], M] { + + val out = Outlet[T]("testSourceStage.out") + override val shape: SourceShape[T] = SourceShape.of(out) + + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { + stageUnderTest.shape.out.id = out.id + val (logic, mat) = stageUnderTest.createLogicAndMaterializedValue(inheritedAttributes) + + val outHandler = logic.handlers(out.id).asInstanceOf[OutHandler] + logic.handlers(out.id) = new OutHandler { + override def onPull(): Unit = { + probe.ref ! GraphStageMessages.Pull + outHandler.onPull() + } + override def onDownstreamFinish(): Unit = { + probe.ref ! GraphStageMessages.DownstreamFinish + outHandler.onDownstreamFinish() + } + } + (logic, mat) + } +} \ No newline at end of file diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala index 6940cf7861..45b4b8092e 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala @@ -6,18 +6,18 @@ package akka.stream.io import java.io.{ IOException, InputStream } import java.util.concurrent.TimeoutException -import akka.actor.{ ActorSystem, NoSerializationVerificationNeeded } +import akka.actor.ActorSystem import akka.stream._ import akka.stream.impl.StreamSupervisor.Children import akka.stream.impl.io.InputStreamSinkStage import akka.stream.impl.{ ActorMaterializerImpl, StreamSupervisor } -import akka.stream.scaladsl.{ Source, Keep, Sink, StreamConverters } -import akka.stream.stage.InHandler -import akka.stream.testkit.AkkaSpec +import akka.stream.scaladsl.{ Keep, Source, StreamConverters } import akka.stream.testkit.Utils._ import akka.stream.testkit.scaladsl.TestSource +import akka.stream.testkit.{ AkkaSpec, GraphStageMessages, TestSinkStage } import akka.testkit.TestProbe import akka.util.ByteString + import scala.concurrent.duration._ import scala.concurrent.forkjoin.ThreadLocalRandom import scala.concurrent.{ Await, Future } @@ -39,44 +39,12 @@ class InputStreamSinkSpec extends AkkaSpec(UnboundedMailboxConfig) { val byteString = randomByteString(3) val byteArray = byteString.toArray - private[this] def readN(is: InputStream, n: Int): (Int, ByteString) = { + def readN(is: InputStream, n: Int): (Int, ByteString) = { val buf = new Array[Byte](n) val r = is.read(buf) (r, ByteString.fromArray(buf, 0, r)) } - - object InputStreamSinkTestMessages { - case object Push extends NoSerializationVerificationNeeded - case object Finish extends NoSerializationVerificationNeeded - case class Failure(ex: Throwable) extends NoSerializationVerificationNeeded - } - - def testSink(probe: TestProbe): Sink[ByteString, InputStream] = { - class InputStreamSinkTestStage(val timeout: FiniteDuration) - extends InputStreamSinkStage(timeout) { - - override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { - val (logic, inputStream) = super.createLogicAndMaterializedValue(inheritedAttributes) - val inHandler = logic.handlers(in.id).asInstanceOf[InHandler] - logic.handlers(in.id) = new InHandler { - override def onPush(): Unit = { - probe.ref ! InputStreamSinkTestMessages.Push - inHandler.onPush() - } - override def onUpstreamFinish(): Unit = { - probe.ref ! InputStreamSinkTestMessages.Finish - inHandler.onUpstreamFinish() - } - override def onUpstreamFailure(ex: Throwable): Unit = { - probe.ref ! InputStreamSinkTestMessages.Failure(ex) - inHandler.onUpstreamFailure(ex) - } - } - (logic, inputStream) - } - } - Sink.fromGraph(new InputStreamSinkTestStage(timeout)) - } + def testSink(probe: TestProbe) = TestSinkStage(new InputStreamSinkStage(timeout), probe) "InputStreamSink" must { "read bytes from InputStream" in assertAllStagesStopped { @@ -88,9 +56,10 @@ class InputStreamSinkSpec extends AkkaSpec(UnboundedMailboxConfig) { "read bytes correctly if requested by InputStream not in chunk size" in assertAllStagesStopped { val sinkProbe = TestProbe() val byteString2 = randomByteString(3) - val inputStream = Source(byteString :: byteString2 :: Nil).runWith(testSink(sinkProbe)) + val inputStream = Source(byteString :: byteString2 :: Nil) + .runWith(testSink(sinkProbe)) - sinkProbe.expectMsgAllOf(InputStreamSinkTestMessages.Push, InputStreamSinkTestMessages.Push) + sinkProbe.expectMsgAllOf(GraphStageMessages.Push, GraphStageMessages.Push) readN(inputStream, 2) should ===((2, byteString.take(2))) readN(inputStream, 2) should ===((2, byteString.drop(2) ++ byteString2.take(1))) @@ -142,14 +111,15 @@ class InputStreamSinkSpec extends AkkaSpec(UnboundedMailboxConfig) { "return all data when upstream is completed" in assertAllStagesStopped { val sinkProbe = TestProbe() - val (probe, inputStream) = TestSource.probe[ByteString].toMat(testSink(sinkProbe))(Keep.both).run() + val (probe, inputStream) = TestSource.probe[ByteString] + .toMat(testSink(sinkProbe))(Keep.both).run() val bytes = randomByteString(1) probe.sendNext(bytes) - sinkProbe.expectMsg(InputStreamSinkTestMessages.Push) + sinkProbe.expectMsg(GraphStageMessages.Push) probe.sendComplete() - sinkProbe.expectMsg(InputStreamSinkTestMessages.Finish) + sinkProbe.expectMsg(GraphStageMessages.UpstreamFinish) readN(inputStream, 3) should ===((1, bytes)) } @@ -177,10 +147,11 @@ class InputStreamSinkSpec extends AkkaSpec(UnboundedMailboxConfig) { "successfully read several chunks at once" in assertAllStagesStopped { val bytes = List.fill(4)(randomByteString(4)) val sinkProbe = TestProbe() - val inputStream = Source[ByteString](bytes).runWith(testSink(sinkProbe)) + val inputStream = Source[ByteString](bytes) + .runWith(testSink(sinkProbe)) //need to wait while all elements arrive to sink - bytes foreach { _ ⇒ sinkProbe.expectMsg(InputStreamSinkTestMessages.Push) } + bytes foreach { _ ⇒ sinkProbe.expectMsg(GraphStageMessages.Push) } for (i ← 0 to 1) readN(inputStream, 8) should ===((8, bytes(i * 2) ++ bytes(i * 2 + 1))) @@ -195,7 +166,7 @@ class InputStreamSinkSpec extends AkkaSpec(UnboundedMailboxConfig) { val inputStream = Source(bytes1 :: bytes2 :: Nil).runWith(testSink(sinkProbe)) //need to wait while both elements arrive to sink - sinkProbe.expectMsgAllOf(InputStreamSinkTestMessages.Push, InputStreamSinkTestMessages.Push) + sinkProbe.expectMsgAllOf(GraphStageMessages.Push, GraphStageMessages.Push) readN(inputStream, 15) should ===((15, bytes1 ++ bytes2.take(5))) readN(inputStream, 15) should ===((5, bytes2.drop(5))) @@ -218,12 +189,12 @@ class InputStreamSinkSpec extends AkkaSpec(UnboundedMailboxConfig) { val ex = new RuntimeException("Stream failed.") with NoStackTrace probe.sendNext(byteString) - sinkProbe.expectMsg(InputStreamSinkTestMessages.Push) + sinkProbe.expectMsg(GraphStageMessages.Push) readN(inputStream, byteString.size) should ===((byteString.size, byteString)) probe.sendError(ex) - sinkProbe.expectMsg(InputStreamSinkTestMessages.Failure(ex)) + sinkProbe.expectMsg(GraphStageMessages.Failure(ex)) val e = intercept[IOException] { Await.result(Future(inputStream.read()), timeout) } e.getCause should ===(ex) } diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala index e18e80c224..f3433d9ddf 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala @@ -3,16 +3,15 @@ */ package akka.stream.io -import java.io.{ IOException, OutputStream } +import java.io.IOException import java.util.concurrent.TimeoutException -import akka.actor.{ ActorSystem, NoSerializationVerificationNeeded } +import akka.actor.ActorSystem import akka.stream._ import akka.stream.impl.StreamSupervisor.Children import akka.stream.impl.io.OutputStreamSourceStage import akka.stream.impl.{ ActorMaterializerImpl, StreamSupervisor } -import akka.stream.scaladsl.{ Keep, Source, StreamConverters } -import akka.stream.stage.OutHandler +import akka.stream.scaladsl.{ Keep, StreamConverters } import akka.stream.testkit.Utils._ import akka.stream.testkit._ import akka.stream.testkit.scaladsl.TestSink @@ -40,34 +39,6 @@ class OutputStreamSourceSpec extends AkkaSpec(UnboundedMailboxConfig) { def expectSuccess[T](f: Future[T], value: T) = Await.result(f, timeout) should be(value) - object OutputStreamSourceTestMessages { - case object Pull extends NoSerializationVerificationNeeded - case object Finish extends NoSerializationVerificationNeeded - } - - def testSource(probe: TestProbe): Source[ByteString, OutputStream] = { - class OutputStreamSourceTestStage(val timeout: FiniteDuration) - extends OutputStreamSourceStage(timeout) { - - override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { - val (logic, inputStream) = super.createLogicAndMaterializedValue(inheritedAttributes) - val outHandler = logic.handlers(out.id).asInstanceOf[OutHandler] - logic.handlers(out.id) = new OutHandler { - override def onDownstreamFinish(): Unit = { - probe.ref ! OutputStreamSourceTestMessages.Finish - outHandler.onDownstreamFinish() - } - override def onPull(): Unit = { - probe.ref ! OutputStreamSourceTestMessages.Pull - outHandler.onPull() - } - } - (logic, inputStream) - } - } - Source.fromGraph(new OutputStreamSourceTestStage(timeout)) - } - "OutputStreamSource" must { "read bytes from OutputStream" in assertAllStagesStopped { val (outputStream, probe) = StreamConverters.asOutputStream().toMat(TestSink.probe[ByteString])(Keep.both).run @@ -161,18 +132,19 @@ class OutputStreamSourceSpec extends AkkaSpec(UnboundedMailboxConfig) { "throw IOException when writing to the stream after the subscriber has cancelled the reactive stream" in assertAllStagesStopped { val sourceProbe = TestProbe() - val (outputStream, probe) = testSource(sourceProbe).toMat(TestSink.probe[ByteString])(Keep.both).run + val (outputStream, probe) = TestSourceStage(new OutputStreamSourceStage(timeout), sourceProbe) + .toMat(TestSink.probe[ByteString])(Keep.both).run val s = probe.expectSubscription() outputStream.write(bytesArray) s.request(1) - sourceProbe.expectMsg(OutputStreamSourceTestMessages.Pull) + sourceProbe.expectMsg(GraphStageMessages.Pull) probe.expectNext(byteString) s.cancel() - sourceProbe.expectMsg(OutputStreamSourceTestMessages.Finish) + sourceProbe.expectMsg(GraphStageMessages.DownstreamFinish) the[Exception] thrownBy outputStream.write(bytesArray) shouldBe a[IOException] } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/QueueSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/QueueSourceSpec.scala index 71dab7f8e1..0d9a821731 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/QueueSourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/QueueSourceSpec.scala @@ -9,7 +9,7 @@ import akka.stream._ import akka.stream.impl.QueueSource import akka.stream.stage.OutHandler import akka.stream.testkit.Utils._ -import akka.stream.testkit.{ AkkaSpec, TestSubscriber } +import akka.stream.testkit._ import akka.testkit.TestProbe import scala.concurrent.duration._ import scala.concurrent.{ Future, _ } @@ -25,35 +25,6 @@ class QueueSourceSpec extends AkkaSpec { expectMsg(QueueOfferResult.Enqueued) } - object SourceTestMessages { - case object Pull extends NoSerializationVerificationNeeded - case object Finish extends NoSerializationVerificationNeeded - } - - def testSource(maxBuffer: Int, overflowStrategy: OverflowStrategy, probe: TestProbe): Source[Int, SourceQueue[Int]] = { - class QueueSourceTestStage(maxBuffer: Int, overflowStrategy: OverflowStrategy) - extends QueueSource[Int](maxBuffer, overflowStrategy) { - - override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { - val (logic, inputStream) = super.createLogicAndMaterializedValue(inheritedAttributes) - val outHandler = logic.handlers(out.id).asInstanceOf[OutHandler] - logic.handlers(out.id) = new OutHandler { - override def onPull(): Unit = { - probe.ref ! SourceTestMessages.Pull - outHandler.onPull() - } - override def onDownstreamFinish(): Unit = { - probe.ref ! SourceTestMessages.Finish - outHandler.onDownstreamFinish() - } - - } - (logic, inputStream) - } - } - Source.fromGraph(new QueueSourceTestStage(maxBuffer, overflowStrategy)) - } - "A QueueSourceSpec" must { "emit received messages to the stream" in { @@ -139,11 +110,12 @@ class QueueSourceSpec extends AkkaSpec { "remember pull from downstream to send offered element immediately" in assertAllStagesStopped { val s = TestSubscriber.manualProbe[Int]() val probe = TestProbe() - val queue = testSource(1, OverflowStrategy.dropHead, probe).to(Sink.fromSubscriber(s)).run() + val queue = TestSourceStage(new QueueSource[Int](1, OverflowStrategy.dropHead), probe) + .to(Sink.fromSubscriber(s)).run() val sub = s.expectSubscription sub.request(1) - probe.expectMsg(SourceTestMessages.Pull) + probe.expectMsg(GraphStageMessages.Pull) assertSuccess(queue.offer(1)) s.expectNext(1) sub.cancel() diff --git a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala index 1d3553f374..3bad190dae 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Sinks.scala @@ -245,7 +245,7 @@ private[akka] final class HeadOptionStage[T] extends GraphStageWithMaterializedV /** * INTERNAL API */ -private[akka] class QueueSink[T]() extends GraphStageWithMaterializedValue[SinkShape[T], SinkQueue[T]] { +final private[stream] class QueueSink[T]() extends GraphStageWithMaterializedValue[SinkShape[T], SinkQueue[T]] { type Requested[E] = Promise[Option[E]] val in = Inlet[T]("queueSink.in") diff --git a/akka-stream/src/main/scala/akka/stream/impl/Sources.scala b/akka-stream/src/main/scala/akka/stream/impl/Sources.scala index fd3d4ae0e1..160f8e5d6b 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Sources.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Sources.scala @@ -15,7 +15,7 @@ import scala.compat.java8.FutureConverters._ /** * INTERNAL API */ -private[akka] class QueueSource[T](maxBuffer: Int, overflowStrategy: OverflowStrategy) extends GraphStageWithMaterializedValue[SourceShape[T], SourceQueue[T]] { +final private[stream] class QueueSource[T](maxBuffer: Int, overflowStrategy: OverflowStrategy) extends GraphStageWithMaterializedValue[SourceShape[T], SourceQueue[T]] { type Offered = Promise[QueueOfferResult] val out = Outlet[T]("queueSource.out") diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala b/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala index 0578e21aa7..ff852f03d1 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala @@ -13,7 +13,7 @@ import scala.annotation.tailrec import scala.concurrent.duration.FiniteDuration import akka.stream.{ Inlet, SinkShape, Attributes } -private[akka] object InputStreamSinkStage { +private[stream] object InputStreamSinkStage { sealed trait AdapterToStageMessage case object ReadElementAcknowledgement extends AdapterToStageMessage @@ -32,7 +32,7 @@ private[akka] object InputStreamSinkStage { /** * INTERNAL API */ -private[akka] class InputStreamSinkStage(readTimeout: FiniteDuration) extends GraphStageWithMaterializedValue[SinkShape[ByteString], InputStream] { +final private[stream] class InputStreamSinkStage(readTimeout: FiniteDuration) extends GraphStageWithMaterializedValue[SinkShape[ByteString], InputStream] { val in = Inlet[ByteString]("InputStreamSink.in") override val shape: SinkShape[ByteString] = SinkShape.of(in) diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala b/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala index accacb2d01..d222250380 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala @@ -18,7 +18,7 @@ import scala.concurrent.{ Await, Future, Promise } import scala.util.control.NonFatal import scala.util.{ Failure, Success, Try } -private[akka] object OutputStreamSourceStage { +private[stream] object OutputStreamSourceStage { sealed trait AdapterToStageMessage case object Flush extends AdapterToStageMessage case object Close extends AdapterToStageMessage @@ -32,7 +32,7 @@ private[akka] object OutputStreamSourceStage { } } -private[akka] class OutputStreamSourceStage(writeTimeout: FiniteDuration) extends GraphStageWithMaterializedValue[SourceShape[ByteString], OutputStream] { +final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration) extends GraphStageWithMaterializedValue[SourceShape[ByteString], OutputStream] { val out = Outlet[ByteString]("OutputStreamSource.out") override val shape: SourceShape[ByteString] = SourceShape.of(out)