diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/One2OneBidiFlowSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/One2OneBidiFlowSpec.scala new file mode 100644 index 0000000000..c09c6ab1c1 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/One2OneBidiFlowSpec.scala @@ -0,0 +1,87 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.scaladsl + +import java.util.concurrent.atomic.AtomicInteger +import scala.concurrent.Await +import scala.concurrent.duration._ +import org.scalactic.ConversionCheckedTripleEquals +import akka.stream.ActorMaterializer +import akka.stream.testkit.AkkaSpec +import akka.stream.testkit._ + +class One2OneBidiFlowSpec extends AkkaSpec with ConversionCheckedTripleEquals { + implicit val mat = ActorMaterializer() + + "A One2OneBidiFlow" must { + + def test(flow: Flow[Int, Int, Unit]) = + Source(List(1, 2, 3)).via(flow).grouped(10).runWith(Sink.head) + + "be fully transparent for valid one-to-one streams" in { + val f = One2OneBidiFlow[Int, Int](-1) join Flow[Int].map(_ * 2) + Await.result(test(f), 1.second) should ===(Seq(2, 4, 6)) + } + + "be fully transparent to errors" in { + val f = One2OneBidiFlow[Int, Int](-1) join Flow[Int].map(x ⇒ 10 / (x - 2)) + an[ArithmeticException] should be thrownBy Await.result(test(f), 1.second) + } + + "trigger an `OutputTruncationException` if the wrapped stream terminates early" in { + val f = One2OneBidiFlow[Int, Int](-1) join Flow[Int].filter(_ < 3) + a[One2OneBidiFlow.OutputTruncationException.type] should be thrownBy Await.result(test(f), 1.second) + } + + "trigger an `UnexpectedOutputException` if the wrapped stream produces out-of-order elements" in new Test() { + inIn.sendNext(1) + inOut.requestNext() should ===(1) + + outIn.sendNext(2) + outOut.requestNext() should ===(2) + + outOut.request(1) + outIn.sendNext(3) + outOut.expectError(new One2OneBidiFlow.UnexpectedOutputException(3)) + } + + "drop surplus output elements" in new Test() { + inIn.sendNext(1) + inOut.requestNext() should ===(1) + + outIn.sendNext(2) + outOut.requestNext() should ===(2) + + outOut.cancel() + outIn.expectCancellation() + } + + "backpressure the input side if the maximum number of pending output elements has been reached" in { + val MAX_PENDING = 24 + + val out = TestPublisher.probe[Int]() + val seen = new AtomicInteger + + Source(1 to 1000) + .log("", seen.set) + .via(One2OneBidiFlow[Int, Int](MAX_PENDING) join Flow.wrap(Sink.ignore, Source(out))(Keep.left)) + .runWith(Sink.ignore) + + Thread.sleep(50) + val x = seen.get() + (1 to 8) foreach out.sendNext + Thread.sleep(50) + seen.get should ===(x + 8) + } + } + + class Test(maxPending: Int = -1) { + val inIn = TestPublisher.probe[Int]() + val inOut = TestSubscriber.probe[Int]() + val outIn = TestPublisher.probe[Int]() + val outOut = TestSubscriber.probe[Int]() + + Source(inIn).via(One2OneBidiFlow[Int, Int](maxPending) join Flow.wrap(Sink(inOut), Source(outIn))(Keep.left)).runWith(Sink(outOut)) + } +} diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala new file mode 100644 index 0000000000..21e331c25a --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala @@ -0,0 +1,87 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.scaladsl + +import akka.stream._ +import akka.stream.stage.{ OutHandler, InHandler, GraphStageLogic, GraphStage } + +import scala.concurrent.duration.Deadline +import scala.util.control.NoStackTrace + +object One2OneBidiFlow { + + case class UnexpectedOutputException(element: Any) extends RuntimeException with NoStackTrace + case object OutputTruncationException extends RuntimeException with NoStackTrace + + /** + * Creates a generic ``BidiFlow`` which verifies that another flow produces exactly one output element per + * input element, at the right time. Specifically it + * + * 1. triggers an ``UnexpectedOutputException`` if the inner flow produces an output element before having + * consumed the respective input element. + * 2. triggers an `OutputTruncationException` if the inner flow completes before having produced an output element + * for every input element. + * 3. Backpressures the input side if the maximum number of pending output elements has been reached, + * which is given via the ``maxPending`` parameter. You can use -1 to disable this feature. + * 4. Drops surplus output elements, i.e. ones that the inner flow tries to produce after the input stream + * has signalled completion. Note that no error is triggered in this case! + */ + def apply[I, O](maxPending: Int): BidiFlow[I, I, O, O, Unit] = + BidiFlow.wrap(new One2OneBidi[I, O](maxPending)) + + class One2OneBidi[I, O](maxPending: Int) extends GraphStage[BidiShape[I, I, O, O]] { + val inIn = Inlet[I]("inIn") + val inOut = Outlet[I]("inOut") + val outIn = Inlet[O]("outIn") + val outOut = Outlet[O]("outOut") + val shape = BidiShape(inIn, inOut, outIn, outOut) + + override def toString = "One2OneBidi" + + override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + private var pending = 0 + private var pullsSuppressed = 0 + + setHandler(inIn, new InHandler { + override def onPush(): Unit = { + pending += 1 + push(inOut, grab(inIn)) + } + override def onUpstreamFinish(): Unit = complete(inOut) + }) + + setHandler(inOut, new OutHandler { + override def onPull(): Unit = + if (pending < maxPending || maxPending == -1) pull(inIn) + else pullsSuppressed += 1 + override def onDownstreamFinish(): Unit = cancel(inIn) + }) + + setHandler(outIn, new InHandler { + override def onPush(): Unit = { + val element = grab(outIn) + if (pending > 0) { + pending -= 1 + push(outOut, element) + if (pullsSuppressed > 0) { + pullsSuppressed -= 1 + pull(inIn) + } + } else throw new UnexpectedOutputException(element) + } + override def onUpstreamFinish(): Unit = + if (pending == 0) complete(outOut) + else throw OutputTruncationException + }) + + setHandler(outOut, new OutHandler { + override def onPull(): Unit = pull(outIn) + override def onDownstreamFinish(): Unit = { + cancel(outIn) + cancel(inIn) // short-cut to speed up cleanup of upstream + } + }) + } + } +}