diff --git a/akka-bench-jmh-dev/src/main/scala/akka/stream/InterpreterBenchmark.scala b/akka-bench-jmh-dev/src/main/scala/akka/stream/InterpreterBenchmark.scala index c3af314393..c997b6f4c7 100644 --- a/akka-bench-jmh-dev/src/main/scala/akka/stream/InterpreterBenchmark.scala +++ b/akka-bench-jmh-dev/src/main/scala/akka/stream/InterpreterBenchmark.scala @@ -1,7 +1,7 @@ package akka.stream import akka.event._ -import akka.stream.impl.fusing.{ GraphInterpreterSpecKit, GraphStages, Map => MapStage, OneBoundedInterpreter } +import akka.stream.impl.fusing.{ GraphInterpreterSpecKit, GraphStages, Map => MapStage } import akka.stream.impl.fusing.GraphStages.Identity import akka.stream.impl.fusing.GraphInterpreter.{ DownstreamBoundaryStageLogic, UpstreamBoundaryStageLogic } import akka.stream.stage._ @@ -18,10 +18,10 @@ class InterpreterBenchmark { import InterpreterBenchmark._ // manual, and not via @Param, because we want @OperationsPerInvocation on our tests - final val data100k = (1 to 100000).toVector + final val data100k: Vector[Int] = (1 to 100000).toVector @Param(Array("1", "5", "10")) - val numberOfIds = 0 + val numberOfIds: Int = 0 @Benchmark @OperationsPerInvocation(100000) @@ -47,33 +47,13 @@ class InterpreterBenchmark { } } } - - @Benchmark - @OperationsPerInvocation(100000) - def onebounded_interpreter_100k_elements() { - val lock = new Lock() - lock.acquire() - val sink = OneBoundedDataSink(data100k.size) - val ops = Vector.fill(numberOfIds)(new PushPullStage[Int, Int] { - override def onPull(ctx: _root_.akka.stream.stage.Context[Int]) = ctx.pull() - override def onPush(elem: Int, ctx: _root_.akka.stream.stage.Context[Int]) = ctx.push(elem) - }) - val interpreter = new OneBoundedInterpreter(OneBoundedDataSource(data100k) +: ops :+ sink, - (op, ctx, event) ⇒ (), - Logging(NoopBus, classOf[InterpreterBenchmark]), - null, - Attributes.none, - forkLimit = 100, overflowToHeap = false) - interpreter.init() - sink.requestOne() - } } object InterpreterBenchmark { case class GraphDataSource[T](override val toString: String, data: Vector[T]) extends UpstreamBoundaryStageLogic[T] { - var idx = 0 - val out = Outlet[T]("out") + var idx: Int = 0 + override val out: akka.stream.Outlet[T] = Outlet[T]("out") out.id = 0 setHandler(out, new OutHandler { @@ -81,8 +61,7 @@ object InterpreterBenchmark { if (idx < data.size) { push(out, data(idx)) idx += 1 - } - else { + } else { completeStage() } } @@ -91,7 +70,7 @@ object InterpreterBenchmark { } case class GraphDataSink[T](override val toString: String, var expected: Int) extends DownstreamBoundaryStageLogic[T] { - val in = Inlet[T]("in") + override val in: akka.stream.Inlet[T] = Inlet[T]("in") in.id = 0 setHandler(in, new InHandler { @@ -104,49 +83,7 @@ object InterpreterBenchmark { override def onUpstreamFailure(ex: Throwable): Unit = failStage(ex) }) - def requestOne() = pull(in) - } - - case class OneBoundedDataSource[T](data: Vector[T]) extends BoundaryStage { - var idx = 0 - - override def onDownstreamFinish(ctx: BoundaryContext): TerminationDirective = { - ctx.finish() - } - - override def onPull(ctx: BoundaryContext): Directive = { - if (idx < data.size) { - idx += 1 - ctx.push(data(idx - 1)) - } - else { - ctx.finish() - } - } - - override def onPush(elem: Any, ctx: BoundaryContext): Directive = - throw new UnsupportedOperationException("Cannot push the boundary") - } - - case class OneBoundedDataSink(var expected: Int) extends BoundaryStage { - override def onPush(elem: Any, ctx: BoundaryContext): Directive = { - expected -= 1 - if (expected == 0) ctx.exit() - else ctx.pull() - } - - override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = { - ctx.finish() - } - - override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = { - ctx.finish() - } - - override def onPull(ctx: BoundaryContext): Directive = - throw new UnsupportedOperationException("Cannot pull the boundary") - - def requestOne(): Unit = enterAndPull() + def requestOne(): Unit = pull(in) } val NoopBus = new LoggingBus { diff --git a/akka-docs-dev/rst/java/code/docs/http/javadsl/server/HttpServerExampleDocTest.java b/akka-docs-dev/rst/java/code/docs/http/javadsl/server/HttpServerExampleDocTest.java index a5b2c4990c..a163e305cf 100644 --- a/akka-docs-dev/rst/java/code/docs/http/javadsl/server/HttpServerExampleDocTest.java +++ b/akka-docs-dev/rst/java/code/docs/http/javadsl/server/HttpServerExampleDocTest.java @@ -159,8 +159,8 @@ public class HttpServerExampleDocTest { Flow.of(HttpRequest.class) .via(failureDetection) .map(request -> { - Source bytes = request.entity().getDataBytes(); - HttpEntity.Chunked entity = HttpEntities.create(ContentTypes.TEXT_PLAIN, (Source) bytes); + Source bytes = request.entity().getDataBytes(); + HttpEntity.Chunked entity = HttpEntities.create(ContentTypes.TEXT_PLAIN, bytes); return HttpResponse.create() .withEntity(entity); diff --git a/akka-docs-dev/rst/migration-guide-1.0-2.x.rst b/akka-docs-dev/rst/migration-guide-1.0-2.x.rst new file mode 100644 index 0000000000..65e09e2433 --- /dev/null +++ b/akka-docs-dev/rst/migration-guide-1.0-2.x.rst @@ -0,0 +1,263 @@ +.. _migration-2.0: + +############################ + Migration Guide 1.0 to 2.x +############################ + +The 2.0 release contains some structural changes that require some +simple, mechanical source-level changes in client code. + + +Introduced proper named constructor methods insted of ``wrap()`` +================================================================ + +There were several, unrelated uses of ``wrap()`` which made it hard to find and hard to understand the intention of +the call. Therefore these use-cases now have methods with different names, helping Java 8 type inference (by reducing +the number of overloads) and finding relevant methods in the documentation. + +Creating a Flow from other stages +--------------------------------- + +It was possible to create a ``Flow`` from a graph with the correct shape (``FlowShape``) using ``wrap()``. Now this +must be done with the more descriptive method ``Flow.fromGraph()``. + +It was possible to create a ``Flow`` from a ``Source`` and a ``Sink`` using ``wrap()``. Now this functionality can +be accessed trough the more descriptive methods ``Flow.fromSinkAndSource`` and ``Flow.fromSinkAndSourceMat``. + +Creating a BidiFlow from other stages +------------------------------------- + +It was possible to create a ``BidiFlow`` from a graph with the correct shape (``BidiShape``) using ``wrap()``. Now this +must be done with the more descriptive method ``BidiFlow.fromGraph()``. + +It was possible to create a ``BidiFlow`` from two ``Flow`` s using ``wrap()``. Now this functionality can +be accessed trough the more descriptive methods ``BidiFlow.fromFlows`` and ``BidiFlow.fromFlowsMat``. + +It was possible to create a ``BidiFlow`` from two functions using ``apply()`` (Scala DSL) or ``create()`` (Java DSL). +Now this functionality can be accessed trough the more descriptive method ``BidiFlow.fromFunctions``. + +Update procedure +---------------- + +1. Replace all uses of ``Flow.wrap`` when it converts a ``Graph`` to a ``Flow`` with ``Flow.fromGraph`` +2. Replace all uses of ``Flow.wrap`` when it converts a ``Source`` and ``Sink`` to a ``Flow`` with + ``Flow.fromSinkAndSource`` or ``Flow.fromSinkAndSourceMat`` +3. Replace all uses of ``BidiFlow.wrap`` when it converts a ``Graph`` to a ``BidiFlow`` with ``BidiFlow.fromGraph`` +4. Replace all uses of ``BidiFlow.wrap`` when it converts two ``Flow``s to a ``BidiFlow`` with + ``BidiFlow.fromFlows`` or ``BidiFlow.fromFlowsMat`` +5. Repplace all uses of ``BidiFlow.apply()`` (Scala DSL) or ``BidiFlow.create()`` (Java DSL) when it converts two + functions to a ``BidiFlow`` with ``BidiFlow.fromFunctions`` + +TODO: Code example + +FlowGraph builder methods have been renamed +=========================================== + +There is now only one graph creation method called ``create`` which is analogous to the old ``partial`` method. For +closed graphs now it is explicitly required to return ``ClosedShape`` at the end of the builder block. + +Update procedure +---------------- + +1. Replace all occurrences of ``FlowGraph.create()`` with ``FlowGraph.partial()`` +2. Add ``ClosedShape`` as a return value of the builder block + +TODO: Code sample + +Methods that create Source, Sink, Flow from Graphs have been removed +==================================================================== + +Previously there were convenience methods available on ``Sink``, ``Source``, ``Flow`` an ``BidiFlow`` to create +these DSL elements from a graph builder directly. Now this requires two explicit steps to reduce the number of overloaded +methods (helps Java 8 type inference) and also reduces the ways how these elements can be created. There is only one +graph creation method to learn (``FlowGraph.create``) and then there is only one conversion method to use ``fromGraph()``. + +This means that the following methods have been removed: + - ``adapt()`` method on ``Source``, ``Sink``, ``Flow`` and ``BidiFlow`` (both DSLs) + - ``apply()`` overloads providing a graph ``Builder`` on ``Source``, ``Sink``, ``Flow`` and ``BidiFlow`` (Scala DSL) + - ``create()`` overloads providing a graph ``Builder`` on ``Source``, ``Sink``, ``Flow`` and ``BidiFlow`` (Java DSL) + +Update procedure +---------------- + +Everywhere where ``Source``, ``Sink``, ``Flow`` and ``BidiFlow`` is created from a graph using a builder have to +be replaced with two steps + +1. Create a ``Graph`` with the correct ``Shape`` using ``FlowGraph.create`` (e.g.. for ``Source`` it means first + creating a ``Graph`` with ``SourceShape``) +2. Create the required DSL element by calling ``fromGraph()`` on the required DSL element (e.g. ``Source.fromGraph``) + passing the graph created in the previous step + +TODO code example + +Some graph Builder methods in the Java DSL have been renamed +============================================================ + +Due to the high number of overloads Java 8 type inference suffered, and it was also hard to figure out which time +to use which method. Therefore various redundant methods have been removed. + +Update procedure +---------------- + +1. All uses of builder.addEdge(Outlet, Inlet) should be replaced by the alternative builder.from(…).to(…) +2. All uses of builder.addEdge(Outlet, FlowShape, Inlet) should be replaced by builder.from(…).via(…).to(…) + +Builder.source => use builder.from(…).via(…).to(…) +Builder.flow => use builder.from(…).via(…).to(…) +Builder.sink => use builder.from(…).via(…).to(…) + +TODO: code example + +Builder overloads from the Scala DSL have been removed +====================================================== + +scaladsl.Builder.addEdge(Outlet, Inlet) => use the DSL (~> and <~) +scaladsl.Builder.addEdge(Outlet, FlowShape, Inlet) => use the DSL (~> and <~) + +Source constructor name changes +=============================== + +``Source.lazyEmpty`` have been replaced by ``Source.maybe`` which returns a ``Promise`` that can be completed by one or +zero elements by providing an ``Option``. This is different from ``lazyEmpty`` which only allowed completion to be +sent, but no elements. + +The ``apply()`` and ``from()`` overloads on ``Source`` that provide a tick source (``Source(delay,interval,tick)``) +are replaced by the named method ``Source.tick()`` to reduce the number of overloads and to make the function more +discoverable. + +Update procedure +---------------- + +1. Replace all uses of ``Source(delay,interval,tick)`` and ``Source.from(delay,interval,tick)`` with the method + ``Source.tick()`` +2. All uses of ``Source.lazyEmpty`` should be replaced by ``Source.maybe`` and the returned ``Promise`` completed with + a ``None`` (an empty ``Option``) + +TODO: code example + +``Flow.empty()`` has been removed from the Java DSL +=================================================== + +The ``empty()`` method has been removed since it behaves exactly the same as ``create()``, creating a ``Flow`` with no +transformations added yet. + +Update procedure +---------------- + +1. Replace all uses of ``Flow.empty()`` with ``Flow.create``. + +TODO: code example + +``flatten(FlattenStrategy)`` has been replaced by named counterparts +==================================================================== + +To simplify type inference in Java 8 and to make the method more discoverable, ``flatten(FlattenStrategy.concat)`` +has been removed and replaced with the alternative method ``flatten(FlattenStrategy.concat)``. + +Update procedure +---------------- + +1. Replace all occurences of ``flatten(FlattenStrategy.concat)`` with ``flattenConcat()`` + +TODO: code example + +FlexiMerge an FlexiRoute has been replaced by GraphStage +======================================================== + +The ``FlexiMerge`` and ``FlexiRoute`` DSLs have been removed since they provided an abstraction that was too limiting +and a better abstraction have been created which is called ``GraphStage``. ``GraphStage`` can express fan-in and +fan-out stages, but many other constructs as well with possibly multiple input and output ports (e.g. a ``BidiStage``). + +This new abstraction provides a more uniform way to crate custom stream processing stages of arbitrary ``Shape``. In +fact, all of the built-in fan-in and fan-out stages are now implemented in terms of ``GraphStage``. + +Update procedure +---------------- + +*There is no simple update procedure. The affected stages must be ported to the new ``GraphStage`` DSL manually. Please +read the* ``GraphStage`` *documentation (TODO) for details.* + +Variance of Inlet and Outlet (Scala DSL) +======================================== + +Scala uses *declaration site variance* which was cumbersome in the cases of ``Inlet`` and ``Outlet`` as they are +purely symbolic object containing no fields or methods and which are used both in input and output locations (wiring +an ``Outlet`` into an ``Inlet``; reading in a stage from an ``Inlet``). Because of this reasons all users of these +port abstractions now use *use-site variance* (just like Java variance works). This in general does not affect user +code expect the case of custom shapes, which now require ``@uncheckedVariance`` annotations on their ``Inlet`` and +``Outlet`` members (since these are now invariant, but the Scala compiler does not know that they have no fields or +methods that would violate variance constraints) + +This change does not affect Java DSL users. + +TODO: code example + +Update procedure +---------------- + +1. All custom shapes must use ``@uncheckedVariance`` on their ``Inlet`` and ``Outlet`` members. + +Semantic change in ``isHoldingUpstream`` in the DetachedStage DSL +================================================================= + +The ``isHoldingUpstream`` method used to return true if the upstream port was in holding state and a completion arrived +(inside the ``onUpstreamFinished`` callback). Now it returns ``false`` when the upstream is completed. + +Update procedure +---------------- + +1. Those stages that relied on the previous behavior need to introduce an extra ``Boolean`` field with initial value + ``false`` +2. This field must be set on every call to ``holdUpstream()`` (and variants). +3. In completion, instead of calling ``isHoldingUpstream`` read this variable instead. + +TODO: code example + + +AsyncStage has been replaced by GraphStage +========================================== + +Due to its complexity and relative inflexibility ``AsyncStage`` have been removed. + +TODO explanation + +Update procedure +---------------- + +1. The subclass of ``AsyncStage`` should be replaced by ``GraphStage`` +2. The new subclass must define an ``in`` and ``out`` port (``Inlet`` and ``Outlet`` instance) and override the ``shape`` + method returning a ``FlowShape`` +3. An instance of ``GraphStageLogic`` must be returned by overriding ``createLogic()``. The original processing logic and + state will be encapsulated in this ``GraphStageLogic`` +4. Using ``setHandler(port, handler)`` and ``InHandler`` instance should be set on ``in`` and an ``OutHandler`` should + be set on ``out`` +5. ``onPush``, ``onUpstreamFinished`` and ``onUpstreamFailed`` are now available in the ``InHandler`` subclass created + by the user +6. ``onPull`` and ``onDownstreamFinished`` are now available in the ``OutHandler`` subclass created by the user +7. the callbacks above no longer take an extra `ctxt` context parameter. +8. ``onPull`` only signals the stage, the actual element can be obtained by calling ``grab(in)`` +9. ``ctx.push(elem)`` is now ``push(out, elem)`` +10. ``ctx.pull()`` is now ``pull(in)`` +11. ``ctx.finish()`` is now ``completeStage()`` +12. ``ctx.pushAndFinish(elem)`` is now simply two calls: ``push(out, elem); completeStage()`` +13. ``ctx.fail(cause)`` is now ``failStage(cause)`` +14. ``ctx.isFinishing()`` is now ``isClosed(in)`` +15. ``ctx.absorbTermination()`` can be replaced with ``if (isAvailable(shape.outlet)) `` +16. ``ctx.pushAndPull(elem)`` can be replaced with ``push(out, elem); pull(in)`` +17. ``ctx.holdUpstreamAndPush`` and ``context.holdDownstreamAndPull`` can be replaced by simply ``push(elem)`` and + ``pull()`` respectively +18. The following calls should be removed: ``ctx.ignore()``, ``ctx.holdUpstream()`` and ``ctx.holdDownstream()``. +19. ``ctx.isHoldingUpstream()`` can be replaced with ``isAvailable(out)`` +20. ``ctx.isHoldingDowntream()`` can be replaced with ``!(isClosed(in) || hasBeenPulled(in))`` +21. ``ctx.getAsyncCallback()`` is now ``getAsyncCallback(callback)`` which now takes a callback as a parameter. This + would correspond to the ``onAsyncInput()`` callback in the original ``AsyncStage`` + +We show the necessary steps in terms of an example ``AsyncStage`` + +TODO: code sample + + + +TODO: Code example + + diff --git a/akka-docs-dev/rst/scala/code/docs/http/scaladsl/HttpServerExampleSpec.scala b/akka-docs-dev/rst/scala/code/docs/http/scaladsl/HttpServerExampleSpec.scala index 8834a98e37..19b4a8fe3d 100644 --- a/akka-docs-dev/rst/scala/code/docs/http/scaladsl/HttpServerExampleSpec.scala +++ b/akka-docs-dev/rst/scala/code/docs/http/scaladsl/HttpServerExampleSpec.scala @@ -352,7 +352,7 @@ class HttpServerExampleSpec extends WordSpec with Matchers { pathEnd { (put | parameter('method ! "put")) { // form extraction from multipart or www-url-encoded forms - formFields('email, 'total.as[Money]).as(Order) { order => + formFields(('email, 'total.as[Money])).as(Order) { order => complete { // complete with serialized Future result (myDbActor ? Update(order)).mapTo[TransactionResult] @@ -373,7 +373,7 @@ class HttpServerExampleSpec extends WordSpec with Matchers { path("items") { get { // parameters to case class extraction - parameters('size.as[Int], 'color ?, 'dangerous ? "no") + parameters(('size.as[Int], 'color ?, 'dangerous ? "no")) .as(OrderItem) { orderItem => // ... route using case class instance created from // required and optional query parameters diff --git a/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/FutureDirectivesExamplesSpec.scala b/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/FutureDirectivesExamplesSpec.scala index 2f6cbb62dc..6beb509d1b 100644 --- a/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/FutureDirectivesExamplesSpec.scala +++ b/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/FutureDirectivesExamplesSpec.scala @@ -26,9 +26,6 @@ class FutureDirectivesExamplesSpec extends RoutingSpec { ctx.complete((InternalServerError, "Unsuccessful future!")) } - val resourceActor = system.actorOf(Props(new Actor { - def receive = { case _ => sender ! "resource" } - })) implicit val responseTimeout = Timeout(2, TimeUnit.SECONDS) "onComplete" in { diff --git a/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/RouteDirectivesExamplesSpec.scala b/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/RouteDirectivesExamplesSpec.scala index 49ff680801..74bbaac0ca 100644 --- a/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/RouteDirectivesExamplesSpec.scala +++ b/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/RouteDirectivesExamplesSpec.scala @@ -7,6 +7,7 @@ package directives import akka.http.scaladsl.model._ import akka.http.scaladsl.server.{ Route, ValidationRejection } +import akka.testkit.EventFilter class RouteDirectivesExamplesSpec extends RoutingSpec { @@ -83,7 +84,7 @@ class RouteDirectivesExamplesSpec extends RoutingSpec { } } - "failwith-examples" in { + "failwith-examples" in EventFilter[RuntimeException](start = "Error during processing of request", occurrences = 1).intercept { val route = path("foo") { failWith(new RuntimeException("Oops.")) diff --git a/akka-docs-dev/rst/scala/code/docs/stream/RateTransformationDocSpec.scala b/akka-docs-dev/rst/scala/code/docs/stream/RateTransformationDocSpec.scala index 8365457e13..1152a648c5 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/RateTransformationDocSpec.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/RateTransformationDocSpec.scala @@ -12,6 +12,7 @@ import scala.math._ import scala.concurrent.Await import scala.concurrent.duration._ import scala.collection.immutable +import akka.testkit.TestLatch class RateTransformationDocSpec extends AkkaSpec { @@ -84,9 +85,14 @@ class RateTransformationDocSpec extends AkkaSpec { case (lastElement, drift) => ((lastElement, drift), (lastElement, drift + 1)) } //#expand-drift + val latch = TestLatch(2) + val realDriftFlow = Flow[Double] + .expand(d => { latch.countDown(); (d, 0) }) { + case (lastElement, drift) => ((lastElement, drift), (lastElement, drift + 1)) + } val (pub, sub) = TestSource.probe[Double] - .via(driftFlow) + .via(realDriftFlow) .toMat(TestSink.probe[(Double, Int)])(Keep.both) .run() @@ -98,6 +104,7 @@ class RateTransformationDocSpec extends AkkaSpec { sub.requestNext((1.0, 2)) pub.sendNext(2.0) + Await.ready(latch, 1.second) sub.requestNext((2.0, 0)) } diff --git a/akka-docs-dev/rst/scala/code/docs/stream/ReactiveStreamsDocSpec.scala b/akka-docs-dev/rst/scala/code/docs/stream/ReactiveStreamsDocSpec.scala index 079f5c3f30..a001829af3 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/ReactiveStreamsDocSpec.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/ReactiveStreamsDocSpec.scala @@ -139,7 +139,7 @@ class ReactiveStreamsDocSpec extends AkkaSpec { // An example Processor factory def createProcessor: Processor[Int, Int] = Flow[Int].toProcessor.run() - val flow: Flow[Int, Int, Unit] = Flow(() => createProcessor) + val flow: Flow[Int, Int, Unit] = Flow.fromProcessor(() => createProcessor) //#use-processor } diff --git a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeByteStrings.scala b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeByteStrings.scala index 4909d21f57..db7a994681 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeByteStrings.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeByteStrings.scala @@ -27,9 +27,15 @@ class RecipeByteStrings extends RecipeSpec { override def onPull(ctx: Context[ByteString]): SyncDirective = emitChunkOrPull(ctx) + override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = + if (buffer.nonEmpty) ctx.absorbTermination() + else ctx.finish() + private def emitChunkOrPull(ctx: Context[ByteString]): SyncDirective = { - if (buffer.isEmpty) ctx.pull() - else { + if (buffer.isEmpty) { + if (ctx.isFinishing) ctx.finish() + else ctx.pull() + } else { val (emit, nextBuffer) = buffer.splitAt(chunkSize) buffer = nextBuffer ctx.push(emit) diff --git a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeKeepAlive.scala b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeKeepAlive.scala index 6fe5589576..c7ac25835e 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeKeepAlive.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeKeepAlive.scala @@ -41,6 +41,8 @@ class RecipeKeepAlive extends RecipeSpec { val subscription = sub.expectSubscription() + // FIXME RK: remove (because I think this cannot deterministically be tested and it might also not do what it should anymore) + tickPub.sendNext(()) // pending data will overcome the keepalive diff --git a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeMissedTicks.scala b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeMissedTicks.scala index d8117490e7..1829655862 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeMissedTicks.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeMissedTicks.scala @@ -2,8 +2,9 @@ package docs.stream.cookbook import akka.stream.scaladsl._ import akka.stream.testkit._ - import scala.concurrent.duration._ +import akka.testkit.TestLatch +import scala.concurrent.Await class RecipeMissedTicks extends RecipeSpec { @@ -22,8 +23,12 @@ class RecipeMissedTicks extends RecipeSpec { Flow[Tick].conflate(seed = (_) => 0)( (missedTicks, tick) => missedTicks + 1) //#missed-ticks + val latch = TestLatch(3) + val realMissedTicks: Flow[Tick, Int, Unit] = + Flow[Tick].conflate(seed = (_) => 0)( + (missedTicks, tick) => { latch.countDown(); missedTicks + 1 }) - tickStream.via(missedTicks).to(sink).run() + tickStream.via(realMissedTicks).to(sink).run() pub.sendNext(()) pub.sendNext(()) @@ -31,6 +36,8 @@ class RecipeMissedTicks extends RecipeSpec { pub.sendNext(()) val subscription = sub.expectSubscription() + Await.ready(latch, 1.second) + subscription.request(1) sub.expectNext(3) diff --git a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeSimpleDrop.scala b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeSimpleDrop.scala index 4e18bfc439..f8509d29f6 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeSimpleDrop.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeSimpleDrop.scala @@ -2,8 +2,9 @@ package docs.stream.cookbook import akka.stream.scaladsl.{ Flow, Sink, Source } import akka.stream.testkit._ - import scala.concurrent.duration._ +import akka.testkit.TestLatch +import scala.concurrent.Await class RecipeSimpleDrop extends RecipeSpec { @@ -15,13 +16,16 @@ class RecipeSimpleDrop extends RecipeSpec { val droppyStream: Flow[Message, Message, Unit] = Flow[Message].conflate(seed = identity)((lastMessage, newMessage) => newMessage) //#simple-drop + val latch = TestLatch(2) + val realDroppyStream = + Flow[Message].conflate(seed = identity)((lastMessage, newMessage) => { latch.countDown(); newMessage }) val pub = TestPublisher.probe[Message]() val sub = TestSubscriber.manualProbe[Message]() val messageSource = Source(pub) val sink = Sink(sub) - messageSource.via(droppyStream).to(sink).run() + messageSource.via(realDroppyStream).to(sink).run() val subscription = sub.expectSubscription() sub.expectNoMsg(100.millis) @@ -30,6 +34,8 @@ class RecipeSimpleDrop extends RecipeSpec { pub.sendNext("2") pub.sendNext("3") + Await.ready(latch, 1.second) + subscription.request(1) sub.expectNext("3") diff --git a/akka-docs-dev/src/test/resources/application.conf b/akka-docs-dev/src/test/resources/application.conf new file mode 100644 index 0000000000..dafc521805 --- /dev/null +++ b/akka-docs-dev/src/test/resources/application.conf @@ -0,0 +1 @@ +akka.loggers = ["akka.testkit.TestEventListener"] \ No newline at end of file diff --git a/akka-http-core/src/main/java/akka/http/javadsl/model/HttpEntity.java b/akka-http-core/src/main/java/akka/http/javadsl/model/HttpEntity.java index 70b8b4fc25..9c2627afac 100644 --- a/akka-http-core/src/main/java/akka/http/javadsl/model/HttpEntity.java +++ b/akka-http-core/src/main/java/akka/http/javadsl/model/HttpEntity.java @@ -81,7 +81,7 @@ public interface HttpEntity { /** * Returns a stream of data bytes this entity consists of. */ - public abstract Source getDataBytes(); + public abstract Source getDataBytes(); /** * Returns a future of a strict entity that contains the same data as this entity diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala index 8abbe7d9b2..9a93592426 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala @@ -117,7 +117,7 @@ private[http] object OutgoingConnectionBlueprint { val shape = new FanInShape2(requests, responses, out) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { passAlong(requests, out, doFinish = false, doFail = true) setHandler(out, eagerTerminateOutput) @@ -147,7 +147,7 @@ private[http] object OutgoingConnectionBlueprint { val shape = new FanInShape2(dataInput, methodBypassInput, out) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { // each connection uses a single (private) response parser instance for all its responses // which builds a cache of all header instances seen on that connection val parser = rootParser.createShallowCopy() diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolConductor.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolConductor.scala index feb9809891..3662170935 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolConductor.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolConductor.scala @@ -115,7 +115,7 @@ private object PoolConductor { override val shape = new FanInShape2(ctxIn, slotIn, out) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { val slotStates = Array.fill[SlotState](slotCount)(Unconnected) var nextSlot = 0 @@ -207,7 +207,7 @@ private object PoolConductor { override val shape = new UniformFanOutShape[SwitchCommand, RequestContext](slotCount) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { shape.outArray foreach { setHandler(_, ignoreTerminateOutput) } val in = shape.in diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolSlot.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolSlot.scala index c28331fd5c..7ac40d9998 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolSlot.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolSlot.scala @@ -57,10 +57,10 @@ private object PoolSlot { import FlowGraph.Implicits._ val slotProcessor = b.add { - Flow[RequestContext].andThenMat { () ⇒ + Flow.fromProcessor { () ⇒ val actor = system.actorOf(Props(new SlotProcessor(slotIx, connectionFlow, settings)).withDeploy(Deploy.local), slotProcessorActorName.next()) - (ActorProcessor[RequestContext, List[ProcessorOut]](actor), ()) + ActorProcessor[RequestContext, List[ProcessorOut]](actor) }.mapConcat(identity) } val split = b.add(Broadcast[ProcessorOut](2)) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index f7feddcdbf..7b8806b52b 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -179,7 +179,7 @@ private[http] object HttpServerBluePrint { override val shape = new FanInShape3(bypassInput, oneHundredContinue, applicationInput, out) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { var requestStart: RequestStart = _ setHandler(bypassInput, new InHandler { @@ -334,7 +334,7 @@ private[http] object HttpServerBluePrint { override val shape = new FanOutShape2(in, httpOut, wsOut) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { var target = httpOut setHandler(in, new InHandler { @@ -362,7 +362,7 @@ private[http] object HttpServerBluePrint { override val shape = new FanInShape2(httpIn, wsIn, out) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { var websocketHandlerWasInstalled = false setHandler(httpIn, conditionalTerminateInput(() ⇒ !websocketHandlerWasInstalled)) @@ -407,7 +407,7 @@ private[http] object HttpServerBluePrint { override val shape = new FanInShape2(bytes, token, out) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { passAlong(bytes, out, doFinish = true, doFail = true) passAlong(token, out, doFinish = false, doFail = true) setHandler(out, eagerTerminateOutput) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEvent.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEvent.scala index 35c8cc54b3..c2e0d3045f 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEvent.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEvent.scala @@ -7,12 +7,16 @@ package akka.http.impl.engine.ws import akka.http.impl.engine.ws.Protocol.Opcode import akka.util.ByteString +private[http] sealed trait FrameEventOrError + +private[http] final case class FrameError(p: ProtocolException) extends FrameEventOrError + /** * The low-level Websocket framing model. * * INTERNAL API */ -private[http] sealed trait FrameEvent { +private[http] sealed trait FrameEvent extends FrameEventOrError { def data: ByteString def lastPart: Boolean def withData(data: ByteString): FrameEvent diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEventParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEventParser.scala index ca08022b2a..44c74fbfd2 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEventParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEventParser.scala @@ -4,11 +4,10 @@ package akka.http.impl.engine.ws -import akka.http.impl.util.{ ByteReader, ByteStringParserStage } -import akka.stream.stage.{ StageState, SyncDirective, Context } import akka.util.ByteString - import scala.annotation.tailrec +import akka.stream.io.ByteStringParser +import akka.stream.Attributes /** * Streaming parser for the Websocket framing protocol as defined in RFC6455 @@ -36,108 +35,81 @@ import scala.annotation.tailrec * * INTERNAL API */ -private[http] class FrameEventParser extends ByteStringParserStage[FrameEvent] { - protected def onTruncation(ctx: Context[FrameEvent]): SyncDirective = - ctx.fail(new ProtocolException("Data truncated")) +private[http] object FrameEventParser extends ByteStringParser[FrameEvent] { + import ByteStringParser._ - def initial: StageState[ByteString, FrameEvent] = ReadFrameHeader + override def createLogic(attr: Attributes) = new ParsingLogic { + startWith(ReadFrameHeader) - object ReadFrameHeader extends ByteReadingState { - def read(reader: ByteReader, ctx: Context[FrameEvent]): SyncDirective = { - import Protocol._ - - val flagsAndOp = reader.readByte() - val maskAndLength = reader.readByte() - - val flags = flagsAndOp & FLAGS_MASK - val op = flagsAndOp & OP_MASK - - val maskBit = (maskAndLength & MASK_MASK) != 0 - val length7 = maskAndLength & LENGTH_MASK - - val length = - length7 match { - case 126 ⇒ reader.readShortBE().toLong - case 127 ⇒ reader.readLongBE() - case x ⇒ x.toLong - } - - if (length < 0) ctx.fail(new ProtocolException("Highest bit of 64bit length was set")) - - val mask = - if (maskBit) Some(reader.readIntBE()) - else None - - def isFlagSet(mask: Int): Boolean = (flags & mask) != 0 - val header = - FrameHeader(Opcode.forCode(op.toByte), - mask, - length, - fin = isFlagSet(FIN_MASK), - rsv1 = isFlagSet(RSV1_MASK), - rsv2 = isFlagSet(RSV2_MASK), - rsv3 = isFlagSet(RSV3_MASK)) - - val data = reader.remainingData - val takeNow = (header.length min Int.MaxValue).toInt - val thisFrameData = data.take(takeNow) - val remaining = data.drop(takeNow) - - val nextState = - if (thisFrameData.length == length) ReadFrameHeader - else readData(length - thisFrameData.length) - - pushAndBecomeWithRemaining(FrameStart(header, thisFrameData.compact), nextState, remaining, ctx) - } - } - - def readData(_remaining: Long): State = - new State { - var remaining = _remaining - def onPush(elem: ByteString, ctx: Context[FrameEvent]): SyncDirective = - if (elem.size < remaining) { - remaining -= elem.size - ctx.push(FrameData(elem, lastPart = false)) - } else { - require(remaining <= Int.MaxValue) // safe because, remaining <= elem.size <= Int.MaxValue - val frameData = elem.take(remaining.toInt) - val remainingData = elem.drop(remaining.toInt) - - pushAndBecomeWithRemaining(FrameData(frameData.compact, lastPart = true), ReadFrameHeader, remainingData, ctx) - } + trait Step extends ParseStep[FrameEvent] { + override def onTruncation(): Unit = failStage(new ProtocolException("Data truncated")) } - def becomeWithRemaining(nextState: State, remainingData: ByteString, ctx: Context[FrameEvent]): SyncDirective = { - become(nextState) - nextState.onPush(remainingData, ctx) - } - def pushAndBecomeWithRemaining(elem: FrameEvent, nextState: State, remainingData: ByteString, ctx: Context[FrameEvent]): SyncDirective = - if (remainingData.isEmpty) { - become(nextState) - ctx.push(elem) - } else { - become(waitForPull(nextState, remainingData)) - ctx.push(elem) - } + object ReadFrameHeader extends Step { + override def parse(reader: ByteReader): (FrameEvent, Step) = { + import Protocol._ - def waitForPull(nextState: State, remainingData: ByteString): State = - new State { - def onPush(elem: ByteString, ctx: Context[FrameEvent]): SyncDirective = - throw new IllegalStateException("Mustn't push in this state") + val flagsAndOp = reader.readByte() + val maskAndLength = reader.readByte() - override def onPull(ctx: Context[FrameEvent]): SyncDirective = { - become(nextState) - nextState.onPush(remainingData, ctx) + val flags = flagsAndOp & FLAGS_MASK + val op = flagsAndOp & OP_MASK + + val maskBit = (maskAndLength & MASK_MASK) != 0 + val length7 = maskAndLength & LENGTH_MASK + + val length = + length7 match { + case 126 ⇒ reader.readShortBE().toLong + case 127 ⇒ reader.readLongBE() + case x ⇒ x.toLong + } + + if (length < 0) throw new ProtocolException("Highest bit of 64bit length was set") + + val mask = + if (maskBit) Some(reader.readIntBE()) + else None + + def isFlagSet(mask: Int): Boolean = (flags & mask) != 0 + val header = + FrameHeader(Opcode.forCode(op.toByte), + mask, + length, + fin = isFlagSet(FIN_MASK), + rsv1 = isFlagSet(RSV1_MASK), + rsv2 = isFlagSet(RSV2_MASK), + rsv3 = isFlagSet(RSV3_MASK)) + + val takeNow = (header.length min reader.remainingSize).toInt + val thisFrameData = reader.take(takeNow) + + val nextState = + if (thisFrameData.length == length) ReadFrameHeader + else new ReadData(length - thisFrameData.length) + + (FrameStart(header, thisFrameData.compact), nextState) } } -} -object FrameEventParser { + class ReadData(_remaining: Long) extends Step { + var remaining = _remaining + override def parse(reader: ByteReader): (FrameEvent, Step) = + if (reader.remainingSize < remaining) { + remaining -= reader.remainingSize + (FrameData(reader.takeAll(), lastPart = false), this) + } else { + (FrameData(reader.take(remaining.toInt), lastPart = true), ReadFrameHeader) + } + } + } + def mask(bytes: ByteString, _mask: Option[Int]): ByteString = _mask match { case Some(m) ⇒ mask(bytes, m)._1 case None ⇒ bytes } + def mask(bytes: ByteString, mask: Int): (ByteString, Int) = { @tailrec def rec(bytes: Array[Byte], offset: Int, mask: Int): Int = if (offset >= bytes.length) mask diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameHandler.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameHandler.scala index c6557ca0c5..ceb56c25fe 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameHandler.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameHandler.scala @@ -19,10 +19,10 @@ import scala.util.control.NonFatal */ private[http] object FrameHandler { - def create(server: Boolean): Flow[FrameEvent, Output, Unit] = - Flow[FrameEvent].transform(() ⇒ new HandlerStage(server)) + def create(server: Boolean): Flow[FrameEventOrError, Output, Unit] = + Flow[FrameEventOrError].transform(() ⇒ new HandlerStage(server)) - private class HandlerStage(server: Boolean) extends StatefulStage[FrameEvent, Output] { + private class HandlerStage(server: Boolean) extends StatefulStage[FrameEventOrError, Output] { type Ctx = Context[Output] def initial: State = Idle @@ -79,11 +79,6 @@ private[http] object FrameHandler { } } - private object Closed extends State { - def onPush(elem: FrameEvent, ctx: Ctx): SyncDirective = - ctx.pull() // ignore - } - private def becomeAndHandleWith(newState: State, part: FrameEvent)(implicit ctx: Ctx): SyncDirective = { become(newState) current.onPush(part, ctx) @@ -132,7 +127,7 @@ private[http] object FrameHandler { } private object CloseAfterPeerClosed extends State { - def onPush(elem: FrameEvent, ctx: Context[Output]): SyncDirective = + def onPush(elem: FrameEventOrError, ctx: Context[Output]): SyncDirective = elem match { case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data) ⇒ become(WaitForPeerTcpClose) @@ -141,7 +136,7 @@ private[http] object FrameHandler { } } private object WaitForPeerTcpClose extends State { - def onPush(elem: FrameEvent, ctx: Context[Output]): SyncDirective = + def onPush(elem: FrameEventOrError, ctx: Context[Output]): SyncDirective = ctx.pull() // ignore } @@ -168,10 +163,11 @@ private[http] object FrameHandler { def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective - def onPush(part: FrameEvent, ctx: Ctx): SyncDirective = + def onPush(part: FrameEventOrError, ctx: Ctx): SyncDirective = part match { case data: FrameData ⇒ handleFrameData(data)(ctx) case start: FrameStart ⇒ handleFrameStart(start)(ctx) + case FrameError(ex) ⇒ ctx.fail(ex) } } } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala index 9b618b5b10..65a571eb3e 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala @@ -15,13 +15,19 @@ import akka.stream.stage.{ SyncDirective, Context, StatefulStage } * INTERNAL API */ private[http] object Masking { - def apply(serverSide: Boolean, maskRandom: () ⇒ Random): BidiFlow[ /* net in */ FrameEvent, /* app out */ FrameEvent, /* app in */ FrameEvent, /* net out */ FrameEvent, Unit] = + def apply(serverSide: Boolean, maskRandom: () ⇒ Random): BidiFlow[ /* net in */ FrameEvent, /* app out */ FrameEventOrError, /* app in */ FrameEvent, /* net out */ FrameEvent, Unit] = BidiFlow.fromFlowsMat(unmaskIf(serverSide), maskIf(!serverSide, maskRandom))(Keep.none) def maskIf(condition: Boolean, maskRandom: () ⇒ Random): Flow[FrameEvent, FrameEvent, Unit] = - if (condition) Flow[FrameEvent].transform(() ⇒ new Masking(maskRandom())) // new random per materialization + if (condition) + Flow[FrameEvent] + .transform(() ⇒ new Masking(maskRandom())) // new random per materialization + .map { + case f: FrameEvent ⇒ f + case FrameError(ex) ⇒ throw ex + } else Flow[FrameEvent] - def unmaskIf(condition: Boolean): Flow[FrameEvent, FrameEvent, Unit] = + def unmaskIf(condition: Boolean): Flow[FrameEvent, FrameEventOrError, Unit] = if (condition) Flow[FrameEvent].transform(() ⇒ new Unmasking()) else Flow[FrameEvent] @@ -41,19 +47,25 @@ private[http] object Masking { } /** Implements both masking and unmasking which is mostly symmetric (because of XOR) */ - private abstract class Masker extends StatefulStage[FrameEvent, FrameEvent] { + private abstract class Masker extends StatefulStage[FrameEvent, FrameEventOrError] { def extractMask(header: FrameHeader): Int def setNewMask(header: FrameHeader, mask: Int): FrameHeader def initial: State = Idle - object Idle extends State { - def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective = + private object Idle extends State { + def onPush(part: FrameEvent, ctx: Context[FrameEventOrError]): SyncDirective = part match { case start @ FrameStart(header, data) ⇒ - val mask = extractMask(header) - become(new Running(mask)) - current.onPush(start.copy(header = setNewMask(header, mask)), ctx) + try { + val mask = extractMask(header) + become(new Running(mask)) + current.onPush(start.copy(header = setNewMask(header, mask)), ctx) + } catch { + case p: ProtocolException ⇒ + become(Done) + ctx.push(FrameError(p)) + } case _: FrameData ⇒ ctx.fail(new IllegalStateException("unexpected FrameData (need FrameStart first)")) } @@ -61,7 +73,7 @@ private[http] object Masking { private class Running(initialMask: Int) extends State { var mask = initialMask - def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective = { + def onPush(part: FrameEvent, ctx: Context[FrameEventOrError]): SyncDirective = { if (part.lastPart) become(Idle) val (masked, newMask) = FrameEventParser.mask(part.data, mask) @@ -69,5 +81,8 @@ private[http] object Masking { ctx.push(part.withData(data = masked)) } } + private object Done extends State { + def onPush(part: FrameEvent, ctx: Context[FrameEventOrError]): SyncDirective = ctx.pull() + } } } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala index ca313bfbb8..0ac57b7a5d 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala @@ -40,12 +40,12 @@ private[http] object Websocket { /** The lowest layer that implements the binary protocol */ def framing: BidiFlow[ByteString, FrameEvent, FrameEvent, ByteString, Unit] = BidiFlow.fromFlowsMat( - Flow[ByteString].transform(() ⇒ new FrameEventParser), + Flow[ByteString].via(FrameEventParser), Flow[FrameEvent].transform(() ⇒ new FrameEventRenderer))(Keep.none) .named("ws-framing") /** The layer that handles masking using the rules defined in the specification */ - def masking(serverSide: Boolean, maskingRandomFactory: () ⇒ Random): BidiFlow[FrameEvent, FrameEvent, FrameEvent, FrameEvent, Unit] = + def masking(serverSide: Boolean, maskingRandomFactory: () ⇒ Random): BidiFlow[FrameEvent, FrameEventOrError, FrameEvent, FrameEvent, Unit] = Masking(serverSide, maskingRandomFactory) .named("ws-masking") @@ -55,7 +55,7 @@ private[http] object Websocket { */ def frameHandling(serverSide: Boolean = true, closeTimeout: FiniteDuration, - log: LoggingAdapter): BidiFlow[FrameEvent, FrameHandler.Output, FrameOutHandler.Input, FrameStart, Unit] = + log: LoggingAdapter): BidiFlow[FrameEventOrError, FrameHandler.Output, FrameOutHandler.Input, FrameStart, Unit] = BidiFlow.fromFlowsMat( FrameHandler.create(server = serverSide), FrameOutHandler.create(serverSide, closeTimeout, log))(Keep.none) @@ -156,7 +156,7 @@ private[http] object Websocket { val shape = new FanOutShape2(in, bypass, user) - def createLogic = new GraphStageLogic(shape) { + def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { setHandler(in, new InHandler { override def onPush(): Unit = { @@ -187,7 +187,7 @@ private[http] object Websocket { val shape = new FanInShape3(bypass, user, tick, out) - def createLogic = new GraphStageLogic(shape) { + def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { passAlong(bypass, out, doFinish = true, doFail = true) passAlong(user, out, doFinish = false, doFail = false) @@ -210,7 +210,7 @@ private[http] object Websocket { val shape = new FlowShape(in, out) - def createLogic = new GraphStageLogic(shape) { + def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { setHandler(out, new OutHandler { override def onPull(): Unit = pull(in) }) diff --git a/akka-http-core/src/main/scala/akka/http/impl/util/ByteReader.scala b/akka-http-core/src/main/scala/akka/http/impl/util/ByteReader.scala index d46ab53ca0..3a03d3dea5 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/util/ByteReader.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/util/ByteReader.scala @@ -21,17 +21,17 @@ private[akka] class ByteReader(input: ByteString) { def currentOffset: Int = off def remainingData: ByteString = input.drop(off) - def fromStartToHere: ByteString = input.take(currentOffset) + def fromStartToHere: ByteString = input.take(off) def readByte(): Int = if (off < input.length) { val x = input(off) off += 1 - x.toInt & 0xFF + x & 0xFF } else throw NeedMoreData def readShortLE(): Int = readByte() | (readByte() << 8) def readIntLE(): Int = readShortLE() | (readShortLE() << 16) - def readLongLE(): Long = (readIntBE() & 0xffffffffL) | ((readIntLE() & 0xffffffffL) << 32) + def readLongLE(): Long = (readIntLE() & 0xffffffffL) | ((readIntLE() & 0xffffffffL) << 32) def readShortBE(): Int = (readByte() << 8) | readByte() def readIntBE(): Int = (readShortBE() << 16) | readShortBE() diff --git a/akka-http-core/src/main/scala/akka/http/impl/util/package.scala b/akka-http-core/src/main/scala/akka/http/impl/util/package.scala index e86da57357..dc10cfea7f 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/util/package.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/util/package.scala @@ -128,7 +128,7 @@ package object util { package util { import akka.http.scaladsl.model.{ ContentType, HttpEntity } - import akka.stream.{ Outlet, Inlet, FlowShape } + import akka.stream.{ Attributes, Outlet, Inlet, FlowShape } import scala.concurrent.duration.FiniteDuration private[http] class ToStrict(timeout: FiniteDuration, contentType: ContentType) @@ -138,7 +138,7 @@ package util { val out = Outlet[HttpEntity.Strict]("out") override val shape = FlowShape(in, out) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { var bytes = ByteString.newBuilder private var emptyStream = false diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpEntity.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpEntity.scala index f29829b103..e8db1d88c2 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpEntity.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpEntity.scala @@ -75,7 +75,7 @@ sealed trait HttpEntity extends jm.HttpEntity { def withContentType(contentType: ContentType): HttpEntity /** Java API */ - def getDataBytes: stream.javadsl.Source[ByteString, _] = stream.javadsl.Source.fromGraph(dataBytes) + def getDataBytes: stream.javadsl.Source[ByteString, AnyRef] = stream.javadsl.Source.fromGraph(dataBytes.asInstanceOf[Source[ByteString, AnyRef]]) /** Java API */ def getContentLengthOption: japi.Option[JLong] = @@ -147,9 +147,11 @@ object HttpEntity { def apply(contentType: ContentType, data: Source[ByteString, Any]): Chunked = Chunked.fromData(contentType, data) - def apply(contentType: ContentType, file: File, chunkSize: Int = SynchronousFileSource.DefaultChunkSize): UniversalEntity = { + def apply(contentType: ContentType, file: File, chunkSize: Int = -1): UniversalEntity = { val fileLength = file.length - if (fileLength > 0) Default(contentType, fileLength, SynchronousFileSource(file, chunkSize)) + if (fileLength > 0) + Default(contentType, fileLength, + if (chunkSize > 0) SynchronousFileSource(file, chunkSize) else SynchronousFileSource(file)) else empty(contentType) } diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala index 346f90cd22..530546a8e6 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala @@ -194,7 +194,7 @@ object Multipart { * To create an instance with several parts or for multiple files, use * ``FormData(BodyPart.fromFile("field1", ...), BodyPart.fromFile("field2", ...)`` */ - def fromFile(name: String, contentType: ContentType, file: File, chunkSize: Int = SynchronousFileSource.DefaultChunkSize): FormData = + def fromFile(name: String, contentType: ContentType, file: File, chunkSize: Int = -1): FormData = FormData(Source.single(BodyPart.fromFile(name, contentType, file, chunkSize))) /** @@ -237,7 +237,7 @@ object Multipart { /** * Creates a BodyPart backed by a File that will be streamed using a SynchronousFileSource. */ - def fromFile(name: String, contentType: ContentType, file: File, chunkSize: Int = SynchronousFileSource.DefaultChunkSize): BodyPart = + def fromFile(name: String, contentType: ContentType, file: File, chunkSize: Int = -1): BodyPart = BodyPart(name, HttpEntity(contentType, file, chunkSize), Map("filename" -> file.getName)) def unapply(value: BodyPart): Option[(String, BodyPartEntity, Map[String, String], immutable.Seq[HttpHeader])] = diff --git a/akka-http-core/src/test/resources/reference.conf b/akka-http-core/src/test/resources/reference.conf index ab48718a51..1660c0e30d 100644 --- a/akka-http-core/src/test/resources/reference.conf +++ b/akka-http-core/src/test/resources/reference.conf @@ -1,4 +1,5 @@ akka { + loggers = ["akka.testkit.TestEventListener"] actor { serialize-creators = on serialize-messages = on diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala index 8a76f3fb63..e287b92add 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala @@ -57,8 +57,8 @@ class ConnectionPoolSpec extends AkkaSpec(""" requestIn.sendNext(HttpRequest(uri = "/") -> 42) - acceptIncomingConnection() responseOutSub.request(1) + acceptIncomingConnection() val (Success(response), 42) = responseOut.expectNext() response.headers should contain(RawHeader("Req-Host", s"$serverHostName:$serverPort")) } @@ -116,8 +116,8 @@ class ConnectionPoolSpec extends AkkaSpec(""" val (requestIn, responseOut, responseOutSub, hcp) = cachedHostConnectionPool[Int]() requestIn.sendNext(HttpRequest(uri = "/a") -> 42) - acceptIncomingConnection() responseOutSub.request(1) + acceptIncomingConnection() val (Success(response1), 42) = responseOut.expectNext() connNr(response1) shouldEqual 1 @@ -222,8 +222,8 @@ class ConnectionPoolSpec extends AkkaSpec(""" requestIn.sendNext(HttpRequest(uri = "/") -> 42) - acceptIncomingConnection() responseOutSub.request(1) + acceptIncomingConnection() val (Success(_), 42) = responseOut.expectNext() } } @@ -346,7 +346,7 @@ class ConnectionPoolSpec extends AkkaSpec(""" def flowTestBench[T, Mat](poolFlow: Flow[(HttpRequest, T), (Try[HttpResponse], T), Mat]) = { val requestIn = TestPublisher.probe[(HttpRequest, T)]() val responseOut = TestSubscriber.manualProbe[(Try[HttpResponse], T)] - val hcp = Source(requestIn).viaMat(poolFlow)(Keep.right).toMat(Sink(responseOut))(Keep.left).run() + val hcp = Source(requestIn).viaMat(poolFlow)(Keep.right).to(Sink(responseOut)).run() val responseOutSub = responseOut.expectSubscription() (requestIn, responseOut, responseOutSub, hcp) } diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala index 4d56fcca37..9150e1540d 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala @@ -62,7 +62,7 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. | |""") val sub = probe.expectSubscription() - sub.expectRequest(4) + sub.expectRequest() sub.sendNext(ByteString("ABC")) expectWireData("ABC") sub.sendNext(ByteString("DEF")) @@ -228,7 +228,7 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. | |""") val sub = probe.expectSubscription() - sub.expectRequest(4) + sub.expectRequest() sub.sendNext(ByteString("ABC")) expectWireData("ABC") sub.sendNext(ByteString("DEF")) @@ -254,7 +254,7 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. | |""") val sub = probe.expectSubscription() - sub.expectRequest(4) + sub.expectRequest() sub.sendNext(ByteString("ABC")) expectWireData("ABC") sub.sendNext(ByteString("DEF")) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala index c0a5aab683..71ccc11977 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala @@ -6,20 +6,18 @@ package akka.http.impl.engine.client import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.concurrent.ScalaFutures - import akka.stream.ActorMaterializer import akka.stream.io._ import akka.stream.scaladsl._ import akka.stream.testkit.AkkaSpec - import akka.http.impl.util._ - import akka.http.scaladsl.{ HttpsContext, Http } import akka.http.scaladsl.model.{ StatusCodes, HttpResponse, HttpRequest } import akka.http.scaladsl.model.headers.Host import org.scalatest.time.{ Span, Seconds } - import scala.concurrent.Future +import akka.testkit.EventFilter +import javax.net.ssl.SSLException class TlsEndpointVerificationSpec extends AkkaSpec(""" #akka.loggers = [] @@ -30,7 +28,7 @@ class TlsEndpointVerificationSpec extends AkkaSpec(""" val timeout = Timeout(Span(3, Seconds)) "The client implementation" should { - "not accept certificates signed by unknown CA" in { + "not accept certificates signed by unknown CA" in EventFilter[SSLException](occurrences = 1).intercept { val pipe = pipeline(Http().defaultClientHttpsContext, hostname = "akka.example.org") // default context doesn't include custom CA whenReady(pipe(HttpRequest(uri = "https://akka.example.org/")).failed, timeout) { e ⇒ diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/FramingSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/FramingSpec.scala index f5e5aac239..095eef7e51 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/FramingSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/FramingSpec.scala @@ -307,13 +307,12 @@ class FramingSpec extends FreeSpec with Matchers with WithMaterializerSpec { } private def parseToEvents(bytes: Seq[ByteString]): immutable.Seq[FrameEvent] = - Source(bytes.toVector).transform(newParser).runFold(Vector.empty[FrameEvent])(_ :+ _) + Source(bytes.toVector).via(FrameEventParser).runFold(Vector.empty[FrameEvent])(_ :+ _) .awaitResult(1.second) private def renderToByteString(events: immutable.Seq[FrameEvent]): ByteString = Source(events).transform(newRenderer).runFold(ByteString.empty)(_ ++ _) .awaitResult(1.second) - protected def newParser(): Stage[ByteString, FrameEvent] = new FrameEventParser protected def newRenderer(): Stage[FrameEvent, ByteString] = new FrameEventRenderer import scala.language.implicitConversions diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala index 49b4024b93..90583939af 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala @@ -13,6 +13,7 @@ import akka.stream.testkit._ import akka.util.ByteString import akka.http.scaladsl.model.ws._ import Protocol.Opcode +import akka.testkit.EventFilter class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { import WSTestUtils._ @@ -595,15 +596,18 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { netIn.expectCancellation() } "if user handler fails" in new ServerTestSetup { - messageOut.sendError(new RuntimeException("Oops, user handler failed!")) - expectCloseCodeOnNetwork(Protocol.CloseCodes.UnexpectedCondition) + EventFilter[RuntimeException](message = "Oops, user handler failed!", occurrences = 1) + .intercept { + messageOut.sendError(new RuntimeException("Oops, user handler failed!")) + expectCloseCodeOnNetwork(Protocol.CloseCodes.UnexpectedCondition) - expectNoNetworkData() // wait for peer to close regularly - pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) + expectNoNetworkData() // wait for peer to close regularly + pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) - expectComplete(messageIn) - netOut.expectComplete() - netIn.expectCancellation() + expectComplete(messageIn) + netOut.expectComplete() + netIn.expectCancellation() + } } "if peer closes with invalid close frame" - { "close code outside of the valid range" in new ServerTestSetup { @@ -828,7 +832,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { Source(netIn) .via(printEvent("netIn")) - .transform(() ⇒ new FrameEventParser) + .via(FrameEventParser) .via(Websocket.stack(serverSide, maskingRandomFactory = Randoms.SecureRandomInstances, closeTimeout = closeTimeout, log = system.log).join(messageHandler)) .via(printEvent("frameRendererIn")) .transform(() ⇒ new FrameEventRenderer) diff --git a/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala b/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala index 244fded983..85929003fc 100644 --- a/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala @@ -24,8 +24,8 @@ import akka.http.scaladsl.model.HttpMethods._ import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers._ import akka.http.impl.util._ - import scala.util.{ Failure, Try, Success } +import java.net.BindException class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { val testConf: Config = ConfigFactory.parseString(""" @@ -56,7 +56,7 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { sub.cancel() } - "report failure if bind fails" in { + "report failure if bind fails" in EventFilter[BindException](occurrences = 2).intercept { val (_, hostname, port) = TestUtils.temporaryServerHostnameAndPort() val binding = Http().bind(hostname, port) val probe1 = TestSubscriber.manualProbe[Http.IncomingConnection]() @@ -170,7 +170,8 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { // waiting for the timeout to happen on the client intercept[StreamTcpException] { Await.result(clientsResponseFuture, 2.second) } - (System.nanoTime() - serverReceivedRequestAtNanos).millis should be >= theIdleTimeout + val fudge = 100.millis + ((System.nanoTime() - serverReceivedRequestAtNanos).nanos + fudge) should be >= theIdleTimeout } "log materialization errors in `bindAndHandle`" which { diff --git a/akka-http-tests/src/test/resources/reference.conf b/akka-http-tests/src/test/resources/reference.conf index ab48718a51..1660c0e30d 100644 --- a/akka-http-tests/src/test/resources/reference.conf +++ b/akka-http-tests/src/test/resources/reference.conf @@ -1,4 +1,5 @@ akka { + loggers = ["akka.testkit.TestEventListener"] actor { serialize-creators = on serialize-messages = on diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/BasicRouteSpecs.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/BasicRouteSpecs.scala index c732923631..68bcc0afd9 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/BasicRouteSpecs.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/BasicRouteSpecs.scala @@ -7,6 +7,7 @@ package akka.http.scaladsl.server import akka.http.scaladsl.model import model.HttpMethods._ import model.StatusCodes +import akka.testkit.EventFilter class BasicRouteSpecs extends RoutingSpec { @@ -134,7 +135,7 @@ class BasicRouteSpecs extends RoutingSpec { case object MyException extends RuntimeException "Route sealing" should { - "catch route execution exceptions" in { + "catch route execution exceptions" in EventFilter[MyException.type](occurrences = 1).intercept { Get("/abc") ~> Route.seal { get { ctx ⇒ throw MyException @@ -143,7 +144,7 @@ class BasicRouteSpecs extends RoutingSpec { status shouldEqual StatusCodes.InternalServerError } } - "catch route building exceptions" in { + "catch route building exceptions" in EventFilter[MyException.type](occurrences = 1).intercept { Get("/abc") ~> Route.seal { get { throw MyException @@ -152,7 +153,7 @@ class BasicRouteSpecs extends RoutingSpec { status shouldEqual StatusCodes.InternalServerError } } - "convert all rejections to responses" in { + "convert all rejections to responses" in EventFilter[RuntimeException](occurrences = 1).intercept { object MyRejection extends Rejection Get("/abc") ~> Route.seal { get { diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ExecutionDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ExecutionDirectivesSpec.scala index 86e5cd6eb7..f5891ecdcb 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ExecutionDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ExecutionDirectivesSpec.scala @@ -7,8 +7,8 @@ package directives import akka.http.scaladsl.model.{ MediaTypes, MediaRanges, StatusCodes } import akka.http.scaladsl.model.headers._ - import scala.concurrent.Future +import akka.testkit.EventFilter class ExecutionDirectivesSpec extends RoutingSpec { object MyException extends RuntimeException @@ -51,7 +51,7 @@ class ExecutionDirectivesSpec extends RoutingSpec { } } } - "not interfere with alternative routes" in { + "not interfere with alternative routes" in EventFilter[MyException.type](occurrences = 1).intercept { Get("/abc") ~> get { handleExceptions(handler)(reject) ~ { ctx ⇒ @@ -62,22 +62,22 @@ class ExecutionDirectivesSpec extends RoutingSpec { responseAs[String] shouldEqual "There was an internal server error." } } - "not handle other exceptions" in { + "not handle other exceptions" in EventFilter[RuntimeException](occurrences = 1, message = "buh").intercept { Get("/abc") ~> get { handleExceptions(handler) { - throw new RuntimeException + throw new RuntimeException("buh") } } ~> check { status shouldEqual StatusCodes.InternalServerError responseAs[String] shouldEqual "There was an internal server error." } } - "always fall back to a default content type" in { + "always fall back to a default content type" in EventFilter[RuntimeException](occurrences = 2, message = "buh2").intercept { Get("/abc") ~> Accept(MediaTypes.`application/json`) ~> get { handleExceptions(handler) { - throw new RuntimeException + throw new RuntimeException("buh2") } } ~> check { status shouldEqual StatusCodes.InternalServerError @@ -87,7 +87,7 @@ class ExecutionDirectivesSpec extends RoutingSpec { Get("/abc") ~> Accept(MediaTypes.`text/xml`, MediaRanges.`*/*`.withQValue(0f)) ~> get { handleExceptions(handler) { - throw new RuntimeException + throw new RuntimeException("buh2") } } ~> check { status shouldEqual StatusCodes.InternalServerError diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FutureDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FutureDirectivesSpec.scala index fc0c6e7130..2dc05f770d 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FutureDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FutureDirectivesSpec.scala @@ -6,8 +6,8 @@ package akka.http.scaladsl.server package directives import akka.http.scaladsl.model.StatusCodes - import scala.concurrent.Future +import akka.testkit.EventFilter class FutureDirectivesSpec extends RoutingSpec { @@ -56,7 +56,7 @@ class FutureDirectivesSpec extends RoutingSpec { responseAs[String] shouldEqual "yes" } } - "propagate the exception in the failure case" in { + "propagate the exception in the failure case" in EventFilter[Exception](occurrences = 1, message = "XXX").intercept { Get() ~> onSuccess(Future.failed(TestException)) { echoComplete } ~> check { status shouldEqual StatusCodes.InternalServerError } @@ -67,7 +67,7 @@ class FutureDirectivesSpec extends RoutingSpec { responseAs[String] shouldEqual s"Oops. akka.http.scaladsl.server.directives.FutureDirectivesSpec$$TestException: EX when ok" } } - "catch an exception in the failure case" in { + "catch an exception in the failure case" in EventFilter[Exception](occurrences = 1, message = "XXX").intercept { Get() ~> onSuccess(Future.failed(TestException)) { throwTestException("EX when ") } ~> check { status shouldEqual StatusCodes.InternalServerError responseAs[String] shouldEqual "There was an internal server error." diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RouteDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RouteDirectivesSpec.scala index 09f5bc03f1..7de787fb30 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RouteDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RouteDirectivesSpec.scala @@ -8,7 +8,6 @@ import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport import akka.http.scaladsl.marshallers.xml.ScalaXmlSupport import akka.stream.scaladsl.Sink import org.scalatest.FreeSpec - import scala.concurrent.{ Future, Promise } import akka.http.scaladsl.marshallers.xml.ScalaXmlSupport._ import akka.http.scaladsl.marshalling._ @@ -18,8 +17,8 @@ import akka.http.impl.util._ import headers._ import StatusCodes._ import MediaTypes._ - import scala.xml.NodeSeq +import akka.testkit.EventFilter class RouteDirectivesSpec extends FreeSpec with GenericRoutingSpec { @@ -47,7 +46,7 @@ class RouteDirectivesSpec extends FreeSpec with GenericRoutingSpec { "for successful futures and marshalling" in { Get() ~> complete(Promise.successful("yes").future) ~> check { responseAs[String] shouldEqual "yes" } } - "for failed futures and marshalling" in { + "for failed futures and marshalling" in EventFilter[RuntimeException](occurrences = 1).intercept { object TestException extends RuntimeException Get() ~> complete(Promise.failed[String](TestException).future) ~> check { diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/SecurityDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/SecurityDirectivesSpec.scala index 8d3a5f8f4f..a6ae7308e0 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/SecurityDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/SecurityDirectivesSpec.scala @@ -9,6 +9,7 @@ import scala.concurrent.Future import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.server.AuthenticationFailedRejection.{ CredentialsRejected, CredentialsMissing } +import akka.testkit.EventFilter class SecurityDirectivesSpec extends RoutingSpec { val dontBasicAuth = authenticateBasicAsync[String]("MyRealm", _ ⇒ Future.successful(None)) @@ -60,11 +61,13 @@ class SecurityDirectivesSpec extends RoutingSpec { } "properly handle exceptions thrown in its inner route" in { object TestException extends RuntimeException - Get() ~> Authorization(BasicHttpCredentials("Alice", "")) ~> { - Route.seal { - doBasicAuth { _ ⇒ throw TestException } - } - } ~> check { status shouldEqual StatusCodes.InternalServerError } + EventFilter[TestException.type](occurrences = 1).intercept { + Get() ~> Authorization(BasicHttpCredentials("Alice", "")) ~> { + Route.seal { + doBasicAuth { _ ⇒ throw TestException } + } + } ~> check { status shouldEqual StatusCodes.InternalServerError } + } } } "bearer token authentication" should { @@ -108,11 +111,13 @@ class SecurityDirectivesSpec extends RoutingSpec { } "properly handle exceptions thrown in its inner route" in { object TestException extends RuntimeException - Get() ~> Authorization(OAuth2BearerToken("myToken")) ~> { - Route.seal { - doOAuth2Auth { _ ⇒ throw TestException } - } - } ~> check { status shouldEqual StatusCodes.InternalServerError } + EventFilter[TestException.type](occurrences = 1).intercept { + Get() ~> Authorization(OAuth2BearerToken("myToken")) ~> { + Route.seal { + doOAuth2Auth { _ ⇒ throw TestException } + } + } ~> check { status shouldEqual StatusCodes.InternalServerError } + } } } "authentication directives" should { diff --git a/akka-stream-tck/src/test/scala/akka/stream/tck/FusableProcessorTest.scala b/akka-stream-tck/src/test/scala/akka/stream/tck/FusableProcessorTest.scala index fe4fd1895f..717c04da22 100644 --- a/akka-stream-tck/src/test/scala/akka/stream/tck/FusableProcessorTest.scala +++ b/akka-stream-tck/src/test/scala/akka/stream/tck/FusableProcessorTest.scala @@ -3,11 +3,11 @@ */ package akka.stream.tck -import akka.stream.{ ActorMaterializer, ActorMaterializerSettings } +import akka.stream.impl.Stages +import akka.stream._ import akka.stream.impl.Stages.Identity import akka.stream.scaladsl.Flow import org.reactivestreams.Processor -import akka.stream.Attributes class FusableProcessorTest extends AkkaIdentityProcessorVerification[Int] { @@ -18,7 +18,7 @@ class FusableProcessorTest extends AkkaIdentityProcessorVerification[Int] { implicit val materializer = ActorMaterializer(settings)(system) // withAttributes "wraps" the underlying identity and protects it from automatic removal - Flow[Int].andThen[Int](Identity()).named("identity").toProcessor.run() + Flow[Int].via(Stages.identityGraph.asInstanceOf[Graph[FlowShape[Int, Int], Unit]]).named("identity").toProcessor.run() } override def createElement(element: Int): Int = element diff --git a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala index e3b1d41545..94d39c368b 100644 --- a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala +++ b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala @@ -9,11 +9,12 @@ import akka.stream.impl.StreamLayout.Module import akka.stream.impl._ import akka.testkit.TestProbe import org.reactivestreams.{ Publisher, Subscriber, Subscription } - import scala.annotation.tailrec import scala.collection.immutable import scala.concurrent.duration._ import scala.language.existentials +import java.io.StringWriter +import java.io.PrintWriter /** * Provides factory methods for various Publishers. @@ -183,7 +184,16 @@ object TestSubscriber { final case class OnSubscribe(subscription: Subscription) extends SubscriberEvent final case class OnNext[I](element: I) extends SubscriberEvent final case object OnComplete extends SubscriberEvent - final case class OnError(cause: Throwable) extends SubscriberEvent + final case class OnError(cause: Throwable) extends SubscriberEvent { + override def toString: String = { + val str = new StringWriter + val out = new PrintWriter(str) + out.print("OnError(") + cause.printStackTrace(out) + out.print(")") + str.toString + } + } /** * Probe that implements [[org.reactivestreams.Subscriber]] interface. diff --git a/akka-stream-tests/src/test/java/akka/stream/javadsl/AttributesTest.java b/akka-stream-tests/src/test/java/akka/stream/javadsl/AttributesTest.java index 33332a6cb2..caaf023fef 100644 --- a/akka-stream-tests/src/test/java/akka/stream/javadsl/AttributesTest.java +++ b/akka-stream-tests/src/test/java/akka/stream/javadsl/AttributesTest.java @@ -12,7 +12,7 @@ import org.junit.Test; import akka.stream.Attributes; public class AttributesTest { - + final Attributes attributes = Attributes.name("a") .and(Attributes.name("b")) @@ -27,12 +27,12 @@ public class AttributesTest { Collections.singletonList(new Attributes.InputBuffer(1, 2)), attributes.getAttributeList(Attributes.InputBuffer.class)); } - + @Test public void mustGetAttributeByClass() { assertEquals( - new Attributes.Name("a"), + new Attributes.Name("b"), attributes.getAttribute(Attributes.Name.class, new Attributes.Name("default"))); } - + } diff --git a/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java b/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java index ebaa33bbba..db6c3d33e4 100644 --- a/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java +++ b/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java @@ -16,6 +16,7 @@ import akka.stream.testkit.AkkaSpec; import akka.stream.testkit.TestPublisher; import akka.testkit.JavaTestKit; import org.junit.ClassRule; +import org.junit.Ignore; import org.junit.Test; import org.reactivestreams.Publisher; import scala.concurrent.Await; @@ -179,7 +180,7 @@ public class FlowTest extends StreamTest { } - @Test + @Ignore("StatefulStage to be converted to GraphStage when Java Api is available (#18817)") @Test public void mustBeAbleToUseTransform() { final JavaTestKit probe = new JavaTestKit(system); final Iterable input = Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7); @@ -204,15 +205,15 @@ public class FlowTest extends StreamTest { return emit(Arrays.asList(element, element).iterator(), ctx); } } - + }; } - + @Override public TerminationDirective onUpstreamFinish(Context ctx) { return terminationEmit(Collections.singletonList(sum).iterator(), ctx); } - + }; } }); @@ -355,12 +356,12 @@ public class FlowTest extends StreamTest { return new akka.japi.function.Creator>() { @Override public PushPullStage create() throws Exception { - return new PushPullStage() { + return new PushPullStage() { @Override public SyncDirective onPush(T element, Context ctx) { return ctx.push(element); } - + @Override public SyncDirective onPull(Context ctx) { return ctx.pull(); @@ -374,17 +375,17 @@ public class FlowTest extends StreamTest { public void mustBeAbleToUseMerge() throws Exception { final Flow f1 = Flow.of(String.class).transform(FlowTest.this. op()).named("f1"); - final Flow f2 = + final Flow f2 = Flow.of(String.class).transform(FlowTest.this. op()).named("f2"); @SuppressWarnings("unused") - final Flow f3 = + final Flow f3 = Flow.of(String.class).transform(FlowTest.this. op()).named("f3"); final Source in1 = Source.from(Arrays.asList("a", "b", "c")); final Source in2 = Source.from(Arrays.asList("d", "e", "f")); final Sink> publisher = Sink.publisher(); - + final Source source = Source.fromGraph( FlowGraph.create(new Function, SourceShape>() { @Override diff --git a/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java b/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java index 6058e72db0..1f70e19d3a 100644 --- a/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java +++ b/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java @@ -20,6 +20,7 @@ import akka.stream.testkit.AkkaSpec; import akka.stream.testkit.TestPublisher; import akka.testkit.JavaTestKit; import org.junit.ClassRule; +import org.junit.Ignore; import org.junit.Test; import scala.concurrent.Await; import scala.concurrent.Future; @@ -107,7 +108,7 @@ public class SourceTest extends StreamTest { probe.expectMsgEquals("()"); } - @Test + @Ignore("StatefulStage to be converted to GraphStage when Java Api is available (#18817)") @Test public void mustBeAbleToUseTransform() { final JavaTestKit probe = new JavaTestKit(system); final Iterable input = Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7); @@ -415,7 +416,7 @@ public class SourceTest extends StreamTest { @Test public void mustProduceTicks() throws Exception { final JavaTestKit probe = new JavaTestKit(system); - Source tickSource = Source.from(FiniteDuration.create(1, TimeUnit.SECONDS), + Source tickSource = Source.from(FiniteDuration.create(1, TimeUnit.SECONDS), FiniteDuration.create(500, TimeUnit.MILLISECONDS), "tick"); Cancellable cancellable = tickSource.to(Sink.foreach(new Procedure() { public void apply(String elem) { @@ -457,7 +458,7 @@ public class SourceTest extends StreamTest { String result = Await.result(future2, probe.dilated(FiniteDuration.create(3, TimeUnit.SECONDS))); assertEquals("A", result); } - + @Test public void mustRepeat() throws Exception { final Future> f = Source.repeat(42).grouped(10000).runWith(Sink.> head(), materializer); @@ -465,7 +466,7 @@ public class SourceTest extends StreamTest { assertEquals(result.size(), 10000); for (Integer i: result) assertEquals(i, (Integer) 42); } - + @Test public void mustBeAbleToUseActorRefSource() throws Exception { final JavaTestKit probe = new JavaTestKit(system); diff --git a/akka-stream-tests/src/test/resources/reference.conf b/akka-stream-tests/src/test/resources/application.conf similarity index 62% rename from akka-stream-tests/src/test/resources/reference.conf rename to akka-stream-tests/src/test/resources/application.conf index ab48718a51..1660c0e30d 100644 --- a/akka-stream-tests/src/test/resources/reference.conf +++ b/akka-stream-tests/src/test/resources/application.conf @@ -1,4 +1,5 @@ akka { + loggers = ["akka.testkit.TestEventListener"] actor { serialize-creators = on serialize-messages = on diff --git a/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala b/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala index 79688dd703..cfca10dbb8 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala @@ -26,20 +26,16 @@ class DslConsistencySpec extends WordSpec with Matchers { Set("equals", "hashCode", "notify", "notifyAll", "wait", "toString", "getClass") ++ Set("productArity", "canEqual", "productPrefix", "copy", "productIterator", "productElement") ++ Set("create", "apply", "ops", "appendJava", "andThen", "andThenMat", "isIdentity", "withAttributes", "transformMaterializing") ++ - Set("asScala", "asJava") + Set("asScala", "asJava", "deprecatedAndThen", "deprecatedAndThenMat") val allowMissing: Map[Class[_], Set[String]] = Map( sFlowClass -> Set("of"), sSourceClass -> Set("adapt", "from"), sSinkClass -> Set("adapt"), - // TODO timerTransform is to be removed or replaced. See https://github.com/akka/akka/issues/16393 - jFlowClass -> Set("timerTransform"), - jSourceClass -> Set("timerTransform"), jSinkClass -> Set(), - sRunnableGraphClass -> Set("builder"), - jRunnableGraphClass → Set("graph", "cyclesAllowed")) + sRunnableGraphClass -> Set("builder")) def materializing(m: Method): Boolean = m.getParameterTypes.contains(classOf[ActorMaterializer]) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterLifecycleSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterLifecycleSpec.scala deleted file mode 100644 index e0ec326af2..0000000000 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterLifecycleSpec.scala +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright (C) 2015 Typesafe Inc. - */ -package akka.stream.impl - -import akka.stream.Supervision._ -import akka.stream._ -import akka.stream.impl.fusing.{ ActorInterpreter, InterpreterLifecycleSpecKit } -import akka.stream.stage.Stage -import akka.stream.testkit.Utils.TE -import akka.stream.testkit.{ AkkaSpec, _ } - -import scala.concurrent.duration._ - -class ActorInterpreterLifecycleSpec extends AkkaSpec with InterpreterLifecycleSpecKit { - - implicit val mat = ActorMaterializer() - - class Setup(ops: List[Stage[_, _]] = List(fusing.Map({ x: Any ⇒ x }, stoppingDecider))) { - val up = TestPublisher.manualProbe[Int]() - val down = TestSubscriber.manualProbe[Int] - private val props = ActorInterpreter.props(mat.settings, ops, mat).withDispatcher("akka.test.stream-dispatcher") - val actor = system.actorOf(props) - val processor = ActorProcessorFactory[Int, Int](actor) - } - - "An ActorInterpreter" must { - - "call preStart in order on stages" in new Setup(List( - PreStartAndPostStopIdentity(onStart = _ ⇒ testActor ! "start-a"), - PreStartAndPostStopIdentity(onStart = _ ⇒ testActor ! "start-b"), - PreStartAndPostStopIdentity(onStart = _ ⇒ testActor ! "start-c"))) { - processor.subscribe(down) - val sub = down.expectSubscription() - sub.cancel() - up.subscribe(processor) - val upsub = up.expectSubscription() - upsub.expectCancellation() - - expectMsg("start-a") - expectMsg("start-b") - expectMsg("start-c") - } - - "call postStart in order on stages - when upstream completes" in new Setup(List( - PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-a"), - PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-b"), - PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-c"))) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - upsub.sendComplete() - down.expectComplete() - - expectMsg("stop-a") - expectMsg("stop-b") - expectMsg("stop-c") - } - - "call postStart in order on stages - when downstream cancels" in new Setup(List( - PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-a"), - PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-b"), - PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-c"))) { - processor.subscribe(down) - val sub = down.expectSubscription() - sub.cancel() - up.subscribe(processor) - val upsub = up.expectSubscription() - upsub.expectCancellation() - - expectMsg("stop-c") - expectMsg("stop-b") - expectMsg("stop-a") - } - - "onError downstream when preStart fails" in new Setup(List( - PreStartFailer(() ⇒ throw TE("Boom!")))) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - down.expectError(TE("Boom!")) - } - - "onError only once even with Supervision.restart" in new Setup(List( - PreStartFailer(() ⇒ throw TE("Boom!")))) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - down.expectError(TE("Boom!")) - down.expectNoMsg(1.second) - } - - "onError downstream when preStart fails with 'most downstream' failure, when multiple stages fail" in new Setup(List( - PreStartFailer(() ⇒ throw TE("Boom 1!")), - PreStartFailer(() ⇒ throw TE("Boom 2!")), - PreStartFailer(() ⇒ throw TE("Boom 3!")))) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - down.expectError(TE("Boom 3!")) - down.expectNoMsg(300.millis) - } - - "continue with stream shutdown when postStop fails" in new Setup(List( - PostStopFailer(() ⇒ throw TE("Boom!")))) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - upsub.sendComplete() - down.expectComplete() - } - - } - -} diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterSpec.scala deleted file mode 100644 index ea3afe083b..0000000000 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterSpec.scala +++ /dev/null @@ -1,190 +0,0 @@ -/** - * Copyright (C) 2015 Typesafe Inc. - */ -package akka.stream.impl - -import akka.stream.Supervision._ -import akka.stream.impl.ReactiveStreamsCompliance.SpecViolation -import akka.stream.testkit.AkkaSpec -import akka.stream._ -import akka.stream.scaladsl._ -import akka.stream.testkit._ -import akka.stream.impl.fusing.ActorInterpreter -import akka.stream.stage.Stage -import akka.stream.stage.PushPullStage -import akka.stream.stage.Context -import akka.testkit.TestLatch -import org.reactivestreams.{ Subscription, Subscriber, Publisher } -import scala.concurrent.Await -import scala.concurrent.duration._ - -class ActorInterpreterSpec extends AkkaSpec { - import FlowGraph.Implicits._ - - implicit val mat = ActorMaterializer() - - class Setup(ops: List[Stage[_, _]] = List(fusing.Map({ x: Any ⇒ x }, stoppingDecider))) { - val up = TestPublisher.manualProbe[Int]() - val down = TestSubscriber.manualProbe[Int] - private val props = ActorInterpreter.props(mat.settings, ops, mat).withDispatcher("akka.test.stream-dispatcher") - val actor = system.actorOf(props) - val processor = ActorProcessorFactory[Int, Int](actor) - } - - "An ActorInterpreter" must { - - "pass along early cancellation" in new Setup { - processor.subscribe(down) - val sub = down.expectSubscription() - sub.cancel() - up.subscribe(processor) - val upsub = up.expectSubscription() - upsub.expectCancellation() - } - - "heed cancellation signal while large demand is outstanding" in { - val latch = TestLatch() - val infinite = new PushPullStage[Int, Int] { - override def onPush(elem: Int, ctx: Context[Int]) = ??? - override def onPull(ctx: Context[Int]) = { - Await.ready(latch, 5.seconds) - ctx.push(42) - } - } - val N = system.settings.config.getInt("akka.stream.materializer.output-burst-limit") - - new Setup(infinite :: Nil) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - sub.request(100000000) - sub.cancel() - watch(actor) - latch.countDown() - for (i ← 1 to N) withClue(s"iteration $i: ") { - try down.expectNext(42) catch { case e: Throwable ⇒ fail(e) } - } - // now cancellation request is processed - down.expectNoMsg(500.millis) - upsub.expectCancellation() - expectTerminated(actor) - } - } - - "heed upstream failure while large demand is outstanding" in { - val latch = TestLatch() - val infinite = new PushPullStage[Int, Int] { - override def onPush(elem: Int, ctx: Context[Int]) = ??? - override def onPull(ctx: Context[Int]) = { - Await.ready(latch, 5.seconds) - ctx.push(42) - } - } - val N = system.settings.config.getInt("akka.stream.materializer.output-burst-limit") - - new Setup(infinite :: Nil) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - sub.request(100000000) - val ex = new Exception("FAIL!") - upsub.sendError(ex) - latch.countDown() - for (i ← 1 to N) withClue(s"iteration $i: ") { - try down.expectNext(42) catch { case e: Throwable ⇒ fail(e) } - } - down.expectError(ex) - } - } - - "hold back upstream completion while large demand is outstanding" in { - val latch = TestLatch() - val N = 3 * system.settings.config.getInt("akka.stream.materializer.output-burst-limit") - val infinite = new PushPullStage[Int, Int] { - private var remaining = N - override def onPush(elem: Int, ctx: Context[Int]) = ??? - override def onPull(ctx: Context[Int]) = { - Await.ready(latch, 5.seconds) - remaining -= 1 - if (remaining >= 0) ctx.push(42) - else ctx.finish() - } - override def onUpstreamFinish(ctx: Context[Int]) = { - if (remaining > 0) ctx.absorbTermination() - else ctx.finish() - } - } - - new Setup(infinite :: Nil) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - sub.request(100000000) - upsub.sendComplete() - latch.countDown() - for (i ← 1 to N) withClue(s"iteration $i: ") { - try down.expectNext(42) catch { case e: Throwable ⇒ fail(e) } - } - down.expectComplete() - } - } - - "satisfy large demand" in largeDemand(0) - "satisfy larger demand" in largeDemand(1) - - "handle spec violations" in { - a[AbruptTerminationException] should be thrownBy { - Await.result( - Source(new Publisher[String] { - def subscribe(s: Subscriber[_ >: String]) = { - s.onSubscribe(new Subscription { - def cancel() = () - def request(n: Long) = sys.error("test error") - }) - } - }).runFold("")(_ + _), - 3.seconds) - } - } - - "handle failed stage factories" in { - a[RuntimeException] should be thrownBy - Await.result( - Source.empty[Int].transform(() ⇒ sys.error("test error")).runWith(Sink.head), - 3.seconds) - } - - def largeDemand(extra: Int): Unit = { - val N = 3 * system.settings.config.getInt("akka.stream.materializer.output-burst-limit") - val large = new PushPullStage[Int, Int] { - private var remaining = N - override def onPush(elem: Int, ctx: Context[Int]) = ??? - override def onPull(ctx: Context[Int]) = { - remaining -= 1 - if (remaining >= 0) ctx.push(42) - else ctx.finish() - } - } - - new Setup(large :: Nil) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - sub.request(100000000) - watch(actor) - for (i ← 1 to N) withClue(s"iteration $i: ") { - try down.expectNext(42) catch { case e: Throwable ⇒ fail(e) } - } - down.expectComplete() - upsub.expectCancellation() - expectTerminated(actor) - } - } - - } - -} diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala new file mode 100644 index 0000000000..80bb3c6cd7 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala @@ -0,0 +1,177 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.impl + +import akka.stream.testkit.AkkaSpec +import akka.stream._ +import akka.stream.scaladsl._ +import akka.stream.stage._ +import akka.stream.testkit.Utils.assertAllStagesStopped +import akka.stream.testkit.scaladsl.TestSink +import akka.stream.impl.fusing._ +import akka.stream.impl.fusing.GraphInterpreter._ +import org.scalactic.ConversionCheckedTripleEquals +import scala.concurrent.duration.Duration + +class GraphStageLogicSpec extends GraphInterpreterSpecKit with ConversionCheckedTripleEquals { + + implicit val mat = ActorMaterializer() + + object emit1234 extends GraphStage[FlowShape[Int, Int]] { + val in = Inlet[Int]("in") + val out = Outlet[Int]("out") + override val shape = FlowShape(in, out) + override def createLogic(attr: Attributes) = new GraphStageLogic(shape) { + setHandler(in, eagerTerminateInput) + setHandler(out, eagerTerminateOutput) + override def preStart(): Unit = { + emit(out, 1, () ⇒ emit(out, 2)) + emit(out, 3, () ⇒ emit(out, 4)) + } + } + } + + object emit5678 extends GraphStage[FlowShape[Int, Int]] { + val in = Inlet[Int]("in") + val out = Outlet[Int]("out") + override val shape = FlowShape(in, out) + override def createLogic(attr: Attributes) = new GraphStageLogic(shape) { + setHandler(in, new InHandler { + override def onPush(): Unit = push(out, grab(in)) + override def onUpstreamFinish(): Unit = { + emit(out, 5, () ⇒ emit(out, 6)) + emit(out, 7, () ⇒ emit(out, 8)) + completeStage() + } + }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + } + } + + object passThrough extends GraphStage[FlowShape[Int, Int]] { + val in = Inlet[Int]("in") + val out = Outlet[Int]("out") + override val shape = FlowShape(in, out) + override def createLogic(attr: Attributes) = new GraphStageLogic(shape) { + setHandler(in, new InHandler { + override def onPush(): Unit = push(out, grab(in)) + override def onUpstreamFinish(): Unit = complete(out) + }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + } + } + + class FusedGraph[S <: Shape](ga: GraphAssembly, s: S, a: Attributes = Attributes.none) extends Graph[S, Unit] { + override def shape = s + override val module = GraphModule(ga, s, a) + override def withAttributes(attr: Attributes) = new FusedGraph(ga, s, attr) + } + + "A GraphStageLogic" must { + + "emit all things before completing" in assertAllStagesStopped { + + Source.empty.via(emit1234.named("testStage")).runWith(TestSink.probe) + .request(5) + .expectNext(1, 2, 3, 4) + .expectComplete() + + } + + "emit all things before completing with two fused stages" in assertAllStagesStopped { + new Builder { + val g = new FusedGraph( + builder(emit1234, emit5678) + .connect(Upstream, emit1234.in) + .connect(emit1234.out, emit5678.in) + .connect(emit5678.out, Downstream) + .buildAssembly(), + FlowShape(emit1234.in, emit5678.out)) + + Source.empty.via(g).runWith(TestSink.probe) + .request(9) + .expectNextN(1 to 8) + .expectComplete() + } + } + + "emit all things before completing with three fused stages" in assertAllStagesStopped { + new Builder { + val g = new FusedGraph( + builder(emit1234, passThrough, emit5678) + .connect(Upstream, emit1234.in) + .connect(emit1234.out, passThrough.in) + .connect(passThrough.out, emit5678.in) + .connect(emit5678.out, Downstream) + .buildAssembly(), + FlowShape(emit1234.in, emit5678.out)) + + Source.empty.via(g).runWith(TestSink.probe) + .request(9) + .expectNextN(1 to 8) + .expectComplete() + } + } + + "invoke lifecycle hooks in the right order" in assertAllStagesStopped { + val g = new GraphStage[FlowShape[Int, Int]] { + val in = Inlet[Int]("in") + val out = Outlet[Int]("out") + override val shape = FlowShape(in, out) + override def createLogic(attr: Attributes) = new GraphStageLogic(shape) { + setHandler(in, eagerTerminateInput) + setHandler(out, new OutHandler { + override def onPull(): Unit = { + completeStage() + testActor ! "pulled" + } + }) + override def preStart(): Unit = testActor ! "preStart" + override def postStop(): Unit = testActor ! "postStop" + } + } + Source.single(1).via(g).runWith(Sink.ignore) + expectMsg("preStart") + expectMsg("pulled") + expectMsg("postStop") + } + + "not double-terminate a single stage" in new Builder { + object g extends GraphStage[FlowShape[Int, Int]] { + val in = Inlet[Int]("in") + val out = Outlet[Int]("out") + override val shape = FlowShape(in, out) + override def createLogic(attr: Attributes) = new GraphStageLogic(shape) { + setHandler(in, eagerTerminateInput) + setHandler(out, eagerTerminateOutput) + override def postStop(): Unit = testActor ! "postStop2" + } + } + + builder(g, passThrough) + .connect(Upstream, g.in) + .connect(g.out, passThrough.in) + .connect(passThrough.out, Downstream) + .init() + + interpreter.complete(0) + interpreter.cancel(1) + interpreter.execute(2) + + expectMsg("postStop2") + expectNoMsg(Duration.Zero) + + interpreter.isCompleted should ===(false) + interpreter.isSuspended should ===(false) + interpreter.isStageCompleted(interpreter.logics(0)) should ===(true) + interpreter.isStageCompleted(interpreter.logics(1)) should ===(false) + } + + } + +} \ No newline at end of file diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala index bd9b5b25d8..29695f242b 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala @@ -171,7 +171,7 @@ class StreamLayoutSpec extends AkkaSpec { override def onNext(t: Any): Unit = () } - class FlatTestMaterializer(_module: Module) extends MaterializerSession(_module) { + class FlatTestMaterializer(_module: Module) extends MaterializerSession(_module, Attributes()) { var publishers = Vector.empty[TestPublisher] var subscribers = Vector.empty[TestSubscriber] diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala index cc4ba30309..0eadb8a826 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala @@ -47,7 +47,7 @@ class ActorGraphInterpreterSpec extends AkkaSpec { val out2 = Outlet[Int]("out2") val shape = BidiShape(in1, out1, in2, out2) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { setHandler(in1, new InHandler { override def onPush(): Unit = push(out1, grab(in1)) override def onUpstreamFinish(): Unit = complete(out1) @@ -88,7 +88,7 @@ class ActorGraphInterpreterSpec extends AkkaSpec { val out2 = Outlet[Int]("out2") val shape = BidiShape(in1, out1, in2, out2) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { setHandler(in1, new InHandler { override def onPush(): Unit = push(out1, grab(in1)) @@ -134,7 +134,7 @@ class ActorGraphInterpreterSpec extends AkkaSpec { val out2 = Outlet[Int]("out2") val shape = BidiShape(in1, out1, in2, out2) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { setHandler(in1, new InHandler { override def onPush(): Unit = push(out1, grab(in1)) @@ -183,7 +183,7 @@ class ActorGraphInterpreterSpec extends AkkaSpec { val out2 = Outlet[Int]("out2") val shape = BidiShape(in1, out1, in2, out2) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { setHandler(in1, new InHandler { override def onPush(): Unit = push(out2, grab(in1)) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala index ed0184b37c..56ea126be9 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala @@ -3,6 +3,7 @@ */ package akka.stream.impl.fusing +import akka.stream.Attributes import akka.stream.testkit.AkkaSpec import akka.stream.scaladsl.{ Merge, Broadcast, Balance, Zip } import GraphInterpreter._ @@ -45,6 +46,7 @@ class GraphInterpreterSpec extends GraphInterpreterSpecKit { // Constructing an assembly by hand and resolving ambiguities val assembly = new GraphAssembly( stages = Array(identity, identity), + originalAttributes = Array(Attributes.none, Attributes.none), ins = Array(identity.in, identity.in, null), inOwners = Array(0, 1, -1), outs = Array(null, identity.out, identity.out), diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala index e89666735b..6428195ebb 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala @@ -5,35 +5,32 @@ package akka.stream.impl.fusing import akka.event.Logging import akka.stream._ -import akka.stream.impl.fusing.GraphInterpreter.{ Failed, GraphAssembly, DownstreamBoundaryStageLogic, UpstreamBoundaryStageLogic } -import akka.stream.stage.{ InHandler, OutHandler, GraphStage, GraphStageLogic } +import akka.stream.impl.fusing.GraphInterpreter.{ DownstreamBoundaryStageLogic, Failed, GraphAssembly, UpstreamBoundaryStageLogic } +import akka.stream.stage.AbstractStage.PushPullGraphStage +import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler, _ } import akka.stream.testkit.AkkaSpec import akka.stream.testkit.Utils.TE +import akka.stream.impl.fusing.GraphInterpreter.GraphAssembly trait GraphInterpreterSpecKit extends AkkaSpec { - sealed trait TestEvent { - def source: GraphStageLogic - } - - case class OnComplete(source: GraphStageLogic) extends TestEvent - case class Cancel(source: GraphStageLogic) extends TestEvent - case class OnError(source: GraphStageLogic, cause: Throwable) extends TestEvent - case class OnNext(source: GraphStageLogic, elem: Any) extends TestEvent - case class RequestOne(source: GraphStageLogic) extends TestEvent - case class RequestAnother(source: GraphStageLogic) extends TestEvent - - case class PreStart(source: GraphStageLogic) extends TestEvent - case class PostStop(source: GraphStageLogic) extends TestEvent - - abstract class TestSetup { - protected var lastEvent: Set[TestEvent] = Set.empty + abstract class Builder { private var _interpreter: GraphInterpreter = _ protected def interpreter: GraphInterpreter = _interpreter def stepAll(): Unit = interpreter.execute(eventLimit = Int.MaxValue) def step(): Unit = interpreter.execute(eventLimit = 1) + object Upstream extends UpstreamBoundaryStageLogic[Int] { + override val out = Outlet[Int]("up") + out.id = 0 + } + + object Downstream extends DownstreamBoundaryStageLogic[Int] { + override val in = Inlet[Int]("down") + in.id = 0 + } + class AssemblyBuilder(stages: Seq[GraphStage[_ <: Shape]]) { var upstreams = Vector.empty[(UpstreamBoundaryStageLogic[_], Inlet[_])] var downstreams = Vector.empty[(Outlet[_], DownstreamBoundaryStageLogic[_])] @@ -54,20 +51,25 @@ trait GraphInterpreterSpecKit extends AkkaSpec { this } - def init(): Unit = { + def buildAssembly(): GraphAssembly = { val ins = upstreams.map(_._2) ++ connections.map(_._2) val outs = connections.map(_._1) ++ downstreams.map(_._1) val inOwners = ins.map { in ⇒ stages.indexWhere(_.shape.inlets.contains(in)) } val outOwners = outs.map { out ⇒ stages.indexWhere(_.shape.outlets.contains(out)) } - val assembly = new GraphAssembly( + new GraphAssembly( stages.toArray, + Array.fill(stages.size)(Attributes.none), (ins ++ Vector.fill(downstreams.size)(null)).toArray, (inOwners ++ Vector.fill(downstreams.size)(-1)).toArray, (Vector.fill(upstreams.size)(null) ++ outs).toArray, (Vector.fill(upstreams.size)(-1) ++ outOwners).toArray) + } - val (inHandlers, outHandlers, logics, _) = assembly.materialize() + def init(): Unit = { + val assembly = buildAssembly() + + val (inHandlers, outHandlers, logics, _) = assembly.materialize(Attributes.none) _interpreter = new GraphInterpreter(assembly, NoMaterializer, Logging(system, classOf[TestSetup]), inHandlers, outHandlers, logics, (_, _, _) ⇒ ()) for ((upstream, i) ← upstreams.zipWithIndex) { @@ -83,11 +85,29 @@ trait GraphInterpreterSpecKit extends AkkaSpec { } def manualInit(assembly: GraphAssembly): Unit = { - val (inHandlers, outHandlers, logics, _) = assembly.materialize() + val (inHandlers, outHandlers, logics, _) = assembly.materialize(Attributes.none) _interpreter = new GraphInterpreter(assembly, NoMaterializer, Logging(system, classOf[TestSetup]), inHandlers, outHandlers, logics, (_, _, _) ⇒ ()) } - def builder(stages: GraphStage[_ <: Shape]*): AssemblyBuilder = new AssemblyBuilder(stages.toSeq) + def builder(stages: GraphStage[_ <: Shape]*): AssemblyBuilder = new AssemblyBuilder(stages) + } + + abstract class TestSetup extends Builder { + + sealed trait TestEvent { + def source: GraphStageLogic + } + + case class OnComplete(source: GraphStageLogic) extends TestEvent + case class Cancel(source: GraphStageLogic) extends TestEvent + case class OnError(source: GraphStageLogic, cause: Throwable) extends TestEvent + case class OnNext(source: GraphStageLogic, elem: Any) extends TestEvent + case class RequestOne(source: GraphStageLogic) extends TestEvent + case class RequestAnother(source: GraphStageLogic) extends TestEvent + case class PreStart(source: GraphStageLogic) extends TestEvent + case class PostStop(source: GraphStageLogic) extends TestEvent + + protected var lastEvent: Set[TestEvent] = Set.empty def lastEvents(): Set[TestEvent] = { val result = lastEvent @@ -158,7 +178,7 @@ trait GraphInterpreterSpecKit extends AkkaSpec { // Modified onPush that does not grab() automatically the element. This accesses some internals. override def onPush(): Unit = { - val internalEvent = interpreter.connectionSlots(inToConn(in.id)) + val internalEvent = interpreter.connectionSlots(portToConn(in.id)) internalEvent match { case Failed(_, elem) ⇒ lastEvent += OnNext(DownstreamPortProbe.this, elem) @@ -173,6 +193,7 @@ trait GraphInterpreterSpecKit extends AkkaSpec { private val assembly = new GraphAssembly( stages = Array.empty, + originalAttributes = Array.empty, ins = Array(null), inOwners = Array(-1), outs = Array(null), @@ -233,7 +254,7 @@ trait GraphInterpreterSpecKit extends AkkaSpec { private val sandwitchStage = new GraphStage[FlowShape[Int, Int]] { override def shape = stageshape - override def createLogic: GraphStageLogic = stage + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = stage } class UpstreamPortProbe[T] extends UpstreamProbe[T]("upstreamPort") { @@ -253,4 +274,133 @@ trait GraphInterpreterSpecKit extends AkkaSpec { .init() } + abstract class OneBoundedSetup[T](ops: Array[GraphStageWithMaterializedValue[Shape, Any]]) extends Builder { + + def this(ops: Iterable[Stage[_, _]]) = { + this(ops.map { op ⇒ + new PushPullGraphStage[Any, Any, Any]( + (_) ⇒ op.asInstanceOf[Stage[Any, Any]], + Attributes.none) + }.toArray.asInstanceOf[Array[GraphStageWithMaterializedValue[Shape, Any]]]) + } + + val upstream = new UpstreamOneBoundedProbe[T] + val downstream = new DownstreamOneBoundedPortProbe[T] + var lastEvent = Set.empty[TestEvent] + + sealed trait TestEvent + + case object OnComplete extends TestEvent + case object Cancel extends TestEvent + case class OnError(cause: Throwable) extends TestEvent + case class OnNext(elem: Any) extends TestEvent + case object RequestOne extends TestEvent + case object RequestAnother extends TestEvent + + private def run() = interpreter.execute(Int.MaxValue) + + private def initialize(): Unit = { + import GraphInterpreter.Boundary + + var i = 0 + val attributes = Array.fill[Attributes](ops.length)(Attributes.none) + val ins = Array.ofDim[Inlet[_]](ops.length + 1) + val inOwners = Array.ofDim[Int](ops.length + 1) + val outs = Array.ofDim[Outlet[_]](ops.length + 1) + val outOwners = Array.ofDim[Int](ops.length + 1) + + ins(ops.length) = null + inOwners(ops.length) = Boundary + outs(0) = null + outOwners(0) = Boundary + + while (i < ops.length) { + val stage = ops(i).asInstanceOf[PushPullGraphStage[_, _, _]] + ins(i) = stage.shape.inlet + inOwners(i) = i + outs(i + 1) = stage.shape.outlet + outOwners(i + 1) = i + i += 1 + } + + manualInit(new GraphAssembly(ops, attributes, ins, inOwners, outs, outOwners)) + interpreter.attachUpstreamBoundary(0, upstream) + interpreter.attachDownstreamBoundary(ops.length, downstream) + + interpreter.init() + + } + + initialize() + run() // Detached stages need the prefetch + + def lastEvents(): Set[TestEvent] = { + val events = lastEvent + lastEvent = Set.empty + events + } + + class UpstreamOneBoundedProbe[T] extends UpstreamBoundaryStageLogic[T] { + val out = Outlet[T]("out") + out.id = 0 + + setHandler(out, new OutHandler { + override def onPull(): Unit = { + if (lastEvent.contains(RequestOne)) lastEvent += RequestAnother + else lastEvent += RequestOne + } + + override def onDownstreamFinish(): Unit = lastEvent += Cancel + }) + + def onNext(elem: T): Unit = { + push(out, elem) + run() + } + def onComplete(): Unit = { + complete(out) + run() + } + + def onNextAndComplete(elem: T): Unit = { + push(out, elem) + complete(out) + run() + } + + def onError(ex: Throwable): Unit = { + fail(out, ex) + run() + } + } + + class DownstreamOneBoundedPortProbe[T] extends DownstreamBoundaryStageLogic[T] { + val in = Inlet[T]("in") + in.id = 0 + + setHandler(in, new InHandler { + + // Modified onPush that does not grab() automatically the element. This accesses some internals. + override def onPush(): Unit = { + lastEvent += OnNext(grab(in)) + } + + override def onUpstreamFinish() = lastEvent += OnComplete + override def onUpstreamFailure(ex: Throwable) = lastEvent += OnError(ex) + }) + + def requestOne(): Unit = { + pull(in) + run() + } + + def cancel(): Unit = { + cancel(in) + run() + } + + } + + } + } diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala index 4d0b32decd..00c9ff1ba1 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala @@ -4,17 +4,24 @@ package akka.stream.impl.fusing import akka.stream.stage._ +import akka.stream.testkit.AkkaSpec import akka.testkit.EventFilter import scala.util.control.NoStackTrace import akka.stream.Supervision -class InterpreterSpec extends InterpreterSpecKit { +class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { import Supervision.stoppingDecider + /* + * These tests were writtern for the previous veryion of the interpreter, the so called OneBoundedInterpreter. + * These stages are now properly emulated by the GraphInterpreter and many of the edge cases were relevant to + * the execution model of the old one. Still, these tests are very valuable, so please do not remove. + */ + "Interpreter" must { - "implement map correctly" in new TestSetup(Seq(Map((x: Int) ⇒ x + 1, stoppingDecider))) { + "implement map correctly" in new OneBoundedSetup[Int](Seq(Map((x: Int) ⇒ x + 1, stoppingDecider))) { lastEvents() should be(Set.empty) downstream.requestOne() @@ -33,7 +40,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnComplete)) } - "implement chain of maps correctly" in new TestSetup(Seq( + "implement chain of maps correctly" in new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, stoppingDecider), Map((x: Int) ⇒ x * 2, stoppingDecider), Map((x: Int) ⇒ x + 1, stoppingDecider))) { @@ -56,7 +63,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel)) } - "work with only boundary ops" in new TestSetup(Seq.empty) { + "work with only boundary ops" in new OneBoundedSetup[Int](Seq.empty) { lastEvents() should be(Set.empty) downstream.requestOne() @@ -69,7 +76,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnComplete)) } - "implement one-to-many many-to-one chain correctly" in new TestSetup(Seq( + "implement one-to-many many-to-one chain correctly" in new OneBoundedSetup[Int](Seq( Doubler(), Filter((x: Int) ⇒ x != 0, stoppingDecider))) { @@ -94,7 +101,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnComplete)) } - "implement many-to-one one-to-many chain correctly" in new TestSetup(Seq( + "implement many-to-one one-to-many chain correctly" in new OneBoundedSetup[Int](Seq( Filter((x: Int) ⇒ x != 0, stoppingDecider), Doubler())) { @@ -119,7 +126,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel)) } - "implement take" in new TestSetup(Seq(Take(2))) { + "implement take" in new OneBoundedSetup[Int](Seq(Take(2))) { lastEvents() should be(Set.empty) @@ -136,7 +143,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(1), Cancel, OnComplete)) } - "implement take inside a chain" in new TestSetup(Seq( + "implement take inside a chain" in new OneBoundedSetup[Int](Seq( Filter((x: Int) ⇒ x != 0, stoppingDecider), Take(2), Map((x: Int) ⇒ x + 1, stoppingDecider))) { @@ -159,7 +166,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel, OnComplete, OnNext(3))) } - "implement fold" in new TestSetup(Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { + "implement fold" in new OneBoundedSetup[Int](Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { lastEvents() should be(Set.empty) downstream.requestOne() @@ -178,7 +185,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(3), OnComplete)) } - "implement fold with proper cancel" in new TestSetup(Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { + "implement fold with proper cancel" in new OneBoundedSetup[Int](Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { lastEvents() should be(Set.empty) @@ -198,7 +205,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel)) } - "work if fold completes while not in a push position" in new TestSetup(Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { + "work if fold completes while not in a push position" in new OneBoundedSetup[Int](Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { lastEvents() should be(Set.empty) @@ -209,7 +216,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnComplete, OnNext(0))) } - "implement grouped" in new TestSetup(Seq(Grouped(3))) { + "implement grouped" in new OneBoundedSetup[Int](Seq(Grouped(3))) { lastEvents() should be(Set.empty) downstream.requestOne() @@ -234,7 +241,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(Vector(3)), OnComplete)) } - "implement conflate" in new TestSetup(Seq(Conflate( + "implement conflate" in new OneBoundedSetup[Int](Seq(Conflate( (in: Int) ⇒ in, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { @@ -266,7 +273,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel)) } - "implement expand" in new TestSetup(Seq(Expand( + "implement expand" in new OneBoundedSetup[Int](Seq(Expand( (in: Int) ⇒ in, (agg: Int) ⇒ (agg, agg)))) { @@ -294,7 +301,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnComplete)) } - "work with conflate-conflate" in new TestSetup(Seq( + "work with conflate-conflate" in new OneBoundedSetup[Int](Seq( Conflate( (in: Int) ⇒ in, (agg: Int, x: Int) ⇒ agg + x, @@ -332,7 +339,7 @@ class InterpreterSpec extends InterpreterSpecKit { } - "work with expand-expand" in new TestSetup(Seq( + "work with expand-expand" in new OneBoundedSetup[Int](Seq( Expand( (in: Int) ⇒ in, (agg: Int) ⇒ (agg, agg + 1)), @@ -369,7 +376,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnComplete, OnNext(12))) } - "implement conflate-expand" in new TestSetup(Seq( + "implement conflate-expand" in new OneBoundedSetup[Int](Seq( Conflate( (in: Int) ⇒ in, (agg: Int, x: Int) ⇒ agg + x, @@ -405,12 +412,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel)) } - "implement expand-conflate" in { - pending - // Needs to detect divergent loops - } - - "implement doubler-conflate" in new TestSetup(Seq( + "implement doubler-conflate" in new OneBoundedSetup[Int](Seq( Doubler(), Conflate( (in: Int) ⇒ in, @@ -429,7 +431,8 @@ class InterpreterSpec extends InterpreterSpecKit { } - "work with jumpback table and completed elements" in new TestSetup(Seq( + // Note, the new interpreter has no jumpback table, still did not want to remove the test + "work with jumpback table and completed elements" in new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x, stoppingDecider), Map((x: Int) ⇒ x, stoppingDecider), KeepGoing(), @@ -461,7 +464,7 @@ class InterpreterSpec extends InterpreterSpecKit { } - "work with pushAndFinish if upstream completes with pushAndFinish" in new TestSetup(Seq( + "work with pushAndFinish if upstream completes with pushAndFinish" in new OneBoundedSetup[Int](Seq( new PushFinishStage)) { lastEvents() should be(Set.empty) @@ -469,11 +472,11 @@ class InterpreterSpec extends InterpreterSpecKit { downstream.requestOne() lastEvents() should be(Set(RequestOne)) - upstream.onNextAndComplete("foo") - lastEvents() should be(Set(OnNext("foo"), OnComplete)) + upstream.onNextAndComplete(0) + lastEvents() should be(Set(OnNext(0), OnComplete)) } - "work with pushAndFinish if indirect upstream completes with pushAndFinish" in new TestSetup(Seq( + "work with pushAndFinish if indirect upstream completes with pushAndFinish" in new OneBoundedSetup[Int](Seq( Map((x: Any) ⇒ x, stoppingDecider), new PushFinishStage, Map((x: Any) ⇒ x, stoppingDecider))) { @@ -483,24 +486,24 @@ class InterpreterSpec extends InterpreterSpecKit { downstream.requestOne() lastEvents() should be(Set(RequestOne)) - upstream.onNextAndComplete("foo") - lastEvents() should be(Set(OnNext("foo"), OnComplete)) + upstream.onNextAndComplete(1) + lastEvents() should be(Set(OnNext(1), OnComplete)) } - "work with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new TestSetup(Seq( + "work with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new OneBoundedSetup[Int](Seq( new PushFinishStage, - Fold("", (x: String, y: String) ⇒ x + y, stoppingDecider))) { + Fold(0, (x: Int, y: Int) ⇒ x + y, stoppingDecider))) { lastEvents() should be(Set.empty) downstream.requestOne() lastEvents() should be(Set(RequestOne)) - upstream.onNextAndComplete("foo") - lastEvents() should be(Set(OnNext("foo"), OnComplete)) + upstream.onNextAndComplete(1) + lastEvents() should be(Set(OnNext(1), OnComplete)) } - "report error if pull is called while op is terminating" in new TestSetup(Seq(new PushPullStage[Any, Any] { + "report error if pull is called while op is terminating" in new OneBoundedSetup[Int](Seq(new PushPullStage[Any, Any] { override def onPull(ctx: Context[Any]): SyncDirective = ctx.pull() override def onPush(elem: Any, ctx: Context[Any]): SyncDirective = ctx.pull() override def onUpstreamFinish(ctx: Context[Any]): TerminationDirective = ctx.absorbTermination() @@ -514,12 +517,12 @@ class InterpreterSpec extends InterpreterSpecKit { val ev = lastEvents() ev.nonEmpty should be(true) ev.forall { - case OnError(_: IllegalStateException) ⇒ true - case _ ⇒ false + case OnError(_: IllegalArgumentException) ⇒ true + case _ ⇒ false } should be(true) } - "implement take-take" in new TestSetup(Seq( + "implement take-take" in new OneBoundedSetup[Int](Seq( Take(1), Take(1))) { lastEvents() should be(Set.empty) @@ -527,12 +530,12 @@ class InterpreterSpec extends InterpreterSpecKit { downstream.requestOne() lastEvents() should be(Set(RequestOne)) - upstream.onNext("foo") - lastEvents() should be(Set(OnNext("foo"), OnComplete, Cancel)) + upstream.onNext(1) + lastEvents() should be(Set(OnNext(1), OnComplete, Cancel)) } - "implement take-take with pushAndFinish from upstream" in new TestSetup(Seq( + "implement take-take with pushAndFinish from upstream" in new OneBoundedSetup[Int](Seq( Take(1), Take(1))) { lastEvents() should be(Set.empty) @@ -540,8 +543,8 @@ class InterpreterSpec extends InterpreterSpecKit { downstream.requestOne() lastEvents() should be(Set(RequestOne)) - upstream.onNextAndComplete("foo") - lastEvents() should be(Set(OnNext("foo"), OnComplete)) + upstream.onNextAndComplete(1) + lastEvents() should be(Set(OnNext(1), OnComplete)) } @@ -551,7 +554,7 @@ class InterpreterSpec extends InterpreterSpecKit { override def onDownstreamFinish(ctx: Context[Int]): TerminationDirective = ctx.absorbTermination() } - "not allow absorbTermination from onDownstreamFinish()" in new TestSetup(Seq( + "not allow absorbTermination from onDownstreamFinish()" in new OneBoundedSetup[Int](Seq( new InvalidAbsorbTermination)) { lastEvents() should be(Set.empty) @@ -564,4 +567,51 @@ class InterpreterSpec extends InterpreterSpecKit { } + private[akka] case class Doubler[T]() extends PushPullStage[T, T] { + var oneMore: Boolean = false + var lastElem: T = _ + + override def onPush(elem: T, ctx: Context[T]): SyncDirective = { + lastElem = elem + oneMore = true + ctx.push(elem) + } + + override def onPull(ctx: Context[T]): SyncDirective = { + if (oneMore) { + oneMore = false + ctx.push(lastElem) + } else ctx.pull() + } + } + + private[akka] case class KeepGoing[T]() extends PushPullStage[T, T] { + var lastElem: T = _ + + override def onPush(elem: T, ctx: Context[T]): SyncDirective = { + lastElem = elem + ctx.push(elem) + } + + override def onPull(ctx: Context[T]): SyncDirective = { + if (ctx.isFinishing) { + ctx.push(lastElem) + } else ctx.pull() + } + + override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = ctx.absorbTermination() + } + + // This test is related to issue #17351 + private[akka] class PushFinishStage(onPostStop: () ⇒ Unit = () ⇒ ()) extends PushStage[Any, Any] { + override def onPush(elem: Any, ctx: Context[Any]): SyncDirective = + ctx.pushAndFinish(elem) + + override def onUpstreamFinish(ctx: Context[Any]): TerminationDirective = + ctx.fail(akka.stream.testkit.Utils.TE("Cannot happen")) + + override def postStop(): Unit = + onPostStop() + } + } diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpecKit.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpecKit.scala deleted file mode 100644 index 805641a5ad..0000000000 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpecKit.scala +++ /dev/null @@ -1,183 +0,0 @@ -/** - * Copyright (C) 2009-2014 Typesafe Inc. - */ -package akka.stream.impl.fusing - -import akka.event.Logging -import akka.stream.stage._ -import akka.stream.testkit.AkkaSpec -import akka.stream.{ ActorMaterializer, Attributes } -import akka.testkit.TestProbe - -trait InterpreterLifecycleSpecKit { - private[akka] case class PreStartAndPostStopIdentity[T]( - onStart: LifecycleContext ⇒ Unit = _ ⇒ (), - onStop: () ⇒ Unit = () ⇒ (), - onUpstreamCompleted: () ⇒ Unit = () ⇒ (), - onUpstreamFailed: Throwable ⇒ Unit = ex ⇒ ()) - extends PushStage[T, T] { - override def preStart(ctx: LifecycleContext) = onStart(ctx) - - override def onPush(elem: T, ctx: Context[T]) = ctx.push(elem) - - override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = { - onUpstreamCompleted() - super.onUpstreamFinish(ctx) - } - - override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = { - onUpstreamFailed(cause) - super.onUpstreamFailure(cause, ctx) - } - - override def postStop() = onStop() - } - - private[akka] case class PreStartFailer[T](pleaseThrow: () ⇒ Unit) extends PushStage[T, T] { - - override def preStart(ctx: LifecycleContext) = - pleaseThrow() - - override def onPush(elem: T, ctx: Context[T]) = ctx.push(elem) - } - - private[akka] case class PostStopFailer[T](ex: () ⇒ Throwable) extends PushStage[T, T] { - override def onUpstreamFinish(ctx: Context[T]) = ctx.finish() - override def onPush(elem: T, ctx: Context[T]) = ctx.push(elem) - - override def postStop(): Unit = throw ex() - } - - // This test is related to issue #17351 - private[akka] class PushFinishStage(onPostStop: () ⇒ Unit = () ⇒ ()) extends PushStage[Any, Any] { - override def onPush(elem: Any, ctx: Context[Any]): SyncDirective = - ctx.pushAndFinish(elem) - - override def onUpstreamFinish(ctx: Context[Any]): TerminationDirective = - ctx.fail(akka.stream.testkit.Utils.TE("Cannot happen")) - - override def postStop(): Unit = - onPostStop() - } - -} - -trait InterpreterSpecKit extends AkkaSpec with InterpreterLifecycleSpecKit { - - case object OnComplete - case object Cancel - case class OnError(cause: Throwable) - case class OnNext(elem: Any) - case object RequestOne - case object RequestAnother - - private[akka] case class Doubler[T]() extends PushPullStage[T, T] { - var oneMore: Boolean = false - var lastElem: T = _ - - override def onPush(elem: T, ctx: Context[T]): SyncDirective = { - lastElem = elem - oneMore = true - ctx.push(elem) - } - - override def onPull(ctx: Context[T]): SyncDirective = { - if (oneMore) { - oneMore = false - ctx.push(lastElem) - } else ctx.pull() - } - } - - private[akka] case class KeepGoing[T]() extends PushPullStage[T, T] { - var lastElem: T = _ - - override def onPush(elem: T, ctx: Context[T]): SyncDirective = { - lastElem = elem - ctx.push(elem) - } - - override def onPull(ctx: Context[T]): SyncDirective = { - if (ctx.isFinishing) { - ctx.push(lastElem) - } else ctx.pull() - } - - override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = ctx.absorbTermination() - } - - abstract class TestSetup(ops: Seq[Stage[_, _]], forkLimit: Int = 100, overflowToHeap: Boolean = false) { - private var lastEvent: Set[Any] = Set.empty - - val upstream = new UpstreamProbe - val downstream = new DownstreamProbe - val sidechannel = TestProbe() - val interpreter = new OneBoundedInterpreter(upstream +: ops :+ downstream, - (op, ctx, event) ⇒ sidechannel.ref ! ActorInterpreter.AsyncInput(op, ctx, event), - Logging(system, classOf[TestSetup]), - ActorMaterializer(), - Attributes.none, - forkLimit, overflowToHeap) - interpreter.init() - - def lastEvents(): Set[Any] = { - val result = lastEvent - lastEvent = Set.empty - result - } - - private[akka] class UpstreamProbe extends BoundaryStage { - - override def onDownstreamFinish(ctx: BoundaryContext): TerminationDirective = { - lastEvent += Cancel - ctx.finish() - } - - override def onPull(ctx: BoundaryContext): Directive = { - if (lastEvent(RequestOne)) - lastEvent += RequestAnother - else - lastEvent += RequestOne - ctx.exit() - } - - override def onPush(elem: Any, ctx: BoundaryContext): Directive = - throw new UnsupportedOperationException("Cannot push the boundary") - - def onNext(elem: Any): Unit = enterAndPush(elem) - def onComplete(): Unit = enterAndFinish() - def onNextAndComplete(elem: Any): Unit = { - context.enter() - context.pushAndFinish(elem) - context.execute() - } - def onError(cause: Throwable): Unit = enterAndFail(cause) - - } - - private[akka] class DownstreamProbe extends BoundaryStage { - override def onPush(elem: Any, ctx: BoundaryContext): Directive = { - lastEvent += OnNext(elem) - ctx.exit() - } - - override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = { - lastEvent += OnComplete - ctx.finish() - } - - override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = { - lastEvent += OnError(cause) - ctx.finish() - } - - override def onPull(ctx: BoundaryContext): Directive = - throw new UnsupportedOperationException("Cannot pull the boundary") - - def requestOne(): Unit = enterAndPull() - - def cancel(): Unit = enterAndFinish() - } - - } -} diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala index 4a4f427c8f..e7f13562a3 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala @@ -3,18 +3,27 @@ */ package akka.stream.impl.fusing -import akka.stream.Supervision +import akka.stream.{ Attributes, Shape, Supervision } +import akka.stream.stage.AbstractStage.PushPullGraphStage +import akka.stream.stage.GraphStageWithMaterializedValue +import akka.stream.testkit.AkkaSpec -class InterpreterStressSpec extends InterpreterSpecKit { +class InterpreterStressSpec extends AkkaSpec with GraphInterpreterSpecKit { import Supervision.stoppingDecider val chainLength = 1000 * 1000 val halfLength = chainLength / 2 val repetition = 100 + val f = (x: Int) ⇒ x + 1 + + val map: GraphStageWithMaterializedValue[Shape, Any] = + new PushPullGraphStage[Int, Int, Unit]((_) ⇒ Map(f, stoppingDecider), Attributes.none) + .asInstanceOf[GraphStageWithMaterializedValue[Shape, Any]] + "Interpreter" must { - "work with a massive chain of maps" in new TestSetup(Seq.fill(chainLength)(Map((x: Int) ⇒ x + 1, stoppingDecider))) { + "work with a massive chain of maps" in new OneBoundedSetup[Int](Array.fill(chainLength)(map).asInstanceOf[Array[GraphStageWithMaterializedValue[Shape, Any]]]) { lastEvents() should be(Set.empty) val tstamp = System.nanoTime() @@ -32,11 +41,11 @@ class InterpreterStressSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnComplete)) val time = (System.nanoTime() - tstamp) / (1000.0 * 1000.0 * 1000.0) - // FIXME: Not a real benchmark, should be replaced by a proper JMH bench + // Not a real benchmark, just for sanity check info(s"Chain finished in $time seconds ${(chainLength * repetition) / (time * 1000 * 1000)} million maps/s") } - "work with a massive chain of maps with early complete" in new TestSetup(Seq.fill(halfLength)(Map((x: Int) ⇒ x + 1, stoppingDecider)) ++ + "work with a massive chain of maps with early complete" in new OneBoundedSetup[Int](Iterable.fill(halfLength)(Map((x: Int) ⇒ x + 1, stoppingDecider)) ++ Seq(Take(repetition / 2)) ++ Seq.fill(halfLength)(Map((x: Int) ⇒ x + 1, stoppingDecider))) { @@ -60,11 +69,11 @@ class InterpreterStressSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel, OnComplete, OnNext(0 + chainLength))) val time = (System.nanoTime() - tstamp) / (1000.0 * 1000.0 * 1000.0) - // FIXME: Not a real benchmark, should be replaced by a proper JMH bench + // Not a real benchmark, just for sanity check info(s"Chain finished in $time seconds ${(chainLength * repetition) / (time * 1000 * 1000)} million maps/s") } - "work with a massive chain of takes" in new TestSetup(Seq.fill(chainLength)(Take(1))) { + "work with a massive chain of takes" in new OneBoundedSetup[Int](Iterable.fill(chainLength)(Take(1))) { lastEvents() should be(Set.empty) downstream.requestOne() @@ -75,7 +84,7 @@ class InterpreterStressSpec extends InterpreterSpecKit { } - "work with a massive chain of drops" in new TestSetup(Seq.fill(chainLength / 1000)(Drop(1))) { + "work with a massive chain of drops" in new OneBoundedSetup[Int](Iterable.fill(chainLength / 1000)(Drop(1))) { lastEvents() should be(Set.empty) downstream.requestOne() @@ -93,12 +102,10 @@ class InterpreterStressSpec extends InterpreterSpecKit { } - "work with a massive chain of conflates by overflowing to the heap" in new TestSetup(Seq.fill(100000)(Conflate( + "work with a massive chain of conflates by overflowing to the heap" in new OneBoundedSetup[Int](Iterable.fill(100000)(Conflate( (in: Int) ⇒ in, (agg: Int, in: Int) ⇒ agg + in, - Supervision.stoppingDecider)), - forkLimit = 100, - overflowToHeap = true) { + Supervision.stoppingDecider))) { lastEvents() should be(Set(RequestOne)) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala index 2438588ad7..238c246678 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala @@ -70,7 +70,7 @@ object InterpreterSupervisionSpec { } -class InterpreterSupervisionSpec extends InterpreterSpecKit { +class InterpreterSupervisionSpec extends GraphInterpreterSpecKit { import InterpreterSupervisionSpec._ import Supervision.stoppingDecider import Supervision.resumingDecider @@ -78,7 +78,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { "Interpreter error handling" must { - "handle external failure" in new TestSetup(Seq(Map((x: Int) ⇒ x + 1, stoppingDecider))) { + "handle external failure" in new OneBoundedSetup[Int](Seq(Map((x: Int) ⇒ x + 1, stoppingDecider))) { lastEvents() should be(Set.empty) upstream.onError(TE) @@ -86,7 +86,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { } - "emit failure when op throws" in new TestSetup(Seq(Map((x: Int) ⇒ if (x == 0) throw TE else x, stoppingDecider))) { + "emit failure when op throws" in new OneBoundedSetup[Int](Seq(Map((x: Int) ⇒ if (x == 0) throw TE else x, stoppingDecider))) { downstream.requestOne() lastEvents() should be(Set(RequestOne)) upstream.onNext(2) @@ -98,7 +98,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel, OnError(TE))) } - "emit failure when op throws in middle of the chain" in new TestSetup(Seq( + "emit failure when op throws in middle of the chain" in new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, stoppingDecider), Map((x: Int) ⇒ if (x == 0) throw TE else x + 10, stoppingDecider), Map((x: Int) ⇒ x + 100, stoppingDecider))) { @@ -114,7 +114,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel, OnError(TE))) } - "resume when Map throws" in new TestSetup(Seq(Map((x: Int) ⇒ if (x == 0) throw TE else x, resumingDecider))) { + "resume when Map throws" in new OneBoundedSetup[Int](Seq(Map((x: Int) ⇒ if (x == 0) throw TE else x, resumingDecider))) { downstream.requestOne() lastEvents() should be(Set(RequestOne)) upstream.onNext(2) @@ -138,7 +138,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(4))) } - "resume when Map throws in middle of the chain" in new TestSetup(Seq( + "resume when Map throws in middle of the chain" in new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, resumingDecider), Map((x: Int) ⇒ if (x == 0) throw TE else x + 10, resumingDecider), Map((x: Int) ⇒ x + 100, resumingDecider))) { @@ -157,7 +157,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(114))) } - "resume when Map throws before Grouped" in new TestSetup(Seq( + "resume when Map throws before Grouped" in new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, resumingDecider), Map((x: Int) ⇒ if (x <= 0) throw TE else x + 10, resumingDecider), Grouped(3))) { @@ -177,7 +177,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(Vector(13, 14, 15)))) } - "complete after resume when Map throws before Grouped" in new TestSetup(Seq( + "complete after resume when Map throws before Grouped" in new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, resumingDecider), Map((x: Int) ⇒ if (x <= 0) throw TE else x + 10, resumingDecider), Grouped(1000))) { @@ -205,7 +205,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { } } - new TestSetup(Seq( + new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, restartingDecider), stage, Map((x: Int) ⇒ x + 100, restartingDecider))) { @@ -228,14 +228,13 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { "restart when onPush throws after ctx.push" in { val stage = new RestartTestStage { override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = { - val ret = ctx.push(sum) - super.onPush(elem, ctx) + val ret = ctx.push(elem) if (elem <= 0) throw TE ret } } - new TestSetup(Seq( + new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, restartingDecider), stage, Map((x: Int) ⇒ x + 100, restartingDecider))) { @@ -248,6 +247,10 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { downstream.requestOne() lastEvents() should be(Set(RequestOne)) upstream.onNext(-1) // boom + // The element has been pushed before the exception, there is no way back + lastEvents() should be(Set(OnNext(100))) + + downstream.requestOne() lastEvents() should be(Set(RequestOne)) upstream.onNext(3) @@ -263,7 +266,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { } } - new TestSetup(Seq( + new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, restartingDecider), stage, Map((x: Int) ⇒ x + 100, restartingDecider))) { @@ -283,7 +286,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { } } - "resume when Filter throws" in new TestSetup(Seq( + "resume when Filter throws" in new OneBoundedSetup[Int](Seq( Filter((x: Int) ⇒ if (x == 0) throw TE else true, resumingDecider))) { downstream.requestOne() lastEvents() should be(Set(RequestOne)) @@ -299,7 +302,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(3))) } - "resume when MapConcat throws" in new TestSetup(Seq( + "resume when MapConcat throws" in new OneBoundedSetup[Int](Seq( MapConcat((x: Int) ⇒ if (x == 0) throw TE else List(x, -x), resumingDecider))) { downstream.requestOne() lastEvents() should be(Set(RequestOne)) @@ -323,7 +326,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { // TODO can't get type inference to work with `pf` inlined val pf: PartialFunction[Int, Int] = { case x: Int ⇒ if (x == 0) throw TE else x } - new TestSetup(Seq( + new OneBoundedSetup[Int](Seq( Collect(pf, restartingDecider))) { downstream.requestOne() lastEvents() should be(Set(RequestOne)) @@ -340,7 +343,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { } } - "resume when Scan throws" in new TestSetup(Seq( + "resume when Scan throws" in new OneBoundedSetup[Int](Seq( Scan(1, (acc: Int, x: Int) ⇒ if (x == 10) throw TE else acc + x, resumingDecider))) { downstream.requestOne() lastEvents() should be(Set(OnNext(1))) @@ -358,7 +361,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(7))) // 1 + 2 + 4 } - "restart when Scan throws" in new TestSetup(Seq( + "restart when Scan throws" in new OneBoundedSetup[Int](Seq( Scan(1, (acc: Int, x: Int) ⇒ if (x == 10) throw TE else acc + x, restartingDecider))) { downstream.requestOne() lastEvents() should be(Set(OnNext(1))) @@ -383,7 +386,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(25))) // 1 + 4 + 20 } - "restart when Conflate `seed` throws" in new TestSetup(Seq(Conflate( + "restart when Conflate `seed` throws" in new OneBoundedSetup[Int](Seq(Conflate( (in: Int) ⇒ if (in == 1) throw TE else in, (agg: Int, x: Int) ⇒ agg + x, restartingDecider))) { @@ -412,7 +415,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set.empty) } - "restart when Conflate `aggregate` throws" in new TestSetup(Seq(Conflate( + "restart when Conflate `aggregate` throws" in new OneBoundedSetup[Int](Seq(Conflate( (in: Int) ⇒ in, (agg: Int, x: Int) ⇒ if (x == 2) throw TE else agg + x, restartingDecider))) { @@ -447,7 +450,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel)) } - "fail when Expand `seed` throws" in new TestSetup(Seq(Expand( + "fail when Expand `seed` throws" in new OneBoundedSetup[Int](Seq(Expand( (in: Int) ⇒ if (in == 2) throw TE else in, (agg: Int) ⇒ (agg, -math.abs(agg))))) { @@ -469,7 +472,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnError(TE), Cancel)) } - "fail when Expand `extrapolate` throws" in new TestSetup(Seq(Expand( + "fail when Expand `extrapolate` throws" in new OneBoundedSetup[Int](Seq(Expand( (in: Int) ⇒ in, (agg: Int) ⇒ if (agg == 2) throw TE else (agg, -math.abs(agg))))) { @@ -493,7 +496,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { "fail when onPull throws before pushing all generated elements" in { def test(decider: Supervision.Decider, absorbTermination: Boolean): Unit = { - new TestSetup(Seq( + new OneBoundedSetup[Int](Seq( OneToManyTestStage(decider, absorbTermination))) { downstream.requestOne() diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/LifecycleInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/LifecycleInterpreterSpec.scala index f85e45f73b..97e7b0de5f 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/LifecycleInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/LifecycleInterpreterSpec.scala @@ -3,16 +3,18 @@ */ package akka.stream.impl.fusing +import akka.stream.stage._ +import akka.stream.testkit.AkkaSpec import akka.stream.testkit.Utils.TE import scala.concurrent.duration._ -class LifecycleInterpreterSpec extends InterpreterSpecKit { +class LifecycleInterpreterSpec extends GraphInterpreterSpecKit { import akka.stream.Supervision._ "Interpreter" must { - "call preStart in order on stages" in new TestSetup(Seq( + "call preStart in order on stages" in new OneBoundedSetup[String](Seq( PreStartAndPostStopIdentity(onStart = _ ⇒ testActor ! "start-a"), PreStartAndPostStopIdentity(onStart = _ ⇒ testActor ! "start-b"), PreStartAndPostStopIdentity(onStart = _ ⇒ testActor ! "start-c"))) { @@ -23,7 +25,7 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { upstream.onComplete() } - "call postStop in order on stages - when upstream completes" in new TestSetup(Seq( + "call postStop in order on stages - when upstream completes" in new OneBoundedSetup[String](Seq( PreStartAndPostStopIdentity(onUpstreamCompleted = () ⇒ testActor ! "complete-a", onStop = () ⇒ testActor ! "stop-a"), PreStartAndPostStopIdentity(onUpstreamCompleted = () ⇒ testActor ! "complete-b", onStop = () ⇒ testActor ! "stop-b"), PreStartAndPostStopIdentity(onUpstreamCompleted = () ⇒ testActor ! "complete-c", onStop = () ⇒ testActor ! "stop-c"))) { @@ -37,7 +39,7 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { expectNoMsg(300.millis) } - "call postStop in order on stages - when upstream onErrors" in new TestSetup(Seq( + "call postStop in order on stages - when upstream onErrors" in new OneBoundedSetup[String](Seq( PreStartAndPostStopIdentity( onUpstreamFailed = ex ⇒ testActor ! ex.getMessage, onStop = () ⇒ testActor ! "stop-c"))) { @@ -48,7 +50,7 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { expectNoMsg(300.millis) } - "call postStop in order on stages - when downstream cancels" in new TestSetup(Seq( + "call postStop in order on stages - when downstream cancels" in new OneBoundedSetup[String](Seq( PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-a"), PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-b"), PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-c"))) { @@ -59,7 +61,7 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { expectNoMsg(300.millis) } - "call preStart before postStop" in new TestSetup(Seq( + "call preStart before postStop" in new OneBoundedSetup[String](Seq( PreStartAndPostStopIdentity(onStart = _ ⇒ testActor ! "start-a", onStop = () ⇒ testActor ! "stop-a"))) { expectMsg("start-a") expectNoMsg(300.millis) @@ -68,25 +70,25 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { expectNoMsg(300.millis) } - "onError when preStart fails" in new TestSetup(Seq( + "onError when preStart fails" in new OneBoundedSetup[String](Seq( PreStartFailer(() ⇒ throw TE("Boom!")))) { lastEvents() should ===(Set(Cancel, OnError(TE("Boom!")))) } - "not blow up when postStop fails" in new TestSetup(Seq( + "not blow up when postStop fails" in new OneBoundedSetup[String](Seq( PostStopFailer(() ⇒ throw TE("Boom!")))) { upstream.onComplete() lastEvents() should ===(Set(OnComplete)) } - "onError when preStart fails with stages after" in new TestSetup(Seq( + "onError when preStart fails with stages after" in new OneBoundedSetup[String](Seq( Map((x: Int) ⇒ x, stoppingDecider), PreStartFailer(() ⇒ throw TE("Boom!")), Map((x: Int) ⇒ x, stoppingDecider))) { lastEvents() should ===(Set(Cancel, OnError(TE("Boom!")))) } - "continue with stream shutdown when postStop fails" in new TestSetup(Seq( + "continue with stream shutdown when postStop fails" in new OneBoundedSetup[String](Seq( PostStopFailer(() ⇒ throw TE("Boom!")))) { lastEvents() should ===(Set()) @@ -94,7 +96,7 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { lastEvents should ===(Set(OnComplete)) } - "postStop when pushAndFinish called if upstream completes with pushAndFinish" in new TestSetup(Seq( + "postStop when pushAndFinish called if upstream completes with pushAndFinish" in new OneBoundedSetup[String](Seq( new PushFinishStage(onPostStop = () ⇒ testActor ! "stop"))) { lastEvents() should be(Set.empty) @@ -107,7 +109,7 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { expectMsg("stop") } - "postStop when pushAndFinish called with pushAndFinish if indirect upstream completes with pushAndFinish" in new TestSetup(Seq( + "postStop when pushAndFinish called with pushAndFinish if indirect upstream completes with pushAndFinish" in new OneBoundedSetup[String](Seq( Map((x: Any) ⇒ x, stoppingDecider), new PushFinishStage(onPostStop = () ⇒ testActor ! "stop"), Map((x: Any) ⇒ x, stoppingDecider))) { @@ -122,7 +124,7 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { expectMsg("stop") } - "postStop when pushAndFinish called with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new TestSetup(Seq( + "postStop when pushAndFinish called with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new OneBoundedSetup[String](Seq( new PushFinishStage(onPostStop = () ⇒ testActor ! "stop"), Fold("", (x: String, y: String) ⇒ x + y, stoppingDecider))) { @@ -138,4 +140,54 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { } + private[akka] case class PreStartAndPostStopIdentity[T]( + onStart: LifecycleContext ⇒ Unit = _ ⇒ (), + onStop: () ⇒ Unit = () ⇒ (), + onUpstreamCompleted: () ⇒ Unit = () ⇒ (), + onUpstreamFailed: Throwable ⇒ Unit = ex ⇒ ()) + extends PushStage[T, T] { + override def preStart(ctx: LifecycleContext) = onStart(ctx) + + override def onPush(elem: T, ctx: Context[T]) = ctx.push(elem) + + override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = { + onUpstreamCompleted() + super.onUpstreamFinish(ctx) + } + + override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = { + onUpstreamFailed(cause) + super.onUpstreamFailure(cause, ctx) + } + + override def postStop() = onStop() + } + + private[akka] case class PreStartFailer[T](pleaseThrow: () ⇒ Unit) extends PushStage[T, T] { + + override def preStart(ctx: LifecycleContext) = + pleaseThrow() + + override def onPush(elem: T, ctx: Context[T]) = ctx.push(elem) + } + + private[akka] case class PostStopFailer[T](ex: () ⇒ Throwable) extends PushStage[T, T] { + override def onUpstreamFinish(ctx: Context[T]) = ctx.finish() + override def onPush(elem: T, ctx: Context[T]) = ctx.push(elem) + + override def postStop(): Unit = throw ex() + } + + // This test is related to issue #17351 + private[akka] class PushFinishStage(onPostStop: () ⇒ Unit = () ⇒ ()) extends PushStage[Any, Any] { + override def onPush(elem: Any, ctx: Context[Any]): SyncDirective = + ctx.pushAndFinish(elem) + + override def onUpstreamFinish(ctx: Context[Any]): TerminationDirective = + ctx.fail(akka.stream.testkit.Utils.TE("Cannot happen")) + + override def postStop(): Unit = + onPostStop() + } + } diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala index 66f31eb1df..396afcebf9 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala @@ -55,10 +55,10 @@ class InputStreamSinkSpec extends AkkaSpec(UnboundedMailboxConfig) { class InputStreamSinkTestStage(val timeout: FiniteDuration) extends InputStreamSinkStage(timeout) { - override def createLogicAndMaterializedValue = { - val (logic, inputStream) = super.createLogicAndMaterializedValue - val inHandler = logic.inHandlers(in.id) - logic.inHandlers(in.id) = new InHandler { + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { + val (logic, inputStream) = super.createLogicAndMaterializedValue(inheritedAttributes) + val inHandler = logic.handlers(in.id).asInstanceOf[InHandler] + logic.handlers(in.id) = new InHandler { override def onPush(): Unit = { probe.ref ! InputStreamSinkTestMessages.Push inHandler.onPush() diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala index 1290c59fdb..3b36ce87fb 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala @@ -49,10 +49,10 @@ class OutputStreamSourceSpec extends AkkaSpec(UnboundedMailboxConfig) { class OutputStreamSourceTestStage(val timeout: FiniteDuration) extends OutputStreamSourceStage(timeout) { - override def createLogicAndMaterializedValue = { - val (logic, inputStream) = super.createLogicAndMaterializedValue - val outHandler = logic.outHandlers(out.id) - logic.outHandlers(out.id) = new OutHandler { + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { + val (logic, inputStream) = super.createLogicAndMaterializedValue(inheritedAttributes) + val outHandler = logic.handlers(out.id).asInstanceOf[OutHandler] + logic.handlers(out.id) = new OutHandler { override def onDownstreamFinish(): Unit = { probe.ref ! OutputStreamSourceTestMessages.Finish outHandler.onDownstreamFinish() diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSinkSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSinkSpec.scala index ddaa3e6bcc..18e6aad9d0 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSinkSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSinkSpec.scala @@ -107,7 +107,9 @@ class SynchronousFileSinkSpec extends AkkaSpec(UnboundedMailboxConfig) { } } + // FIXME: overriding dispatcher should be made available with dispatcher alias support in materializer (#17929) "allow overriding the dispatcher using Attributes" in assertAllStagesStopped { + pending targetFile { f ⇒ val sys = ActorSystem("dispatcher-testing", UnboundedMailboxConfig) val mat = ActorMaterializer()(sys) diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSourceSpec.scala index 656a5bf328..df3a2ff65c 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSourceSpec.scala @@ -180,7 +180,9 @@ class SynchronousFileSourceSpec extends AkkaSpec(UnboundedMailboxConfig) { } finally shutdown(sys) } + //FIXME: overriding dispatcher should be made available with dispatcher alias support in materializer (#17929) "allow overriding the dispatcher using Attributes" in { + pending val sys = ActorSystem("dispatcher-testing", UnboundedMailboxConfig) val mat = ActorMaterializer()(sys) implicit val timeout = Timeout(500.millis) diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala index e3503a7012..da0f8d34a6 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala @@ -13,10 +13,11 @@ import akka.stream.testkit.Utils._ import akka.stream.testkit._ import akka.stream.{ ActorMaterializer, BindFailedException, StreamTcpException } import akka.util.{ ByteString, Helpers } - import scala.collection.immutable import scala.concurrent.{ Promise, Await } import scala.concurrent.duration._ +import java.net.BindException +import akka.testkit.EventFilter class TcpSpec extends AkkaSpec("akka.io.tcp.windows-connection-abort-workaround-enabled=auto\nakka.stream.materializer.subscription-timeout.timeout = 3s") with TcpHelper { var demand = 0L @@ -350,12 +351,14 @@ class TcpSpec extends AkkaSpec("akka.io.tcp.windows-connection-abort-workaround- conn.flow.join(writeButIgnoreRead).run() })(Keep.left).run(), 3.seconds) - val result = Source.maybe[ByteString] + val (promise, result) = Source.maybe[ByteString] .via(Tcp().outgoingConnection(serverAddress.getHostName, serverAddress.getPort)) - .runFold(ByteString.empty)(_ ++ _) + .toMat(Sink.fold(ByteString.empty)(_ ++ _))(Keep.both) + .run() Await.result(result, 3.seconds) should ===(ByteString("Early response")) + promise.success(None) // close client upstream, no more data binding.unbind() } @@ -453,7 +456,7 @@ class TcpSpec extends AkkaSpec("akka.io.tcp.windows-connection-abort-workaround- Await.result(echoServerFinish, 1.second) } - "bind and unbind correctly" in { + "bind and unbind correctly" in EventFilter[BindException](occurrences = 2).intercept { if (Helpers.isWindows) { info("On Windows unbinding is not immediate") pending diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala index 85289179c6..86bcaf1eea 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala @@ -12,7 +12,7 @@ import scala.util.Random import akka.actor.ActorSystem import akka.pattern.{ after ⇒ later } -import akka.stream.{ ClosedShape, ActorMaterializer } +import akka.stream._ import akka.stream.scaladsl._ import akka.stream.stage._ import akka.stream.testkit._ @@ -52,35 +52,28 @@ object TlsSpec { * independent of the traffic going through. The purpose is to include the last seen * element in the exception message to help in figuring out what went wrong. */ - class Timeout(duration: FiniteDuration)(implicit system: ActorSystem) extends AsyncStage[ByteString, ByteString, Unit] { - private var last: ByteString = _ + class Timeout(duration: FiniteDuration)(implicit system: ActorSystem) extends GraphStage[FlowShape[ByteString, ByteString]] { - override def preStart(ctx: AsyncContext[ByteString, Unit]) = { - val cb = ctx.getAsyncCallback - system.scheduler.scheduleOnce(duration)(cb.invoke(()))(system.dispatcher) - } + private val in = Inlet[ByteString]("in") + private val out = Outlet[ByteString]("out") + override val shape = FlowShape(in, out) - override def onAsyncInput(u: Unit, ctx: AsyncContext[ByteString, Unit]) = - ctx.fail(new TimeoutException(s"timeout expired, last element was $last")) + override def createLogic(attr: Attributes) = new TimerGraphStageLogic(shape) { + override def preStart(): Unit = scheduleOnce((), duration) - override def onPush(elem: ByteString, ctx: AsyncContext[ByteString, Unit]) = { - last = elem - if (ctx.isHoldingDownstream) ctx.pushAndPull(elem) - else ctx.holdUpstream() - } - - override def onPull(ctx: AsyncContext[ByteString, Unit]) = - if (ctx.isFinishing) ctx.pushAndFinish(last) - else if (ctx.isHoldingUpstream) ctx.pushAndPull(last) - else ctx.holdDownstream() - - override def onUpstreamFinish(ctx: AsyncContext[ByteString, Unit]) = - if (ctx.isHoldingUpstream) ctx.absorbTermination() - else ctx.finish() - - override def onDownstreamFinish(ctx: AsyncContext[ByteString, Unit]) = { - system.log.debug("cancelled") - ctx.finish() + var last: ByteString = _ + setHandler(in, new InHandler { + override def onPush(): Unit = { + last = grab(in) + push(out, last) + } + }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + override def onTimer(x: Any): Unit = { + failStage(new TimeoutException(s"timeout expired, last element was $last")) + } } } @@ -363,7 +356,7 @@ class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off .via(debug) .collect { case SessionBytes(_, b) ⇒ b } .scan(ByteString.empty)(_ ++ _) - .transform(() ⇒ new Timeout(6.seconds)) + .via(new Timeout(6.seconds)) .dropWhile(_.size < scenario.output.size) .runWith(Sink.head) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/AttributesSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/AttributesSpec.scala new file mode 100644 index 0000000000..963679d857 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/AttributesSpec.scala @@ -0,0 +1,65 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.scaladsl + +import akka.stream.ActorMaterializer +import akka.stream.ActorMaterializerSettings +import akka.stream.Attributes +import akka.stream.Attributes._ +import akka.stream.MaterializationContext +import akka.stream.SinkShape +import akka.stream.testkit._ +import scala.concurrent.Future +import scala.concurrent.Promise +import akka.stream.impl.SinkModule +import akka.stream.impl.StreamLayout.Module +import org.scalatest.concurrent.ScalaFutures +import akka.stream.impl.BlackholeSubscriber + +object AttributesSpec { + + object AttributesSink { + def apply(): Sink[Nothing, Future[Attributes]] = + new Sink(new AttributesSink(Attributes.name("attributesSink"), Sink.shape("attributesSink"))) + } + + final class AttributesSink(val attributes: Attributes, shape: SinkShape[Nothing]) extends SinkModule[Nothing, Future[Attributes]](shape) { + override def create(context: MaterializationContext) = + (new BlackholeSubscriber(0, Promise()), Future.successful(context.effectiveAttributes)) + + override protected def newInstance(shape: SinkShape[Nothing]): SinkModule[Nothing, Future[Attributes]] = + new AttributesSink(attributes, shape) + + override def withAttributes(attr: Attributes): Module = + new AttributesSink(attr, amendShape(attr)) + } + +} + +class AttributesSpec extends AkkaSpec with ScalaFutures { + import AttributesSpec._ + + val settings = ActorMaterializerSettings(system) + .withInputBuffer(initialSize = 2, maxSize = 16) + + implicit val materializer = ActorMaterializer(settings) + + "attributes" must { + + "be overridable on a module basis" in { + val runnable = Source.empty.toMat(AttributesSink().withAttributes(Attributes.name("new-name")))(Keep.right) + whenReady(runnable.run()) { attributes ⇒ + attributes.get[Name] should contain(Name("new-name")) + } + } + + "keep the outermost attribute as the least specific" in { + val runnable = Source.empty.toMat(AttributesSink())(Keep.right).withAttributes(Attributes.name("new-name")) + whenReady(runnable.run()) { attributes ⇒ + attributes.get[Name] should contain(Name("attributesSink")) + } + } + } + +} diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala index ecce6e93fd..faf1ee5964 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala @@ -21,46 +21,13 @@ import scala.util.Try import scala.concurrent.ExecutionContext import scala.util.Failure import scala.util.Success +import scala.annotation.tailrec +import scala.concurrent.Promise +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.LinkedBlockingQueue +import org.scalatest.concurrent.ScalaFutures -object FlowMapAsyncSpec { - class MapAsyncOne[In, Out](f: In ⇒ Future[Out])(implicit ec: ExecutionContext) extends AsyncStage[In, Out, Try[Out]] { - private var elemInFlight: Out = _ - - override def onPush(elem: In, ctx: AsyncContext[Out, Try[Out]]) = { - val future = f(elem) - val cb = ctx.getAsyncCallback - future.onComplete(cb.invoke) - ctx.holdUpstream() - } - - override def onPull(ctx: AsyncContext[Out, Try[Out]]) = - if (elemInFlight != null) { - val e = elemInFlight - elemInFlight = null.asInstanceOf[Out] - pushIt(e, ctx) - } else ctx.holdDownstream() - - override def onAsyncInput(input: Try[Out], ctx: AsyncContext[Out, Try[Out]]) = - input match { - case Failure(ex) ⇒ ctx.fail(ex) - case Success(e) if ctx.isHoldingDownstream ⇒ pushIt(e, ctx) - case Success(e) ⇒ - elemInFlight = e - ctx.ignore() - } - - override def onUpstreamFinish(ctx: AsyncContext[Out, Try[Out]]) = - if (ctx.isHoldingUpstream) ctx.absorbTermination() - else ctx.finish() - - private def pushIt(elem: Out, ctx: AsyncContext[Out, Try[Out]]) = - if (ctx.isFinishing) ctx.pushAndFinish(elem) - else ctx.pushAndPull(elem) - } -} - -class FlowMapAsyncSpec extends AkkaSpec { - import FlowMapAsyncSpec._ +class FlowMapAsyncSpec extends AkkaSpec with ScalaFutures { implicit val materializer = ActorMaterializer() @@ -102,11 +69,9 @@ class FlowMapAsyncSpec extends AkkaSpec { n }).to(Sink(c)).run() val sub = c.expectSubscription() - // running 8 in parallel - probe.receiveN(8).toSet should be((1 to 8).toSet) probe.expectNoMsg(500.millis) sub.request(1) - probe.expectMsg(9) + probe.receiveN(9).toSet should be((1 to 9).toSet) probe.expectNoMsg(500.millis) sub.request(2) probe.receiveN(2).toSet should be(Set(10, 11)) @@ -246,44 +211,50 @@ class FlowMapAsyncSpec extends AkkaSpec { } - } + "not run more futures than configured" in assertAllStagesStopped { + val parallelism = 8 - "A MapAsyncOne" must { - import system.dispatcher + val counter = new AtomicInteger + val queue = new LinkedBlockingQueue[(Promise[Int], Long)] - "work in the happy case" in { - val probe = TestProbe() - val N = 100 - val f = Source(1 to N).transform(() ⇒ new MapAsyncOne(i ⇒ { - probe.ref ! i - Future { Thread.sleep(10); probe.ref ! (i + 10); i * 2 } - })).grouped(N + 10).runWith(Sink.head) - Await.result(f, 2.seconds) should ===((1 to N).map(_ * 2)) - probe.receiveN(2 * N) should ===((1 to N).flatMap(x ⇒ List(x, x + 10))) - probe.expectNoMsg(100.millis) - } + val timer = new Thread { + val delay = 50000 // nanoseconds + var count = 0 + @tailrec final override def run(): Unit = { + val cont = try { + val (promise, enqueued) = queue.take() + val wakeup = enqueued + delay + while (System.nanoTime() < wakeup) {} + counter.decrementAndGet() + promise.success(count) + count += 1 + true + } catch { + case _: InterruptedException ⇒ false + } + if (cont) run() + } + } + timer.start - "work when futures fail" in { - val probe = TestSubscriber.manualProbe[Int] - val ex = new Exception("KABOOM") - Source.single(1) - .transform(() ⇒ new MapAsyncOne(_ ⇒ Future.failed(ex))) - .runWith(Sink(probe)) - val sub = probe.expectSubscription() - sub.request(1) - probe.expectError(ex) - } + def deferred(): Future[Int] = { + if (counter.incrementAndGet() > parallelism) Future.failed(new Exception("parallelism exceeded")) + else { + val p = Promise[Int] + queue.offer(p -> System.nanoTime()) + p.future + } + } - "work when futures fail later" in { - val probe = TestSubscriber.manualProbe[Int] - val ex = new Exception("KABOOM") - Source(List(1, 2)) - .transform(() ⇒ new MapAsyncOne(x ⇒ if (x == 1) Future.successful(1) else Future.failed(ex))) - .runWith(Sink(probe)) - val sub = probe.expectSubscription() - sub.request(1) - probe.expectNext(1) - probe.expectError(ex) + try { + val N = 100000 + Source(1 to N) + .mapAsync(parallelism)(i ⇒ deferred()) + .runFold(0)((c, _) ⇒ c + 1) + .futureValue(PatienceConfig(3.seconds)) should ===(N) + } finally { + timer.interrupt() + } } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncUnorderedSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncUnorderedSpec.scala index b22737d70b..09173520cb 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncUnorderedSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncUnorderedSpec.scala @@ -7,7 +7,6 @@ import scala.concurrent.Await import scala.concurrent.Future import scala.concurrent.duration._ import scala.util.control.NoStackTrace - import akka.stream.ActorMaterializer import akka.stream.testkit._ import akka.stream.testkit.scaladsl._ @@ -17,8 +16,15 @@ import akka.testkit.TestProbe import akka.stream.ActorAttributes.supervisionStrategy import akka.stream.Supervision.resumingDecider import akka.stream.impl.ReactiveStreamsCompliance +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.ConcurrentLinkedQueue +import scala.concurrent.Promise +import java.util.concurrent.LinkedBlockingQueue +import scala.annotation.tailrec +import org.scalatest.concurrent.ScalaFutures +import org.scalactic.ConversionCheckedTripleEquals -class FlowMapAsyncUnorderedSpec extends AkkaSpec { +class FlowMapAsyncUnorderedSpec extends AkkaSpec with ScalaFutures with ConversionCheckedTripleEquals { implicit val materializer = ActorMaterializer() @@ -54,13 +60,11 @@ class FlowMapAsyncUnorderedSpec extends AkkaSpec { n }).to(Sink(c)).run() val sub = c.expectSubscription() - // first four run immediately - probe.expectMsgAllOf(1, 2, 3, 4) c.expectNoMsg(200.millis) probe.expectNoMsg(Duration.Zero) sub.request(1) var got = Set(c.expectNext()) - probe.expectMsg(5) + probe.expectMsgAllOf(1, 2, 3, 4, 5) probe.expectNoMsg(500.millis) sub.request(25) probe.expectMsgAllOf(6 to 20: _*) @@ -176,11 +180,11 @@ class FlowMapAsyncUnorderedSpec extends AkkaSpec { .to(Sink(c)).run() val sub = c.expectSubscription() sub.request(10) - for (elem ← List("a", "c")) c.expectNext(elem) + c.expectNextUnordered("a", "c") c.expectComplete() } - "should handle cancel properly" in assertAllStagesStopped { + "handle cancel properly" in assertAllStagesStopped { val pub = TestPublisher.manualProbe[Int]() val sub = TestSubscriber.manualProbe[Int]() @@ -195,5 +199,51 @@ class FlowMapAsyncUnorderedSpec extends AkkaSpec { } + "not run more futures than configured" in assertAllStagesStopped { + val parallelism = 8 + + val counter = new AtomicInteger + val queue = new LinkedBlockingQueue[(Promise[Int], Long)] + + val timer = new Thread { + val delay = 50000 // nanoseconds + var count = 0 + @tailrec final override def run(): Unit = { + val cont = try { + val (promise, enqueued) = queue.take() + val wakeup = enqueued + delay + while (System.nanoTime() < wakeup) {} + counter.decrementAndGet() + promise.success(count) + count += 1 + true + } catch { + case _: InterruptedException ⇒ false + } + if (cont) run() + } + } + timer.start + + def deferred(): Future[Int] = { + if (counter.incrementAndGet() > parallelism) Future.failed(new Exception("parallelism exceeded")) + else { + val p = Promise[Int] + queue.offer(p -> System.nanoTime()) + p.future + } + } + + try { + val N = 100000 + Source(1 to N) + .mapAsyncUnordered(parallelism)(i ⇒ deferred()) + .runFold(0)((c, _) ⇒ c + 1) + .futureValue(PatienceConfig(3.seconds)) should ===(N) + } finally { + timer.interrupt() + } + } + } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSpec.scala index e64c550c49..d05cd1f863 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSpec.scala @@ -6,12 +6,14 @@ package akka.stream.scaladsl import akka.actor._ import akka.stream.Supervision._ import akka.stream.impl._ -import akka.stream.impl.fusing.ActorInterpreter -import akka.stream.stage.Stage +import akka.stream.impl.fusing.{ ActorGraphInterpreter } +import akka.stream.impl.fusing.GraphInterpreter.GraphAssembly +import akka.stream.stage.AbstractStage.PushPullGraphStage +import akka.stream.stage.{ GraphStageLogic, OutHandler, InHandler, Stage } import akka.stream.testkit.Utils._ import akka.stream.testkit._ import akka.stream.testkit.scaladsl.TestSink -import akka.stream.{ AbruptTerminationException, ActorMaterializer, ActorMaterializerSettings, Attributes } +import akka.stream._ import akka.testkit.TestEvent.{ Mute, UnMute } import akka.testkit.{ EventFilter, TestDuration } import com.typesafe.config.ConfigFactory @@ -32,7 +34,7 @@ class FlowSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug.rece import FlowSpec._ val settings = ActorMaterializerSettings(system) - .withInputBuffer(initialSize = 2, maxSize = 16) + .withInputBuffer(initialSize = 2, maxSize = 2) implicit val mat = ActorMaterializer(settings) @@ -40,30 +42,51 @@ class FlowSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug.rece val identity2: Flow[Any, Any, _] ⇒ Flow[Any, Any, _] = in ⇒ identity(in) class BrokenActorInterpreter( + _assembly: GraphAssembly, + _inHandlers: Array[InHandler], + _outHandlers: Array[OutHandler], + _logics: Array[GraphStageLogic], + _shape: Shape, _settings: ActorMaterializerSettings, - _ops: Seq[Stage[_, _]], + _mat: Materializer, brokenMessage: Any) - extends ActorInterpreter(_settings, _ops, mat, Attributes.none) { + extends ActorGraphInterpreter(_assembly, _inHandlers, _outHandlers, _logics, _shape, _settings, _mat) { import akka.stream.actor.ActorSubscriberMessage._ override protected[akka] def aroundReceive(receive: Receive, msg: Any) = { msg match { - case OnNext(m) if m == brokenMessage ⇒ + case ActorGraphInterpreter.OnNext(0, m) if m == brokenMessage ⇒ throw new NullPointerException(s"I'm so broken [$m]") case _ ⇒ super.aroundReceive(receive, msg) } } } - val faultyFlow: Flow[Any, Any, _] ⇒ Flow[Any, Any, _] = in ⇒ in.andThenMat { () ⇒ - val props = Props(new BrokenActorInterpreter(settings, List(fusing.Map({ x: Any ⇒ x }, stoppingDecider)), "a3")) + val faultyFlow: Flow[Any, Any, _] ⇒ Flow[Any, Any, _] = in ⇒ in.via({ + val stage = new PushPullGraphStage((_) ⇒ fusing.Map({ x: Any ⇒ x }, stoppingDecider), Attributes.none) + + val assembly = new GraphAssembly( + Array(stage), + Array(Attributes.none), + Array(stage.shape.inlet, null), + Array(0, -1), + Array(null, stage.shape.outlet), + Array(-1, 0)) + + val (inHandlers, outHandlers, logics, _) = assembly.materialize(Attributes.none) + + val props = Props(new BrokenActorInterpreter(assembly, inHandlers, outHandlers, logics, stage.shape, settings, mat, "a3")) .withDispatcher("akka.test.stream-dispatcher").withDeploy(Deploy.local) - val processor = ActorProcessorFactory[Any, Any](system.actorOf( - props, - "borken-stage-actor")) - (processor, ()) - } + val impl = system.actorOf(props, "borken-stage-actor") + + val subscriber = new ActorGraphInterpreter.BoundarySubscriber(impl, 0) + val publisher = new ActorPublisher[Any](impl) { override val wakeUpMsg = ActorGraphInterpreter.SubscribePending(0) } + + impl ! ActorGraphInterpreter.ExposedPublisher(0, publisher) + + Flow.fromSinkAndSource(Sink(subscriber), Source(publisher)) + }) val toPublisher: (Source[Any, _], ActorMaterializer) ⇒ Publisher[Any] = (f, m) ⇒ f.runWith(Sink.publisher)(m) @@ -81,15 +104,15 @@ class FlowSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug.rece for ((name, op) ← List("identity" -> identity, "identity2" -> identity2); n ← List(1, 2, 4)) { s"request initial elements from upstream ($name, $n)" in { - new ChainSetup(op, settings.withInputBuffer(initialSize = n, maxSize = settings.maxInputBufferSize), toPublisher) { - upstream.expectRequest(upstreamSubscription, settings.initialInputBufferSize) + new ChainSetup(op, settings.withInputBuffer(initialSize = n, maxSize = n), toPublisher) { + upstream.expectRequest(upstreamSubscription, settings.maxInputBufferSize) } } } "request more elements from upstream when downstream requests more elements" in { new ChainSetup(identity, settings, toPublisher) { - upstream.expectRequest(upstreamSubscription, settings.initialInputBufferSize) + upstream.expectRequest(upstreamSubscription, settings.maxInputBufferSize) downstreamSubscription.request(1) upstream.expectNoMsg(100.millis) downstreamSubscription.request(2) @@ -132,7 +155,7 @@ class FlowSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug.rece } "cancel upstream when single subscriber cancels subscription while receiving data" in { - new ChainSetup(identity, settings.withInputBuffer(initialSize = 1, maxSize = settings.maxInputBufferSize), toPublisher) { + new ChainSetup(identity, settings.withInputBuffer(initialSize = 1, maxSize = 1), toPublisher) { downstreamSubscription.request(5) upstreamSubscription.expectRequest(1) upstreamSubscription.sendNext("test") @@ -291,7 +314,7 @@ class FlowSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug.rece "be possible to convert to a processor, and should be able to take a Processor" in { val identity1 = Flow[Int].toProcessor - val identity2 = Flow(() ⇒ identity1.run()) + val identity2 = Flow.fromProcessor(() ⇒ identity1.run()) Await.result( Source(1 to 10).via(identity2).grouped(100).runWith(Sink.head), 3.seconds) should ===(1 to 10) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPreferredMergeSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergePreferredSpec.scala similarity index 97% rename from akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPreferredMergeSpec.scala rename to akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergePreferredSpec.scala index 992648d5f7..6acb3625e2 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPreferredMergeSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergePreferredSpec.scala @@ -9,7 +9,7 @@ import akka.stream._ import scala.concurrent.Await import scala.concurrent.duration._ -class GraphPreferredMergeSpec extends TwoStreamsSetup { +class GraphMergePreferredSpec extends TwoStreamsSetup { import FlowGraph.Implicits._ override type Outputs = Int diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphStageTimersSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphStageTimersSpec.scala index 1f1343389b..30ba5cd735 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphStageTimersSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphStageTimersSpec.scala @@ -1,9 +1,9 @@ package akka.stream.scaladsl import akka.actor.ActorRef -import akka.stream.ActorMaterializer +import akka.stream.{ Attributes, ActorMaterializer } import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage -import akka.stream.stage.{ OutHandler, AsyncCallback, InHandler } +import akka.stream.stage.{ TimerGraphStageLogic, OutHandler, AsyncCallback, InHandler } import akka.stream.testkit.{ AkkaSpec, TestPublisher } import akka.testkit.TestDuration @@ -39,13 +39,17 @@ class GraphStageTimersSpec extends AkkaSpec { implicit val mat = ActorMaterializer() class TestStage(probe: ActorRef, sideChannel: SideChannel) extends SimpleLinearGraphStage[Int] { - override def createLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes) = new TimerGraphStageLogic(shape) { val tickCount = Iterator from 1 setHandler(in, new InHandler { override def onPush() = push(out, grab(in)) }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + override def preStart() = { sideChannel.asyncCallback = getAsyncCallback(onTestEvent) } @@ -147,7 +151,7 @@ class GraphStageTimersSpec extends AkkaSpec { } class TestStage2 extends SimpleLinearGraphStage[Int] { - override def createLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes) = new TimerGraphStageLogic(shape) { var tickCount = 0 override def preStart(): Unit = schedulePeriodically("tick", 100.millis) @@ -195,13 +199,17 @@ class GraphStageTimersSpec extends AkkaSpec { val downstream = TestSubscriber.probe[Int]() Source(upstream).via(new SimpleLinearGraphStage[Int] { - override def createLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes) = new TimerGraphStageLogic(shape) { override def preStart(): Unit = scheduleOnce("tick", 100.millis) setHandler(in, new InHandler { override def onPush() = () // Ingore }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + override def onTimer(timerKey: Any) = throw exception } }).runWith(Sink(downstream)) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala index 26730799c3..a2ec35a06e 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala @@ -73,7 +73,7 @@ class SourceSpec extends AkkaSpec { } "Maybe Source" must { - "complete materialized future with None when stream cancels" in { + "complete materialized future with None when stream cancels" in Utils.assertAllStagesStopped { val neverSource = Source.maybe[Int] val pubSink = Sink.publisher[Int] @@ -90,7 +90,7 @@ class SourceSpec extends AkkaSpec { Await.result(f.future, 500.millis) shouldEqual None } - "allow external triggering of empty completion" in { + "allow external triggering of empty completion" in Utils.assertAllStagesStopped { val neverSource = Source.maybe[Int].filter(_ ⇒ false) val counterSink = Sink.fold[Int, Int](0) { (acc, _) ⇒ acc + 1 } @@ -102,7 +102,7 @@ class SourceSpec extends AkkaSpec { Await.result(counterFuture, 500.millis) shouldEqual 0 } - "allow external triggering of non-empty completion" in { + "allow external triggering of non-empty completion" in Utils.assertAllStagesStopped { val neverSource = Source.maybe[Int] val counterSink = Sink.head[Int] @@ -114,7 +114,7 @@ class SourceSpec extends AkkaSpec { Await.result(counterFuture, 500.millis) shouldEqual 6 } - "allow external triggering of onError" in { + "allow external triggering of onError" in Utils.assertAllStagesStopped { val neverSource = Source.maybe[Int] val counterSink = Sink.fold[Int, Int](0) { (acc, _) ⇒ acc + 1 } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubscriberSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubscriberSourceSpec.scala index 29d2d82aa1..4bed273a0d 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubscriberSourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubscriberSourceSpec.scala @@ -11,7 +11,7 @@ import scala.concurrent.duration._ import scala.concurrent.Await -class SubscriberSourceSpec extends AkkaSpec("akka.loglevel=DEBUG\nakka.actor.debug.lifecycle=on") { +class SubscriberSourceSpec extends AkkaSpec { implicit val materializer = ActorMaterializer() diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala index a49c67bb7c..2d09cb2e7c 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala @@ -150,7 +150,7 @@ class SubstreamSubscriptionTimeoutSpec(conf: String) extends AkkaSpec(conf) { private def watchGroupByActor(flowNr: Int): ActorRef = { implicit val t = Timeout(300.millis) import akka.pattern.ask - val path = s"/user/$$a/flow-${flowNr}-1-publisherSource-groupBy" + val path = s"/user/$$a/flow-${flowNr}-1-groupBy" val gropByPath = system.actorSelection(path) val groupByActor = try { Await.result((gropByPath ? Identify("")).mapTo[ActorIdentity], 300.millis).ref.get diff --git a/akka-stream/src/main/boilerplate/akka/stream/scaladsl/UnzipWithApply.scala.template b/akka-stream/src/main/boilerplate/akka/stream/scaladsl/UnzipWithApply.scala.template index 10ac47fa63..5a5eba8b21 100644 --- a/akka-stream/src/main/boilerplate/akka/stream/scaladsl/UnzipWithApply.scala.template +++ b/akka-stream/src/main/boilerplate/akka/stream/scaladsl/UnzipWithApply.scala.template @@ -49,7 +49,7 @@ class UnzipWith1[In, [#A1#]](unzipper: In ⇒ ([#A1#])) extends GraphStage[FanOu [#def out0: Outlet[A1] = shape.out0# ] - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { var pendingCount = 1 var downstreamRunning = 1 diff --git a/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template b/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template index 9f40ba4c8b..03af15507f 100644 --- a/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template +++ b/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template @@ -30,7 +30,7 @@ class ZipWith1[[#A1#], O] (zipper: ([#A1#]) ⇒ O) extends GraphStage[FanInShape [#val in0: Inlet[A1] = shape.in0# ] - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { var pending = 1 private def pushAll(): Unit = push(out, zipper([#grab(in0)#])) diff --git a/akka-stream/src/main/scala/akka/stream/ActorMaterializer.scala b/akka-stream/src/main/scala/akka/stream/ActorMaterializer.scala index 254ee46aac..e072cfd40e 100644 --- a/akka-stream/src/main/scala/akka/stream/ActorMaterializer.scala +++ b/akka-stream/src/main/scala/akka/stream/ActorMaterializer.scala @@ -31,11 +31,11 @@ object ActorMaterializer { * the processing steps. The default `namePrefix` is `"flow"`. The actor names are built up of * `namePrefix-flowNumber-flowStepNumber-stepName`. */ - def apply(materializerSettings: Option[ActorMaterializerSettings] = None, namePrefix: Option[String] = None, optimizations: Optimizations = Optimizations.none)(implicit context: ActorRefFactory): ActorMaterializer = { + def apply(materializerSettings: Option[ActorMaterializerSettings] = None, namePrefix: Option[String] = None)(implicit context: ActorRefFactory): ActorMaterializer = { val system = actorSystemOf(context) val settings = materializerSettings getOrElse ActorMaterializerSettings(system) - apply(settings, namePrefix.getOrElse("flow"), optimizations)(context) + apply(settings, namePrefix.getOrElse("flow"))(context) } /** @@ -49,7 +49,7 @@ object ActorMaterializer { * the processing steps. The default `namePrefix` is `"flow"`. The actor names are built up of * `namePrefix-flowNumber-flowStepNumber-stepName`. */ - def apply(materializerSettings: ActorMaterializerSettings, namePrefix: String, optimizations: Optimizations)(implicit context: ActorRefFactory): ActorMaterializer = { + def apply(materializerSettings: ActorMaterializerSettings, namePrefix: String)(implicit context: ActorRefFactory): ActorMaterializer = { val haveShutDown = new AtomicBoolean(false) val system = actorSystemOf(context) @@ -60,8 +60,7 @@ object ActorMaterializer { context.actorOf(StreamSupervisor.props(materializerSettings, haveShutDown).withDispatcher(materializerSettings.dispatcher)), haveShutDown, FlowNameCounter(system).counter, - namePrefix, - optimizations) + namePrefix) } /** @@ -199,11 +198,10 @@ object ActorMaterializerSettings { supervisionDecider: Supervision.Decider, subscriptionTimeoutSettings: StreamSubscriptionTimeoutSettings, debugLogging: Boolean, - outputBurstLimit: Int, - optimizations: Optimizations) = + outputBurstLimit: Int) = new ActorMaterializerSettings( initialInputBufferSize, maxInputBufferSize, dispatcher, supervisionDecider, subscriptionTimeoutSettings, debugLogging, - outputBurstLimit, optimizations) + outputBurstLimit) /** * Create [[ActorMaterializerSettings]]. @@ -228,8 +226,7 @@ object ActorMaterializerSettings { supervisionDecider = Supervision.stoppingDecider, subscriptionTimeoutSettings = StreamSubscriptionTimeoutSettings(config), debugLogging = config.getBoolean("debug-logging"), - outputBurstLimit = config.getInt("output-burst-limit"), - optimizations = Optimizations.none) + outputBurstLimit = config.getInt("output-burst-limit")) /** * Java API @@ -263,8 +260,7 @@ final class ActorMaterializerSettings( val supervisionDecider: Supervision.Decider, val subscriptionTimeoutSettings: StreamSubscriptionTimeoutSettings, val debugLogging: Boolean, - val outputBurstLimit: Int, - val optimizations: Optimizations) { + val outputBurstLimit: Int) { require(initialInputBufferSize > 0, "initialInputBufferSize must be > 0") @@ -278,11 +274,10 @@ final class ActorMaterializerSettings( supervisionDecider: Supervision.Decider = this.supervisionDecider, subscriptionTimeoutSettings: StreamSubscriptionTimeoutSettings = this.subscriptionTimeoutSettings, debugLogging: Boolean = this.debugLogging, - outputBurstLimit: Int = this.outputBurstLimit, - optimizations: Optimizations = this.optimizations) = + outputBurstLimit: Int = this.outputBurstLimit) = new ActorMaterializerSettings( initialInputBufferSize, maxInputBufferSize, dispatcher, supervisionDecider, subscriptionTimeoutSettings, debugLogging, - outputBurstLimit, optimizations) + outputBurstLimit) def withInputBuffer(initialSize: Int, maxSize: Int): ActorMaterializerSettings = copy(initialInputBufferSize = initialSize, maxInputBufferSize = maxSize) @@ -316,9 +311,6 @@ final class ActorMaterializerSettings( def withDebugLogging(enable: Boolean): ActorMaterializerSettings = copy(debugLogging = enable) - def withOptimizations(optimizations: Optimizations): ActorMaterializerSettings = - copy(optimizations = optimizations) - private def requirePowerOfTwo(n: Integer, name: String): Unit = { require(n > 0, s"$name must be > 0") require((n & (n - 1)) == 0, s"$name must be a power of two") @@ -365,11 +357,3 @@ object StreamSubscriptionTimeoutTerminationMode { } -final object Optimizations { - val none: Optimizations = Optimizations(collapsing = false, elision = false, simplification = false, fusion = false) - val all: Optimizations = Optimizations(collapsing = true, elision = true, simplification = true, fusion = true) -} - -final case class Optimizations(collapsing: Boolean, elision: Boolean, simplification: Boolean, fusion: Boolean) { - def isEnabled: Boolean = collapsing || elision || simplification || fusion -} diff --git a/akka-stream/src/main/scala/akka/stream/Attributes.scala b/akka-stream/src/main/scala/akka/stream/Attributes.scala index 1732717140..0863a2b2cd 100644 --- a/akka-stream/src/main/scala/akka/stream/Attributes.scala +++ b/akka-stream/src/main/scala/akka/stream/Attributes.scala @@ -4,10 +4,10 @@ package akka.stream import akka.event.Logging - import scala.annotation.tailrec import scala.collection.immutable -import akka.stream.impl.Stages.StageModule +import scala.reflect.{ classTag, ClassTag } +import akka.stream.impl.Stages.SymbolicStage import akka.japi.function /** @@ -16,7 +16,7 @@ import akka.japi.function * * Note that more attributes for the [[ActorMaterializer]] are defined in [[ActorAttributes]]. */ -final case class Attributes private (attributeList: immutable.Seq[Attributes.Attribute] = Nil) { +final case class Attributes(attributeList: List[Attributes.Attribute] = Nil) { import Attributes._ @@ -44,22 +44,47 @@ final case class Attributes private (attributeList: immutable.Seq[Attributes.Att } /** - * Get first attribute of a given `Class` or subclass thereof. + * Java API: Get the last (most specific) attribute of a given `Class` or subclass thereof. * If no such attribute exists the `default` value is returned. */ def getAttribute[T <: Attribute](c: Class[T], default: T): T = - attributeList.find(c.isInstance) match { - case Some(a) ⇒ c.cast(a) + getAttribute(c) match { + case Some(a) ⇒ a case None ⇒ default } + /** + * Java API: Get the last (most specific) attribute of a given `Class` or subclass thereof. + */ + def getAttribute[T <: Attribute](c: Class[T]): Option[T] = + Option(attributeList.foldLeft(null.asInstanceOf[T])((acc, attr) ⇒ if (c.isInstance(attr)) c.cast(attr) else acc)) + + /** + * Get the last (most specific) attribute of a given type parameter T `Class` or subclass thereof. + * If no such attribute exists the `default` value is returned. + */ + def get[T <: Attribute: ClassTag](default: T) = + getAttribute(classTag[T].runtimeClass.asInstanceOf[Class[T]], default) + + /** + * Get the last (most specific) attribute of a given type parameter T `Class` or subclass thereof. + */ + def get[T <: Attribute: ClassTag] = + getAttribute(classTag[T].runtimeClass.asInstanceOf[Class[T]]) + /** * Adds given attributes to the end of these attributes. */ def and(other: Attributes): Attributes = if (attributeList.isEmpty) other else if (other.attributeList.isEmpty) this - else Attributes(attributeList ++ other.attributeList) + else Attributes(attributeList ::: other.attributeList) + + /** + * Adds given attribute to the end of these attributes. + */ + def and(other: Attribute): Attributes = + Attributes(attributeList :+ other) /** * INTERNAL API @@ -76,12 +101,7 @@ final case class Attributes private (attributeList: immutable.Seq[Attributes.Att case Name(n) ⇒ if (buf ne null) concatNames(i, null, buf.append('-').append(n)) else if (first ne null) { - val b = new StringBuilder( - (first.length() + n.length()) match { - case x if x < 0 ⇒ throw new IllegalStateException("Names too long to concatenate") - case y if y > Int.MaxValue / 2 ⇒ Int.MaxValue - case z ⇒ Math.max(Integer.highestOneBit(z) * 2, 32) - }) + val b = new StringBuilder((first.length() + n.length()) * 2) concatNames(i, null, b.append(first).append('-').append(n)) } else concatNames(i, n, null) case _ ⇒ concatNames(i, first, buf) @@ -95,16 +115,6 @@ final case class Attributes private (attributeList: immutable.Seq[Attributes.Att } } - /** - * INTERNAL API - */ - private[akka] def logLevels: Option[LogLevels] = - attributeList.collectFirst { case l: LogLevels ⇒ l } - - private[akka] def transform(node: StageModule): StageModule = - if ((this eq Attributes.none) || (this eq node.attributes)) node - else node.withAttributes(attributes = this and node.attributes) - } /** diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala index 980bf20ebf..6362def624 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala @@ -10,7 +10,7 @@ import akka.dispatch.Dispatchers import akka.pattern.ask import akka.stream.actor.ActorSubscriber import akka.stream.impl.StreamLayout.Module -import akka.stream.impl.fusing.{ ActorGraphInterpreter, GraphModule, ActorInterpreter } +import akka.stream.impl.fusing.{ ActorGraphInterpreter, GraphModule } import akka.stream.impl.io.SslTlsCipherActor import akka.stream._ import akka.stream.io.SslTls.TlsModule @@ -24,14 +24,13 @@ import scala.concurrent.{ Await, ExecutionContextExecutor } /** * INTERNAL API */ -private[akka] case class ActorMaterializerImpl(val system: ActorSystem, +private[akka] case class ActorMaterializerImpl(system: ActorSystem, override val settings: ActorMaterializerSettings, dispatchers: Dispatchers, - val supervisor: ActorRef, - val haveShutDown: AtomicBoolean, + supervisor: ActorRef, + haveShutDown: AtomicBoolean, flowNameCounter: AtomicLong, - namePrefix: String, - optimizations: Optimizations) extends ActorMaterializer { + namePrefix: String) extends ActorMaterializer { import akka.stream.impl.Stages._ override def shutdown(): Unit = @@ -45,6 +44,12 @@ private[akka] case class ActorMaterializerImpl(val system: ActorSystem, private[this] def createFlowName(): String = s"$namePrefix-${nextFlowNameCount()}" + private val initialAttributes = Attributes( + Attributes.InputBuffer(settings.initialInputBufferSize, settings.maxInputBufferSize) :: + ActorAttributes.Dispatcher(settings.dispatcher) :: + ActorAttributes.SupervisionStrategy(settings.supervisionDecider) :: + Nil) + override def effectiveSettings(opAttr: Attributes): ActorMaterializerSettings = { import Attributes._ import ActorAttributes._ @@ -70,7 +75,7 @@ private[akka] case class ActorMaterializerImpl(val system: ActorSystem, throw new IllegalStateException("Attempted to call materialize() after the ActorMaterializer has been shut down.") if (StreamLayout.Debug) StreamLayout.validate(runnableGraph.module) - val session = new MaterializerSession(runnableGraph.module) { + val session = new MaterializerSession(runnableGraph.module, initialAttributes) { private val flowName = createFlowName() private var nextId = 0 private def stageName(attr: Attributes): String = { @@ -93,6 +98,7 @@ private[akka] case class ActorMaterializerImpl(val system: ActorSystem, assignPort(source.shape.outlet, pub.asInstanceOf[Publisher[Any]]) mat + // FIXME: Remove this, only stream-of-stream ops need it case stage: StageModule ⇒ val (processor, mat) = processorFor(stage, effectiveAttributes, effectiveSettings(effectiveAttributes)) assignPort(stage.inPort, processor) @@ -118,7 +124,7 @@ private[akka] case class ActorMaterializerImpl(val system: ActorSystem, case graph: GraphModule ⇒ val calculatedSettings = effectiveSettings(effectiveAttributes) - val (inHandlers, outHandlers, logics, mat) = graph.assembly.materialize() + val (inHandlers, outHandlers, logics, mat) = graph.assembly.materialize(effectiveAttributes) val props = ActorGraphInterpreter.props( graph.assembly, inHandlers, outHandlers, logics, graph.shape, calculatedSettings, ActorMaterializerImpl.this) @@ -137,11 +143,11 @@ private[akka] case class ActorMaterializerImpl(val system: ActorSystem, } } + // FIXME: Remove this, only stream-of-stream ops need it private def processorFor(op: StageModule, effectiveAttributes: Attributes, effectiveSettings: ActorMaterializerSettings): (Processor[Any, Any], Any) = op match { case DirectProcessor(processorFactory, _) ⇒ processorFactory() - case Identity(attr) ⇒ (new VirtualProcessor, ()) case _ ⇒ val (opprops, mat) = ActorProcessorFactory.props(ActorMaterializerImpl.this, op, effectiveAttributes) ActorProcessorFactory[Any, Any]( @@ -248,47 +254,12 @@ private[akka] object ActorProcessorFactory { // USE THIS TO AVOID CLOSING OVER THE MATERIALIZER BELOW // Also, otherwise the attributes will not affect the settings properly! val settings = materializer.effectiveSettings(att) - def interp(s: Stage[_, _]): (Props, Unit) = (ActorInterpreter.props(settings, List(s), materializer, att), ()) - def interpAttr(s: Stage[_, _], newAttributes: Attributes): (Props, Unit) = (ActorInterpreter.props(settings, List(s), materializer, newAttributes), ()) - def inputSizeAttr(n: Long) = { - if (n <= 0) - inputBuffer(initial = 1, max = 1) and att - else if (n <= materializer.settings.maxInputBufferSize) - inputBuffer(initial = n.toInt, max = n.toInt) and att - else - att - } op match { - case Map(f, _) ⇒ interp(fusing.Map(f, settings.supervisionDecider)) - case Filter(p, _) ⇒ interp(fusing.Filter(p, settings.supervisionDecider)) - case Drop(n, _) ⇒ interp(fusing.Drop(n)) - case Take(n, _) ⇒ interpAttr(fusing.Take(n), inputSizeAttr(n)) - case TakeWhile(p, _) ⇒ interp(fusing.TakeWhile(p, settings.supervisionDecider)) - case DropWhile(p, _) ⇒ interp(fusing.DropWhile(p, settings.supervisionDecider)) - case Collect(pf, _) ⇒ interp(fusing.Collect(pf, settings.supervisionDecider)) - case Scan(z, f, _) ⇒ interp(fusing.Scan(z, f, settings.supervisionDecider)) - case Fold(z, f, _) ⇒ interp(fusing.Fold(z, f, settings.supervisionDecider)) - case Intersperse(s, i, e, _) ⇒ interp(fusing.Intersperse(s, i, e)) - case Recover(pf, _) ⇒ interp(fusing.Recover(pf)) - case Expand(s, f, _) ⇒ interp(fusing.Expand(s, f)) - case Conflate(s, f, _) ⇒ interp(fusing.Conflate(s, f, settings.supervisionDecider)) - case Buffer(n, s, _) ⇒ interp(fusing.Buffer(n, s)) - case MapConcat(f, _) ⇒ interp(fusing.MapConcat(f, settings.supervisionDecider)) - case MapAsync(p, f, _) ⇒ interp(fusing.MapAsync(p, f, settings.supervisionDecider)) - case MapAsyncUnordered(p, f, _) ⇒ interp(fusing.MapAsyncUnordered(p, f, settings.supervisionDecider)) - case Grouped(n, _) ⇒ interp(fusing.Grouped(n)) - case Sliding(n, step, _) ⇒ interp(fusing.Sliding(n, step)) - case Log(n, e, l, _) ⇒ interp(fusing.Log(n, e, l)) - case GroupBy(f, _) ⇒ (GroupByProcessorImpl.props(settings, f), ()) - case PrefixAndTail(n, _) ⇒ (PrefixAndTailImpl.props(settings, n), ()) - case Split(d, _) ⇒ (SplitWhereProcessorImpl.props(settings, d), ()) - case ConcatAll(_) ⇒ (ConcatAllImpl.props(materializer), ()) - case StageFactory(mkStage, _) ⇒ interp(mkStage()) - case MaterializingStageFactory(mkStageAndMat, _) ⇒ - val s_m = mkStageAndMat() - (ActorInterpreter.props(settings, List(s_m._1), materializer, att), s_m._2) + case GroupBy(f, _) ⇒ (GroupByProcessorImpl.props(settings, f), ()) + case PrefixAndTail(n, _) ⇒ (PrefixAndTailImpl.props(settings, n), ()) + case Split(d, _) ⇒ (SplitWhereProcessorImpl.props(settings, d), ()) + case ConcatAll(_) ⇒ (ConcatAllImpl.props(materializer), ()) case DirectProcessor(p, m) ⇒ throw new AssertionError("DirectProcessor cannot end up in ActorProcessorFactory") - case Identity(_) ⇒ throw new AssertionError("Identity cannot end up in ActorProcessorFactory") } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/FixedSizeBuffer.scala b/akka-stream/src/main/scala/akka/stream/impl/FixedSizeBuffer.scala index 773e41fc0d..c46b2db8d0 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FixedSizeBuffer.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FixedSizeBuffer.scala @@ -26,6 +26,7 @@ private[akka] object FixedSizeBuffer { else new ModuloFixedSizeBuffer(size) sealed abstract class FixedSizeBuffer[T](val size: Int) { + override def toString = s"Buffer($size, $readIdx, $writeIdx)(${(readIdx until writeIdx).map(get).mkString(", ")})" private val buffer = new Array[AnyRef](size) protected var readIdx = 0 diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index 485c87109d..2640e4d954 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -4,11 +4,14 @@ package akka.stream.impl import akka.event.LoggingAdapter +import akka.stream.ActorAttributes.SupervisionStrategy +import akka.stream.Supervision.Decider import akka.stream.impl.SplitDecision.SplitDecision import akka.stream.impl.StreamLayout._ -import akka.stream.{ OverflowStrategy, Attributes } +import akka.stream._ import akka.stream.Attributes._ -import akka.stream.stage.Stage +import akka.stream.stage.AbstractStage.PushPullGraphStage +import akka.stream.stage.{ GraphStageLogic, GraphStage, Stage } import org.reactivestreams.Processor import scala.collection.immutable import scala.concurrent.Future @@ -96,112 +99,121 @@ private[stream] object Stages { import DefaultAttributes._ + // FIXME: To be deprecated as soon as stream-of-stream operations are stages sealed trait StageModule extends FlowModule[Any, Any, Any] { - def attributes: Attributes def withAttributes(attributes: Attributes): StageModule override def carbonCopy: Module = withAttributes(attributes) } - final case class StageFactory(mkStage: () ⇒ Stage[_, _], attributes: Attributes = stageFactory) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + /* + * Stage that is backed by a GraphStage but can be symbolically introspected + */ + case class SymbolicGraphStage[-In, +Out, Ext](symbolicStage: SymbolicStage[In, Out]) + extends PushPullGraphStage[In, Out, Ext]( + symbolicStage.create, + symbolicStage.attributes) { } - final case class MaterializingStageFactory( - mkStageAndMaterialized: () ⇒ (Stage[_, _], Any), - attributes: Attributes = stageFactory) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + val identityGraph = SymbolicGraphStage[Any, Any, Any](Identity) + + sealed trait SymbolicStage[-In, +Out] { + def attributes: Attributes + def create(effectiveAttributes: Attributes): Stage[In, Out] + + // FIXME: No supervision hooked in yet. + + protected def supervision(attributes: Attributes): Decider = + attributes.get[SupervisionStrategy](SupervisionStrategy(Supervision.stoppingDecider)).decider + } - final case class Identity(attributes: Attributes = identityOp) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + object Identity extends SymbolicStage[Any, Any] { + override val attributes: Attributes = identityOp + + def apply[T]: SymbolicStage[T, T] = this.asInstanceOf[SymbolicStage[T, T]] + + override def create(attr: Attributes): Stage[Any, Any] = fusing.Map(identity, supervision(attr)) } - final case class Map(f: Any ⇒ Any, attributes: Attributes = map) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Map[In, Out](f: In ⇒ Out, attributes: Attributes = map) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.Map(f, supervision(attr)) } - final case class Log(name: String, extract: Any ⇒ Any, loggingAdapter: Option[LoggingAdapter], attributes: Attributes = log) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Log[T](name: String, extract: T ⇒ Any, loggingAdapter: Option[LoggingAdapter], attributes: Attributes = log) extends SymbolicStage[T, T] { + override def create(attr: Attributes): Stage[T, T] = fusing.Log(name, extract, loggingAdapter) } - final case class Filter(p: Any ⇒ Boolean, attributes: Attributes = filter) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Filter[T](p: T ⇒ Boolean, attributes: Attributes = filter) extends SymbolicStage[T, T] { + override def create(attr: Attributes): Stage[T, T] = fusing.Filter(p, supervision(attr)) } - final case class Collect(pf: PartialFunction[Any, Any], attributes: Attributes = collect) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Collect[In, Out](pf: PartialFunction[In, Out], attributes: Attributes = collect) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.Collect(pf, supervision(attr)) } - final case class Recover(pf: PartialFunction[Any, Any], attributes: Attributes = recover) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Recover[In, Out >: In](pf: PartialFunction[Throwable, Out], attributes: Attributes = recover) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.Recover(pf) } - final case class MapAsync(parallelism: Int, f: Any ⇒ Future[Any], attributes: Attributes = mapAsync) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) - } - - final case class MapAsyncUnordered(parallelism: Int, f: Any ⇒ Future[Any], attributes: Attributes = mapAsyncUnordered) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) - } - - final case class Grouped(n: Int, attributes: Attributes = grouped) extends StageModule { + final case class Grouped[T](n: Int, attributes: Attributes = grouped) extends SymbolicStage[T, immutable.Seq[T]] { require(n > 0, "n must be greater than 0") - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Grouped(n) } - final case class Sliding(n: Int, step: Int, attributes: Attributes = sliding) extends StageModule { + final case class Sliding[T](n: Int, step: Int, attributes: Attributes = sliding) extends SymbolicStage[T, immutable.Seq[T]] { require(n > 0, "n must be greater than 0") require(step > 0, "step must be greater than 0") - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + + override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Sliding(n, step) } - final case class Take(n: Long, attributes: Attributes = take) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Take[T](n: Long, attributes: Attributes = take) extends SymbolicStage[T, T] { + override def create(attr: Attributes): Stage[T, T] = fusing.Take(n) } - final case class Drop(n: Long, attributes: Attributes = drop) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Drop[T](n: Long, attributes: Attributes = drop) extends SymbolicStage[T, T] { + override def create(attr: Attributes): Stage[T, T] = fusing.Drop(n) } - final case class TakeWhile(p: Any ⇒ Boolean, attributes: Attributes = takeWhile) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) - + final case class TakeWhile[T](p: T ⇒ Boolean, attributes: Attributes = takeWhile) extends SymbolicStage[T, T] { + override def create(attr: Attributes): Stage[T, T] = fusing.TakeWhile(p, supervision(attr)) } - final case class DropWhile(p: Any ⇒ Boolean, attributes: Attributes = dropWhile) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class DropWhile[T](p: T ⇒ Boolean, attributes: Attributes = dropWhile) extends SymbolicStage[T, T] { + override def create(attr: Attributes): Stage[T, T] = fusing.DropWhile(p, supervision(attr)) } - final case class Scan(zero: Any, f: (Any, Any) ⇒ Any, attributes: Attributes = scan) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out, attributes: Attributes = scan) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.Scan(zero, f, supervision(attr)) } - final case class Fold(zero: Any, f: (Any, Any) ⇒ Any, attributes: Attributes = fold) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Intersperse[T](start: Option[T], inject: T, end: Option[T], attributes: Attributes = intersperse) extends SymbolicStage[T, T] { + override def create(attr: Attributes) = fusing.Intersperse(start, inject, end) } - final case class Intersperse(start: Option[Any], inject: Any, end: Option[Any], attributes: Attributes = intersperse) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out, attributes: Attributes = fold) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.Fold(zero, f, supervision(attr)) } - final case class Buffer(size: Int, overflowStrategy: OverflowStrategy, attributes: Attributes = buffer) extends StageModule { + final case class Buffer[T](size: Int, overflowStrategy: OverflowStrategy, attributes: Attributes = buffer) extends SymbolicStage[T, T] { require(size > 0, s"Buffer size must be larger than zero but was [$size]") - - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + override def create(attr: Attributes): Stage[T, T] = fusing.Buffer(size, overflowStrategy) } - final case class Conflate(seed: Any ⇒ Any, aggregate: (Any, Any) ⇒ Any, attributes: Attributes = conflate) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Conflate[In, Out](seed: In ⇒ Out, aggregate: (Out, In) ⇒ Out, attributes: Attributes = conflate) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.Conflate(seed, aggregate, supervision(attr)) } - final case class Expand(seed: Any ⇒ Any, extrapolate: Any ⇒ (Any, Any), attributes: Attributes = expand) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Expand[In, Out, Seed](seed: In ⇒ Seed, extrapolate: Seed ⇒ (Out, Seed), attributes: Attributes = expand) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.Expand(seed, extrapolate) } - final case class MapConcat(f: Any ⇒ immutable.Iterable[Any], attributes: Attributes = mapConcat) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class MapConcat[In, Out](f: In ⇒ immutable.Iterable[Out], attributes: Attributes = mapConcat) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.MapConcat(f, supervision(attr)) } + // FIXME: These are not yet proper stages, therefore they use the deprecated StageModule infrastructure + final case class GroupBy(f: Any ⇒ Any, attributes: Attributes = groupBy) extends StageModule { override def withAttributes(attributes: Attributes) = copy(attributes = attributes) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala b/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala index d84da3721c..674f4a309e 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala @@ -184,7 +184,7 @@ private[akka] object StreamLayout { downstreams, upstreams, Transform(f, if (this.isSealed) Atomic(this) else this.materializedValueComputation), - attributes) + if (this.isSealed) Attributes.none else attributes) } /** @@ -225,7 +225,7 @@ private[akka] object StreamLayout { upstreams ++ that.upstreams, // would like to optimize away this allocation for Keep.{left,right} but that breaks side-effecting transformations Combine(f.asInstanceOf[(Any, Any) ⇒ Any], matComputation1, matComputation2), - attributes) + Attributes.none) } /** @@ -595,7 +595,7 @@ private[stream] object MaterializerSession { /** * INTERNAL API */ -private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Module) { +private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Module, val initialAttributes: Attributes) { import StreamLayout._ private var subscribersStack: List[mutable.Map[InPort, Subscriber[Any]]] = @@ -653,7 +653,7 @@ private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Mo require( topLevel.isRunnable, s"The top level module cannot be materialized because it has unconnected ports: ${(topLevel.inPorts ++ topLevel.outPorts).mkString(", ")}") - try materializeModule(topLevel, topLevel.attributes) + try materializeModule(topLevel, initialAttributes and topLevel.attributes) catch { case NonFatal(cause) ⇒ // PANIC!!! THE END OF THE MATERIALIZATION IS NEAR! diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala index ad08907bd5..abe0df2276 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala @@ -339,40 +339,41 @@ private[stream] class ActorGraphInterpreter( override def receive: Receive = { // Cases that are most likely on the hot path, in decreasing order of frequency case OnNext(id: Int, e: Any) ⇒ - if (GraphInterpreter.Debug) println(s" onNext $e id=$id") + if (GraphInterpreter.Debug) println(s"${interpreter.Name} onNext $e id=$id") inputs(id).onNext(e) runBatch() case RequestMore(id: Int, demand: Long) ⇒ - if (GraphInterpreter.Debug) println(s" request $demand id=$id") + if (GraphInterpreter.Debug) println(s"${interpreter.Name} request $demand id=$id") outputs(id).requestMore(demand) runBatch() case Resume ⇒ resumeScheduled = false if (interpreter.isSuspended) runBatch() case AsyncInput(logic, event, handler) ⇒ - if (GraphInterpreter.Debug) println(s"ASYNC $event") - if (!interpreter.isStageCompleted(logic.stageId)) { + if (GraphInterpreter.Debug) println(s"${interpreter.Name} ASYNC $event ($handler) [$logic]") + if (!interpreter.isStageCompleted(logic)) { try handler(event) catch { case NonFatal(e) ⇒ logic.failStage(e) } + interpreter.afterStageHasRun(logic) } runBatch() // Initialization and completion messages case OnError(id: Int, cause: Throwable) ⇒ - if (GraphInterpreter.Debug) println(s" onError id=$id") + if (GraphInterpreter.Debug) println(s"${interpreter.Name} onError id=$id") inputs(id).onError(cause) runBatch() case OnComplete(id: Int) ⇒ - if (GraphInterpreter.Debug) println(s" onComplete id=$id") + if (GraphInterpreter.Debug) println(s"${interpreter.Name} onComplete id=$id") inputs(id).onComplete() runBatch() case OnSubscribe(id: Int, subscription: Subscription) ⇒ subscribesPending -= 1 inputs(id).onSubscribe(subscription) case Cancel(id: Int) ⇒ - if (GraphInterpreter.Debug) println(s" cancel id=$id") + if (GraphInterpreter.Debug) println(s"${interpreter.Name} cancel id=$id") outputs(id).cancel() runBatch() case SubscribePending(id: Int) ⇒ diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala deleted file mode 100644 index c9bb2373e7..0000000000 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala +++ /dev/null @@ -1,389 +0,0 @@ -/** - * Copyright (C) 2009-2014 Typesafe Inc. - */ -package akka.stream.impl.fusing - -import akka.actor._ -import akka.stream.impl.ReactiveStreamsCompliance._ -import akka.stream.{ AbruptTerminationException, ActorMaterializerSettings, Attributes, ActorMaterializer } -import akka.stream.actor.ActorSubscriber.OnSubscribe -import akka.stream.actor.ActorSubscriberMessage.{ OnNext, OnError, OnComplete } -import akka.stream.impl._ -import akka.stream.impl.fusing.OneBoundedInterpreter.{ InitializationFailed, InitializationSuccessful } -import akka.stream.stage._ -import org.reactivestreams.{ Subscriber, Subscription } -import akka.event.{ Logging, LoggingAdapter } - -/** - * INTERNAL API - */ -private[akka] class BatchingActorInputBoundary(val size: Int, val name: String) - extends BoundaryStage { - - require(size > 0, "buffer size cannot be zero") - require((size & (size - 1)) == 0, "buffer size must be a power of two") - - // TODO: buffer and batch sizing heuristics - private var upstream: Subscription = _ - private val inputBuffer = Array.ofDim[AnyRef](size) - private var inputBufferElements = 0 - private var nextInputElementCursor = 0 - private var upstreamCompleted = false - private var downstreamWaiting = false - private var downstreamCanceled = false - private val IndexMask = size - 1 - - private def requestBatchSize = math.max(1, inputBuffer.length / 2) - private var batchRemaining = requestBatchSize - - val subreceive: SubReceive = new SubReceive(waitingForUpstream) - - def isFinished = upstreamCompleted && ((upstream ne null) || downstreamCanceled) - - def setDownstreamCanceled(): Unit = downstreamCanceled = true - - private def dequeue(): Any = { - val elem = inputBuffer(nextInputElementCursor) - require(elem ne null) - inputBuffer(nextInputElementCursor) = null - - batchRemaining -= 1 - if (batchRemaining == 0 && !upstreamCompleted) { - tryRequest(upstream, requestBatchSize) - batchRemaining = requestBatchSize - } - - inputBufferElements -= 1 - nextInputElementCursor = (nextInputElementCursor + 1) & IndexMask - elem - } - - private def enqueue(elem: Any): Unit = { - if (OneBoundedInterpreter.Debug) println(f" enq $elem%-19s $name") - if (!upstreamCompleted) { - if (inputBufferElements == size) throw new IllegalStateException("Input buffer overrun") - inputBuffer((nextInputElementCursor + inputBufferElements) & IndexMask) = elem.asInstanceOf[AnyRef] - inputBufferElements += 1 - } - } - - override def onPush(elem: Any, ctx: BoundaryContext): Directive = - throw new UnsupportedOperationException("BUG: Cannot push the upstream boundary") - - override def onPull(ctx: BoundaryContext): Directive = { - if (inputBufferElements > 1) ctx.push(dequeue()) - else if (inputBufferElements == 1) { - if (upstreamCompleted) ctx.pushAndFinish(dequeue()) - else ctx.push(dequeue()) - } else if (upstreamCompleted) { - ctx.finish() - } else { - downstreamWaiting = true - ctx.exit() - } - } - - override def onDownstreamFinish(ctx: BoundaryContext): TerminationDirective = { - cancel() - ctx.finish() - } - - def cancel(): Unit = { - if (!upstreamCompleted) { - upstreamCompleted = true - if (upstream ne null) tryCancel(upstream) - downstreamWaiting = false - clear() - } - } - - private def clear(): Unit = { - java.util.Arrays.fill(inputBuffer, 0, inputBuffer.length, null) - inputBufferElements = 0 - } - - private def onComplete(): Unit = - if (!upstreamCompleted) { - upstreamCompleted = true - // onUpstreamFinish is not back-pressured, stages need to deal with this - if (inputBufferElements == 0) enterAndFinish() - } - - private def onSubscribe(subscription: Subscription): Unit = { - require(subscription != null) - if (upstreamCompleted) - tryCancel(subscription) - else if (downstreamCanceled) { - upstreamCompleted = true - tryCancel(subscription) - } else { - upstream = subscription - // Prefetch - tryRequest(upstream, inputBuffer.length) - subreceive.become(upstreamRunning) - } - } - - // Call this when an error happens that does not come from the usual onError channel - // (exceptions while calling RS interfaces, abrupt termination etc) - def onInternalError(e: Throwable): Unit = { - if (!(upstreamCompleted || downstreamCanceled) && (upstream ne null)) { - upstream.cancel() - } - onError(e) - } - - def onError(e: Throwable): Unit = { - if (!upstreamCompleted) { - upstreamCompleted = true - enterAndFail(e) - } - } - - private def waitingForUpstream: Actor.Receive = { - case OnComplete ⇒ onComplete() - case OnSubscribe(subscription) ⇒ onSubscribe(subscription) - case OnError(cause) ⇒ onError(cause) - } - - private def upstreamRunning: Actor.Receive = { - case OnNext(element) ⇒ - enqueue(element) - if (downstreamWaiting) { - downstreamWaiting = false - enterAndPush(dequeue()) - } - - case OnComplete ⇒ onComplete() - case OnError(cause) ⇒ onError(cause) - case OnSubscribe(subscription) ⇒ tryCancel(subscription) // spec rule 2.5 - } - -} - -private[akka] object ActorOutputBoundary { - /** - * INTERNAL API. - */ - private case object ContinuePulling extends DeadLetterSuppression with NoSerializationVerificationNeeded -} - -/** - * INTERNAL API - */ -private[akka] class ActorOutputBoundary(val actor: ActorRef, - val debugLogging: Boolean, - val log: LoggingAdapter, - val outputBurstLimit: Int) - extends BoundaryStage { - import ReactiveStreamsCompliance._ - import ActorOutputBoundary._ - - private var exposedPublisher: ActorPublisher[Any] = _ - - private var subscriber: Subscriber[Any] = _ - private var downstreamDemand: Long = 0L - // This flag is only used if complete/fail is called externally since this op turns into a Finished one inside the - // interpreter (i.e. inside this op this flag has no effects since if it is completed the op will not be invoked) - private var downstreamCompleted = false - // this is true while we “hold the ball”; while “false” incoming demand will just be queued up - private var upstreamWaiting = true - // when upstream failed before we got the exposed publisher - private var upstreamFailed: Option[Throwable] = None - // the number of elements emitted during a single execution is bounded - private var burstRemaining = outputBurstLimit - - private def tryBounceBall(ctx: BoundaryContext) = { - burstRemaining -= 1 - if (burstRemaining > 0) ctx.pull() - else { - actor ! ContinuePulling - takeBallOut(ctx) - } - } - - private def takeBallOut(ctx: BoundaryContext) = { - upstreamWaiting = true - ctx.exit() - } - - private def tryPutBallIn() = - if (upstreamWaiting) { - burstRemaining = outputBurstLimit - upstreamWaiting = false - enterAndPull() - } - - val subreceive = new SubReceive(waitingExposedPublisher) - - private def onNext(elem: Any): Unit = { - downstreamDemand -= 1 - tryOnNext(subscriber, elem) - } - - private def complete(): Unit = { - if (!downstreamCompleted) { - downstreamCompleted = true - if (exposedPublisher ne null) exposedPublisher.shutdown(None) - if (subscriber ne null) tryOnComplete(subscriber) - } - } - - def fail(e: Throwable): Unit = { - if (!downstreamCompleted) { - downstreamCompleted = true - if (debugLogging) - log.debug("fail due to: {}", e.getMessage) - if (exposedPublisher ne null) exposedPublisher.shutdown(Some(e)) - if ((subscriber ne null) && !e.isInstanceOf[SpecViolation]) tryOnError(subscriber, e) - } else if (exposedPublisher == null && upstreamFailed.isEmpty) { - // fail called before the exposed publisher arrived, we must store it and fail when we're first able to - upstreamFailed = Some(e) - } - } - - override def onPush(elem: Any, ctx: BoundaryContext): Directive = { - onNext(elem) - if (downstreamCompleted) ctx.finish() - else if (downstreamDemand > 0) tryBounceBall(ctx) - else takeBallOut(ctx) - } - - override def onPull(ctx: BoundaryContext): Directive = - throw new UnsupportedOperationException("BUG: Cannot pull the downstream boundary") - - override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = { - complete() - ctx.finish() - } - - override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = { - fail(cause) - ctx.fail(cause) - } - - private def subscribePending(subscribers: Seq[Subscriber[Any]]): Unit = - subscribers foreach { sub ⇒ - if (subscriber eq null) { - subscriber = sub - tryOnSubscribe(subscriber, new ActorSubscription(actor, subscriber)) - } else - rejectAdditionalSubscriber(subscriber, s"${Logging.simpleName(this)}") - } - - protected def waitingExposedPublisher: Actor.Receive = { - case ExposedPublisher(publisher) ⇒ - upstreamFailed match { - case _: Some[_] ⇒ - publisher.shutdown(upstreamFailed) - case _ ⇒ - exposedPublisher = publisher - subreceive.become(downstreamRunning) - } - case other ⇒ - throw new IllegalStateException(s"The first message must be ExposedPublisher but was [$other]") - } - - protected def downstreamRunning: Actor.Receive = { - case SubscribePending ⇒ - subscribePending(exposedPublisher.takePendingSubscribers()) - case RequestMore(subscription, elements) ⇒ - if (elements < 1) { - enterAndFinish() - fail(ReactiveStreamsCompliance.numberOfElementsInRequestMustBePositiveException) - } else { - downstreamDemand += elements - if (downstreamDemand < 0) - downstreamDemand = Long.MaxValue // Long overflow, Reactive Streams Spec 3:17: effectively unbounded - if (OneBoundedInterpreter.Debug) { - val s = s"$downstreamDemand (+$elements)" - println(f" dem $s%-19s ${actor.path}") - } - tryPutBallIn() - } - - case ContinuePulling ⇒ - if (!downstreamCompleted && downstreamDemand > 0) tryPutBallIn() - - case Cancel(subscription) ⇒ - downstreamCompleted = true - subscriber = null - exposedPublisher.shutdown(Some(new ActorPublisher.NormalShutdownException)) - enterAndFinish() - } - -} - -/** - * INTERNAL API - */ -private[akka] object ActorInterpreter { - def props(settings: ActorMaterializerSettings, ops: Seq[Stage[_, _]], materializer: ActorMaterializer, attributes: Attributes = Attributes.none): Props = - Props(new ActorInterpreter(settings, ops, materializer, attributes)).withDeploy(Deploy.local) - - case class AsyncInput(op: AsyncStage[Any, Any, Any], ctx: AsyncContext[Any, Any], event: Any) extends DeadLetterSuppression with NoSerializationVerificationNeeded -} - -/** - * INTERNAL API - */ -private[akka] class ActorInterpreter(val settings: ActorMaterializerSettings, val ops: Seq[Stage[_, _]], val materializer: ActorMaterializer, val attributes: Attributes) - extends Actor with ActorLogging { - import ActorInterpreter._ - - private val upstream = new BatchingActorInputBoundary(settings.initialInputBufferSize, context.self.path.toString) - private val downstream = new ActorOutputBoundary(self, settings.debugLogging, log, settings.outputBurstLimit) - private val interpreter = - new OneBoundedInterpreter(upstream +: ops :+ downstream, - (op, ctx, event) ⇒ self ! AsyncInput(op, ctx, event), - Logging(this), - materializer, - attributes, - name = context.self.path.toString) - - interpreter.init() match { - case failed: InitializationFailed ⇒ - // the Actor will be stopped thanks to aroundReceive checking interpreter.isFinished - upstream.setDownstreamCanceled() - downstream.fail(failed.mostDownstream.ex) - case InitializationSuccessful ⇒ // ok - } - - def receive: Receive = - upstream.subreceive - .orElse[Any, Unit](downstream.subreceive) - .orElse[Any, Unit] { - case AsyncInput(op, ctx, event) ⇒ - ctx.enter() - op.onAsyncInput(event, ctx) - ctx.execute() - } - - override protected[akka] def aroundReceive(receive: Actor.Receive, msg: Any): Unit = { - super.aroundReceive(receive, msg) - - if (interpreter.isFinished) { - if (upstream.isFinished) context.stop(self) - else upstream.setDownstreamCanceled() - } - } - - override def postStop(): Unit = { - // This should handle termination while interpreter is running. If the upstream have been closed already this - // call has no effect and therefore do the right thing: nothing. - val ex = AbruptTerminationException(self) - try upstream.onInternalError(ex) - // Will only have an effect if the above call to the interpreter failed to emit a proper failure to the downstream - // otherwise this will have no effect - finally { - downstream.fail(ex) - upstream.cancel() - } - } - - override def postRestart(reason: Throwable): Unit = { - super.postRestart(reason) - throw new IllegalStateException("This actor cannot be restarted", reason) - } - -} diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala index c3a420bdb9..f03efe3c4a 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala @@ -7,9 +7,9 @@ import java.util.Arrays import akka.event.LoggingAdapter import akka.stream.stage._ -import akka.stream.{ Materializer, Shape, Inlet, Outlet } import scala.annotation.tailrec import scala.collection.immutable +import akka.stream._ import scala.util.control.NonFatal /** @@ -55,12 +55,17 @@ private[stream] object GraphInterpreter { def in: Inlet[T] } + val singleNoAttribute: Array[Attributes] = Array(Attributes.none) + /** * INTERNAL API * * A GraphAssembly represents a small stream processing graph to be executed by the interpreter. Instances of this * class **must not** be mutated after construction. * + * The array ``originalAttributes`` may contain the attribute information of the original atomic module, otherwise + * it must contain a none (otherwise the enclosing module could not overwrite attributes defined in this array). + * * The arrays [[ins]] and [[outs]] correspond to the notion of a *connection* in the [[GraphInterpreter]]. Each slot * *i* contains the input and output port corresponding to connection *i*. Slots where the graph is not closed (i.e. * ports are exposed to the external world) are marked with *null* values. For example if an input port *p* is @@ -88,6 +93,7 @@ private[stream] object GraphInterpreter { * */ final class GraphAssembly(val stages: Array[GraphStageWithMaterializedValue[Shape, Any]], + val originalAttributes: Array[Attributes], val ins: Array[Inlet[_]], val inOwners: Array[Int], val outs: Array[Outlet[_]], @@ -106,7 +112,7 @@ private[stream] object GraphInterpreter { * - array of the logics * - materialized value */ - def materialize(): (Array[InHandler], Array[OutHandler], Array[GraphStageLogic], Any) = { + def materialize(inheritedAttributes: Attributes): (Array[InHandler], Array[OutHandler], Array[GraphStageLogic], Any) = { val logics = Array.ofDim[GraphStageLogic](stages.length) var finalMat: Any = () @@ -134,7 +140,7 @@ private[stream] object GraphInterpreter { } // FIXME: Support for materialized values in fused islands is not yet figured out! - val logicAndMat = stages(i).createLogicAndMaterializedValue + val logicAndMat = stages(i).createLogicAndMaterializedValue(inheritedAttributes and originalAttributes(i)) // FIXME: Current temporary hack to support non-fused stages. If there is one stage that will be under index 0. if (i == 0) finalMat = logicAndMat._2 @@ -148,20 +154,21 @@ private[stream] object GraphInterpreter { i = 0 while (i < connectionCount) { if (ins(i) ne null) { - val l = logics(inOwners(i)) - l.inHandlers(ins(i).id) match { - case null ⇒ throw new IllegalStateException(s"no handler defined in stage $l for port ${ins(i)}") - case h ⇒ inHandlers(i) = h + val logic = logics(inOwners(i)) + logic.handlers(ins(i).id) match { + case null ⇒ throw new IllegalStateException(s"no handler defined in stage $logic for port ${ins(i)}") + case h: InHandler ⇒ inHandlers(i) = h } - l.inToConn(ins(i).id) = i + logics(inOwners(i)).portToConn(ins(i).id) = i } if (outs(i) ne null) { - val l = logics(outOwners(i)) - l.outHandlers(outs(i).id) match { - case null ⇒ throw new IllegalStateException(s"no handler defined in stage $l for port ${outs(i)}") - case h ⇒ outHandlers(i) = h + val logic = logics(outOwners(i)) + val inCount = logic.inCount + logic.handlers(outs(i).id + inCount) match { + case null ⇒ throw new IllegalStateException(s"no handler defined in stage $logic for port ${outs(i)}") + case h: OutHandler ⇒ outHandlers(i) = h } - l.outToConn(outs(i).id) = i + logic.portToConn(outs(i).id + inCount) = i } i += 1 } @@ -206,6 +213,7 @@ private[stream] object GraphInterpreter { val assembly = new GraphAssembly( stages.toArray, + GraphInterpreter.singleNoAttribute, add(inlets.iterator, Array.ofDim(connectionCount), 0), markBoundary(Array.ofDim(connectionCount), inletsSize, connectionCount), add(outlets.iterator, Array.ofDim(connectionCount), inletsSize), @@ -288,7 +296,7 @@ private[stream] object GraphInterpreter { * * Because of the FIFO construction of the queue the interpreter is fair, i.e. a pending event is always executed * after a bounded number of other events. This property, together with suspendability means that even infinite cycles can - * be modeled, or even dissolved (if preempted and a "stealing" external even is injected; for example the non-cycle + * be modeled, or even dissolved (if preempted and a "stealing" external event is injected; for example the non-cycle * edge of a balance is pulled, dissolving the original cycle). */ private[stream] final class GraphInterpreter( @@ -309,7 +317,7 @@ private[stream] final class GraphInterpreter( // of the class for a full description. val portStates = Array.fill[Int](assembly.connectionCount)(InReady) - private[this] var activeStageId = Boundary + private[this] var activeStage: GraphStageLogic = _ // The number of currently running stages. Once this counter reaches zero, the interpreter is considered to be // completed @@ -323,19 +331,33 @@ private[stream] final class GraphInterpreter( // An event queue implemented as a circular buffer // FIXME: This calculates the maximum size ever needed, but most assemblies can run on a smaller queue - private[this] val eventQueue = Array.ofDim[Int](1 << Integer.highestOneBit(assembly.connectionCount)) + private[this] val eventQueue = Array.ofDim[Int](1 << (32 - Integer.numberOfLeadingZeros(assembly.connectionCount - 1))) private[this] val mask = eventQueue.length - 1 private[this] var queueHead: Int = 0 private[this] var queueTail: Int = 0 + private def queueStatus: String = { + val contents = (queueHead until queueTail).map(idx ⇒ { + val conn = eventQueue(idx & mask) + (conn, portStates(conn), connectionSlots(conn)) + }) + s"(${eventQueue.length}, $queueHead, $queueTail)(${contents.mkString(", ")})" + } + private[this] var _Name: String = _ + def Name: String = + if (_Name eq null) { + _Name = f"${System.identityHashCode(this)}%08X" + _Name + } else _Name + /** * Assign the boundary logic to a given connection. This will serve as the interface to the external world * (outside the interpreter) to process and inject events. */ def attachUpstreamBoundary(connection: Int, logic: UpstreamBoundaryStageLogic[_]): Unit = { - logic.outToConn(logic.out.id) = connection + logic.portToConn(logic.out.id + logic.inCount) = connection logic.interpreter = this - outHandlers(connection) = logic.outHandlers(0) + outHandlers(connection) = logic.handlers(0).asInstanceOf[OutHandler] } /** @@ -343,16 +365,16 @@ private[stream] final class GraphInterpreter( * (outside the interpreter) to process and inject events. */ def attachDownstreamBoundary(connection: Int, logic: DownstreamBoundaryStageLogic[_]): Unit = { - logic.inToConn(logic.in.id) = connection + logic.portToConn(logic.in.id) = connection logic.interpreter = this - inHandlers(connection) = logic.inHandlers(0) + inHandlers(connection) = logic.handlers(0).asInstanceOf[InHandler] } /** * Dynamic handler changes are communicated from a GraphStageLogic by this method. */ def setHandler(connection: Int, handler: InHandler): Unit = { - if (GraphInterpreter.Debug) println(s"SETHANDLER ${inOwnerName(connection)} (in) $handler") + if (Debug) println(s"$Name SETHANDLER ${inOwnerName(connection)} (in) $handler") inHandlers(connection) = handler } @@ -360,7 +382,7 @@ private[stream] final class GraphInterpreter( * Dynamic handler changes are communicated from a GraphStageLogic by this method. */ def setHandler(connection: Int, handler: OutHandler): Unit = { - if (GraphInterpreter.Debug) println(s"SETHANDLER ${outOwnerName(connection)} (out) $handler") + if (Debug) println(s"$Name SETHANDLER ${outOwnerName(connection)} (out) $handler") outHandlers(connection) = handler } @@ -389,6 +411,7 @@ private[stream] final class GraphInterpreter( } catch { case NonFatal(e) ⇒ logic.failStage(e) } + afterStageHasRun(logic) i += 1 } } @@ -399,7 +422,8 @@ private[stream] final class GraphInterpreter( def finish(): Unit = { var i = 0 while (i < logics.length) { - if (!isStageCompleted(i)) finalizeStage(logics(i)) + val logic = logics(i) + if (!isStageCompleted(logic)) finalizeStage(logic) i += 1 } } @@ -418,84 +442,97 @@ private[stream] final class GraphInterpreter( case owner ⇒ assembly.stages(owner).toString } + // Debug name for a connections input part + private def inLogicName(connection: Int): String = + assembly.inOwners(connection) match { + case Boundary ⇒ "DownstreamBoundary" + case owner ⇒ logics(owner).toString + } + + // Debug name for a connections ouput part + private def outLogicName(connection: Int): String = + assembly.outOwners(connection) match { + case Boundary ⇒ "UpstreamBoundary" + case owner ⇒ logics(owner).toString + } + /** * Executes pending events until the given limit is met. If there were remaining events, isSuspended will return * true. */ def execute(eventLimit: Int): Unit = { - if (GraphInterpreter.Debug) println("---------------- EXECUTE") + if (Debug) println(s"$Name ---------------- EXECUTE (running=$runningStages, shutdown=${shutdownCounter.mkString(",")})") var eventsRemaining = eventLimit var connection = dequeue() while (eventsRemaining > 0 && connection != NoEvent) { try processEvent(connection) catch { case NonFatal(e) ⇒ - if (activeStageId == Boundary) throw e - else logics(activeStageId).failStage(e) + if (activeStage == null) throw e + else activeStage.failStage(e) } + afterStageHasRun(activeStage) eventsRemaining -= 1 if (eventsRemaining > 0) connection = dequeue() } + if (Debug) println(s"$Name ---------------- $queueStatus (running=$runningStages, shutdown=${shutdownCounter.mkString(",")})") // TODO: deadlock detection } // Decodes and processes a single event for the given connection private def processEvent(connection: Int): Unit = { + def safeLogics(id: Int) = + if (id == Boundary) null + else logics(id) - def processElement(elem: Any): Unit = { - if (GraphInterpreter.Debug) println(s"PUSH ${outOwnerName(connection)} -> ${inOwnerName(connection)}, $elem (${inHandlers(connection)})") - activeStageId = assembly.inOwners(connection) + def processElement(): Unit = { + if (Debug) println(s"$Name PUSH ${outOwnerName(connection)} -> ${inOwnerName(connection)}, ${connectionSlots(connection)} (${inHandlers(connection)}) [${inLogicName(connection)}]") + activeStage = safeLogics(assembly.inOwners(connection)) portStates(connection) ^= PushEndFlip inHandlers(connection).onPush() } + // this must be the state after returning without delivering any signals, to avoid double-finalization of some unlucky stage + // (this can happen if a stage completes voluntarily while connection close events are still queued) + activeStage = null val code = portStates(connection) // Manual fast decoding, fast paths are PUSH and PULL // PUSH if ((code & (Pushing | InClosed | OutClosed)) == Pushing) { - processElement(connectionSlots(connection)) + processElement() // PULL } else if ((code & (Pulling | OutClosed | InClosed)) == Pulling) { - if (GraphInterpreter.Debug) println(s"PULL ${inOwnerName(connection)} -> ${outOwnerName(connection)} (${outHandlers(connection)})") + if (Debug) println(s"$Name PULL ${inOwnerName(connection)} -> ${outOwnerName(connection)} (${outHandlers(connection)}) [${outLogicName(connection)}]") portStates(connection) ^= PullEndFlip - activeStageId = assembly.outOwners(connection) + activeStage = safeLogics(assembly.outOwners(connection)) outHandlers(connection).onPull() // CANCEL } else if ((code & (OutClosed | InClosed)) == InClosed) { val stageId = assembly.outOwners(connection) - if (GraphInterpreter.Debug) println(s"CANCEL ${inOwnerName(connection)} -> ${outOwnerName(connection)} (${outHandlers(connection)})") + activeStage = safeLogics(stageId) + if (Debug) println(s"$Name CANCEL ${inOwnerName(connection)} -> ${outOwnerName(connection)} (${outHandlers(connection)}) [${outLogicName(connection)}]") portStates(connection) |= OutClosed - activeStageId = assembly.outOwners(connection) - outHandlers(connection).onDownstreamFinish() completeConnection(stageId) + outHandlers(connection).onDownstreamFinish() } else if ((code & (OutClosed | InClosed)) == OutClosed) { // COMPLETIONS - val stageId = assembly.inOwners(connection) - if ((code & Pushing) == 0) { // Normal completion (no push pending) - if (GraphInterpreter.Debug) println(s"COMPLETE ${outOwnerName(connection)} -> ${inOwnerName(connection)} (${inHandlers(connection)})") + if (Debug) println(s"$Name COMPLETE ${outOwnerName(connection)} -> ${inOwnerName(connection)} (${inHandlers(connection)}) [${inLogicName(connection)}]") portStates(connection) |= InClosed - activeStageId = assembly.inOwners(connection) + val stageId = assembly.inOwners(connection) + activeStage = safeLogics(stageId) + completeConnection(stageId) if ((portStates(connection) & InFailed) == 0) inHandlers(connection).onUpstreamFinish() else inHandlers(connection).onUpstreamFailure(connectionSlots(connection).asInstanceOf[Failed].ex) - completeConnection(stageId) } else { // Push is pending, first process push, then re-enqueue closing event - // Non-failure case - val code = portStates(connection) & (InClosed | InFailed) - if (code == 0) { - processElement(connectionSlots(connection)) - enqueue(connection) - } else if (code == InFailed) { - // Failure case - processElement(connectionSlots(connection).asInstanceOf[Failed].previousElem) - enqueue(connection) - } + processElement() + enqueue(connection) } } @@ -513,26 +550,26 @@ private[stream] final class GraphInterpreter( } private def enqueue(connection: Int): Unit = { + if (Debug) if (queueTail - queueHead > mask) new Exception(s"$Name internal queue full ($queueStatus) + $connection").printStackTrace() eventQueue(queueTail & mask) = connection queueTail += 1 } + def afterStageHasRun(logic: GraphStageLogic): Unit = + if (isStageCompleted(logic)) { + runningStages -= 1 + finalizeStage(logic) + } + // Returns true if the given stage is alredy completed - def isStageCompleted(stageId: Int): Boolean = stageId != Boundary && shutdownCounter(stageId) == 0 + def isStageCompleted(stage: GraphStageLogic): Boolean = stage != null && shutdownCounter(stage.stageId) == 0 // Register that a connection in which the given stage participated has been completed and therefore the stage // itself might stop, too. private def completeConnection(stageId: Int): Unit = { if (stageId != Boundary) { val activeConnections = shutdownCounter(stageId) - if (activeConnections > 0) { - shutdownCounter(stageId) = activeConnections - 1 - // This was the last active connection keeping this stage alive - if (activeConnections == 1) { - runningStages -= 1 - finalizeStage(logics(stageId)) - } - } + if (activeConnections > 0) shutdownCounter(stageId) = activeConnections - 1 } } @@ -564,35 +601,32 @@ private[stream] final class GraphInterpreter( private[stream] def complete(connection: Int): Unit = { val currentState = portStates(connection) + if (Debug) println(s"$Name complete($connection) [$currentState]") portStates(connection) = currentState | OutClosed - - if ((currentState & (InClosed | Pushing)) == 0) { - enqueue(connection) - } - - completeConnection(assembly.outOwners(connection)) + if ((currentState & (InClosed | Pushing | Pulling)) == 0) enqueue(connection) + if ((currentState & OutClosed) == 0) completeConnection(assembly.outOwners(connection)) } private[stream] def fail(connection: Int, ex: Throwable): Unit = { val currentState = portStates(connection) + if (Debug) println(s"$Name fail($connection, $ex) [$currentState]") portStates(connection) = currentState | (OutClosed | InFailed) if ((currentState & InClosed) == 0) { connectionSlots(connection) = Failed(ex, connectionSlots(connection)) - enqueue(connection) + if ((currentState & (Pulling | Pushing)) == 0) enqueue(connection) } - - completeConnection(assembly.outOwners(connection)) + if ((currentState & OutClosed) == 0) completeConnection(assembly.outOwners(connection)) } private[stream] def cancel(connection: Int): Unit = { val currentState = portStates(connection) + if (Debug) println(s"$Name cancel($connection) [$currentState]") portStates(connection) = currentState | InClosed if ((currentState & OutClosed) == 0) { connectionSlots(connection) = Empty - enqueue(connection) + if ((currentState & (Pulling | Pushing)) == 0) enqueue(connection) } - - completeConnection(assembly.inOwners(connection)) + if ((currentState & InClosed) == 0) completeConnection(assembly.inOwners(connection)) } -} \ No newline at end of file +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala index 6370e2a6a2..412105e165 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala @@ -24,21 +24,18 @@ object GraphStages { val in = Inlet[T]("in") val out = Outlet[T]("out") override val shape = FlowShape(in, out) - - protected abstract class SimpleLinearStageLogic extends GraphStageLogic(shape) { - setHandler(out, new OutHandler { - override def onPull(): Unit = pull(in) - }) - } - } class Identity[T] extends SimpleLinearGraphStage[T] { - override def createLogic: GraphStageLogic = new SimpleLinearStageLogic() { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { setHandler(in, new InHandler { override def onPush(): Unit = push(out, grab(in)) }) + + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) } override def toString = "Identity" @@ -49,7 +46,7 @@ object GraphStages { val out = Outlet[T]("out") override val shape = FlowShape(in, out) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { var initialized = false setHandler(in, new InHandler { @@ -99,13 +96,13 @@ object GraphStages { val out = Outlet[T]("TimerSource.out") override val shape = SourceShape(out) - override def createLogicAndMaterializedValue: (GraphStageLogic, Cancellable) = { + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Cancellable) = { import TickSource._ val cancelled = new AtomicBoolean(false) val cancellable = new TickSourceCancellable(cancelled) - val logic = new GraphStageLogic(shape) { + val logic = new TimerGraphStageLogic(shape) { override def preStart() = { schedulePeriodicallyWithInitialDelay("TickTimer", initialDelay, interval) val callback = getAsyncCallback[Unit]((_) ⇒ { diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala index d73eb01231..10f59bab37 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala @@ -3,64 +3,71 @@ */ package akka.stream.impl.fusing -import akka.event.NoLogging +import akka.event.{ Logging, NoLogging } import akka.stream._ +import akka.stream.impl.fusing.GraphInterpreter.{ GraphAssembly, DownstreamBoundaryStageLogic, UpstreamBoundaryStageLogic } +import akka.stream.stage.AbstractStage.PushPullGraphStage import akka.stream.stage._ /** * INTERNAL API */ private[akka] object IteratorInterpreter { - final case class IteratorUpstream[T](input: Iterator[T]) extends PushPullStage[T, T] { + + final case class IteratorUpstream[T](input: Iterator[T]) extends UpstreamBoundaryStageLogic[T] { + val out: Outlet[T] = Outlet[T]("IteratorUpstream.out") + out.id = 0 + private var hasNext = input.hasNext - override def onPush(elem: T, ctx: Context[T]): SyncDirective = - throw new UnsupportedOperationException("IteratorUpstream operates as a source, it cannot be pushed") - - override def onPull(ctx: Context[T]): SyncDirective = { - if (!hasNext) ctx.finish() - else { - val elem = input.next() - hasNext = input.hasNext - if (!hasNext) ctx.pushAndFinish(elem) - else ctx.push(elem) + setHandler(out, new OutHandler { + override def onPull(): Unit = { + if (!hasNext) completeStage() + else { + val elem = input.next() + hasNext = input.hasNext + if (!hasNext) { + push(out, elem) + complete(out) + } else push(out, elem) + } } + }) - } - - // don't let toString consume the iterator - override def toString: String = "IteratorUpstream" + override def toString = "IteratorUpstream" } - final case class IteratorDownstream[T]() extends BoundaryStage with Iterator[T] { + final case class IteratorDownstream[T]() extends DownstreamBoundaryStageLogic[T] with Iterator[T] { + val in: Inlet[T] = Inlet[T]("IteratorDownstream.in") + in.id = 0 + private var done = false private var nextElem: T = _ private var needsPull = true private var lastFailure: Throwable = null - override def onPush(elem: Any, ctx: BoundaryContext): Directive = { - nextElem = elem.asInstanceOf[T] - needsPull = false - ctx.exit() - } + setHandler(in, new InHandler { + override def onPush(): Unit = { + nextElem = grab(in) + needsPull = false + } - override def onPull(ctx: BoundaryContext): Directive = - throw new UnsupportedOperationException("IteratorDownstream operates as a sink, it cannot be pulled") + override def onUpstreamFinish(): Unit = { + done = true + completeStage() + } - override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = { - done = true - ctx.finish() - } - - override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = { - done = true - lastFailure = cause - ctx.finish() - } + override def onUpstreamFailure(cause: Throwable): Unit = { + done = true + lastFailure = cause + completeStage() + } + }) private def pullIfNeeded(): Unit = { if (needsPull) { - enterAndPull() // will eventually result in a finish, or an onPush which exits + pull(in) + interpreter.execute(Int.MaxValue) } } @@ -84,8 +91,8 @@ private[akka] object IteratorInterpreter { // don't let toString consume the iterator override def toString: String = "IteratorDownstream" - } + } /** @@ -96,11 +103,52 @@ private[akka] class IteratorInterpreter[I, O](val input: Iterator[I], val ops: S private val upstream = IteratorUpstream(input) private val downstream = IteratorDownstream[O]() - private val interpreter = new OneBoundedInterpreter(upstream +: ops.asInstanceOf[Seq[Stage[_, _]]] :+ downstream, - (op, ctx, evt) ⇒ throw new UnsupportedOperationException("IteratorInterpreter is fully synchronous"), - NoLogging, - NoMaterializer) - interpreter.init() + + private def init(): Unit = { + import GraphInterpreter.Boundary + + var i = 0 + val length = ops.length + val attributes = Array.fill[Attributes](ops.length)(Attributes.none) + val ins = Array.ofDim[Inlet[_]](length + 1) + val inOwners = Array.ofDim[Int](length + 1) + val outs = Array.ofDim[Outlet[_]](length + 1) + val outOwners = Array.ofDim[Int](length + 1) + val stages = Array.ofDim[GraphStageWithMaterializedValue[Shape, Any]](length) + + ins(ops.length) = null + inOwners(ops.length) = Boundary + outs(0) = null + outOwners(0) = Boundary + + val opsIterator = ops.iterator + while (opsIterator.hasNext) { + val op = opsIterator.next().asInstanceOf[Stage[Any, Any]] + val stage = new PushPullGraphStage((_) ⇒ op, Attributes.none) + stages(i) = stage + ins(i) = stage.shape.inlet + inOwners(i) = i + outs(i + 1) = stage.shape.outlet + outOwners(i + 1) = i + i += 1 + } + val assembly = new GraphAssembly(stages, attributes, ins, inOwners, outs, outOwners) + + val (inHandlers, outHandlers, logics, _) = assembly.materialize(Attributes.none) + val interpreter = new GraphInterpreter( + assembly, + NoMaterializer, + NoLogging, + inHandlers, + outHandlers, + logics, + (_, _, _) ⇒ throw new UnsupportedOperationException("IteratorInterpreter does not support asynchronous events.")) + interpreter.attachUpstreamBoundary(0, upstream) + interpreter.attachDownstreamBoundary(ops.length, downstream) + interpreter.init() + } + + init() def iterator: Iterator[O] = downstream } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/OneBoundedInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/OneBoundedInterpreter.scala deleted file mode 100644 index b8551d20fb..0000000000 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/OneBoundedInterpreter.scala +++ /dev/null @@ -1,747 +0,0 @@ -/** - * Copyright (C) 2009-2014 Typesafe Inc. - */ -package akka.stream.impl.fusing - -import akka.event.LoggingAdapter -import akka.stream.impl.ReactiveStreamsCompliance -import akka.stream.stage._ -import akka.stream.{ Materializer, Attributes, Supervision } - -import scala.annotation.{ switch, tailrec } -import scala.collection.{ breakOut, immutable } -import scala.util.control.NonFatal - -/** - * INTERNAL API - */ -private[akka] object OneBoundedInterpreter { - final val Debug = false - - /** INTERNAL API */ - private[akka] sealed trait InitializationStatus - /** INTERNAL API */ - private[akka] case object InitializationSuccessful extends InitializationStatus - /** INTERNAL API */ - private[akka] final case class InitializationFailed(failures: immutable.Seq[InitializationFailure]) extends InitializationStatus { - // exceptions are reverse ordered here, below methods help to avoid confusion when used from the outside - def mostUpstream = failures.last - def mostDownstream = failures.head - } - - /** INTERNAL API */ - private[akka] case class InitializationFailure(op: Int, ex: Throwable) - - /** - * INTERNAL API - * - * This artificial op is used as a boundary to prevent two forked paths of execution (complete, cancel) to cross - * paths again. When finishing an op this op is injected in its place to isolate upstream and downstream execution - * domains. - */ - private[akka] object Finished extends BoundaryStage { - override def onPush(elem: Any, ctx: BoundaryContext): UpstreamDirective = ctx.finish() - override def onPull(ctx: BoundaryContext): DownstreamDirective = ctx.finish() - override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = ctx.exit() - override def onDownstreamFinish(ctx: BoundaryContext): TerminationDirective = ctx.exit() - override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = ctx.exit() - } - - /** - * INTERNAL API - * - * This artificial op is used as a boundary to prevent the first forked onPush of execution of a pushFinish to enter - * the originating stage again. This stage allows the forked upstream onUpstreamFinish to pass through if there was - * no onPull called on the stage. Calling onPull on this op makes it a Finished op, which absorbs the - * onUpstreamTermination, but otherwise onUpstreamTermination results in calling finish() - */ - private[akka] object PushFinished extends BoundaryStage { - override def onPush(elem: Any, ctx: BoundaryContext): UpstreamDirective = ctx.finish() - override def onPull(ctx: BoundaryContext): DownstreamDirective = ctx.finish() - // This allows propagation of an onUpstreamFinish call. Note that if onPull has been called on this stage - // before, then the call ctx.finish() in onPull already turned this op to a normal Finished, i.e. it will no longer - // propagate onUpstreamFinish. - override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = ctx.finish() - override def onDownstreamFinish(ctx: BoundaryContext): TerminationDirective = ctx.exit() - override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = ctx.exit() - } -} - -/** - * INTERNAL API - * - * One-bounded interpreter for a linear chain of stream operations (graph support is possible and will be implemented - * later) - * - * The ideas in this interpreter are an amalgamation of earlier ideas, notably: - * - The original effect-tracking implementation by Johannes Rudolph -- the difference here that effects are not chained - * together as classes but the callstack is used instead and only certain combinations are allowed. - * - The on-stack reentrant implementation by Mathias Doenitz -- the difference here that reentrancy is handled by the - * interpreter itself, not user code, and the interpreter is able to use the heap when needed instead of the - * callstack. - * - The pinball interpreter by Endre Sándor Varga -- the difference here that the restriction for "one ball" is - * lifted by using isolated execution regions, completion handling is introduced and communication with the external - * world is done via boundary ops. - * - * The design goals/features of this interpreter are: - * - bounded callstack and heapless execution whenever possible - * - callstack usage should be constant for the most common ops independently of the size of the op-chain - * - allocation-free execution on the hot paths - * - enforced backpressure-safety (boundedness) on user defined ops at compile-time (and runtime in a few cases) - * - * The main driving idea of this interpreter is the concept of 1-bounded execution of well-formed free choice Petri - * nets (J. Desel and J. Esparza: Free Choice Petri Nets - https://www7.in.tum.de/~esparza/bookfc.html). Technically - * different kinds of operations partition the chain of ops into regions where *exactly one* event is active all the - * time. This "exactly one" property is enforced by proper types and runtime checks where needed. Currently there are - * three kinds of ops: - * - * - PushPullStage implementations participate in 1-bounded regions. For every external non-completion signal these - * ops produce *exactly one* signal (completion is different, explained later) therefore keeping the number of events - * the same: exactly one. - * - * - DetachedStage implementations are boundaries between 1-bounded regions. This means that they need to enforce the - * "exactly one" property both on their upstream and downstream regions. As a consequence a DetachedStage can never - * answer an onPull with a ctx.pull() or answer an onPush() with a ctx.push() since such an action would "steal" - * the event from one region (resulting in zero signals) and would inject it to the other region (resulting in two - * signals). However DetachedStages have the ability to call ctx.hold() as a response to onPush/onPull which temporarily - * takes the signal off and stops execution, at the same time putting the op in a "holding" state. If the op is in a - * holding state it contains one absorbed signal, therefore in this state the only possible command to call is - * ctx.pushAndPull() which results in two events making the balance right again: - * 1 hold + 1 external event = 2 external event - * This mechanism allows synchronization between the upstream and downstream regions which otherwise can progress - * independently. - * - * - BoundaryStage implementations are meant to communicate with the external world. These ops do not have most of the - * safety properties enforced and should be used carefully. One important ability of BoundaryStages that they can take - * off an execution signal by calling ctx.exit(). This is typically used immediately after an external signal has - * been produced (for example an actor message). BoundaryStages can also kickstart execution by calling enter() which - * returns a context they can use to inject signals into the interpreter. There is no checks in place to enforce that - * the number of signals taken out by exit() and the number of signals returned via enter() are the same -- using this - * op type needs extra care from the implementer. - * BoundaryStages are the elements that make the interpreter *tick*, there is no other way to start the interpreter - * than using a BoundaryStage. - * - * Operations are allowed to do early completion and cancel/complete their upstreams and downstreams. It is *not* - * allowed however to do these independently to avoid isolated execution islands. The only call possible is ctx.finish() - * which is a combination of cancel/complete. - * Since onComplete is not a backpressured signal it is sometimes preferable to push a final element and then immediately - * finish. This combination is exposed as pushAndFinish() which enables op writers to propagate completion events without - * waiting for an extra round of pull. - * Another peculiarity is how to convert termination events (complete/failure) into elements. The problem - * here is that the termination events are not backpressured while elements are. This means that simply calling ctx.push() - * as a response to onUpstreamFinished() will very likely break boundedness and result in a buffer overflow somewhere. - * Therefore the only allowed command in this case is ctx.absorbTermination() which stops the propagation of the - * termination signal, and puts the op in a finishing state. Depending on whether the op has a pending pull signal it has - * not yet "consumed" by a push its onPull() handler might be called immediately. - * - * In order to execute different individual execution regions the interpreter uses the callstack to schedule these. The - * current execution forking operations are - * - ctx.finish() which starts a wave of completion and cancellation in two directions. When an op calls finish() - * it is immediately replaced by an artificial Finished op which makes sure that the two execution paths are isolated - * forever. - * - ctx.fail() which is similar to finish() - * - ctx.pushAndPull() which (as a response to a previous ctx.hold()) starts a wave of downstream push and upstream - * pull. The two execution paths are isolated by the op itself since onPull() from downstream can only be answered by hold or - * push, while onPush() from upstream can only answered by hold or pull -- it is impossible to "cross" the op. - * - ctx.pushAndFinish() which is different from the forking ops above because the execution of push and finish happens on - * the same execution region and they are order dependent, too. - * The interpreter tracks the depth of recursive forking and allows various strategies of dealing with the situation - * when this depth reaches a certain limit. In the simplest case a failure is reported (this is very useful for stress - * testing and finding callstack wasting bugs), in the other case the forked call is scheduled via a list -- i.e. instead - * of the stack the heap is used. - */ -private[akka] class OneBoundedInterpreter(ops: Seq[Stage[_, _]], - onAsyncInput: (AsyncStage[Any, Any, Any], AsyncContext[Any, Any], Any) ⇒ Unit, - log: LoggingAdapter, - materializer: Materializer, - attributes: Attributes = Attributes.none, - val forkLimit: Int = 100, - val overflowToHeap: Boolean = true, - val name: String = "") { - import AbstractStage._ - import OneBoundedInterpreter._ - - type UntypedOp = AbstractStage[Any, Any, Directive, Directive, Context[Any], LifecycleContext] - require(ops.nonEmpty, "OneBoundedInterpreter cannot be created without at least one Op") - - private final val pipeline: Array[UntypedOp] = ops.map(_.asInstanceOf[UntypedOp])(breakOut) - - /** - * This table is used to accelerate demand propagation upstream. All ops that implement PushStage are guaranteed - * to only do upstream propagation of demand signals, therefore it is not necessary to execute them but enough to - * "jump over" them. This means that when a chain of one million maps gets a downstream demand it is propagated - * to the upstream *in one step* instead of one million onPull() calls. - * This table maintains the positions where execution should jump from a current position when a pull event is to - * be executed. - */ - private final val jumpBacks: Array[Int] = calculateJumpBacks - - private final val Upstream = 0 - private final val Downstream = pipeline.length - 1 - - // Var to hold the current element if pushing. The only reason why this var is needed is to avoid allocations and - // make it possible for the Pushing state to be an object - private var elementInFlight: Any = _ - // Points to the current point of execution inside the pipeline - private var activeOpIndex = -1 - // Points to the last point of exit - private var lastExitedIndex = Downstream - // The current interpreter state that decides what happens at the next round - private var state: State = _ - - // Counter that keeps track of the depth of recursive forked executions - private var forkCount = 0 - // List that is used as an auxiliary stack if fork recursion depth reaches forkLimit - private var overflowStack = List.empty[(Int, State, Any)] - - private var lastOpFailing: Int = -1 - - private def pipeName(op: UntypedOp): String = { - val o = op: AbstractStage[_, _, _, _, _, _] - (o match { - case Finished ⇒ "finished" - case _: BoundaryStage ⇒ "boundary" - case _: StatefulStage[_, _] ⇒ "stateful" - case _: PushStage[_, _] ⇒ "push" - case _: PushPullStage[_, _] ⇒ "pushpull" - case _: DetachedStage[_, _] ⇒ "detached" - case _ ⇒ "other" - }) + f"(${o.bits}%04X)" - } - override def toString = - s"""|OneBoundedInterpreter($name) - | pipeline = ${pipeline map pipeName mkString ":"} - | lastExit=$lastExitedIndex activeOp=$activeOpIndex state=$state elem=$elementInFlight forks=$forkCount""".stripMargin - - @inline private def currentOp: UntypedOp = pipeline(activeOpIndex) - - // see the jumpBacks variable for explanation - private def calculateJumpBacks: Array[Int] = { - val table = Array.ofDim[Int](pipeline.length) - var nextJumpBack = -1 - for (pos ← pipeline.indices) { - table(pos) = nextJumpBack - if (!pipeline(pos).isInstanceOf[PushStage[_, _]]) nextJumpBack = pos - } - table - } - - private def updateJumpBacks(lastNonCompletedIndex: Int): Unit = { - var pos = lastNonCompletedIndex - // For every jump that would jump over us we change them to jump into us - while (pos < pipeline.length && jumpBacks(pos) < lastNonCompletedIndex) { - jumpBacks(pos) = lastNonCompletedIndex - pos += 1 - } - } - - private sealed trait State extends DetachedContext[Any] with BoundaryContext with AsyncContext[Any, Any] { - def enter(): Unit = throw new IllegalStateException("cannot enter an ordinary Context") - - final def execute(): Unit = OneBoundedInterpreter.this.execute() - - final def progress(): Unit = { - advance() - if (inside) run() - else exit() - } - - /** - * Override this method to do execution steps necessary after executing an op, and advance the activeOpIndex - * to another value (next or previous steps). Do NOT put code that invokes the next op, override run instead. - */ - def advance(): Unit - - /** - * Override this method to enter the current op and execute it. Do NOT put code that should be executed after the - * op has been invoked, that should be in the advance() method of the next state resulting from the invocation of - * the op. - */ - def run(): Unit - - /** - * This method shall return the bit set representing the incoming ball (if any). - */ - def incomingBall: Int - - protected def hasBits(b: Int): Boolean = ((currentOp.bits | incomingBall) & b) == b - protected def addBits(b: Int): Unit = currentOp.bits |= b - protected def removeBits(b: Int): Unit = currentOp.bits &= ~b - - protected def mustHave(b: Int): Unit = - if (!hasBits(b)) { - def format(b: Int) = { - val ballStatus = (b & BothBalls: @switch) match { - case 0 ⇒ "no balls" - case UpstreamBall ⇒ "upstream ball" - case DownstreamBall ⇒ "downstream ball" - case BothBalls ⇒ "upstream & downstream balls" - } - if ((b & NoTerminationPending) > 0) ballStatus + " and not isFinishing" - else ballStatus + " and isFinishing" - } - throw new IllegalStateException(s"operation requires [${format(b)}] while holding [${format(currentOp.bits)}] and receiving [${format(incomingBall)}]") - } - - override def push(elem: Any): DownstreamDirective = { - ReactiveStreamsCompliance.requireNonNullElement(elem) - if (currentOp.isDetached) { - if (incomingBall == UpstreamBall) - throw new IllegalStateException("Cannot push during onPush, only pull, pushAndPull or holdUpstreamAndPush") - mustHave(DownstreamBall) - } - removeBits(PrecedingWasPull | DownstreamBall) - elementInFlight = elem - state = Pushing - null - } - - override def pull(): UpstreamDirective = { - var requirements = NoTerminationPending - if (currentOp.isDetached) { - if (incomingBall == DownstreamBall) - throw new IllegalStateException("Cannot pull during onPull, only push, pushAndPull or holdDownstreamAndPull") - requirements |= UpstreamBall - } - mustHave(requirements) - removeBits(UpstreamBall) - addBits(PrecedingWasPull) - state = Pulling - null - } - - override def getAsyncCallback: AsyncCallback[Any] = { - val current = currentOp.asInstanceOf[AsyncStage[Any, Any, Any]] - val context = current.context // avoid concurrent access (to avoid @volatile) - new AsyncCallback[Any] { - override def invoke(evt: Any): Unit = onAsyncInput(current, context, evt) - } - } - - override def ignore(): AsyncDirective = { - if (incomingBall != 0) throw new IllegalStateException("Can only ignore from onAsyncInput") - exit() - } - - override def finish(): FreeDirective = { - finishCurrentOp() - fork(Completing) - state = Cancelling - null - } - - def isFinishing: Boolean = !hasBits(NoTerminationPending) - - final protected def pushAndFinishCommon(elem: Any, finishState: UntypedOp): Unit = { - finishCurrentOp(finishState) - ReactiveStreamsCompliance.requireNonNullElement(elem) - if (currentOp.isDetached) { - mustHave(DownstreamBall) - } - removeBits(DownstreamBall | PrecedingWasPull) - } - - override def pushAndFinish(elem: Any): DownstreamDirective = { - // Spit the execution domain in two and invoke postStop - pushAndFinishCommon(elem, Finished.asInstanceOf[UntypedOp]) - - // This MUST be an unsafeFork because the execution of PushFinish MUST strictly come before the finish execution - // path. Other forks are not order dependent because they execute on isolated execution domains which cannot - // "cross paths". This unsafeFork is relatively safe here because PushAndFinish simply absorbs all later downstream - // calls of pushAndFinish since the finish event has been scheduled already. - // It might be that there are some degenerate cases where this can blow up the stack with a very long chain but I - // am not aware of such scenario yet. If you know one, put it in InterpreterStressSpec :) - unsafeFork(PushFinish, elem) - - // Same as finish, without calling finishCurrentOp - elementInFlight = null - fork(Completing) - state = Cancelling - null - } - - override def fail(cause: Throwable): FreeDirective = { - fork(Failing(cause)) - state = Cancelling - null - } - - override def holdUpstream(): UpstreamDirective = { - mustHave(NoTerminationPending) - removeBits(PrecedingWasPull) - addBits(UpstreamBall) - exit() - } - - override def holdUpstreamAndPush(elem: Any): UpstreamDirective = { - ReactiveStreamsCompliance.requireNonNullElement(elem) - if (incomingBall != UpstreamBall) - throw new IllegalStateException("can only holdUpstreamAndPush from onPush") - mustHave(BothBallsAndNoTerminationPending) - removeBits(PrecedingWasPull | DownstreamBall) - addBits(UpstreamBall) - elementInFlight = elem - state = Pushing - null - } - - override def isHoldingUpstream: Boolean = (currentOp.bits & UpstreamBall) != 0 - - override def holdDownstream(): DownstreamDirective = { - addBits(DownstreamBall) - exit() - } - - override def holdDownstreamAndPull(): DownstreamDirective = { - if (incomingBall != DownstreamBall) - throw new IllegalStateException("can only holdDownstreamAndPull from onPull") - mustHave(BothBallsAndNoTerminationPending) - addBits(PrecedingWasPull | DownstreamBall) - removeBits(UpstreamBall) - state = Pulling - null - } - - override def isHoldingDownstream: Boolean = (currentOp.bits & DownstreamBall) != 0 - - override def pushAndPull(elem: Any): FreeDirective = { - ReactiveStreamsCompliance.requireNonNullElement(elem) - mustHave(BothBallsAndNoTerminationPending) - addBits(PrecedingWasPull) - removeBits(BothBalls) - fork(Pushing, elem) - state = Pulling - null - } - - override def absorbTermination(): TerminationDirective = { - updateJumpBacks(activeOpIndex) - removeBits(BothBallsAndNoTerminationPending) - finish() - } - - override def exit(): FreeDirective = { - elementInFlight = null - lastExitedIndex = activeOpIndex - activeOpIndex = -1 - null - } - - override def materializer: Materializer = OneBoundedInterpreter.this.materializer - override def attributes: Attributes = OneBoundedInterpreter.this.attributes - } - - private final val Pushing: State = new State { - override def advance(): Unit = activeOpIndex += 1 - override def run(): Unit = currentOp.onPush(elementInFlight, ctx = this) - override def incomingBall = UpstreamBall - override def toString = "Pushing" - } - - private final val PushFinish: State = new State { - override def advance(): Unit = activeOpIndex += 1 - override def run(): Unit = currentOp.onPush(elementInFlight, ctx = this) - - override def pushAndFinish(elem: Any): DownstreamDirective = { - // PushFinished - // Put an isolation barrier that will prevent the onPull of this op to be called again. This barrier - // is different from simple Finished that it allows onUpstreamTerminated to pass through, unless onPull - // has been called on the stage - pushAndFinishCommon(elem, PushFinished.asInstanceOf[UntypedOp]) - - elementInFlight = elem - state = PushFinish - null - } - - override def finish(): FreeDirective = { - state = Completing - null - } - - override def incomingBall = UpstreamBall - - override def toString = "PushFinish" - } - - private final val Pulling: State = new State { - override def advance(): Unit = { - elementInFlight = null - activeOpIndex = jumpBacks(activeOpIndex) - } - - override def run(): Unit = currentOp.onPull(ctx = this) - - override def incomingBall = DownstreamBall - - override def toString = "Pulling" - } - - private final val Completing: State = new State { - override def advance(): Unit = { - elementInFlight = null - finishCurrentOp() - activeOpIndex += 1 - } - - override def run(): Unit = { - if (!hasBits(NoTerminationPending)) exit() - else currentOp.onUpstreamFinish(ctx = this) - } - - override def finish(): FreeDirective = { - state = Completing - null - } - - override def absorbTermination(): TerminationDirective = { - removeBits(NoTerminationPending) - removeBits(UpstreamBall) - updateJumpBacks(activeOpIndex) - if (hasBits(DownstreamBall) || (!currentOp.isDetached && hasBits(PrecedingWasPull))) { - removeBits(DownstreamBall) - currentOp.onPull(ctx = Pulling) - } else exit() - null - } - - override def incomingBall = UpstreamBall - - override def toString = "Completing" - } - - private final val Cancelling: State = new State { - override def advance(): Unit = { - elementInFlight = null - finishCurrentOp() - activeOpIndex -= 1 - } - - def run(): Unit = { - if (!hasBits(NoTerminationPending)) exit() - else currentOp.onDownstreamFinish(ctx = this) - } - - override def finish(): FreeDirective = { - state = Cancelling - null - } - - override def incomingBall = DownstreamBall - - override def toString = "Cancelling" - - override def absorbTermination(): TerminationDirective = { - val ex = new UnsupportedOperationException("It is not allowed to call absorbTermination() from onDownstreamFinish.") - // This MUST be logged here, since the downstream has cancelled, i.e. there is noone to send onError to, the - // stage is just about to finish so noone will catch it anyway just the interpreter - log.error(ex.getMessage) - throw ex // We still throw for correctness (although a finish() would also work here) - } - - } - - private final case class Failing(cause: Throwable) extends State { - override def advance(): Unit = { - elementInFlight = null - finishCurrentOp() - activeOpIndex += 1 - } - - def run(): Unit = currentOp.onUpstreamFailure(cause, ctx = this) - - override def absorbTermination(): TerminationDirective = { - removeBits(NoTerminationPending) - removeBits(UpstreamBall) - updateJumpBacks(activeOpIndex) - if (hasBits(DownstreamBall) || (!currentOp.isDetached && hasBits(PrecedingWasPull))) { - removeBits(DownstreamBall) - currentOp.onPull(ctx = Pulling) - } else exit() - null - } - - override def incomingBall = UpstreamBall - } - - private def inside: Boolean = activeOpIndex > -1 && activeOpIndex < pipeline.length - - private def printDebug(): Unit = { - val padding = " " * activeOpIndex - val icon: String = state match { - case Pushing | PushFinish ⇒ padding + s"---> $elementInFlight" - case Pulling ⇒ - (" " * jumpBacks(activeOpIndex)) + - "<---" + - ("----" * (activeOpIndex - jumpBacks(activeOpIndex) - 1)) - case Completing ⇒ padding + "---|" - case Cancelling ⇒ padding + "|---" - case Failing(e) ⇒ padding + s"---X ${e.getMessage} => ${decide(e)}" - case other ⇒ padding + s"---? $state" - } - println(f"$icon%-24s $name") - } - - @tailrec private def execute(): Unit = { - while (inside) { - try { - if (Debug) printDebug() - state.progress() - } catch { - case NonFatal(e) if lastOpFailing != activeOpIndex ⇒ - lastOpFailing = activeOpIndex - decide(e) match { - case Supervision.Stop ⇒ state.fail(e) - case Supervision.Resume ⇒ - // reset, purpose of lastOpFailing is to avoid infinite loops when fail fails -- double fault - lastOpFailing = -1 - afterRecovery() - case Supervision.Restart ⇒ - // reset, purpose of lastOpFailing is to avoid infinite loops when fail fails -- double fault - lastOpFailing = -1 - pipeline(activeOpIndex) = pipeline(activeOpIndex).restart().asInstanceOf[UntypedOp] - afterRecovery() - } - } - } - - // FIXME push this into AbstractStage so it can be customized - def afterRecovery(): Unit = state match { - case _: EntryState ⇒ // no ball to be juggled with - case _ ⇒ state.pull() - } - - // Execute all delayed forks that were put on the heap if the fork limit has been reached - if (overflowStack.nonEmpty) { - val memo = overflowStack.head - activeOpIndex = memo._1 - state = memo._2 - elementInFlight = memo._3 - overflowStack = overflowStack.tail - execute() - } - } - - def decide(e: Throwable): Supervision.Directive = - if (state == Pulling || state == Cancelling) Supervision.Stop - else currentOp.decide(e) - - /** - * Forks off execution of the pipeline by saving current position, fully executing the effects of the given - * forkState then setting back the position to the saved value. - * By default forking is executed by using the callstack. If the depth of forking ever reaches the configured forkLimit - * this method either fails (useful for testing) or starts using the heap instead of the callstack to avoid a - * stack overflow. - */ - private def fork(forkState: State, elem: Any = null): Unit = { - forkCount += 1 - if (forkCount == forkLimit) { - if (!overflowToHeap) throw new IllegalStateException("Fork limit reached") - else overflowStack ::= ((activeOpIndex, forkState, elem)) - } else unsafeFork(forkState, elem) - forkCount -= 1 - } - - /** - * Unsafe fork always uses the stack for execution. This call is needed by pushAndComplete where the forked execution - * is order dependent since the push and complete events travel in the same direction and not isolated by a boundary - */ - private def unsafeFork(forkState: State, elem: Any = null): Unit = { - val savePos = activeOpIndex - elementInFlight = elem - state = forkState - execute() - activeOpIndex = savePos - } - - /** - * Initializes all stages setting their initial context and calling [[AbstractStage.preStart]] on each. - */ - def init(): InitializationStatus = { - val failures = initBoundaries() - runDetached() - - if (failures.isEmpty) InitializationSuccessful - else { - val failure = failures.head - activeOpIndex = failure.op - currentOp.enterAndFail(failure.ex) - InitializationFailed(failures) - } - } - - def isFinished: Boolean = pipeline(Upstream) == Finished && pipeline(Downstream) == Finished - - private class EntryState(name: String, position: Int) extends State { - val entryPoint = position - - final override def enter(): Unit = { - activeOpIndex = entryPoint - if (Debug) { - val s = " " * entryPoint + "ENTR" - println(f"$s%-24s ${OneBoundedInterpreter.this.name}") - } - } - - override def run(): Unit = () - override def advance(): Unit = () - - override def incomingBall = 0 - - override def toString = s"$name($entryPoint)" - } - - /** - * This method injects a Context to each of the BoundaryStages and AsyncStages. This will be the context returned by enter(). - */ - private def initBoundaries(): List[InitializationFailure] = { - var failures: List[InitializationFailure] = Nil - var op = 0 - while (op < pipeline.length) { - (pipeline(op): Any) match { - case b: BoundaryStage ⇒ - b.context = new EntryState("boundary", op) - - case a: AsyncStage[Any, Any, Any] @unchecked ⇒ - a.context = new EntryState("async", op) - activeOpIndex = op - a.preStart(a.context) - - case a: AbstractStage[Any, Any, Any, Any, Any, Any] @unchecked ⇒ - val state = new EntryState("stage", op) - a.context = state - try a.preStart(state) catch { - case NonFatal(ex) ⇒ - failures ::= InitializationFailure(op, ex) // not logging here as 'most downstream' exception will be signaled via onError - } - } - op += 1 - } - failures - } - - private def finishCurrentOp(finishState: UntypedOp = Finished.asInstanceOf[UntypedOp]): Unit = { - try pipeline(activeOpIndex).postStop() - catch { case NonFatal(ex) ⇒ log.error(s"Stage [{}] postStop failed", ex) } - finally pipeline(activeOpIndex) = finishState - } - - /** - * Starts execution of detached regions. - * - * Since detached ops partition the pipeline into different 1-bounded domains is is necessary to inject a starting - * signal into these regions (since there is no external signal that would kick off their execution otherwise). - */ - private def runDetached(): Unit = { - var op = pipeline.length - 1 - while (op >= 0) { - if (pipeline(op).isDetached) { - activeOpIndex = op - state = Pulling - execute() - } - op -= 1 - } - } - -} diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index cf438260ac..9e1ce386c3 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -10,7 +10,6 @@ import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage import akka.stream.impl.{ FixedSizeBuffer, ReactiveStreamsCompliance } import akka.stream.stage._ import akka.stream.{ Supervision, _ } - import scala.annotation.tailrec import scala.collection.immutable import scala.collection.immutable.VectorBuilder @@ -18,6 +17,7 @@ import scala.concurrent.Future import scala.concurrent.duration.FiniteDuration import scala.util.control.NonFatal import scala.util.{ Failure, Success, Try } +import akka.stream.ActorAttributes.SupervisionStrategy /** * INTERNAL API @@ -537,149 +537,150 @@ private[akka] object MapAsync { /** * INTERNAL API */ -private[akka] final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out], decider: Supervision.Decider) - extends AsyncStage[In, Out, (Int, Try[Out])] { +private[akka] final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out]) + extends GraphStage[FlowShape[In, Out]] { import MapAsync._ - type Notification = (Int, Try[Out]) + private val in = Inlet[In]("in") + private val out = Outlet[Out]("out") - private var callback: AsyncCallback[Notification] = _ - private val elemsInFlight = FixedSizeBuffer[Try[Out]](parallelism) + override val shape = FlowShape(in, out) - override def preStart(ctx: AsyncContext[Out, Notification]): Unit = { - callback = ctx.getAsyncCallback - } + override def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) { + override def toString = s"MapAsync.Logic(buffer=$buffer)" - override def decide(ex: Throwable) = decider(ex) + val decider = + inheritedAttributes.getAttribute(classOf[SupervisionStrategy]) + .map(_.decider).getOrElse(Supervision.stoppingDecider) - override def onPush(elem: In, ctx: AsyncContext[Out, Notification]) = { - val future = f(elem) - val idx = elemsInFlight.enqueue(NotYetThere) - future.onComplete(t ⇒ callback.invoke((idx, t)))(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) - if (elemsInFlight.isFull) ctx.holdUpstream() - else ctx.pull() - } + val buffer = FixedSizeBuffer[Try[Out]](parallelism) + def todo = buffer.used - override def onPull(ctx: AsyncContext[Out, (Int, Try[Out])]) = { - @tailrec def rec(): DownstreamDirective = - if (elemsInFlight.isEmpty && ctx.isFinishing) ctx.finish() - else if (elemsInFlight.isEmpty || elemsInFlight.peek == NotYetThere) { - if (!elemsInFlight.isFull && ctx.isHoldingUpstream) ctx.holdDownstreamAndPull() - else ctx.holdDownstream() - } else elemsInFlight.dequeue() match { - case Failure(ex) ⇒ rec() + @tailrec private def pushOne(): Unit = + if (buffer.isEmpty) { + if (isClosed(in)) completeStage() + else if (!hasBeenPulled(in)) pull(in) + } else if (buffer.peek == NotYetThere) { + if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) + } else buffer.dequeue() match { + case Failure(ex) ⇒ pushOne() case Success(elem) ⇒ - if (ctx.isHoldingUpstream) ctx.pushAndPull(elem) - else ctx.push(elem) - } - rec() - } - - override def onAsyncInput(input: (Int, Try[Out]), ctx: AsyncContext[Out, Notification]) = { - @tailrec def rec(): Directive = - if (elemsInFlight.isEmpty && ctx.isFinishing) ctx.finish() - else if (elemsInFlight.isEmpty || (elemsInFlight.peek eq NotYetThere)) { - if (!elemsInFlight.isFull && ctx.isHoldingUpstream) ctx.pull() - else ctx.ignore() - } else elemsInFlight.dequeue() match { - case Failure(ex) ⇒ rec() - case Success(elem) ⇒ - if (ctx.isHoldingUpstream) ctx.pushAndPull(elem) - else ctx.push(elem) + push(out, elem) + if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) } - input match { - case (idx, f @ Failure(ex)) ⇒ - if (decider(ex) != Supervision.Stop) { - elemsInFlight.put(idx, f) - if (ctx.isHoldingDownstream) rec() - else ctx.ignore() - } else ctx.fail(ex) - case (idx, s: Success[_]) ⇒ - val exception = try { - ReactiveStreamsCompliance.requireNonNullElement(s.value) - elemsInFlight.put(idx, s) - null: Exception + def failOrPull(idx: Int, f: Failure[Out]) = + if (decider(f.exception) == Supervision.Stop) failStage(f.exception) + else { + buffer.put(idx, f) + if (isAvailable(out)) pushOne() + } + + val futureCB = + getAsyncCallback[(Int, Try[Out])]({ + case (idx, f: Failure[_]) ⇒ failOrPull(idx, f) + case (idx, s @ Success(elem)) ⇒ + if (elem == null) { + val ex = ReactiveStreamsCompliance.elementMustNotBeNullException + failOrPull(idx, Failure(ex)) + } else { + buffer.put(idx, s) + if (isAvailable(out)) pushOne() + } + }) + + setHandler(in, new InHandler { + override def onPush(): Unit = { + try { + val future = f(grab(in)) + val idx = buffer.enqueue(NotYetThere) + future.onComplete(result ⇒ futureCB.invoke(idx -> result))(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) } catch { case NonFatal(ex) ⇒ - if (decider(ex) != Supervision.Stop) { - elemsInFlight.put(idx, Failure(ex)) - null: Exception - } else ex + if (decider(ex) == Supervision.Stop) failStage(ex) } - if (exception != null) ctx.fail(exception) - else if (ctx.isHoldingDownstream) rec() - else ctx.ignore() - } - } + if (todo < parallelism) tryPull(in) + } + override def onUpstreamFinish(): Unit = { + if (todo == 0) completeStage() + } + }) - override def onUpstreamFinish(ctx: AsyncContext[Out, Notification]) = - if (ctx.isHoldingUpstream || !elemsInFlight.isEmpty) ctx.absorbTermination() - else ctx.finish() + setHandler(out, new OutHandler { + override def onPull(): Unit = pushOne() + }) + } } /** * INTERNAL API */ -private[akka] final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[Out], decider: Supervision.Decider) - extends AsyncStage[In, Out, Try[Out]] { +private[akka] final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[Out]) + extends GraphStage[FlowShape[In, Out]] { - private var callback: AsyncCallback[Try[Out]] = _ - private var inFlight = 0 - private val buffer = FixedSizeBuffer[Out](parallelism) + private val in = Inlet[In]("in") + private val out = Outlet[Out]("out") - private def todo = inFlight + buffer.used + override val shape = FlowShape(in, out) - override def preStart(ctx: AsyncContext[Out, Try[Out]]): Unit = - callback = ctx.getAsyncCallback + override def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) { + override def toString = s"MapAsyncUnordered.Logic(inFlight=$inFlight, buffer=$buffer)" - override def decide(ex: Throwable) = decider(ex) + val decider = + inheritedAttributes.getAttribute(classOf[SupervisionStrategy]) + .map(_.decider).getOrElse(Supervision.stoppingDecider) - override def onPush(elem: In, ctx: AsyncContext[Out, Try[Out]]) = { - val future = f(elem) - inFlight += 1 - future.onComplete(callback.invoke)(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) - if (todo == parallelism) ctx.holdUpstream() - else ctx.pull() - } + var inFlight = 0 + val buffer = FixedSizeBuffer[Out](parallelism) + def todo = inFlight + buffer.used - override def onPull(ctx: AsyncContext[Out, Try[Out]]) = - if (buffer.isEmpty) { - if (ctx.isFinishing && inFlight == 0) ctx.finish() else ctx.holdDownstream() - } else { - val elem = buffer.dequeue() - if (ctx.isHoldingUpstream) ctx.pushAndPull(elem) - else ctx.push(elem) - } + def failOrPull(ex: Throwable) = + if (decider(ex) == Supervision.Stop) failStage(ex) + else if (isClosed(in) && todo == 0) completeStage() + else if (!hasBeenPulled(in)) tryPull(in) - override def onAsyncInput(input: Try[Out], ctx: AsyncContext[Out, Try[Out]]) = { - def ignoreOrFail(ex: Throwable) = - if (decider(ex) == Supervision.Stop) ctx.fail(ex) - else if (ctx.isHoldingUpstream) ctx.pull() - else if (ctx.isFinishing && todo == 0) ctx.finish() - else ctx.ignore() - - inFlight -= 1 - input match { - case Failure(ex) ⇒ ignoreOrFail(ex) - case Success(elem) ⇒ - if (elem == null) { - val ex = ReactiveStreamsCompliance.elementMustNotBeNullException - ignoreOrFail(ex) - } else if (ctx.isHoldingDownstream) { - if (ctx.isHoldingUpstream) ctx.pushAndPull(elem) - else ctx.push(elem) - } else { - buffer.enqueue(elem) - ctx.ignore() + val futureCB = + getAsyncCallback((result: Try[Out]) ⇒ { + inFlight -= 1 + result match { + case Failure(ex) ⇒ failOrPull(ex) + case Success(elem) ⇒ + if (elem == null) { + val ex = ReactiveStreamsCompliance.elementMustNotBeNullException + failOrPull(ex) + } else if (isAvailable(out)) { + if (!hasBeenPulled(in)) tryPull(in) + push(out, elem) + } else buffer.enqueue(elem) } - } - } + }).invoke _ - override def onUpstreamFinish(ctx: AsyncContext[Out, Try[Out]]) = - if (todo > 0) ctx.absorbTermination() - else ctx.finish() + setHandler(in, new InHandler { + override def onPush(): Unit = { + try { + val future = f(grab(in)) + inFlight += 1 + future.onComplete(futureCB)(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) + } catch { + case NonFatal(ex) ⇒ + if (decider(ex) == Supervision.Stop) failStage(ex) + } + if (todo < parallelism) tryPull(in) + } + override def onUpstreamFinish(): Unit = { + if (todo == 0) completeStage() + } + }) + + setHandler(out, new OutHandler { + override def onPull(): Unit = { + if (!buffer.isEmpty) push(out, buffer.dequeue()) + else if (isClosed(in) && todo == 0) completeStage() + if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) + } + }) + } } /** @@ -695,7 +696,7 @@ private[akka] final case class Log[T](name: String, extract: T ⇒ Any, logAdapt // TODO more optimisations can be done here - prepare logOnPush function etc override def preStart(ctx: LifecycleContext): Unit = { - logLevels = ctx.attributes.logLevels.getOrElse(DefaultLogLevels) + logLevels = ctx.attributes.get[LogLevels](DefaultLogLevels) log = logAdapter match { case Some(l) ⇒ l case _ ⇒ @@ -787,7 +788,7 @@ private[stream] class GroupedWithin[T](n: Int, d: FiniteDuration) extends GraphS val out = Outlet[immutable.Seq[T]]("out") val shape = FlowShape(in, out) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { private val buf: VectorBuilder[T] = new VectorBuilder // True if: // - buf is nonEmpty @@ -855,13 +856,17 @@ private[stream] class GroupedWithin[T](n: Int, d: FiniteDuration) extends GraphS private[stream] class TakeWithin[T](timeout: FiniteDuration) extends SimpleLinearGraphStage[T] { - override def createLogic: GraphStageLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { setHandler(in, new InHandler { override def onPush(): Unit = push(out, grab(in)) override def onUpstreamFinish(): Unit = completeStage() override def onUpstreamFailure(ex: Throwable): Unit = failStage(ex) }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + final override protected def onTimer(key: Any): Unit = completeStage() @@ -874,7 +879,7 @@ private[stream] class TakeWithin[T](timeout: FiniteDuration) extends SimpleLinea private[stream] class DropWithin[T](timeout: FiniteDuration) extends SimpleLinearGraphStage[T] { private var allow = false - override def createLogic: GraphStageLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { setHandler(in, new InHandler { override def onPush(): Unit = if (allow) push(out, grab(in)) @@ -883,6 +888,10 @@ private[stream] class DropWithin[T](timeout: FiniteDuration) extends SimpleLinea override def onUpstreamFailure(ex: Throwable): Unit = failStage(ex) }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + final override protected def onTimer(key: Any): Unit = allow = true override def preStart(): Unit = scheduleOnce("DropWithinTimer", timeout) diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/IOSettings.scala b/akka-stream/src/main/scala/akka/stream/impl/io/IOSettings.scala index 2e6da93a1d..51dedd0622 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/IOSettings.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/IOSettings.scala @@ -1,15 +1,12 @@ package akka.stream.impl.io -import akka.stream.ActorAttributes.Dispatcher -import akka.stream.{ ActorMaterializer, MaterializationContext } +import akka.stream.ActorAttributes +import akka.stream.Attributes private[stream] object IOSettings { - /** Picks default akka.stream.blocking-io-dispatcher or the Attributes configured one */ - def blockingIoDispatcher(context: MaterializationContext): String = { - val mat = ActorMaterializer.downcast(context.materializer) - context.effectiveAttributes.attributeList.collectFirst { case d: Dispatcher ⇒ d.dispatcher } getOrElse { - mat.system.settings.config.getString("akka.stream.blocking-io-dispatcher") - } - } -} \ No newline at end of file + final val SyncFileSourceDefaultChunkSize = 8192 + final val SyncFileSourceName = Attributes.name("synchronousFileSource") + final val SyncFileSinkName = Attributes.name("synchronousFileSink") + final val IODispatcher = ActorAttributes.Dispatcher("akka.stream.default-blocking-io-dispatcher") +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/IOSinks.scala b/akka-stream/src/main/scala/akka/stream/impl/io/IOSinks.scala index 4573934055..313d552619 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/IOSinks.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/IOSinks.scala @@ -4,13 +4,11 @@ package akka.stream.impl.io import java.io.{ File, OutputStream } -import java.lang.{ Long ⇒ JLong } - -import akka.stream._ import akka.stream.impl.SinkModule import akka.stream.impl.StreamLayout.Module +import akka.stream.{ ActorMaterializer, MaterializationContext, Attributes, SinkShape } +import akka.stream.ActorAttributes.Dispatcher import akka.util.ByteString - import scala.concurrent.{ Future, Promise } /** @@ -27,7 +25,7 @@ private[akka] final class SynchronousFileSink(f: File, append: Boolean, val attr val bytesWrittenPromise = Promise[Long]() val props = SynchronousFileSubscriber.props(f, bytesWrittenPromise, settings.maxInputBufferSize, append) - val dispatcher = IOSettings.blockingIoDispatcher(context) + val dispatcher = context.effectiveAttributes.get[Dispatcher](IOSettings.IODispatcher).dispatcher val ref = mat.actorOf(context, props.withDispatcher(dispatcher)) (akka.stream.actor.ActorSubscriber[ByteString](ref), bytesWrittenPromise.future) diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/IOSources.scala b/akka-stream/src/main/scala/akka/stream/impl/io/IOSources.scala index fa2cb7b7d2..eef2ac9e16 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/IOSources.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/IOSources.scala @@ -10,6 +10,7 @@ import java.util.concurrent.{ LinkedBlockingQueue, BlockingQueue } import akka.actor.{ ActorRef, Deploy } import akka.japi import akka.stream._ +import akka.stream.ActorAttributes.Dispatcher import akka.stream.impl.StreamLayout.Module import akka.stream.impl.{ ErrorPublisher, SourceModule } import akka.stream.scaladsl.{ Source, FlowGraph } @@ -27,13 +28,13 @@ import scala.util.control.NonFatal private[akka] final class SynchronousFileSource(f: File, chunkSize: Int, val attributes: Attributes, shape: SourceShape[ByteString]) extends SourceModule[ByteString, Future[Long]](shape) { override def create(context: MaterializationContext) = { - // FIXME rewrite to be based on AsyncStage rather than dangerous downcasts + // FIXME rewrite to be based on GraphStage rather than dangerous downcasts val mat = ActorMaterializer.downcast(context.materializer) val settings = mat.effectiveSettings(context.effectiveAttributes) val bytesReadPromise = Promise[Long]() val props = SynchronousFilePublisher.props(f, bytesReadPromise, chunkSize, settings.initialInputBufferSize, settings.maxInputBufferSize) - val dispatcher = IOSettings.blockingIoDispatcher(context) + val dispatcher = context.effectiveAttributes.get[Dispatcher](IOSettings.IODispatcher).dispatcher val ref = mat.actorOf(context, props.withDispatcher(dispatcher)) diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala b/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala index aec1109411..a049fd38ea 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala @@ -5,15 +5,14 @@ package akka.stream.impl.io import java.io.{ IOException, InputStream } import java.util.concurrent.{ BlockingQueue, LinkedBlockingDeque, TimeUnit } - import akka.stream.Attributes.InputBuffer import akka.stream.impl.io.InputStreamSinkStage._ import akka.stream.stage._ import akka.util.ByteString - import scala.annotation.tailrec import scala.concurrent.Future import scala.concurrent.duration.FiniteDuration +import akka.stream.Attributes private[akka] object InputStreamSinkStage { @@ -38,7 +37,7 @@ private[akka] class InputStreamSinkStage(timeout: FiniteDuration) extends SinkSt val maxBuffer = module.attributes.getAttribute(classOf[InputBuffer], InputBuffer(16, 16)).max require(maxBuffer > 0, "Buffer size must be greater than 0") - override def createLogicAndMaterializedValue: (GraphStageLogic, InputStream) = { + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, InputStream) = { val dataQueue = new LinkedBlockingDeque[StreamToAdapterMessage](maxBuffer + 1) var pullRequestIsSent = true diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala b/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala index 29f8e7dedb..6519c7f3c3 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala @@ -7,6 +7,7 @@ import java.io.{ IOException, OutputStream } import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.{ BlockingQueue, LinkedBlockingQueue } +import akka.stream.Attributes import akka.stream.Attributes.InputBuffer import akka.stream.impl.io.OutputStreamSourceStage._ import akka.stream.stage._ @@ -35,7 +36,7 @@ private[akka] class OutputStreamSourceStage(timeout: FiniteDuration) extends Sou val maxBuffer = module.attributes.getAttribute(classOf[InputBuffer], InputBuffer(16, 16)).max require(maxBuffer > 0, "Buffer size must be greater than 0") - override def createLogicAndMaterializedValue: (GraphStageLogic, OutputStream) = { + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, OutputStream) = { val dataQueue = new LinkedBlockingQueue[ByteString](maxBuffer) var flush: Option[Promise[Unit]] = None diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TcpListenStreamActor.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TcpListenStreamActor.scala index 8af8d3a578..3c25bc79e6 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/TcpListenStreamActor.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/TcpListenStreamActor.scala @@ -156,11 +156,11 @@ private[akka] class TcpListenStreamActor(localAddressPromise: Promise[InetSocket val tcpStreamActor = context.watch(context.actorOf(TcpStreamActor.inboundProps(connection, halfClose, settings))) val processor = ActorProcessor[ByteString, ByteString](tcpStreamActor) - import scala.concurrent.duration._ + import scala.concurrent.duration.FiniteDuration val handler = (idleTimeout match { case d: FiniteDuration ⇒ Flow[ByteString].join(Timeouts.idleTimeoutBidi[ByteString, ByteString](d)) case _ ⇒ Flow[ByteString] - }).andThenMat(() ⇒ (processor, ())) + }).via(Flow.fromProcessor(() ⇒ processor)) val conn = StreamTcp.IncomingConnection( connected.localAddress, diff --git a/akka-stream/src/main/scala/akka/stream/io/ByteStringParser.scala b/akka-stream/src/main/scala/akka/stream/io/ByteStringParser.scala new file mode 100644 index 0000000000..c442eefcae --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/io/ByteStringParser.scala @@ -0,0 +1,119 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.io + +import scala.util.control.NoStackTrace +import akka.stream._ +import akka.stream.stage._ +import akka.util.ByteString +import scala.annotation.tailrec + +abstract class ByteStringParser[T] extends GraphStage[FlowShape[ByteString, T]] { + import ByteStringParser._ + + private val bytesIn = Inlet[ByteString]("bytesIn") + private val objOut = Outlet[T]("objOut") + + final override val shape = FlowShape(bytesIn, objOut) + + class ParsingLogic extends GraphStageLogic(shape) { + override def preStart(): Unit = pull(bytesIn) + setHandler(objOut, eagerTerminateOutput) + + private var buffer = ByteString.empty + private var current: ParseStep[T] = FinishedParser + + final protected def startWith(step: ParseStep[T]): Unit = current = step + + @tailrec private def doParse(): Unit = + if (buffer.nonEmpty) { + val cont = try { + val reader = new ByteReader(buffer) + val (elem, next) = current.parse(reader) + emit(objOut, elem) + if (next == FinishedParser) { + completeStage() + false + } else { + buffer = reader.remainingData + current = next + true + } + } catch { + case NeedMoreData ⇒ + pull(bytesIn) + false + } + if (cont) doParse() + } else pull(bytesIn) + + setHandler(bytesIn, new InHandler { + override def onPush(): Unit = { + buffer ++= grab(bytesIn) + doParse() + } + override def onUpstreamFinish(): Unit = + if (buffer.isEmpty) completeStage() + else current.onTruncation() + }) + } +} + +object ByteStringParser { + + trait ParseStep[+T] { + def parse(reader: ByteReader): (T, ParseStep[T]) + def onTruncation(): Unit = throw new IllegalStateException("truncated data in ByteStringParser") + } + + object FinishedParser extends ParseStep[Nothing] { + def parse(reader: ByteReader) = + throw new IllegalStateException("no initial parser installed: you must use startWith(...)") + } + + val NeedMoreData = new Exception with NoStackTrace + + class ByteReader(input: ByteString) { + + private[this] var off = 0 + + def hasRemaining: Boolean = off < input.size + def remainingSize: Int = input.size - off + + def currentOffset: Int = off + def remainingData: ByteString = input.drop(off) + def fromStartToHere: ByteString = input.take(off) + + def take(n: Int): ByteString = + if (off + n <= input.length) { + val o = off + off = o + n + input.slice(o, off) + } else throw NeedMoreData + def takeAll(): ByteString = { + val ret = remainingData + off = input.size + ret + } + + def readByte(): Int = + if (off < input.length) { + val x = input(off) + off += 1 + x & 0xFF + } else throw NeedMoreData + def readShortLE(): Int = readByte() | (readByte() << 8) + def readIntLE(): Int = readShortLE() | (readShortLE() << 16) + def readLongLE(): Long = (readIntLE() & 0xffffffffL) | ((readIntLE() & 0xffffffffL) << 32) + + def readShortBE(): Int = (readByte() << 8) | readByte() + def readIntBE(): Int = (readShortBE() << 16) | readShortBE() + def readLongBE(): Long = ((readIntBE() & 0xffffffffL) << 32) | (readIntBE() & 0xffffffffL) + + def skip(numBytes: Int): Unit = + if (off + numBytes <= input.length) off += numBytes + else throw NeedMoreData + def skipZeroTerminatedString(): Unit = while (readByte() != 0) {} + } +} \ No newline at end of file diff --git a/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSink.scala b/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSink.scala index f5fabdbec6..797348c3b8 100644 --- a/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSink.scala +++ b/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSink.scala @@ -5,7 +5,6 @@ package akka.stream.io import java.io.File -import akka.stream.impl.io.SynchronousFileSink import akka.stream.{ Attributes, javadsl, ActorAttributes } import akka.stream.scaladsl.Sink import akka.util.ByteString @@ -16,8 +15,10 @@ import scala.concurrent.Future * Sink which writes incoming [[ByteString]]s to the given file */ object SynchronousFileSink { + import akka.stream.impl.io.IOSettings._ + import akka.stream.impl.io.SynchronousFileSink - final val DefaultAttributes = Attributes.name("synchronousFileSink") + final val DefaultAttributes = SyncFileSinkName and IODispatcher /** * Synchronous (Java 6 compatible) Sink that writes incoming [[ByteString]] elements to the given file. diff --git a/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSource.scala b/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSource.scala index f038595952..6ad5d763b8 100644 --- a/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSource.scala +++ b/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSource.scala @@ -10,9 +10,10 @@ import akka.util.ByteString import scala.concurrent.Future object SynchronousFileSource { + import akka.stream.impl.io.IOSettings._ import akka.stream.impl.io.SynchronousFileSource - final val DefaultChunkSize = 8192 - final val DefaultAttributes = Attributes.name("synchronousFileSource") + + final val DefaultAttributes = SyncFileSourceName and IODispatcher /** * Creates a synchronous (Java 6 compatible) Source from a Files contents. @@ -24,7 +25,7 @@ object SynchronousFileSource { * * It materializes a [[Future]] containing the number of bytes read from the source file upon completion. */ - def apply(f: File, chunkSize: Int = DefaultChunkSize): Source[ByteString, Future[Long]] = + def apply(f: File, chunkSize: Int = SyncFileSourceDefaultChunkSize): Source[ByteString, Future[Long]] = new Source(new SynchronousFileSource(f, chunkSize, DefaultAttributes, Source.shape("SynchronousFileSource")).nest()) // TO DISCUSS: I had to add wrap() here to make the name available /** @@ -37,7 +38,7 @@ object SynchronousFileSource { * * It materializes a [[Future]] containing the number of bytes read from the source file upon completion. */ - def create(f: File): javadsl.Source[ByteString, Future[java.lang.Long]] = create(f, DefaultChunkSize) + def create(f: File): javadsl.Source[ByteString, Future[java.lang.Long]] = create(f, SyncFileSourceDefaultChunkSize) /** * Creates a synchronous (Java 6 compatible) Source from a Files contents. diff --git a/akka-stream/src/main/scala/akka/stream/io/Timeouts.scala b/akka-stream/src/main/scala/akka/stream/io/Timeouts.scala index 52e1c69a7d..fe01877147 100644 --- a/akka-stream/src/main/scala/akka/stream/io/Timeouts.scala +++ b/akka-stream/src/main/scala/akka/stream/io/Timeouts.scala @@ -5,7 +5,7 @@ import java.util.concurrent.{ TimeUnit, TimeoutException } import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage import akka.stream.scaladsl.{ BidiFlow, Flow } import akka.stream.stage._ -import akka.stream.{ BidiShape, Inlet, Outlet } +import akka.stream.{ BidiShape, Inlet, Outlet, Attributes } import scala.concurrent.duration.{ Deadline, FiniteDuration } @@ -62,7 +62,7 @@ object Timeouts { private class InitialTimeout[T](timeout: FiniteDuration) extends SimpleLinearGraphStage[T] { - override def createLogic: GraphStageLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { private var initialHasPassed = false setHandler(in, new InHandler { @@ -72,6 +72,10 @@ object Timeouts { } }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + final override protected def onTimer(key: Any): Unit = if (!initialHasPassed) failStage(new TimeoutException(s"The first element has not yet passed through in $timeout.")) @@ -84,11 +88,15 @@ object Timeouts { private class CompletionTimeout[T](timeout: FiniteDuration) extends SimpleLinearGraphStage[T] { - override def createLogic: GraphStageLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { setHandler(in, new InHandler { override def onPush(): Unit = push(out, grab(in)) }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + final override protected def onTimer(key: Any): Unit = failStage(new TimeoutException(s"The stream has not been completed in $timeout.")) @@ -101,7 +109,7 @@ object Timeouts { private class IdleTimeout[T](timeout: FiniteDuration) extends SimpleLinearGraphStage[T] { private var nextDeadline: Deadline = Deadline.now + timeout - override def createLogic: GraphStageLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { setHandler(in, new InHandler { override def onPush(): Unit = { nextDeadline = Deadline.now + timeout @@ -109,6 +117,10 @@ object Timeouts { } }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + final override protected def onTimer(key: Any): Unit = if (nextDeadline.isOverdue()) failStage(new TimeoutException(s"No elements passed in the last $timeout.")) @@ -128,7 +140,7 @@ object Timeouts { override def toString = "IdleTimeoutBidi" - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { private var nextDeadline: Deadline = Deadline.now + timeout setHandler(in1, new InHandler { diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala index 50dbf02e50..48fd37535e 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala @@ -22,8 +22,14 @@ object Flow { /** Create a `Flow` which can process elements of type `T`. */ def create[T](): javadsl.Flow[T, T, Unit] = fromGraph(scaladsl.Flow[T]) - def create[I, O](processorFactory: function.Creator[Processor[I, O]]): javadsl.Flow[I, O, Unit] = - new Flow(scaladsl.Flow(() ⇒ processorFactory.create())) + def fromProcessor[I, O](processorFactory: function.Creator[Processor[I, O]]): javadsl.Flow[I, O, Unit] = + new Flow(scaladsl.Flow.fromProcessor(() ⇒ processorFactory.create())) + + def fromProcessorMat[I, O, Mat](processorFactory: function.Creator[Pair[Processor[I, O], Mat]]): javadsl.Flow[I, O, Mat] = + new Flow(scaladsl.Flow.fromProcessorMat { () ⇒ + val javaPair = processorFactory.create() + (javaPair.first, javaPair.second) + }) /** Create a `Flow` which can process elements of type `T`. */ def of[T](clazz: Class[T]): javadsl.Flow[T, T, Unit] = create[T]() diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index 7ba5ee4736..5f099809bc 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -7,10 +7,12 @@ import akka.event.LoggingAdapter import akka.stream.Attributes._ import akka.stream._ import akka.stream.impl.SplitDecision._ -import akka.stream.impl.Stages.{ DirectProcessor, MaterializingStageFactory, StageModule } +import akka.stream.impl.Stages.{ SymbolicGraphStage, StageModule, DirectProcessor, SymbolicStage } import akka.stream.impl.StreamLayout.{ EmptyModule, Module } -import akka.stream.impl.fusing.{ DropWithin, GroupedWithin, TakeWithin } +import akka.stream.impl.fusing.{ DropWithin, GroupedWithin, TakeWithin, MapAsync, MapAsyncUnordered } import akka.stream.impl.{ ReactiveStreamsCompliance, ConstantFun, Stages, StreamLayout } +import akka.stream.impl.{ Stages, StreamLayout } +import akka.stream.stage.AbstractStage.{ PushPullGraphStageWithMaterializedValue, PushPullGraphStage } import akka.stream.stage._ import org.reactivestreams.{ Processor, Publisher, Subscriber, Subscription } @@ -30,7 +32,7 @@ final class Flow[-In, +Out, +Mat](private[stream] override val module: Module) override type Repr[+O, +M] = Flow[In @uncheckedVariance, O, M] - private[stream] def isIdentity: Boolean = this.module.isInstanceOf[Stages.Identity] + private[stream] def isIdentity: Boolean = this.module eq Stages.identityGraph.module def viaMat[T, Mat2, Mat3](flow: Graph[FlowShape[Out, T], Mat2])(combine: (Mat, Mat2) ⇒ Mat3): Flow[In, T, Mat3] = { if (this.isIdentity) { @@ -182,19 +184,15 @@ final class Flow[-In, +Out, +Mat](private[stream] override val module: Module) } /** INTERNAL API */ - override private[stream] def andThen[U](op: StageModule): Repr[U, Mat] = { + // FIXME: Only exists to keep old stuff alive + private[stream] override def deprecatedAndThen[U](op: StageModule): Repr[U, Mat] = { //No need to copy here, op is a fresh instance - if (op.isInstanceOf[Stages.Identity]) this.asInstanceOf[Repr[U, Mat]] - else if (this.isIdentity) new Flow(op).asInstanceOf[Repr[U, Mat]] + if (this.isIdentity) new Flow(op).asInstanceOf[Repr[U, Mat]] else new Flow(module.fuse(op, shape.outlet, op.inPort).replaceShape(FlowShape(shape.inlet, op.outPort))) } - private[stream] def andThenMat[U, Mat2](op: MaterializingStageFactory): Repr[U, Mat2] = { - if (this.isIdentity) new Flow(op).asInstanceOf[Repr[U, Mat2]] - else new Flow(module.fuse(op, shape.outlet, op.inPort, Keep.right).replaceShape(FlowShape(shape.inlet, op.outPort))) - } - - private[akka] def andThenMat[U, Mat2, O >: Out](processorFactory: () ⇒ (Processor[O, U], Mat2)): Repr[U, Mat2] = { + // FIXME: Only exists to keep old stuff alive + private[akka] def deprecatedAndThenMat[U, Mat2, O >: Out](processorFactory: () ⇒ (Processor[O, U], Mat2)): Repr[U, Mat2] = { val op = Stages.DirectProcessor(processorFactory.asInstanceOf[() ⇒ (Processor[Any, Any], Any)]) if (this.isIdentity) new Flow(op).asInstanceOf[Repr[U, Mat2]] else new Flow[In, U, Mat2](module.fuse(op, shape.outlet, op.inPort, Keep.right).replaceShape(FlowShape(shape.inlet, op.outPort))) @@ -244,14 +242,20 @@ final class Flow[-In, +Out, +Mat](private[stream] override val module: Module) } object Flow { - private[this] val identity: Flow[Any, Any, Unit] = new Flow[Any, Any, Unit](Stages.Identity()) + private[this] val identity: Flow[Any, Any, Unit] = new Flow[Any, Any, Unit](SymbolicGraphStage(Stages.Identity).module) /** * Creates a Flow from a Reactive Streams [[org.reactivestreams.Processor]] */ - def apply[I, O](processorFactory: () ⇒ Processor[I, O]): Flow[I, O, Unit] = { - val untypedFactory = processorFactory.asInstanceOf[() ⇒ Processor[Any, Any]] - Flow[I].andThen(DirectProcessor(() ⇒ (untypedFactory(), ()))) + def fromProcessor[I, O](processorFactory: () ⇒ Processor[I, O]): Flow[I, O, Unit] = { + fromProcessorMat(() ⇒ (processorFactory(), ())) + } + + /** + * Creates a Flow from a Reactive Streams [[org.reactivestreams.Processor]] and returns a materialized value. + */ + def fromProcessorMat[I, O, Mat](processorFactory: () ⇒ (Processor[I, O], Mat)): Flow[I, O, Mat] = { + Flow[I].deprecatedAndThenMat(processorFactory) } /** @@ -377,7 +381,7 @@ trait FlowOps[+Out, +Mat] { * '''Cancels when''' downstream cancels * */ - def recover[T >: Out](pf: PartialFunction[Throwable, T]): Repr[T, Mat] = andThen(Recover(pf.asInstanceOf[PartialFunction[Any, Any]])) + def recover[T >: Out](pf: PartialFunction[Throwable, T]): Repr[T, Mat] = andThen(Recover(pf)) /** * Transform this stream by applying the given function to each of the elements @@ -392,7 +396,7 @@ trait FlowOps[+Out, +Mat] { * '''Cancels when''' downstream cancels * */ - def map[T](f: Out ⇒ T): Repr[T, Mat] = andThen(Map(f.asInstanceOf[Any ⇒ Any])) + def map[T](f: Out ⇒ T): Repr[T, Mat] = andThen(Map(f)) /** * Transform each input element into an `Iterable` of output elements that is @@ -412,7 +416,7 @@ trait FlowOps[+Out, +Mat] { * '''Cancels when''' downstream cancels * */ - def mapConcat[T](f: Out ⇒ immutable.Iterable[T]): Repr[T, Mat] = andThen(MapConcat(f.asInstanceOf[Any ⇒ immutable.Iterable[Any]])) + def mapConcat[T](f: Out ⇒ immutable.Iterable[T]): Repr[T, Mat] = andThen(MapConcat(f)) /** * Transform this stream by applying the given function to each of the elements @@ -443,8 +447,7 @@ trait FlowOps[+Out, +Mat] { * * @see [[#mapAsyncUnordered]] */ - def mapAsync[T](parallelism: Int)(f: Out ⇒ Future[T]): Repr[T, Mat] = - andThen(MapAsync(parallelism, f.asInstanceOf[Any ⇒ Future[Any]])) + def mapAsync[T](parallelism: Int)(f: Out ⇒ Future[T]): Repr[T, Mat] = via(MapAsync(parallelism, f)) /** * Transform this stream by applying the given function to each of the elements @@ -475,8 +478,7 @@ trait FlowOps[+Out, +Mat] { * * @see [[#mapAsync]] */ - def mapAsyncUnordered[T](parallelism: Int)(f: Out ⇒ Future[T]): Repr[T, Mat] = - andThen(MapAsyncUnordered(parallelism, f.asInstanceOf[Any ⇒ Future[Any]])) + def mapAsyncUnordered[T](parallelism: Int)(f: Out ⇒ Future[T]): Repr[T, Mat] = via(MapAsyncUnordered(parallelism, f)) /** * Only pass on those elements that satisfy the given predicate. @@ -489,7 +491,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def filter(p: Out ⇒ Boolean): Repr[Out, Mat] = andThen(Filter(p.asInstanceOf[Any ⇒ Boolean])) + def filter(p: Out ⇒ Boolean): Repr[Out, Mat] = andThen(Filter(p)) /** * Only pass on those elements that NOT satisfy the given predicate. @@ -522,7 +524,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' predicate returned false or downstream cancels */ - def takeWhile(p: Out ⇒ Boolean): Repr[Out, Mat] = andThen(TakeWhile(p.asInstanceOf[Any ⇒ Boolean])) + def takeWhile(p: Out ⇒ Boolean): Repr[Out, Mat] = andThen(TakeWhile(p)) /** * Discard elements at the beginning of the stream while predicate is true. @@ -536,7 +538,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def dropWhile(p: Out ⇒ Boolean): Repr[Out, Mat] = andThen(DropWhile(p.asInstanceOf[Any ⇒ Boolean])) + def dropWhile(p: Out ⇒ Boolean): Repr[Out, Mat] = andThen(DropWhile(p)) /** * Transform this stream by applying the given partial function to each of the elements @@ -551,7 +553,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def collect[T](pf: PartialFunction[Out, T]): Repr[T, Mat] = andThen(Collect(pf.asInstanceOf[PartialFunction[Any, Any]])) + def collect[T](pf: PartialFunction[Out, T]): Repr[T, Mat] = andThen(Collect(pf)) /** * Chunk up this stream into groups of the given size, with the last group @@ -604,7 +606,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def scan[T](zero: T)(f: (T, Out) ⇒ T): Repr[T, Mat] = andThen(Scan(zero, f.asInstanceOf[(Any, Any) ⇒ Any])) + def scan[T](zero: T)(f: (T, Out) ⇒ T): Repr[T, Mat] = andThen(Scan(zero, f)) /** * Similar to `scan` but only emits its result when the upstream completes, @@ -623,7 +625,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def fold[T](zero: T)(f: (T, Out) ⇒ T): Repr[T, Mat] = andThen(Fold(zero, f.asInstanceOf[(Any, Any) ⇒ Any])) + def fold[T](zero: T)(f: (T, Out) ⇒ T): Repr[T, Mat] = andThen(Fold(zero, f)) /** * Intersperses stream with provided element, similar to how [[scala.collection.immutable.List.mkString]] @@ -778,8 +780,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels or timer fires */ - def takeWithin(d: FiniteDuration): Repr[Out, Mat] = - via(new TakeWithin[Out](d).withAttributes(name("takeWithin"))) + def takeWithin(d: FiniteDuration): Repr[Out, Mat] = via(new TakeWithin[Out](d).withAttributes(name("takeWithin"))) /** * Allows a faster upstream to progress independently of a slower subscriber by conflating elements into a summary @@ -800,8 +801,7 @@ trait FlowOps[+Out, +Mat] { * @param seed Provides the first state for a conflated value using the first unconsumed element as a start * @param aggregate Takes the currently aggregated value and the current pending element to produce a new aggregate */ - def conflate[S](seed: Out ⇒ S)(aggregate: (S, Out) ⇒ S): Repr[S, Mat] = - andThen(Conflate(seed.asInstanceOf[Any ⇒ Any], aggregate.asInstanceOf[(Any, Any) ⇒ Any])) + def conflate[S](seed: Out ⇒ S)(aggregate: (S, Out) ⇒ S): Repr[S, Mat] = andThen(Conflate(seed, aggregate)) /** * Allows a faster downstream to progress independently of a slower publisher by extrapolating elements from an older @@ -827,8 +827,7 @@ trait FlowOps[+Out, +Mat] { * @param extrapolate Takes the current extrapolation state to produce an output element and the next extrapolation * state. */ - def expand[S, U](seed: Out ⇒ S)(extrapolate: S ⇒ (U, S)): Repr[U, Mat] = - andThen(Expand(seed.asInstanceOf[Any ⇒ Any], extrapolate.asInstanceOf[Any ⇒ (Any, Any)])) + def expand[S, U](seed: Out ⇒ S)(extrapolate: S ⇒ (U, S)): Repr[U, Mat] = andThen(Expand(seed, extrapolate)) /** * Adds a fixed size buffer in the flow that allows to store elements from a faster upstream until it becomes full. @@ -849,8 +848,7 @@ trait FlowOps[+Out, +Mat] { * @param size The size of the buffer in element count * @param overflowStrategy Strategy that is used when incoming elements cannot fit inside the buffer */ - def buffer(size: Int, overflowStrategy: OverflowStrategy): Repr[Out, Mat] = - andThen(Buffer(size, overflowStrategy)) + def buffer(size: Int, overflowStrategy: OverflowStrategy): Repr[Out, Mat] = andThen(Buffer(size, overflowStrategy)) /** * Generic transformation of a stream with a custom processing [[akka.stream.stage.Stage]]. @@ -858,10 +856,10 @@ trait FlowOps[+Out, +Mat] { * operator that performs the transformation. */ def transform[T](mkStage: () ⇒ Stage[Out, T]): Repr[T, Mat] = - andThen(StageFactory(mkStage)) + via(new PushPullGraphStage((attr) ⇒ mkStage(), Attributes.none)) private[akka] def transformMaterializing[T, M](mkStageAndMaterialized: () ⇒ (Stage[Out, T], M)): Repr[T, M] = - andThenMat(MaterializingStageFactory(mkStageAndMaterialized)) + viaMat(new PushPullGraphStageWithMaterializedValue[Out, T, Unit, M]((attr) ⇒ mkStageAndMaterialized(), Attributes.none))(Keep.right) /** * Takes up to `n` elements from the stream (less than `n` only if the upstream completes before emitting `n` elements) @@ -886,7 +884,7 @@ trait FlowOps[+Out, +Mat] { * */ def prefixAndTail[U >: Out](n: Int): Repr[(immutable.Seq[Out], Source[U, Unit]), Mat] = - andThen(PrefixAndTail(n)) + deprecatedAndThen(PrefixAndTail(n)) /** * This operation demultiplexes the incoming stream into separate output @@ -918,7 +916,7 @@ trait FlowOps[+Out, +Mat] { * */ def groupBy[K, U >: Out](f: Out ⇒ K): Repr[(K, Source[U, Unit]), Mat] = - andThen(GroupBy(f.asInstanceOf[Any ⇒ Any])) + deprecatedAndThen(GroupBy(f.asInstanceOf[Any ⇒ Any])) /** * This operation applies the given predicate to all incoming elements and @@ -962,7 +960,7 @@ trait FlowOps[+Out, +Mat] { */ def splitWhen[U >: Out](p: Out ⇒ Boolean): Repr[Out, Mat]#Repr[Source[U, Unit], Mat] = { val f = p.asInstanceOf[Any ⇒ Boolean] - withAttributes(name("splitWhen")).andThen(Split(el ⇒ if (f(el)) SplitBefore else Continue)) + withAttributes(name("splitWhen")).deprecatedAndThen(Split(el ⇒ if (f(el)) SplitBefore else Continue)) } /** @@ -999,7 +997,7 @@ trait FlowOps[+Out, +Mat] { */ def splitAfter[U >: Out](p: Out ⇒ Boolean): Repr[Out, Mat]#Repr[Source[U, Unit], Mat] = { val f = p.asInstanceOf[Any ⇒ Boolean] - withAttributes(name("splitAfter")).andThen(Split(el ⇒ if (f(el)) SplitAfter else Continue)) + withAttributes(name("splitAfter")).deprecatedAndThen(Split(el ⇒ if (f(el)) SplitAfter else Continue)) } /** @@ -1016,7 +1014,7 @@ trait FlowOps[+Out, +Mat] { * */ def flatten[U](strategy: FlattenStrategy[Out, U]): Repr[U, Mat] = strategy match { - case scaladsl.FlattenStrategy.Concat | javadsl.FlattenStrategy.Concat ⇒ andThen(ConcatAll()) + case scaladsl.FlattenStrategy.Concat | javadsl.FlattenStrategy.Concat ⇒ deprecatedAndThen(ConcatAll()) case _ ⇒ throw new IllegalArgumentException(s"Unsupported flattening strategy [${strategy.getClass.getName}]") } @@ -1211,8 +1209,9 @@ trait FlowOps[+Out, +Mat] { def withAttributes(attr: Attributes): Repr[Out, Mat] /** INTERNAL API */ - private[scaladsl] def andThen[U](op: StageModule): Repr[U, Mat] + private[scaladsl] def andThen[T](op: SymbolicStage[Out, T]): Repr[T, Mat] = + via(SymbolicGraphStage(op)) - private[scaladsl] def andThenMat[U, Mat2](op: MaterializingStageFactory): Repr[U, Mat2] + private[scaladsl] def deprecatedAndThen[U](op: StageModule): Repr[U, Mat] } diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala index 2f557e88d4..71a46d4893 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala @@ -3,11 +3,10 @@ */ package akka.stream.scaladsl -import akka.stream.impl.Stages.{ MaterializingStageFactory, StageModule } +import akka.stream.impl.Stages.{ StageModule, SymbolicStage } import akka.stream.impl._ import akka.stream.impl.StreamLayout._ import akka.stream._ -import Attributes.name import akka.stream.stage.{ OutHandler, InHandler, GraphStageLogic, GraphStage } import scala.annotation.unchecked.uncheckedVariance import scala.annotation.tailrec @@ -41,7 +40,7 @@ class Merge[T] private (val inputPorts: Int, val eagerClose: Boolean) extends Gr val out: Outlet[T] = Outlet[T]("Merge.out") override val shape: UniformFanInShape[T, T] = UniformFanInShape(out, in: _*) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { private var initialized = false private val pendingQueue = Array.ofDim[Inlet[T]](inputPorts) @@ -133,7 +132,7 @@ object MergePreferred { * * '''Backpressures when''' downstream backpressures * - * '''Completes when''' all upstreams complete + * '''Completes when''' all upstreams complete (eagerClose=false) or one upstream completes (eagerClose=true) * * '''Cancels when''' downstream cancels * @@ -147,91 +146,82 @@ class MergePreferred[T] private (val secondaryPorts: Int, val eagerClose: Boolea def out: Outlet[T] = shape.out def preferred: Inlet[T] = shape.preferred - // FIXME: Factor out common stuff with Merge - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { - private var initialized = false - - private val pendingQueue = Array.ofDim[Inlet[T]](secondaryPorts) - private var pendingHead = 0 - private var pendingTail = 0 - - private var runningUpstreams = secondaryPorts + 1 - private def upstreamsClosed = runningUpstreams == 0 - - private def pending: Boolean = pendingHead != pendingTail - private def priority: Boolean = isAvailable(preferred) - - private def enqueue(in: Inlet[T]): Unit = { - pendingQueue(pendingTail % secondaryPorts) = in - pendingTail += 1 + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { + var openInputs = secondaryPorts + 1 + def onComplete(): Unit = { + openInputs -= 1 + if (eagerClose || openInputs == 0) completeStage() } - private def dequeueAndDispatch(): Unit = { - val in = pendingQueue(pendingHead % secondaryPorts) - pendingHead += 1 - push(out, grab(in)) - if (upstreamsClosed && !pending && !priority) completeStage() - else tryPull(in) - } - - // FIXME: slow iteration, try to make in a vector and inject into shape instead - (0 until secondaryPorts).map(in).foreach { i ⇒ - setHandler(i, new InHandler { - override def onPush(): Unit = { - if (isAvailable(out)) { - if (!pending) { - push(out, grab(i)) - tryPull(i) - } - } else enqueue(i) - } - - override def onUpstreamFinish() = - if (eagerClose) { - (0 until secondaryPorts).foreach(i ⇒ cancel(in(i))) - cancel(preferred) - runningUpstreams = 0 - if (!pending) completeStage() - } else { - runningUpstreams -= 1 - if (upstreamsClosed && !pending && !priority) completeStage() - } - }) - } - - setHandler(preferred, new InHandler { - override def onPush() = { - if (isAvailable(out)) { - push(out, grab(preferred)) - tryPull(preferred) - } - } - - override def onUpstreamFinish() = - if (eagerClose) { - (0 until secondaryPorts).foreach(i ⇒ cancel(in(i))) - runningUpstreams = 0 - if (!pending) completeStage() - } else { - runningUpstreams -= 1 - if (upstreamsClosed && !pending && !priority) completeStage() - } - }) - setHandler(out, new OutHandler { + private var first = true override def onPull(): Unit = { - if (!initialized) { - initialized = true - // FIXME: slow iteration, try to make in a vector and inject into shape instead + if (first) { + first = false tryPull(preferred) - (0 until secondaryPorts).map(in).foreach(tryPull) - } else if (priority) { - push(out, grab(preferred)) - tryPull(preferred) - } else if (pending) - dequeueAndDispatch() + shape.inSeq.foreach(tryPull) + } } }) + + val pullMe = Array.tabulate(secondaryPorts)(i ⇒ { + val port = in(i) + () ⇒ tryPull(port) + }) + + /* + * This determines the unfairness of the merge: + * - at 1 the preferred will grab 40% of the bandwidth against three equally fast secondaries + * - at 2 the preferred will grab almost all bandwidth against three equally fast secondaries + * (measured with eventLimit=1 in the GraphInterpreter, so may not be accurate) + */ + val maxEmitting = 2 + var preferredEmitting = 0 + + setHandler(preferred, new InHandler { + override def onUpstreamFinish(): Unit = onComplete() + override def onPush(): Unit = + if (preferredEmitting == maxEmitting) () // blocked + else emitPreferred() + + def emitPreferred(): Unit = { + preferredEmitting += 1 + emit(out, grab(preferred), emitted) + tryPull(preferred) + } + + val emitted = () ⇒ { + preferredEmitting -= 1 + if (isAvailable(preferred)) emitPreferred() + else if (preferredEmitting == 0) emitSecondary() + } + + def emitSecondary(): Unit = { + var i = 0 + while (i < secondaryPorts) { + val port = in(i) + if (isAvailable(port)) emit(out, grab(port), pullMe(i)) + i += 1 + } + } + }) + + var i = 0 + while (i < secondaryPorts) { + val port = in(i) + val pullPort = pullMe(i) + setHandler(port, new InHandler { + override def onPush(): Unit = { + if (preferredEmitting > 0) () // blocked + else { + emit(out, grab(port), pullPort) + } + } + override def onUpstreamFinish(): Unit = onComplete() + }) + i += 1 + } + } } @@ -266,7 +256,7 @@ class Broadcast[T](private val outputPorts: Int, eagerCancel: Boolean) extends G val out: immutable.IndexedSeq[Outlet[T]] = Vector.tabulate(outputPorts)(i ⇒ Outlet[T]("Broadcast.out" + i)) override val shape: UniformFanOutShape[T, T] = UniformFanOutShape(in, out: _*) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { private var pendingCount = outputPorts private val pending = Array.fill[Boolean](outputPorts)(true) private var downstreamsRunning = outputPorts @@ -364,7 +354,7 @@ class Balance[T](val outputPorts: Int, waitForAllDownstreams: Boolean) extends G val out: immutable.IndexedSeq[Outlet[T]] = Vector.tabulate(outputPorts)(i ⇒ Outlet[T]("Balance.out" + i)) override val shape: UniformFanOutShape[T, T] = UniformFanOutShape[T, T](in, out: _*) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { private val pendingQueue = Array.ofDim[Outlet[T]](outputPorts) private var pendingHead: Int = 0 private var pendingTail: Int = 0 @@ -531,7 +521,7 @@ class Concat[T](inputCount: Int) extends GraphStage[UniformFanInShape[T, T]] { val out: Outlet[T] = Outlet[T]("Concat.out") override val shape: UniformFanInShape[T, T] = UniformFanInShape(out, in: _*) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) { var activeStream: Int = 0 { @@ -634,7 +624,7 @@ object FlowGraph extends GraphApply { module.shape.outlet.asInstanceOf[Outlet[M]] } - private[stream] def andThen(port: OutPort, op: StageModule): Unit = { + private[stream] def deprecatedAndThen(port: OutPort, op: StageModule): Unit = { moduleInProgress = moduleInProgress .compose(op) @@ -759,17 +749,6 @@ object FlowGraph extends GraphApply { override def withAttributes(attr: Attributes): Repr[Out, Mat] = throw new UnsupportedOperationException("Cannot set attributes on chained ops from a junction output port") - override private[scaladsl] def andThen[U](op: StageModule): Repr[U, Mat] = { - b.andThen(outlet, op) - new PortOps(op.shape.outlet.asInstanceOf[Outlet[U]], b) - } - - override private[scaladsl] def andThenMat[U, Mat2](op: MaterializingStageFactory): Repr[U, Mat2] = { - // We don't track materialization here - b.andThen(outlet, op) - new PortOps(op.shape.outlet.asInstanceOf[Outlet[U]], b) - } - override def importAndGetPort(b: Builder[_]): Outlet[Out] = outlet override def via[T, Mat2](flow: Graph[FlowShape[Out, T], Mat2]): Repr[T, Mat] = @@ -777,6 +756,12 @@ object FlowGraph extends GraphApply { override def viaMat[T, Mat2, Mat3](flow: Graph[FlowShape[Out, T], Mat2])(combine: (Mat, Mat2) ⇒ Mat3) = throw new UnsupportedOperationException("Cannot use viaMat on a port") + + override private[scaladsl] def deprecatedAndThen[U](op: StageModule): PortOps[U, Mat] = { + b.deprecatedAndThen(outlet, op) + new PortOps(op.shape.outlet.asInstanceOf[Outlet[U]], b) + } + } final class DisabledPortOps[Out, Mat](msg: String) extends PortOps[Out, Mat](null, null) { diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala index c5f6ade5bc..f32b7a4b07 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala @@ -39,7 +39,7 @@ object One2OneBidiFlow { override def toString = "One2OneBidi" - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { private var pending = 0 private var pullsSuppressed = 0 diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala index 58dba3ca1d..7a287147f8 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala @@ -5,7 +5,7 @@ package akka.stream.scaladsl import akka.actor.{ ActorRef, Cancellable, Props } import akka.stream.actor.ActorPublisher -import akka.stream.impl.Stages.{ DefaultAttributes, MaterializingStageFactory, StageModule } +import akka.stream.impl.Stages.{ DefaultAttributes, StageModule } import akka.stream.impl.StreamLayout.Module import akka.stream.impl.fusing.GraphStages.TickSource import akka.stream.impl.{ EmptyPublisher, ErrorPublisher, _ } @@ -32,7 +32,7 @@ final class Source[+Out, +Mat](private[stream] override val module: Module) override val shape: SourceShape[Out] = module.shape.asInstanceOf[SourceShape[Out]] def viaMat[T, Mat2, Mat3](flow: Graph[FlowShape[Out, T], Mat2])(combine: (Mat, Mat2) ⇒ Mat3): Source[T, Mat3] = { - if (flow.module.isInstanceOf[Stages.Identity]) this.asInstanceOf[Source[T, Mat3]] + if (flow.module eq Stages.identityGraph.module) this.asInstanceOf[Source[T, Mat3]] else { val flowCopy = flow.module.carbonCopy new Source( @@ -64,7 +64,7 @@ final class Source[+Out, +Mat](private[stream] override val module: Module) new Source(module.transformMaterializedValue(f.asInstanceOf[Any ⇒ Any])) /** INTERNAL API */ - override private[scaladsl] def andThen[U](op: StageModule): Repr[U, Mat] = { + override private[scaladsl] def deprecatedAndThen[U](op: StageModule): Repr[U, Mat] = { // No need to copy here, op is a fresh instance new Source( module @@ -72,13 +72,6 @@ final class Source[+Out, +Mat](private[stream] override val module: Module) .replaceShape(SourceShape(op.outPort))) } - override private[scaladsl] def andThenMat[U, Mat2](op: MaterializingStageFactory): Repr[U, Mat2] = { - new Source( - module - .fuse(op, shape.outlet, op.inPort, Keep.right) - .replaceShape(SourceShape(op.outPort))) - } - /** * Connect this `Source` to a `Sink` and run it. The returned value is the materialized value * of the `Sink`, e.g. the `Publisher` of a [[akka.stream.scaladsl.Sink#publisher]]. diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala index 125b208242..3f16f5258f 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala @@ -208,7 +208,7 @@ class Tcp(system: ExtendedActorSystem) extends akka.actor.Extension { case _ ⇒ Flow[ByteString] } - Flow[ByteString].andThenMat(() ⇒ { + Flow[ByteString].deprecatedAndThenMat(() ⇒ { val processorPromise = Promise[Processor[ByteString, ByteString]]() val localAddressPromise = Promise[InetSocketAddress]() manager ! StreamTcpManager.Connect(processorPromise, localAddressPromise, remoteAddress, localAddress, halfClose, options, diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 0cd99a338c..2ff8644f0c 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -5,15 +5,18 @@ package akka.stream.stage import akka.actor.{ Cancellable, DeadLetterSuppression } import akka.stream._ +import akka.stream.impl.ReactiveStreamsCompliance import akka.stream.impl.StreamLayout.Module import akka.stream.impl.fusing.{ GraphModule, GraphInterpreter } import akka.stream.impl.fusing.GraphInterpreter.GraphAssembly import scala.collection.{ immutable, mutable } import scala.concurrent.duration.FiniteDuration import scala.collection.mutable.ArrayBuffer +import scala.annotation.tailrec abstract class GraphStageWithMaterializedValue[+S <: Shape, +M] extends Graph[S, M] { - def createLogicAndMaterializedValue: (GraphStageLogic, M) + + def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, M) final override private[stream] lazy val module: Module = GraphModule( @@ -39,9 +42,10 @@ abstract class GraphStageWithMaterializedValue[+S <: Shape, +M] extends Graph[S, * logic that ties the ports together. */ abstract class GraphStage[S <: Shape] extends GraphStageWithMaterializedValue[S, Unit] { - final override def createLogicAndMaterializedValue = (createLogic, Unit) + final override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = + (createLogic(inheritedAttributes), Unit) - def createLogic: GraphStageLogic + def createLogic(inheritedAttributes: Attributes): GraphStageLogic } /** @@ -91,7 +95,7 @@ object GraphStageLogic { */ class ConditionalTerminateInput(predicate: () ⇒ Boolean) extends InHandler { override def onPush(): Unit = () - override def onUpstreamFinish(): Unit = if (predicate()) ownerStageLogic.completeStage() + override def onUpstreamFinish(): Unit = if (predicate()) inOwnerStageLogic.completeStage() } /** @@ -125,7 +129,7 @@ object GraphStageLogic { */ class ConditionalTerminateOutput(predicate: () ⇒ Boolean) extends OutHandler { override def onPull(): Unit = () - override def onDownstreamFinish(): Unit = if (predicate()) ownerStageLogic.completeStage() + override def onDownstreamFinish(): Unit = if (predicate()) outOwnerStageLogic.completeStage() } private object DoNothing extends (() ⇒ Unit) { @@ -146,24 +150,11 @@ object GraphStageLogic { * The stage logic is always stopped once all its input and output ports have been closed, i.e. it is not possible to * keep the stage alive for further processing once it does not have any open ports. */ -abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { +abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: Int) { import GraphInterpreter._ - import TimerMessages._ import GraphStageLogic._ def this(shape: Shape) = this(shape.inlets.size, shape.outlets.size) - - private val keyToTimers = mutable.Map[Any, Timer]() - private val timerIdGen = Iterator from 1 - - private var _timerAsyncCallback: AsyncCallback[Scheduled] = _ - private def getTimerAsyncCallback: AsyncCallback[Scheduled] = { - if (_timerAsyncCallback eq null) - _timerAsyncCallback = getAsyncCallback(onInternalTimer) - - _timerAsyncCallback - } - /** * INTERNAL API */ @@ -172,20 +163,14 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { /** * INTERNAL API */ - private[stream] var inHandlers = Array.ofDim[InHandler](inCount) - /** - * INTERNAL API - */ - private[stream] var outHandlers = Array.ofDim[OutHandler](outCount) + // Using common array to reduce overhead for small port counts + private[stream] var handlers = Array.ofDim[Any](inCount + outCount) /** * INTERNAL API */ - private[stream] var inToConn = Array.ofDim[Int](inHandlers.length) - /** - * INTERNAL API - */ - private[stream] var outToConn = Array.ofDim[Int](outHandlers.length) + // Using common array to reduce overhead for small port counts + private[stream] var portToConn = Array.ofDim[Int](handlers.length) /** * INTERNAL API @@ -243,32 +228,35 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { * Assigns callbacks for the events for an [[Inlet]] */ final protected def setHandler(in: Inlet[_], handler: InHandler): Unit = { - handler.ownerStageLogic = this - inHandlers(in.id) = handler - if (_interpreter != null) _interpreter.setHandler(inToConn(in.id), handler) + handler.inOwnerStageLogic = this + handlers(in.id) = handler + if (_interpreter != null) _interpreter.setHandler(conn(in), handler) } /** * Retrieves the current callback for the events on the given [[Inlet]] */ final protected def getHandler(in: Inlet[_]): InHandler = { - inHandlers(in.id) + handlers(in.id).asInstanceOf[InHandler] } /** * Assigns callbacks for the events for an [[Outlet]] */ final protected def setHandler(out: Outlet[_], handler: OutHandler): Unit = { - handler.ownerStageLogic = this - outHandlers(out.id) = handler - if (_interpreter != null) _interpreter.setHandler(outToConn(out.id), handler) + handler.outOwnerStageLogic = this + handlers(out.id + inCount) = handler + if (_interpreter != null) _interpreter.setHandler(conn(out), handler) } + private def conn(in: Inlet[_]): Int = portToConn(in.id) + private def conn(out: Outlet[_]): Int = portToConn(out.id + inCount) + /** * Retrieves the current callback for the events on the given [[Outlet]] */ final protected def getHandler(out: Outlet[_]): OutHandler = { - outHandlers(out.id) + handlers(out.id + inCount).asInstanceOf[OutHandler] } private def getNonEmittingHandler(out: Outlet[_]): OutHandler = @@ -277,9 +265,6 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { case other ⇒ other } - private def conn[T](in: Inlet[T]): Int = inToConn(in.id) - private def conn[T](out: Outlet[T]): Int = outToConn(out.id) - /** * Requests an element on the given port. Calling this method twice before an element arrived will fail. * There can only be one outstanding request at any given time. The method [[hasBeenPulled()]] can be used @@ -374,10 +359,11 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { * used to check if the port is ready to be pushed or not. */ final protected def push[T](out: Outlet[T], elem: T): Unit = { - if ((interpreter.portStates(conn(out)) & (OutReady | OutClosed)) == OutReady) { + if ((interpreter.portStates(conn(out)) & (OutReady | OutClosed)) == OutReady && (elem != null)) { interpreter.push(conn(out), elem) } else { // Detailed error information should not add overhead to the hot path + ReactiveStreamsCompliance.requireNonNullElement(elem) require(isAvailable(out), "Cannot push port twice") require(!isClosed(out), "Cannot pull closed port") } @@ -399,13 +385,13 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { */ final def completeStage(): Unit = { var i = 0 - while (i < inToConn.length) { - interpreter.cancel(inToConn(i)) - i += 1 - } - i = 0 - while (i < outToConn.length) { - interpreter.complete(outToConn(i)) + while (i < portToConn.length) { + if (i < inCount) + interpreter.cancel(portToConn(i)) + else handlers(i) match { + case e: Emitting[_] ⇒ e.addFollowUp(new EmittingCompletion(e.out, e.previous)) + case _ ⇒ interpreter.complete(portToConn(i)) + } i += 1 } } @@ -416,13 +402,11 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { */ final def failStage(ex: Throwable): Unit = { var i = 0 - while (i < inToConn.length) { - interpreter.cancel(inToConn(i)) - i += 1 - } - i = 0 - while (i < outToConn.length) { - interpreter.fail(outToConn(i), ex) + while (i < portToConn.length) { + if (i < inCount) + interpreter.cancel(portToConn(i)) + else + interpreter.fail(portToConn(i), ex) i += 1 } } @@ -612,22 +596,53 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { case _ ⇒ setHandler(out, next) } - private abstract class Emitting[T](protected val out: Outlet[T], val previous: OutHandler, andThen: () ⇒ Unit) extends OutHandler { - private var followUps: mutable.Queue[Emitting[T]] = null + private abstract class Emitting[T](val out: Outlet[T], val previous: OutHandler, andThen: () ⇒ Unit) extends OutHandler { + private var followUps: Emitting[T] = _ + private var followUpsTail: Emitting[T] = _ + private def as[U] = this.asInstanceOf[Emitting[U]] protected def followUp(): Unit = { setHandler(out, previous) andThen() if (followUps != null) { - val next = followUps.dequeue() - if (followUps.nonEmpty) next.followUps = followUps - setHandler(out, next) + getHandler(out) match { + case e: Emitting[_] ⇒ e.as[T].addFollowUps(this) + case _ ⇒ + val next = dequeue() + if (next.isInstanceOf[EmittingCompletion[_]]) complete(out) + else setHandler(out, next) + } } } def addFollowUp(e: Emitting[T]): Unit = - if (followUps == null) followUps = mutable.Queue(e) - else followUps.enqueue(e) + if (followUps == null) { + followUps = e + followUpsTail = e + } else { + followUpsTail.followUps = e + followUpsTail = e + } + + private def addFollowUps(e: Emitting[T]): Unit = + if (followUps == null) { + followUps = e.followUps + followUpsTail = e.followUpsTail + } else { + followUpsTail.followUps = e.followUps + followUpsTail = e.followUpsTail + } + + /** + * Dequeue `this` from the head of the queue, meaning that this object will + * not be retained (setHandler will install the followUp). For this reason + * the followUpsTail knowledge needs to be passed on to the next runner. + */ + private def dequeue(): Emitting[T] = { + val ret = followUps + ret.followUpsTail = followUpsTail + ret + } override def onDownstreamFinish(): Unit = previous.onDownstreamFinish() } @@ -652,6 +667,10 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { } } + private class EmittingCompletion[T](_out: Outlet[T], _previous: OutHandler) extends Emitting(_out, _previous, DoNothing) { + override def onPull(): Unit = complete(out) + } + /** * Install a handler on the given inlet that emits received elements on the * given outlet before pulling for more data. `doTerminate` controls whether @@ -659,9 +678,10 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { */ final protected def passAlong[Out, In <: Out](from: Inlet[In], to: Outlet[Out], doFinish: Boolean, doFail: Boolean): Unit = setHandler(from, new InHandler { + val puller = () ⇒ tryPull(from) override def onPush(): Unit = { val elem = grab(from) - emit(to, elem, () ⇒ tryPull(from)) + emit(to, elem, puller) } override def onUpstreamFinish(): Unit = if (doFinish) super.onUpstreamFinish() override def onUpstreamFailure(ex: Throwable): Unit = if (doFail) super.onUpstreamFailure(ex) @@ -669,7 +689,7 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { /** * Obtain a callback object that can be used asynchronously to re-enter the - * current [[AsyncStage]] with an asynchronous notification. The [[invoke()]] method of the returned + * current [[GraphStage]] with an asynchronous notification. The [[invoke()]] method of the returned * [[AsyncCallback]] is safe to be called from other threads and it will in the background thread-safely * delegate to the passed callback function. I.e. [[invoke()]] will be called by the external world and * the passed handler will be invoked eventually in a thread-safe way by the execution environment. @@ -683,6 +703,50 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { } } + // Internal hooks to avoid reliance on user calling super in preStart + protected[stream] def beforePreStart(): Unit = () + + // Internal hooks to avoid reliance on user calling super in postStop + protected[stream] def afterPostStop(): Unit = () + + /** + * Invoked before any external events are processed, at the startup of the stage. + */ + def preStart(): Unit = () + + /** + * Invoked after processing of external events stopped because the stage is about to stop or fail. + */ + def postStop(): Unit = () +} + +/** + * An asynchronous callback holder that is attached to a [[GraphStageLogic]]. + * Invoking [[AsyncCallback#invoke]] will eventually lead to the registered handler + * being called. + */ +trait AsyncCallback[T] { + /** + * Dispatch an asynchronous notification. This method is thread-safe and + * may be invoked from external execution contexts. + */ + def invoke(t: T): Unit +} + +abstract class TimerGraphStageLogic(_shape: Shape) extends GraphStageLogic(_shape) { + import TimerMessages._ + + private val keyToTimers = mutable.Map[Any, Timer]() + private val timerIdGen = Iterator from 1 + + private var _timerAsyncCallback: AsyncCallback[Scheduled] = _ + private def getTimerAsyncCallback: AsyncCallback[Scheduled] = { + if (_timerAsyncCallback eq null) + _timerAsyncCallback = getAsyncCallback(onInternalTimer) + + _timerAsyncCallback + } + private def onInternalTimer(scheduled: Scheduled): Unit = { val Id = scheduled.timerId keyToTimers.get(scheduled.timerKey) match { @@ -694,12 +758,18 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { } /** - * Schedule timer to call [[#onTimer]] periodically with the given interval. - * Any existing timer with the same key will automatically be canceled before - * adding the new timer. + * Will be called when the scheduled timer is triggered. + * @param timerKey key of the scheduled timer */ - final protected def schedulePeriodically(timerKey: Any, interval: FiniteDuration): Unit = - schedulePeriodicallyWithInitialDelay(timerKey, interval, interval) + protected def onTimer(timerKey: Any): Unit = () + + // Internal hooks to avoid reliance on user calling super in postStop + protected[stream] override def afterPostStop(): Unit = { + if (keyToTimers ne null) { + keyToTimers.foreach { case (_, Timer(_, task)) ⇒ task.cancel() } + keyToTimers.clear() + } + } /** * Schedule timer to call [[#onTimer]] periodically with the given interval after the specified @@ -751,30 +821,13 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { final protected def isTimerActive(timerKey: Any): Boolean = keyToTimers contains timerKey /** - * Will be called when the scheduled timer is triggered. - * @param timerKey key of the scheduled timer + * Schedule timer to call [[#onTimer]] periodically with the given interval. + * Any existing timer with the same key will automatically be canceled before + * adding the new timer. */ - protected def onTimer(timerKey: Any): Unit = () + final protected def schedulePeriodically(timerKey: Any, interval: FiniteDuration): Unit = + schedulePeriodicallyWithInitialDelay(timerKey, interval, interval) - // Internal hooks to avoid reliance on user calling super in preStart - protected[stream] def beforePreStart(): Unit = { - } - - // Internal hooks to avoid reliance on user calling super in postStop - protected[stream] def afterPostStop(): Unit = { - keyToTimers.foreach { case (_, Timer(_, task)) ⇒ task.cancel() } - keyToTimers.clear() - } - - /** - * Invoked before any external events are processed, at the startup of the stage. - */ - def preStart(): Unit = () - - /** - * Invoked after processing of external events stopped because the stage is about to stop or fail. - */ - def postStop(): Unit = () } /** @@ -784,7 +837,7 @@ trait InHandler { /** * INTERNAL API */ - private[stream] var ownerStageLogic: GraphStageLogic = _ + private[stream] var inOwnerStageLogic: GraphStageLogic = _ /** * Called when the input port has a new element available. The actual element can be retrieved via the @@ -795,12 +848,12 @@ trait InHandler { /** * Called when the input port is finished. After this callback no other callbacks will be called for this port. */ - def onUpstreamFinish(): Unit = ownerStageLogic.completeStage() + def onUpstreamFinish(): Unit = inOwnerStageLogic.completeStage() /** * Called when the input port has failed. After this callback no other callbacks will be called for this port. */ - def onUpstreamFailure(ex: Throwable): Unit = ownerStageLogic.failStage(ex) + def onUpstreamFailure(ex: Throwable): Unit = inOwnerStageLogic.failStage(ex) } /** @@ -810,7 +863,7 @@ trait OutHandler { /** * INTERNAL API */ - private[stream] var ownerStageLogic: GraphStageLogic = _ + private[stream] var outOwnerStageLogic: GraphStageLogic = _ /** * Called when the output port has received a pull, and therefore ready to emit an element, i.e. [[GraphStageLogic.push()]] @@ -822,5 +875,5 @@ trait OutHandler { * Called when the output port will no longer accept any new elements. After this callback no other callbacks will * be called for this port. */ - def onDownstreamFinish(): Unit = ownerStageLogic.completeStage() + def onDownstreamFinish(): Unit = outOwnerStageLogic.completeStage() } \ No newline at end of file diff --git a/akka-stream/src/main/scala/akka/stream/stage/Stage.scala b/akka-stream/src/main/scala/akka/stream/stage/Stage.scala index 34c2acb981..5a3cc187c7 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/Stage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/Stage.scala @@ -3,7 +3,10 @@ */ package akka.stream.stage -import akka.stream.{ Attributes, Materializer, Supervision } +import akka.stream._ + +import scala.annotation.unchecked.uncheckedVariance +import scala.util.control.NonFatal /** * General interface for stream transformation. @@ -25,76 +28,169 @@ import akka.stream.{ Attributes, Materializer, Supervision } * @see [[akka.stream.scaladsl.Flow#transform]] * @see [[akka.stream.javadsl.Flow#transform]] */ -sealed trait Stage[-In, Out] +sealed trait Stage[-In, +Out] /** * INTERNAL API */ private[stream] object AbstractStage { - final val UpstreamBall = 1 - final val DownstreamBall = 2 - final val PrecedingWasPull = 0x4000 - final val NoTerminationPending = 0x8000 - final val BothBalls = UpstreamBall | DownstreamBall - final val BothBallsAndNoTerminationPending = UpstreamBall | DownstreamBall | NoTerminationPending + + private class PushPullGraphLogic[In, Out]( + private val shape: FlowShape[In, Out], + val attributes: Attributes, + val stage: AbstractStage[In, Out, Directive, Directive, Context[Out], LifecycleContext]) + extends GraphStageLogic(shape) with DetachedContext[Out] { + + final override def materializer: Materializer = interpreter.materializer + + private def ctx: DetachedContext[Out] = this + + private var currentStage: AbstractStage[In, Out, Directive, Directive, Context[Out], LifecycleContext] = stage + + { + // No need to refer to the handle in a private val + val handler = new InHandler with OutHandler { + override def onPush(): Unit = + try { currentStage.onPush(grab(shape.inlet), ctx) } catch { case NonFatal(ex) ⇒ onSupervision(ex) } + + override def onPull(): Unit = currentStage.onPull(ctx) + + override def onUpstreamFinish(): Unit = currentStage.onUpstreamFinish(ctx) + + override def onUpstreamFailure(ex: Throwable): Unit = currentStage.onUpstreamFailure(ex, ctx) + + override def onDownstreamFinish(): Unit = currentStage.onDownstreamFinish(ctx) + } + + setHandler(shape.inlet, handler) + setHandler(shape.outlet, handler) + } + + private def onSupervision(ex: Throwable): Unit = { + currentStage.decide(ex) match { + case Supervision.Stop ⇒ + failStage(ex) + case Supervision.Resume ⇒ + resetAfterSupervise() + case Supervision.Restart ⇒ + resetAfterSupervise() + currentStage.postStop() + currentStage = currentStage.restart().asInstanceOf[AbstractStage[In, Out, Directive, Directive, Context[Out], LifecycleContext]] + currentStage.preStart(ctx) + } + } + + private def resetAfterSupervise(): Unit = { + val mustPull = currentStage.isDetached || isAvailable(shape.outlet) + if (!hasBeenPulled(shape.inlet) && mustPull) pull(shape.inlet) + } + + override protected[stream] def beforePreStart(): Unit = { + super.beforePreStart() + if (currentStage.isDetached) pull(shape.inlet) + } + + final override def push(elem: Out): DownstreamDirective = { + push(shape.outlet, elem) + null + } + + final override def pull(): UpstreamDirective = { + pull(shape.inlet) + null + } + + final override def finish(): FreeDirective = { + completeStage() + null + } + + final override def pushAndFinish(elem: Out): DownstreamDirective = { + push(shape.outlet, elem) + completeStage() + null + } + + final override def fail(cause: Throwable): FreeDirective = { + failStage(cause) + null + } + + final override def isFinishing: Boolean = isClosed(shape.inlet) + + final override def absorbTermination(): TerminationDirective = { + if (isClosed(shape.outlet)) { + val ex = new UnsupportedOperationException("It is not allowed to call absorbTermination() from onDownstreamFinish.") + // This MUST be logged here, since the downstream has cancelled, i.e. there is noone to send onError to, the + // stage is just about to finish so noone will catch it anyway just the interpreter + + interpreter.log.error(ex.getMessage) + throw ex // We still throw for correctness (although a finish() would also work here) + } + if (isAvailable(shape.outlet)) currentStage.onPull(ctx) + null + } + + override def pushAndPull(elem: Out): FreeDirective = { + push(shape.outlet, elem) + pull(shape.inlet) + null + } + + final override def holdUpstreamAndPush(elem: Out): UpstreamDirective = { + push(shape.outlet, elem) + null + } + + final override def holdDownstreamAndPull(): DownstreamDirective = { + pull(shape.inlet) + null + } + + final override def isHoldingDownstream: Boolean = isAvailable(shape.outlet) + + final override def isHoldingUpstream: Boolean = !(isClosed(shape.inlet) || hasBeenPulled(shape.inlet)) + + final override def holdDownstream(): DownstreamDirective = null + + final override def holdUpstream(): UpstreamDirective = null + + override def preStart(): Unit = currentStage.preStart(ctx) + override def postStop(): Unit = currentStage.postStop() + + override def toString: String = s"PushPullGraphLogic($currentStage)" + } + + class PushPullGraphStageWithMaterializedValue[-In, +Out, Ext, +Mat]( + val factory: (Attributes) ⇒ (Stage[In, Out], Mat), + stageAttributes: Attributes) + extends GraphStageWithMaterializedValue[FlowShape[In, Out], Mat] { + + val name = stageAttributes.nameOrDefault() + val shape = FlowShape(Inlet[In](name + ".in"), Outlet[Out](name + ".out")) + + override def toString = name + + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Mat) = { + val effectiveAttributes = inheritedAttributes and stageAttributes + val stageAndMat = factory(effectiveAttributes) + val stage: AbstractStage[In, Out, Directive, Directive, Context[Out], LifecycleContext] = + stageAndMat._1.asInstanceOf[AbstractStage[In, Out, Directive, Directive, Context[Out], LifecycleContext]] + (new PushPullGraphLogic(shape, effectiveAttributes, stage), stageAndMat._2) + } + } + + class PushPullGraphStage[-In, +Out, Ext](_factory: (Attributes) ⇒ Stage[In, Out], _stageAttributes: Attributes) + extends PushPullGraphStageWithMaterializedValue[In, Out, Ext, Unit]((att: Attributes) ⇒ (_factory(att), ()), _stageAttributes) } abstract class AbstractStage[-In, Out, PushD <: Directive, PullD <: Directive, Ctx <: Context[Out], LifeCtx <: LifecycleContext] extends Stage[In, Out] { - /** - * INTERNAL API - */ - private[stream] var bits = AbstractStage.NoTerminationPending - - /** - * INTERNAL API - */ - private[stream] var context: Ctx = _ /** * INTERNAL API */ private[stream] def isDetached: Boolean = false - /** - * INTERNAL API - */ - private[stream] def enterAndPush(elem: Out): Unit = { - val c = context - c.enter() - c.push(elem) - c.execute() - } - - /** - * INTERNAL API - */ - private[stream] def enterAndPull(): Unit = { - val c = context - c.enter() - c.pull() - c.execute() - } - - /** - * INTERNAL API - */ - private[stream] def enterAndFinish(): Unit = { - val c = context - c.enter() - c.finish() - c.execute() - } - - /** - * INTERNAL API - */ - private[stream] def enterAndFail(e: Throwable): Unit = { - val c = context - c.enter() - c.fail(e) - c.execute() - } - /** * User overridable callback. *

@@ -285,26 +381,6 @@ abstract class DetachedStage[In, Out] override def decide(t: Throwable): Supervision.Directive = super.decide(t) } -/** - * This is a variant of [[DetachedStage]] that can receive asynchronous input - * from external sources, for example timers or Future results. In order to - * do this, obtain an [[AsyncCallback]] from the [[AsyncContext]] and attach - * it to the asynchronous event. When the event fires an asynchronous notification - * will be dispatched that eventually will lead to `onAsyncInput` being invoked - * with the provided data item. - */ -abstract class AsyncStage[In, Out, Ext] - extends AbstractStage[In, Out, UpstreamDirective, DownstreamDirective, AsyncContext[Out, Ext], AsyncContext[Out, Ext]] { - private[stream] override def isDetached = true - - /** - * Implement this method to define the action to be taken in response to an - * asynchronous notification that was previously registered using - * [[AsyncContext#getAsyncCallback]]. - */ - def onAsyncInput(event: Ext, ctx: AsyncContext[Out, Ext]): Directive -} - /** * The behavior of [[StatefulStage]] is defined by these two methods, which * has the same semantics as corresponding methods in [[PushPullStage]]. @@ -507,27 +583,6 @@ abstract class StatefulStage[In, Out] extends PushPullStage[In, Out] { } -/** - * INTERNAL API - * - * `BoundaryStage` implementations are meant to communicate with the external world. These stages do not have most of the - * safety properties enforced and should be used carefully. One important ability of BoundaryStages that they can take - * off an execution signal by calling `ctx.exit()`. This is typically used immediately after an external signal has - * been produced (for example an actor message). BoundaryStages can also kickstart execution by calling `enter()` which - * returns a context they can use to inject signals into the interpreter. There is no checks in place to enforce that - * the number of signals taken out by exit() and the number of signals returned via enter() are the same -- using this - * stage type needs extra care from the implementer. - * - * BoundaryStages are the elements that make the interpreter *tick*, there is no other way to start the interpreter - * than using a BoundaryStage. - */ -private[akka] abstract class BoundaryStage extends AbstractStage[Any, Any, Directive, Directive, BoundaryContext, LifecycleContext] { - final override def decide(t: Throwable): Supervision.Directive = Supervision.Stop - - final override def restart(): BoundaryStage = - throw new UnsupportedOperationException("BoundaryStage doesn't support restart") -} - /** * Return type from [[Context]] methods. */ @@ -555,16 +610,6 @@ trait LifecycleContext { * Passed to the callback methods of [[PushPullStage]] and [[StatefulStage]]. */ sealed trait Context[Out] extends LifecycleContext { - /** - * INTERNAL API - */ - private[stream] def enter(): Unit - - /** - * INTERNAL API - */ - private[stream] def execute(): Unit - /** * Push one element to downstreams. */ @@ -625,48 +670,3 @@ trait DetachedContext[Out] extends Context[Out] { def pushAndPull(elem: Out): FreeDirective } - -/** - * An asynchronous callback holder that is attached to an [[AsyncContext]]. - * Invoking [[AsyncCallback#invoke]] will eventually lead to [[AsyncStage#onAsyncInput]] - * being called. - */ -trait AsyncCallback[T] { - /** - * Dispatch an asynchronous notification. This method is thread-safe and - * may be invoked from external execution contexts. - */ - def invoke(t: T): Unit -} - -/** - * This kind of context is available to [[AsyncStage]]. It implements the same - * interface as for [[DetachedStage]] with the addition of being able to obtain - * [[AsyncCallback]] objects that allow the registration of asynchronous - * notifications. - */ -trait AsyncContext[Out, Ext] extends DetachedContext[Out] { - /** - * Obtain a callback object that can be used asynchronously to re-enter the - * current [[AsyncStage]] with an asynchronous notification. After the - * notification has been invoked, eventually [[AsyncStage#onAsyncInput]] - * will be called with the given data item. - * - * This object can be cached and reused within the same [[AsyncStage]]. - */ - def getAsyncCallback: AsyncCallback[Ext] - /** - * In response to an asynchronous notification an [[AsyncStage]] may choose - * to neither push nor pull nor terminate, which is represented as this - * directive. - */ - def ignore(): AsyncDirective -} - -/** - * INTERNAL API - */ -private[akka] trait BoundaryContext extends Context[Any] { - def exit(): FreeDirective -} -