diff --git a/akka-docs-dev/rst/stages-overview.rst b/akka-docs-dev/rst/stages-overview.rst index 541226a477..8ca7108735 100644 --- a/akka-docs-dev/rst/stages-overview.rst +++ b/akka-docs-dev/rst/stages-overview.rst @@ -121,9 +121,10 @@ a single output combining the elements from all of the inputs in different ways. Stage Emits when Backpressures when Completes when ===================== ========================================================================================================================= ============================================================================================================================== ===================================================================================== merge one of the inputs has an element available downstream backpressures all upstreams complete (*) +mergeSorted all of the inputs have an element available downstream backpressures all upstreams complete mergePreferred one of the inputs has an element available, preferring a defined input if multiple have elements available downstream backpressures all upstreams complete (*) -zip all of the inputs has an element available downstream backpressures any upstream completes -zipWith all of the inputs has an element available downstream backpressures any upstream completes +zip all of the inputs have an element available downstream backpressures any upstream completes +zipWith all of the inputs have an element available downstream backpressures any upstream completes concat the current stream has an element available; if the current input completes, it tries the next one downstream backpressures all upstreams complete ===================== ========================================================================================================================= ============================================================================================================================== ===================================================================================== diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index e1730da4e2..1d880d72cf 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -396,7 +396,9 @@ private[http] object HttpServerBluePrint { outlets(2).asInstanceOf[Outlet[ByteString]]) } } - private class ProtocolSwitchStage(installHandler: Flow[FrameEvent, FrameEvent, Any] ⇒ Unit, websocketRandomFactory: () ⇒ Random, log: LoggingAdapter) extends GraphStage[ProtocolSwitchShape] { + + private class ProtocolSwitchStage(installHandler: Flow[FrameEvent, FrameEvent, Any] ⇒ Unit, + websocketRandomFactory: () ⇒ Random, log: LoggingAdapter) extends GraphStage[ProtocolSwitchShape] { private val fromNet = Inlet[ByteString]("fromNet") private val toNet = Outlet[ByteString]("toNet") @@ -409,11 +411,13 @@ private[http] object HttpServerBluePrint { def shape: ProtocolSwitchShape = ProtocolSwitchShape(fromNet, toNet, fromHttp, toHttp, fromWs, toWs) def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { + import akka.http.impl.engine.rendering.ResponseRenderingOutput._ + var websocketHandlerWasInstalled = false - setHandler(fromHttp, conditionalTerminateInput(() ⇒ !websocketHandlerWasInstalled)) + setHandler(fromHttp, ignoreTerminateInput) setHandler(toHttp, ignoreTerminateOutput) - setHandler(fromWs, conditionalTerminateInput(() ⇒ websocketHandlerWasInstalled)) + setHandler(fromWs, ignoreTerminateInput) setHandler(toWs, ignoreTerminateOutput) val pullNet = () ⇒ pull(fromNet) @@ -425,23 +429,25 @@ private[http] object HttpServerBluePrint { override def onUpstreamFailure(ex: Throwable): Unit = fail(target, ex) }) - setHandler(toNet, new OutHandler { - import akka.http.impl.engine.rendering.ResponseRenderingOutput._ - def onPull(): Unit = - if (isHttp) read(fromHttp) { - case HttpData(b) ⇒ push(toNet, b) - case SwitchToWebsocket(bytes, handlerFlow) ⇒ - push(toNet, bytes) - val frameHandler = handlerFlow match { - case Left(frameHandler) ⇒ frameHandler - case Right(messageHandler) ⇒ - Websocket.stack(serverSide = true, maskingRandomFactory = websocketRandomFactory, log = log).join(messageHandler) - } - installHandler(frameHandler) - websocketHandlerWasInstalled = true + val shutdown: () ⇒ Unit = () ⇒ completeStage() + val httpToNet: ResponseRenderingOutput ⇒ Unit = { + case HttpData(b) ⇒ push(toNet, b) + case SwitchToWebsocket(bytes, handlerFlow) ⇒ + push(toNet, bytes) + val frameHandler = handlerFlow match { + case Left(frameHandler) ⇒ frameHandler + case Right(messageHandler) ⇒ + Websocket.stack(serverSide = true, maskingRandomFactory = websocketRandomFactory, log = log).join(messageHandler) } - else - read(fromWs)(push(toNet, _)) + installHandler(frameHandler) + websocketHandlerWasInstalled = true + } + val wsToNet: ByteString ⇒ Unit = push(toNet, _) + + setHandler(toNet, new OutHandler { + def onPull(): Unit = + if (isHttp) read(fromHttp)(httpToNet, shutdown) + else read(fromWs)(wsToNet, shutdown) // toNet cancellation isn't allowed to stop this stage override def onDownstreamFinish(): Unit = () diff --git a/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala b/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala index d55d2bb76f..51843fef7e 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala @@ -39,7 +39,7 @@ class DslConsistencySpec extends WordSpec with Matchers { Set("create", "apply", "ops", "appendJava", "andThen", "andThenMat", "isIdentity", "withAttributes", "transformMaterializing") ++ Set("asScala", "asJava", "deprecatedAndThen", "deprecatedAndThenMat") - val graphHelpers = Set("zipGraph", "zipWithGraph", "mergeGraph", "interleaveGraph", "concatGraph", "alsoToGraph") + val graphHelpers = Set("zipGraph", "zipWithGraph", "mergeGraph", "mergeSortedGraph", "interleaveGraph", "concatGraph", "alsoToGraph") val allowMissing: Map[Class[_], Set[String]] = Map( jFlowClass -> graphHelpers, jSourceClass -> graphHelpers, diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSortedSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSortedSpec.scala new file mode 100644 index 0000000000..732e1b0c20 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergeSortedSpec.scala @@ -0,0 +1,52 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.scaladsl + +import akka.stream._ +import akka.stream.scaladsl._ +import akka.stream.testkit.TwoStreamsSetup +import org.scalacheck.Gen +import scala.util.Random +import org.scalatest.prop.GeneratorDrivenPropertyChecks +import org.scalatest.concurrent.ScalaFutures +import scala.concurrent.duration._ +import org.scalactic.ConversionCheckedTripleEquals +import org.scalacheck.Shrink + +class GraphMergeSortedSpec extends TwoStreamsSetup with GeneratorDrivenPropertyChecks with ScalaFutures with ConversionCheckedTripleEquals { + import GraphDSL.Implicits._ + + override type Outputs = Int + + override def fixture(b: GraphDSL.Builder[_]): Fixture = new Fixture(b) { + val merge = b.add(new MergeSorted[Outputs]) + + override def left: Inlet[Outputs] = merge.in0 + override def right: Inlet[Outputs] = merge.in1 + override def out: Outlet[Outputs] = merge.out + } + + implicit val patience = PatienceConfig(1.second) + implicit def noShrink[T] = Shrink[T](_ ⇒ Stream.empty) // do not shrink failures, it only destroys evidence + + "MergeSorted" must { + + "work in the nominal case" in { + val gen = Gen.listOf(Gen.oneOf(false, true)) + + forAll(gen) { picks ⇒ + val N = picks.size + val (left, right) = picks.zipWithIndex.partition(_._1) + Source(left.map(_._2)).mergeSorted(Source(right.map(_._2))) + .grouped(N max 1) + .concat(Source.single(Nil)) + .runWith(Sink.head) + .futureValue should ===(0 until N) + } + } + + commonTests() + + } +} diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala index d39db47d6a..92d68d98cd 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala @@ -15,6 +15,7 @@ import scala.collection.immutable import scala.concurrent.Future import scala.concurrent.duration.FiniteDuration import akka.japi.Util +import java.util.Comparator object Flow { @@ -1100,6 +1101,37 @@ final class Flow[-In, +Out, +Mat](delegate: scaladsl.Flow[In, Out, Mat]) extends matF: function.Function2[Mat, M, M2]): javadsl.Flow[In, T, M2] = new Flow(delegate.mergeMat(that)(combinerToScala(matF))) + /** + * Merge the given [[Source]] to this [[Flow]], taking elements as they arrive from input streams, + * picking always the smallest of the available elements (waiting for one element from each side + * to be available). This means that possible contiguity of the input streams is not exploited to avoid + * waiting for elements, this merge will block when one of the inputs does not have more elements (and + * does not complete). + * + * '''Emits when''' all of the inputs have an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' all upstreams complete + * + * '''Cancels when''' downstream cancels + */ + def mergeSorted[U >: Out, M](that: Graph[SourceShape[U], M], comp: Comparator[U]): javadsl.Flow[In, U, Mat] = + new Flow(delegate.mergeSorted(that)(Ordering.comparatorToOrdering(comp))) + + /** + * Merge the given [[Source]] to this [[Flow]], taking elements as they arrive from input streams, + * picking always the smallest of the available elements (waiting for one element from each side + * to be available). This means that possible contiguity of the input streams is not exploited to avoid + * waiting for elements, this merge will block when one of the inputs does not have more elements (and + * does not complete). + * + * @see [[#mergeSorted]]. + */ + def mergeSortedMat[U >: Out, Mat2, Mat3](that: Graph[SourceShape[U], Mat2], comp: Comparator[U], + matF: function.Function2[Mat, Mat2, Mat3]): javadsl.Flow[In, U, Mat3] = + new Flow(delegate.mergeSortedMat(that)(combinerToScala(matF))(Ordering.comparatorToOrdering(comp))) + /** * Combine the elements of current [[Flow]] and the given [[Source]] into a stream of tuples. * @@ -1122,7 +1154,7 @@ final class Flow[-In, +Out, +Mat](delegate: scaladsl.Flow[In, Out, Mat]) extends def zipMat[T, M, M2](that: Graph[SourceShape[T], M], matF: function.Function2[Mat, M, M2]): javadsl.Flow[In, Out @uncheckedVariance Pair T, M2] = this.viaMat(Flow.fromGraph(GraphDSL.create(that, - new function.Function2[GraphDSL.Builder[M], SourceShape[T], FlowShape[Out, Out @ uncheckedVariance Pair T]] { + new function.Function2[GraphDSL.Builder[M], SourceShape[T], FlowShape[Out, Out @uncheckedVariance Pair T]] { def apply(b: GraphDSL.Builder[M], s: SourceShape[T]): FlowShape[Out, Out @uncheckedVariance Pair T] = { val zip: FanInShape2[Out, T, Out Pair T] = b.add(Zip.create[Out, T]) b.from(s).toInlet(zip.in1) diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala index 7728ce02e2..65e8190c17 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala @@ -16,7 +16,6 @@ import akka.stream.impl.{ ConstantFun, StreamLayout } import akka.stream.stage.Stage import akka.util.ByteString import org.reactivestreams.{ Publisher, Subscriber } - import scala.annotation.unchecked.uncheckedVariance import scala.collection.JavaConverters._ import scala.collection.immutable @@ -615,6 +614,37 @@ final class Source[+Out, +Mat](delegate: scaladsl.Source[Out, Mat]) extends Grap matF: function.Function2[Mat, M, M2]): javadsl.Source[T, M2] = new Source(delegate.mergeMat(that)(combinerToScala(matF))) + /** + * Merge the given [[Source]] to this [[Source]], taking elements as they arrive from input streams, + * picking always the smallest of the available elements (waiting for one element from each side + * to be available). This means that possible contiguity of the input streams is not exploited to avoid + * waiting for elements, this merge will block when one of the inputs does not have more elements (and + * does not complete). + * + * '''Emits when''' all of the inputs have an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' all upstreams complete + * + * '''Cancels when''' downstream cancels + */ + def mergeSorted[U >: Out, M](that: Graph[SourceShape[U], M], comp: util.Comparator[U]): javadsl.Source[U, Mat] = + new Source(delegate.mergeSorted(that)(Ordering.comparatorToOrdering(comp))) + + /** + * Merge the given [[Source]] to this [[Source]], taking elements as they arrive from input streams, + * picking always the smallest of the available elements (waiting for one element from each side + * to be available). This means that possible contiguity of the input streams is not exploited to avoid + * waiting for elements, this merge will block when one of the inputs does not have more elements (and + * does not complete). + * + * @see [[#mergeSorted]]. + */ + def mergeSortedMat[U >: Out, Mat2, Mat3](that: Graph[SourceShape[U], Mat2], comp: util.Comparator[U], + matF: function.Function2[Mat, Mat2, Mat3]): javadsl.Source[U, Mat3] = + new Source(delegate.mergeSortedMat(that)(combinerToScala(matF))(Ordering.comparatorToOrdering(comp))) + /** * Combine the elements of current [[Source]] and the given one into a stream of tuples. * diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala b/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala index 660fc339f5..ddba19b89d 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/SubFlow.scala @@ -15,6 +15,7 @@ import scala.annotation.unchecked.uncheckedVariance import scala.concurrent.Future import scala.concurrent.duration.FiniteDuration import akka.japi.Util +import java.util.Comparator /** * A “stream of streams” sub-flow of data elements, e.g. produced by `groupBy`. @@ -769,6 +770,24 @@ class SubFlow[-In, +Out, +Mat](delegate: scaladsl.SubFlow[Out, Mat, scaladsl.Flo def interleave[T >: Out](that: Graph[SourceShape[T], _], segmentSize: Int): SubFlow[In, T, Mat] = new SubFlow(delegate.interleave(that, segmentSize)) + /** + * Merge the given [[Source]] to this [[Flow]], taking elements as they arrive from input streams, + * picking always the smallest of the available elements (waiting for one element from each side + * to be available). This means that possible contiguity of the input streams is not exploited to avoid + * waiting for elements, this merge will block when one of the inputs does not have more elements (and + * does not complete). + * + * '''Emits when''' all of the inputs have an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' all upstreams complete + * + * '''Cancels when''' downstream cancels + */ + def mergeSorted[U >: Out, M](that: Graph[SourceShape[U], M], comp: Comparator[U]): javadsl.SubFlow[In, U, Mat] = + new SubFlow(delegate.mergeSorted(that)(Ordering.comparatorToOrdering(comp))) + /** * Combine the elements of current [[Flow]] and the given [[Source]] into a stream of tuples. * diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala b/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala index 4a3b7d093b..26c16ca7c4 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/SubSource.scala @@ -15,6 +15,7 @@ import scala.annotation.unchecked.uncheckedVariance import scala.concurrent.Future import scala.concurrent.duration.FiniteDuration import akka.japi.Util +import java.util.Comparator /** * A “stream of streams” sub-flow of data elements, e.g. produced by `groupBy`. @@ -768,6 +769,24 @@ class SubSource[+Out, +Mat](delegate: scaladsl.SubFlow[Out, Mat, scaladsl.Source def interleave[T >: Out](that: Graph[SourceShape[T], _], segmentSize: Int): SubSource[T, Mat] = new SubSource(delegate.interleave(that, segmentSize)) + /** + * Merge the given [[Source]] to this [[Source]], taking elements as they arrive from input streams, + * picking always the smallest of the available elements (waiting for one element from each side + * to be available). This means that possible contiguity of the input streams is not exploited to avoid + * waiting for elements, this merge will block when one of the inputs does not have more elements (and + * does not complete). + * + * '''Emits when''' all of the inputs have an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' all upstreams complete + * + * '''Cancels when''' downstream cancels + */ + def mergeSorted[U >: Out, M](that: Graph[SourceShape[U], M], comp: Comparator[U]): javadsl.SubSource[U, Mat] = + new SubSource(delegate.mergeSorted(that)(Ordering.comparatorToOrdering(comp))) + /** * Combine the elements of current [[Flow]] and the given [[Source]] into a stream of tuples. * diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index 8ec28bb836..57ddcaa37b 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -336,6 +336,7 @@ final case class RunnableGraph[+Mat](private[stream] val module: StreamLayout.Mo */ trait FlowOps[+Out, +Mat] { import akka.stream.impl.Stages._ + import GraphDSL.Implicits._ type Repr[+O] <: FlowOps[O, Mat] @@ -1294,7 +1295,6 @@ trait FlowOps[+Out, +Mat] { protected def zipGraph[U, M](that: Graph[SourceShape[U], M]): Graph[FlowShape[Out @uncheckedVariance, (Out, U)], M] = GraphDSL.create(that) { implicit b ⇒ r ⇒ - import GraphDSL.Implicits._ val zip = b.add(Zip[Out, U]()) r ~> zip.in1 FlowShape(zip.in0, zip.out) @@ -1318,7 +1318,6 @@ trait FlowOps[+Out, +Mat] { protected def zipWithGraph[Out2, Out3, M](that: Graph[SourceShape[Out2], M])(combine: (Out, Out2) ⇒ Out3): Graph[FlowShape[Out @uncheckedVariance, Out3], M] = GraphDSL.create(that) { implicit b ⇒ r ⇒ - import GraphDSL.Implicits._ val zip = b.add(ZipWith[Out, Out2, Out3](combine)) r ~> zip.in1 FlowShape(zip.in0, zip.out) @@ -1354,7 +1353,6 @@ trait FlowOps[+Out, +Mat] { segmentSize: Int): Graph[FlowShape[Out @uncheckedVariance, U], M] = GraphDSL.create(that) { implicit b ⇒ r ⇒ - import GraphDSL.Implicits._ val interleave = b.add(Interleave[U](2, segmentSize)) r ~> interleave.in(1) FlowShape(interleave.in(0), interleave.out) @@ -1372,18 +1370,43 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def merge[U >: Out](that: Graph[SourceShape[U], _]): Repr[U] = + def merge[U >: Out, M](that: Graph[SourceShape[U], M]): Repr[U] = via(mergeGraph(that)) protected def mergeGraph[U >: Out, M](that: Graph[SourceShape[U], M]): Graph[FlowShape[Out @uncheckedVariance, U], M] = GraphDSL.create(that) { implicit b ⇒ r ⇒ - import GraphDSL.Implicits._ val merge = b.add(Merge[U](2)) r ~> merge.in(1) FlowShape(merge.in(0), merge.out) } + /** + * Merge the given [[Source]] to this [[Flow]], taking elements as they arrive from input streams, + * picking always the smallest of the available elements (waiting for one element from each side + * to be available). This means that possible contiguity of the input streams is not exploited to avoid + * waiting for elements, this merge will block when one of the inputs does not have more elements (and + * does not complete). + * + * '''Emits when''' all of the inputs have an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' all upstreams complete + * + * '''Cancels when''' downstream cancels + */ + def mergeSorted[U >: Out, M](that: Graph[SourceShape[U], M])(implicit ord: Ordering[U]): Repr[U] = + via(mergeSortedGraph(that)) + + protected def mergeSortedGraph[U >: Out, M](that: Graph[SourceShape[U], M])(implicit ord: Ordering[U]): Graph[FlowShape[Out @uncheckedVariance, U], M] = + GraphDSL.create(that) { implicit b ⇒ + r ⇒ + val merge = b.add(new MergeSorted[U]) + r ~> merge.in1 + FlowShape(merge.in0, merge.out) + } + /** * Concatenate the given [[Source]] to this [[Flow]], meaning that once this * Flow’s input is exhausted and all result elements have been generated, @@ -1408,7 +1431,6 @@ trait FlowOps[+Out, +Mat] { protected def concatGraph[U >: Out, Mat2](that: Graph[SourceShape[U], Mat2]): Graph[FlowShape[Out @uncheckedVariance, U], Mat2] = GraphDSL.create(that) { implicit b ⇒ r ⇒ - import GraphDSL.Implicits._ val merge = b.add(Concat[U]()) r ~> merge.in(1) FlowShape(merge.in(0), merge.out) @@ -1560,6 +1582,18 @@ trait FlowOpsMat[+Out, +Mat] extends FlowOps[Out, Mat] { def interleaveMat[U >: Out, Mat2, Mat3](that: Graph[SourceShape[U], Mat2], request: Int)(matF: (Mat, Mat2) ⇒ Mat3): ReprMat[U, Mat3] = viaMat(interleaveGraph(that, request))(matF) + /** + * Merge the given [[Source]] to this [[Flow]], taking elements as they arrive from input streams, + * picking always the smallest of the available elements (waiting for one element from each side + * to be available). This means that possible contiguity of the input streams is not exploited to avoid + * waiting for elements, this merge will block when one of the inputs does not have more elements (and + * does not complete). + * + * @see [[#mergeSorted]]. + */ + def mergeSortedMat[U >: Out, Mat2, Mat3](that: Graph[SourceShape[U], Mat2])(matF: (Mat, Mat2) ⇒ Mat3)(implicit ord: Ordering[U]): ReprMat[U, Mat3] = + viaMat(mergeSortedGraph(that))(matF) + /** * Concatenate the given [[Source]] to this [[Flow]], meaning that once this * Flow’s input is exhausted and all result elements have been generated, diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala index f55d8af5ce..c81dab74c1 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala @@ -318,6 +318,55 @@ final class Interleave[T] private (val inputPorts: Int, val segmentSize: Int, va override def toString = "Interleave" } +/** + * Merge two pre-sorted streams such that the resulting stream is sorted. + * + * '''Emits when''' both inputs have an element available + * + * '''Backpressures when''' downstream backpressures + * + * '''Completes when''' all upstreams complete + * + * '''Cancels when''' downstream cancels + */ +final class MergeSorted[T: Ordering] extends GraphStage[FanInShape2[T, T, T]] { + private val left = Inlet[T]("left") + private val right = Inlet[T]("right") + private val out = Outlet[T]("out") + + override val shape = new FanInShape2(left, right, out) + + override def createLogic(attr: Attributes) = new GraphStageLogic(shape) { + import Ordering.Implicits._ + setHandler(left, ignoreTerminateInput) + setHandler(right, ignoreTerminateInput) + setHandler(out, eagerTerminateOutput) + + var other: T = _ + def nullOut(): Unit = other = null.asInstanceOf[T] + + def dispatch(l: T, r: T): Unit = + if (l < r) { other = r; emit(out, l, readL) } + else { other = l; emit(out, r, readR) } + + val dispatchR = dispatch(other, _: T) + val dispatchL = dispatch(_: T, other) + val passR = () ⇒ emit(out, other, () ⇒ { nullOut(); passAlong(right, out, doPull = true) }) + val passL = () ⇒ emit(out, other, () ⇒ { nullOut(); passAlong(left, out, doPull = true) }) + val readR = () ⇒ read(right)(dispatchR, passL) + val readL = () ⇒ read(left)(dispatchL, passR) + + override def preStart(): Unit = { + // all fan-in stages need to eagerly pull all inputs to get cycles started + pull(right) + read(left)(l ⇒ { + other = l + readR() + }, () ⇒ passAlong(right, out)) + } + } +} + object Broadcast { /** * Create a new `Broadcast` with the specified number of output ports. diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 2925a62ba6..ba9329cd9e 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -571,12 +571,19 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * for the given inlet if suspension is needed and reinstalls the current * handler upon receiving the last `onPush()` signal (before invoking the `andThen` function). */ - final protected def readN[T](in: Inlet[T], n: Int)(andThen: Seq[T] ⇒ Unit): Unit = + final protected def readN[T](in: Inlet[T], n: Int)(andThen: Seq[T] ⇒ Unit, onClose: Seq[T] ⇒ Unit): Unit = if (n < 0) throw new IllegalArgumentException("cannot read negative number of elements") else if (n == 0) andThen(Nil) else { val result = new ArrayBuffer[T](n) var pos = 0 + def realAndThen = (elem: T) ⇒ { + result(pos) = elem + pos += 1 + if (pos == n) andThen(result) + } + def realOnClose = () ⇒ onClose(result.take(pos)) + if (isAvailable(in)) { val elem = grab(in) result(0) = elem @@ -586,20 +593,12 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: pos = 1 requireNotReading(in) pull(in) - setHandler(in, new Reading(in, n - 1, getHandler(in))(elem ⇒ { - result(pos) = elem - pos += 1 - if (pos == n) andThen(result) - })) + setHandler(in, new Reading(in, n - 1, getHandler(in))(realAndThen, realOnClose)) } } else { requireNotReading(in) if (!hasBeenPulled(in)) pull(in) - setHandler(in, new Reading(in, n, getHandler(in))(elem ⇒ { - result(pos) = elem - pos += 1 - if (pos == n) andThen(result) - })) + setHandler(in, new Reading(in, n, getHandler(in))(realAndThen, realOnClose)) } } @@ -609,14 +608,16 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * for the given inlet if suspension is needed and reinstalls the current * handler upon receiving the `onPush()` signal (before invoking the `andThen` function). */ - final protected def read[T](in: Inlet[T])(andThen: T ⇒ Unit): Unit = { + final protected def read[T](in: Inlet[T])(andThen: T ⇒ Unit, onClose: () ⇒ Unit): Unit = { if (isAvailable(in)) { val elem = grab(in) andThen(elem) + } else if (isClosed(in)) { + onClose() } else { requireNotReading(in) if (!hasBeenPulled(in)) pull(in) - setHandler(in, new Reading(in, 1, getHandler(in))(andThen)) + setHandler(in, new Reading(in, 1, getHandler(in))(andThen, onClose)) } } @@ -640,7 +641,7 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * Caution: for n==1 andThen is called after resetting the handler, for * other values it is called without resetting the handler. */ - private class Reading[T](in: Inlet[T], private var n: Int, val previous: InHandler)(andThen: T ⇒ Unit) extends InHandler { + private class Reading[T](in: Inlet[T], private var n: Int, val previous: InHandler)(andThen: T ⇒ Unit, onClose: () ⇒ Unit) extends InHandler { override def onPush(): Unit = { val elem = grab(in) if (n == 1) setHandler(in, previous) @@ -650,8 +651,15 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: } andThen(elem) } - override def onUpstreamFinish(): Unit = previous.onUpstreamFinish() - override def onUpstreamFailure(ex: Throwable): Unit = previous.onUpstreamFailure(ex) + override def onUpstreamFinish(): Unit = { + setHandler(in, previous) + onClose() + previous.onUpstreamFinish() + } + override def onUpstreamFailure(ex: Throwable): Unit = { + setHandler(in, previous) + previous.onUpstreamFailure(ex) + } } /** @@ -816,19 +824,30 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: /** * Install a handler on the given inlet that emits received elements on the - * given outlet before pulling for more data. `doTerminate` controls whether + * given outlet before pulling for more data. `doFinish` and `doFail` control whether * completion or failure of the given inlet shall lead to stage termination or not. + * `doPull` instructs to perform one initial pull on the `from` port. */ - final protected def passAlong[Out, In <: Out](from: Inlet[In], to: Outlet[Out], doFinish: Boolean, doFail: Boolean): Unit = - setHandler(from, new InHandler { - val puller = () ⇒ tryPull(from) + final protected def passAlong[Out, In <: Out](from: Inlet[In], to: Outlet[Out], + doFinish: Boolean = true, doFail: Boolean = true, + doPull: Boolean = false): Unit = { + class PassAlongHandler extends InHandler with (() ⇒ Unit) { + override def apply(): Unit = tryPull(from) override def onPush(): Unit = { val elem = grab(from) - emit(to, elem, puller) + emit(to, elem, this) } - override def onUpstreamFinish(): Unit = if (doFinish) super.onUpstreamFinish() - override def onUpstreamFailure(ex: Throwable): Unit = if (doFail) super.onUpstreamFailure(ex) - }) + override def onUpstreamFinish(): Unit = if (doFinish) completeStage() + override def onUpstreamFailure(ex: Throwable): Unit = if (doFail) failStage(ex) + } + val ph = new PassAlongHandler + if (_interpreter != null) { + if (isAvailable(from)) emit(to, grab(from), ph) + if (doFinish && isClosed(from)) completeStage() + } + setHandler(from, ph) + if (doPull) tryPull(from) + } /** * Obtain a callback object that can be used asynchronously to re-enter the