Merge pull request #19619 from agolubev/agolubev-#19291-TestGraphStage

=str #19291 add TestGraphStages for Sink and Source
This commit is contained in:
Roland Kuhn 2016-02-12 14:30:29 +01:00
commit 432b77a9a2
8 changed files with 118 additions and 121 deletions

View file

@ -6,18 +6,18 @@ package akka.stream.io
import java.io.{ IOException, InputStream }
import java.util.concurrent.TimeoutException
import akka.actor.{ ActorSystem, NoSerializationVerificationNeeded }
import akka.actor.ActorSystem
import akka.stream._
import akka.stream.impl.StreamSupervisor.Children
import akka.stream.impl.io.InputStreamSinkStage
import akka.stream.impl.{ ActorMaterializerImpl, StreamSupervisor }
import akka.stream.scaladsl.{ Source, Keep, Sink, StreamConverters }
import akka.stream.stage.InHandler
import akka.stream.testkit.AkkaSpec
import akka.stream.scaladsl.{ Keep, Source, StreamConverters }
import akka.stream.testkit.Utils._
import akka.stream.testkit.scaladsl.TestSource
import akka.stream.testkit.{ AkkaSpec, GraphStageMessages, TestSinkStage }
import akka.testkit.TestProbe
import akka.util.ByteString
import scala.concurrent.duration._
import scala.concurrent.forkjoin.ThreadLocalRandom
import scala.concurrent.{ Await, Future }
@ -39,44 +39,12 @@ class InputStreamSinkSpec extends AkkaSpec(UnboundedMailboxConfig) {
val byteString = randomByteString(3)
val byteArray = byteString.toArray
private[this] def readN(is: InputStream, n: Int): (Int, ByteString) = {
def readN(is: InputStream, n: Int): (Int, ByteString) = {
val buf = new Array[Byte](n)
val r = is.read(buf)
(r, ByteString.fromArray(buf, 0, r))
}
object InputStreamSinkTestMessages {
case object Push extends NoSerializationVerificationNeeded
case object Finish extends NoSerializationVerificationNeeded
case class Failure(ex: Throwable) extends NoSerializationVerificationNeeded
}
def testSink(probe: TestProbe): Sink[ByteString, InputStream] = {
class InputStreamSinkTestStage(val timeout: FiniteDuration)
extends InputStreamSinkStage(timeout) {
override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = {
val (logic, inputStream) = super.createLogicAndMaterializedValue(inheritedAttributes)
val inHandler = logic.handlers(in.id).asInstanceOf[InHandler]
logic.handlers(in.id) = new InHandler {
override def onPush(): Unit = {
probe.ref ! InputStreamSinkTestMessages.Push
inHandler.onPush()
}
override def onUpstreamFinish(): Unit = {
probe.ref ! InputStreamSinkTestMessages.Finish
inHandler.onUpstreamFinish()
}
override def onUpstreamFailure(ex: Throwable): Unit = {
probe.ref ! InputStreamSinkTestMessages.Failure(ex)
inHandler.onUpstreamFailure(ex)
}
}
(logic, inputStream)
}
}
Sink.fromGraph(new InputStreamSinkTestStage(timeout))
}
def testSink(probe: TestProbe) = TestSinkStage(new InputStreamSinkStage(timeout), probe)
"InputStreamSink" must {
"read bytes from InputStream" in assertAllStagesStopped {
@ -88,9 +56,10 @@ class InputStreamSinkSpec extends AkkaSpec(UnboundedMailboxConfig) {
"read bytes correctly if requested by InputStream not in chunk size" in assertAllStagesStopped {
val sinkProbe = TestProbe()
val byteString2 = randomByteString(3)
val inputStream = Source(byteString :: byteString2 :: Nil).runWith(testSink(sinkProbe))
val inputStream = Source(byteString :: byteString2 :: Nil)
.runWith(testSink(sinkProbe))
sinkProbe.expectMsgAllOf(InputStreamSinkTestMessages.Push, InputStreamSinkTestMessages.Push)
sinkProbe.expectMsgAllOf(GraphStageMessages.Push, GraphStageMessages.Push)
readN(inputStream, 2) should ===((2, byteString.take(2)))
readN(inputStream, 2) should ===((2, byteString.drop(2) ++ byteString2.take(1)))
@ -142,14 +111,15 @@ class InputStreamSinkSpec extends AkkaSpec(UnboundedMailboxConfig) {
"return all data when upstream is completed" in assertAllStagesStopped {
val sinkProbe = TestProbe()
val (probe, inputStream) = TestSource.probe[ByteString].toMat(testSink(sinkProbe))(Keep.both).run()
val (probe, inputStream) = TestSource.probe[ByteString]
.toMat(testSink(sinkProbe))(Keep.both).run()
val bytes = randomByteString(1)
probe.sendNext(bytes)
sinkProbe.expectMsg(InputStreamSinkTestMessages.Push)
sinkProbe.expectMsg(GraphStageMessages.Push)
probe.sendComplete()
sinkProbe.expectMsg(InputStreamSinkTestMessages.Finish)
sinkProbe.expectMsg(GraphStageMessages.UpstreamFinish)
readN(inputStream, 3) should ===((1, bytes))
}
@ -177,10 +147,11 @@ class InputStreamSinkSpec extends AkkaSpec(UnboundedMailboxConfig) {
"successfully read several chunks at once" in assertAllStagesStopped {
val bytes = List.fill(4)(randomByteString(4))
val sinkProbe = TestProbe()
val inputStream = Source[ByteString](bytes).runWith(testSink(sinkProbe))
val inputStream = Source[ByteString](bytes)
.runWith(testSink(sinkProbe))
//need to wait while all elements arrive to sink
bytes foreach { _ sinkProbe.expectMsg(InputStreamSinkTestMessages.Push) }
bytes foreach { _ sinkProbe.expectMsg(GraphStageMessages.Push) }
for (i 0 to 1)
readN(inputStream, 8) should ===((8, bytes(i * 2) ++ bytes(i * 2 + 1)))
@ -195,7 +166,7 @@ class InputStreamSinkSpec extends AkkaSpec(UnboundedMailboxConfig) {
val inputStream = Source(bytes1 :: bytes2 :: Nil).runWith(testSink(sinkProbe))
//need to wait while both elements arrive to sink
sinkProbe.expectMsgAllOf(InputStreamSinkTestMessages.Push, InputStreamSinkTestMessages.Push)
sinkProbe.expectMsgAllOf(GraphStageMessages.Push, GraphStageMessages.Push)
readN(inputStream, 15) should ===((15, bytes1 ++ bytes2.take(5)))
readN(inputStream, 15) should ===((5, bytes2.drop(5)))
@ -218,12 +189,12 @@ class InputStreamSinkSpec extends AkkaSpec(UnboundedMailboxConfig) {
val ex = new RuntimeException("Stream failed.") with NoStackTrace
probe.sendNext(byteString)
sinkProbe.expectMsg(InputStreamSinkTestMessages.Push)
sinkProbe.expectMsg(GraphStageMessages.Push)
readN(inputStream, byteString.size) should ===((byteString.size, byteString))
probe.sendError(ex)
sinkProbe.expectMsg(InputStreamSinkTestMessages.Failure(ex))
sinkProbe.expectMsg(GraphStageMessages.Failure(ex))
val e = intercept[IOException] { Await.result(Future(inputStream.read()), timeout) }
e.getCause should ===(ex)
}

View file

@ -3,16 +3,15 @@
*/
package akka.stream.io
import java.io.{ IOException, OutputStream }
import java.io.IOException
import java.util.concurrent.TimeoutException
import akka.actor.{ ActorSystem, NoSerializationVerificationNeeded }
import akka.actor.ActorSystem
import akka.stream._
import akka.stream.impl.StreamSupervisor.Children
import akka.stream.impl.io.OutputStreamSourceStage
import akka.stream.impl.{ ActorMaterializerImpl, StreamSupervisor }
import akka.stream.scaladsl.{ Keep, Source, StreamConverters }
import akka.stream.stage.OutHandler
import akka.stream.scaladsl.{ Keep, StreamConverters }
import akka.stream.testkit.Utils._
import akka.stream.testkit._
import akka.stream.testkit.scaladsl.TestSink
@ -40,34 +39,6 @@ class OutputStreamSourceSpec extends AkkaSpec(UnboundedMailboxConfig) {
def expectSuccess[T](f: Future[T], value: T) =
Await.result(f, timeout) should be(value)
object OutputStreamSourceTestMessages {
case object Pull extends NoSerializationVerificationNeeded
case object Finish extends NoSerializationVerificationNeeded
}
def testSource(probe: TestProbe): Source[ByteString, OutputStream] = {
class OutputStreamSourceTestStage(val timeout: FiniteDuration)
extends OutputStreamSourceStage(timeout) {
override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = {
val (logic, inputStream) = super.createLogicAndMaterializedValue(inheritedAttributes)
val outHandler = logic.handlers(out.id).asInstanceOf[OutHandler]
logic.handlers(out.id) = new OutHandler {
override def onDownstreamFinish(): Unit = {
probe.ref ! OutputStreamSourceTestMessages.Finish
outHandler.onDownstreamFinish()
}
override def onPull(): Unit = {
probe.ref ! OutputStreamSourceTestMessages.Pull
outHandler.onPull()
}
}
(logic, inputStream)
}
}
Source.fromGraph(new OutputStreamSourceTestStage(timeout))
}
"OutputStreamSource" must {
"read bytes from OutputStream" in assertAllStagesStopped {
val (outputStream, probe) = StreamConverters.asOutputStream().toMat(TestSink.probe[ByteString])(Keep.both).run
@ -161,18 +132,19 @@ class OutputStreamSourceSpec extends AkkaSpec(UnboundedMailboxConfig) {
"throw IOException when writing to the stream after the subscriber has cancelled the reactive stream" in assertAllStagesStopped {
val sourceProbe = TestProbe()
val (outputStream, probe) = testSource(sourceProbe).toMat(TestSink.probe[ByteString])(Keep.both).run
val (outputStream, probe) = TestSourceStage(new OutputStreamSourceStage(timeout), sourceProbe)
.toMat(TestSink.probe[ByteString])(Keep.both).run
val s = probe.expectSubscription()
outputStream.write(bytesArray)
s.request(1)
sourceProbe.expectMsg(OutputStreamSourceTestMessages.Pull)
sourceProbe.expectMsg(GraphStageMessages.Pull)
probe.expectNext(byteString)
s.cancel()
sourceProbe.expectMsg(OutputStreamSourceTestMessages.Finish)
sourceProbe.expectMsg(GraphStageMessages.DownstreamFinish)
the[Exception] thrownBy outputStream.write(bytesArray) shouldBe a[IOException]
}
}

View file

@ -9,7 +9,7 @@ import akka.stream._
import akka.stream.impl.QueueSource
import akka.stream.stage.OutHandler
import akka.stream.testkit.Utils._
import akka.stream.testkit.{ AkkaSpec, TestSubscriber }
import akka.stream.testkit._
import akka.testkit.TestProbe
import scala.concurrent.duration._
import scala.concurrent.{ Future, _ }
@ -25,35 +25,6 @@ class QueueSourceSpec extends AkkaSpec {
expectMsg(QueueOfferResult.Enqueued)
}
object SourceTestMessages {
case object Pull extends NoSerializationVerificationNeeded
case object Finish extends NoSerializationVerificationNeeded
}
def testSource(maxBuffer: Int, overflowStrategy: OverflowStrategy, probe: TestProbe): Source[Int, SourceQueue[Int]] = {
class QueueSourceTestStage(maxBuffer: Int, overflowStrategy: OverflowStrategy)
extends QueueSource[Int](maxBuffer, overflowStrategy) {
override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = {
val (logic, inputStream) = super.createLogicAndMaterializedValue(inheritedAttributes)
val outHandler = logic.handlers(out.id).asInstanceOf[OutHandler]
logic.handlers(out.id) = new OutHandler {
override def onPull(): Unit = {
probe.ref ! SourceTestMessages.Pull
outHandler.onPull()
}
override def onDownstreamFinish(): Unit = {
probe.ref ! SourceTestMessages.Finish
outHandler.onDownstreamFinish()
}
}
(logic, inputStream)
}
}
Source.fromGraph(new QueueSourceTestStage(maxBuffer, overflowStrategy))
}
"A QueueSourceSpec" must {
"emit received messages to the stream" in {
@ -139,11 +110,12 @@ class QueueSourceSpec extends AkkaSpec {
"remember pull from downstream to send offered element immediately" in assertAllStagesStopped {
val s = TestSubscriber.manualProbe[Int]()
val probe = TestProbe()
val queue = testSource(1, OverflowStrategy.dropHead, probe).to(Sink.fromSubscriber(s)).run()
val queue = TestSourceStage(new QueueSource[Int](1, OverflowStrategy.dropHead), probe)
.to(Sink.fromSubscriber(s)).run()
val sub = s.expectSubscription
sub.request(1)
probe.expectMsg(SourceTestMessages.Pull)
probe.expectMsg(GraphStageMessages.Pull)
assertSuccess(queue.offer(1))
s.expectNext(1)
sub.cancel()