From 556012b7eeda072137010bf6e4f6f2fa241bb3e2 Mon Sep 17 00:00:00 2001 From: Roland Kuhn Date: Sat, 31 Oct 2015 14:46:10 +0100 Subject: [PATCH] !str,htc replace and remove OneBoundedInterpreter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit main work by @drewhk with contributions from @2m and @rkuhn This work uncovered many well-hidden bugs in existing Stages, in particular StatefulStage. These were hidden by the behavior of OneBoundedInterpreter that normally behaves more orderly than it guarantees in general, especially with respect to the timeliness of delivery of upstream termination signals; the bugs were then that internal state was not flushed when onComplete arrived “too early”. --- .../akka/stream/InterpreterBenchmark.scala | 79 +- .../server/HttpServerExampleDocTest.java | 4 +- akka-docs-dev/rst/migration-guide-1.0-2.x.rst | 263 ++++++ .../http/scaladsl/HttpServerExampleSpec.scala | 4 +- .../FutureDirectivesExamplesSpec.scala | 3 - .../RouteDirectivesExamplesSpec.scala | 3 +- .../stream/RateTransformationDocSpec.scala | 9 +- .../docs/stream/ReactiveStreamsDocSpec.scala | 2 +- .../stream/cookbook/RecipeByteStrings.scala | 10 +- .../stream/cookbook/RecipeKeepAlive.scala | 2 + .../stream/cookbook/RecipeMissedTicks.scala | 11 +- .../stream/cookbook/RecipeSimpleDrop.scala | 10 +- .../src/test/resources/application.conf | 1 + .../akka/http/javadsl/model/HttpEntity.java | 2 +- .../client/OutgoingConnectionBlueprint.scala | 4 +- .../impl/engine/client/PoolConductor.scala | 4 +- .../http/impl/engine/client/PoolSlot.scala | 4 +- .../engine/server/HttpServerBluePrint.scala | 8 +- .../akka/http/impl/engine/ws/FrameEvent.scala | 6 +- .../impl/engine/ws/FrameEventParser.scala | 154 ++-- .../http/impl/engine/ws/FrameHandler.scala | 18 +- .../akka/http/impl/engine/ws/Masking.scala | 35 +- .../akka/http/impl/engine/ws/Websocket.scala | 12 +- .../akka/http/impl/util/ByteReader.scala | 6 +- .../scala/akka/http/impl/util/package.scala | 4 +- .../akka/http/scaladsl/model/HttpEntity.scala | 8 +- .../akka/http/scaladsl/model/Multipart.scala | 4 +- .../src/test/resources/reference.conf | 1 + .../engine/client/ConnectionPoolSpec.scala | 8 +- .../LowLevelOutgoingConnectionSpec.scala | 6 +- .../client/TlsEndpointVerificationSpec.scala | 8 +- .../http/impl/engine/ws/FramingSpec.scala | 3 +- .../http/impl/engine/ws/MessageSpec.scala | 20 +- .../akka/http/scaladsl/ClientServerSpec.scala | 7 +- .../src/test/resources/reference.conf | 1 + .../scaladsl/server/BasicRouteSpecs.scala | 7 +- .../directives/ExecutionDirectivesSpec.scala | 14 +- .../directives/FutureDirectivesSpec.scala | 6 +- .../directives/RouteDirectivesSpec.scala | 5 +- .../directives/SecurityDirectivesSpec.scala | 25 +- .../stream/tck/FusableProcessorTest.scala | 6 +- .../akka/stream/testkit/StreamTestKit.scala | 14 +- .../akka/stream/javadsl/AttributesTest.java | 8 +- .../java/akka/stream/javadsl/FlowTest.java | 19 +- .../java/akka/stream/javadsl/SourceTest.java | 9 +- .../{reference.conf => application.conf} | 1 + .../akka/stream/DslConsistencySpec.scala | 8 +- .../impl/ActorInterpreterLifecycleSpec.scala | 120 --- .../stream/impl/ActorInterpreterSpec.scala | 190 ----- .../stream/impl/GraphStageLogicSpec.scala | 177 +++++ .../akka/stream/impl/StreamLayoutSpec.scala | 2 +- .../fusing/ActorGraphInterpreterSpec.scala | 8 +- .../impl/fusing/GraphInterpreterSpec.scala | 2 + .../impl/fusing/GraphInterpreterSpecKit.scala | 200 ++++- .../stream/impl/fusing/InterpreterSpec.scala | 138 ++-- .../impl/fusing/InterpreterSpecKit.scala | 183 ----- .../impl/fusing/InterpreterStressSpec.scala | 31 +- .../fusing/InterpreterSupervisionSpec.scala | 49 +- .../fusing/LifecycleInterpreterSpec.scala | 78 +- .../akka/stream/io/InputStreamSinkSpec.scala | 8 +- .../stream/io/OutputStreamSourceSpec.scala | 8 +- .../stream/io/SynchronousFileSinkSpec.scala | 2 + .../stream/io/SynchronousFileSourceSpec.scala | 2 + .../test/scala/akka/stream/io/TcpSpec.scala | 11 +- .../test/scala/akka/stream/io/TlsSpec.scala | 49 +- .../akka/stream/scaladsl/AttributesSpec.scala | 65 ++ .../stream/scaladsl/FlowMapAsyncSpec.scala | 123 ++- .../scaladsl/FlowMapAsyncUnorderedSpec.scala | 64 +- .../scala/akka/stream/scaladsl/FlowSpec.scala | 61 +- ...ec.scala => GraphMergePreferredSpec.scala} | 2 +- .../scaladsl/GraphStageTimersSpec.scala | 18 +- .../akka/stream/scaladsl/SourceSpec.scala | 8 +- .../scaladsl/SubscriberSourceSpec.scala | 2 +- .../SubstreamSubscriptionTimeoutSpec.scala | 2 +- .../scaladsl/UnzipWithApply.scala.template | 2 +- .../scaladsl/ZipWithApply.scala.template | 2 +- .../scala/akka/stream/ActorMaterializer.scala | 36 +- .../main/scala/akka/stream/Attributes.scala | 56 +- .../stream/impl/ActorMaterializerImpl.scala | 67 +- .../akka/stream/impl/FixedSizeBuffer.scala | 1 + .../main/scala/akka/stream/impl/Stages.scala | 126 +-- .../scala/akka/stream/impl/StreamLayout.scala | 8 +- .../impl/fusing/ActorGraphInterpreter.scala | 15 +- .../stream/impl/fusing/ActorInterpreter.scala | 389 --------- .../stream/impl/fusing/GraphInterpreter.scala | 178 +++-- .../akka/stream/impl/fusing/GraphStages.scala | 19 +- .../impl/fusing/IteratorInterpreter.scala | 130 ++- .../impl/fusing/OneBoundedInterpreter.scala | 747 ------------------ .../scala/akka/stream/impl/fusing/Ops.scala | 247 +++--- .../akka/stream/impl/io/IOSettings.scala | 17 +- .../scala/akka/stream/impl/io/IOSinks.scala | 8 +- .../scala/akka/stream/impl/io/IOSources.scala | 5 +- .../stream/impl/io/InputStreamSinkStage.scala | 5 +- .../impl/io/OutputStreamSourceStage.scala | 3 +- .../stream/impl/io/TcpListenStreamActor.scala | 4 +- .../akka/stream/io/ByteStringParser.scala | 119 +++ .../akka/stream/io/SynchronousFileSink.scala | 5 +- .../stream/io/SynchronousFileSource.scala | 9 +- .../main/scala/akka/stream/io/Timeouts.scala | 22 +- .../main/scala/akka/stream/javadsl/Flow.scala | 10 +- .../scala/akka/stream/scaladsl/Flow.scala | 91 ++- .../scala/akka/stream/scaladsl/Graph.scala | 179 ++--- .../stream/scaladsl/One2OneBidiFlow.scala | 2 +- .../scala/akka/stream/scaladsl/Source.scala | 13 +- .../main/scala/akka/stream/scaladsl/Tcp.scala | 2 +- .../scala/akka/stream/stage/GraphStage.scala | 245 +++--- .../main/scala/akka/stream/stage/Stage.scala | 306 +++---- 107 files changed, 2456 insertions(+), 3061 deletions(-) create mode 100644 akka-docs-dev/rst/migration-guide-1.0-2.x.rst create mode 100644 akka-docs-dev/src/test/resources/application.conf rename akka-stream-tests/src/test/resources/{reference.conf => application.conf} (62%) delete mode 100644 akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterLifecycleSpec.scala delete mode 100644 akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterSpec.scala create mode 100644 akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala delete mode 100644 akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpecKit.scala create mode 100644 akka-stream-tests/src/test/scala/akka/stream/scaladsl/AttributesSpec.scala rename akka-stream-tests/src/test/scala/akka/stream/scaladsl/{GraphPreferredMergeSpec.scala => GraphMergePreferredSpec.scala} (97%) delete mode 100644 akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala delete mode 100644 akka-stream/src/main/scala/akka/stream/impl/fusing/OneBoundedInterpreter.scala create mode 100644 akka-stream/src/main/scala/akka/stream/io/ByteStringParser.scala diff --git a/akka-bench-jmh-dev/src/main/scala/akka/stream/InterpreterBenchmark.scala b/akka-bench-jmh-dev/src/main/scala/akka/stream/InterpreterBenchmark.scala index c3af314393..c997b6f4c7 100644 --- a/akka-bench-jmh-dev/src/main/scala/akka/stream/InterpreterBenchmark.scala +++ b/akka-bench-jmh-dev/src/main/scala/akka/stream/InterpreterBenchmark.scala @@ -1,7 +1,7 @@ package akka.stream import akka.event._ -import akka.stream.impl.fusing.{ GraphInterpreterSpecKit, GraphStages, Map => MapStage, OneBoundedInterpreter } +import akka.stream.impl.fusing.{ GraphInterpreterSpecKit, GraphStages, Map => MapStage } import akka.stream.impl.fusing.GraphStages.Identity import akka.stream.impl.fusing.GraphInterpreter.{ DownstreamBoundaryStageLogic, UpstreamBoundaryStageLogic } import akka.stream.stage._ @@ -18,10 +18,10 @@ class InterpreterBenchmark { import InterpreterBenchmark._ // manual, and not via @Param, because we want @OperationsPerInvocation on our tests - final val data100k = (1 to 100000).toVector + final val data100k: Vector[Int] = (1 to 100000).toVector @Param(Array("1", "5", "10")) - val numberOfIds = 0 + val numberOfIds: Int = 0 @Benchmark @OperationsPerInvocation(100000) @@ -47,33 +47,13 @@ class InterpreterBenchmark { } } } - - @Benchmark - @OperationsPerInvocation(100000) - def onebounded_interpreter_100k_elements() { - val lock = new Lock() - lock.acquire() - val sink = OneBoundedDataSink(data100k.size) - val ops = Vector.fill(numberOfIds)(new PushPullStage[Int, Int] { - override def onPull(ctx: _root_.akka.stream.stage.Context[Int]) = ctx.pull() - override def onPush(elem: Int, ctx: _root_.akka.stream.stage.Context[Int]) = ctx.push(elem) - }) - val interpreter = new OneBoundedInterpreter(OneBoundedDataSource(data100k) +: ops :+ sink, - (op, ctx, event) ⇒ (), - Logging(NoopBus, classOf[InterpreterBenchmark]), - null, - Attributes.none, - forkLimit = 100, overflowToHeap = false) - interpreter.init() - sink.requestOne() - } } object InterpreterBenchmark { case class GraphDataSource[T](override val toString: String, data: Vector[T]) extends UpstreamBoundaryStageLogic[T] { - var idx = 0 - val out = Outlet[T]("out") + var idx: Int = 0 + override val out: akka.stream.Outlet[T] = Outlet[T]("out") out.id = 0 setHandler(out, new OutHandler { @@ -81,8 +61,7 @@ object InterpreterBenchmark { if (idx < data.size) { push(out, data(idx)) idx += 1 - } - else { + } else { completeStage() } } @@ -91,7 +70,7 @@ object InterpreterBenchmark { } case class GraphDataSink[T](override val toString: String, var expected: Int) extends DownstreamBoundaryStageLogic[T] { - val in = Inlet[T]("in") + override val in: akka.stream.Inlet[T] = Inlet[T]("in") in.id = 0 setHandler(in, new InHandler { @@ -104,49 +83,7 @@ object InterpreterBenchmark { override def onUpstreamFailure(ex: Throwable): Unit = failStage(ex) }) - def requestOne() = pull(in) - } - - case class OneBoundedDataSource[T](data: Vector[T]) extends BoundaryStage { - var idx = 0 - - override def onDownstreamFinish(ctx: BoundaryContext): TerminationDirective = { - ctx.finish() - } - - override def onPull(ctx: BoundaryContext): Directive = { - if (idx < data.size) { - idx += 1 - ctx.push(data(idx - 1)) - } - else { - ctx.finish() - } - } - - override def onPush(elem: Any, ctx: BoundaryContext): Directive = - throw new UnsupportedOperationException("Cannot push the boundary") - } - - case class OneBoundedDataSink(var expected: Int) extends BoundaryStage { - override def onPush(elem: Any, ctx: BoundaryContext): Directive = { - expected -= 1 - if (expected == 0) ctx.exit() - else ctx.pull() - } - - override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = { - ctx.finish() - } - - override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = { - ctx.finish() - } - - override def onPull(ctx: BoundaryContext): Directive = - throw new UnsupportedOperationException("Cannot pull the boundary") - - def requestOne(): Unit = enterAndPull() + def requestOne(): Unit = pull(in) } val NoopBus = new LoggingBus { diff --git a/akka-docs-dev/rst/java/code/docs/http/javadsl/server/HttpServerExampleDocTest.java b/akka-docs-dev/rst/java/code/docs/http/javadsl/server/HttpServerExampleDocTest.java index a5b2c4990c..a163e305cf 100644 --- a/akka-docs-dev/rst/java/code/docs/http/javadsl/server/HttpServerExampleDocTest.java +++ b/akka-docs-dev/rst/java/code/docs/http/javadsl/server/HttpServerExampleDocTest.java @@ -159,8 +159,8 @@ public class HttpServerExampleDocTest { Flow.of(HttpRequest.class) .via(failureDetection) .map(request -> { - Source bytes = request.entity().getDataBytes(); - HttpEntity.Chunked entity = HttpEntities.create(ContentTypes.TEXT_PLAIN, (Source) bytes); + Source bytes = request.entity().getDataBytes(); + HttpEntity.Chunked entity = HttpEntities.create(ContentTypes.TEXT_PLAIN, bytes); return HttpResponse.create() .withEntity(entity); diff --git a/akka-docs-dev/rst/migration-guide-1.0-2.x.rst b/akka-docs-dev/rst/migration-guide-1.0-2.x.rst new file mode 100644 index 0000000000..65e09e2433 --- /dev/null +++ b/akka-docs-dev/rst/migration-guide-1.0-2.x.rst @@ -0,0 +1,263 @@ +.. _migration-2.0: + +############################ + Migration Guide 1.0 to 2.x +############################ + +The 2.0 release contains some structural changes that require some +simple, mechanical source-level changes in client code. + + +Introduced proper named constructor methods insted of ``wrap()`` +================================================================ + +There were several, unrelated uses of ``wrap()`` which made it hard to find and hard to understand the intention of +the call. Therefore these use-cases now have methods with different names, helping Java 8 type inference (by reducing +the number of overloads) and finding relevant methods in the documentation. + +Creating a Flow from other stages +--------------------------------- + +It was possible to create a ``Flow`` from a graph with the correct shape (``FlowShape``) using ``wrap()``. Now this +must be done with the more descriptive method ``Flow.fromGraph()``. + +It was possible to create a ``Flow`` from a ``Source`` and a ``Sink`` using ``wrap()``. Now this functionality can +be accessed trough the more descriptive methods ``Flow.fromSinkAndSource`` and ``Flow.fromSinkAndSourceMat``. + +Creating a BidiFlow from other stages +------------------------------------- + +It was possible to create a ``BidiFlow`` from a graph with the correct shape (``BidiShape``) using ``wrap()``. Now this +must be done with the more descriptive method ``BidiFlow.fromGraph()``. + +It was possible to create a ``BidiFlow`` from two ``Flow`` s using ``wrap()``. Now this functionality can +be accessed trough the more descriptive methods ``BidiFlow.fromFlows`` and ``BidiFlow.fromFlowsMat``. + +It was possible to create a ``BidiFlow`` from two functions using ``apply()`` (Scala DSL) or ``create()`` (Java DSL). +Now this functionality can be accessed trough the more descriptive method ``BidiFlow.fromFunctions``. + +Update procedure +---------------- + +1. Replace all uses of ``Flow.wrap`` when it converts a ``Graph`` to a ``Flow`` with ``Flow.fromGraph`` +2. Replace all uses of ``Flow.wrap`` when it converts a ``Source`` and ``Sink`` to a ``Flow`` with + ``Flow.fromSinkAndSource`` or ``Flow.fromSinkAndSourceMat`` +3. Replace all uses of ``BidiFlow.wrap`` when it converts a ``Graph`` to a ``BidiFlow`` with ``BidiFlow.fromGraph`` +4. Replace all uses of ``BidiFlow.wrap`` when it converts two ``Flow``s to a ``BidiFlow`` with + ``BidiFlow.fromFlows`` or ``BidiFlow.fromFlowsMat`` +5. Repplace all uses of ``BidiFlow.apply()`` (Scala DSL) or ``BidiFlow.create()`` (Java DSL) when it converts two + functions to a ``BidiFlow`` with ``BidiFlow.fromFunctions`` + +TODO: Code example + +FlowGraph builder methods have been renamed +=========================================== + +There is now only one graph creation method called ``create`` which is analogous to the old ``partial`` method. For +closed graphs now it is explicitly required to return ``ClosedShape`` at the end of the builder block. + +Update procedure +---------------- + +1. Replace all occurrences of ``FlowGraph.create()`` with ``FlowGraph.partial()`` +2. Add ``ClosedShape`` as a return value of the builder block + +TODO: Code sample + +Methods that create Source, Sink, Flow from Graphs have been removed +==================================================================== + +Previously there were convenience methods available on ``Sink``, ``Source``, ``Flow`` an ``BidiFlow`` to create +these DSL elements from a graph builder directly. Now this requires two explicit steps to reduce the number of overloaded +methods (helps Java 8 type inference) and also reduces the ways how these elements can be created. There is only one +graph creation method to learn (``FlowGraph.create``) and then there is only one conversion method to use ``fromGraph()``. + +This means that the following methods have been removed: + - ``adapt()`` method on ``Source``, ``Sink``, ``Flow`` and ``BidiFlow`` (both DSLs) + - ``apply()`` overloads providing a graph ``Builder`` on ``Source``, ``Sink``, ``Flow`` and ``BidiFlow`` (Scala DSL) + - ``create()`` overloads providing a graph ``Builder`` on ``Source``, ``Sink``, ``Flow`` and ``BidiFlow`` (Java DSL) + +Update procedure +---------------- + +Everywhere where ``Source``, ``Sink``, ``Flow`` and ``BidiFlow`` is created from a graph using a builder have to +be replaced with two steps + +1. Create a ``Graph`` with the correct ``Shape`` using ``FlowGraph.create`` (e.g.. for ``Source`` it means first + creating a ``Graph`` with ``SourceShape``) +2. Create the required DSL element by calling ``fromGraph()`` on the required DSL element (e.g. ``Source.fromGraph``) + passing the graph created in the previous step + +TODO code example + +Some graph Builder methods in the Java DSL have been renamed +============================================================ + +Due to the high number of overloads Java 8 type inference suffered, and it was also hard to figure out which time +to use which method. Therefore various redundant methods have been removed. + +Update procedure +---------------- + +1. All uses of builder.addEdge(Outlet, Inlet) should be replaced by the alternative builder.from(…).to(…) +2. All uses of builder.addEdge(Outlet, FlowShape, Inlet) should be replaced by builder.from(…).via(…).to(…) + +Builder.source => use builder.from(…).via(…).to(…) +Builder.flow => use builder.from(…).via(…).to(…) +Builder.sink => use builder.from(…).via(…).to(…) + +TODO: code example + +Builder overloads from the Scala DSL have been removed +====================================================== + +scaladsl.Builder.addEdge(Outlet, Inlet) => use the DSL (~> and <~) +scaladsl.Builder.addEdge(Outlet, FlowShape, Inlet) => use the DSL (~> and <~) + +Source constructor name changes +=============================== + +``Source.lazyEmpty`` have been replaced by ``Source.maybe`` which returns a ``Promise`` that can be completed by one or +zero elements by providing an ``Option``. This is different from ``lazyEmpty`` which only allowed completion to be +sent, but no elements. + +The ``apply()`` and ``from()`` overloads on ``Source`` that provide a tick source (``Source(delay,interval,tick)``) +are replaced by the named method ``Source.tick()`` to reduce the number of overloads and to make the function more +discoverable. + +Update procedure +---------------- + +1. Replace all uses of ``Source(delay,interval,tick)`` and ``Source.from(delay,interval,tick)`` with the method + ``Source.tick()`` +2. All uses of ``Source.lazyEmpty`` should be replaced by ``Source.maybe`` and the returned ``Promise`` completed with + a ``None`` (an empty ``Option``) + +TODO: code example + +``Flow.empty()`` has been removed from the Java DSL +=================================================== + +The ``empty()`` method has been removed since it behaves exactly the same as ``create()``, creating a ``Flow`` with no +transformations added yet. + +Update procedure +---------------- + +1. Replace all uses of ``Flow.empty()`` with ``Flow.create``. + +TODO: code example + +``flatten(FlattenStrategy)`` has been replaced by named counterparts +==================================================================== + +To simplify type inference in Java 8 and to make the method more discoverable, ``flatten(FlattenStrategy.concat)`` +has been removed and replaced with the alternative method ``flatten(FlattenStrategy.concat)``. + +Update procedure +---------------- + +1. Replace all occurences of ``flatten(FlattenStrategy.concat)`` with ``flattenConcat()`` + +TODO: code example + +FlexiMerge an FlexiRoute has been replaced by GraphStage +======================================================== + +The ``FlexiMerge`` and ``FlexiRoute`` DSLs have been removed since they provided an abstraction that was too limiting +and a better abstraction have been created which is called ``GraphStage``. ``GraphStage`` can express fan-in and +fan-out stages, but many other constructs as well with possibly multiple input and output ports (e.g. a ``BidiStage``). + +This new abstraction provides a more uniform way to crate custom stream processing stages of arbitrary ``Shape``. In +fact, all of the built-in fan-in and fan-out stages are now implemented in terms of ``GraphStage``. + +Update procedure +---------------- + +*There is no simple update procedure. The affected stages must be ported to the new ``GraphStage`` DSL manually. Please +read the* ``GraphStage`` *documentation (TODO) for details.* + +Variance of Inlet and Outlet (Scala DSL) +======================================== + +Scala uses *declaration site variance* which was cumbersome in the cases of ``Inlet`` and ``Outlet`` as they are +purely symbolic object containing no fields or methods and which are used both in input and output locations (wiring +an ``Outlet`` into an ``Inlet``; reading in a stage from an ``Inlet``). Because of this reasons all users of these +port abstractions now use *use-site variance* (just like Java variance works). This in general does not affect user +code expect the case of custom shapes, which now require ``@uncheckedVariance`` annotations on their ``Inlet`` and +``Outlet`` members (since these are now invariant, but the Scala compiler does not know that they have no fields or +methods that would violate variance constraints) + +This change does not affect Java DSL users. + +TODO: code example + +Update procedure +---------------- + +1. All custom shapes must use ``@uncheckedVariance`` on their ``Inlet`` and ``Outlet`` members. + +Semantic change in ``isHoldingUpstream`` in the DetachedStage DSL +================================================================= + +The ``isHoldingUpstream`` method used to return true if the upstream port was in holding state and a completion arrived +(inside the ``onUpstreamFinished`` callback). Now it returns ``false`` when the upstream is completed. + +Update procedure +---------------- + +1. Those stages that relied on the previous behavior need to introduce an extra ``Boolean`` field with initial value + ``false`` +2. This field must be set on every call to ``holdUpstream()`` (and variants). +3. In completion, instead of calling ``isHoldingUpstream`` read this variable instead. + +TODO: code example + + +AsyncStage has been replaced by GraphStage +========================================== + +Due to its complexity and relative inflexibility ``AsyncStage`` have been removed. + +TODO explanation + +Update procedure +---------------- + +1. The subclass of ``AsyncStage`` should be replaced by ``GraphStage`` +2. The new subclass must define an ``in`` and ``out`` port (``Inlet`` and ``Outlet`` instance) and override the ``shape`` + method returning a ``FlowShape`` +3. An instance of ``GraphStageLogic`` must be returned by overriding ``createLogic()``. The original processing logic and + state will be encapsulated in this ``GraphStageLogic`` +4. Using ``setHandler(port, handler)`` and ``InHandler`` instance should be set on ``in`` and an ``OutHandler`` should + be set on ``out`` +5. ``onPush``, ``onUpstreamFinished`` and ``onUpstreamFailed`` are now available in the ``InHandler`` subclass created + by the user +6. ``onPull`` and ``onDownstreamFinished`` are now available in the ``OutHandler`` subclass created by the user +7. the callbacks above no longer take an extra `ctxt` context parameter. +8. ``onPull`` only signals the stage, the actual element can be obtained by calling ``grab(in)`` +9. ``ctx.push(elem)`` is now ``push(out, elem)`` +10. ``ctx.pull()`` is now ``pull(in)`` +11. ``ctx.finish()`` is now ``completeStage()`` +12. ``ctx.pushAndFinish(elem)`` is now simply two calls: ``push(out, elem); completeStage()`` +13. ``ctx.fail(cause)`` is now ``failStage(cause)`` +14. ``ctx.isFinishing()`` is now ``isClosed(in)`` +15. ``ctx.absorbTermination()`` can be replaced with ``if (isAvailable(shape.outlet)) `` +16. ``ctx.pushAndPull(elem)`` can be replaced with ``push(out, elem); pull(in)`` +17. ``ctx.holdUpstreamAndPush`` and ``context.holdDownstreamAndPull`` can be replaced by simply ``push(elem)`` and + ``pull()`` respectively +18. The following calls should be removed: ``ctx.ignore()``, ``ctx.holdUpstream()`` and ``ctx.holdDownstream()``. +19. ``ctx.isHoldingUpstream()`` can be replaced with ``isAvailable(out)`` +20. ``ctx.isHoldingDowntream()`` can be replaced with ``!(isClosed(in) || hasBeenPulled(in))`` +21. ``ctx.getAsyncCallback()`` is now ``getAsyncCallback(callback)`` which now takes a callback as a parameter. This + would correspond to the ``onAsyncInput()`` callback in the original ``AsyncStage`` + +We show the necessary steps in terms of an example ``AsyncStage`` + +TODO: code sample + + + +TODO: Code example + + diff --git a/akka-docs-dev/rst/scala/code/docs/http/scaladsl/HttpServerExampleSpec.scala b/akka-docs-dev/rst/scala/code/docs/http/scaladsl/HttpServerExampleSpec.scala index 8834a98e37..19b4a8fe3d 100644 --- a/akka-docs-dev/rst/scala/code/docs/http/scaladsl/HttpServerExampleSpec.scala +++ b/akka-docs-dev/rst/scala/code/docs/http/scaladsl/HttpServerExampleSpec.scala @@ -352,7 +352,7 @@ class HttpServerExampleSpec extends WordSpec with Matchers { pathEnd { (put | parameter('method ! "put")) { // form extraction from multipart or www-url-encoded forms - formFields('email, 'total.as[Money]).as(Order) { order => + formFields(('email, 'total.as[Money])).as(Order) { order => complete { // complete with serialized Future result (myDbActor ? Update(order)).mapTo[TransactionResult] @@ -373,7 +373,7 @@ class HttpServerExampleSpec extends WordSpec with Matchers { path("items") { get { // parameters to case class extraction - parameters('size.as[Int], 'color ?, 'dangerous ? "no") + parameters(('size.as[Int], 'color ?, 'dangerous ? "no")) .as(OrderItem) { orderItem => // ... route using case class instance created from // required and optional query parameters diff --git a/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/FutureDirectivesExamplesSpec.scala b/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/FutureDirectivesExamplesSpec.scala index 2f6cbb62dc..6beb509d1b 100644 --- a/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/FutureDirectivesExamplesSpec.scala +++ b/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/FutureDirectivesExamplesSpec.scala @@ -26,9 +26,6 @@ class FutureDirectivesExamplesSpec extends RoutingSpec { ctx.complete((InternalServerError, "Unsuccessful future!")) } - val resourceActor = system.actorOf(Props(new Actor { - def receive = { case _ => sender ! "resource" } - })) implicit val responseTimeout = Timeout(2, TimeUnit.SECONDS) "onComplete" in { diff --git a/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/RouteDirectivesExamplesSpec.scala b/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/RouteDirectivesExamplesSpec.scala index 49ff680801..74bbaac0ca 100644 --- a/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/RouteDirectivesExamplesSpec.scala +++ b/akka-docs-dev/rst/scala/code/docs/http/scaladsl/server/directives/RouteDirectivesExamplesSpec.scala @@ -7,6 +7,7 @@ package directives import akka.http.scaladsl.model._ import akka.http.scaladsl.server.{ Route, ValidationRejection } +import akka.testkit.EventFilter class RouteDirectivesExamplesSpec extends RoutingSpec { @@ -83,7 +84,7 @@ class RouteDirectivesExamplesSpec extends RoutingSpec { } } - "failwith-examples" in { + "failwith-examples" in EventFilter[RuntimeException](start = "Error during processing of request", occurrences = 1).intercept { val route = path("foo") { failWith(new RuntimeException("Oops.")) diff --git a/akka-docs-dev/rst/scala/code/docs/stream/RateTransformationDocSpec.scala b/akka-docs-dev/rst/scala/code/docs/stream/RateTransformationDocSpec.scala index 8365457e13..1152a648c5 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/RateTransformationDocSpec.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/RateTransformationDocSpec.scala @@ -12,6 +12,7 @@ import scala.math._ import scala.concurrent.Await import scala.concurrent.duration._ import scala.collection.immutable +import akka.testkit.TestLatch class RateTransformationDocSpec extends AkkaSpec { @@ -84,9 +85,14 @@ class RateTransformationDocSpec extends AkkaSpec { case (lastElement, drift) => ((lastElement, drift), (lastElement, drift + 1)) } //#expand-drift + val latch = TestLatch(2) + val realDriftFlow = Flow[Double] + .expand(d => { latch.countDown(); (d, 0) }) { + case (lastElement, drift) => ((lastElement, drift), (lastElement, drift + 1)) + } val (pub, sub) = TestSource.probe[Double] - .via(driftFlow) + .via(realDriftFlow) .toMat(TestSink.probe[(Double, Int)])(Keep.both) .run() @@ -98,6 +104,7 @@ class RateTransformationDocSpec extends AkkaSpec { sub.requestNext((1.0, 2)) pub.sendNext(2.0) + Await.ready(latch, 1.second) sub.requestNext((2.0, 0)) } diff --git a/akka-docs-dev/rst/scala/code/docs/stream/ReactiveStreamsDocSpec.scala b/akka-docs-dev/rst/scala/code/docs/stream/ReactiveStreamsDocSpec.scala index 079f5c3f30..a001829af3 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/ReactiveStreamsDocSpec.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/ReactiveStreamsDocSpec.scala @@ -139,7 +139,7 @@ class ReactiveStreamsDocSpec extends AkkaSpec { // An example Processor factory def createProcessor: Processor[Int, Int] = Flow[Int].toProcessor.run() - val flow: Flow[Int, Int, Unit] = Flow(() => createProcessor) + val flow: Flow[Int, Int, Unit] = Flow.fromProcessor(() => createProcessor) //#use-processor } diff --git a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeByteStrings.scala b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeByteStrings.scala index 4909d21f57..db7a994681 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeByteStrings.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeByteStrings.scala @@ -27,9 +27,15 @@ class RecipeByteStrings extends RecipeSpec { override def onPull(ctx: Context[ByteString]): SyncDirective = emitChunkOrPull(ctx) + override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = + if (buffer.nonEmpty) ctx.absorbTermination() + else ctx.finish() + private def emitChunkOrPull(ctx: Context[ByteString]): SyncDirective = { - if (buffer.isEmpty) ctx.pull() - else { + if (buffer.isEmpty) { + if (ctx.isFinishing) ctx.finish() + else ctx.pull() + } else { val (emit, nextBuffer) = buffer.splitAt(chunkSize) buffer = nextBuffer ctx.push(emit) diff --git a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeKeepAlive.scala b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeKeepAlive.scala index 6fe5589576..c7ac25835e 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeKeepAlive.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeKeepAlive.scala @@ -41,6 +41,8 @@ class RecipeKeepAlive extends RecipeSpec { val subscription = sub.expectSubscription() + // FIXME RK: remove (because I think this cannot deterministically be tested and it might also not do what it should anymore) + tickPub.sendNext(()) // pending data will overcome the keepalive diff --git a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeMissedTicks.scala b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeMissedTicks.scala index d8117490e7..1829655862 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeMissedTicks.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeMissedTicks.scala @@ -2,8 +2,9 @@ package docs.stream.cookbook import akka.stream.scaladsl._ import akka.stream.testkit._ - import scala.concurrent.duration._ +import akka.testkit.TestLatch +import scala.concurrent.Await class RecipeMissedTicks extends RecipeSpec { @@ -22,8 +23,12 @@ class RecipeMissedTicks extends RecipeSpec { Flow[Tick].conflate(seed = (_) => 0)( (missedTicks, tick) => missedTicks + 1) //#missed-ticks + val latch = TestLatch(3) + val realMissedTicks: Flow[Tick, Int, Unit] = + Flow[Tick].conflate(seed = (_) => 0)( + (missedTicks, tick) => { latch.countDown(); missedTicks + 1 }) - tickStream.via(missedTicks).to(sink).run() + tickStream.via(realMissedTicks).to(sink).run() pub.sendNext(()) pub.sendNext(()) @@ -31,6 +36,8 @@ class RecipeMissedTicks extends RecipeSpec { pub.sendNext(()) val subscription = sub.expectSubscription() + Await.ready(latch, 1.second) + subscription.request(1) sub.expectNext(3) diff --git a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeSimpleDrop.scala b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeSimpleDrop.scala index 4e18bfc439..f8509d29f6 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeSimpleDrop.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/cookbook/RecipeSimpleDrop.scala @@ -2,8 +2,9 @@ package docs.stream.cookbook import akka.stream.scaladsl.{ Flow, Sink, Source } import akka.stream.testkit._ - import scala.concurrent.duration._ +import akka.testkit.TestLatch +import scala.concurrent.Await class RecipeSimpleDrop extends RecipeSpec { @@ -15,13 +16,16 @@ class RecipeSimpleDrop extends RecipeSpec { val droppyStream: Flow[Message, Message, Unit] = Flow[Message].conflate(seed = identity)((lastMessage, newMessage) => newMessage) //#simple-drop + val latch = TestLatch(2) + val realDroppyStream = + Flow[Message].conflate(seed = identity)((lastMessage, newMessage) => { latch.countDown(); newMessage }) val pub = TestPublisher.probe[Message]() val sub = TestSubscriber.manualProbe[Message]() val messageSource = Source(pub) val sink = Sink(sub) - messageSource.via(droppyStream).to(sink).run() + messageSource.via(realDroppyStream).to(sink).run() val subscription = sub.expectSubscription() sub.expectNoMsg(100.millis) @@ -30,6 +34,8 @@ class RecipeSimpleDrop extends RecipeSpec { pub.sendNext("2") pub.sendNext("3") + Await.ready(latch, 1.second) + subscription.request(1) sub.expectNext("3") diff --git a/akka-docs-dev/src/test/resources/application.conf b/akka-docs-dev/src/test/resources/application.conf new file mode 100644 index 0000000000..dafc521805 --- /dev/null +++ b/akka-docs-dev/src/test/resources/application.conf @@ -0,0 +1 @@ +akka.loggers = ["akka.testkit.TestEventListener"] \ No newline at end of file diff --git a/akka-http-core/src/main/java/akka/http/javadsl/model/HttpEntity.java b/akka-http-core/src/main/java/akka/http/javadsl/model/HttpEntity.java index 70b8b4fc25..9c2627afac 100644 --- a/akka-http-core/src/main/java/akka/http/javadsl/model/HttpEntity.java +++ b/akka-http-core/src/main/java/akka/http/javadsl/model/HttpEntity.java @@ -81,7 +81,7 @@ public interface HttpEntity { /** * Returns a stream of data bytes this entity consists of. */ - public abstract Source getDataBytes(); + public abstract Source getDataBytes(); /** * Returns a future of a strict entity that contains the same data as this entity diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala index 8abbe7d9b2..9a93592426 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala @@ -117,7 +117,7 @@ private[http] object OutgoingConnectionBlueprint { val shape = new FanInShape2(requests, responses, out) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { passAlong(requests, out, doFinish = false, doFail = true) setHandler(out, eagerTerminateOutput) @@ -147,7 +147,7 @@ private[http] object OutgoingConnectionBlueprint { val shape = new FanInShape2(dataInput, methodBypassInput, out) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { // each connection uses a single (private) response parser instance for all its responses // which builds a cache of all header instances seen on that connection val parser = rootParser.createShallowCopy() diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolConductor.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolConductor.scala index feb9809891..3662170935 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolConductor.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolConductor.scala @@ -115,7 +115,7 @@ private object PoolConductor { override val shape = new FanInShape2(ctxIn, slotIn, out) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { val slotStates = Array.fill[SlotState](slotCount)(Unconnected) var nextSlot = 0 @@ -207,7 +207,7 @@ private object PoolConductor { override val shape = new UniformFanOutShape[SwitchCommand, RequestContext](slotCount) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { shape.outArray foreach { setHandler(_, ignoreTerminateOutput) } val in = shape.in diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolSlot.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolSlot.scala index c28331fd5c..7ac40d9998 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolSlot.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/PoolSlot.scala @@ -57,10 +57,10 @@ private object PoolSlot { import FlowGraph.Implicits._ val slotProcessor = b.add { - Flow[RequestContext].andThenMat { () ⇒ + Flow.fromProcessor { () ⇒ val actor = system.actorOf(Props(new SlotProcessor(slotIx, connectionFlow, settings)).withDeploy(Deploy.local), slotProcessorActorName.next()) - (ActorProcessor[RequestContext, List[ProcessorOut]](actor), ()) + ActorProcessor[RequestContext, List[ProcessorOut]](actor) }.mapConcat(identity) } val split = b.add(Broadcast[ProcessorOut](2)) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index f7feddcdbf..7b8806b52b 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -179,7 +179,7 @@ private[http] object HttpServerBluePrint { override val shape = new FanInShape3(bypassInput, oneHundredContinue, applicationInput, out) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { var requestStart: RequestStart = _ setHandler(bypassInput, new InHandler { @@ -334,7 +334,7 @@ private[http] object HttpServerBluePrint { override val shape = new FanOutShape2(in, httpOut, wsOut) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { var target = httpOut setHandler(in, new InHandler { @@ -362,7 +362,7 @@ private[http] object HttpServerBluePrint { override val shape = new FanInShape2(httpIn, wsIn, out) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { var websocketHandlerWasInstalled = false setHandler(httpIn, conditionalTerminateInput(() ⇒ !websocketHandlerWasInstalled)) @@ -407,7 +407,7 @@ private[http] object HttpServerBluePrint { override val shape = new FanInShape2(bytes, token, out) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { passAlong(bytes, out, doFinish = true, doFail = true) passAlong(token, out, doFinish = false, doFail = true) setHandler(out, eagerTerminateOutput) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEvent.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEvent.scala index 35c8cc54b3..c2e0d3045f 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEvent.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEvent.scala @@ -7,12 +7,16 @@ package akka.http.impl.engine.ws import akka.http.impl.engine.ws.Protocol.Opcode import akka.util.ByteString +private[http] sealed trait FrameEventOrError + +private[http] final case class FrameError(p: ProtocolException) extends FrameEventOrError + /** * The low-level Websocket framing model. * * INTERNAL API */ -private[http] sealed trait FrameEvent { +private[http] sealed trait FrameEvent extends FrameEventOrError { def data: ByteString def lastPart: Boolean def withData(data: ByteString): FrameEvent diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEventParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEventParser.scala index ca08022b2a..44c74fbfd2 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEventParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameEventParser.scala @@ -4,11 +4,10 @@ package akka.http.impl.engine.ws -import akka.http.impl.util.{ ByteReader, ByteStringParserStage } -import akka.stream.stage.{ StageState, SyncDirective, Context } import akka.util.ByteString - import scala.annotation.tailrec +import akka.stream.io.ByteStringParser +import akka.stream.Attributes /** * Streaming parser for the Websocket framing protocol as defined in RFC6455 @@ -36,108 +35,81 @@ import scala.annotation.tailrec * * INTERNAL API */ -private[http] class FrameEventParser extends ByteStringParserStage[FrameEvent] { - protected def onTruncation(ctx: Context[FrameEvent]): SyncDirective = - ctx.fail(new ProtocolException("Data truncated")) +private[http] object FrameEventParser extends ByteStringParser[FrameEvent] { + import ByteStringParser._ - def initial: StageState[ByteString, FrameEvent] = ReadFrameHeader + override def createLogic(attr: Attributes) = new ParsingLogic { + startWith(ReadFrameHeader) - object ReadFrameHeader extends ByteReadingState { - def read(reader: ByteReader, ctx: Context[FrameEvent]): SyncDirective = { - import Protocol._ - - val flagsAndOp = reader.readByte() - val maskAndLength = reader.readByte() - - val flags = flagsAndOp & FLAGS_MASK - val op = flagsAndOp & OP_MASK - - val maskBit = (maskAndLength & MASK_MASK) != 0 - val length7 = maskAndLength & LENGTH_MASK - - val length = - length7 match { - case 126 ⇒ reader.readShortBE().toLong - case 127 ⇒ reader.readLongBE() - case x ⇒ x.toLong - } - - if (length < 0) ctx.fail(new ProtocolException("Highest bit of 64bit length was set")) - - val mask = - if (maskBit) Some(reader.readIntBE()) - else None - - def isFlagSet(mask: Int): Boolean = (flags & mask) != 0 - val header = - FrameHeader(Opcode.forCode(op.toByte), - mask, - length, - fin = isFlagSet(FIN_MASK), - rsv1 = isFlagSet(RSV1_MASK), - rsv2 = isFlagSet(RSV2_MASK), - rsv3 = isFlagSet(RSV3_MASK)) - - val data = reader.remainingData - val takeNow = (header.length min Int.MaxValue).toInt - val thisFrameData = data.take(takeNow) - val remaining = data.drop(takeNow) - - val nextState = - if (thisFrameData.length == length) ReadFrameHeader - else readData(length - thisFrameData.length) - - pushAndBecomeWithRemaining(FrameStart(header, thisFrameData.compact), nextState, remaining, ctx) - } - } - - def readData(_remaining: Long): State = - new State { - var remaining = _remaining - def onPush(elem: ByteString, ctx: Context[FrameEvent]): SyncDirective = - if (elem.size < remaining) { - remaining -= elem.size - ctx.push(FrameData(elem, lastPart = false)) - } else { - require(remaining <= Int.MaxValue) // safe because, remaining <= elem.size <= Int.MaxValue - val frameData = elem.take(remaining.toInt) - val remainingData = elem.drop(remaining.toInt) - - pushAndBecomeWithRemaining(FrameData(frameData.compact, lastPart = true), ReadFrameHeader, remainingData, ctx) - } + trait Step extends ParseStep[FrameEvent] { + override def onTruncation(): Unit = failStage(new ProtocolException("Data truncated")) } - def becomeWithRemaining(nextState: State, remainingData: ByteString, ctx: Context[FrameEvent]): SyncDirective = { - become(nextState) - nextState.onPush(remainingData, ctx) - } - def pushAndBecomeWithRemaining(elem: FrameEvent, nextState: State, remainingData: ByteString, ctx: Context[FrameEvent]): SyncDirective = - if (remainingData.isEmpty) { - become(nextState) - ctx.push(elem) - } else { - become(waitForPull(nextState, remainingData)) - ctx.push(elem) - } + object ReadFrameHeader extends Step { + override def parse(reader: ByteReader): (FrameEvent, Step) = { + import Protocol._ - def waitForPull(nextState: State, remainingData: ByteString): State = - new State { - def onPush(elem: ByteString, ctx: Context[FrameEvent]): SyncDirective = - throw new IllegalStateException("Mustn't push in this state") + val flagsAndOp = reader.readByte() + val maskAndLength = reader.readByte() - override def onPull(ctx: Context[FrameEvent]): SyncDirective = { - become(nextState) - nextState.onPush(remainingData, ctx) + val flags = flagsAndOp & FLAGS_MASK + val op = flagsAndOp & OP_MASK + + val maskBit = (maskAndLength & MASK_MASK) != 0 + val length7 = maskAndLength & LENGTH_MASK + + val length = + length7 match { + case 126 ⇒ reader.readShortBE().toLong + case 127 ⇒ reader.readLongBE() + case x ⇒ x.toLong + } + + if (length < 0) throw new ProtocolException("Highest bit of 64bit length was set") + + val mask = + if (maskBit) Some(reader.readIntBE()) + else None + + def isFlagSet(mask: Int): Boolean = (flags & mask) != 0 + val header = + FrameHeader(Opcode.forCode(op.toByte), + mask, + length, + fin = isFlagSet(FIN_MASK), + rsv1 = isFlagSet(RSV1_MASK), + rsv2 = isFlagSet(RSV2_MASK), + rsv3 = isFlagSet(RSV3_MASK)) + + val takeNow = (header.length min reader.remainingSize).toInt + val thisFrameData = reader.take(takeNow) + + val nextState = + if (thisFrameData.length == length) ReadFrameHeader + else new ReadData(length - thisFrameData.length) + + (FrameStart(header, thisFrameData.compact), nextState) } } -} -object FrameEventParser { + class ReadData(_remaining: Long) extends Step { + var remaining = _remaining + override def parse(reader: ByteReader): (FrameEvent, Step) = + if (reader.remainingSize < remaining) { + remaining -= reader.remainingSize + (FrameData(reader.takeAll(), lastPart = false), this) + } else { + (FrameData(reader.take(remaining.toInt), lastPart = true), ReadFrameHeader) + } + } + } + def mask(bytes: ByteString, _mask: Option[Int]): ByteString = _mask match { case Some(m) ⇒ mask(bytes, m)._1 case None ⇒ bytes } + def mask(bytes: ByteString, mask: Int): (ByteString, Int) = { @tailrec def rec(bytes: Array[Byte], offset: Int, mask: Int): Int = if (offset >= bytes.length) mask diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameHandler.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameHandler.scala index c6557ca0c5..ceb56c25fe 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameHandler.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameHandler.scala @@ -19,10 +19,10 @@ import scala.util.control.NonFatal */ private[http] object FrameHandler { - def create(server: Boolean): Flow[FrameEvent, Output, Unit] = - Flow[FrameEvent].transform(() ⇒ new HandlerStage(server)) + def create(server: Boolean): Flow[FrameEventOrError, Output, Unit] = + Flow[FrameEventOrError].transform(() ⇒ new HandlerStage(server)) - private class HandlerStage(server: Boolean) extends StatefulStage[FrameEvent, Output] { + private class HandlerStage(server: Boolean) extends StatefulStage[FrameEventOrError, Output] { type Ctx = Context[Output] def initial: State = Idle @@ -79,11 +79,6 @@ private[http] object FrameHandler { } } - private object Closed extends State { - def onPush(elem: FrameEvent, ctx: Ctx): SyncDirective = - ctx.pull() // ignore - } - private def becomeAndHandleWith(newState: State, part: FrameEvent)(implicit ctx: Ctx): SyncDirective = { become(newState) current.onPush(part, ctx) @@ -132,7 +127,7 @@ private[http] object FrameHandler { } private object CloseAfterPeerClosed extends State { - def onPush(elem: FrameEvent, ctx: Context[Output]): SyncDirective = + def onPush(elem: FrameEventOrError, ctx: Context[Output]): SyncDirective = elem match { case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data) ⇒ become(WaitForPeerTcpClose) @@ -141,7 +136,7 @@ private[http] object FrameHandler { } } private object WaitForPeerTcpClose extends State { - def onPush(elem: FrameEvent, ctx: Context[Output]): SyncDirective = + def onPush(elem: FrameEventOrError, ctx: Context[Output]): SyncDirective = ctx.pull() // ignore } @@ -168,10 +163,11 @@ private[http] object FrameHandler { def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective - def onPush(part: FrameEvent, ctx: Ctx): SyncDirective = + def onPush(part: FrameEventOrError, ctx: Ctx): SyncDirective = part match { case data: FrameData ⇒ handleFrameData(data)(ctx) case start: FrameStart ⇒ handleFrameStart(start)(ctx) + case FrameError(ex) ⇒ ctx.fail(ex) } } } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala index 9b618b5b10..65a571eb3e 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala @@ -15,13 +15,19 @@ import akka.stream.stage.{ SyncDirective, Context, StatefulStage } * INTERNAL API */ private[http] object Masking { - def apply(serverSide: Boolean, maskRandom: () ⇒ Random): BidiFlow[ /* net in */ FrameEvent, /* app out */ FrameEvent, /* app in */ FrameEvent, /* net out */ FrameEvent, Unit] = + def apply(serverSide: Boolean, maskRandom: () ⇒ Random): BidiFlow[ /* net in */ FrameEvent, /* app out */ FrameEventOrError, /* app in */ FrameEvent, /* net out */ FrameEvent, Unit] = BidiFlow.fromFlowsMat(unmaskIf(serverSide), maskIf(!serverSide, maskRandom))(Keep.none) def maskIf(condition: Boolean, maskRandom: () ⇒ Random): Flow[FrameEvent, FrameEvent, Unit] = - if (condition) Flow[FrameEvent].transform(() ⇒ new Masking(maskRandom())) // new random per materialization + if (condition) + Flow[FrameEvent] + .transform(() ⇒ new Masking(maskRandom())) // new random per materialization + .map { + case f: FrameEvent ⇒ f + case FrameError(ex) ⇒ throw ex + } else Flow[FrameEvent] - def unmaskIf(condition: Boolean): Flow[FrameEvent, FrameEvent, Unit] = + def unmaskIf(condition: Boolean): Flow[FrameEvent, FrameEventOrError, Unit] = if (condition) Flow[FrameEvent].transform(() ⇒ new Unmasking()) else Flow[FrameEvent] @@ -41,19 +47,25 @@ private[http] object Masking { } /** Implements both masking and unmasking which is mostly symmetric (because of XOR) */ - private abstract class Masker extends StatefulStage[FrameEvent, FrameEvent] { + private abstract class Masker extends StatefulStage[FrameEvent, FrameEventOrError] { def extractMask(header: FrameHeader): Int def setNewMask(header: FrameHeader, mask: Int): FrameHeader def initial: State = Idle - object Idle extends State { - def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective = + private object Idle extends State { + def onPush(part: FrameEvent, ctx: Context[FrameEventOrError]): SyncDirective = part match { case start @ FrameStart(header, data) ⇒ - val mask = extractMask(header) - become(new Running(mask)) - current.onPush(start.copy(header = setNewMask(header, mask)), ctx) + try { + val mask = extractMask(header) + become(new Running(mask)) + current.onPush(start.copy(header = setNewMask(header, mask)), ctx) + } catch { + case p: ProtocolException ⇒ + become(Done) + ctx.push(FrameError(p)) + } case _: FrameData ⇒ ctx.fail(new IllegalStateException("unexpected FrameData (need FrameStart first)")) } @@ -61,7 +73,7 @@ private[http] object Masking { private class Running(initialMask: Int) extends State { var mask = initialMask - def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective = { + def onPush(part: FrameEvent, ctx: Context[FrameEventOrError]): SyncDirective = { if (part.lastPart) become(Idle) val (masked, newMask) = FrameEventParser.mask(part.data, mask) @@ -69,5 +81,8 @@ private[http] object Masking { ctx.push(part.withData(data = masked)) } } + private object Done extends State { + def onPush(part: FrameEvent, ctx: Context[FrameEventOrError]): SyncDirective = ctx.pull() + } } } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala index ca313bfbb8..0ac57b7a5d 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala @@ -40,12 +40,12 @@ private[http] object Websocket { /** The lowest layer that implements the binary protocol */ def framing: BidiFlow[ByteString, FrameEvent, FrameEvent, ByteString, Unit] = BidiFlow.fromFlowsMat( - Flow[ByteString].transform(() ⇒ new FrameEventParser), + Flow[ByteString].via(FrameEventParser), Flow[FrameEvent].transform(() ⇒ new FrameEventRenderer))(Keep.none) .named("ws-framing") /** The layer that handles masking using the rules defined in the specification */ - def masking(serverSide: Boolean, maskingRandomFactory: () ⇒ Random): BidiFlow[FrameEvent, FrameEvent, FrameEvent, FrameEvent, Unit] = + def masking(serverSide: Boolean, maskingRandomFactory: () ⇒ Random): BidiFlow[FrameEvent, FrameEventOrError, FrameEvent, FrameEvent, Unit] = Masking(serverSide, maskingRandomFactory) .named("ws-masking") @@ -55,7 +55,7 @@ private[http] object Websocket { */ def frameHandling(serverSide: Boolean = true, closeTimeout: FiniteDuration, - log: LoggingAdapter): BidiFlow[FrameEvent, FrameHandler.Output, FrameOutHandler.Input, FrameStart, Unit] = + log: LoggingAdapter): BidiFlow[FrameEventOrError, FrameHandler.Output, FrameOutHandler.Input, FrameStart, Unit] = BidiFlow.fromFlowsMat( FrameHandler.create(server = serverSide), FrameOutHandler.create(serverSide, closeTimeout, log))(Keep.none) @@ -156,7 +156,7 @@ private[http] object Websocket { val shape = new FanOutShape2(in, bypass, user) - def createLogic = new GraphStageLogic(shape) { + def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { setHandler(in, new InHandler { override def onPush(): Unit = { @@ -187,7 +187,7 @@ private[http] object Websocket { val shape = new FanInShape3(bypass, user, tick, out) - def createLogic = new GraphStageLogic(shape) { + def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { passAlong(bypass, out, doFinish = true, doFail = true) passAlong(user, out, doFinish = false, doFail = false) @@ -210,7 +210,7 @@ private[http] object Websocket { val shape = new FlowShape(in, out) - def createLogic = new GraphStageLogic(shape) { + def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { setHandler(out, new OutHandler { override def onPull(): Unit = pull(in) }) diff --git a/akka-http-core/src/main/scala/akka/http/impl/util/ByteReader.scala b/akka-http-core/src/main/scala/akka/http/impl/util/ByteReader.scala index d46ab53ca0..3a03d3dea5 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/util/ByteReader.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/util/ByteReader.scala @@ -21,17 +21,17 @@ private[akka] class ByteReader(input: ByteString) { def currentOffset: Int = off def remainingData: ByteString = input.drop(off) - def fromStartToHere: ByteString = input.take(currentOffset) + def fromStartToHere: ByteString = input.take(off) def readByte(): Int = if (off < input.length) { val x = input(off) off += 1 - x.toInt & 0xFF + x & 0xFF } else throw NeedMoreData def readShortLE(): Int = readByte() | (readByte() << 8) def readIntLE(): Int = readShortLE() | (readShortLE() << 16) - def readLongLE(): Long = (readIntBE() & 0xffffffffL) | ((readIntLE() & 0xffffffffL) << 32) + def readLongLE(): Long = (readIntLE() & 0xffffffffL) | ((readIntLE() & 0xffffffffL) << 32) def readShortBE(): Int = (readByte() << 8) | readByte() def readIntBE(): Int = (readShortBE() << 16) | readShortBE() diff --git a/akka-http-core/src/main/scala/akka/http/impl/util/package.scala b/akka-http-core/src/main/scala/akka/http/impl/util/package.scala index e86da57357..dc10cfea7f 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/util/package.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/util/package.scala @@ -128,7 +128,7 @@ package object util { package util { import akka.http.scaladsl.model.{ ContentType, HttpEntity } - import akka.stream.{ Outlet, Inlet, FlowShape } + import akka.stream.{ Attributes, Outlet, Inlet, FlowShape } import scala.concurrent.duration.FiniteDuration private[http] class ToStrict(timeout: FiniteDuration, contentType: ContentType) @@ -138,7 +138,7 @@ package util { val out = Outlet[HttpEntity.Strict]("out") override val shape = FlowShape(in, out) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { var bytes = ByteString.newBuilder private var emptyStream = false diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpEntity.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpEntity.scala index f29829b103..e8db1d88c2 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpEntity.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpEntity.scala @@ -75,7 +75,7 @@ sealed trait HttpEntity extends jm.HttpEntity { def withContentType(contentType: ContentType): HttpEntity /** Java API */ - def getDataBytes: stream.javadsl.Source[ByteString, _] = stream.javadsl.Source.fromGraph(dataBytes) + def getDataBytes: stream.javadsl.Source[ByteString, AnyRef] = stream.javadsl.Source.fromGraph(dataBytes.asInstanceOf[Source[ByteString, AnyRef]]) /** Java API */ def getContentLengthOption: japi.Option[JLong] = @@ -147,9 +147,11 @@ object HttpEntity { def apply(contentType: ContentType, data: Source[ByteString, Any]): Chunked = Chunked.fromData(contentType, data) - def apply(contentType: ContentType, file: File, chunkSize: Int = SynchronousFileSource.DefaultChunkSize): UniversalEntity = { + def apply(contentType: ContentType, file: File, chunkSize: Int = -1): UniversalEntity = { val fileLength = file.length - if (fileLength > 0) Default(contentType, fileLength, SynchronousFileSource(file, chunkSize)) + if (fileLength > 0) + Default(contentType, fileLength, + if (chunkSize > 0) SynchronousFileSource(file, chunkSize) else SynchronousFileSource(file)) else empty(contentType) } diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala index 346f90cd22..530546a8e6 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/Multipart.scala @@ -194,7 +194,7 @@ object Multipart { * To create an instance with several parts or for multiple files, use * ``FormData(BodyPart.fromFile("field1", ...), BodyPart.fromFile("field2", ...)`` */ - def fromFile(name: String, contentType: ContentType, file: File, chunkSize: Int = SynchronousFileSource.DefaultChunkSize): FormData = + def fromFile(name: String, contentType: ContentType, file: File, chunkSize: Int = -1): FormData = FormData(Source.single(BodyPart.fromFile(name, contentType, file, chunkSize))) /** @@ -237,7 +237,7 @@ object Multipart { /** * Creates a BodyPart backed by a File that will be streamed using a SynchronousFileSource. */ - def fromFile(name: String, contentType: ContentType, file: File, chunkSize: Int = SynchronousFileSource.DefaultChunkSize): BodyPart = + def fromFile(name: String, contentType: ContentType, file: File, chunkSize: Int = -1): BodyPart = BodyPart(name, HttpEntity(contentType, file, chunkSize), Map("filename" -> file.getName)) def unapply(value: BodyPart): Option[(String, BodyPartEntity, Map[String, String], immutable.Seq[HttpHeader])] = diff --git a/akka-http-core/src/test/resources/reference.conf b/akka-http-core/src/test/resources/reference.conf index ab48718a51..1660c0e30d 100644 --- a/akka-http-core/src/test/resources/reference.conf +++ b/akka-http-core/src/test/resources/reference.conf @@ -1,4 +1,5 @@ akka { + loggers = ["akka.testkit.TestEventListener"] actor { serialize-creators = on serialize-messages = on diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala index 8a76f3fb63..e287b92add 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/ConnectionPoolSpec.scala @@ -57,8 +57,8 @@ class ConnectionPoolSpec extends AkkaSpec(""" requestIn.sendNext(HttpRequest(uri = "/") -> 42) - acceptIncomingConnection() responseOutSub.request(1) + acceptIncomingConnection() val (Success(response), 42) = responseOut.expectNext() response.headers should contain(RawHeader("Req-Host", s"$serverHostName:$serverPort")) } @@ -116,8 +116,8 @@ class ConnectionPoolSpec extends AkkaSpec(""" val (requestIn, responseOut, responseOutSub, hcp) = cachedHostConnectionPool[Int]() requestIn.sendNext(HttpRequest(uri = "/a") -> 42) - acceptIncomingConnection() responseOutSub.request(1) + acceptIncomingConnection() val (Success(response1), 42) = responseOut.expectNext() connNr(response1) shouldEqual 1 @@ -222,8 +222,8 @@ class ConnectionPoolSpec extends AkkaSpec(""" requestIn.sendNext(HttpRequest(uri = "/") -> 42) - acceptIncomingConnection() responseOutSub.request(1) + acceptIncomingConnection() val (Success(_), 42) = responseOut.expectNext() } } @@ -346,7 +346,7 @@ class ConnectionPoolSpec extends AkkaSpec(""" def flowTestBench[T, Mat](poolFlow: Flow[(HttpRequest, T), (Try[HttpResponse], T), Mat]) = { val requestIn = TestPublisher.probe[(HttpRequest, T)]() val responseOut = TestSubscriber.manualProbe[(Try[HttpResponse], T)] - val hcp = Source(requestIn).viaMat(poolFlow)(Keep.right).toMat(Sink(responseOut))(Keep.left).run() + val hcp = Source(requestIn).viaMat(poolFlow)(Keep.right).to(Sink(responseOut)).run() val responseOutSub = responseOut.expectSubscription() (requestIn, responseOut, responseOutSub, hcp) } diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala index 4d56fcca37..9150e1540d 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/LowLevelOutgoingConnectionSpec.scala @@ -62,7 +62,7 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. | |""") val sub = probe.expectSubscription() - sub.expectRequest(4) + sub.expectRequest() sub.sendNext(ByteString("ABC")) expectWireData("ABC") sub.sendNext(ByteString("DEF")) @@ -228,7 +228,7 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. | |""") val sub = probe.expectSubscription() - sub.expectRequest(4) + sub.expectRequest() sub.sendNext(ByteString("ABC")) expectWireData("ABC") sub.sendNext(ByteString("DEF")) @@ -254,7 +254,7 @@ class LowLevelOutgoingConnectionSpec extends AkkaSpec("akka.loggers = []\n akka. | |""") val sub = probe.expectSubscription() - sub.expectRequest(4) + sub.expectRequest() sub.sendNext(ByteString("ABC")) expectWireData("ABC") sub.sendNext(ByteString("DEF")) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala index c0a5aab683..71ccc11977 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala @@ -6,20 +6,18 @@ package akka.http.impl.engine.client import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.concurrent.ScalaFutures - import akka.stream.ActorMaterializer import akka.stream.io._ import akka.stream.scaladsl._ import akka.stream.testkit.AkkaSpec - import akka.http.impl.util._ - import akka.http.scaladsl.{ HttpsContext, Http } import akka.http.scaladsl.model.{ StatusCodes, HttpResponse, HttpRequest } import akka.http.scaladsl.model.headers.Host import org.scalatest.time.{ Span, Seconds } - import scala.concurrent.Future +import akka.testkit.EventFilter +import javax.net.ssl.SSLException class TlsEndpointVerificationSpec extends AkkaSpec(""" #akka.loggers = [] @@ -30,7 +28,7 @@ class TlsEndpointVerificationSpec extends AkkaSpec(""" val timeout = Timeout(Span(3, Seconds)) "The client implementation" should { - "not accept certificates signed by unknown CA" in { + "not accept certificates signed by unknown CA" in EventFilter[SSLException](occurrences = 1).intercept { val pipe = pipeline(Http().defaultClientHttpsContext, hostname = "akka.example.org") // default context doesn't include custom CA whenReady(pipe(HttpRequest(uri = "https://akka.example.org/")).failed, timeout) { e ⇒ diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/FramingSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/FramingSpec.scala index f5e5aac239..095eef7e51 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/FramingSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/FramingSpec.scala @@ -307,13 +307,12 @@ class FramingSpec extends FreeSpec with Matchers with WithMaterializerSpec { } private def parseToEvents(bytes: Seq[ByteString]): immutable.Seq[FrameEvent] = - Source(bytes.toVector).transform(newParser).runFold(Vector.empty[FrameEvent])(_ :+ _) + Source(bytes.toVector).via(FrameEventParser).runFold(Vector.empty[FrameEvent])(_ :+ _) .awaitResult(1.second) private def renderToByteString(events: immutable.Seq[FrameEvent]): ByteString = Source(events).transform(newRenderer).runFold(ByteString.empty)(_ ++ _) .awaitResult(1.second) - protected def newParser(): Stage[ByteString, FrameEvent] = new FrameEventParser protected def newRenderer(): Stage[FrameEvent, ByteString] = new FrameEventRenderer import scala.language.implicitConversions diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala index 49b4024b93..90583939af 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala @@ -13,6 +13,7 @@ import akka.stream.testkit._ import akka.util.ByteString import akka.http.scaladsl.model.ws._ import Protocol.Opcode +import akka.testkit.EventFilter class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { import WSTestUtils._ @@ -595,15 +596,18 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { netIn.expectCancellation() } "if user handler fails" in new ServerTestSetup { - messageOut.sendError(new RuntimeException("Oops, user handler failed!")) - expectCloseCodeOnNetwork(Protocol.CloseCodes.UnexpectedCondition) + EventFilter[RuntimeException](message = "Oops, user handler failed!", occurrences = 1) + .intercept { + messageOut.sendError(new RuntimeException("Oops, user handler failed!")) + expectCloseCodeOnNetwork(Protocol.CloseCodes.UnexpectedCondition) - expectNoNetworkData() // wait for peer to close regularly - pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) + expectNoNetworkData() // wait for peer to close regularly + pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) - expectComplete(messageIn) - netOut.expectComplete() - netIn.expectCancellation() + expectComplete(messageIn) + netOut.expectComplete() + netIn.expectCancellation() + } } "if peer closes with invalid close frame" - { "close code outside of the valid range" in new ServerTestSetup { @@ -828,7 +832,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { Source(netIn) .via(printEvent("netIn")) - .transform(() ⇒ new FrameEventParser) + .via(FrameEventParser) .via(Websocket.stack(serverSide, maskingRandomFactory = Randoms.SecureRandomInstances, closeTimeout = closeTimeout, log = system.log).join(messageHandler)) .via(printEvent("frameRendererIn")) .transform(() ⇒ new FrameEventRenderer) diff --git a/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala b/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala index 5fc481a5f7..a3001e0851 100644 --- a/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala @@ -24,8 +24,8 @@ import akka.http.scaladsl.model.HttpMethods._ import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers._ import akka.http.impl.util._ - import scala.util.{ Failure, Try, Success } +import java.net.BindException class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { val testConf: Config = ConfigFactory.parseString(""" @@ -56,7 +56,7 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { sub.cancel() } - "report failure if bind fails" in { + "report failure if bind fails" in EventFilter[BindException](occurrences = 2).intercept { val (_, hostname, port) = TestUtils.temporaryServerHostnameAndPort() val binding = Http().bind(hostname, port) val probe1 = TestSubscriber.manualProbe[Http.IncomingConnection]() @@ -173,7 +173,8 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { case other: Throwable ⇒ Failure(other) }.get val diff = System.nanoTime() - serverReceivedRequestAtNanos - diff should be > theIdleTimeout.toNanos + val fudge = 100 * 1000 * 1000 // 100ms to account for signal propagation between idleStage and handler + diff should be > (theIdleTimeout.toNanos - fudge) } "log materialization errors in `bindAndHandle`" which { diff --git a/akka-http-tests/src/test/resources/reference.conf b/akka-http-tests/src/test/resources/reference.conf index ab48718a51..1660c0e30d 100644 --- a/akka-http-tests/src/test/resources/reference.conf +++ b/akka-http-tests/src/test/resources/reference.conf @@ -1,4 +1,5 @@ akka { + loggers = ["akka.testkit.TestEventListener"] actor { serialize-creators = on serialize-messages = on diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/BasicRouteSpecs.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/BasicRouteSpecs.scala index c732923631..68bcc0afd9 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/BasicRouteSpecs.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/BasicRouteSpecs.scala @@ -7,6 +7,7 @@ package akka.http.scaladsl.server import akka.http.scaladsl.model import model.HttpMethods._ import model.StatusCodes +import akka.testkit.EventFilter class BasicRouteSpecs extends RoutingSpec { @@ -134,7 +135,7 @@ class BasicRouteSpecs extends RoutingSpec { case object MyException extends RuntimeException "Route sealing" should { - "catch route execution exceptions" in { + "catch route execution exceptions" in EventFilter[MyException.type](occurrences = 1).intercept { Get("/abc") ~> Route.seal { get { ctx ⇒ throw MyException @@ -143,7 +144,7 @@ class BasicRouteSpecs extends RoutingSpec { status shouldEqual StatusCodes.InternalServerError } } - "catch route building exceptions" in { + "catch route building exceptions" in EventFilter[MyException.type](occurrences = 1).intercept { Get("/abc") ~> Route.seal { get { throw MyException @@ -152,7 +153,7 @@ class BasicRouteSpecs extends RoutingSpec { status shouldEqual StatusCodes.InternalServerError } } - "convert all rejections to responses" in { + "convert all rejections to responses" in EventFilter[RuntimeException](occurrences = 1).intercept { object MyRejection extends Rejection Get("/abc") ~> Route.seal { get { diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ExecutionDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ExecutionDirectivesSpec.scala index 86e5cd6eb7..f5891ecdcb 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ExecutionDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/ExecutionDirectivesSpec.scala @@ -7,8 +7,8 @@ package directives import akka.http.scaladsl.model.{ MediaTypes, MediaRanges, StatusCodes } import akka.http.scaladsl.model.headers._ - import scala.concurrent.Future +import akka.testkit.EventFilter class ExecutionDirectivesSpec extends RoutingSpec { object MyException extends RuntimeException @@ -51,7 +51,7 @@ class ExecutionDirectivesSpec extends RoutingSpec { } } } - "not interfere with alternative routes" in { + "not interfere with alternative routes" in EventFilter[MyException.type](occurrences = 1).intercept { Get("/abc") ~> get { handleExceptions(handler)(reject) ~ { ctx ⇒ @@ -62,22 +62,22 @@ class ExecutionDirectivesSpec extends RoutingSpec { responseAs[String] shouldEqual "There was an internal server error." } } - "not handle other exceptions" in { + "not handle other exceptions" in EventFilter[RuntimeException](occurrences = 1, message = "buh").intercept { Get("/abc") ~> get { handleExceptions(handler) { - throw new RuntimeException + throw new RuntimeException("buh") } } ~> check { status shouldEqual StatusCodes.InternalServerError responseAs[String] shouldEqual "There was an internal server error." } } - "always fall back to a default content type" in { + "always fall back to a default content type" in EventFilter[RuntimeException](occurrences = 2, message = "buh2").intercept { Get("/abc") ~> Accept(MediaTypes.`application/json`) ~> get { handleExceptions(handler) { - throw new RuntimeException + throw new RuntimeException("buh2") } } ~> check { status shouldEqual StatusCodes.InternalServerError @@ -87,7 +87,7 @@ class ExecutionDirectivesSpec extends RoutingSpec { Get("/abc") ~> Accept(MediaTypes.`text/xml`, MediaRanges.`*/*`.withQValue(0f)) ~> get { handleExceptions(handler) { - throw new RuntimeException + throw new RuntimeException("buh2") } } ~> check { status shouldEqual StatusCodes.InternalServerError diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FutureDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FutureDirectivesSpec.scala index fc0c6e7130..2dc05f770d 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FutureDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/FutureDirectivesSpec.scala @@ -6,8 +6,8 @@ package akka.http.scaladsl.server package directives import akka.http.scaladsl.model.StatusCodes - import scala.concurrent.Future +import akka.testkit.EventFilter class FutureDirectivesSpec extends RoutingSpec { @@ -56,7 +56,7 @@ class FutureDirectivesSpec extends RoutingSpec { responseAs[String] shouldEqual "yes" } } - "propagate the exception in the failure case" in { + "propagate the exception in the failure case" in EventFilter[Exception](occurrences = 1, message = "XXX").intercept { Get() ~> onSuccess(Future.failed(TestException)) { echoComplete } ~> check { status shouldEqual StatusCodes.InternalServerError } @@ -67,7 +67,7 @@ class FutureDirectivesSpec extends RoutingSpec { responseAs[String] shouldEqual s"Oops. akka.http.scaladsl.server.directives.FutureDirectivesSpec$$TestException: EX when ok" } } - "catch an exception in the failure case" in { + "catch an exception in the failure case" in EventFilter[Exception](occurrences = 1, message = "XXX").intercept { Get() ~> onSuccess(Future.failed(TestException)) { throwTestException("EX when ") } ~> check { status shouldEqual StatusCodes.InternalServerError responseAs[String] shouldEqual "There was an internal server error." diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RouteDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RouteDirectivesSpec.scala index 09f5bc03f1..7de787fb30 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RouteDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/RouteDirectivesSpec.scala @@ -8,7 +8,6 @@ import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport import akka.http.scaladsl.marshallers.xml.ScalaXmlSupport import akka.stream.scaladsl.Sink import org.scalatest.FreeSpec - import scala.concurrent.{ Future, Promise } import akka.http.scaladsl.marshallers.xml.ScalaXmlSupport._ import akka.http.scaladsl.marshalling._ @@ -18,8 +17,8 @@ import akka.http.impl.util._ import headers._ import StatusCodes._ import MediaTypes._ - import scala.xml.NodeSeq +import akka.testkit.EventFilter class RouteDirectivesSpec extends FreeSpec with GenericRoutingSpec { @@ -47,7 +46,7 @@ class RouteDirectivesSpec extends FreeSpec with GenericRoutingSpec { "for successful futures and marshalling" in { Get() ~> complete(Promise.successful("yes").future) ~> check { responseAs[String] shouldEqual "yes" } } - "for failed futures and marshalling" in { + "for failed futures and marshalling" in EventFilter[RuntimeException](occurrences = 1).intercept { object TestException extends RuntimeException Get() ~> complete(Promise.failed[String](TestException).future) ~> check { diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/SecurityDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/SecurityDirectivesSpec.scala index 8d3a5f8f4f..a6ae7308e0 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/SecurityDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/SecurityDirectivesSpec.scala @@ -9,6 +9,7 @@ import scala.concurrent.Future import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.server.AuthenticationFailedRejection.{ CredentialsRejected, CredentialsMissing } +import akka.testkit.EventFilter class SecurityDirectivesSpec extends RoutingSpec { val dontBasicAuth = authenticateBasicAsync[String]("MyRealm", _ ⇒ Future.successful(None)) @@ -60,11 +61,13 @@ class SecurityDirectivesSpec extends RoutingSpec { } "properly handle exceptions thrown in its inner route" in { object TestException extends RuntimeException - Get() ~> Authorization(BasicHttpCredentials("Alice", "")) ~> { - Route.seal { - doBasicAuth { _ ⇒ throw TestException } - } - } ~> check { status shouldEqual StatusCodes.InternalServerError } + EventFilter[TestException.type](occurrences = 1).intercept { + Get() ~> Authorization(BasicHttpCredentials("Alice", "")) ~> { + Route.seal { + doBasicAuth { _ ⇒ throw TestException } + } + } ~> check { status shouldEqual StatusCodes.InternalServerError } + } } } "bearer token authentication" should { @@ -108,11 +111,13 @@ class SecurityDirectivesSpec extends RoutingSpec { } "properly handle exceptions thrown in its inner route" in { object TestException extends RuntimeException - Get() ~> Authorization(OAuth2BearerToken("myToken")) ~> { - Route.seal { - doOAuth2Auth { _ ⇒ throw TestException } - } - } ~> check { status shouldEqual StatusCodes.InternalServerError } + EventFilter[TestException.type](occurrences = 1).intercept { + Get() ~> Authorization(OAuth2BearerToken("myToken")) ~> { + Route.seal { + doOAuth2Auth { _ ⇒ throw TestException } + } + } ~> check { status shouldEqual StatusCodes.InternalServerError } + } } } "authentication directives" should { diff --git a/akka-stream-tck/src/test/scala/akka/stream/tck/FusableProcessorTest.scala b/akka-stream-tck/src/test/scala/akka/stream/tck/FusableProcessorTest.scala index fe4fd1895f..717c04da22 100644 --- a/akka-stream-tck/src/test/scala/akka/stream/tck/FusableProcessorTest.scala +++ b/akka-stream-tck/src/test/scala/akka/stream/tck/FusableProcessorTest.scala @@ -3,11 +3,11 @@ */ package akka.stream.tck -import akka.stream.{ ActorMaterializer, ActorMaterializerSettings } +import akka.stream.impl.Stages +import akka.stream._ import akka.stream.impl.Stages.Identity import akka.stream.scaladsl.Flow import org.reactivestreams.Processor -import akka.stream.Attributes class FusableProcessorTest extends AkkaIdentityProcessorVerification[Int] { @@ -18,7 +18,7 @@ class FusableProcessorTest extends AkkaIdentityProcessorVerification[Int] { implicit val materializer = ActorMaterializer(settings)(system) // withAttributes "wraps" the underlying identity and protects it from automatic removal - Flow[Int].andThen[Int](Identity()).named("identity").toProcessor.run() + Flow[Int].via(Stages.identityGraph.asInstanceOf[Graph[FlowShape[Int, Int], Unit]]).named("identity").toProcessor.run() } override def createElement(element: Int): Int = element diff --git a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala index e3b1d41545..94d39c368b 100644 --- a/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala +++ b/akka-stream-testkit/src/main/scala/akka/stream/testkit/StreamTestKit.scala @@ -9,11 +9,12 @@ import akka.stream.impl.StreamLayout.Module import akka.stream.impl._ import akka.testkit.TestProbe import org.reactivestreams.{ Publisher, Subscriber, Subscription } - import scala.annotation.tailrec import scala.collection.immutable import scala.concurrent.duration._ import scala.language.existentials +import java.io.StringWriter +import java.io.PrintWriter /** * Provides factory methods for various Publishers. @@ -183,7 +184,16 @@ object TestSubscriber { final case class OnSubscribe(subscription: Subscription) extends SubscriberEvent final case class OnNext[I](element: I) extends SubscriberEvent final case object OnComplete extends SubscriberEvent - final case class OnError(cause: Throwable) extends SubscriberEvent + final case class OnError(cause: Throwable) extends SubscriberEvent { + override def toString: String = { + val str = new StringWriter + val out = new PrintWriter(str) + out.print("OnError(") + cause.printStackTrace(out) + out.print(")") + str.toString + } + } /** * Probe that implements [[org.reactivestreams.Subscriber]] interface. diff --git a/akka-stream-tests/src/test/java/akka/stream/javadsl/AttributesTest.java b/akka-stream-tests/src/test/java/akka/stream/javadsl/AttributesTest.java index 33332a6cb2..caaf023fef 100644 --- a/akka-stream-tests/src/test/java/akka/stream/javadsl/AttributesTest.java +++ b/akka-stream-tests/src/test/java/akka/stream/javadsl/AttributesTest.java @@ -12,7 +12,7 @@ import org.junit.Test; import akka.stream.Attributes; public class AttributesTest { - + final Attributes attributes = Attributes.name("a") .and(Attributes.name("b")) @@ -27,12 +27,12 @@ public class AttributesTest { Collections.singletonList(new Attributes.InputBuffer(1, 2)), attributes.getAttributeList(Attributes.InputBuffer.class)); } - + @Test public void mustGetAttributeByClass() { assertEquals( - new Attributes.Name("a"), + new Attributes.Name("b"), attributes.getAttribute(Attributes.Name.class, new Attributes.Name("default"))); } - + } diff --git a/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java b/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java index ebaa33bbba..db6c3d33e4 100644 --- a/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java +++ b/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java @@ -16,6 +16,7 @@ import akka.stream.testkit.AkkaSpec; import akka.stream.testkit.TestPublisher; import akka.testkit.JavaTestKit; import org.junit.ClassRule; +import org.junit.Ignore; import org.junit.Test; import org.reactivestreams.Publisher; import scala.concurrent.Await; @@ -179,7 +180,7 @@ public class FlowTest extends StreamTest { } - @Test + @Ignore("StatefulStage to be converted to GraphStage when Java Api is available (#18817)") @Test public void mustBeAbleToUseTransform() { final JavaTestKit probe = new JavaTestKit(system); final Iterable input = Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7); @@ -204,15 +205,15 @@ public class FlowTest extends StreamTest { return emit(Arrays.asList(element, element).iterator(), ctx); } } - + }; } - + @Override public TerminationDirective onUpstreamFinish(Context ctx) { return terminationEmit(Collections.singletonList(sum).iterator(), ctx); } - + }; } }); @@ -355,12 +356,12 @@ public class FlowTest extends StreamTest { return new akka.japi.function.Creator>() { @Override public PushPullStage create() throws Exception { - return new PushPullStage() { + return new PushPullStage() { @Override public SyncDirective onPush(T element, Context ctx) { return ctx.push(element); } - + @Override public SyncDirective onPull(Context ctx) { return ctx.pull(); @@ -374,17 +375,17 @@ public class FlowTest extends StreamTest { public void mustBeAbleToUseMerge() throws Exception { final Flow f1 = Flow.of(String.class).transform(FlowTest.this. op()).named("f1"); - final Flow f2 = + final Flow f2 = Flow.of(String.class).transform(FlowTest.this. op()).named("f2"); @SuppressWarnings("unused") - final Flow f3 = + final Flow f3 = Flow.of(String.class).transform(FlowTest.this. op()).named("f3"); final Source in1 = Source.from(Arrays.asList("a", "b", "c")); final Source in2 = Source.from(Arrays.asList("d", "e", "f")); final Sink> publisher = Sink.publisher(); - + final Source source = Source.fromGraph( FlowGraph.create(new Function, SourceShape>() { @Override diff --git a/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java b/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java index 6058e72db0..1f70e19d3a 100644 --- a/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java +++ b/akka-stream-tests/src/test/java/akka/stream/javadsl/SourceTest.java @@ -20,6 +20,7 @@ import akka.stream.testkit.AkkaSpec; import akka.stream.testkit.TestPublisher; import akka.testkit.JavaTestKit; import org.junit.ClassRule; +import org.junit.Ignore; import org.junit.Test; import scala.concurrent.Await; import scala.concurrent.Future; @@ -107,7 +108,7 @@ public class SourceTest extends StreamTest { probe.expectMsgEquals("()"); } - @Test + @Ignore("StatefulStage to be converted to GraphStage when Java Api is available (#18817)") @Test public void mustBeAbleToUseTransform() { final JavaTestKit probe = new JavaTestKit(system); final Iterable input = Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7); @@ -415,7 +416,7 @@ public class SourceTest extends StreamTest { @Test public void mustProduceTicks() throws Exception { final JavaTestKit probe = new JavaTestKit(system); - Source tickSource = Source.from(FiniteDuration.create(1, TimeUnit.SECONDS), + Source tickSource = Source.from(FiniteDuration.create(1, TimeUnit.SECONDS), FiniteDuration.create(500, TimeUnit.MILLISECONDS), "tick"); Cancellable cancellable = tickSource.to(Sink.foreach(new Procedure() { public void apply(String elem) { @@ -457,7 +458,7 @@ public class SourceTest extends StreamTest { String result = Await.result(future2, probe.dilated(FiniteDuration.create(3, TimeUnit.SECONDS))); assertEquals("A", result); } - + @Test public void mustRepeat() throws Exception { final Future> f = Source.repeat(42).grouped(10000).runWith(Sink.> head(), materializer); @@ -465,7 +466,7 @@ public class SourceTest extends StreamTest { assertEquals(result.size(), 10000); for (Integer i: result) assertEquals(i, (Integer) 42); } - + @Test public void mustBeAbleToUseActorRefSource() throws Exception { final JavaTestKit probe = new JavaTestKit(system); diff --git a/akka-stream-tests/src/test/resources/reference.conf b/akka-stream-tests/src/test/resources/application.conf similarity index 62% rename from akka-stream-tests/src/test/resources/reference.conf rename to akka-stream-tests/src/test/resources/application.conf index ab48718a51..1660c0e30d 100644 --- a/akka-stream-tests/src/test/resources/reference.conf +++ b/akka-stream-tests/src/test/resources/application.conf @@ -1,4 +1,5 @@ akka { + loggers = ["akka.testkit.TestEventListener"] actor { serialize-creators = on serialize-messages = on diff --git a/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala b/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala index 79688dd703..cfca10dbb8 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala @@ -26,20 +26,16 @@ class DslConsistencySpec extends WordSpec with Matchers { Set("equals", "hashCode", "notify", "notifyAll", "wait", "toString", "getClass") ++ Set("productArity", "canEqual", "productPrefix", "copy", "productIterator", "productElement") ++ Set("create", "apply", "ops", "appendJava", "andThen", "andThenMat", "isIdentity", "withAttributes", "transformMaterializing") ++ - Set("asScala", "asJava") + Set("asScala", "asJava", "deprecatedAndThen", "deprecatedAndThenMat") val allowMissing: Map[Class[_], Set[String]] = Map( sFlowClass -> Set("of"), sSourceClass -> Set("adapt", "from"), sSinkClass -> Set("adapt"), - // TODO timerTransform is to be removed or replaced. See https://github.com/akka/akka/issues/16393 - jFlowClass -> Set("timerTransform"), - jSourceClass -> Set("timerTransform"), jSinkClass -> Set(), - sRunnableGraphClass -> Set("builder"), - jRunnableGraphClass → Set("graph", "cyclesAllowed")) + sRunnableGraphClass -> Set("builder")) def materializing(m: Method): Boolean = m.getParameterTypes.contains(classOf[ActorMaterializer]) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterLifecycleSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterLifecycleSpec.scala deleted file mode 100644 index e0ec326af2..0000000000 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterLifecycleSpec.scala +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright (C) 2015 Typesafe Inc. - */ -package akka.stream.impl - -import akka.stream.Supervision._ -import akka.stream._ -import akka.stream.impl.fusing.{ ActorInterpreter, InterpreterLifecycleSpecKit } -import akka.stream.stage.Stage -import akka.stream.testkit.Utils.TE -import akka.stream.testkit.{ AkkaSpec, _ } - -import scala.concurrent.duration._ - -class ActorInterpreterLifecycleSpec extends AkkaSpec with InterpreterLifecycleSpecKit { - - implicit val mat = ActorMaterializer() - - class Setup(ops: List[Stage[_, _]] = List(fusing.Map({ x: Any ⇒ x }, stoppingDecider))) { - val up = TestPublisher.manualProbe[Int]() - val down = TestSubscriber.manualProbe[Int] - private val props = ActorInterpreter.props(mat.settings, ops, mat).withDispatcher("akka.test.stream-dispatcher") - val actor = system.actorOf(props) - val processor = ActorProcessorFactory[Int, Int](actor) - } - - "An ActorInterpreter" must { - - "call preStart in order on stages" in new Setup(List( - PreStartAndPostStopIdentity(onStart = _ ⇒ testActor ! "start-a"), - PreStartAndPostStopIdentity(onStart = _ ⇒ testActor ! "start-b"), - PreStartAndPostStopIdentity(onStart = _ ⇒ testActor ! "start-c"))) { - processor.subscribe(down) - val sub = down.expectSubscription() - sub.cancel() - up.subscribe(processor) - val upsub = up.expectSubscription() - upsub.expectCancellation() - - expectMsg("start-a") - expectMsg("start-b") - expectMsg("start-c") - } - - "call postStart in order on stages - when upstream completes" in new Setup(List( - PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-a"), - PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-b"), - PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-c"))) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - upsub.sendComplete() - down.expectComplete() - - expectMsg("stop-a") - expectMsg("stop-b") - expectMsg("stop-c") - } - - "call postStart in order on stages - when downstream cancels" in new Setup(List( - PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-a"), - PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-b"), - PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-c"))) { - processor.subscribe(down) - val sub = down.expectSubscription() - sub.cancel() - up.subscribe(processor) - val upsub = up.expectSubscription() - upsub.expectCancellation() - - expectMsg("stop-c") - expectMsg("stop-b") - expectMsg("stop-a") - } - - "onError downstream when preStart fails" in new Setup(List( - PreStartFailer(() ⇒ throw TE("Boom!")))) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - down.expectError(TE("Boom!")) - } - - "onError only once even with Supervision.restart" in new Setup(List( - PreStartFailer(() ⇒ throw TE("Boom!")))) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - down.expectError(TE("Boom!")) - down.expectNoMsg(1.second) - } - - "onError downstream when preStart fails with 'most downstream' failure, when multiple stages fail" in new Setup(List( - PreStartFailer(() ⇒ throw TE("Boom 1!")), - PreStartFailer(() ⇒ throw TE("Boom 2!")), - PreStartFailer(() ⇒ throw TE("Boom 3!")))) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - down.expectError(TE("Boom 3!")) - down.expectNoMsg(300.millis) - } - - "continue with stream shutdown when postStop fails" in new Setup(List( - PostStopFailer(() ⇒ throw TE("Boom!")))) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - upsub.sendComplete() - down.expectComplete() - } - - } - -} diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterSpec.scala deleted file mode 100644 index ea3afe083b..0000000000 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/ActorInterpreterSpec.scala +++ /dev/null @@ -1,190 +0,0 @@ -/** - * Copyright (C) 2015 Typesafe Inc. - */ -package akka.stream.impl - -import akka.stream.Supervision._ -import akka.stream.impl.ReactiveStreamsCompliance.SpecViolation -import akka.stream.testkit.AkkaSpec -import akka.stream._ -import akka.stream.scaladsl._ -import akka.stream.testkit._ -import akka.stream.impl.fusing.ActorInterpreter -import akka.stream.stage.Stage -import akka.stream.stage.PushPullStage -import akka.stream.stage.Context -import akka.testkit.TestLatch -import org.reactivestreams.{ Subscription, Subscriber, Publisher } -import scala.concurrent.Await -import scala.concurrent.duration._ - -class ActorInterpreterSpec extends AkkaSpec { - import FlowGraph.Implicits._ - - implicit val mat = ActorMaterializer() - - class Setup(ops: List[Stage[_, _]] = List(fusing.Map({ x: Any ⇒ x }, stoppingDecider))) { - val up = TestPublisher.manualProbe[Int]() - val down = TestSubscriber.manualProbe[Int] - private val props = ActorInterpreter.props(mat.settings, ops, mat).withDispatcher("akka.test.stream-dispatcher") - val actor = system.actorOf(props) - val processor = ActorProcessorFactory[Int, Int](actor) - } - - "An ActorInterpreter" must { - - "pass along early cancellation" in new Setup { - processor.subscribe(down) - val sub = down.expectSubscription() - sub.cancel() - up.subscribe(processor) - val upsub = up.expectSubscription() - upsub.expectCancellation() - } - - "heed cancellation signal while large demand is outstanding" in { - val latch = TestLatch() - val infinite = new PushPullStage[Int, Int] { - override def onPush(elem: Int, ctx: Context[Int]) = ??? - override def onPull(ctx: Context[Int]) = { - Await.ready(latch, 5.seconds) - ctx.push(42) - } - } - val N = system.settings.config.getInt("akka.stream.materializer.output-burst-limit") - - new Setup(infinite :: Nil) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - sub.request(100000000) - sub.cancel() - watch(actor) - latch.countDown() - for (i ← 1 to N) withClue(s"iteration $i: ") { - try down.expectNext(42) catch { case e: Throwable ⇒ fail(e) } - } - // now cancellation request is processed - down.expectNoMsg(500.millis) - upsub.expectCancellation() - expectTerminated(actor) - } - } - - "heed upstream failure while large demand is outstanding" in { - val latch = TestLatch() - val infinite = new PushPullStage[Int, Int] { - override def onPush(elem: Int, ctx: Context[Int]) = ??? - override def onPull(ctx: Context[Int]) = { - Await.ready(latch, 5.seconds) - ctx.push(42) - } - } - val N = system.settings.config.getInt("akka.stream.materializer.output-burst-limit") - - new Setup(infinite :: Nil) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - sub.request(100000000) - val ex = new Exception("FAIL!") - upsub.sendError(ex) - latch.countDown() - for (i ← 1 to N) withClue(s"iteration $i: ") { - try down.expectNext(42) catch { case e: Throwable ⇒ fail(e) } - } - down.expectError(ex) - } - } - - "hold back upstream completion while large demand is outstanding" in { - val latch = TestLatch() - val N = 3 * system.settings.config.getInt("akka.stream.materializer.output-burst-limit") - val infinite = new PushPullStage[Int, Int] { - private var remaining = N - override def onPush(elem: Int, ctx: Context[Int]) = ??? - override def onPull(ctx: Context[Int]) = { - Await.ready(latch, 5.seconds) - remaining -= 1 - if (remaining >= 0) ctx.push(42) - else ctx.finish() - } - override def onUpstreamFinish(ctx: Context[Int]) = { - if (remaining > 0) ctx.absorbTermination() - else ctx.finish() - } - } - - new Setup(infinite :: Nil) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - sub.request(100000000) - upsub.sendComplete() - latch.countDown() - for (i ← 1 to N) withClue(s"iteration $i: ") { - try down.expectNext(42) catch { case e: Throwable ⇒ fail(e) } - } - down.expectComplete() - } - } - - "satisfy large demand" in largeDemand(0) - "satisfy larger demand" in largeDemand(1) - - "handle spec violations" in { - a[AbruptTerminationException] should be thrownBy { - Await.result( - Source(new Publisher[String] { - def subscribe(s: Subscriber[_ >: String]) = { - s.onSubscribe(new Subscription { - def cancel() = () - def request(n: Long) = sys.error("test error") - }) - } - }).runFold("")(_ + _), - 3.seconds) - } - } - - "handle failed stage factories" in { - a[RuntimeException] should be thrownBy - Await.result( - Source.empty[Int].transform(() ⇒ sys.error("test error")).runWith(Sink.head), - 3.seconds) - } - - def largeDemand(extra: Int): Unit = { - val N = 3 * system.settings.config.getInt("akka.stream.materializer.output-burst-limit") - val large = new PushPullStage[Int, Int] { - private var remaining = N - override def onPush(elem: Int, ctx: Context[Int]) = ??? - override def onPull(ctx: Context[Int]) = { - remaining -= 1 - if (remaining >= 0) ctx.push(42) - else ctx.finish() - } - } - - new Setup(large :: Nil) { - processor.subscribe(down) - val sub = down.expectSubscription() - up.subscribe(processor) - val upsub = up.expectSubscription() - sub.request(100000000) - watch(actor) - for (i ← 1 to N) withClue(s"iteration $i: ") { - try down.expectNext(42) catch { case e: Throwable ⇒ fail(e) } - } - down.expectComplete() - upsub.expectCancellation() - expectTerminated(actor) - } - } - - } - -} diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala new file mode 100644 index 0000000000..80bb3c6cd7 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/GraphStageLogicSpec.scala @@ -0,0 +1,177 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.impl + +import akka.stream.testkit.AkkaSpec +import akka.stream._ +import akka.stream.scaladsl._ +import akka.stream.stage._ +import akka.stream.testkit.Utils.assertAllStagesStopped +import akka.stream.testkit.scaladsl.TestSink +import akka.stream.impl.fusing._ +import akka.stream.impl.fusing.GraphInterpreter._ +import org.scalactic.ConversionCheckedTripleEquals +import scala.concurrent.duration.Duration + +class GraphStageLogicSpec extends GraphInterpreterSpecKit with ConversionCheckedTripleEquals { + + implicit val mat = ActorMaterializer() + + object emit1234 extends GraphStage[FlowShape[Int, Int]] { + val in = Inlet[Int]("in") + val out = Outlet[Int]("out") + override val shape = FlowShape(in, out) + override def createLogic(attr: Attributes) = new GraphStageLogic(shape) { + setHandler(in, eagerTerminateInput) + setHandler(out, eagerTerminateOutput) + override def preStart(): Unit = { + emit(out, 1, () ⇒ emit(out, 2)) + emit(out, 3, () ⇒ emit(out, 4)) + } + } + } + + object emit5678 extends GraphStage[FlowShape[Int, Int]] { + val in = Inlet[Int]("in") + val out = Outlet[Int]("out") + override val shape = FlowShape(in, out) + override def createLogic(attr: Attributes) = new GraphStageLogic(shape) { + setHandler(in, new InHandler { + override def onPush(): Unit = push(out, grab(in)) + override def onUpstreamFinish(): Unit = { + emit(out, 5, () ⇒ emit(out, 6)) + emit(out, 7, () ⇒ emit(out, 8)) + completeStage() + } + }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + } + } + + object passThrough extends GraphStage[FlowShape[Int, Int]] { + val in = Inlet[Int]("in") + val out = Outlet[Int]("out") + override val shape = FlowShape(in, out) + override def createLogic(attr: Attributes) = new GraphStageLogic(shape) { + setHandler(in, new InHandler { + override def onPush(): Unit = push(out, grab(in)) + override def onUpstreamFinish(): Unit = complete(out) + }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + } + } + + class FusedGraph[S <: Shape](ga: GraphAssembly, s: S, a: Attributes = Attributes.none) extends Graph[S, Unit] { + override def shape = s + override val module = GraphModule(ga, s, a) + override def withAttributes(attr: Attributes) = new FusedGraph(ga, s, attr) + } + + "A GraphStageLogic" must { + + "emit all things before completing" in assertAllStagesStopped { + + Source.empty.via(emit1234.named("testStage")).runWith(TestSink.probe) + .request(5) + .expectNext(1, 2, 3, 4) + .expectComplete() + + } + + "emit all things before completing with two fused stages" in assertAllStagesStopped { + new Builder { + val g = new FusedGraph( + builder(emit1234, emit5678) + .connect(Upstream, emit1234.in) + .connect(emit1234.out, emit5678.in) + .connect(emit5678.out, Downstream) + .buildAssembly(), + FlowShape(emit1234.in, emit5678.out)) + + Source.empty.via(g).runWith(TestSink.probe) + .request(9) + .expectNextN(1 to 8) + .expectComplete() + } + } + + "emit all things before completing with three fused stages" in assertAllStagesStopped { + new Builder { + val g = new FusedGraph( + builder(emit1234, passThrough, emit5678) + .connect(Upstream, emit1234.in) + .connect(emit1234.out, passThrough.in) + .connect(passThrough.out, emit5678.in) + .connect(emit5678.out, Downstream) + .buildAssembly(), + FlowShape(emit1234.in, emit5678.out)) + + Source.empty.via(g).runWith(TestSink.probe) + .request(9) + .expectNextN(1 to 8) + .expectComplete() + } + } + + "invoke lifecycle hooks in the right order" in assertAllStagesStopped { + val g = new GraphStage[FlowShape[Int, Int]] { + val in = Inlet[Int]("in") + val out = Outlet[Int]("out") + override val shape = FlowShape(in, out) + override def createLogic(attr: Attributes) = new GraphStageLogic(shape) { + setHandler(in, eagerTerminateInput) + setHandler(out, new OutHandler { + override def onPull(): Unit = { + completeStage() + testActor ! "pulled" + } + }) + override def preStart(): Unit = testActor ! "preStart" + override def postStop(): Unit = testActor ! "postStop" + } + } + Source.single(1).via(g).runWith(Sink.ignore) + expectMsg("preStart") + expectMsg("pulled") + expectMsg("postStop") + } + + "not double-terminate a single stage" in new Builder { + object g extends GraphStage[FlowShape[Int, Int]] { + val in = Inlet[Int]("in") + val out = Outlet[Int]("out") + override val shape = FlowShape(in, out) + override def createLogic(attr: Attributes) = new GraphStageLogic(shape) { + setHandler(in, eagerTerminateInput) + setHandler(out, eagerTerminateOutput) + override def postStop(): Unit = testActor ! "postStop2" + } + } + + builder(g, passThrough) + .connect(Upstream, g.in) + .connect(g.out, passThrough.in) + .connect(passThrough.out, Downstream) + .init() + + interpreter.complete(0) + interpreter.cancel(1) + interpreter.execute(2) + + expectMsg("postStop2") + expectNoMsg(Duration.Zero) + + interpreter.isCompleted should ===(false) + interpreter.isSuspended should ===(false) + interpreter.isStageCompleted(interpreter.logics(0)) should ===(true) + interpreter.isStageCompleted(interpreter.logics(1)) should ===(false) + } + + } + +} \ No newline at end of file diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala index bd9b5b25d8..29695f242b 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/StreamLayoutSpec.scala @@ -171,7 +171,7 @@ class StreamLayoutSpec extends AkkaSpec { override def onNext(t: Any): Unit = () } - class FlatTestMaterializer(_module: Module) extends MaterializerSession(_module) { + class FlatTestMaterializer(_module: Module) extends MaterializerSession(_module, Attributes()) { var publishers = Vector.empty[TestPublisher] var subscribers = Vector.empty[TestSubscriber] diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala index cc4ba30309..0eadb8a826 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/ActorGraphInterpreterSpec.scala @@ -47,7 +47,7 @@ class ActorGraphInterpreterSpec extends AkkaSpec { val out2 = Outlet[Int]("out2") val shape = BidiShape(in1, out1, in2, out2) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { setHandler(in1, new InHandler { override def onPush(): Unit = push(out1, grab(in1)) override def onUpstreamFinish(): Unit = complete(out1) @@ -88,7 +88,7 @@ class ActorGraphInterpreterSpec extends AkkaSpec { val out2 = Outlet[Int]("out2") val shape = BidiShape(in1, out1, in2, out2) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { setHandler(in1, new InHandler { override def onPush(): Unit = push(out1, grab(in1)) @@ -134,7 +134,7 @@ class ActorGraphInterpreterSpec extends AkkaSpec { val out2 = Outlet[Int]("out2") val shape = BidiShape(in1, out1, in2, out2) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { setHandler(in1, new InHandler { override def onPush(): Unit = push(out1, grab(in1)) @@ -183,7 +183,7 @@ class ActorGraphInterpreterSpec extends AkkaSpec { val out2 = Outlet[Int]("out2") val shape = BidiShape(in1, out1, in2, out2) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { setHandler(in1, new InHandler { override def onPush(): Unit = push(out2, grab(in1)) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala index ed0184b37c..56ea126be9 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpec.scala @@ -3,6 +3,7 @@ */ package akka.stream.impl.fusing +import akka.stream.Attributes import akka.stream.testkit.AkkaSpec import akka.stream.scaladsl.{ Merge, Broadcast, Balance, Zip } import GraphInterpreter._ @@ -45,6 +46,7 @@ class GraphInterpreterSpec extends GraphInterpreterSpecKit { // Constructing an assembly by hand and resolving ambiguities val assembly = new GraphAssembly( stages = Array(identity, identity), + originalAttributes = Array(Attributes.none, Attributes.none), ins = Array(identity.in, identity.in, null), inOwners = Array(0, 1, -1), outs = Array(null, identity.out, identity.out), diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala index e89666735b..6428195ebb 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/GraphInterpreterSpecKit.scala @@ -5,35 +5,32 @@ package akka.stream.impl.fusing import akka.event.Logging import akka.stream._ -import akka.stream.impl.fusing.GraphInterpreter.{ Failed, GraphAssembly, DownstreamBoundaryStageLogic, UpstreamBoundaryStageLogic } -import akka.stream.stage.{ InHandler, OutHandler, GraphStage, GraphStageLogic } +import akka.stream.impl.fusing.GraphInterpreter.{ DownstreamBoundaryStageLogic, Failed, GraphAssembly, UpstreamBoundaryStageLogic } +import akka.stream.stage.AbstractStage.PushPullGraphStage +import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler, _ } import akka.stream.testkit.AkkaSpec import akka.stream.testkit.Utils.TE +import akka.stream.impl.fusing.GraphInterpreter.GraphAssembly trait GraphInterpreterSpecKit extends AkkaSpec { - sealed trait TestEvent { - def source: GraphStageLogic - } - - case class OnComplete(source: GraphStageLogic) extends TestEvent - case class Cancel(source: GraphStageLogic) extends TestEvent - case class OnError(source: GraphStageLogic, cause: Throwable) extends TestEvent - case class OnNext(source: GraphStageLogic, elem: Any) extends TestEvent - case class RequestOne(source: GraphStageLogic) extends TestEvent - case class RequestAnother(source: GraphStageLogic) extends TestEvent - - case class PreStart(source: GraphStageLogic) extends TestEvent - case class PostStop(source: GraphStageLogic) extends TestEvent - - abstract class TestSetup { - protected var lastEvent: Set[TestEvent] = Set.empty + abstract class Builder { private var _interpreter: GraphInterpreter = _ protected def interpreter: GraphInterpreter = _interpreter def stepAll(): Unit = interpreter.execute(eventLimit = Int.MaxValue) def step(): Unit = interpreter.execute(eventLimit = 1) + object Upstream extends UpstreamBoundaryStageLogic[Int] { + override val out = Outlet[Int]("up") + out.id = 0 + } + + object Downstream extends DownstreamBoundaryStageLogic[Int] { + override val in = Inlet[Int]("down") + in.id = 0 + } + class AssemblyBuilder(stages: Seq[GraphStage[_ <: Shape]]) { var upstreams = Vector.empty[(UpstreamBoundaryStageLogic[_], Inlet[_])] var downstreams = Vector.empty[(Outlet[_], DownstreamBoundaryStageLogic[_])] @@ -54,20 +51,25 @@ trait GraphInterpreterSpecKit extends AkkaSpec { this } - def init(): Unit = { + def buildAssembly(): GraphAssembly = { val ins = upstreams.map(_._2) ++ connections.map(_._2) val outs = connections.map(_._1) ++ downstreams.map(_._1) val inOwners = ins.map { in ⇒ stages.indexWhere(_.shape.inlets.contains(in)) } val outOwners = outs.map { out ⇒ stages.indexWhere(_.shape.outlets.contains(out)) } - val assembly = new GraphAssembly( + new GraphAssembly( stages.toArray, + Array.fill(stages.size)(Attributes.none), (ins ++ Vector.fill(downstreams.size)(null)).toArray, (inOwners ++ Vector.fill(downstreams.size)(-1)).toArray, (Vector.fill(upstreams.size)(null) ++ outs).toArray, (Vector.fill(upstreams.size)(-1) ++ outOwners).toArray) + } - val (inHandlers, outHandlers, logics, _) = assembly.materialize() + def init(): Unit = { + val assembly = buildAssembly() + + val (inHandlers, outHandlers, logics, _) = assembly.materialize(Attributes.none) _interpreter = new GraphInterpreter(assembly, NoMaterializer, Logging(system, classOf[TestSetup]), inHandlers, outHandlers, logics, (_, _, _) ⇒ ()) for ((upstream, i) ← upstreams.zipWithIndex) { @@ -83,11 +85,29 @@ trait GraphInterpreterSpecKit extends AkkaSpec { } def manualInit(assembly: GraphAssembly): Unit = { - val (inHandlers, outHandlers, logics, _) = assembly.materialize() + val (inHandlers, outHandlers, logics, _) = assembly.materialize(Attributes.none) _interpreter = new GraphInterpreter(assembly, NoMaterializer, Logging(system, classOf[TestSetup]), inHandlers, outHandlers, logics, (_, _, _) ⇒ ()) } - def builder(stages: GraphStage[_ <: Shape]*): AssemblyBuilder = new AssemblyBuilder(stages.toSeq) + def builder(stages: GraphStage[_ <: Shape]*): AssemblyBuilder = new AssemblyBuilder(stages) + } + + abstract class TestSetup extends Builder { + + sealed trait TestEvent { + def source: GraphStageLogic + } + + case class OnComplete(source: GraphStageLogic) extends TestEvent + case class Cancel(source: GraphStageLogic) extends TestEvent + case class OnError(source: GraphStageLogic, cause: Throwable) extends TestEvent + case class OnNext(source: GraphStageLogic, elem: Any) extends TestEvent + case class RequestOne(source: GraphStageLogic) extends TestEvent + case class RequestAnother(source: GraphStageLogic) extends TestEvent + case class PreStart(source: GraphStageLogic) extends TestEvent + case class PostStop(source: GraphStageLogic) extends TestEvent + + protected var lastEvent: Set[TestEvent] = Set.empty def lastEvents(): Set[TestEvent] = { val result = lastEvent @@ -158,7 +178,7 @@ trait GraphInterpreterSpecKit extends AkkaSpec { // Modified onPush that does not grab() automatically the element. This accesses some internals. override def onPush(): Unit = { - val internalEvent = interpreter.connectionSlots(inToConn(in.id)) + val internalEvent = interpreter.connectionSlots(portToConn(in.id)) internalEvent match { case Failed(_, elem) ⇒ lastEvent += OnNext(DownstreamPortProbe.this, elem) @@ -173,6 +193,7 @@ trait GraphInterpreterSpecKit extends AkkaSpec { private val assembly = new GraphAssembly( stages = Array.empty, + originalAttributes = Array.empty, ins = Array(null), inOwners = Array(-1), outs = Array(null), @@ -233,7 +254,7 @@ trait GraphInterpreterSpecKit extends AkkaSpec { private val sandwitchStage = new GraphStage[FlowShape[Int, Int]] { override def shape = stageshape - override def createLogic: GraphStageLogic = stage + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = stage } class UpstreamPortProbe[T] extends UpstreamProbe[T]("upstreamPort") { @@ -253,4 +274,133 @@ trait GraphInterpreterSpecKit extends AkkaSpec { .init() } + abstract class OneBoundedSetup[T](ops: Array[GraphStageWithMaterializedValue[Shape, Any]]) extends Builder { + + def this(ops: Iterable[Stage[_, _]]) = { + this(ops.map { op ⇒ + new PushPullGraphStage[Any, Any, Any]( + (_) ⇒ op.asInstanceOf[Stage[Any, Any]], + Attributes.none) + }.toArray.asInstanceOf[Array[GraphStageWithMaterializedValue[Shape, Any]]]) + } + + val upstream = new UpstreamOneBoundedProbe[T] + val downstream = new DownstreamOneBoundedPortProbe[T] + var lastEvent = Set.empty[TestEvent] + + sealed trait TestEvent + + case object OnComplete extends TestEvent + case object Cancel extends TestEvent + case class OnError(cause: Throwable) extends TestEvent + case class OnNext(elem: Any) extends TestEvent + case object RequestOne extends TestEvent + case object RequestAnother extends TestEvent + + private def run() = interpreter.execute(Int.MaxValue) + + private def initialize(): Unit = { + import GraphInterpreter.Boundary + + var i = 0 + val attributes = Array.fill[Attributes](ops.length)(Attributes.none) + val ins = Array.ofDim[Inlet[_]](ops.length + 1) + val inOwners = Array.ofDim[Int](ops.length + 1) + val outs = Array.ofDim[Outlet[_]](ops.length + 1) + val outOwners = Array.ofDim[Int](ops.length + 1) + + ins(ops.length) = null + inOwners(ops.length) = Boundary + outs(0) = null + outOwners(0) = Boundary + + while (i < ops.length) { + val stage = ops(i).asInstanceOf[PushPullGraphStage[_, _, _]] + ins(i) = stage.shape.inlet + inOwners(i) = i + outs(i + 1) = stage.shape.outlet + outOwners(i + 1) = i + i += 1 + } + + manualInit(new GraphAssembly(ops, attributes, ins, inOwners, outs, outOwners)) + interpreter.attachUpstreamBoundary(0, upstream) + interpreter.attachDownstreamBoundary(ops.length, downstream) + + interpreter.init() + + } + + initialize() + run() // Detached stages need the prefetch + + def lastEvents(): Set[TestEvent] = { + val events = lastEvent + lastEvent = Set.empty + events + } + + class UpstreamOneBoundedProbe[T] extends UpstreamBoundaryStageLogic[T] { + val out = Outlet[T]("out") + out.id = 0 + + setHandler(out, new OutHandler { + override def onPull(): Unit = { + if (lastEvent.contains(RequestOne)) lastEvent += RequestAnother + else lastEvent += RequestOne + } + + override def onDownstreamFinish(): Unit = lastEvent += Cancel + }) + + def onNext(elem: T): Unit = { + push(out, elem) + run() + } + def onComplete(): Unit = { + complete(out) + run() + } + + def onNextAndComplete(elem: T): Unit = { + push(out, elem) + complete(out) + run() + } + + def onError(ex: Throwable): Unit = { + fail(out, ex) + run() + } + } + + class DownstreamOneBoundedPortProbe[T] extends DownstreamBoundaryStageLogic[T] { + val in = Inlet[T]("in") + in.id = 0 + + setHandler(in, new InHandler { + + // Modified onPush that does not grab() automatically the element. This accesses some internals. + override def onPush(): Unit = { + lastEvent += OnNext(grab(in)) + } + + override def onUpstreamFinish() = lastEvent += OnComplete + override def onUpstreamFailure(ex: Throwable) = lastEvent += OnError(ex) + }) + + def requestOne(): Unit = { + pull(in) + run() + } + + def cancel(): Unit = { + cancel(in) + run() + } + + } + + } + } diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala index 4d0b32decd..00c9ff1ba1 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpec.scala @@ -4,17 +4,24 @@ package akka.stream.impl.fusing import akka.stream.stage._ +import akka.stream.testkit.AkkaSpec import akka.testkit.EventFilter import scala.util.control.NoStackTrace import akka.stream.Supervision -class InterpreterSpec extends InterpreterSpecKit { +class InterpreterSpec extends AkkaSpec with GraphInterpreterSpecKit { import Supervision.stoppingDecider + /* + * These tests were writtern for the previous veryion of the interpreter, the so called OneBoundedInterpreter. + * These stages are now properly emulated by the GraphInterpreter and many of the edge cases were relevant to + * the execution model of the old one. Still, these tests are very valuable, so please do not remove. + */ + "Interpreter" must { - "implement map correctly" in new TestSetup(Seq(Map((x: Int) ⇒ x + 1, stoppingDecider))) { + "implement map correctly" in new OneBoundedSetup[Int](Seq(Map((x: Int) ⇒ x + 1, stoppingDecider))) { lastEvents() should be(Set.empty) downstream.requestOne() @@ -33,7 +40,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnComplete)) } - "implement chain of maps correctly" in new TestSetup(Seq( + "implement chain of maps correctly" in new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, stoppingDecider), Map((x: Int) ⇒ x * 2, stoppingDecider), Map((x: Int) ⇒ x + 1, stoppingDecider))) { @@ -56,7 +63,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel)) } - "work with only boundary ops" in new TestSetup(Seq.empty) { + "work with only boundary ops" in new OneBoundedSetup[Int](Seq.empty) { lastEvents() should be(Set.empty) downstream.requestOne() @@ -69,7 +76,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnComplete)) } - "implement one-to-many many-to-one chain correctly" in new TestSetup(Seq( + "implement one-to-many many-to-one chain correctly" in new OneBoundedSetup[Int](Seq( Doubler(), Filter((x: Int) ⇒ x != 0, stoppingDecider))) { @@ -94,7 +101,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnComplete)) } - "implement many-to-one one-to-many chain correctly" in new TestSetup(Seq( + "implement many-to-one one-to-many chain correctly" in new OneBoundedSetup[Int](Seq( Filter((x: Int) ⇒ x != 0, stoppingDecider), Doubler())) { @@ -119,7 +126,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel)) } - "implement take" in new TestSetup(Seq(Take(2))) { + "implement take" in new OneBoundedSetup[Int](Seq(Take(2))) { lastEvents() should be(Set.empty) @@ -136,7 +143,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(1), Cancel, OnComplete)) } - "implement take inside a chain" in new TestSetup(Seq( + "implement take inside a chain" in new OneBoundedSetup[Int](Seq( Filter((x: Int) ⇒ x != 0, stoppingDecider), Take(2), Map((x: Int) ⇒ x + 1, stoppingDecider))) { @@ -159,7 +166,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel, OnComplete, OnNext(3))) } - "implement fold" in new TestSetup(Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { + "implement fold" in new OneBoundedSetup[Int](Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { lastEvents() should be(Set.empty) downstream.requestOne() @@ -178,7 +185,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(3), OnComplete)) } - "implement fold with proper cancel" in new TestSetup(Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { + "implement fold with proper cancel" in new OneBoundedSetup[Int](Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { lastEvents() should be(Set.empty) @@ -198,7 +205,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel)) } - "work if fold completes while not in a push position" in new TestSetup(Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { + "work if fold completes while not in a push position" in new OneBoundedSetup[Int](Seq(Fold(0, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { lastEvents() should be(Set.empty) @@ -209,7 +216,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnComplete, OnNext(0))) } - "implement grouped" in new TestSetup(Seq(Grouped(3))) { + "implement grouped" in new OneBoundedSetup[Int](Seq(Grouped(3))) { lastEvents() should be(Set.empty) downstream.requestOne() @@ -234,7 +241,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(Vector(3)), OnComplete)) } - "implement conflate" in new TestSetup(Seq(Conflate( + "implement conflate" in new OneBoundedSetup[Int](Seq(Conflate( (in: Int) ⇒ in, (agg: Int, x: Int) ⇒ agg + x, stoppingDecider))) { @@ -266,7 +273,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel)) } - "implement expand" in new TestSetup(Seq(Expand( + "implement expand" in new OneBoundedSetup[Int](Seq(Expand( (in: Int) ⇒ in, (agg: Int) ⇒ (agg, agg)))) { @@ -294,7 +301,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnComplete)) } - "work with conflate-conflate" in new TestSetup(Seq( + "work with conflate-conflate" in new OneBoundedSetup[Int](Seq( Conflate( (in: Int) ⇒ in, (agg: Int, x: Int) ⇒ agg + x, @@ -332,7 +339,7 @@ class InterpreterSpec extends InterpreterSpecKit { } - "work with expand-expand" in new TestSetup(Seq( + "work with expand-expand" in new OneBoundedSetup[Int](Seq( Expand( (in: Int) ⇒ in, (agg: Int) ⇒ (agg, agg + 1)), @@ -369,7 +376,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnComplete, OnNext(12))) } - "implement conflate-expand" in new TestSetup(Seq( + "implement conflate-expand" in new OneBoundedSetup[Int](Seq( Conflate( (in: Int) ⇒ in, (agg: Int, x: Int) ⇒ agg + x, @@ -405,12 +412,7 @@ class InterpreterSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel)) } - "implement expand-conflate" in { - pending - // Needs to detect divergent loops - } - - "implement doubler-conflate" in new TestSetup(Seq( + "implement doubler-conflate" in new OneBoundedSetup[Int](Seq( Doubler(), Conflate( (in: Int) ⇒ in, @@ -429,7 +431,8 @@ class InterpreterSpec extends InterpreterSpecKit { } - "work with jumpback table and completed elements" in new TestSetup(Seq( + // Note, the new interpreter has no jumpback table, still did not want to remove the test + "work with jumpback table and completed elements" in new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x, stoppingDecider), Map((x: Int) ⇒ x, stoppingDecider), KeepGoing(), @@ -461,7 +464,7 @@ class InterpreterSpec extends InterpreterSpecKit { } - "work with pushAndFinish if upstream completes with pushAndFinish" in new TestSetup(Seq( + "work with pushAndFinish if upstream completes with pushAndFinish" in new OneBoundedSetup[Int](Seq( new PushFinishStage)) { lastEvents() should be(Set.empty) @@ -469,11 +472,11 @@ class InterpreterSpec extends InterpreterSpecKit { downstream.requestOne() lastEvents() should be(Set(RequestOne)) - upstream.onNextAndComplete("foo") - lastEvents() should be(Set(OnNext("foo"), OnComplete)) + upstream.onNextAndComplete(0) + lastEvents() should be(Set(OnNext(0), OnComplete)) } - "work with pushAndFinish if indirect upstream completes with pushAndFinish" in new TestSetup(Seq( + "work with pushAndFinish if indirect upstream completes with pushAndFinish" in new OneBoundedSetup[Int](Seq( Map((x: Any) ⇒ x, stoppingDecider), new PushFinishStage, Map((x: Any) ⇒ x, stoppingDecider))) { @@ -483,24 +486,24 @@ class InterpreterSpec extends InterpreterSpecKit { downstream.requestOne() lastEvents() should be(Set(RequestOne)) - upstream.onNextAndComplete("foo") - lastEvents() should be(Set(OnNext("foo"), OnComplete)) + upstream.onNextAndComplete(1) + lastEvents() should be(Set(OnNext(1), OnComplete)) } - "work with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new TestSetup(Seq( + "work with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new OneBoundedSetup[Int](Seq( new PushFinishStage, - Fold("", (x: String, y: String) ⇒ x + y, stoppingDecider))) { + Fold(0, (x: Int, y: Int) ⇒ x + y, stoppingDecider))) { lastEvents() should be(Set.empty) downstream.requestOne() lastEvents() should be(Set(RequestOne)) - upstream.onNextAndComplete("foo") - lastEvents() should be(Set(OnNext("foo"), OnComplete)) + upstream.onNextAndComplete(1) + lastEvents() should be(Set(OnNext(1), OnComplete)) } - "report error if pull is called while op is terminating" in new TestSetup(Seq(new PushPullStage[Any, Any] { + "report error if pull is called while op is terminating" in new OneBoundedSetup[Int](Seq(new PushPullStage[Any, Any] { override def onPull(ctx: Context[Any]): SyncDirective = ctx.pull() override def onPush(elem: Any, ctx: Context[Any]): SyncDirective = ctx.pull() override def onUpstreamFinish(ctx: Context[Any]): TerminationDirective = ctx.absorbTermination() @@ -514,12 +517,12 @@ class InterpreterSpec extends InterpreterSpecKit { val ev = lastEvents() ev.nonEmpty should be(true) ev.forall { - case OnError(_: IllegalStateException) ⇒ true - case _ ⇒ false + case OnError(_: IllegalArgumentException) ⇒ true + case _ ⇒ false } should be(true) } - "implement take-take" in new TestSetup(Seq( + "implement take-take" in new OneBoundedSetup[Int](Seq( Take(1), Take(1))) { lastEvents() should be(Set.empty) @@ -527,12 +530,12 @@ class InterpreterSpec extends InterpreterSpecKit { downstream.requestOne() lastEvents() should be(Set(RequestOne)) - upstream.onNext("foo") - lastEvents() should be(Set(OnNext("foo"), OnComplete, Cancel)) + upstream.onNext(1) + lastEvents() should be(Set(OnNext(1), OnComplete, Cancel)) } - "implement take-take with pushAndFinish from upstream" in new TestSetup(Seq( + "implement take-take with pushAndFinish from upstream" in new OneBoundedSetup[Int](Seq( Take(1), Take(1))) { lastEvents() should be(Set.empty) @@ -540,8 +543,8 @@ class InterpreterSpec extends InterpreterSpecKit { downstream.requestOne() lastEvents() should be(Set(RequestOne)) - upstream.onNextAndComplete("foo") - lastEvents() should be(Set(OnNext("foo"), OnComplete)) + upstream.onNextAndComplete(1) + lastEvents() should be(Set(OnNext(1), OnComplete)) } @@ -551,7 +554,7 @@ class InterpreterSpec extends InterpreterSpecKit { override def onDownstreamFinish(ctx: Context[Int]): TerminationDirective = ctx.absorbTermination() } - "not allow absorbTermination from onDownstreamFinish()" in new TestSetup(Seq( + "not allow absorbTermination from onDownstreamFinish()" in new OneBoundedSetup[Int](Seq( new InvalidAbsorbTermination)) { lastEvents() should be(Set.empty) @@ -564,4 +567,51 @@ class InterpreterSpec extends InterpreterSpecKit { } + private[akka] case class Doubler[T]() extends PushPullStage[T, T] { + var oneMore: Boolean = false + var lastElem: T = _ + + override def onPush(elem: T, ctx: Context[T]): SyncDirective = { + lastElem = elem + oneMore = true + ctx.push(elem) + } + + override def onPull(ctx: Context[T]): SyncDirective = { + if (oneMore) { + oneMore = false + ctx.push(lastElem) + } else ctx.pull() + } + } + + private[akka] case class KeepGoing[T]() extends PushPullStage[T, T] { + var lastElem: T = _ + + override def onPush(elem: T, ctx: Context[T]): SyncDirective = { + lastElem = elem + ctx.push(elem) + } + + override def onPull(ctx: Context[T]): SyncDirective = { + if (ctx.isFinishing) { + ctx.push(lastElem) + } else ctx.pull() + } + + override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = ctx.absorbTermination() + } + + // This test is related to issue #17351 + private[akka] class PushFinishStage(onPostStop: () ⇒ Unit = () ⇒ ()) extends PushStage[Any, Any] { + override def onPush(elem: Any, ctx: Context[Any]): SyncDirective = + ctx.pushAndFinish(elem) + + override def onUpstreamFinish(ctx: Context[Any]): TerminationDirective = + ctx.fail(akka.stream.testkit.Utils.TE("Cannot happen")) + + override def postStop(): Unit = + onPostStop() + } + } diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpecKit.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpecKit.scala deleted file mode 100644 index 805641a5ad..0000000000 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpecKit.scala +++ /dev/null @@ -1,183 +0,0 @@ -/** - * Copyright (C) 2009-2014 Typesafe Inc. - */ -package akka.stream.impl.fusing - -import akka.event.Logging -import akka.stream.stage._ -import akka.stream.testkit.AkkaSpec -import akka.stream.{ ActorMaterializer, Attributes } -import akka.testkit.TestProbe - -trait InterpreterLifecycleSpecKit { - private[akka] case class PreStartAndPostStopIdentity[T]( - onStart: LifecycleContext ⇒ Unit = _ ⇒ (), - onStop: () ⇒ Unit = () ⇒ (), - onUpstreamCompleted: () ⇒ Unit = () ⇒ (), - onUpstreamFailed: Throwable ⇒ Unit = ex ⇒ ()) - extends PushStage[T, T] { - override def preStart(ctx: LifecycleContext) = onStart(ctx) - - override def onPush(elem: T, ctx: Context[T]) = ctx.push(elem) - - override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = { - onUpstreamCompleted() - super.onUpstreamFinish(ctx) - } - - override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = { - onUpstreamFailed(cause) - super.onUpstreamFailure(cause, ctx) - } - - override def postStop() = onStop() - } - - private[akka] case class PreStartFailer[T](pleaseThrow: () ⇒ Unit) extends PushStage[T, T] { - - override def preStart(ctx: LifecycleContext) = - pleaseThrow() - - override def onPush(elem: T, ctx: Context[T]) = ctx.push(elem) - } - - private[akka] case class PostStopFailer[T](ex: () ⇒ Throwable) extends PushStage[T, T] { - override def onUpstreamFinish(ctx: Context[T]) = ctx.finish() - override def onPush(elem: T, ctx: Context[T]) = ctx.push(elem) - - override def postStop(): Unit = throw ex() - } - - // This test is related to issue #17351 - private[akka] class PushFinishStage(onPostStop: () ⇒ Unit = () ⇒ ()) extends PushStage[Any, Any] { - override def onPush(elem: Any, ctx: Context[Any]): SyncDirective = - ctx.pushAndFinish(elem) - - override def onUpstreamFinish(ctx: Context[Any]): TerminationDirective = - ctx.fail(akka.stream.testkit.Utils.TE("Cannot happen")) - - override def postStop(): Unit = - onPostStop() - } - -} - -trait InterpreterSpecKit extends AkkaSpec with InterpreterLifecycleSpecKit { - - case object OnComplete - case object Cancel - case class OnError(cause: Throwable) - case class OnNext(elem: Any) - case object RequestOne - case object RequestAnother - - private[akka] case class Doubler[T]() extends PushPullStage[T, T] { - var oneMore: Boolean = false - var lastElem: T = _ - - override def onPush(elem: T, ctx: Context[T]): SyncDirective = { - lastElem = elem - oneMore = true - ctx.push(elem) - } - - override def onPull(ctx: Context[T]): SyncDirective = { - if (oneMore) { - oneMore = false - ctx.push(lastElem) - } else ctx.pull() - } - } - - private[akka] case class KeepGoing[T]() extends PushPullStage[T, T] { - var lastElem: T = _ - - override def onPush(elem: T, ctx: Context[T]): SyncDirective = { - lastElem = elem - ctx.push(elem) - } - - override def onPull(ctx: Context[T]): SyncDirective = { - if (ctx.isFinishing) { - ctx.push(lastElem) - } else ctx.pull() - } - - override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = ctx.absorbTermination() - } - - abstract class TestSetup(ops: Seq[Stage[_, _]], forkLimit: Int = 100, overflowToHeap: Boolean = false) { - private var lastEvent: Set[Any] = Set.empty - - val upstream = new UpstreamProbe - val downstream = new DownstreamProbe - val sidechannel = TestProbe() - val interpreter = new OneBoundedInterpreter(upstream +: ops :+ downstream, - (op, ctx, event) ⇒ sidechannel.ref ! ActorInterpreter.AsyncInput(op, ctx, event), - Logging(system, classOf[TestSetup]), - ActorMaterializer(), - Attributes.none, - forkLimit, overflowToHeap) - interpreter.init() - - def lastEvents(): Set[Any] = { - val result = lastEvent - lastEvent = Set.empty - result - } - - private[akka] class UpstreamProbe extends BoundaryStage { - - override def onDownstreamFinish(ctx: BoundaryContext): TerminationDirective = { - lastEvent += Cancel - ctx.finish() - } - - override def onPull(ctx: BoundaryContext): Directive = { - if (lastEvent(RequestOne)) - lastEvent += RequestAnother - else - lastEvent += RequestOne - ctx.exit() - } - - override def onPush(elem: Any, ctx: BoundaryContext): Directive = - throw new UnsupportedOperationException("Cannot push the boundary") - - def onNext(elem: Any): Unit = enterAndPush(elem) - def onComplete(): Unit = enterAndFinish() - def onNextAndComplete(elem: Any): Unit = { - context.enter() - context.pushAndFinish(elem) - context.execute() - } - def onError(cause: Throwable): Unit = enterAndFail(cause) - - } - - private[akka] class DownstreamProbe extends BoundaryStage { - override def onPush(elem: Any, ctx: BoundaryContext): Directive = { - lastEvent += OnNext(elem) - ctx.exit() - } - - override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = { - lastEvent += OnComplete - ctx.finish() - } - - override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = { - lastEvent += OnError(cause) - ctx.finish() - } - - override def onPull(ctx: BoundaryContext): Directive = - throw new UnsupportedOperationException("Cannot pull the boundary") - - def requestOne(): Unit = enterAndPull() - - def cancel(): Unit = enterAndFinish() - } - - } -} diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala index 4a4f427c8f..e7f13562a3 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterStressSpec.scala @@ -3,18 +3,27 @@ */ package akka.stream.impl.fusing -import akka.stream.Supervision +import akka.stream.{ Attributes, Shape, Supervision } +import akka.stream.stage.AbstractStage.PushPullGraphStage +import akka.stream.stage.GraphStageWithMaterializedValue +import akka.stream.testkit.AkkaSpec -class InterpreterStressSpec extends InterpreterSpecKit { +class InterpreterStressSpec extends AkkaSpec with GraphInterpreterSpecKit { import Supervision.stoppingDecider val chainLength = 1000 * 1000 val halfLength = chainLength / 2 val repetition = 100 + val f = (x: Int) ⇒ x + 1 + + val map: GraphStageWithMaterializedValue[Shape, Any] = + new PushPullGraphStage[Int, Int, Unit]((_) ⇒ Map(f, stoppingDecider), Attributes.none) + .asInstanceOf[GraphStageWithMaterializedValue[Shape, Any]] + "Interpreter" must { - "work with a massive chain of maps" in new TestSetup(Seq.fill(chainLength)(Map((x: Int) ⇒ x + 1, stoppingDecider))) { + "work with a massive chain of maps" in new OneBoundedSetup[Int](Array.fill(chainLength)(map).asInstanceOf[Array[GraphStageWithMaterializedValue[Shape, Any]]]) { lastEvents() should be(Set.empty) val tstamp = System.nanoTime() @@ -32,11 +41,11 @@ class InterpreterStressSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnComplete)) val time = (System.nanoTime() - tstamp) / (1000.0 * 1000.0 * 1000.0) - // FIXME: Not a real benchmark, should be replaced by a proper JMH bench + // Not a real benchmark, just for sanity check info(s"Chain finished in $time seconds ${(chainLength * repetition) / (time * 1000 * 1000)} million maps/s") } - "work with a massive chain of maps with early complete" in new TestSetup(Seq.fill(halfLength)(Map((x: Int) ⇒ x + 1, stoppingDecider)) ++ + "work with a massive chain of maps with early complete" in new OneBoundedSetup[Int](Iterable.fill(halfLength)(Map((x: Int) ⇒ x + 1, stoppingDecider)) ++ Seq(Take(repetition / 2)) ++ Seq.fill(halfLength)(Map((x: Int) ⇒ x + 1, stoppingDecider))) { @@ -60,11 +69,11 @@ class InterpreterStressSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel, OnComplete, OnNext(0 + chainLength))) val time = (System.nanoTime() - tstamp) / (1000.0 * 1000.0 * 1000.0) - // FIXME: Not a real benchmark, should be replaced by a proper JMH bench + // Not a real benchmark, just for sanity check info(s"Chain finished in $time seconds ${(chainLength * repetition) / (time * 1000 * 1000)} million maps/s") } - "work with a massive chain of takes" in new TestSetup(Seq.fill(chainLength)(Take(1))) { + "work with a massive chain of takes" in new OneBoundedSetup[Int](Iterable.fill(chainLength)(Take(1))) { lastEvents() should be(Set.empty) downstream.requestOne() @@ -75,7 +84,7 @@ class InterpreterStressSpec extends InterpreterSpecKit { } - "work with a massive chain of drops" in new TestSetup(Seq.fill(chainLength / 1000)(Drop(1))) { + "work with a massive chain of drops" in new OneBoundedSetup[Int](Iterable.fill(chainLength / 1000)(Drop(1))) { lastEvents() should be(Set.empty) downstream.requestOne() @@ -93,12 +102,10 @@ class InterpreterStressSpec extends InterpreterSpecKit { } - "work with a massive chain of conflates by overflowing to the heap" in new TestSetup(Seq.fill(100000)(Conflate( + "work with a massive chain of conflates by overflowing to the heap" in new OneBoundedSetup[Int](Iterable.fill(100000)(Conflate( (in: Int) ⇒ in, (agg: Int, in: Int) ⇒ agg + in, - Supervision.stoppingDecider)), - forkLimit = 100, - overflowToHeap = true) { + Supervision.stoppingDecider))) { lastEvents() should be(Set(RequestOne)) diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala index 2438588ad7..238c246678 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala @@ -70,7 +70,7 @@ object InterpreterSupervisionSpec { } -class InterpreterSupervisionSpec extends InterpreterSpecKit { +class InterpreterSupervisionSpec extends GraphInterpreterSpecKit { import InterpreterSupervisionSpec._ import Supervision.stoppingDecider import Supervision.resumingDecider @@ -78,7 +78,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { "Interpreter error handling" must { - "handle external failure" in new TestSetup(Seq(Map((x: Int) ⇒ x + 1, stoppingDecider))) { + "handle external failure" in new OneBoundedSetup[Int](Seq(Map((x: Int) ⇒ x + 1, stoppingDecider))) { lastEvents() should be(Set.empty) upstream.onError(TE) @@ -86,7 +86,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { } - "emit failure when op throws" in new TestSetup(Seq(Map((x: Int) ⇒ if (x == 0) throw TE else x, stoppingDecider))) { + "emit failure when op throws" in new OneBoundedSetup[Int](Seq(Map((x: Int) ⇒ if (x == 0) throw TE else x, stoppingDecider))) { downstream.requestOne() lastEvents() should be(Set(RequestOne)) upstream.onNext(2) @@ -98,7 +98,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel, OnError(TE))) } - "emit failure when op throws in middle of the chain" in new TestSetup(Seq( + "emit failure when op throws in middle of the chain" in new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, stoppingDecider), Map((x: Int) ⇒ if (x == 0) throw TE else x + 10, stoppingDecider), Map((x: Int) ⇒ x + 100, stoppingDecider))) { @@ -114,7 +114,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel, OnError(TE))) } - "resume when Map throws" in new TestSetup(Seq(Map((x: Int) ⇒ if (x == 0) throw TE else x, resumingDecider))) { + "resume when Map throws" in new OneBoundedSetup[Int](Seq(Map((x: Int) ⇒ if (x == 0) throw TE else x, resumingDecider))) { downstream.requestOne() lastEvents() should be(Set(RequestOne)) upstream.onNext(2) @@ -138,7 +138,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(4))) } - "resume when Map throws in middle of the chain" in new TestSetup(Seq( + "resume when Map throws in middle of the chain" in new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, resumingDecider), Map((x: Int) ⇒ if (x == 0) throw TE else x + 10, resumingDecider), Map((x: Int) ⇒ x + 100, resumingDecider))) { @@ -157,7 +157,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(114))) } - "resume when Map throws before Grouped" in new TestSetup(Seq( + "resume when Map throws before Grouped" in new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, resumingDecider), Map((x: Int) ⇒ if (x <= 0) throw TE else x + 10, resumingDecider), Grouped(3))) { @@ -177,7 +177,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(Vector(13, 14, 15)))) } - "complete after resume when Map throws before Grouped" in new TestSetup(Seq( + "complete after resume when Map throws before Grouped" in new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, resumingDecider), Map((x: Int) ⇒ if (x <= 0) throw TE else x + 10, resumingDecider), Grouped(1000))) { @@ -205,7 +205,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { } } - new TestSetup(Seq( + new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, restartingDecider), stage, Map((x: Int) ⇒ x + 100, restartingDecider))) { @@ -228,14 +228,13 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { "restart when onPush throws after ctx.push" in { val stage = new RestartTestStage { override def onPush(elem: Int, ctx: Context[Int]): SyncDirective = { - val ret = ctx.push(sum) - super.onPush(elem, ctx) + val ret = ctx.push(elem) if (elem <= 0) throw TE ret } } - new TestSetup(Seq( + new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, restartingDecider), stage, Map((x: Int) ⇒ x + 100, restartingDecider))) { @@ -248,6 +247,10 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { downstream.requestOne() lastEvents() should be(Set(RequestOne)) upstream.onNext(-1) // boom + // The element has been pushed before the exception, there is no way back + lastEvents() should be(Set(OnNext(100))) + + downstream.requestOne() lastEvents() should be(Set(RequestOne)) upstream.onNext(3) @@ -263,7 +266,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { } } - new TestSetup(Seq( + new OneBoundedSetup[Int](Seq( Map((x: Int) ⇒ x + 1, restartingDecider), stage, Map((x: Int) ⇒ x + 100, restartingDecider))) { @@ -283,7 +286,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { } } - "resume when Filter throws" in new TestSetup(Seq( + "resume when Filter throws" in new OneBoundedSetup[Int](Seq( Filter((x: Int) ⇒ if (x == 0) throw TE else true, resumingDecider))) { downstream.requestOne() lastEvents() should be(Set(RequestOne)) @@ -299,7 +302,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(3))) } - "resume when MapConcat throws" in new TestSetup(Seq( + "resume when MapConcat throws" in new OneBoundedSetup[Int](Seq( MapConcat((x: Int) ⇒ if (x == 0) throw TE else List(x, -x), resumingDecider))) { downstream.requestOne() lastEvents() should be(Set(RequestOne)) @@ -323,7 +326,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { // TODO can't get type inference to work with `pf` inlined val pf: PartialFunction[Int, Int] = { case x: Int ⇒ if (x == 0) throw TE else x } - new TestSetup(Seq( + new OneBoundedSetup[Int](Seq( Collect(pf, restartingDecider))) { downstream.requestOne() lastEvents() should be(Set(RequestOne)) @@ -340,7 +343,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { } } - "resume when Scan throws" in new TestSetup(Seq( + "resume when Scan throws" in new OneBoundedSetup[Int](Seq( Scan(1, (acc: Int, x: Int) ⇒ if (x == 10) throw TE else acc + x, resumingDecider))) { downstream.requestOne() lastEvents() should be(Set(OnNext(1))) @@ -358,7 +361,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(7))) // 1 + 2 + 4 } - "restart when Scan throws" in new TestSetup(Seq( + "restart when Scan throws" in new OneBoundedSetup[Int](Seq( Scan(1, (acc: Int, x: Int) ⇒ if (x == 10) throw TE else acc + x, restartingDecider))) { downstream.requestOne() lastEvents() should be(Set(OnNext(1))) @@ -383,7 +386,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnNext(25))) // 1 + 4 + 20 } - "restart when Conflate `seed` throws" in new TestSetup(Seq(Conflate( + "restart when Conflate `seed` throws" in new OneBoundedSetup[Int](Seq(Conflate( (in: Int) ⇒ if (in == 1) throw TE else in, (agg: Int, x: Int) ⇒ agg + x, restartingDecider))) { @@ -412,7 +415,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set.empty) } - "restart when Conflate `aggregate` throws" in new TestSetup(Seq(Conflate( + "restart when Conflate `aggregate` throws" in new OneBoundedSetup[Int](Seq(Conflate( (in: Int) ⇒ in, (agg: Int, x: Int) ⇒ if (x == 2) throw TE else agg + x, restartingDecider))) { @@ -447,7 +450,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(Cancel)) } - "fail when Expand `seed` throws" in new TestSetup(Seq(Expand( + "fail when Expand `seed` throws" in new OneBoundedSetup[Int](Seq(Expand( (in: Int) ⇒ if (in == 2) throw TE else in, (agg: Int) ⇒ (agg, -math.abs(agg))))) { @@ -469,7 +472,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(OnError(TE), Cancel)) } - "fail when Expand `extrapolate` throws" in new TestSetup(Seq(Expand( + "fail when Expand `extrapolate` throws" in new OneBoundedSetup[Int](Seq(Expand( (in: Int) ⇒ in, (agg: Int) ⇒ if (agg == 2) throw TE else (agg, -math.abs(agg))))) { @@ -493,7 +496,7 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { "fail when onPull throws before pushing all generated elements" in { def test(decider: Supervision.Decider, absorbTermination: Boolean): Unit = { - new TestSetup(Seq( + new OneBoundedSetup[Int](Seq( OneToManyTestStage(decider, absorbTermination))) { downstream.requestOne() diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/LifecycleInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/LifecycleInterpreterSpec.scala index f85e45f73b..97e7b0de5f 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/LifecycleInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/LifecycleInterpreterSpec.scala @@ -3,16 +3,18 @@ */ package akka.stream.impl.fusing +import akka.stream.stage._ +import akka.stream.testkit.AkkaSpec import akka.stream.testkit.Utils.TE import scala.concurrent.duration._ -class LifecycleInterpreterSpec extends InterpreterSpecKit { +class LifecycleInterpreterSpec extends GraphInterpreterSpecKit { import akka.stream.Supervision._ "Interpreter" must { - "call preStart in order on stages" in new TestSetup(Seq( + "call preStart in order on stages" in new OneBoundedSetup[String](Seq( PreStartAndPostStopIdentity(onStart = _ ⇒ testActor ! "start-a"), PreStartAndPostStopIdentity(onStart = _ ⇒ testActor ! "start-b"), PreStartAndPostStopIdentity(onStart = _ ⇒ testActor ! "start-c"))) { @@ -23,7 +25,7 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { upstream.onComplete() } - "call postStop in order on stages - when upstream completes" in new TestSetup(Seq( + "call postStop in order on stages - when upstream completes" in new OneBoundedSetup[String](Seq( PreStartAndPostStopIdentity(onUpstreamCompleted = () ⇒ testActor ! "complete-a", onStop = () ⇒ testActor ! "stop-a"), PreStartAndPostStopIdentity(onUpstreamCompleted = () ⇒ testActor ! "complete-b", onStop = () ⇒ testActor ! "stop-b"), PreStartAndPostStopIdentity(onUpstreamCompleted = () ⇒ testActor ! "complete-c", onStop = () ⇒ testActor ! "stop-c"))) { @@ -37,7 +39,7 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { expectNoMsg(300.millis) } - "call postStop in order on stages - when upstream onErrors" in new TestSetup(Seq( + "call postStop in order on stages - when upstream onErrors" in new OneBoundedSetup[String](Seq( PreStartAndPostStopIdentity( onUpstreamFailed = ex ⇒ testActor ! ex.getMessage, onStop = () ⇒ testActor ! "stop-c"))) { @@ -48,7 +50,7 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { expectNoMsg(300.millis) } - "call postStop in order on stages - when downstream cancels" in new TestSetup(Seq( + "call postStop in order on stages - when downstream cancels" in new OneBoundedSetup[String](Seq( PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-a"), PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-b"), PreStartAndPostStopIdentity(onStop = () ⇒ testActor ! "stop-c"))) { @@ -59,7 +61,7 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { expectNoMsg(300.millis) } - "call preStart before postStop" in new TestSetup(Seq( + "call preStart before postStop" in new OneBoundedSetup[String](Seq( PreStartAndPostStopIdentity(onStart = _ ⇒ testActor ! "start-a", onStop = () ⇒ testActor ! "stop-a"))) { expectMsg("start-a") expectNoMsg(300.millis) @@ -68,25 +70,25 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { expectNoMsg(300.millis) } - "onError when preStart fails" in new TestSetup(Seq( + "onError when preStart fails" in new OneBoundedSetup[String](Seq( PreStartFailer(() ⇒ throw TE("Boom!")))) { lastEvents() should ===(Set(Cancel, OnError(TE("Boom!")))) } - "not blow up when postStop fails" in new TestSetup(Seq( + "not blow up when postStop fails" in new OneBoundedSetup[String](Seq( PostStopFailer(() ⇒ throw TE("Boom!")))) { upstream.onComplete() lastEvents() should ===(Set(OnComplete)) } - "onError when preStart fails with stages after" in new TestSetup(Seq( + "onError when preStart fails with stages after" in new OneBoundedSetup[String](Seq( Map((x: Int) ⇒ x, stoppingDecider), PreStartFailer(() ⇒ throw TE("Boom!")), Map((x: Int) ⇒ x, stoppingDecider))) { lastEvents() should ===(Set(Cancel, OnError(TE("Boom!")))) } - "continue with stream shutdown when postStop fails" in new TestSetup(Seq( + "continue with stream shutdown when postStop fails" in new OneBoundedSetup[String](Seq( PostStopFailer(() ⇒ throw TE("Boom!")))) { lastEvents() should ===(Set()) @@ -94,7 +96,7 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { lastEvents should ===(Set(OnComplete)) } - "postStop when pushAndFinish called if upstream completes with pushAndFinish" in new TestSetup(Seq( + "postStop when pushAndFinish called if upstream completes with pushAndFinish" in new OneBoundedSetup[String](Seq( new PushFinishStage(onPostStop = () ⇒ testActor ! "stop"))) { lastEvents() should be(Set.empty) @@ -107,7 +109,7 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { expectMsg("stop") } - "postStop when pushAndFinish called with pushAndFinish if indirect upstream completes with pushAndFinish" in new TestSetup(Seq( + "postStop when pushAndFinish called with pushAndFinish if indirect upstream completes with pushAndFinish" in new OneBoundedSetup[String](Seq( Map((x: Any) ⇒ x, stoppingDecider), new PushFinishStage(onPostStop = () ⇒ testActor ! "stop"), Map((x: Any) ⇒ x, stoppingDecider))) { @@ -122,7 +124,7 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { expectMsg("stop") } - "postStop when pushAndFinish called with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new TestSetup(Seq( + "postStop when pushAndFinish called with pushAndFinish if upstream completes with pushAndFinish and downstream immediately pulls" in new OneBoundedSetup[String](Seq( new PushFinishStage(onPostStop = () ⇒ testActor ! "stop"), Fold("", (x: String, y: String) ⇒ x + y, stoppingDecider))) { @@ -138,4 +140,54 @@ class LifecycleInterpreterSpec extends InterpreterSpecKit { } + private[akka] case class PreStartAndPostStopIdentity[T]( + onStart: LifecycleContext ⇒ Unit = _ ⇒ (), + onStop: () ⇒ Unit = () ⇒ (), + onUpstreamCompleted: () ⇒ Unit = () ⇒ (), + onUpstreamFailed: Throwable ⇒ Unit = ex ⇒ ()) + extends PushStage[T, T] { + override def preStart(ctx: LifecycleContext) = onStart(ctx) + + override def onPush(elem: T, ctx: Context[T]) = ctx.push(elem) + + override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = { + onUpstreamCompleted() + super.onUpstreamFinish(ctx) + } + + override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = { + onUpstreamFailed(cause) + super.onUpstreamFailure(cause, ctx) + } + + override def postStop() = onStop() + } + + private[akka] case class PreStartFailer[T](pleaseThrow: () ⇒ Unit) extends PushStage[T, T] { + + override def preStart(ctx: LifecycleContext) = + pleaseThrow() + + override def onPush(elem: T, ctx: Context[T]) = ctx.push(elem) + } + + private[akka] case class PostStopFailer[T](ex: () ⇒ Throwable) extends PushStage[T, T] { + override def onUpstreamFinish(ctx: Context[T]) = ctx.finish() + override def onPush(elem: T, ctx: Context[T]) = ctx.push(elem) + + override def postStop(): Unit = throw ex() + } + + // This test is related to issue #17351 + private[akka] class PushFinishStage(onPostStop: () ⇒ Unit = () ⇒ ()) extends PushStage[Any, Any] { + override def onPush(elem: Any, ctx: Context[Any]): SyncDirective = + ctx.pushAndFinish(elem) + + override def onUpstreamFinish(ctx: Context[Any]): TerminationDirective = + ctx.fail(akka.stream.testkit.Utils.TE("Cannot happen")) + + override def postStop(): Unit = + onPostStop() + } + } diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala index 66f31eb1df..396afcebf9 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/InputStreamSinkSpec.scala @@ -55,10 +55,10 @@ class InputStreamSinkSpec extends AkkaSpec(UnboundedMailboxConfig) { class InputStreamSinkTestStage(val timeout: FiniteDuration) extends InputStreamSinkStage(timeout) { - override def createLogicAndMaterializedValue = { - val (logic, inputStream) = super.createLogicAndMaterializedValue - val inHandler = logic.inHandlers(in.id) - logic.inHandlers(in.id) = new InHandler { + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { + val (logic, inputStream) = super.createLogicAndMaterializedValue(inheritedAttributes) + val inHandler = logic.handlers(in.id).asInstanceOf[InHandler] + logic.handlers(in.id) = new InHandler { override def onPush(): Unit = { probe.ref ! InputStreamSinkTestMessages.Push inHandler.onPush() diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala index 1290c59fdb..3b36ce87fb 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala @@ -49,10 +49,10 @@ class OutputStreamSourceSpec extends AkkaSpec(UnboundedMailboxConfig) { class OutputStreamSourceTestStage(val timeout: FiniteDuration) extends OutputStreamSourceStage(timeout) { - override def createLogicAndMaterializedValue = { - val (logic, inputStream) = super.createLogicAndMaterializedValue - val outHandler = logic.outHandlers(out.id) - logic.outHandlers(out.id) = new OutHandler { + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { + val (logic, inputStream) = super.createLogicAndMaterializedValue(inheritedAttributes) + val outHandler = logic.handlers(out.id).asInstanceOf[OutHandler] + logic.handlers(out.id) = new OutHandler { override def onDownstreamFinish(): Unit = { probe.ref ! OutputStreamSourceTestMessages.Finish outHandler.onDownstreamFinish() diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSinkSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSinkSpec.scala index ddaa3e6bcc..18e6aad9d0 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSinkSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSinkSpec.scala @@ -107,7 +107,9 @@ class SynchronousFileSinkSpec extends AkkaSpec(UnboundedMailboxConfig) { } } + // FIXME: overriding dispatcher should be made available with dispatcher alias support in materializer (#17929) "allow overriding the dispatcher using Attributes" in assertAllStagesStopped { + pending targetFile { f ⇒ val sys = ActorSystem("dispatcher-testing", UnboundedMailboxConfig) val mat = ActorMaterializer()(sys) diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSourceSpec.scala index 656a5bf328..df3a2ff65c 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/SynchronousFileSourceSpec.scala @@ -180,7 +180,9 @@ class SynchronousFileSourceSpec extends AkkaSpec(UnboundedMailboxConfig) { } finally shutdown(sys) } + //FIXME: overriding dispatcher should be made available with dispatcher alias support in materializer (#17929) "allow overriding the dispatcher using Attributes" in { + pending val sys = ActorSystem("dispatcher-testing", UnboundedMailboxConfig) val mat = ActorMaterializer()(sys) implicit val timeout = Timeout(500.millis) diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala index e3503a7012..da0f8d34a6 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala @@ -13,10 +13,11 @@ import akka.stream.testkit.Utils._ import akka.stream.testkit._ import akka.stream.{ ActorMaterializer, BindFailedException, StreamTcpException } import akka.util.{ ByteString, Helpers } - import scala.collection.immutable import scala.concurrent.{ Promise, Await } import scala.concurrent.duration._ +import java.net.BindException +import akka.testkit.EventFilter class TcpSpec extends AkkaSpec("akka.io.tcp.windows-connection-abort-workaround-enabled=auto\nakka.stream.materializer.subscription-timeout.timeout = 3s") with TcpHelper { var demand = 0L @@ -350,12 +351,14 @@ class TcpSpec extends AkkaSpec("akka.io.tcp.windows-connection-abort-workaround- conn.flow.join(writeButIgnoreRead).run() })(Keep.left).run(), 3.seconds) - val result = Source.maybe[ByteString] + val (promise, result) = Source.maybe[ByteString] .via(Tcp().outgoingConnection(serverAddress.getHostName, serverAddress.getPort)) - .runFold(ByteString.empty)(_ ++ _) + .toMat(Sink.fold(ByteString.empty)(_ ++ _))(Keep.both) + .run() Await.result(result, 3.seconds) should ===(ByteString("Early response")) + promise.success(None) // close client upstream, no more data binding.unbind() } @@ -453,7 +456,7 @@ class TcpSpec extends AkkaSpec("akka.io.tcp.windows-connection-abort-workaround- Await.result(echoServerFinish, 1.second) } - "bind and unbind correctly" in { + "bind and unbind correctly" in EventFilter[BindException](occurrences = 2).intercept { if (Helpers.isWindows) { info("On Windows unbinding is not immediate") pending diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala index 85289179c6..86bcaf1eea 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala @@ -12,7 +12,7 @@ import scala.util.Random import akka.actor.ActorSystem import akka.pattern.{ after ⇒ later } -import akka.stream.{ ClosedShape, ActorMaterializer } +import akka.stream._ import akka.stream.scaladsl._ import akka.stream.stage._ import akka.stream.testkit._ @@ -52,35 +52,28 @@ object TlsSpec { * independent of the traffic going through. The purpose is to include the last seen * element in the exception message to help in figuring out what went wrong. */ - class Timeout(duration: FiniteDuration)(implicit system: ActorSystem) extends AsyncStage[ByteString, ByteString, Unit] { - private var last: ByteString = _ + class Timeout(duration: FiniteDuration)(implicit system: ActorSystem) extends GraphStage[FlowShape[ByteString, ByteString]] { - override def preStart(ctx: AsyncContext[ByteString, Unit]) = { - val cb = ctx.getAsyncCallback - system.scheduler.scheduleOnce(duration)(cb.invoke(()))(system.dispatcher) - } + private val in = Inlet[ByteString]("in") + private val out = Outlet[ByteString]("out") + override val shape = FlowShape(in, out) - override def onAsyncInput(u: Unit, ctx: AsyncContext[ByteString, Unit]) = - ctx.fail(new TimeoutException(s"timeout expired, last element was $last")) + override def createLogic(attr: Attributes) = new TimerGraphStageLogic(shape) { + override def preStart(): Unit = scheduleOnce((), duration) - override def onPush(elem: ByteString, ctx: AsyncContext[ByteString, Unit]) = { - last = elem - if (ctx.isHoldingDownstream) ctx.pushAndPull(elem) - else ctx.holdUpstream() - } - - override def onPull(ctx: AsyncContext[ByteString, Unit]) = - if (ctx.isFinishing) ctx.pushAndFinish(last) - else if (ctx.isHoldingUpstream) ctx.pushAndPull(last) - else ctx.holdDownstream() - - override def onUpstreamFinish(ctx: AsyncContext[ByteString, Unit]) = - if (ctx.isHoldingUpstream) ctx.absorbTermination() - else ctx.finish() - - override def onDownstreamFinish(ctx: AsyncContext[ByteString, Unit]) = { - system.log.debug("cancelled") - ctx.finish() + var last: ByteString = _ + setHandler(in, new InHandler { + override def onPush(): Unit = { + last = grab(in) + push(out, last) + } + }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + override def onTimer(x: Any): Unit = { + failStage(new TimeoutException(s"timeout expired, last element was $last")) + } } } @@ -363,7 +356,7 @@ class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off .via(debug) .collect { case SessionBytes(_, b) ⇒ b } .scan(ByteString.empty)(_ ++ _) - .transform(() ⇒ new Timeout(6.seconds)) + .via(new Timeout(6.seconds)) .dropWhile(_.size < scenario.output.size) .runWith(Sink.head) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/AttributesSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/AttributesSpec.scala new file mode 100644 index 0000000000..963679d857 --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/AttributesSpec.scala @@ -0,0 +1,65 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.scaladsl + +import akka.stream.ActorMaterializer +import akka.stream.ActorMaterializerSettings +import akka.stream.Attributes +import akka.stream.Attributes._ +import akka.stream.MaterializationContext +import akka.stream.SinkShape +import akka.stream.testkit._ +import scala.concurrent.Future +import scala.concurrent.Promise +import akka.stream.impl.SinkModule +import akka.stream.impl.StreamLayout.Module +import org.scalatest.concurrent.ScalaFutures +import akka.stream.impl.BlackholeSubscriber + +object AttributesSpec { + + object AttributesSink { + def apply(): Sink[Nothing, Future[Attributes]] = + new Sink(new AttributesSink(Attributes.name("attributesSink"), Sink.shape("attributesSink"))) + } + + final class AttributesSink(val attributes: Attributes, shape: SinkShape[Nothing]) extends SinkModule[Nothing, Future[Attributes]](shape) { + override def create(context: MaterializationContext) = + (new BlackholeSubscriber(0, Promise()), Future.successful(context.effectiveAttributes)) + + override protected def newInstance(shape: SinkShape[Nothing]): SinkModule[Nothing, Future[Attributes]] = + new AttributesSink(attributes, shape) + + override def withAttributes(attr: Attributes): Module = + new AttributesSink(attr, amendShape(attr)) + } + +} + +class AttributesSpec extends AkkaSpec with ScalaFutures { + import AttributesSpec._ + + val settings = ActorMaterializerSettings(system) + .withInputBuffer(initialSize = 2, maxSize = 16) + + implicit val materializer = ActorMaterializer(settings) + + "attributes" must { + + "be overridable on a module basis" in { + val runnable = Source.empty.toMat(AttributesSink().withAttributes(Attributes.name("new-name")))(Keep.right) + whenReady(runnable.run()) { attributes ⇒ + attributes.get[Name] should contain(Name("new-name")) + } + } + + "keep the outermost attribute as the least specific" in { + val runnable = Source.empty.toMat(AttributesSink())(Keep.right).withAttributes(Attributes.name("new-name")) + whenReady(runnable.run()) { attributes ⇒ + attributes.get[Name] should contain(Name("attributesSink")) + } + } + } + +} diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala index ecce6e93fd..faf1ee5964 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncSpec.scala @@ -21,46 +21,13 @@ import scala.util.Try import scala.concurrent.ExecutionContext import scala.util.Failure import scala.util.Success +import scala.annotation.tailrec +import scala.concurrent.Promise +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.LinkedBlockingQueue +import org.scalatest.concurrent.ScalaFutures -object FlowMapAsyncSpec { - class MapAsyncOne[In, Out](f: In ⇒ Future[Out])(implicit ec: ExecutionContext) extends AsyncStage[In, Out, Try[Out]] { - private var elemInFlight: Out = _ - - override def onPush(elem: In, ctx: AsyncContext[Out, Try[Out]]) = { - val future = f(elem) - val cb = ctx.getAsyncCallback - future.onComplete(cb.invoke) - ctx.holdUpstream() - } - - override def onPull(ctx: AsyncContext[Out, Try[Out]]) = - if (elemInFlight != null) { - val e = elemInFlight - elemInFlight = null.asInstanceOf[Out] - pushIt(e, ctx) - } else ctx.holdDownstream() - - override def onAsyncInput(input: Try[Out], ctx: AsyncContext[Out, Try[Out]]) = - input match { - case Failure(ex) ⇒ ctx.fail(ex) - case Success(e) if ctx.isHoldingDownstream ⇒ pushIt(e, ctx) - case Success(e) ⇒ - elemInFlight = e - ctx.ignore() - } - - override def onUpstreamFinish(ctx: AsyncContext[Out, Try[Out]]) = - if (ctx.isHoldingUpstream) ctx.absorbTermination() - else ctx.finish() - - private def pushIt(elem: Out, ctx: AsyncContext[Out, Try[Out]]) = - if (ctx.isFinishing) ctx.pushAndFinish(elem) - else ctx.pushAndPull(elem) - } -} - -class FlowMapAsyncSpec extends AkkaSpec { - import FlowMapAsyncSpec._ +class FlowMapAsyncSpec extends AkkaSpec with ScalaFutures { implicit val materializer = ActorMaterializer() @@ -102,11 +69,9 @@ class FlowMapAsyncSpec extends AkkaSpec { n }).to(Sink(c)).run() val sub = c.expectSubscription() - // running 8 in parallel - probe.receiveN(8).toSet should be((1 to 8).toSet) probe.expectNoMsg(500.millis) sub.request(1) - probe.expectMsg(9) + probe.receiveN(9).toSet should be((1 to 9).toSet) probe.expectNoMsg(500.millis) sub.request(2) probe.receiveN(2).toSet should be(Set(10, 11)) @@ -246,44 +211,50 @@ class FlowMapAsyncSpec extends AkkaSpec { } - } + "not run more futures than configured" in assertAllStagesStopped { + val parallelism = 8 - "A MapAsyncOne" must { - import system.dispatcher + val counter = new AtomicInteger + val queue = new LinkedBlockingQueue[(Promise[Int], Long)] - "work in the happy case" in { - val probe = TestProbe() - val N = 100 - val f = Source(1 to N).transform(() ⇒ new MapAsyncOne(i ⇒ { - probe.ref ! i - Future { Thread.sleep(10); probe.ref ! (i + 10); i * 2 } - })).grouped(N + 10).runWith(Sink.head) - Await.result(f, 2.seconds) should ===((1 to N).map(_ * 2)) - probe.receiveN(2 * N) should ===((1 to N).flatMap(x ⇒ List(x, x + 10))) - probe.expectNoMsg(100.millis) - } + val timer = new Thread { + val delay = 50000 // nanoseconds + var count = 0 + @tailrec final override def run(): Unit = { + val cont = try { + val (promise, enqueued) = queue.take() + val wakeup = enqueued + delay + while (System.nanoTime() < wakeup) {} + counter.decrementAndGet() + promise.success(count) + count += 1 + true + } catch { + case _: InterruptedException ⇒ false + } + if (cont) run() + } + } + timer.start - "work when futures fail" in { - val probe = TestSubscriber.manualProbe[Int] - val ex = new Exception("KABOOM") - Source.single(1) - .transform(() ⇒ new MapAsyncOne(_ ⇒ Future.failed(ex))) - .runWith(Sink(probe)) - val sub = probe.expectSubscription() - sub.request(1) - probe.expectError(ex) - } + def deferred(): Future[Int] = { + if (counter.incrementAndGet() > parallelism) Future.failed(new Exception("parallelism exceeded")) + else { + val p = Promise[Int] + queue.offer(p -> System.nanoTime()) + p.future + } + } - "work when futures fail later" in { - val probe = TestSubscriber.manualProbe[Int] - val ex = new Exception("KABOOM") - Source(List(1, 2)) - .transform(() ⇒ new MapAsyncOne(x ⇒ if (x == 1) Future.successful(1) else Future.failed(ex))) - .runWith(Sink(probe)) - val sub = probe.expectSubscription() - sub.request(1) - probe.expectNext(1) - probe.expectError(ex) + try { + val N = 100000 + Source(1 to N) + .mapAsync(parallelism)(i ⇒ deferred()) + .runFold(0)((c, _) ⇒ c + 1) + .futureValue(PatienceConfig(3.seconds)) should ===(N) + } finally { + timer.interrupt() + } } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncUnorderedSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncUnorderedSpec.scala index b22737d70b..09173520cb 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncUnorderedSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowMapAsyncUnorderedSpec.scala @@ -7,7 +7,6 @@ import scala.concurrent.Await import scala.concurrent.Future import scala.concurrent.duration._ import scala.util.control.NoStackTrace - import akka.stream.ActorMaterializer import akka.stream.testkit._ import akka.stream.testkit.scaladsl._ @@ -17,8 +16,15 @@ import akka.testkit.TestProbe import akka.stream.ActorAttributes.supervisionStrategy import akka.stream.Supervision.resumingDecider import akka.stream.impl.ReactiveStreamsCompliance +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.ConcurrentLinkedQueue +import scala.concurrent.Promise +import java.util.concurrent.LinkedBlockingQueue +import scala.annotation.tailrec +import org.scalatest.concurrent.ScalaFutures +import org.scalactic.ConversionCheckedTripleEquals -class FlowMapAsyncUnorderedSpec extends AkkaSpec { +class FlowMapAsyncUnorderedSpec extends AkkaSpec with ScalaFutures with ConversionCheckedTripleEquals { implicit val materializer = ActorMaterializer() @@ -54,13 +60,11 @@ class FlowMapAsyncUnorderedSpec extends AkkaSpec { n }).to(Sink(c)).run() val sub = c.expectSubscription() - // first four run immediately - probe.expectMsgAllOf(1, 2, 3, 4) c.expectNoMsg(200.millis) probe.expectNoMsg(Duration.Zero) sub.request(1) var got = Set(c.expectNext()) - probe.expectMsg(5) + probe.expectMsgAllOf(1, 2, 3, 4, 5) probe.expectNoMsg(500.millis) sub.request(25) probe.expectMsgAllOf(6 to 20: _*) @@ -176,11 +180,11 @@ class FlowMapAsyncUnorderedSpec extends AkkaSpec { .to(Sink(c)).run() val sub = c.expectSubscription() sub.request(10) - for (elem ← List("a", "c")) c.expectNext(elem) + c.expectNextUnordered("a", "c") c.expectComplete() } - "should handle cancel properly" in assertAllStagesStopped { + "handle cancel properly" in assertAllStagesStopped { val pub = TestPublisher.manualProbe[Int]() val sub = TestSubscriber.manualProbe[Int]() @@ -195,5 +199,51 @@ class FlowMapAsyncUnorderedSpec extends AkkaSpec { } + "not run more futures than configured" in assertAllStagesStopped { + val parallelism = 8 + + val counter = new AtomicInteger + val queue = new LinkedBlockingQueue[(Promise[Int], Long)] + + val timer = new Thread { + val delay = 50000 // nanoseconds + var count = 0 + @tailrec final override def run(): Unit = { + val cont = try { + val (promise, enqueued) = queue.take() + val wakeup = enqueued + delay + while (System.nanoTime() < wakeup) {} + counter.decrementAndGet() + promise.success(count) + count += 1 + true + } catch { + case _: InterruptedException ⇒ false + } + if (cont) run() + } + } + timer.start + + def deferred(): Future[Int] = { + if (counter.incrementAndGet() > parallelism) Future.failed(new Exception("parallelism exceeded")) + else { + val p = Promise[Int] + queue.offer(p -> System.nanoTime()) + p.future + } + } + + try { + val N = 100000 + Source(1 to N) + .mapAsyncUnordered(parallelism)(i ⇒ deferred()) + .runFold(0)((c, _) ⇒ c + 1) + .futureValue(PatienceConfig(3.seconds)) should ===(N) + } finally { + timer.interrupt() + } + } + } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSpec.scala index e64c550c49..d05cd1f863 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSpec.scala @@ -6,12 +6,14 @@ package akka.stream.scaladsl import akka.actor._ import akka.stream.Supervision._ import akka.stream.impl._ -import akka.stream.impl.fusing.ActorInterpreter -import akka.stream.stage.Stage +import akka.stream.impl.fusing.{ ActorGraphInterpreter } +import akka.stream.impl.fusing.GraphInterpreter.GraphAssembly +import akka.stream.stage.AbstractStage.PushPullGraphStage +import akka.stream.stage.{ GraphStageLogic, OutHandler, InHandler, Stage } import akka.stream.testkit.Utils._ import akka.stream.testkit._ import akka.stream.testkit.scaladsl.TestSink -import akka.stream.{ AbruptTerminationException, ActorMaterializer, ActorMaterializerSettings, Attributes } +import akka.stream._ import akka.testkit.TestEvent.{ Mute, UnMute } import akka.testkit.{ EventFilter, TestDuration } import com.typesafe.config.ConfigFactory @@ -32,7 +34,7 @@ class FlowSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug.rece import FlowSpec._ val settings = ActorMaterializerSettings(system) - .withInputBuffer(initialSize = 2, maxSize = 16) + .withInputBuffer(initialSize = 2, maxSize = 2) implicit val mat = ActorMaterializer(settings) @@ -40,30 +42,51 @@ class FlowSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug.rece val identity2: Flow[Any, Any, _] ⇒ Flow[Any, Any, _] = in ⇒ identity(in) class BrokenActorInterpreter( + _assembly: GraphAssembly, + _inHandlers: Array[InHandler], + _outHandlers: Array[OutHandler], + _logics: Array[GraphStageLogic], + _shape: Shape, _settings: ActorMaterializerSettings, - _ops: Seq[Stage[_, _]], + _mat: Materializer, brokenMessage: Any) - extends ActorInterpreter(_settings, _ops, mat, Attributes.none) { + extends ActorGraphInterpreter(_assembly, _inHandlers, _outHandlers, _logics, _shape, _settings, _mat) { import akka.stream.actor.ActorSubscriberMessage._ override protected[akka] def aroundReceive(receive: Receive, msg: Any) = { msg match { - case OnNext(m) if m == brokenMessage ⇒ + case ActorGraphInterpreter.OnNext(0, m) if m == brokenMessage ⇒ throw new NullPointerException(s"I'm so broken [$m]") case _ ⇒ super.aroundReceive(receive, msg) } } } - val faultyFlow: Flow[Any, Any, _] ⇒ Flow[Any, Any, _] = in ⇒ in.andThenMat { () ⇒ - val props = Props(new BrokenActorInterpreter(settings, List(fusing.Map({ x: Any ⇒ x }, stoppingDecider)), "a3")) + val faultyFlow: Flow[Any, Any, _] ⇒ Flow[Any, Any, _] = in ⇒ in.via({ + val stage = new PushPullGraphStage((_) ⇒ fusing.Map({ x: Any ⇒ x }, stoppingDecider), Attributes.none) + + val assembly = new GraphAssembly( + Array(stage), + Array(Attributes.none), + Array(stage.shape.inlet, null), + Array(0, -1), + Array(null, stage.shape.outlet), + Array(-1, 0)) + + val (inHandlers, outHandlers, logics, _) = assembly.materialize(Attributes.none) + + val props = Props(new BrokenActorInterpreter(assembly, inHandlers, outHandlers, logics, stage.shape, settings, mat, "a3")) .withDispatcher("akka.test.stream-dispatcher").withDeploy(Deploy.local) - val processor = ActorProcessorFactory[Any, Any](system.actorOf( - props, - "borken-stage-actor")) - (processor, ()) - } + val impl = system.actorOf(props, "borken-stage-actor") + + val subscriber = new ActorGraphInterpreter.BoundarySubscriber(impl, 0) + val publisher = new ActorPublisher[Any](impl) { override val wakeUpMsg = ActorGraphInterpreter.SubscribePending(0) } + + impl ! ActorGraphInterpreter.ExposedPublisher(0, publisher) + + Flow.fromSinkAndSource(Sink(subscriber), Source(publisher)) + }) val toPublisher: (Source[Any, _], ActorMaterializer) ⇒ Publisher[Any] = (f, m) ⇒ f.runWith(Sink.publisher)(m) @@ -81,15 +104,15 @@ class FlowSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug.rece for ((name, op) ← List("identity" -> identity, "identity2" -> identity2); n ← List(1, 2, 4)) { s"request initial elements from upstream ($name, $n)" in { - new ChainSetup(op, settings.withInputBuffer(initialSize = n, maxSize = settings.maxInputBufferSize), toPublisher) { - upstream.expectRequest(upstreamSubscription, settings.initialInputBufferSize) + new ChainSetup(op, settings.withInputBuffer(initialSize = n, maxSize = n), toPublisher) { + upstream.expectRequest(upstreamSubscription, settings.maxInputBufferSize) } } } "request more elements from upstream when downstream requests more elements" in { new ChainSetup(identity, settings, toPublisher) { - upstream.expectRequest(upstreamSubscription, settings.initialInputBufferSize) + upstream.expectRequest(upstreamSubscription, settings.maxInputBufferSize) downstreamSubscription.request(1) upstream.expectNoMsg(100.millis) downstreamSubscription.request(2) @@ -132,7 +155,7 @@ class FlowSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug.rece } "cancel upstream when single subscriber cancels subscription while receiving data" in { - new ChainSetup(identity, settings.withInputBuffer(initialSize = 1, maxSize = settings.maxInputBufferSize), toPublisher) { + new ChainSetup(identity, settings.withInputBuffer(initialSize = 1, maxSize = 1), toPublisher) { downstreamSubscription.request(5) upstreamSubscription.expectRequest(1) upstreamSubscription.sendNext("test") @@ -291,7 +314,7 @@ class FlowSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug.rece "be possible to convert to a processor, and should be able to take a Processor" in { val identity1 = Flow[Int].toProcessor - val identity2 = Flow(() ⇒ identity1.run()) + val identity2 = Flow.fromProcessor(() ⇒ identity1.run()) Await.result( Source(1 to 10).via(identity2).grouped(100).runWith(Sink.head), 3.seconds) should ===(1 to 10) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPreferredMergeSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergePreferredSpec.scala similarity index 97% rename from akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPreferredMergeSpec.scala rename to akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergePreferredSpec.scala index 992648d5f7..6acb3625e2 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphPreferredMergeSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphMergePreferredSpec.scala @@ -9,7 +9,7 @@ import akka.stream._ import scala.concurrent.Await import scala.concurrent.duration._ -class GraphPreferredMergeSpec extends TwoStreamsSetup { +class GraphMergePreferredSpec extends TwoStreamsSetup { import FlowGraph.Implicits._ override type Outputs = Int diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphStageTimersSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphStageTimersSpec.scala index 1f1343389b..30ba5cd735 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphStageTimersSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/GraphStageTimersSpec.scala @@ -1,9 +1,9 @@ package akka.stream.scaladsl import akka.actor.ActorRef -import akka.stream.ActorMaterializer +import akka.stream.{ Attributes, ActorMaterializer } import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage -import akka.stream.stage.{ OutHandler, AsyncCallback, InHandler } +import akka.stream.stage.{ TimerGraphStageLogic, OutHandler, AsyncCallback, InHandler } import akka.stream.testkit.{ AkkaSpec, TestPublisher } import akka.testkit.TestDuration @@ -39,13 +39,17 @@ class GraphStageTimersSpec extends AkkaSpec { implicit val mat = ActorMaterializer() class TestStage(probe: ActorRef, sideChannel: SideChannel) extends SimpleLinearGraphStage[Int] { - override def createLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes) = new TimerGraphStageLogic(shape) { val tickCount = Iterator from 1 setHandler(in, new InHandler { override def onPush() = push(out, grab(in)) }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + override def preStart() = { sideChannel.asyncCallback = getAsyncCallback(onTestEvent) } @@ -147,7 +151,7 @@ class GraphStageTimersSpec extends AkkaSpec { } class TestStage2 extends SimpleLinearGraphStage[Int] { - override def createLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes) = new TimerGraphStageLogic(shape) { var tickCount = 0 override def preStart(): Unit = schedulePeriodically("tick", 100.millis) @@ -195,13 +199,17 @@ class GraphStageTimersSpec extends AkkaSpec { val downstream = TestSubscriber.probe[Int]() Source(upstream).via(new SimpleLinearGraphStage[Int] { - override def createLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes) = new TimerGraphStageLogic(shape) { override def preStart(): Unit = scheduleOnce("tick", 100.millis) setHandler(in, new InHandler { override def onPush() = () // Ingore }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + override def onTimer(timerKey: Any) = throw exception } }).runWith(Sink(downstream)) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala index 26730799c3..a2ec35a06e 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SourceSpec.scala @@ -73,7 +73,7 @@ class SourceSpec extends AkkaSpec { } "Maybe Source" must { - "complete materialized future with None when stream cancels" in { + "complete materialized future with None when stream cancels" in Utils.assertAllStagesStopped { val neverSource = Source.maybe[Int] val pubSink = Sink.publisher[Int] @@ -90,7 +90,7 @@ class SourceSpec extends AkkaSpec { Await.result(f.future, 500.millis) shouldEqual None } - "allow external triggering of empty completion" in { + "allow external triggering of empty completion" in Utils.assertAllStagesStopped { val neverSource = Source.maybe[Int].filter(_ ⇒ false) val counterSink = Sink.fold[Int, Int](0) { (acc, _) ⇒ acc + 1 } @@ -102,7 +102,7 @@ class SourceSpec extends AkkaSpec { Await.result(counterFuture, 500.millis) shouldEqual 0 } - "allow external triggering of non-empty completion" in { + "allow external triggering of non-empty completion" in Utils.assertAllStagesStopped { val neverSource = Source.maybe[Int] val counterSink = Sink.head[Int] @@ -114,7 +114,7 @@ class SourceSpec extends AkkaSpec { Await.result(counterFuture, 500.millis) shouldEqual 6 } - "allow external triggering of onError" in { + "allow external triggering of onError" in Utils.assertAllStagesStopped { val neverSource = Source.maybe[Int] val counterSink = Sink.fold[Int, Int](0) { (acc, _) ⇒ acc + 1 } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubscriberSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubscriberSourceSpec.scala index 29d2d82aa1..4bed273a0d 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubscriberSourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubscriberSourceSpec.scala @@ -11,7 +11,7 @@ import scala.concurrent.duration._ import scala.concurrent.Await -class SubscriberSourceSpec extends AkkaSpec("akka.loglevel=DEBUG\nakka.actor.debug.lifecycle=on") { +class SubscriberSourceSpec extends AkkaSpec { implicit val materializer = ActorMaterializer() diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala index a49c67bb7c..2d09cb2e7c 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/SubstreamSubscriptionTimeoutSpec.scala @@ -150,7 +150,7 @@ class SubstreamSubscriptionTimeoutSpec(conf: String) extends AkkaSpec(conf) { private def watchGroupByActor(flowNr: Int): ActorRef = { implicit val t = Timeout(300.millis) import akka.pattern.ask - val path = s"/user/$$a/flow-${flowNr}-1-publisherSource-groupBy" + val path = s"/user/$$a/flow-${flowNr}-1-groupBy" val gropByPath = system.actorSelection(path) val groupByActor = try { Await.result((gropByPath ? Identify("")).mapTo[ActorIdentity], 300.millis).ref.get diff --git a/akka-stream/src/main/boilerplate/akka/stream/scaladsl/UnzipWithApply.scala.template b/akka-stream/src/main/boilerplate/akka/stream/scaladsl/UnzipWithApply.scala.template index 10ac47fa63..5a5eba8b21 100644 --- a/akka-stream/src/main/boilerplate/akka/stream/scaladsl/UnzipWithApply.scala.template +++ b/akka-stream/src/main/boilerplate/akka/stream/scaladsl/UnzipWithApply.scala.template @@ -49,7 +49,7 @@ class UnzipWith1[In, [#A1#]](unzipper: In ⇒ ([#A1#])) extends GraphStage[FanOu [#def out0: Outlet[A1] = shape.out0# ] - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { var pendingCount = 1 var downstreamRunning = 1 diff --git a/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template b/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template index 9f40ba4c8b..03af15507f 100644 --- a/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template +++ b/akka-stream/src/main/boilerplate/akka/stream/scaladsl/ZipWithApply.scala.template @@ -30,7 +30,7 @@ class ZipWith1[[#A1#], O] (zipper: ([#A1#]) ⇒ O) extends GraphStage[FanInShape [#val in0: Inlet[A1] = shape.in0# ] - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { var pending = 1 private def pushAll(): Unit = push(out, zipper([#grab(in0)#])) diff --git a/akka-stream/src/main/scala/akka/stream/ActorMaterializer.scala b/akka-stream/src/main/scala/akka/stream/ActorMaterializer.scala index 254ee46aac..e072cfd40e 100644 --- a/akka-stream/src/main/scala/akka/stream/ActorMaterializer.scala +++ b/akka-stream/src/main/scala/akka/stream/ActorMaterializer.scala @@ -31,11 +31,11 @@ object ActorMaterializer { * the processing steps. The default `namePrefix` is `"flow"`. The actor names are built up of * `namePrefix-flowNumber-flowStepNumber-stepName`. */ - def apply(materializerSettings: Option[ActorMaterializerSettings] = None, namePrefix: Option[String] = None, optimizations: Optimizations = Optimizations.none)(implicit context: ActorRefFactory): ActorMaterializer = { + def apply(materializerSettings: Option[ActorMaterializerSettings] = None, namePrefix: Option[String] = None)(implicit context: ActorRefFactory): ActorMaterializer = { val system = actorSystemOf(context) val settings = materializerSettings getOrElse ActorMaterializerSettings(system) - apply(settings, namePrefix.getOrElse("flow"), optimizations)(context) + apply(settings, namePrefix.getOrElse("flow"))(context) } /** @@ -49,7 +49,7 @@ object ActorMaterializer { * the processing steps. The default `namePrefix` is `"flow"`. The actor names are built up of * `namePrefix-flowNumber-flowStepNumber-stepName`. */ - def apply(materializerSettings: ActorMaterializerSettings, namePrefix: String, optimizations: Optimizations)(implicit context: ActorRefFactory): ActorMaterializer = { + def apply(materializerSettings: ActorMaterializerSettings, namePrefix: String)(implicit context: ActorRefFactory): ActorMaterializer = { val haveShutDown = new AtomicBoolean(false) val system = actorSystemOf(context) @@ -60,8 +60,7 @@ object ActorMaterializer { context.actorOf(StreamSupervisor.props(materializerSettings, haveShutDown).withDispatcher(materializerSettings.dispatcher)), haveShutDown, FlowNameCounter(system).counter, - namePrefix, - optimizations) + namePrefix) } /** @@ -199,11 +198,10 @@ object ActorMaterializerSettings { supervisionDecider: Supervision.Decider, subscriptionTimeoutSettings: StreamSubscriptionTimeoutSettings, debugLogging: Boolean, - outputBurstLimit: Int, - optimizations: Optimizations) = + outputBurstLimit: Int) = new ActorMaterializerSettings( initialInputBufferSize, maxInputBufferSize, dispatcher, supervisionDecider, subscriptionTimeoutSettings, debugLogging, - outputBurstLimit, optimizations) + outputBurstLimit) /** * Create [[ActorMaterializerSettings]]. @@ -228,8 +226,7 @@ object ActorMaterializerSettings { supervisionDecider = Supervision.stoppingDecider, subscriptionTimeoutSettings = StreamSubscriptionTimeoutSettings(config), debugLogging = config.getBoolean("debug-logging"), - outputBurstLimit = config.getInt("output-burst-limit"), - optimizations = Optimizations.none) + outputBurstLimit = config.getInt("output-burst-limit")) /** * Java API @@ -263,8 +260,7 @@ final class ActorMaterializerSettings( val supervisionDecider: Supervision.Decider, val subscriptionTimeoutSettings: StreamSubscriptionTimeoutSettings, val debugLogging: Boolean, - val outputBurstLimit: Int, - val optimizations: Optimizations) { + val outputBurstLimit: Int) { require(initialInputBufferSize > 0, "initialInputBufferSize must be > 0") @@ -278,11 +274,10 @@ final class ActorMaterializerSettings( supervisionDecider: Supervision.Decider = this.supervisionDecider, subscriptionTimeoutSettings: StreamSubscriptionTimeoutSettings = this.subscriptionTimeoutSettings, debugLogging: Boolean = this.debugLogging, - outputBurstLimit: Int = this.outputBurstLimit, - optimizations: Optimizations = this.optimizations) = + outputBurstLimit: Int = this.outputBurstLimit) = new ActorMaterializerSettings( initialInputBufferSize, maxInputBufferSize, dispatcher, supervisionDecider, subscriptionTimeoutSettings, debugLogging, - outputBurstLimit, optimizations) + outputBurstLimit) def withInputBuffer(initialSize: Int, maxSize: Int): ActorMaterializerSettings = copy(initialInputBufferSize = initialSize, maxInputBufferSize = maxSize) @@ -316,9 +311,6 @@ final class ActorMaterializerSettings( def withDebugLogging(enable: Boolean): ActorMaterializerSettings = copy(debugLogging = enable) - def withOptimizations(optimizations: Optimizations): ActorMaterializerSettings = - copy(optimizations = optimizations) - private def requirePowerOfTwo(n: Integer, name: String): Unit = { require(n > 0, s"$name must be > 0") require((n & (n - 1)) == 0, s"$name must be a power of two") @@ -365,11 +357,3 @@ object StreamSubscriptionTimeoutTerminationMode { } -final object Optimizations { - val none: Optimizations = Optimizations(collapsing = false, elision = false, simplification = false, fusion = false) - val all: Optimizations = Optimizations(collapsing = true, elision = true, simplification = true, fusion = true) -} - -final case class Optimizations(collapsing: Boolean, elision: Boolean, simplification: Boolean, fusion: Boolean) { - def isEnabled: Boolean = collapsing || elision || simplification || fusion -} diff --git a/akka-stream/src/main/scala/akka/stream/Attributes.scala b/akka-stream/src/main/scala/akka/stream/Attributes.scala index 1732717140..0863a2b2cd 100644 --- a/akka-stream/src/main/scala/akka/stream/Attributes.scala +++ b/akka-stream/src/main/scala/akka/stream/Attributes.scala @@ -4,10 +4,10 @@ package akka.stream import akka.event.Logging - import scala.annotation.tailrec import scala.collection.immutable -import akka.stream.impl.Stages.StageModule +import scala.reflect.{ classTag, ClassTag } +import akka.stream.impl.Stages.SymbolicStage import akka.japi.function /** @@ -16,7 +16,7 @@ import akka.japi.function * * Note that more attributes for the [[ActorMaterializer]] are defined in [[ActorAttributes]]. */ -final case class Attributes private (attributeList: immutable.Seq[Attributes.Attribute] = Nil) { +final case class Attributes(attributeList: List[Attributes.Attribute] = Nil) { import Attributes._ @@ -44,22 +44,47 @@ final case class Attributes private (attributeList: immutable.Seq[Attributes.Att } /** - * Get first attribute of a given `Class` or subclass thereof. + * Java API: Get the last (most specific) attribute of a given `Class` or subclass thereof. * If no such attribute exists the `default` value is returned. */ def getAttribute[T <: Attribute](c: Class[T], default: T): T = - attributeList.find(c.isInstance) match { - case Some(a) ⇒ c.cast(a) + getAttribute(c) match { + case Some(a) ⇒ a case None ⇒ default } + /** + * Java API: Get the last (most specific) attribute of a given `Class` or subclass thereof. + */ + def getAttribute[T <: Attribute](c: Class[T]): Option[T] = + Option(attributeList.foldLeft(null.asInstanceOf[T])((acc, attr) ⇒ if (c.isInstance(attr)) c.cast(attr) else acc)) + + /** + * Get the last (most specific) attribute of a given type parameter T `Class` or subclass thereof. + * If no such attribute exists the `default` value is returned. + */ + def get[T <: Attribute: ClassTag](default: T) = + getAttribute(classTag[T].runtimeClass.asInstanceOf[Class[T]], default) + + /** + * Get the last (most specific) attribute of a given type parameter T `Class` or subclass thereof. + */ + def get[T <: Attribute: ClassTag] = + getAttribute(classTag[T].runtimeClass.asInstanceOf[Class[T]]) + /** * Adds given attributes to the end of these attributes. */ def and(other: Attributes): Attributes = if (attributeList.isEmpty) other else if (other.attributeList.isEmpty) this - else Attributes(attributeList ++ other.attributeList) + else Attributes(attributeList ::: other.attributeList) + + /** + * Adds given attribute to the end of these attributes. + */ + def and(other: Attribute): Attributes = + Attributes(attributeList :+ other) /** * INTERNAL API @@ -76,12 +101,7 @@ final case class Attributes private (attributeList: immutable.Seq[Attributes.Att case Name(n) ⇒ if (buf ne null) concatNames(i, null, buf.append('-').append(n)) else if (first ne null) { - val b = new StringBuilder( - (first.length() + n.length()) match { - case x if x < 0 ⇒ throw new IllegalStateException("Names too long to concatenate") - case y if y > Int.MaxValue / 2 ⇒ Int.MaxValue - case z ⇒ Math.max(Integer.highestOneBit(z) * 2, 32) - }) + val b = new StringBuilder((first.length() + n.length()) * 2) concatNames(i, null, b.append(first).append('-').append(n)) } else concatNames(i, n, null) case _ ⇒ concatNames(i, first, buf) @@ -95,16 +115,6 @@ final case class Attributes private (attributeList: immutable.Seq[Attributes.Att } } - /** - * INTERNAL API - */ - private[akka] def logLevels: Option[LogLevels] = - attributeList.collectFirst { case l: LogLevels ⇒ l } - - private[akka] def transform(node: StageModule): StageModule = - if ((this eq Attributes.none) || (this eq node.attributes)) node - else node.withAttributes(attributes = this and node.attributes) - } /** diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala index 980bf20ebf..6362def624 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorMaterializerImpl.scala @@ -10,7 +10,7 @@ import akka.dispatch.Dispatchers import akka.pattern.ask import akka.stream.actor.ActorSubscriber import akka.stream.impl.StreamLayout.Module -import akka.stream.impl.fusing.{ ActorGraphInterpreter, GraphModule, ActorInterpreter } +import akka.stream.impl.fusing.{ ActorGraphInterpreter, GraphModule } import akka.stream.impl.io.SslTlsCipherActor import akka.stream._ import akka.stream.io.SslTls.TlsModule @@ -24,14 +24,13 @@ import scala.concurrent.{ Await, ExecutionContextExecutor } /** * INTERNAL API */ -private[akka] case class ActorMaterializerImpl(val system: ActorSystem, +private[akka] case class ActorMaterializerImpl(system: ActorSystem, override val settings: ActorMaterializerSettings, dispatchers: Dispatchers, - val supervisor: ActorRef, - val haveShutDown: AtomicBoolean, + supervisor: ActorRef, + haveShutDown: AtomicBoolean, flowNameCounter: AtomicLong, - namePrefix: String, - optimizations: Optimizations) extends ActorMaterializer { + namePrefix: String) extends ActorMaterializer { import akka.stream.impl.Stages._ override def shutdown(): Unit = @@ -45,6 +44,12 @@ private[akka] case class ActorMaterializerImpl(val system: ActorSystem, private[this] def createFlowName(): String = s"$namePrefix-${nextFlowNameCount()}" + private val initialAttributes = Attributes( + Attributes.InputBuffer(settings.initialInputBufferSize, settings.maxInputBufferSize) :: + ActorAttributes.Dispatcher(settings.dispatcher) :: + ActorAttributes.SupervisionStrategy(settings.supervisionDecider) :: + Nil) + override def effectiveSettings(opAttr: Attributes): ActorMaterializerSettings = { import Attributes._ import ActorAttributes._ @@ -70,7 +75,7 @@ private[akka] case class ActorMaterializerImpl(val system: ActorSystem, throw new IllegalStateException("Attempted to call materialize() after the ActorMaterializer has been shut down.") if (StreamLayout.Debug) StreamLayout.validate(runnableGraph.module) - val session = new MaterializerSession(runnableGraph.module) { + val session = new MaterializerSession(runnableGraph.module, initialAttributes) { private val flowName = createFlowName() private var nextId = 0 private def stageName(attr: Attributes): String = { @@ -93,6 +98,7 @@ private[akka] case class ActorMaterializerImpl(val system: ActorSystem, assignPort(source.shape.outlet, pub.asInstanceOf[Publisher[Any]]) mat + // FIXME: Remove this, only stream-of-stream ops need it case stage: StageModule ⇒ val (processor, mat) = processorFor(stage, effectiveAttributes, effectiveSettings(effectiveAttributes)) assignPort(stage.inPort, processor) @@ -118,7 +124,7 @@ private[akka] case class ActorMaterializerImpl(val system: ActorSystem, case graph: GraphModule ⇒ val calculatedSettings = effectiveSettings(effectiveAttributes) - val (inHandlers, outHandlers, logics, mat) = graph.assembly.materialize() + val (inHandlers, outHandlers, logics, mat) = graph.assembly.materialize(effectiveAttributes) val props = ActorGraphInterpreter.props( graph.assembly, inHandlers, outHandlers, logics, graph.shape, calculatedSettings, ActorMaterializerImpl.this) @@ -137,11 +143,11 @@ private[akka] case class ActorMaterializerImpl(val system: ActorSystem, } } + // FIXME: Remove this, only stream-of-stream ops need it private def processorFor(op: StageModule, effectiveAttributes: Attributes, effectiveSettings: ActorMaterializerSettings): (Processor[Any, Any], Any) = op match { case DirectProcessor(processorFactory, _) ⇒ processorFactory() - case Identity(attr) ⇒ (new VirtualProcessor, ()) case _ ⇒ val (opprops, mat) = ActorProcessorFactory.props(ActorMaterializerImpl.this, op, effectiveAttributes) ActorProcessorFactory[Any, Any]( @@ -248,47 +254,12 @@ private[akka] object ActorProcessorFactory { // USE THIS TO AVOID CLOSING OVER THE MATERIALIZER BELOW // Also, otherwise the attributes will not affect the settings properly! val settings = materializer.effectiveSettings(att) - def interp(s: Stage[_, _]): (Props, Unit) = (ActorInterpreter.props(settings, List(s), materializer, att), ()) - def interpAttr(s: Stage[_, _], newAttributes: Attributes): (Props, Unit) = (ActorInterpreter.props(settings, List(s), materializer, newAttributes), ()) - def inputSizeAttr(n: Long) = { - if (n <= 0) - inputBuffer(initial = 1, max = 1) and att - else if (n <= materializer.settings.maxInputBufferSize) - inputBuffer(initial = n.toInt, max = n.toInt) and att - else - att - } op match { - case Map(f, _) ⇒ interp(fusing.Map(f, settings.supervisionDecider)) - case Filter(p, _) ⇒ interp(fusing.Filter(p, settings.supervisionDecider)) - case Drop(n, _) ⇒ interp(fusing.Drop(n)) - case Take(n, _) ⇒ interpAttr(fusing.Take(n), inputSizeAttr(n)) - case TakeWhile(p, _) ⇒ interp(fusing.TakeWhile(p, settings.supervisionDecider)) - case DropWhile(p, _) ⇒ interp(fusing.DropWhile(p, settings.supervisionDecider)) - case Collect(pf, _) ⇒ interp(fusing.Collect(pf, settings.supervisionDecider)) - case Scan(z, f, _) ⇒ interp(fusing.Scan(z, f, settings.supervisionDecider)) - case Fold(z, f, _) ⇒ interp(fusing.Fold(z, f, settings.supervisionDecider)) - case Intersperse(s, i, e, _) ⇒ interp(fusing.Intersperse(s, i, e)) - case Recover(pf, _) ⇒ interp(fusing.Recover(pf)) - case Expand(s, f, _) ⇒ interp(fusing.Expand(s, f)) - case Conflate(s, f, _) ⇒ interp(fusing.Conflate(s, f, settings.supervisionDecider)) - case Buffer(n, s, _) ⇒ interp(fusing.Buffer(n, s)) - case MapConcat(f, _) ⇒ interp(fusing.MapConcat(f, settings.supervisionDecider)) - case MapAsync(p, f, _) ⇒ interp(fusing.MapAsync(p, f, settings.supervisionDecider)) - case MapAsyncUnordered(p, f, _) ⇒ interp(fusing.MapAsyncUnordered(p, f, settings.supervisionDecider)) - case Grouped(n, _) ⇒ interp(fusing.Grouped(n)) - case Sliding(n, step, _) ⇒ interp(fusing.Sliding(n, step)) - case Log(n, e, l, _) ⇒ interp(fusing.Log(n, e, l)) - case GroupBy(f, _) ⇒ (GroupByProcessorImpl.props(settings, f), ()) - case PrefixAndTail(n, _) ⇒ (PrefixAndTailImpl.props(settings, n), ()) - case Split(d, _) ⇒ (SplitWhereProcessorImpl.props(settings, d), ()) - case ConcatAll(_) ⇒ (ConcatAllImpl.props(materializer), ()) - case StageFactory(mkStage, _) ⇒ interp(mkStage()) - case MaterializingStageFactory(mkStageAndMat, _) ⇒ - val s_m = mkStageAndMat() - (ActorInterpreter.props(settings, List(s_m._1), materializer, att), s_m._2) + case GroupBy(f, _) ⇒ (GroupByProcessorImpl.props(settings, f), ()) + case PrefixAndTail(n, _) ⇒ (PrefixAndTailImpl.props(settings, n), ()) + case Split(d, _) ⇒ (SplitWhereProcessorImpl.props(settings, d), ()) + case ConcatAll(_) ⇒ (ConcatAllImpl.props(materializer), ()) case DirectProcessor(p, m) ⇒ throw new AssertionError("DirectProcessor cannot end up in ActorProcessorFactory") - case Identity(_) ⇒ throw new AssertionError("Identity cannot end up in ActorProcessorFactory") } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/FixedSizeBuffer.scala b/akka-stream/src/main/scala/akka/stream/impl/FixedSizeBuffer.scala index 773e41fc0d..c46b2db8d0 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FixedSizeBuffer.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FixedSizeBuffer.scala @@ -26,6 +26,7 @@ private[akka] object FixedSizeBuffer { else new ModuloFixedSizeBuffer(size) sealed abstract class FixedSizeBuffer[T](val size: Int) { + override def toString = s"Buffer($size, $readIdx, $writeIdx)(${(readIdx until writeIdx).map(get).mkString(", ")})" private val buffer = new Array[AnyRef](size) protected var readIdx = 0 diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index 485c87109d..2640e4d954 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -4,11 +4,14 @@ package akka.stream.impl import akka.event.LoggingAdapter +import akka.stream.ActorAttributes.SupervisionStrategy +import akka.stream.Supervision.Decider import akka.stream.impl.SplitDecision.SplitDecision import akka.stream.impl.StreamLayout._ -import akka.stream.{ OverflowStrategy, Attributes } +import akka.stream._ import akka.stream.Attributes._ -import akka.stream.stage.Stage +import akka.stream.stage.AbstractStage.PushPullGraphStage +import akka.stream.stage.{ GraphStageLogic, GraphStage, Stage } import org.reactivestreams.Processor import scala.collection.immutable import scala.concurrent.Future @@ -96,112 +99,121 @@ private[stream] object Stages { import DefaultAttributes._ + // FIXME: To be deprecated as soon as stream-of-stream operations are stages sealed trait StageModule extends FlowModule[Any, Any, Any] { - def attributes: Attributes def withAttributes(attributes: Attributes): StageModule override def carbonCopy: Module = withAttributes(attributes) } - final case class StageFactory(mkStage: () ⇒ Stage[_, _], attributes: Attributes = stageFactory) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + /* + * Stage that is backed by a GraphStage but can be symbolically introspected + */ + case class SymbolicGraphStage[-In, +Out, Ext](symbolicStage: SymbolicStage[In, Out]) + extends PushPullGraphStage[In, Out, Ext]( + symbolicStage.create, + symbolicStage.attributes) { } - final case class MaterializingStageFactory( - mkStageAndMaterialized: () ⇒ (Stage[_, _], Any), - attributes: Attributes = stageFactory) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + val identityGraph = SymbolicGraphStage[Any, Any, Any](Identity) + + sealed trait SymbolicStage[-In, +Out] { + def attributes: Attributes + def create(effectiveAttributes: Attributes): Stage[In, Out] + + // FIXME: No supervision hooked in yet. + + protected def supervision(attributes: Attributes): Decider = + attributes.get[SupervisionStrategy](SupervisionStrategy(Supervision.stoppingDecider)).decider + } - final case class Identity(attributes: Attributes = identityOp) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + object Identity extends SymbolicStage[Any, Any] { + override val attributes: Attributes = identityOp + + def apply[T]: SymbolicStage[T, T] = this.asInstanceOf[SymbolicStage[T, T]] + + override def create(attr: Attributes): Stage[Any, Any] = fusing.Map(identity, supervision(attr)) } - final case class Map(f: Any ⇒ Any, attributes: Attributes = map) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Map[In, Out](f: In ⇒ Out, attributes: Attributes = map) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.Map(f, supervision(attr)) } - final case class Log(name: String, extract: Any ⇒ Any, loggingAdapter: Option[LoggingAdapter], attributes: Attributes = log) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Log[T](name: String, extract: T ⇒ Any, loggingAdapter: Option[LoggingAdapter], attributes: Attributes = log) extends SymbolicStage[T, T] { + override def create(attr: Attributes): Stage[T, T] = fusing.Log(name, extract, loggingAdapter) } - final case class Filter(p: Any ⇒ Boolean, attributes: Attributes = filter) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Filter[T](p: T ⇒ Boolean, attributes: Attributes = filter) extends SymbolicStage[T, T] { + override def create(attr: Attributes): Stage[T, T] = fusing.Filter(p, supervision(attr)) } - final case class Collect(pf: PartialFunction[Any, Any], attributes: Attributes = collect) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Collect[In, Out](pf: PartialFunction[In, Out], attributes: Attributes = collect) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.Collect(pf, supervision(attr)) } - final case class Recover(pf: PartialFunction[Any, Any], attributes: Attributes = recover) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Recover[In, Out >: In](pf: PartialFunction[Throwable, Out], attributes: Attributes = recover) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.Recover(pf) } - final case class MapAsync(parallelism: Int, f: Any ⇒ Future[Any], attributes: Attributes = mapAsync) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) - } - - final case class MapAsyncUnordered(parallelism: Int, f: Any ⇒ Future[Any], attributes: Attributes = mapAsyncUnordered) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) - } - - final case class Grouped(n: Int, attributes: Attributes = grouped) extends StageModule { + final case class Grouped[T](n: Int, attributes: Attributes = grouped) extends SymbolicStage[T, immutable.Seq[T]] { require(n > 0, "n must be greater than 0") - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Grouped(n) } - final case class Sliding(n: Int, step: Int, attributes: Attributes = sliding) extends StageModule { + final case class Sliding[T](n: Int, step: Int, attributes: Attributes = sliding) extends SymbolicStage[T, immutable.Seq[T]] { require(n > 0, "n must be greater than 0") require(step > 0, "step must be greater than 0") - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + + override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Sliding(n, step) } - final case class Take(n: Long, attributes: Attributes = take) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Take[T](n: Long, attributes: Attributes = take) extends SymbolicStage[T, T] { + override def create(attr: Attributes): Stage[T, T] = fusing.Take(n) } - final case class Drop(n: Long, attributes: Attributes = drop) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Drop[T](n: Long, attributes: Attributes = drop) extends SymbolicStage[T, T] { + override def create(attr: Attributes): Stage[T, T] = fusing.Drop(n) } - final case class TakeWhile(p: Any ⇒ Boolean, attributes: Attributes = takeWhile) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) - + final case class TakeWhile[T](p: T ⇒ Boolean, attributes: Attributes = takeWhile) extends SymbolicStage[T, T] { + override def create(attr: Attributes): Stage[T, T] = fusing.TakeWhile(p, supervision(attr)) } - final case class DropWhile(p: Any ⇒ Boolean, attributes: Attributes = dropWhile) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class DropWhile[T](p: T ⇒ Boolean, attributes: Attributes = dropWhile) extends SymbolicStage[T, T] { + override def create(attr: Attributes): Stage[T, T] = fusing.DropWhile(p, supervision(attr)) } - final case class Scan(zero: Any, f: (Any, Any) ⇒ Any, attributes: Attributes = scan) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out, attributes: Attributes = scan) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.Scan(zero, f, supervision(attr)) } - final case class Fold(zero: Any, f: (Any, Any) ⇒ Any, attributes: Attributes = fold) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Intersperse[T](start: Option[T], inject: T, end: Option[T], attributes: Attributes = intersperse) extends SymbolicStage[T, T] { + override def create(attr: Attributes) = fusing.Intersperse(start, inject, end) } - final case class Intersperse(start: Option[Any], inject: Any, end: Option[Any], attributes: Attributes = intersperse) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out, attributes: Attributes = fold) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.Fold(zero, f, supervision(attr)) } - final case class Buffer(size: Int, overflowStrategy: OverflowStrategy, attributes: Attributes = buffer) extends StageModule { + final case class Buffer[T](size: Int, overflowStrategy: OverflowStrategy, attributes: Attributes = buffer) extends SymbolicStage[T, T] { require(size > 0, s"Buffer size must be larger than zero but was [$size]") - - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + override def create(attr: Attributes): Stage[T, T] = fusing.Buffer(size, overflowStrategy) } - final case class Conflate(seed: Any ⇒ Any, aggregate: (Any, Any) ⇒ Any, attributes: Attributes = conflate) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Conflate[In, Out](seed: In ⇒ Out, aggregate: (Out, In) ⇒ Out, attributes: Attributes = conflate) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.Conflate(seed, aggregate, supervision(attr)) } - final case class Expand(seed: Any ⇒ Any, extrapolate: Any ⇒ (Any, Any), attributes: Attributes = expand) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class Expand[In, Out, Seed](seed: In ⇒ Seed, extrapolate: Seed ⇒ (Out, Seed), attributes: Attributes = expand) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.Expand(seed, extrapolate) } - final case class MapConcat(f: Any ⇒ immutable.Iterable[Any], attributes: Attributes = mapConcat) extends StageModule { - override def withAttributes(attributes: Attributes) = copy(attributes = attributes) + final case class MapConcat[In, Out](f: In ⇒ immutable.Iterable[Out], attributes: Attributes = mapConcat) extends SymbolicStage[In, Out] { + override def create(attr: Attributes): Stage[In, Out] = fusing.MapConcat(f, supervision(attr)) } + // FIXME: These are not yet proper stages, therefore they use the deprecated StageModule infrastructure + final case class GroupBy(f: Any ⇒ Any, attributes: Attributes = groupBy) extends StageModule { override def withAttributes(attributes: Attributes) = copy(attributes = attributes) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala b/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala index d84da3721c..674f4a309e 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/StreamLayout.scala @@ -184,7 +184,7 @@ private[akka] object StreamLayout { downstreams, upstreams, Transform(f, if (this.isSealed) Atomic(this) else this.materializedValueComputation), - attributes) + if (this.isSealed) Attributes.none else attributes) } /** @@ -225,7 +225,7 @@ private[akka] object StreamLayout { upstreams ++ that.upstreams, // would like to optimize away this allocation for Keep.{left,right} but that breaks side-effecting transformations Combine(f.asInstanceOf[(Any, Any) ⇒ Any], matComputation1, matComputation2), - attributes) + Attributes.none) } /** @@ -595,7 +595,7 @@ private[stream] object MaterializerSession { /** * INTERNAL API */ -private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Module) { +private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Module, val initialAttributes: Attributes) { import StreamLayout._ private var subscribersStack: List[mutable.Map[InPort, Subscriber[Any]]] = @@ -653,7 +653,7 @@ private[stream] abstract class MaterializerSession(val topLevel: StreamLayout.Mo require( topLevel.isRunnable, s"The top level module cannot be materialized because it has unconnected ports: ${(topLevel.inPorts ++ topLevel.outPorts).mkString(", ")}") - try materializeModule(topLevel, topLevel.attributes) + try materializeModule(topLevel, initialAttributes and topLevel.attributes) catch { case NonFatal(cause) ⇒ // PANIC!!! THE END OF THE MATERIALIZATION IS NEAR! diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala index ad08907bd5..abe0df2276 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorGraphInterpreter.scala @@ -339,40 +339,41 @@ private[stream] class ActorGraphInterpreter( override def receive: Receive = { // Cases that are most likely on the hot path, in decreasing order of frequency case OnNext(id: Int, e: Any) ⇒ - if (GraphInterpreter.Debug) println(s" onNext $e id=$id") + if (GraphInterpreter.Debug) println(s"${interpreter.Name} onNext $e id=$id") inputs(id).onNext(e) runBatch() case RequestMore(id: Int, demand: Long) ⇒ - if (GraphInterpreter.Debug) println(s" request $demand id=$id") + if (GraphInterpreter.Debug) println(s"${interpreter.Name} request $demand id=$id") outputs(id).requestMore(demand) runBatch() case Resume ⇒ resumeScheduled = false if (interpreter.isSuspended) runBatch() case AsyncInput(logic, event, handler) ⇒ - if (GraphInterpreter.Debug) println(s"ASYNC $event") - if (!interpreter.isStageCompleted(logic.stageId)) { + if (GraphInterpreter.Debug) println(s"${interpreter.Name} ASYNC $event ($handler) [$logic]") + if (!interpreter.isStageCompleted(logic)) { try handler(event) catch { case NonFatal(e) ⇒ logic.failStage(e) } + interpreter.afterStageHasRun(logic) } runBatch() // Initialization and completion messages case OnError(id: Int, cause: Throwable) ⇒ - if (GraphInterpreter.Debug) println(s" onError id=$id") + if (GraphInterpreter.Debug) println(s"${interpreter.Name} onError id=$id") inputs(id).onError(cause) runBatch() case OnComplete(id: Int) ⇒ - if (GraphInterpreter.Debug) println(s" onComplete id=$id") + if (GraphInterpreter.Debug) println(s"${interpreter.Name} onComplete id=$id") inputs(id).onComplete() runBatch() case OnSubscribe(id: Int, subscription: Subscription) ⇒ subscribesPending -= 1 inputs(id).onSubscribe(subscription) case Cancel(id: Int) ⇒ - if (GraphInterpreter.Debug) println(s" cancel id=$id") + if (GraphInterpreter.Debug) println(s"${interpreter.Name} cancel id=$id") outputs(id).cancel() runBatch() case SubscribePending(id: Int) ⇒ diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala deleted file mode 100644 index c9bb2373e7..0000000000 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala +++ /dev/null @@ -1,389 +0,0 @@ -/** - * Copyright (C) 2009-2014 Typesafe Inc. - */ -package akka.stream.impl.fusing - -import akka.actor._ -import akka.stream.impl.ReactiveStreamsCompliance._ -import akka.stream.{ AbruptTerminationException, ActorMaterializerSettings, Attributes, ActorMaterializer } -import akka.stream.actor.ActorSubscriber.OnSubscribe -import akka.stream.actor.ActorSubscriberMessage.{ OnNext, OnError, OnComplete } -import akka.stream.impl._ -import akka.stream.impl.fusing.OneBoundedInterpreter.{ InitializationFailed, InitializationSuccessful } -import akka.stream.stage._ -import org.reactivestreams.{ Subscriber, Subscription } -import akka.event.{ Logging, LoggingAdapter } - -/** - * INTERNAL API - */ -private[akka] class BatchingActorInputBoundary(val size: Int, val name: String) - extends BoundaryStage { - - require(size > 0, "buffer size cannot be zero") - require((size & (size - 1)) == 0, "buffer size must be a power of two") - - // TODO: buffer and batch sizing heuristics - private var upstream: Subscription = _ - private val inputBuffer = Array.ofDim[AnyRef](size) - private var inputBufferElements = 0 - private var nextInputElementCursor = 0 - private var upstreamCompleted = false - private var downstreamWaiting = false - private var downstreamCanceled = false - private val IndexMask = size - 1 - - private def requestBatchSize = math.max(1, inputBuffer.length / 2) - private var batchRemaining = requestBatchSize - - val subreceive: SubReceive = new SubReceive(waitingForUpstream) - - def isFinished = upstreamCompleted && ((upstream ne null) || downstreamCanceled) - - def setDownstreamCanceled(): Unit = downstreamCanceled = true - - private def dequeue(): Any = { - val elem = inputBuffer(nextInputElementCursor) - require(elem ne null) - inputBuffer(nextInputElementCursor) = null - - batchRemaining -= 1 - if (batchRemaining == 0 && !upstreamCompleted) { - tryRequest(upstream, requestBatchSize) - batchRemaining = requestBatchSize - } - - inputBufferElements -= 1 - nextInputElementCursor = (nextInputElementCursor + 1) & IndexMask - elem - } - - private def enqueue(elem: Any): Unit = { - if (OneBoundedInterpreter.Debug) println(f" enq $elem%-19s $name") - if (!upstreamCompleted) { - if (inputBufferElements == size) throw new IllegalStateException("Input buffer overrun") - inputBuffer((nextInputElementCursor + inputBufferElements) & IndexMask) = elem.asInstanceOf[AnyRef] - inputBufferElements += 1 - } - } - - override def onPush(elem: Any, ctx: BoundaryContext): Directive = - throw new UnsupportedOperationException("BUG: Cannot push the upstream boundary") - - override def onPull(ctx: BoundaryContext): Directive = { - if (inputBufferElements > 1) ctx.push(dequeue()) - else if (inputBufferElements == 1) { - if (upstreamCompleted) ctx.pushAndFinish(dequeue()) - else ctx.push(dequeue()) - } else if (upstreamCompleted) { - ctx.finish() - } else { - downstreamWaiting = true - ctx.exit() - } - } - - override def onDownstreamFinish(ctx: BoundaryContext): TerminationDirective = { - cancel() - ctx.finish() - } - - def cancel(): Unit = { - if (!upstreamCompleted) { - upstreamCompleted = true - if (upstream ne null) tryCancel(upstream) - downstreamWaiting = false - clear() - } - } - - private def clear(): Unit = { - java.util.Arrays.fill(inputBuffer, 0, inputBuffer.length, null) - inputBufferElements = 0 - } - - private def onComplete(): Unit = - if (!upstreamCompleted) { - upstreamCompleted = true - // onUpstreamFinish is not back-pressured, stages need to deal with this - if (inputBufferElements == 0) enterAndFinish() - } - - private def onSubscribe(subscription: Subscription): Unit = { - require(subscription != null) - if (upstreamCompleted) - tryCancel(subscription) - else if (downstreamCanceled) { - upstreamCompleted = true - tryCancel(subscription) - } else { - upstream = subscription - // Prefetch - tryRequest(upstream, inputBuffer.length) - subreceive.become(upstreamRunning) - } - } - - // Call this when an error happens that does not come from the usual onError channel - // (exceptions while calling RS interfaces, abrupt termination etc) - def onInternalError(e: Throwable): Unit = { - if (!(upstreamCompleted || downstreamCanceled) && (upstream ne null)) { - upstream.cancel() - } - onError(e) - } - - def onError(e: Throwable): Unit = { - if (!upstreamCompleted) { - upstreamCompleted = true - enterAndFail(e) - } - } - - private def waitingForUpstream: Actor.Receive = { - case OnComplete ⇒ onComplete() - case OnSubscribe(subscription) ⇒ onSubscribe(subscription) - case OnError(cause) ⇒ onError(cause) - } - - private def upstreamRunning: Actor.Receive = { - case OnNext(element) ⇒ - enqueue(element) - if (downstreamWaiting) { - downstreamWaiting = false - enterAndPush(dequeue()) - } - - case OnComplete ⇒ onComplete() - case OnError(cause) ⇒ onError(cause) - case OnSubscribe(subscription) ⇒ tryCancel(subscription) // spec rule 2.5 - } - -} - -private[akka] object ActorOutputBoundary { - /** - * INTERNAL API. - */ - private case object ContinuePulling extends DeadLetterSuppression with NoSerializationVerificationNeeded -} - -/** - * INTERNAL API - */ -private[akka] class ActorOutputBoundary(val actor: ActorRef, - val debugLogging: Boolean, - val log: LoggingAdapter, - val outputBurstLimit: Int) - extends BoundaryStage { - import ReactiveStreamsCompliance._ - import ActorOutputBoundary._ - - private var exposedPublisher: ActorPublisher[Any] = _ - - private var subscriber: Subscriber[Any] = _ - private var downstreamDemand: Long = 0L - // This flag is only used if complete/fail is called externally since this op turns into a Finished one inside the - // interpreter (i.e. inside this op this flag has no effects since if it is completed the op will not be invoked) - private var downstreamCompleted = false - // this is true while we “hold the ball”; while “false” incoming demand will just be queued up - private var upstreamWaiting = true - // when upstream failed before we got the exposed publisher - private var upstreamFailed: Option[Throwable] = None - // the number of elements emitted during a single execution is bounded - private var burstRemaining = outputBurstLimit - - private def tryBounceBall(ctx: BoundaryContext) = { - burstRemaining -= 1 - if (burstRemaining > 0) ctx.pull() - else { - actor ! ContinuePulling - takeBallOut(ctx) - } - } - - private def takeBallOut(ctx: BoundaryContext) = { - upstreamWaiting = true - ctx.exit() - } - - private def tryPutBallIn() = - if (upstreamWaiting) { - burstRemaining = outputBurstLimit - upstreamWaiting = false - enterAndPull() - } - - val subreceive = new SubReceive(waitingExposedPublisher) - - private def onNext(elem: Any): Unit = { - downstreamDemand -= 1 - tryOnNext(subscriber, elem) - } - - private def complete(): Unit = { - if (!downstreamCompleted) { - downstreamCompleted = true - if (exposedPublisher ne null) exposedPublisher.shutdown(None) - if (subscriber ne null) tryOnComplete(subscriber) - } - } - - def fail(e: Throwable): Unit = { - if (!downstreamCompleted) { - downstreamCompleted = true - if (debugLogging) - log.debug("fail due to: {}", e.getMessage) - if (exposedPublisher ne null) exposedPublisher.shutdown(Some(e)) - if ((subscriber ne null) && !e.isInstanceOf[SpecViolation]) tryOnError(subscriber, e) - } else if (exposedPublisher == null && upstreamFailed.isEmpty) { - // fail called before the exposed publisher arrived, we must store it and fail when we're first able to - upstreamFailed = Some(e) - } - } - - override def onPush(elem: Any, ctx: BoundaryContext): Directive = { - onNext(elem) - if (downstreamCompleted) ctx.finish() - else if (downstreamDemand > 0) tryBounceBall(ctx) - else takeBallOut(ctx) - } - - override def onPull(ctx: BoundaryContext): Directive = - throw new UnsupportedOperationException("BUG: Cannot pull the downstream boundary") - - override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = { - complete() - ctx.finish() - } - - override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = { - fail(cause) - ctx.fail(cause) - } - - private def subscribePending(subscribers: Seq[Subscriber[Any]]): Unit = - subscribers foreach { sub ⇒ - if (subscriber eq null) { - subscriber = sub - tryOnSubscribe(subscriber, new ActorSubscription(actor, subscriber)) - } else - rejectAdditionalSubscriber(subscriber, s"${Logging.simpleName(this)}") - } - - protected def waitingExposedPublisher: Actor.Receive = { - case ExposedPublisher(publisher) ⇒ - upstreamFailed match { - case _: Some[_] ⇒ - publisher.shutdown(upstreamFailed) - case _ ⇒ - exposedPublisher = publisher - subreceive.become(downstreamRunning) - } - case other ⇒ - throw new IllegalStateException(s"The first message must be ExposedPublisher but was [$other]") - } - - protected def downstreamRunning: Actor.Receive = { - case SubscribePending ⇒ - subscribePending(exposedPublisher.takePendingSubscribers()) - case RequestMore(subscription, elements) ⇒ - if (elements < 1) { - enterAndFinish() - fail(ReactiveStreamsCompliance.numberOfElementsInRequestMustBePositiveException) - } else { - downstreamDemand += elements - if (downstreamDemand < 0) - downstreamDemand = Long.MaxValue // Long overflow, Reactive Streams Spec 3:17: effectively unbounded - if (OneBoundedInterpreter.Debug) { - val s = s"$downstreamDemand (+$elements)" - println(f" dem $s%-19s ${actor.path}") - } - tryPutBallIn() - } - - case ContinuePulling ⇒ - if (!downstreamCompleted && downstreamDemand > 0) tryPutBallIn() - - case Cancel(subscription) ⇒ - downstreamCompleted = true - subscriber = null - exposedPublisher.shutdown(Some(new ActorPublisher.NormalShutdownException)) - enterAndFinish() - } - -} - -/** - * INTERNAL API - */ -private[akka] object ActorInterpreter { - def props(settings: ActorMaterializerSettings, ops: Seq[Stage[_, _]], materializer: ActorMaterializer, attributes: Attributes = Attributes.none): Props = - Props(new ActorInterpreter(settings, ops, materializer, attributes)).withDeploy(Deploy.local) - - case class AsyncInput(op: AsyncStage[Any, Any, Any], ctx: AsyncContext[Any, Any], event: Any) extends DeadLetterSuppression with NoSerializationVerificationNeeded -} - -/** - * INTERNAL API - */ -private[akka] class ActorInterpreter(val settings: ActorMaterializerSettings, val ops: Seq[Stage[_, _]], val materializer: ActorMaterializer, val attributes: Attributes) - extends Actor with ActorLogging { - import ActorInterpreter._ - - private val upstream = new BatchingActorInputBoundary(settings.initialInputBufferSize, context.self.path.toString) - private val downstream = new ActorOutputBoundary(self, settings.debugLogging, log, settings.outputBurstLimit) - private val interpreter = - new OneBoundedInterpreter(upstream +: ops :+ downstream, - (op, ctx, event) ⇒ self ! AsyncInput(op, ctx, event), - Logging(this), - materializer, - attributes, - name = context.self.path.toString) - - interpreter.init() match { - case failed: InitializationFailed ⇒ - // the Actor will be stopped thanks to aroundReceive checking interpreter.isFinished - upstream.setDownstreamCanceled() - downstream.fail(failed.mostDownstream.ex) - case InitializationSuccessful ⇒ // ok - } - - def receive: Receive = - upstream.subreceive - .orElse[Any, Unit](downstream.subreceive) - .orElse[Any, Unit] { - case AsyncInput(op, ctx, event) ⇒ - ctx.enter() - op.onAsyncInput(event, ctx) - ctx.execute() - } - - override protected[akka] def aroundReceive(receive: Actor.Receive, msg: Any): Unit = { - super.aroundReceive(receive, msg) - - if (interpreter.isFinished) { - if (upstream.isFinished) context.stop(self) - else upstream.setDownstreamCanceled() - } - } - - override def postStop(): Unit = { - // This should handle termination while interpreter is running. If the upstream have been closed already this - // call has no effect and therefore do the right thing: nothing. - val ex = AbruptTerminationException(self) - try upstream.onInternalError(ex) - // Will only have an effect if the above call to the interpreter failed to emit a proper failure to the downstream - // otherwise this will have no effect - finally { - downstream.fail(ex) - upstream.cancel() - } - } - - override def postRestart(reason: Throwable): Unit = { - super.postRestart(reason) - throw new IllegalStateException("This actor cannot be restarted", reason) - } - -} diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala index c3a420bdb9..f03efe3c4a 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphInterpreter.scala @@ -7,9 +7,9 @@ import java.util.Arrays import akka.event.LoggingAdapter import akka.stream.stage._ -import akka.stream.{ Materializer, Shape, Inlet, Outlet } import scala.annotation.tailrec import scala.collection.immutable +import akka.stream._ import scala.util.control.NonFatal /** @@ -55,12 +55,17 @@ private[stream] object GraphInterpreter { def in: Inlet[T] } + val singleNoAttribute: Array[Attributes] = Array(Attributes.none) + /** * INTERNAL API * * A GraphAssembly represents a small stream processing graph to be executed by the interpreter. Instances of this * class **must not** be mutated after construction. * + * The array ``originalAttributes`` may contain the attribute information of the original atomic module, otherwise + * it must contain a none (otherwise the enclosing module could not overwrite attributes defined in this array). + * * The arrays [[ins]] and [[outs]] correspond to the notion of a *connection* in the [[GraphInterpreter]]. Each slot * *i* contains the input and output port corresponding to connection *i*. Slots where the graph is not closed (i.e. * ports are exposed to the external world) are marked with *null* values. For example if an input port *p* is @@ -88,6 +93,7 @@ private[stream] object GraphInterpreter { * */ final class GraphAssembly(val stages: Array[GraphStageWithMaterializedValue[Shape, Any]], + val originalAttributes: Array[Attributes], val ins: Array[Inlet[_]], val inOwners: Array[Int], val outs: Array[Outlet[_]], @@ -106,7 +112,7 @@ private[stream] object GraphInterpreter { * - array of the logics * - materialized value */ - def materialize(): (Array[InHandler], Array[OutHandler], Array[GraphStageLogic], Any) = { + def materialize(inheritedAttributes: Attributes): (Array[InHandler], Array[OutHandler], Array[GraphStageLogic], Any) = { val logics = Array.ofDim[GraphStageLogic](stages.length) var finalMat: Any = () @@ -134,7 +140,7 @@ private[stream] object GraphInterpreter { } // FIXME: Support for materialized values in fused islands is not yet figured out! - val logicAndMat = stages(i).createLogicAndMaterializedValue + val logicAndMat = stages(i).createLogicAndMaterializedValue(inheritedAttributes and originalAttributes(i)) // FIXME: Current temporary hack to support non-fused stages. If there is one stage that will be under index 0. if (i == 0) finalMat = logicAndMat._2 @@ -148,20 +154,21 @@ private[stream] object GraphInterpreter { i = 0 while (i < connectionCount) { if (ins(i) ne null) { - val l = logics(inOwners(i)) - l.inHandlers(ins(i).id) match { - case null ⇒ throw new IllegalStateException(s"no handler defined in stage $l for port ${ins(i)}") - case h ⇒ inHandlers(i) = h + val logic = logics(inOwners(i)) + logic.handlers(ins(i).id) match { + case null ⇒ throw new IllegalStateException(s"no handler defined in stage $logic for port ${ins(i)}") + case h: InHandler ⇒ inHandlers(i) = h } - l.inToConn(ins(i).id) = i + logics(inOwners(i)).portToConn(ins(i).id) = i } if (outs(i) ne null) { - val l = logics(outOwners(i)) - l.outHandlers(outs(i).id) match { - case null ⇒ throw new IllegalStateException(s"no handler defined in stage $l for port ${outs(i)}") - case h ⇒ outHandlers(i) = h + val logic = logics(outOwners(i)) + val inCount = logic.inCount + logic.handlers(outs(i).id + inCount) match { + case null ⇒ throw new IllegalStateException(s"no handler defined in stage $logic for port ${outs(i)}") + case h: OutHandler ⇒ outHandlers(i) = h } - l.outToConn(outs(i).id) = i + logic.portToConn(outs(i).id + inCount) = i } i += 1 } @@ -206,6 +213,7 @@ private[stream] object GraphInterpreter { val assembly = new GraphAssembly( stages.toArray, + GraphInterpreter.singleNoAttribute, add(inlets.iterator, Array.ofDim(connectionCount), 0), markBoundary(Array.ofDim(connectionCount), inletsSize, connectionCount), add(outlets.iterator, Array.ofDim(connectionCount), inletsSize), @@ -288,7 +296,7 @@ private[stream] object GraphInterpreter { * * Because of the FIFO construction of the queue the interpreter is fair, i.e. a pending event is always executed * after a bounded number of other events. This property, together with suspendability means that even infinite cycles can - * be modeled, or even dissolved (if preempted and a "stealing" external even is injected; for example the non-cycle + * be modeled, or even dissolved (if preempted and a "stealing" external event is injected; for example the non-cycle * edge of a balance is pulled, dissolving the original cycle). */ private[stream] final class GraphInterpreter( @@ -309,7 +317,7 @@ private[stream] final class GraphInterpreter( // of the class for a full description. val portStates = Array.fill[Int](assembly.connectionCount)(InReady) - private[this] var activeStageId = Boundary + private[this] var activeStage: GraphStageLogic = _ // The number of currently running stages. Once this counter reaches zero, the interpreter is considered to be // completed @@ -323,19 +331,33 @@ private[stream] final class GraphInterpreter( // An event queue implemented as a circular buffer // FIXME: This calculates the maximum size ever needed, but most assemblies can run on a smaller queue - private[this] val eventQueue = Array.ofDim[Int](1 << Integer.highestOneBit(assembly.connectionCount)) + private[this] val eventQueue = Array.ofDim[Int](1 << (32 - Integer.numberOfLeadingZeros(assembly.connectionCount - 1))) private[this] val mask = eventQueue.length - 1 private[this] var queueHead: Int = 0 private[this] var queueTail: Int = 0 + private def queueStatus: String = { + val contents = (queueHead until queueTail).map(idx ⇒ { + val conn = eventQueue(idx & mask) + (conn, portStates(conn), connectionSlots(conn)) + }) + s"(${eventQueue.length}, $queueHead, $queueTail)(${contents.mkString(", ")})" + } + private[this] var _Name: String = _ + def Name: String = + if (_Name eq null) { + _Name = f"${System.identityHashCode(this)}%08X" + _Name + } else _Name + /** * Assign the boundary logic to a given connection. This will serve as the interface to the external world * (outside the interpreter) to process and inject events. */ def attachUpstreamBoundary(connection: Int, logic: UpstreamBoundaryStageLogic[_]): Unit = { - logic.outToConn(logic.out.id) = connection + logic.portToConn(logic.out.id + logic.inCount) = connection logic.interpreter = this - outHandlers(connection) = logic.outHandlers(0) + outHandlers(connection) = logic.handlers(0).asInstanceOf[OutHandler] } /** @@ -343,16 +365,16 @@ private[stream] final class GraphInterpreter( * (outside the interpreter) to process and inject events. */ def attachDownstreamBoundary(connection: Int, logic: DownstreamBoundaryStageLogic[_]): Unit = { - logic.inToConn(logic.in.id) = connection + logic.portToConn(logic.in.id) = connection logic.interpreter = this - inHandlers(connection) = logic.inHandlers(0) + inHandlers(connection) = logic.handlers(0).asInstanceOf[InHandler] } /** * Dynamic handler changes are communicated from a GraphStageLogic by this method. */ def setHandler(connection: Int, handler: InHandler): Unit = { - if (GraphInterpreter.Debug) println(s"SETHANDLER ${inOwnerName(connection)} (in) $handler") + if (Debug) println(s"$Name SETHANDLER ${inOwnerName(connection)} (in) $handler") inHandlers(connection) = handler } @@ -360,7 +382,7 @@ private[stream] final class GraphInterpreter( * Dynamic handler changes are communicated from a GraphStageLogic by this method. */ def setHandler(connection: Int, handler: OutHandler): Unit = { - if (GraphInterpreter.Debug) println(s"SETHANDLER ${outOwnerName(connection)} (out) $handler") + if (Debug) println(s"$Name SETHANDLER ${outOwnerName(connection)} (out) $handler") outHandlers(connection) = handler } @@ -389,6 +411,7 @@ private[stream] final class GraphInterpreter( } catch { case NonFatal(e) ⇒ logic.failStage(e) } + afterStageHasRun(logic) i += 1 } } @@ -399,7 +422,8 @@ private[stream] final class GraphInterpreter( def finish(): Unit = { var i = 0 while (i < logics.length) { - if (!isStageCompleted(i)) finalizeStage(logics(i)) + val logic = logics(i) + if (!isStageCompleted(logic)) finalizeStage(logic) i += 1 } } @@ -418,84 +442,97 @@ private[stream] final class GraphInterpreter( case owner ⇒ assembly.stages(owner).toString } + // Debug name for a connections input part + private def inLogicName(connection: Int): String = + assembly.inOwners(connection) match { + case Boundary ⇒ "DownstreamBoundary" + case owner ⇒ logics(owner).toString + } + + // Debug name for a connections ouput part + private def outLogicName(connection: Int): String = + assembly.outOwners(connection) match { + case Boundary ⇒ "UpstreamBoundary" + case owner ⇒ logics(owner).toString + } + /** * Executes pending events until the given limit is met. If there were remaining events, isSuspended will return * true. */ def execute(eventLimit: Int): Unit = { - if (GraphInterpreter.Debug) println("---------------- EXECUTE") + if (Debug) println(s"$Name ---------------- EXECUTE (running=$runningStages, shutdown=${shutdownCounter.mkString(",")})") var eventsRemaining = eventLimit var connection = dequeue() while (eventsRemaining > 0 && connection != NoEvent) { try processEvent(connection) catch { case NonFatal(e) ⇒ - if (activeStageId == Boundary) throw e - else logics(activeStageId).failStage(e) + if (activeStage == null) throw e + else activeStage.failStage(e) } + afterStageHasRun(activeStage) eventsRemaining -= 1 if (eventsRemaining > 0) connection = dequeue() } + if (Debug) println(s"$Name ---------------- $queueStatus (running=$runningStages, shutdown=${shutdownCounter.mkString(",")})") // TODO: deadlock detection } // Decodes and processes a single event for the given connection private def processEvent(connection: Int): Unit = { + def safeLogics(id: Int) = + if (id == Boundary) null + else logics(id) - def processElement(elem: Any): Unit = { - if (GraphInterpreter.Debug) println(s"PUSH ${outOwnerName(connection)} -> ${inOwnerName(connection)}, $elem (${inHandlers(connection)})") - activeStageId = assembly.inOwners(connection) + def processElement(): Unit = { + if (Debug) println(s"$Name PUSH ${outOwnerName(connection)} -> ${inOwnerName(connection)}, ${connectionSlots(connection)} (${inHandlers(connection)}) [${inLogicName(connection)}]") + activeStage = safeLogics(assembly.inOwners(connection)) portStates(connection) ^= PushEndFlip inHandlers(connection).onPush() } + // this must be the state after returning without delivering any signals, to avoid double-finalization of some unlucky stage + // (this can happen if a stage completes voluntarily while connection close events are still queued) + activeStage = null val code = portStates(connection) // Manual fast decoding, fast paths are PUSH and PULL // PUSH if ((code & (Pushing | InClosed | OutClosed)) == Pushing) { - processElement(connectionSlots(connection)) + processElement() // PULL } else if ((code & (Pulling | OutClosed | InClosed)) == Pulling) { - if (GraphInterpreter.Debug) println(s"PULL ${inOwnerName(connection)} -> ${outOwnerName(connection)} (${outHandlers(connection)})") + if (Debug) println(s"$Name PULL ${inOwnerName(connection)} -> ${outOwnerName(connection)} (${outHandlers(connection)}) [${outLogicName(connection)}]") portStates(connection) ^= PullEndFlip - activeStageId = assembly.outOwners(connection) + activeStage = safeLogics(assembly.outOwners(connection)) outHandlers(connection).onPull() // CANCEL } else if ((code & (OutClosed | InClosed)) == InClosed) { val stageId = assembly.outOwners(connection) - if (GraphInterpreter.Debug) println(s"CANCEL ${inOwnerName(connection)} -> ${outOwnerName(connection)} (${outHandlers(connection)})") + activeStage = safeLogics(stageId) + if (Debug) println(s"$Name CANCEL ${inOwnerName(connection)} -> ${outOwnerName(connection)} (${outHandlers(connection)}) [${outLogicName(connection)}]") portStates(connection) |= OutClosed - activeStageId = assembly.outOwners(connection) - outHandlers(connection).onDownstreamFinish() completeConnection(stageId) + outHandlers(connection).onDownstreamFinish() } else if ((code & (OutClosed | InClosed)) == OutClosed) { // COMPLETIONS - val stageId = assembly.inOwners(connection) - if ((code & Pushing) == 0) { // Normal completion (no push pending) - if (GraphInterpreter.Debug) println(s"COMPLETE ${outOwnerName(connection)} -> ${inOwnerName(connection)} (${inHandlers(connection)})") + if (Debug) println(s"$Name COMPLETE ${outOwnerName(connection)} -> ${inOwnerName(connection)} (${inHandlers(connection)}) [${inLogicName(connection)}]") portStates(connection) |= InClosed - activeStageId = assembly.inOwners(connection) + val stageId = assembly.inOwners(connection) + activeStage = safeLogics(stageId) + completeConnection(stageId) if ((portStates(connection) & InFailed) == 0) inHandlers(connection).onUpstreamFinish() else inHandlers(connection).onUpstreamFailure(connectionSlots(connection).asInstanceOf[Failed].ex) - completeConnection(stageId) } else { // Push is pending, first process push, then re-enqueue closing event - // Non-failure case - val code = portStates(connection) & (InClosed | InFailed) - if (code == 0) { - processElement(connectionSlots(connection)) - enqueue(connection) - } else if (code == InFailed) { - // Failure case - processElement(connectionSlots(connection).asInstanceOf[Failed].previousElem) - enqueue(connection) - } + processElement() + enqueue(connection) } } @@ -513,26 +550,26 @@ private[stream] final class GraphInterpreter( } private def enqueue(connection: Int): Unit = { + if (Debug) if (queueTail - queueHead > mask) new Exception(s"$Name internal queue full ($queueStatus) + $connection").printStackTrace() eventQueue(queueTail & mask) = connection queueTail += 1 } + def afterStageHasRun(logic: GraphStageLogic): Unit = + if (isStageCompleted(logic)) { + runningStages -= 1 + finalizeStage(logic) + } + // Returns true if the given stage is alredy completed - def isStageCompleted(stageId: Int): Boolean = stageId != Boundary && shutdownCounter(stageId) == 0 + def isStageCompleted(stage: GraphStageLogic): Boolean = stage != null && shutdownCounter(stage.stageId) == 0 // Register that a connection in which the given stage participated has been completed and therefore the stage // itself might stop, too. private def completeConnection(stageId: Int): Unit = { if (stageId != Boundary) { val activeConnections = shutdownCounter(stageId) - if (activeConnections > 0) { - shutdownCounter(stageId) = activeConnections - 1 - // This was the last active connection keeping this stage alive - if (activeConnections == 1) { - runningStages -= 1 - finalizeStage(logics(stageId)) - } - } + if (activeConnections > 0) shutdownCounter(stageId) = activeConnections - 1 } } @@ -564,35 +601,32 @@ private[stream] final class GraphInterpreter( private[stream] def complete(connection: Int): Unit = { val currentState = portStates(connection) + if (Debug) println(s"$Name complete($connection) [$currentState]") portStates(connection) = currentState | OutClosed - - if ((currentState & (InClosed | Pushing)) == 0) { - enqueue(connection) - } - - completeConnection(assembly.outOwners(connection)) + if ((currentState & (InClosed | Pushing | Pulling)) == 0) enqueue(connection) + if ((currentState & OutClosed) == 0) completeConnection(assembly.outOwners(connection)) } private[stream] def fail(connection: Int, ex: Throwable): Unit = { val currentState = portStates(connection) + if (Debug) println(s"$Name fail($connection, $ex) [$currentState]") portStates(connection) = currentState | (OutClosed | InFailed) if ((currentState & InClosed) == 0) { connectionSlots(connection) = Failed(ex, connectionSlots(connection)) - enqueue(connection) + if ((currentState & (Pulling | Pushing)) == 0) enqueue(connection) } - - completeConnection(assembly.outOwners(connection)) + if ((currentState & OutClosed) == 0) completeConnection(assembly.outOwners(connection)) } private[stream] def cancel(connection: Int): Unit = { val currentState = portStates(connection) + if (Debug) println(s"$Name cancel($connection) [$currentState]") portStates(connection) = currentState | InClosed if ((currentState & OutClosed) == 0) { connectionSlots(connection) = Empty - enqueue(connection) + if ((currentState & (Pulling | Pushing)) == 0) enqueue(connection) } - - completeConnection(assembly.inOwners(connection)) + if ((currentState & InClosed) == 0) completeConnection(assembly.inOwners(connection)) } -} \ No newline at end of file +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala index 6370e2a6a2..412105e165 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala @@ -24,21 +24,18 @@ object GraphStages { val in = Inlet[T]("in") val out = Outlet[T]("out") override val shape = FlowShape(in, out) - - protected abstract class SimpleLinearStageLogic extends GraphStageLogic(shape) { - setHandler(out, new OutHandler { - override def onPull(): Unit = pull(in) - }) - } - } class Identity[T] extends SimpleLinearGraphStage[T] { - override def createLogic: GraphStageLogic = new SimpleLinearStageLogic() { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { setHandler(in, new InHandler { override def onPush(): Unit = push(out, grab(in)) }) + + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) } override def toString = "Identity" @@ -49,7 +46,7 @@ object GraphStages { val out = Outlet[T]("out") override val shape = FlowShape(in, out) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { var initialized = false setHandler(in, new InHandler { @@ -99,13 +96,13 @@ object GraphStages { val out = Outlet[T]("TimerSource.out") override val shape = SourceShape(out) - override def createLogicAndMaterializedValue: (GraphStageLogic, Cancellable) = { + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Cancellable) = { import TickSource._ val cancelled = new AtomicBoolean(false) val cancellable = new TickSourceCancellable(cancelled) - val logic = new GraphStageLogic(shape) { + val logic = new TimerGraphStageLogic(shape) { override def preStart() = { schedulePeriodicallyWithInitialDelay("TickTimer", initialDelay, interval) val callback = getAsyncCallback[Unit]((_) ⇒ { diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala index d73eb01231..10f59bab37 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala @@ -3,64 +3,71 @@ */ package akka.stream.impl.fusing -import akka.event.NoLogging +import akka.event.{ Logging, NoLogging } import akka.stream._ +import akka.stream.impl.fusing.GraphInterpreter.{ GraphAssembly, DownstreamBoundaryStageLogic, UpstreamBoundaryStageLogic } +import akka.stream.stage.AbstractStage.PushPullGraphStage import akka.stream.stage._ /** * INTERNAL API */ private[akka] object IteratorInterpreter { - final case class IteratorUpstream[T](input: Iterator[T]) extends PushPullStage[T, T] { + + final case class IteratorUpstream[T](input: Iterator[T]) extends UpstreamBoundaryStageLogic[T] { + val out: Outlet[T] = Outlet[T]("IteratorUpstream.out") + out.id = 0 + private var hasNext = input.hasNext - override def onPush(elem: T, ctx: Context[T]): SyncDirective = - throw new UnsupportedOperationException("IteratorUpstream operates as a source, it cannot be pushed") - - override def onPull(ctx: Context[T]): SyncDirective = { - if (!hasNext) ctx.finish() - else { - val elem = input.next() - hasNext = input.hasNext - if (!hasNext) ctx.pushAndFinish(elem) - else ctx.push(elem) + setHandler(out, new OutHandler { + override def onPull(): Unit = { + if (!hasNext) completeStage() + else { + val elem = input.next() + hasNext = input.hasNext + if (!hasNext) { + push(out, elem) + complete(out) + } else push(out, elem) + } } + }) - } - - // don't let toString consume the iterator - override def toString: String = "IteratorUpstream" + override def toString = "IteratorUpstream" } - final case class IteratorDownstream[T]() extends BoundaryStage with Iterator[T] { + final case class IteratorDownstream[T]() extends DownstreamBoundaryStageLogic[T] with Iterator[T] { + val in: Inlet[T] = Inlet[T]("IteratorDownstream.in") + in.id = 0 + private var done = false private var nextElem: T = _ private var needsPull = true private var lastFailure: Throwable = null - override def onPush(elem: Any, ctx: BoundaryContext): Directive = { - nextElem = elem.asInstanceOf[T] - needsPull = false - ctx.exit() - } + setHandler(in, new InHandler { + override def onPush(): Unit = { + nextElem = grab(in) + needsPull = false + } - override def onPull(ctx: BoundaryContext): Directive = - throw new UnsupportedOperationException("IteratorDownstream operates as a sink, it cannot be pulled") + override def onUpstreamFinish(): Unit = { + done = true + completeStage() + } - override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = { - done = true - ctx.finish() - } - - override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = { - done = true - lastFailure = cause - ctx.finish() - } + override def onUpstreamFailure(cause: Throwable): Unit = { + done = true + lastFailure = cause + completeStage() + } + }) private def pullIfNeeded(): Unit = { if (needsPull) { - enterAndPull() // will eventually result in a finish, or an onPush which exits + pull(in) + interpreter.execute(Int.MaxValue) } } @@ -84,8 +91,8 @@ private[akka] object IteratorInterpreter { // don't let toString consume the iterator override def toString: String = "IteratorDownstream" - } + } /** @@ -96,11 +103,52 @@ private[akka] class IteratorInterpreter[I, O](val input: Iterator[I], val ops: S private val upstream = IteratorUpstream(input) private val downstream = IteratorDownstream[O]() - private val interpreter = new OneBoundedInterpreter(upstream +: ops.asInstanceOf[Seq[Stage[_, _]]] :+ downstream, - (op, ctx, evt) ⇒ throw new UnsupportedOperationException("IteratorInterpreter is fully synchronous"), - NoLogging, - NoMaterializer) - interpreter.init() + + private def init(): Unit = { + import GraphInterpreter.Boundary + + var i = 0 + val length = ops.length + val attributes = Array.fill[Attributes](ops.length)(Attributes.none) + val ins = Array.ofDim[Inlet[_]](length + 1) + val inOwners = Array.ofDim[Int](length + 1) + val outs = Array.ofDim[Outlet[_]](length + 1) + val outOwners = Array.ofDim[Int](length + 1) + val stages = Array.ofDim[GraphStageWithMaterializedValue[Shape, Any]](length) + + ins(ops.length) = null + inOwners(ops.length) = Boundary + outs(0) = null + outOwners(0) = Boundary + + val opsIterator = ops.iterator + while (opsIterator.hasNext) { + val op = opsIterator.next().asInstanceOf[Stage[Any, Any]] + val stage = new PushPullGraphStage((_) ⇒ op, Attributes.none) + stages(i) = stage + ins(i) = stage.shape.inlet + inOwners(i) = i + outs(i + 1) = stage.shape.outlet + outOwners(i + 1) = i + i += 1 + } + val assembly = new GraphAssembly(stages, attributes, ins, inOwners, outs, outOwners) + + val (inHandlers, outHandlers, logics, _) = assembly.materialize(Attributes.none) + val interpreter = new GraphInterpreter( + assembly, + NoMaterializer, + NoLogging, + inHandlers, + outHandlers, + logics, + (_, _, _) ⇒ throw new UnsupportedOperationException("IteratorInterpreter does not support asynchronous events.")) + interpreter.attachUpstreamBoundary(0, upstream) + interpreter.attachDownstreamBoundary(ops.length, downstream) + interpreter.init() + } + + init() def iterator: Iterator[O] = downstream } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/OneBoundedInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/OneBoundedInterpreter.scala deleted file mode 100644 index b8551d20fb..0000000000 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/OneBoundedInterpreter.scala +++ /dev/null @@ -1,747 +0,0 @@ -/** - * Copyright (C) 2009-2014 Typesafe Inc. - */ -package akka.stream.impl.fusing - -import akka.event.LoggingAdapter -import akka.stream.impl.ReactiveStreamsCompliance -import akka.stream.stage._ -import akka.stream.{ Materializer, Attributes, Supervision } - -import scala.annotation.{ switch, tailrec } -import scala.collection.{ breakOut, immutable } -import scala.util.control.NonFatal - -/** - * INTERNAL API - */ -private[akka] object OneBoundedInterpreter { - final val Debug = false - - /** INTERNAL API */ - private[akka] sealed trait InitializationStatus - /** INTERNAL API */ - private[akka] case object InitializationSuccessful extends InitializationStatus - /** INTERNAL API */ - private[akka] final case class InitializationFailed(failures: immutable.Seq[InitializationFailure]) extends InitializationStatus { - // exceptions are reverse ordered here, below methods help to avoid confusion when used from the outside - def mostUpstream = failures.last - def mostDownstream = failures.head - } - - /** INTERNAL API */ - private[akka] case class InitializationFailure(op: Int, ex: Throwable) - - /** - * INTERNAL API - * - * This artificial op is used as a boundary to prevent two forked paths of execution (complete, cancel) to cross - * paths again. When finishing an op this op is injected in its place to isolate upstream and downstream execution - * domains. - */ - private[akka] object Finished extends BoundaryStage { - override def onPush(elem: Any, ctx: BoundaryContext): UpstreamDirective = ctx.finish() - override def onPull(ctx: BoundaryContext): DownstreamDirective = ctx.finish() - override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = ctx.exit() - override def onDownstreamFinish(ctx: BoundaryContext): TerminationDirective = ctx.exit() - override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = ctx.exit() - } - - /** - * INTERNAL API - * - * This artificial op is used as a boundary to prevent the first forked onPush of execution of a pushFinish to enter - * the originating stage again. This stage allows the forked upstream onUpstreamFinish to pass through if there was - * no onPull called on the stage. Calling onPull on this op makes it a Finished op, which absorbs the - * onUpstreamTermination, but otherwise onUpstreamTermination results in calling finish() - */ - private[akka] object PushFinished extends BoundaryStage { - override def onPush(elem: Any, ctx: BoundaryContext): UpstreamDirective = ctx.finish() - override def onPull(ctx: BoundaryContext): DownstreamDirective = ctx.finish() - // This allows propagation of an onUpstreamFinish call. Note that if onPull has been called on this stage - // before, then the call ctx.finish() in onPull already turned this op to a normal Finished, i.e. it will no longer - // propagate onUpstreamFinish. - override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = ctx.finish() - override def onDownstreamFinish(ctx: BoundaryContext): TerminationDirective = ctx.exit() - override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = ctx.exit() - } -} - -/** - * INTERNAL API - * - * One-bounded interpreter for a linear chain of stream operations (graph support is possible and will be implemented - * later) - * - * The ideas in this interpreter are an amalgamation of earlier ideas, notably: - * - The original effect-tracking implementation by Johannes Rudolph -- the difference here that effects are not chained - * together as classes but the callstack is used instead and only certain combinations are allowed. - * - The on-stack reentrant implementation by Mathias Doenitz -- the difference here that reentrancy is handled by the - * interpreter itself, not user code, and the interpreter is able to use the heap when needed instead of the - * callstack. - * - The pinball interpreter by Endre Sándor Varga -- the difference here that the restriction for "one ball" is - * lifted by using isolated execution regions, completion handling is introduced and communication with the external - * world is done via boundary ops. - * - * The design goals/features of this interpreter are: - * - bounded callstack and heapless execution whenever possible - * - callstack usage should be constant for the most common ops independently of the size of the op-chain - * - allocation-free execution on the hot paths - * - enforced backpressure-safety (boundedness) on user defined ops at compile-time (and runtime in a few cases) - * - * The main driving idea of this interpreter is the concept of 1-bounded execution of well-formed free choice Petri - * nets (J. Desel and J. Esparza: Free Choice Petri Nets - https://www7.in.tum.de/~esparza/bookfc.html). Technically - * different kinds of operations partition the chain of ops into regions where *exactly one* event is active all the - * time. This "exactly one" property is enforced by proper types and runtime checks where needed. Currently there are - * three kinds of ops: - * - * - PushPullStage implementations participate in 1-bounded regions. For every external non-completion signal these - * ops produce *exactly one* signal (completion is different, explained later) therefore keeping the number of events - * the same: exactly one. - * - * - DetachedStage implementations are boundaries between 1-bounded regions. This means that they need to enforce the - * "exactly one" property both on their upstream and downstream regions. As a consequence a DetachedStage can never - * answer an onPull with a ctx.pull() or answer an onPush() with a ctx.push() since such an action would "steal" - * the event from one region (resulting in zero signals) and would inject it to the other region (resulting in two - * signals). However DetachedStages have the ability to call ctx.hold() as a response to onPush/onPull which temporarily - * takes the signal off and stops execution, at the same time putting the op in a "holding" state. If the op is in a - * holding state it contains one absorbed signal, therefore in this state the only possible command to call is - * ctx.pushAndPull() which results in two events making the balance right again: - * 1 hold + 1 external event = 2 external event - * This mechanism allows synchronization between the upstream and downstream regions which otherwise can progress - * independently. - * - * - BoundaryStage implementations are meant to communicate with the external world. These ops do not have most of the - * safety properties enforced and should be used carefully. One important ability of BoundaryStages that they can take - * off an execution signal by calling ctx.exit(). This is typically used immediately after an external signal has - * been produced (for example an actor message). BoundaryStages can also kickstart execution by calling enter() which - * returns a context they can use to inject signals into the interpreter. There is no checks in place to enforce that - * the number of signals taken out by exit() and the number of signals returned via enter() are the same -- using this - * op type needs extra care from the implementer. - * BoundaryStages are the elements that make the interpreter *tick*, there is no other way to start the interpreter - * than using a BoundaryStage. - * - * Operations are allowed to do early completion and cancel/complete their upstreams and downstreams. It is *not* - * allowed however to do these independently to avoid isolated execution islands. The only call possible is ctx.finish() - * which is a combination of cancel/complete. - * Since onComplete is not a backpressured signal it is sometimes preferable to push a final element and then immediately - * finish. This combination is exposed as pushAndFinish() which enables op writers to propagate completion events without - * waiting for an extra round of pull. - * Another peculiarity is how to convert termination events (complete/failure) into elements. The problem - * here is that the termination events are not backpressured while elements are. This means that simply calling ctx.push() - * as a response to onUpstreamFinished() will very likely break boundedness and result in a buffer overflow somewhere. - * Therefore the only allowed command in this case is ctx.absorbTermination() which stops the propagation of the - * termination signal, and puts the op in a finishing state. Depending on whether the op has a pending pull signal it has - * not yet "consumed" by a push its onPull() handler might be called immediately. - * - * In order to execute different individual execution regions the interpreter uses the callstack to schedule these. The - * current execution forking operations are - * - ctx.finish() which starts a wave of completion and cancellation in two directions. When an op calls finish() - * it is immediately replaced by an artificial Finished op which makes sure that the two execution paths are isolated - * forever. - * - ctx.fail() which is similar to finish() - * - ctx.pushAndPull() which (as a response to a previous ctx.hold()) starts a wave of downstream push and upstream - * pull. The two execution paths are isolated by the op itself since onPull() from downstream can only be answered by hold or - * push, while onPush() from upstream can only answered by hold or pull -- it is impossible to "cross" the op. - * - ctx.pushAndFinish() which is different from the forking ops above because the execution of push and finish happens on - * the same execution region and they are order dependent, too. - * The interpreter tracks the depth of recursive forking and allows various strategies of dealing with the situation - * when this depth reaches a certain limit. In the simplest case a failure is reported (this is very useful for stress - * testing and finding callstack wasting bugs), in the other case the forked call is scheduled via a list -- i.e. instead - * of the stack the heap is used. - */ -private[akka] class OneBoundedInterpreter(ops: Seq[Stage[_, _]], - onAsyncInput: (AsyncStage[Any, Any, Any], AsyncContext[Any, Any], Any) ⇒ Unit, - log: LoggingAdapter, - materializer: Materializer, - attributes: Attributes = Attributes.none, - val forkLimit: Int = 100, - val overflowToHeap: Boolean = true, - val name: String = "") { - import AbstractStage._ - import OneBoundedInterpreter._ - - type UntypedOp = AbstractStage[Any, Any, Directive, Directive, Context[Any], LifecycleContext] - require(ops.nonEmpty, "OneBoundedInterpreter cannot be created without at least one Op") - - private final val pipeline: Array[UntypedOp] = ops.map(_.asInstanceOf[UntypedOp])(breakOut) - - /** - * This table is used to accelerate demand propagation upstream. All ops that implement PushStage are guaranteed - * to only do upstream propagation of demand signals, therefore it is not necessary to execute them but enough to - * "jump over" them. This means that when a chain of one million maps gets a downstream demand it is propagated - * to the upstream *in one step* instead of one million onPull() calls. - * This table maintains the positions where execution should jump from a current position when a pull event is to - * be executed. - */ - private final val jumpBacks: Array[Int] = calculateJumpBacks - - private final val Upstream = 0 - private final val Downstream = pipeline.length - 1 - - // Var to hold the current element if pushing. The only reason why this var is needed is to avoid allocations and - // make it possible for the Pushing state to be an object - private var elementInFlight: Any = _ - // Points to the current point of execution inside the pipeline - private var activeOpIndex = -1 - // Points to the last point of exit - private var lastExitedIndex = Downstream - // The current interpreter state that decides what happens at the next round - private var state: State = _ - - // Counter that keeps track of the depth of recursive forked executions - private var forkCount = 0 - // List that is used as an auxiliary stack if fork recursion depth reaches forkLimit - private var overflowStack = List.empty[(Int, State, Any)] - - private var lastOpFailing: Int = -1 - - private def pipeName(op: UntypedOp): String = { - val o = op: AbstractStage[_, _, _, _, _, _] - (o match { - case Finished ⇒ "finished" - case _: BoundaryStage ⇒ "boundary" - case _: StatefulStage[_, _] ⇒ "stateful" - case _: PushStage[_, _] ⇒ "push" - case _: PushPullStage[_, _] ⇒ "pushpull" - case _: DetachedStage[_, _] ⇒ "detached" - case _ ⇒ "other" - }) + f"(${o.bits}%04X)" - } - override def toString = - s"""|OneBoundedInterpreter($name) - | pipeline = ${pipeline map pipeName mkString ":"} - | lastExit=$lastExitedIndex activeOp=$activeOpIndex state=$state elem=$elementInFlight forks=$forkCount""".stripMargin - - @inline private def currentOp: UntypedOp = pipeline(activeOpIndex) - - // see the jumpBacks variable for explanation - private def calculateJumpBacks: Array[Int] = { - val table = Array.ofDim[Int](pipeline.length) - var nextJumpBack = -1 - for (pos ← pipeline.indices) { - table(pos) = nextJumpBack - if (!pipeline(pos).isInstanceOf[PushStage[_, _]]) nextJumpBack = pos - } - table - } - - private def updateJumpBacks(lastNonCompletedIndex: Int): Unit = { - var pos = lastNonCompletedIndex - // For every jump that would jump over us we change them to jump into us - while (pos < pipeline.length && jumpBacks(pos) < lastNonCompletedIndex) { - jumpBacks(pos) = lastNonCompletedIndex - pos += 1 - } - } - - private sealed trait State extends DetachedContext[Any] with BoundaryContext with AsyncContext[Any, Any] { - def enter(): Unit = throw new IllegalStateException("cannot enter an ordinary Context") - - final def execute(): Unit = OneBoundedInterpreter.this.execute() - - final def progress(): Unit = { - advance() - if (inside) run() - else exit() - } - - /** - * Override this method to do execution steps necessary after executing an op, and advance the activeOpIndex - * to another value (next or previous steps). Do NOT put code that invokes the next op, override run instead. - */ - def advance(): Unit - - /** - * Override this method to enter the current op and execute it. Do NOT put code that should be executed after the - * op has been invoked, that should be in the advance() method of the next state resulting from the invocation of - * the op. - */ - def run(): Unit - - /** - * This method shall return the bit set representing the incoming ball (if any). - */ - def incomingBall: Int - - protected def hasBits(b: Int): Boolean = ((currentOp.bits | incomingBall) & b) == b - protected def addBits(b: Int): Unit = currentOp.bits |= b - protected def removeBits(b: Int): Unit = currentOp.bits &= ~b - - protected def mustHave(b: Int): Unit = - if (!hasBits(b)) { - def format(b: Int) = { - val ballStatus = (b & BothBalls: @switch) match { - case 0 ⇒ "no balls" - case UpstreamBall ⇒ "upstream ball" - case DownstreamBall ⇒ "downstream ball" - case BothBalls ⇒ "upstream & downstream balls" - } - if ((b & NoTerminationPending) > 0) ballStatus + " and not isFinishing" - else ballStatus + " and isFinishing" - } - throw new IllegalStateException(s"operation requires [${format(b)}] while holding [${format(currentOp.bits)}] and receiving [${format(incomingBall)}]") - } - - override def push(elem: Any): DownstreamDirective = { - ReactiveStreamsCompliance.requireNonNullElement(elem) - if (currentOp.isDetached) { - if (incomingBall == UpstreamBall) - throw new IllegalStateException("Cannot push during onPush, only pull, pushAndPull or holdUpstreamAndPush") - mustHave(DownstreamBall) - } - removeBits(PrecedingWasPull | DownstreamBall) - elementInFlight = elem - state = Pushing - null - } - - override def pull(): UpstreamDirective = { - var requirements = NoTerminationPending - if (currentOp.isDetached) { - if (incomingBall == DownstreamBall) - throw new IllegalStateException("Cannot pull during onPull, only push, pushAndPull or holdDownstreamAndPull") - requirements |= UpstreamBall - } - mustHave(requirements) - removeBits(UpstreamBall) - addBits(PrecedingWasPull) - state = Pulling - null - } - - override def getAsyncCallback: AsyncCallback[Any] = { - val current = currentOp.asInstanceOf[AsyncStage[Any, Any, Any]] - val context = current.context // avoid concurrent access (to avoid @volatile) - new AsyncCallback[Any] { - override def invoke(evt: Any): Unit = onAsyncInput(current, context, evt) - } - } - - override def ignore(): AsyncDirective = { - if (incomingBall != 0) throw new IllegalStateException("Can only ignore from onAsyncInput") - exit() - } - - override def finish(): FreeDirective = { - finishCurrentOp() - fork(Completing) - state = Cancelling - null - } - - def isFinishing: Boolean = !hasBits(NoTerminationPending) - - final protected def pushAndFinishCommon(elem: Any, finishState: UntypedOp): Unit = { - finishCurrentOp(finishState) - ReactiveStreamsCompliance.requireNonNullElement(elem) - if (currentOp.isDetached) { - mustHave(DownstreamBall) - } - removeBits(DownstreamBall | PrecedingWasPull) - } - - override def pushAndFinish(elem: Any): DownstreamDirective = { - // Spit the execution domain in two and invoke postStop - pushAndFinishCommon(elem, Finished.asInstanceOf[UntypedOp]) - - // This MUST be an unsafeFork because the execution of PushFinish MUST strictly come before the finish execution - // path. Other forks are not order dependent because they execute on isolated execution domains which cannot - // "cross paths". This unsafeFork is relatively safe here because PushAndFinish simply absorbs all later downstream - // calls of pushAndFinish since the finish event has been scheduled already. - // It might be that there are some degenerate cases where this can blow up the stack with a very long chain but I - // am not aware of such scenario yet. If you know one, put it in InterpreterStressSpec :) - unsafeFork(PushFinish, elem) - - // Same as finish, without calling finishCurrentOp - elementInFlight = null - fork(Completing) - state = Cancelling - null - } - - override def fail(cause: Throwable): FreeDirective = { - fork(Failing(cause)) - state = Cancelling - null - } - - override def holdUpstream(): UpstreamDirective = { - mustHave(NoTerminationPending) - removeBits(PrecedingWasPull) - addBits(UpstreamBall) - exit() - } - - override def holdUpstreamAndPush(elem: Any): UpstreamDirective = { - ReactiveStreamsCompliance.requireNonNullElement(elem) - if (incomingBall != UpstreamBall) - throw new IllegalStateException("can only holdUpstreamAndPush from onPush") - mustHave(BothBallsAndNoTerminationPending) - removeBits(PrecedingWasPull | DownstreamBall) - addBits(UpstreamBall) - elementInFlight = elem - state = Pushing - null - } - - override def isHoldingUpstream: Boolean = (currentOp.bits & UpstreamBall) != 0 - - override def holdDownstream(): DownstreamDirective = { - addBits(DownstreamBall) - exit() - } - - override def holdDownstreamAndPull(): DownstreamDirective = { - if (incomingBall != DownstreamBall) - throw new IllegalStateException("can only holdDownstreamAndPull from onPull") - mustHave(BothBallsAndNoTerminationPending) - addBits(PrecedingWasPull | DownstreamBall) - removeBits(UpstreamBall) - state = Pulling - null - } - - override def isHoldingDownstream: Boolean = (currentOp.bits & DownstreamBall) != 0 - - override def pushAndPull(elem: Any): FreeDirective = { - ReactiveStreamsCompliance.requireNonNullElement(elem) - mustHave(BothBallsAndNoTerminationPending) - addBits(PrecedingWasPull) - removeBits(BothBalls) - fork(Pushing, elem) - state = Pulling - null - } - - override def absorbTermination(): TerminationDirective = { - updateJumpBacks(activeOpIndex) - removeBits(BothBallsAndNoTerminationPending) - finish() - } - - override def exit(): FreeDirective = { - elementInFlight = null - lastExitedIndex = activeOpIndex - activeOpIndex = -1 - null - } - - override def materializer: Materializer = OneBoundedInterpreter.this.materializer - override def attributes: Attributes = OneBoundedInterpreter.this.attributes - } - - private final val Pushing: State = new State { - override def advance(): Unit = activeOpIndex += 1 - override def run(): Unit = currentOp.onPush(elementInFlight, ctx = this) - override def incomingBall = UpstreamBall - override def toString = "Pushing" - } - - private final val PushFinish: State = new State { - override def advance(): Unit = activeOpIndex += 1 - override def run(): Unit = currentOp.onPush(elementInFlight, ctx = this) - - override def pushAndFinish(elem: Any): DownstreamDirective = { - // PushFinished - // Put an isolation barrier that will prevent the onPull of this op to be called again. This barrier - // is different from simple Finished that it allows onUpstreamTerminated to pass through, unless onPull - // has been called on the stage - pushAndFinishCommon(elem, PushFinished.asInstanceOf[UntypedOp]) - - elementInFlight = elem - state = PushFinish - null - } - - override def finish(): FreeDirective = { - state = Completing - null - } - - override def incomingBall = UpstreamBall - - override def toString = "PushFinish" - } - - private final val Pulling: State = new State { - override def advance(): Unit = { - elementInFlight = null - activeOpIndex = jumpBacks(activeOpIndex) - } - - override def run(): Unit = currentOp.onPull(ctx = this) - - override def incomingBall = DownstreamBall - - override def toString = "Pulling" - } - - private final val Completing: State = new State { - override def advance(): Unit = { - elementInFlight = null - finishCurrentOp() - activeOpIndex += 1 - } - - override def run(): Unit = { - if (!hasBits(NoTerminationPending)) exit() - else currentOp.onUpstreamFinish(ctx = this) - } - - override def finish(): FreeDirective = { - state = Completing - null - } - - override def absorbTermination(): TerminationDirective = { - removeBits(NoTerminationPending) - removeBits(UpstreamBall) - updateJumpBacks(activeOpIndex) - if (hasBits(DownstreamBall) || (!currentOp.isDetached && hasBits(PrecedingWasPull))) { - removeBits(DownstreamBall) - currentOp.onPull(ctx = Pulling) - } else exit() - null - } - - override def incomingBall = UpstreamBall - - override def toString = "Completing" - } - - private final val Cancelling: State = new State { - override def advance(): Unit = { - elementInFlight = null - finishCurrentOp() - activeOpIndex -= 1 - } - - def run(): Unit = { - if (!hasBits(NoTerminationPending)) exit() - else currentOp.onDownstreamFinish(ctx = this) - } - - override def finish(): FreeDirective = { - state = Cancelling - null - } - - override def incomingBall = DownstreamBall - - override def toString = "Cancelling" - - override def absorbTermination(): TerminationDirective = { - val ex = new UnsupportedOperationException("It is not allowed to call absorbTermination() from onDownstreamFinish.") - // This MUST be logged here, since the downstream has cancelled, i.e. there is noone to send onError to, the - // stage is just about to finish so noone will catch it anyway just the interpreter - log.error(ex.getMessage) - throw ex // We still throw for correctness (although a finish() would also work here) - } - - } - - private final case class Failing(cause: Throwable) extends State { - override def advance(): Unit = { - elementInFlight = null - finishCurrentOp() - activeOpIndex += 1 - } - - def run(): Unit = currentOp.onUpstreamFailure(cause, ctx = this) - - override def absorbTermination(): TerminationDirective = { - removeBits(NoTerminationPending) - removeBits(UpstreamBall) - updateJumpBacks(activeOpIndex) - if (hasBits(DownstreamBall) || (!currentOp.isDetached && hasBits(PrecedingWasPull))) { - removeBits(DownstreamBall) - currentOp.onPull(ctx = Pulling) - } else exit() - null - } - - override def incomingBall = UpstreamBall - } - - private def inside: Boolean = activeOpIndex > -1 && activeOpIndex < pipeline.length - - private def printDebug(): Unit = { - val padding = " " * activeOpIndex - val icon: String = state match { - case Pushing | PushFinish ⇒ padding + s"---> $elementInFlight" - case Pulling ⇒ - (" " * jumpBacks(activeOpIndex)) + - "<---" + - ("----" * (activeOpIndex - jumpBacks(activeOpIndex) - 1)) - case Completing ⇒ padding + "---|" - case Cancelling ⇒ padding + "|---" - case Failing(e) ⇒ padding + s"---X ${e.getMessage} => ${decide(e)}" - case other ⇒ padding + s"---? $state" - } - println(f"$icon%-24s $name") - } - - @tailrec private def execute(): Unit = { - while (inside) { - try { - if (Debug) printDebug() - state.progress() - } catch { - case NonFatal(e) if lastOpFailing != activeOpIndex ⇒ - lastOpFailing = activeOpIndex - decide(e) match { - case Supervision.Stop ⇒ state.fail(e) - case Supervision.Resume ⇒ - // reset, purpose of lastOpFailing is to avoid infinite loops when fail fails -- double fault - lastOpFailing = -1 - afterRecovery() - case Supervision.Restart ⇒ - // reset, purpose of lastOpFailing is to avoid infinite loops when fail fails -- double fault - lastOpFailing = -1 - pipeline(activeOpIndex) = pipeline(activeOpIndex).restart().asInstanceOf[UntypedOp] - afterRecovery() - } - } - } - - // FIXME push this into AbstractStage so it can be customized - def afterRecovery(): Unit = state match { - case _: EntryState ⇒ // no ball to be juggled with - case _ ⇒ state.pull() - } - - // Execute all delayed forks that were put on the heap if the fork limit has been reached - if (overflowStack.nonEmpty) { - val memo = overflowStack.head - activeOpIndex = memo._1 - state = memo._2 - elementInFlight = memo._3 - overflowStack = overflowStack.tail - execute() - } - } - - def decide(e: Throwable): Supervision.Directive = - if (state == Pulling || state == Cancelling) Supervision.Stop - else currentOp.decide(e) - - /** - * Forks off execution of the pipeline by saving current position, fully executing the effects of the given - * forkState then setting back the position to the saved value. - * By default forking is executed by using the callstack. If the depth of forking ever reaches the configured forkLimit - * this method either fails (useful for testing) or starts using the heap instead of the callstack to avoid a - * stack overflow. - */ - private def fork(forkState: State, elem: Any = null): Unit = { - forkCount += 1 - if (forkCount == forkLimit) { - if (!overflowToHeap) throw new IllegalStateException("Fork limit reached") - else overflowStack ::= ((activeOpIndex, forkState, elem)) - } else unsafeFork(forkState, elem) - forkCount -= 1 - } - - /** - * Unsafe fork always uses the stack for execution. This call is needed by pushAndComplete where the forked execution - * is order dependent since the push and complete events travel in the same direction and not isolated by a boundary - */ - private def unsafeFork(forkState: State, elem: Any = null): Unit = { - val savePos = activeOpIndex - elementInFlight = elem - state = forkState - execute() - activeOpIndex = savePos - } - - /** - * Initializes all stages setting their initial context and calling [[AbstractStage.preStart]] on each. - */ - def init(): InitializationStatus = { - val failures = initBoundaries() - runDetached() - - if (failures.isEmpty) InitializationSuccessful - else { - val failure = failures.head - activeOpIndex = failure.op - currentOp.enterAndFail(failure.ex) - InitializationFailed(failures) - } - } - - def isFinished: Boolean = pipeline(Upstream) == Finished && pipeline(Downstream) == Finished - - private class EntryState(name: String, position: Int) extends State { - val entryPoint = position - - final override def enter(): Unit = { - activeOpIndex = entryPoint - if (Debug) { - val s = " " * entryPoint + "ENTR" - println(f"$s%-24s ${OneBoundedInterpreter.this.name}") - } - } - - override def run(): Unit = () - override def advance(): Unit = () - - override def incomingBall = 0 - - override def toString = s"$name($entryPoint)" - } - - /** - * This method injects a Context to each of the BoundaryStages and AsyncStages. This will be the context returned by enter(). - */ - private def initBoundaries(): List[InitializationFailure] = { - var failures: List[InitializationFailure] = Nil - var op = 0 - while (op < pipeline.length) { - (pipeline(op): Any) match { - case b: BoundaryStage ⇒ - b.context = new EntryState("boundary", op) - - case a: AsyncStage[Any, Any, Any] @unchecked ⇒ - a.context = new EntryState("async", op) - activeOpIndex = op - a.preStart(a.context) - - case a: AbstractStage[Any, Any, Any, Any, Any, Any] @unchecked ⇒ - val state = new EntryState("stage", op) - a.context = state - try a.preStart(state) catch { - case NonFatal(ex) ⇒ - failures ::= InitializationFailure(op, ex) // not logging here as 'most downstream' exception will be signaled via onError - } - } - op += 1 - } - failures - } - - private def finishCurrentOp(finishState: UntypedOp = Finished.asInstanceOf[UntypedOp]): Unit = { - try pipeline(activeOpIndex).postStop() - catch { case NonFatal(ex) ⇒ log.error(s"Stage [{}] postStop failed", ex) } - finally pipeline(activeOpIndex) = finishState - } - - /** - * Starts execution of detached regions. - * - * Since detached ops partition the pipeline into different 1-bounded domains is is necessary to inject a starting - * signal into these regions (since there is no external signal that would kick off their execution otherwise). - */ - private def runDetached(): Unit = { - var op = pipeline.length - 1 - while (op >= 0) { - if (pipeline(op).isDetached) { - activeOpIndex = op - state = Pulling - execute() - } - op -= 1 - } - } - -} diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index cf438260ac..9e1ce386c3 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -10,7 +10,6 @@ import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage import akka.stream.impl.{ FixedSizeBuffer, ReactiveStreamsCompliance } import akka.stream.stage._ import akka.stream.{ Supervision, _ } - import scala.annotation.tailrec import scala.collection.immutable import scala.collection.immutable.VectorBuilder @@ -18,6 +17,7 @@ import scala.concurrent.Future import scala.concurrent.duration.FiniteDuration import scala.util.control.NonFatal import scala.util.{ Failure, Success, Try } +import akka.stream.ActorAttributes.SupervisionStrategy /** * INTERNAL API @@ -537,149 +537,150 @@ private[akka] object MapAsync { /** * INTERNAL API */ -private[akka] final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out], decider: Supervision.Decider) - extends AsyncStage[In, Out, (Int, Try[Out])] { +private[akka] final case class MapAsync[In, Out](parallelism: Int, f: In ⇒ Future[Out]) + extends GraphStage[FlowShape[In, Out]] { import MapAsync._ - type Notification = (Int, Try[Out]) + private val in = Inlet[In]("in") + private val out = Outlet[Out]("out") - private var callback: AsyncCallback[Notification] = _ - private val elemsInFlight = FixedSizeBuffer[Try[Out]](parallelism) + override val shape = FlowShape(in, out) - override def preStart(ctx: AsyncContext[Out, Notification]): Unit = { - callback = ctx.getAsyncCallback - } + override def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) { + override def toString = s"MapAsync.Logic(buffer=$buffer)" - override def decide(ex: Throwable) = decider(ex) + val decider = + inheritedAttributes.getAttribute(classOf[SupervisionStrategy]) + .map(_.decider).getOrElse(Supervision.stoppingDecider) - override def onPush(elem: In, ctx: AsyncContext[Out, Notification]) = { - val future = f(elem) - val idx = elemsInFlight.enqueue(NotYetThere) - future.onComplete(t ⇒ callback.invoke((idx, t)))(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) - if (elemsInFlight.isFull) ctx.holdUpstream() - else ctx.pull() - } + val buffer = FixedSizeBuffer[Try[Out]](parallelism) + def todo = buffer.used - override def onPull(ctx: AsyncContext[Out, (Int, Try[Out])]) = { - @tailrec def rec(): DownstreamDirective = - if (elemsInFlight.isEmpty && ctx.isFinishing) ctx.finish() - else if (elemsInFlight.isEmpty || elemsInFlight.peek == NotYetThere) { - if (!elemsInFlight.isFull && ctx.isHoldingUpstream) ctx.holdDownstreamAndPull() - else ctx.holdDownstream() - } else elemsInFlight.dequeue() match { - case Failure(ex) ⇒ rec() + @tailrec private def pushOne(): Unit = + if (buffer.isEmpty) { + if (isClosed(in)) completeStage() + else if (!hasBeenPulled(in)) pull(in) + } else if (buffer.peek == NotYetThere) { + if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) + } else buffer.dequeue() match { + case Failure(ex) ⇒ pushOne() case Success(elem) ⇒ - if (ctx.isHoldingUpstream) ctx.pushAndPull(elem) - else ctx.push(elem) - } - rec() - } - - override def onAsyncInput(input: (Int, Try[Out]), ctx: AsyncContext[Out, Notification]) = { - @tailrec def rec(): Directive = - if (elemsInFlight.isEmpty && ctx.isFinishing) ctx.finish() - else if (elemsInFlight.isEmpty || (elemsInFlight.peek eq NotYetThere)) { - if (!elemsInFlight.isFull && ctx.isHoldingUpstream) ctx.pull() - else ctx.ignore() - } else elemsInFlight.dequeue() match { - case Failure(ex) ⇒ rec() - case Success(elem) ⇒ - if (ctx.isHoldingUpstream) ctx.pushAndPull(elem) - else ctx.push(elem) + push(out, elem) + if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) } - input match { - case (idx, f @ Failure(ex)) ⇒ - if (decider(ex) != Supervision.Stop) { - elemsInFlight.put(idx, f) - if (ctx.isHoldingDownstream) rec() - else ctx.ignore() - } else ctx.fail(ex) - case (idx, s: Success[_]) ⇒ - val exception = try { - ReactiveStreamsCompliance.requireNonNullElement(s.value) - elemsInFlight.put(idx, s) - null: Exception + def failOrPull(idx: Int, f: Failure[Out]) = + if (decider(f.exception) == Supervision.Stop) failStage(f.exception) + else { + buffer.put(idx, f) + if (isAvailable(out)) pushOne() + } + + val futureCB = + getAsyncCallback[(Int, Try[Out])]({ + case (idx, f: Failure[_]) ⇒ failOrPull(idx, f) + case (idx, s @ Success(elem)) ⇒ + if (elem == null) { + val ex = ReactiveStreamsCompliance.elementMustNotBeNullException + failOrPull(idx, Failure(ex)) + } else { + buffer.put(idx, s) + if (isAvailable(out)) pushOne() + } + }) + + setHandler(in, new InHandler { + override def onPush(): Unit = { + try { + val future = f(grab(in)) + val idx = buffer.enqueue(NotYetThere) + future.onComplete(result ⇒ futureCB.invoke(idx -> result))(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) } catch { case NonFatal(ex) ⇒ - if (decider(ex) != Supervision.Stop) { - elemsInFlight.put(idx, Failure(ex)) - null: Exception - } else ex + if (decider(ex) == Supervision.Stop) failStage(ex) } - if (exception != null) ctx.fail(exception) - else if (ctx.isHoldingDownstream) rec() - else ctx.ignore() - } - } + if (todo < parallelism) tryPull(in) + } + override def onUpstreamFinish(): Unit = { + if (todo == 0) completeStage() + } + }) - override def onUpstreamFinish(ctx: AsyncContext[Out, Notification]) = - if (ctx.isHoldingUpstream || !elemsInFlight.isEmpty) ctx.absorbTermination() - else ctx.finish() + setHandler(out, new OutHandler { + override def onPull(): Unit = pushOne() + }) + } } /** * INTERNAL API */ -private[akka] final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[Out], decider: Supervision.Decider) - extends AsyncStage[In, Out, Try[Out]] { +private[akka] final case class MapAsyncUnordered[In, Out](parallelism: Int, f: In ⇒ Future[Out]) + extends GraphStage[FlowShape[In, Out]] { - private var callback: AsyncCallback[Try[Out]] = _ - private var inFlight = 0 - private val buffer = FixedSizeBuffer[Out](parallelism) + private val in = Inlet[In]("in") + private val out = Outlet[Out]("out") - private def todo = inFlight + buffer.used + override val shape = FlowShape(in, out) - override def preStart(ctx: AsyncContext[Out, Try[Out]]): Unit = - callback = ctx.getAsyncCallback + override def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) { + override def toString = s"MapAsyncUnordered.Logic(inFlight=$inFlight, buffer=$buffer)" - override def decide(ex: Throwable) = decider(ex) + val decider = + inheritedAttributes.getAttribute(classOf[SupervisionStrategy]) + .map(_.decider).getOrElse(Supervision.stoppingDecider) - override def onPush(elem: In, ctx: AsyncContext[Out, Try[Out]]) = { - val future = f(elem) - inFlight += 1 - future.onComplete(callback.invoke)(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) - if (todo == parallelism) ctx.holdUpstream() - else ctx.pull() - } + var inFlight = 0 + val buffer = FixedSizeBuffer[Out](parallelism) + def todo = inFlight + buffer.used - override def onPull(ctx: AsyncContext[Out, Try[Out]]) = - if (buffer.isEmpty) { - if (ctx.isFinishing && inFlight == 0) ctx.finish() else ctx.holdDownstream() - } else { - val elem = buffer.dequeue() - if (ctx.isHoldingUpstream) ctx.pushAndPull(elem) - else ctx.push(elem) - } + def failOrPull(ex: Throwable) = + if (decider(ex) == Supervision.Stop) failStage(ex) + else if (isClosed(in) && todo == 0) completeStage() + else if (!hasBeenPulled(in)) tryPull(in) - override def onAsyncInput(input: Try[Out], ctx: AsyncContext[Out, Try[Out]]) = { - def ignoreOrFail(ex: Throwable) = - if (decider(ex) == Supervision.Stop) ctx.fail(ex) - else if (ctx.isHoldingUpstream) ctx.pull() - else if (ctx.isFinishing && todo == 0) ctx.finish() - else ctx.ignore() - - inFlight -= 1 - input match { - case Failure(ex) ⇒ ignoreOrFail(ex) - case Success(elem) ⇒ - if (elem == null) { - val ex = ReactiveStreamsCompliance.elementMustNotBeNullException - ignoreOrFail(ex) - } else if (ctx.isHoldingDownstream) { - if (ctx.isHoldingUpstream) ctx.pushAndPull(elem) - else ctx.push(elem) - } else { - buffer.enqueue(elem) - ctx.ignore() + val futureCB = + getAsyncCallback((result: Try[Out]) ⇒ { + inFlight -= 1 + result match { + case Failure(ex) ⇒ failOrPull(ex) + case Success(elem) ⇒ + if (elem == null) { + val ex = ReactiveStreamsCompliance.elementMustNotBeNullException + failOrPull(ex) + } else if (isAvailable(out)) { + if (!hasBeenPulled(in)) tryPull(in) + push(out, elem) + } else buffer.enqueue(elem) } - } - } + }).invoke _ - override def onUpstreamFinish(ctx: AsyncContext[Out, Try[Out]]) = - if (todo > 0) ctx.absorbTermination() - else ctx.finish() + setHandler(in, new InHandler { + override def onPush(): Unit = { + try { + val future = f(grab(in)) + inFlight += 1 + future.onComplete(futureCB)(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) + } catch { + case NonFatal(ex) ⇒ + if (decider(ex) == Supervision.Stop) failStage(ex) + } + if (todo < parallelism) tryPull(in) + } + override def onUpstreamFinish(): Unit = { + if (todo == 0) completeStage() + } + }) + + setHandler(out, new OutHandler { + override def onPull(): Unit = { + if (!buffer.isEmpty) push(out, buffer.dequeue()) + else if (isClosed(in) && todo == 0) completeStage() + if (todo < parallelism && !hasBeenPulled(in)) tryPull(in) + } + }) + } } /** @@ -695,7 +696,7 @@ private[akka] final case class Log[T](name: String, extract: T ⇒ Any, logAdapt // TODO more optimisations can be done here - prepare logOnPush function etc override def preStart(ctx: LifecycleContext): Unit = { - logLevels = ctx.attributes.logLevels.getOrElse(DefaultLogLevels) + logLevels = ctx.attributes.get[LogLevels](DefaultLogLevels) log = logAdapter match { case Some(l) ⇒ l case _ ⇒ @@ -787,7 +788,7 @@ private[stream] class GroupedWithin[T](n: Int, d: FiniteDuration) extends GraphS val out = Outlet[immutable.Seq[T]]("out") val shape = FlowShape(in, out) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { private val buf: VectorBuilder[T] = new VectorBuilder // True if: // - buf is nonEmpty @@ -855,13 +856,17 @@ private[stream] class GroupedWithin[T](n: Int, d: FiniteDuration) extends GraphS private[stream] class TakeWithin[T](timeout: FiniteDuration) extends SimpleLinearGraphStage[T] { - override def createLogic: GraphStageLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { setHandler(in, new InHandler { override def onPush(): Unit = push(out, grab(in)) override def onUpstreamFinish(): Unit = completeStage() override def onUpstreamFailure(ex: Throwable): Unit = failStage(ex) }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + final override protected def onTimer(key: Any): Unit = completeStage() @@ -874,7 +879,7 @@ private[stream] class TakeWithin[T](timeout: FiniteDuration) extends SimpleLinea private[stream] class DropWithin[T](timeout: FiniteDuration) extends SimpleLinearGraphStage[T] { private var allow = false - override def createLogic: GraphStageLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { setHandler(in, new InHandler { override def onPush(): Unit = if (allow) push(out, grab(in)) @@ -883,6 +888,10 @@ private[stream] class DropWithin[T](timeout: FiniteDuration) extends SimpleLinea override def onUpstreamFailure(ex: Throwable): Unit = failStage(ex) }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + final override protected def onTimer(key: Any): Unit = allow = true override def preStart(): Unit = scheduleOnce("DropWithinTimer", timeout) diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/IOSettings.scala b/akka-stream/src/main/scala/akka/stream/impl/io/IOSettings.scala index 2e6da93a1d..51dedd0622 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/IOSettings.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/IOSettings.scala @@ -1,15 +1,12 @@ package akka.stream.impl.io -import akka.stream.ActorAttributes.Dispatcher -import akka.stream.{ ActorMaterializer, MaterializationContext } +import akka.stream.ActorAttributes +import akka.stream.Attributes private[stream] object IOSettings { - /** Picks default akka.stream.blocking-io-dispatcher or the Attributes configured one */ - def blockingIoDispatcher(context: MaterializationContext): String = { - val mat = ActorMaterializer.downcast(context.materializer) - context.effectiveAttributes.attributeList.collectFirst { case d: Dispatcher ⇒ d.dispatcher } getOrElse { - mat.system.settings.config.getString("akka.stream.blocking-io-dispatcher") - } - } -} \ No newline at end of file + final val SyncFileSourceDefaultChunkSize = 8192 + final val SyncFileSourceName = Attributes.name("synchronousFileSource") + final val SyncFileSinkName = Attributes.name("synchronousFileSink") + final val IODispatcher = ActorAttributes.Dispatcher("akka.stream.default-blocking-io-dispatcher") +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/IOSinks.scala b/akka-stream/src/main/scala/akka/stream/impl/io/IOSinks.scala index 4573934055..313d552619 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/IOSinks.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/IOSinks.scala @@ -4,13 +4,11 @@ package akka.stream.impl.io import java.io.{ File, OutputStream } -import java.lang.{ Long ⇒ JLong } - -import akka.stream._ import akka.stream.impl.SinkModule import akka.stream.impl.StreamLayout.Module +import akka.stream.{ ActorMaterializer, MaterializationContext, Attributes, SinkShape } +import akka.stream.ActorAttributes.Dispatcher import akka.util.ByteString - import scala.concurrent.{ Future, Promise } /** @@ -27,7 +25,7 @@ private[akka] final class SynchronousFileSink(f: File, append: Boolean, val attr val bytesWrittenPromise = Promise[Long]() val props = SynchronousFileSubscriber.props(f, bytesWrittenPromise, settings.maxInputBufferSize, append) - val dispatcher = IOSettings.blockingIoDispatcher(context) + val dispatcher = context.effectiveAttributes.get[Dispatcher](IOSettings.IODispatcher).dispatcher val ref = mat.actorOf(context, props.withDispatcher(dispatcher)) (akka.stream.actor.ActorSubscriber[ByteString](ref), bytesWrittenPromise.future) diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/IOSources.scala b/akka-stream/src/main/scala/akka/stream/impl/io/IOSources.scala index fa2cb7b7d2..eef2ac9e16 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/IOSources.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/IOSources.scala @@ -10,6 +10,7 @@ import java.util.concurrent.{ LinkedBlockingQueue, BlockingQueue } import akka.actor.{ ActorRef, Deploy } import akka.japi import akka.stream._ +import akka.stream.ActorAttributes.Dispatcher import akka.stream.impl.StreamLayout.Module import akka.stream.impl.{ ErrorPublisher, SourceModule } import akka.stream.scaladsl.{ Source, FlowGraph } @@ -27,13 +28,13 @@ import scala.util.control.NonFatal private[akka] final class SynchronousFileSource(f: File, chunkSize: Int, val attributes: Attributes, shape: SourceShape[ByteString]) extends SourceModule[ByteString, Future[Long]](shape) { override def create(context: MaterializationContext) = { - // FIXME rewrite to be based on AsyncStage rather than dangerous downcasts + // FIXME rewrite to be based on GraphStage rather than dangerous downcasts val mat = ActorMaterializer.downcast(context.materializer) val settings = mat.effectiveSettings(context.effectiveAttributes) val bytesReadPromise = Promise[Long]() val props = SynchronousFilePublisher.props(f, bytesReadPromise, chunkSize, settings.initialInputBufferSize, settings.maxInputBufferSize) - val dispatcher = IOSettings.blockingIoDispatcher(context) + val dispatcher = context.effectiveAttributes.get[Dispatcher](IOSettings.IODispatcher).dispatcher val ref = mat.actorOf(context, props.withDispatcher(dispatcher)) diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala b/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala index aec1109411..a049fd38ea 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/InputStreamSinkStage.scala @@ -5,15 +5,14 @@ package akka.stream.impl.io import java.io.{ IOException, InputStream } import java.util.concurrent.{ BlockingQueue, LinkedBlockingDeque, TimeUnit } - import akka.stream.Attributes.InputBuffer import akka.stream.impl.io.InputStreamSinkStage._ import akka.stream.stage._ import akka.util.ByteString - import scala.annotation.tailrec import scala.concurrent.Future import scala.concurrent.duration.FiniteDuration +import akka.stream.Attributes private[akka] object InputStreamSinkStage { @@ -38,7 +37,7 @@ private[akka] class InputStreamSinkStage(timeout: FiniteDuration) extends SinkSt val maxBuffer = module.attributes.getAttribute(classOf[InputBuffer], InputBuffer(16, 16)).max require(maxBuffer > 0, "Buffer size must be greater than 0") - override def createLogicAndMaterializedValue: (GraphStageLogic, InputStream) = { + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, InputStream) = { val dataQueue = new LinkedBlockingDeque[StreamToAdapterMessage](maxBuffer + 1) var pullRequestIsSent = true diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala b/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala index 29f8e7dedb..6519c7f3c3 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala @@ -7,6 +7,7 @@ import java.io.{ IOException, OutputStream } import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.{ BlockingQueue, LinkedBlockingQueue } +import akka.stream.Attributes import akka.stream.Attributes.InputBuffer import akka.stream.impl.io.OutputStreamSourceStage._ import akka.stream.stage._ @@ -35,7 +36,7 @@ private[akka] class OutputStreamSourceStage(timeout: FiniteDuration) extends Sou val maxBuffer = module.attributes.getAttribute(classOf[InputBuffer], InputBuffer(16, 16)).max require(maxBuffer > 0, "Buffer size must be greater than 0") - override def createLogicAndMaterializedValue: (GraphStageLogic, OutputStream) = { + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, OutputStream) = { val dataQueue = new LinkedBlockingQueue[ByteString](maxBuffer) var flush: Option[Promise[Unit]] = None diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TcpListenStreamActor.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TcpListenStreamActor.scala index 8af8d3a578..3c25bc79e6 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/TcpListenStreamActor.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/TcpListenStreamActor.scala @@ -156,11 +156,11 @@ private[akka] class TcpListenStreamActor(localAddressPromise: Promise[InetSocket val tcpStreamActor = context.watch(context.actorOf(TcpStreamActor.inboundProps(connection, halfClose, settings))) val processor = ActorProcessor[ByteString, ByteString](tcpStreamActor) - import scala.concurrent.duration._ + import scala.concurrent.duration.FiniteDuration val handler = (idleTimeout match { case d: FiniteDuration ⇒ Flow[ByteString].join(Timeouts.idleTimeoutBidi[ByteString, ByteString](d)) case _ ⇒ Flow[ByteString] - }).andThenMat(() ⇒ (processor, ())) + }).via(Flow.fromProcessor(() ⇒ processor)) val conn = StreamTcp.IncomingConnection( connected.localAddress, diff --git a/akka-stream/src/main/scala/akka/stream/io/ByteStringParser.scala b/akka-stream/src/main/scala/akka/stream/io/ByteStringParser.scala new file mode 100644 index 0000000000..c442eefcae --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/io/ByteStringParser.scala @@ -0,0 +1,119 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.io + +import scala.util.control.NoStackTrace +import akka.stream._ +import akka.stream.stage._ +import akka.util.ByteString +import scala.annotation.tailrec + +abstract class ByteStringParser[T] extends GraphStage[FlowShape[ByteString, T]] { + import ByteStringParser._ + + private val bytesIn = Inlet[ByteString]("bytesIn") + private val objOut = Outlet[T]("objOut") + + final override val shape = FlowShape(bytesIn, objOut) + + class ParsingLogic extends GraphStageLogic(shape) { + override def preStart(): Unit = pull(bytesIn) + setHandler(objOut, eagerTerminateOutput) + + private var buffer = ByteString.empty + private var current: ParseStep[T] = FinishedParser + + final protected def startWith(step: ParseStep[T]): Unit = current = step + + @tailrec private def doParse(): Unit = + if (buffer.nonEmpty) { + val cont = try { + val reader = new ByteReader(buffer) + val (elem, next) = current.parse(reader) + emit(objOut, elem) + if (next == FinishedParser) { + completeStage() + false + } else { + buffer = reader.remainingData + current = next + true + } + } catch { + case NeedMoreData ⇒ + pull(bytesIn) + false + } + if (cont) doParse() + } else pull(bytesIn) + + setHandler(bytesIn, new InHandler { + override def onPush(): Unit = { + buffer ++= grab(bytesIn) + doParse() + } + override def onUpstreamFinish(): Unit = + if (buffer.isEmpty) completeStage() + else current.onTruncation() + }) + } +} + +object ByteStringParser { + + trait ParseStep[+T] { + def parse(reader: ByteReader): (T, ParseStep[T]) + def onTruncation(): Unit = throw new IllegalStateException("truncated data in ByteStringParser") + } + + object FinishedParser extends ParseStep[Nothing] { + def parse(reader: ByteReader) = + throw new IllegalStateException("no initial parser installed: you must use startWith(...)") + } + + val NeedMoreData = new Exception with NoStackTrace + + class ByteReader(input: ByteString) { + + private[this] var off = 0 + + def hasRemaining: Boolean = off < input.size + def remainingSize: Int = input.size - off + + def currentOffset: Int = off + def remainingData: ByteString = input.drop(off) + def fromStartToHere: ByteString = input.take(off) + + def take(n: Int): ByteString = + if (off + n <= input.length) { + val o = off + off = o + n + input.slice(o, off) + } else throw NeedMoreData + def takeAll(): ByteString = { + val ret = remainingData + off = input.size + ret + } + + def readByte(): Int = + if (off < input.length) { + val x = input(off) + off += 1 + x & 0xFF + } else throw NeedMoreData + def readShortLE(): Int = readByte() | (readByte() << 8) + def readIntLE(): Int = readShortLE() | (readShortLE() << 16) + def readLongLE(): Long = (readIntLE() & 0xffffffffL) | ((readIntLE() & 0xffffffffL) << 32) + + def readShortBE(): Int = (readByte() << 8) | readByte() + def readIntBE(): Int = (readShortBE() << 16) | readShortBE() + def readLongBE(): Long = ((readIntBE() & 0xffffffffL) << 32) | (readIntBE() & 0xffffffffL) + + def skip(numBytes: Int): Unit = + if (off + numBytes <= input.length) off += numBytes + else throw NeedMoreData + def skipZeroTerminatedString(): Unit = while (readByte() != 0) {} + } +} \ No newline at end of file diff --git a/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSink.scala b/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSink.scala index f5fabdbec6..797348c3b8 100644 --- a/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSink.scala +++ b/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSink.scala @@ -5,7 +5,6 @@ package akka.stream.io import java.io.File -import akka.stream.impl.io.SynchronousFileSink import akka.stream.{ Attributes, javadsl, ActorAttributes } import akka.stream.scaladsl.Sink import akka.util.ByteString @@ -16,8 +15,10 @@ import scala.concurrent.Future * Sink which writes incoming [[ByteString]]s to the given file */ object SynchronousFileSink { + import akka.stream.impl.io.IOSettings._ + import akka.stream.impl.io.SynchronousFileSink - final val DefaultAttributes = Attributes.name("synchronousFileSink") + final val DefaultAttributes = SyncFileSinkName and IODispatcher /** * Synchronous (Java 6 compatible) Sink that writes incoming [[ByteString]] elements to the given file. diff --git a/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSource.scala b/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSource.scala index f038595952..6ad5d763b8 100644 --- a/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSource.scala +++ b/akka-stream/src/main/scala/akka/stream/io/SynchronousFileSource.scala @@ -10,9 +10,10 @@ import akka.util.ByteString import scala.concurrent.Future object SynchronousFileSource { + import akka.stream.impl.io.IOSettings._ import akka.stream.impl.io.SynchronousFileSource - final val DefaultChunkSize = 8192 - final val DefaultAttributes = Attributes.name("synchronousFileSource") + + final val DefaultAttributes = SyncFileSourceName and IODispatcher /** * Creates a synchronous (Java 6 compatible) Source from a Files contents. @@ -24,7 +25,7 @@ object SynchronousFileSource { * * It materializes a [[Future]] containing the number of bytes read from the source file upon completion. */ - def apply(f: File, chunkSize: Int = DefaultChunkSize): Source[ByteString, Future[Long]] = + def apply(f: File, chunkSize: Int = SyncFileSourceDefaultChunkSize): Source[ByteString, Future[Long]] = new Source(new SynchronousFileSource(f, chunkSize, DefaultAttributes, Source.shape("SynchronousFileSource")).nest()) // TO DISCUSS: I had to add wrap() here to make the name available /** @@ -37,7 +38,7 @@ object SynchronousFileSource { * * It materializes a [[Future]] containing the number of bytes read from the source file upon completion. */ - def create(f: File): javadsl.Source[ByteString, Future[java.lang.Long]] = create(f, DefaultChunkSize) + def create(f: File): javadsl.Source[ByteString, Future[java.lang.Long]] = create(f, SyncFileSourceDefaultChunkSize) /** * Creates a synchronous (Java 6 compatible) Source from a Files contents. diff --git a/akka-stream/src/main/scala/akka/stream/io/Timeouts.scala b/akka-stream/src/main/scala/akka/stream/io/Timeouts.scala index 52e1c69a7d..fe01877147 100644 --- a/akka-stream/src/main/scala/akka/stream/io/Timeouts.scala +++ b/akka-stream/src/main/scala/akka/stream/io/Timeouts.scala @@ -5,7 +5,7 @@ import java.util.concurrent.{ TimeUnit, TimeoutException } import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage import akka.stream.scaladsl.{ BidiFlow, Flow } import akka.stream.stage._ -import akka.stream.{ BidiShape, Inlet, Outlet } +import akka.stream.{ BidiShape, Inlet, Outlet, Attributes } import scala.concurrent.duration.{ Deadline, FiniteDuration } @@ -62,7 +62,7 @@ object Timeouts { private class InitialTimeout[T](timeout: FiniteDuration) extends SimpleLinearGraphStage[T] { - override def createLogic: GraphStageLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { private var initialHasPassed = false setHandler(in, new InHandler { @@ -72,6 +72,10 @@ object Timeouts { } }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + final override protected def onTimer(key: Any): Unit = if (!initialHasPassed) failStage(new TimeoutException(s"The first element has not yet passed through in $timeout.")) @@ -84,11 +88,15 @@ object Timeouts { private class CompletionTimeout[T](timeout: FiniteDuration) extends SimpleLinearGraphStage[T] { - override def createLogic: GraphStageLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { setHandler(in, new InHandler { override def onPush(): Unit = push(out, grab(in)) }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + final override protected def onTimer(key: Any): Unit = failStage(new TimeoutException(s"The stream has not been completed in $timeout.")) @@ -101,7 +109,7 @@ object Timeouts { private class IdleTimeout[T](timeout: FiniteDuration) extends SimpleLinearGraphStage[T] { private var nextDeadline: Deadline = Deadline.now + timeout - override def createLogic: GraphStageLogic = new SimpleLinearStageLogic { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { setHandler(in, new InHandler { override def onPush(): Unit = { nextDeadline = Deadline.now + timeout @@ -109,6 +117,10 @@ object Timeouts { } }) + setHandler(out, new OutHandler { + override def onPull(): Unit = pull(in) + }) + final override protected def onTimer(key: Any): Unit = if (nextDeadline.isOverdue()) failStage(new TimeoutException(s"No elements passed in the last $timeout.")) @@ -128,7 +140,7 @@ object Timeouts { override def toString = "IdleTimeoutBidi" - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { private var nextDeadline: Deadline = Deadline.now + timeout setHandler(in1, new InHandler { diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala index 50dbf02e50..48fd37535e 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala @@ -22,8 +22,14 @@ object Flow { /** Create a `Flow` which can process elements of type `T`. */ def create[T](): javadsl.Flow[T, T, Unit] = fromGraph(scaladsl.Flow[T]) - def create[I, O](processorFactory: function.Creator[Processor[I, O]]): javadsl.Flow[I, O, Unit] = - new Flow(scaladsl.Flow(() ⇒ processorFactory.create())) + def fromProcessor[I, O](processorFactory: function.Creator[Processor[I, O]]): javadsl.Flow[I, O, Unit] = + new Flow(scaladsl.Flow.fromProcessor(() ⇒ processorFactory.create())) + + def fromProcessorMat[I, O, Mat](processorFactory: function.Creator[Pair[Processor[I, O], Mat]]): javadsl.Flow[I, O, Mat] = + new Flow(scaladsl.Flow.fromProcessorMat { () ⇒ + val javaPair = processorFactory.create() + (javaPair.first, javaPair.second) + }) /** Create a `Flow` which can process elements of type `T`. */ def of[T](clazz: Class[T]): javadsl.Flow[T, T, Unit] = create[T]() diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index 7ba5ee4736..5f099809bc 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -7,10 +7,12 @@ import akka.event.LoggingAdapter import akka.stream.Attributes._ import akka.stream._ import akka.stream.impl.SplitDecision._ -import akka.stream.impl.Stages.{ DirectProcessor, MaterializingStageFactory, StageModule } +import akka.stream.impl.Stages.{ SymbolicGraphStage, StageModule, DirectProcessor, SymbolicStage } import akka.stream.impl.StreamLayout.{ EmptyModule, Module } -import akka.stream.impl.fusing.{ DropWithin, GroupedWithin, TakeWithin } +import akka.stream.impl.fusing.{ DropWithin, GroupedWithin, TakeWithin, MapAsync, MapAsyncUnordered } import akka.stream.impl.{ ReactiveStreamsCompliance, ConstantFun, Stages, StreamLayout } +import akka.stream.impl.{ Stages, StreamLayout } +import akka.stream.stage.AbstractStage.{ PushPullGraphStageWithMaterializedValue, PushPullGraphStage } import akka.stream.stage._ import org.reactivestreams.{ Processor, Publisher, Subscriber, Subscription } @@ -30,7 +32,7 @@ final class Flow[-In, +Out, +Mat](private[stream] override val module: Module) override type Repr[+O, +M] = Flow[In @uncheckedVariance, O, M] - private[stream] def isIdentity: Boolean = this.module.isInstanceOf[Stages.Identity] + private[stream] def isIdentity: Boolean = this.module eq Stages.identityGraph.module def viaMat[T, Mat2, Mat3](flow: Graph[FlowShape[Out, T], Mat2])(combine: (Mat, Mat2) ⇒ Mat3): Flow[In, T, Mat3] = { if (this.isIdentity) { @@ -182,19 +184,15 @@ final class Flow[-In, +Out, +Mat](private[stream] override val module: Module) } /** INTERNAL API */ - override private[stream] def andThen[U](op: StageModule): Repr[U, Mat] = { + // FIXME: Only exists to keep old stuff alive + private[stream] override def deprecatedAndThen[U](op: StageModule): Repr[U, Mat] = { //No need to copy here, op is a fresh instance - if (op.isInstanceOf[Stages.Identity]) this.asInstanceOf[Repr[U, Mat]] - else if (this.isIdentity) new Flow(op).asInstanceOf[Repr[U, Mat]] + if (this.isIdentity) new Flow(op).asInstanceOf[Repr[U, Mat]] else new Flow(module.fuse(op, shape.outlet, op.inPort).replaceShape(FlowShape(shape.inlet, op.outPort))) } - private[stream] def andThenMat[U, Mat2](op: MaterializingStageFactory): Repr[U, Mat2] = { - if (this.isIdentity) new Flow(op).asInstanceOf[Repr[U, Mat2]] - else new Flow(module.fuse(op, shape.outlet, op.inPort, Keep.right).replaceShape(FlowShape(shape.inlet, op.outPort))) - } - - private[akka] def andThenMat[U, Mat2, O >: Out](processorFactory: () ⇒ (Processor[O, U], Mat2)): Repr[U, Mat2] = { + // FIXME: Only exists to keep old stuff alive + private[akka] def deprecatedAndThenMat[U, Mat2, O >: Out](processorFactory: () ⇒ (Processor[O, U], Mat2)): Repr[U, Mat2] = { val op = Stages.DirectProcessor(processorFactory.asInstanceOf[() ⇒ (Processor[Any, Any], Any)]) if (this.isIdentity) new Flow(op).asInstanceOf[Repr[U, Mat2]] else new Flow[In, U, Mat2](module.fuse(op, shape.outlet, op.inPort, Keep.right).replaceShape(FlowShape(shape.inlet, op.outPort))) @@ -244,14 +242,20 @@ final class Flow[-In, +Out, +Mat](private[stream] override val module: Module) } object Flow { - private[this] val identity: Flow[Any, Any, Unit] = new Flow[Any, Any, Unit](Stages.Identity()) + private[this] val identity: Flow[Any, Any, Unit] = new Flow[Any, Any, Unit](SymbolicGraphStage(Stages.Identity).module) /** * Creates a Flow from a Reactive Streams [[org.reactivestreams.Processor]] */ - def apply[I, O](processorFactory: () ⇒ Processor[I, O]): Flow[I, O, Unit] = { - val untypedFactory = processorFactory.asInstanceOf[() ⇒ Processor[Any, Any]] - Flow[I].andThen(DirectProcessor(() ⇒ (untypedFactory(), ()))) + def fromProcessor[I, O](processorFactory: () ⇒ Processor[I, O]): Flow[I, O, Unit] = { + fromProcessorMat(() ⇒ (processorFactory(), ())) + } + + /** + * Creates a Flow from a Reactive Streams [[org.reactivestreams.Processor]] and returns a materialized value. + */ + def fromProcessorMat[I, O, Mat](processorFactory: () ⇒ (Processor[I, O], Mat)): Flow[I, O, Mat] = { + Flow[I].deprecatedAndThenMat(processorFactory) } /** @@ -377,7 +381,7 @@ trait FlowOps[+Out, +Mat] { * '''Cancels when''' downstream cancels * */ - def recover[T >: Out](pf: PartialFunction[Throwable, T]): Repr[T, Mat] = andThen(Recover(pf.asInstanceOf[PartialFunction[Any, Any]])) + def recover[T >: Out](pf: PartialFunction[Throwable, T]): Repr[T, Mat] = andThen(Recover(pf)) /** * Transform this stream by applying the given function to each of the elements @@ -392,7 +396,7 @@ trait FlowOps[+Out, +Mat] { * '''Cancels when''' downstream cancels * */ - def map[T](f: Out ⇒ T): Repr[T, Mat] = andThen(Map(f.asInstanceOf[Any ⇒ Any])) + def map[T](f: Out ⇒ T): Repr[T, Mat] = andThen(Map(f)) /** * Transform each input element into an `Iterable` of output elements that is @@ -412,7 +416,7 @@ trait FlowOps[+Out, +Mat] { * '''Cancels when''' downstream cancels * */ - def mapConcat[T](f: Out ⇒ immutable.Iterable[T]): Repr[T, Mat] = andThen(MapConcat(f.asInstanceOf[Any ⇒ immutable.Iterable[Any]])) + def mapConcat[T](f: Out ⇒ immutable.Iterable[T]): Repr[T, Mat] = andThen(MapConcat(f)) /** * Transform this stream by applying the given function to each of the elements @@ -443,8 +447,7 @@ trait FlowOps[+Out, +Mat] { * * @see [[#mapAsyncUnordered]] */ - def mapAsync[T](parallelism: Int)(f: Out ⇒ Future[T]): Repr[T, Mat] = - andThen(MapAsync(parallelism, f.asInstanceOf[Any ⇒ Future[Any]])) + def mapAsync[T](parallelism: Int)(f: Out ⇒ Future[T]): Repr[T, Mat] = via(MapAsync(parallelism, f)) /** * Transform this stream by applying the given function to each of the elements @@ -475,8 +478,7 @@ trait FlowOps[+Out, +Mat] { * * @see [[#mapAsync]] */ - def mapAsyncUnordered[T](parallelism: Int)(f: Out ⇒ Future[T]): Repr[T, Mat] = - andThen(MapAsyncUnordered(parallelism, f.asInstanceOf[Any ⇒ Future[Any]])) + def mapAsyncUnordered[T](parallelism: Int)(f: Out ⇒ Future[T]): Repr[T, Mat] = via(MapAsyncUnordered(parallelism, f)) /** * Only pass on those elements that satisfy the given predicate. @@ -489,7 +491,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def filter(p: Out ⇒ Boolean): Repr[Out, Mat] = andThen(Filter(p.asInstanceOf[Any ⇒ Boolean])) + def filter(p: Out ⇒ Boolean): Repr[Out, Mat] = andThen(Filter(p)) /** * Only pass on those elements that NOT satisfy the given predicate. @@ -522,7 +524,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' predicate returned false or downstream cancels */ - def takeWhile(p: Out ⇒ Boolean): Repr[Out, Mat] = andThen(TakeWhile(p.asInstanceOf[Any ⇒ Boolean])) + def takeWhile(p: Out ⇒ Boolean): Repr[Out, Mat] = andThen(TakeWhile(p)) /** * Discard elements at the beginning of the stream while predicate is true. @@ -536,7 +538,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def dropWhile(p: Out ⇒ Boolean): Repr[Out, Mat] = andThen(DropWhile(p.asInstanceOf[Any ⇒ Boolean])) + def dropWhile(p: Out ⇒ Boolean): Repr[Out, Mat] = andThen(DropWhile(p)) /** * Transform this stream by applying the given partial function to each of the elements @@ -551,7 +553,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def collect[T](pf: PartialFunction[Out, T]): Repr[T, Mat] = andThen(Collect(pf.asInstanceOf[PartialFunction[Any, Any]])) + def collect[T](pf: PartialFunction[Out, T]): Repr[T, Mat] = andThen(Collect(pf)) /** * Chunk up this stream into groups of the given size, with the last group @@ -604,7 +606,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def scan[T](zero: T)(f: (T, Out) ⇒ T): Repr[T, Mat] = andThen(Scan(zero, f.asInstanceOf[(Any, Any) ⇒ Any])) + def scan[T](zero: T)(f: (T, Out) ⇒ T): Repr[T, Mat] = andThen(Scan(zero, f)) /** * Similar to `scan` but only emits its result when the upstream completes, @@ -623,7 +625,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels */ - def fold[T](zero: T)(f: (T, Out) ⇒ T): Repr[T, Mat] = andThen(Fold(zero, f.asInstanceOf[(Any, Any) ⇒ Any])) + def fold[T](zero: T)(f: (T, Out) ⇒ T): Repr[T, Mat] = andThen(Fold(zero, f)) /** * Intersperses stream with provided element, similar to how [[scala.collection.immutable.List.mkString]] @@ -778,8 +780,7 @@ trait FlowOps[+Out, +Mat] { * * '''Cancels when''' downstream cancels or timer fires */ - def takeWithin(d: FiniteDuration): Repr[Out, Mat] = - via(new TakeWithin[Out](d).withAttributes(name("takeWithin"))) + def takeWithin(d: FiniteDuration): Repr[Out, Mat] = via(new TakeWithin[Out](d).withAttributes(name("takeWithin"))) /** * Allows a faster upstream to progress independently of a slower subscriber by conflating elements into a summary @@ -800,8 +801,7 @@ trait FlowOps[+Out, +Mat] { * @param seed Provides the first state for a conflated value using the first unconsumed element as a start * @param aggregate Takes the currently aggregated value and the current pending element to produce a new aggregate */ - def conflate[S](seed: Out ⇒ S)(aggregate: (S, Out) ⇒ S): Repr[S, Mat] = - andThen(Conflate(seed.asInstanceOf[Any ⇒ Any], aggregate.asInstanceOf[(Any, Any) ⇒ Any])) + def conflate[S](seed: Out ⇒ S)(aggregate: (S, Out) ⇒ S): Repr[S, Mat] = andThen(Conflate(seed, aggregate)) /** * Allows a faster downstream to progress independently of a slower publisher by extrapolating elements from an older @@ -827,8 +827,7 @@ trait FlowOps[+Out, +Mat] { * @param extrapolate Takes the current extrapolation state to produce an output element and the next extrapolation * state. */ - def expand[S, U](seed: Out ⇒ S)(extrapolate: S ⇒ (U, S)): Repr[U, Mat] = - andThen(Expand(seed.asInstanceOf[Any ⇒ Any], extrapolate.asInstanceOf[Any ⇒ (Any, Any)])) + def expand[S, U](seed: Out ⇒ S)(extrapolate: S ⇒ (U, S)): Repr[U, Mat] = andThen(Expand(seed, extrapolate)) /** * Adds a fixed size buffer in the flow that allows to store elements from a faster upstream until it becomes full. @@ -849,8 +848,7 @@ trait FlowOps[+Out, +Mat] { * @param size The size of the buffer in element count * @param overflowStrategy Strategy that is used when incoming elements cannot fit inside the buffer */ - def buffer(size: Int, overflowStrategy: OverflowStrategy): Repr[Out, Mat] = - andThen(Buffer(size, overflowStrategy)) + def buffer(size: Int, overflowStrategy: OverflowStrategy): Repr[Out, Mat] = andThen(Buffer(size, overflowStrategy)) /** * Generic transformation of a stream with a custom processing [[akka.stream.stage.Stage]]. @@ -858,10 +856,10 @@ trait FlowOps[+Out, +Mat] { * operator that performs the transformation. */ def transform[T](mkStage: () ⇒ Stage[Out, T]): Repr[T, Mat] = - andThen(StageFactory(mkStage)) + via(new PushPullGraphStage((attr) ⇒ mkStage(), Attributes.none)) private[akka] def transformMaterializing[T, M](mkStageAndMaterialized: () ⇒ (Stage[Out, T], M)): Repr[T, M] = - andThenMat(MaterializingStageFactory(mkStageAndMaterialized)) + viaMat(new PushPullGraphStageWithMaterializedValue[Out, T, Unit, M]((attr) ⇒ mkStageAndMaterialized(), Attributes.none))(Keep.right) /** * Takes up to `n` elements from the stream (less than `n` only if the upstream completes before emitting `n` elements) @@ -886,7 +884,7 @@ trait FlowOps[+Out, +Mat] { * */ def prefixAndTail[U >: Out](n: Int): Repr[(immutable.Seq[Out], Source[U, Unit]), Mat] = - andThen(PrefixAndTail(n)) + deprecatedAndThen(PrefixAndTail(n)) /** * This operation demultiplexes the incoming stream into separate output @@ -918,7 +916,7 @@ trait FlowOps[+Out, +Mat] { * */ def groupBy[K, U >: Out](f: Out ⇒ K): Repr[(K, Source[U, Unit]), Mat] = - andThen(GroupBy(f.asInstanceOf[Any ⇒ Any])) + deprecatedAndThen(GroupBy(f.asInstanceOf[Any ⇒ Any])) /** * This operation applies the given predicate to all incoming elements and @@ -962,7 +960,7 @@ trait FlowOps[+Out, +Mat] { */ def splitWhen[U >: Out](p: Out ⇒ Boolean): Repr[Out, Mat]#Repr[Source[U, Unit], Mat] = { val f = p.asInstanceOf[Any ⇒ Boolean] - withAttributes(name("splitWhen")).andThen(Split(el ⇒ if (f(el)) SplitBefore else Continue)) + withAttributes(name("splitWhen")).deprecatedAndThen(Split(el ⇒ if (f(el)) SplitBefore else Continue)) } /** @@ -999,7 +997,7 @@ trait FlowOps[+Out, +Mat] { */ def splitAfter[U >: Out](p: Out ⇒ Boolean): Repr[Out, Mat]#Repr[Source[U, Unit], Mat] = { val f = p.asInstanceOf[Any ⇒ Boolean] - withAttributes(name("splitAfter")).andThen(Split(el ⇒ if (f(el)) SplitAfter else Continue)) + withAttributes(name("splitAfter")).deprecatedAndThen(Split(el ⇒ if (f(el)) SplitAfter else Continue)) } /** @@ -1016,7 +1014,7 @@ trait FlowOps[+Out, +Mat] { * */ def flatten[U](strategy: FlattenStrategy[Out, U]): Repr[U, Mat] = strategy match { - case scaladsl.FlattenStrategy.Concat | javadsl.FlattenStrategy.Concat ⇒ andThen(ConcatAll()) + case scaladsl.FlattenStrategy.Concat | javadsl.FlattenStrategy.Concat ⇒ deprecatedAndThen(ConcatAll()) case _ ⇒ throw new IllegalArgumentException(s"Unsupported flattening strategy [${strategy.getClass.getName}]") } @@ -1211,8 +1209,9 @@ trait FlowOps[+Out, +Mat] { def withAttributes(attr: Attributes): Repr[Out, Mat] /** INTERNAL API */ - private[scaladsl] def andThen[U](op: StageModule): Repr[U, Mat] + private[scaladsl] def andThen[T](op: SymbolicStage[Out, T]): Repr[T, Mat] = + via(SymbolicGraphStage(op)) - private[scaladsl] def andThenMat[U, Mat2](op: MaterializingStageFactory): Repr[U, Mat2] + private[scaladsl] def deprecatedAndThen[U](op: StageModule): Repr[U, Mat] } diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala index 2f557e88d4..71a46d4893 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Graph.scala @@ -3,11 +3,10 @@ */ package akka.stream.scaladsl -import akka.stream.impl.Stages.{ MaterializingStageFactory, StageModule } +import akka.stream.impl.Stages.{ StageModule, SymbolicStage } import akka.stream.impl._ import akka.stream.impl.StreamLayout._ import akka.stream._ -import Attributes.name import akka.stream.stage.{ OutHandler, InHandler, GraphStageLogic, GraphStage } import scala.annotation.unchecked.uncheckedVariance import scala.annotation.tailrec @@ -41,7 +40,7 @@ class Merge[T] private (val inputPorts: Int, val eagerClose: Boolean) extends Gr val out: Outlet[T] = Outlet[T]("Merge.out") override val shape: UniformFanInShape[T, T] = UniformFanInShape(out, in: _*) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { private var initialized = false private val pendingQueue = Array.ofDim[Inlet[T]](inputPorts) @@ -133,7 +132,7 @@ object MergePreferred { * * '''Backpressures when''' downstream backpressures * - * '''Completes when''' all upstreams complete + * '''Completes when''' all upstreams complete (eagerClose=false) or one upstream completes (eagerClose=true) * * '''Cancels when''' downstream cancels * @@ -147,91 +146,82 @@ class MergePreferred[T] private (val secondaryPorts: Int, val eagerClose: Boolea def out: Outlet[T] = shape.out def preferred: Inlet[T] = shape.preferred - // FIXME: Factor out common stuff with Merge - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { - private var initialized = false - - private val pendingQueue = Array.ofDim[Inlet[T]](secondaryPorts) - private var pendingHead = 0 - private var pendingTail = 0 - - private var runningUpstreams = secondaryPorts + 1 - private def upstreamsClosed = runningUpstreams == 0 - - private def pending: Boolean = pendingHead != pendingTail - private def priority: Boolean = isAvailable(preferred) - - private def enqueue(in: Inlet[T]): Unit = { - pendingQueue(pendingTail % secondaryPorts) = in - pendingTail += 1 + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { + var openInputs = secondaryPorts + 1 + def onComplete(): Unit = { + openInputs -= 1 + if (eagerClose || openInputs == 0) completeStage() } - private def dequeueAndDispatch(): Unit = { - val in = pendingQueue(pendingHead % secondaryPorts) - pendingHead += 1 - push(out, grab(in)) - if (upstreamsClosed && !pending && !priority) completeStage() - else tryPull(in) - } - - // FIXME: slow iteration, try to make in a vector and inject into shape instead - (0 until secondaryPorts).map(in).foreach { i ⇒ - setHandler(i, new InHandler { - override def onPush(): Unit = { - if (isAvailable(out)) { - if (!pending) { - push(out, grab(i)) - tryPull(i) - } - } else enqueue(i) - } - - override def onUpstreamFinish() = - if (eagerClose) { - (0 until secondaryPorts).foreach(i ⇒ cancel(in(i))) - cancel(preferred) - runningUpstreams = 0 - if (!pending) completeStage() - } else { - runningUpstreams -= 1 - if (upstreamsClosed && !pending && !priority) completeStage() - } - }) - } - - setHandler(preferred, new InHandler { - override def onPush() = { - if (isAvailable(out)) { - push(out, grab(preferred)) - tryPull(preferred) - } - } - - override def onUpstreamFinish() = - if (eagerClose) { - (0 until secondaryPorts).foreach(i ⇒ cancel(in(i))) - runningUpstreams = 0 - if (!pending) completeStage() - } else { - runningUpstreams -= 1 - if (upstreamsClosed && !pending && !priority) completeStage() - } - }) - setHandler(out, new OutHandler { + private var first = true override def onPull(): Unit = { - if (!initialized) { - initialized = true - // FIXME: slow iteration, try to make in a vector and inject into shape instead + if (first) { + first = false tryPull(preferred) - (0 until secondaryPorts).map(in).foreach(tryPull) - } else if (priority) { - push(out, grab(preferred)) - tryPull(preferred) - } else if (pending) - dequeueAndDispatch() + shape.inSeq.foreach(tryPull) + } } }) + + val pullMe = Array.tabulate(secondaryPorts)(i ⇒ { + val port = in(i) + () ⇒ tryPull(port) + }) + + /* + * This determines the unfairness of the merge: + * - at 1 the preferred will grab 40% of the bandwidth against three equally fast secondaries + * - at 2 the preferred will grab almost all bandwidth against three equally fast secondaries + * (measured with eventLimit=1 in the GraphInterpreter, so may not be accurate) + */ + val maxEmitting = 2 + var preferredEmitting = 0 + + setHandler(preferred, new InHandler { + override def onUpstreamFinish(): Unit = onComplete() + override def onPush(): Unit = + if (preferredEmitting == maxEmitting) () // blocked + else emitPreferred() + + def emitPreferred(): Unit = { + preferredEmitting += 1 + emit(out, grab(preferred), emitted) + tryPull(preferred) + } + + val emitted = () ⇒ { + preferredEmitting -= 1 + if (isAvailable(preferred)) emitPreferred() + else if (preferredEmitting == 0) emitSecondary() + } + + def emitSecondary(): Unit = { + var i = 0 + while (i < secondaryPorts) { + val port = in(i) + if (isAvailable(port)) emit(out, grab(port), pullMe(i)) + i += 1 + } + } + }) + + var i = 0 + while (i < secondaryPorts) { + val port = in(i) + val pullPort = pullMe(i) + setHandler(port, new InHandler { + override def onPush(): Unit = { + if (preferredEmitting > 0) () // blocked + else { + emit(out, grab(port), pullPort) + } + } + override def onUpstreamFinish(): Unit = onComplete() + }) + i += 1 + } + } } @@ -266,7 +256,7 @@ class Broadcast[T](private val outputPorts: Int, eagerCancel: Boolean) extends G val out: immutable.IndexedSeq[Outlet[T]] = Vector.tabulate(outputPorts)(i ⇒ Outlet[T]("Broadcast.out" + i)) override val shape: UniformFanOutShape[T, T] = UniformFanOutShape(in, out: _*) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { private var pendingCount = outputPorts private val pending = Array.fill[Boolean](outputPorts)(true) private var downstreamsRunning = outputPorts @@ -364,7 +354,7 @@ class Balance[T](val outputPorts: Int, waitForAllDownstreams: Boolean) extends G val out: immutable.IndexedSeq[Outlet[T]] = Vector.tabulate(outputPorts)(i ⇒ Outlet[T]("Balance.out" + i)) override val shape: UniformFanOutShape[T, T] = UniformFanOutShape[T, T](in, out: _*) - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { private val pendingQueue = Array.ofDim[Outlet[T]](outputPorts) private var pendingHead: Int = 0 private var pendingTail: Int = 0 @@ -531,7 +521,7 @@ class Concat[T](inputCount: Int) extends GraphStage[UniformFanInShape[T, T]] { val out: Outlet[T] = Outlet[T]("Concat.out") override val shape: UniformFanInShape[T, T] = UniformFanInShape(out, in: _*) - override def createLogic = new GraphStageLogic(shape) { + override def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) { var activeStream: Int = 0 { @@ -634,7 +624,7 @@ object FlowGraph extends GraphApply { module.shape.outlet.asInstanceOf[Outlet[M]] } - private[stream] def andThen(port: OutPort, op: StageModule): Unit = { + private[stream] def deprecatedAndThen(port: OutPort, op: StageModule): Unit = { moduleInProgress = moduleInProgress .compose(op) @@ -759,17 +749,6 @@ object FlowGraph extends GraphApply { override def withAttributes(attr: Attributes): Repr[Out, Mat] = throw new UnsupportedOperationException("Cannot set attributes on chained ops from a junction output port") - override private[scaladsl] def andThen[U](op: StageModule): Repr[U, Mat] = { - b.andThen(outlet, op) - new PortOps(op.shape.outlet.asInstanceOf[Outlet[U]], b) - } - - override private[scaladsl] def andThenMat[U, Mat2](op: MaterializingStageFactory): Repr[U, Mat2] = { - // We don't track materialization here - b.andThen(outlet, op) - new PortOps(op.shape.outlet.asInstanceOf[Outlet[U]], b) - } - override def importAndGetPort(b: Builder[_]): Outlet[Out] = outlet override def via[T, Mat2](flow: Graph[FlowShape[Out, T], Mat2]): Repr[T, Mat] = @@ -777,6 +756,12 @@ object FlowGraph extends GraphApply { override def viaMat[T, Mat2, Mat3](flow: Graph[FlowShape[Out, T], Mat2])(combine: (Mat, Mat2) ⇒ Mat3) = throw new UnsupportedOperationException("Cannot use viaMat on a port") + + override private[scaladsl] def deprecatedAndThen[U](op: StageModule): PortOps[U, Mat] = { + b.deprecatedAndThen(outlet, op) + new PortOps(op.shape.outlet.asInstanceOf[Outlet[U]], b) + } + } final class DisabledPortOps[Out, Mat](msg: String) extends PortOps[Out, Mat](null, null) { diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala index c5f6ade5bc..f32b7a4b07 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/One2OneBidiFlow.scala @@ -39,7 +39,7 @@ object One2OneBidiFlow { override def toString = "One2OneBidi" - override def createLogic: GraphStageLogic = new GraphStageLogic(shape) { + override def createLogic(effectiveAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { private var pending = 0 private var pullsSuppressed = 0 diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala index 58dba3ca1d..7a287147f8 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Source.scala @@ -5,7 +5,7 @@ package akka.stream.scaladsl import akka.actor.{ ActorRef, Cancellable, Props } import akka.stream.actor.ActorPublisher -import akka.stream.impl.Stages.{ DefaultAttributes, MaterializingStageFactory, StageModule } +import akka.stream.impl.Stages.{ DefaultAttributes, StageModule } import akka.stream.impl.StreamLayout.Module import akka.stream.impl.fusing.GraphStages.TickSource import akka.stream.impl.{ EmptyPublisher, ErrorPublisher, _ } @@ -32,7 +32,7 @@ final class Source[+Out, +Mat](private[stream] override val module: Module) override val shape: SourceShape[Out] = module.shape.asInstanceOf[SourceShape[Out]] def viaMat[T, Mat2, Mat3](flow: Graph[FlowShape[Out, T], Mat2])(combine: (Mat, Mat2) ⇒ Mat3): Source[T, Mat3] = { - if (flow.module.isInstanceOf[Stages.Identity]) this.asInstanceOf[Source[T, Mat3]] + if (flow.module eq Stages.identityGraph.module) this.asInstanceOf[Source[T, Mat3]] else { val flowCopy = flow.module.carbonCopy new Source( @@ -64,7 +64,7 @@ final class Source[+Out, +Mat](private[stream] override val module: Module) new Source(module.transformMaterializedValue(f.asInstanceOf[Any ⇒ Any])) /** INTERNAL API */ - override private[scaladsl] def andThen[U](op: StageModule): Repr[U, Mat] = { + override private[scaladsl] def deprecatedAndThen[U](op: StageModule): Repr[U, Mat] = { // No need to copy here, op is a fresh instance new Source( module @@ -72,13 +72,6 @@ final class Source[+Out, +Mat](private[stream] override val module: Module) .replaceShape(SourceShape(op.outPort))) } - override private[scaladsl] def andThenMat[U, Mat2](op: MaterializingStageFactory): Repr[U, Mat2] = { - new Source( - module - .fuse(op, shape.outlet, op.inPort, Keep.right) - .replaceShape(SourceShape(op.outPort))) - } - /** * Connect this `Source` to a `Sink` and run it. The returned value is the materialized value * of the `Sink`, e.g. the `Publisher` of a [[akka.stream.scaladsl.Sink#publisher]]. diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala index 125b208242..3f16f5258f 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala @@ -208,7 +208,7 @@ class Tcp(system: ExtendedActorSystem) extends akka.actor.Extension { case _ ⇒ Flow[ByteString] } - Flow[ByteString].andThenMat(() ⇒ { + Flow[ByteString].deprecatedAndThenMat(() ⇒ { val processorPromise = Promise[Processor[ByteString, ByteString]]() val localAddressPromise = Promise[InetSocketAddress]() manager ! StreamTcpManager.Connect(processorPromise, localAddressPromise, remoteAddress, localAddress, halfClose, options, diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 0cd99a338c..2ff8644f0c 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -5,15 +5,18 @@ package akka.stream.stage import akka.actor.{ Cancellable, DeadLetterSuppression } import akka.stream._ +import akka.stream.impl.ReactiveStreamsCompliance import akka.stream.impl.StreamLayout.Module import akka.stream.impl.fusing.{ GraphModule, GraphInterpreter } import akka.stream.impl.fusing.GraphInterpreter.GraphAssembly import scala.collection.{ immutable, mutable } import scala.concurrent.duration.FiniteDuration import scala.collection.mutable.ArrayBuffer +import scala.annotation.tailrec abstract class GraphStageWithMaterializedValue[+S <: Shape, +M] extends Graph[S, M] { - def createLogicAndMaterializedValue: (GraphStageLogic, M) + + def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, M) final override private[stream] lazy val module: Module = GraphModule( @@ -39,9 +42,10 @@ abstract class GraphStageWithMaterializedValue[+S <: Shape, +M] extends Graph[S, * logic that ties the ports together. */ abstract class GraphStage[S <: Shape] extends GraphStageWithMaterializedValue[S, Unit] { - final override def createLogicAndMaterializedValue = (createLogic, Unit) + final override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = + (createLogic(inheritedAttributes), Unit) - def createLogic: GraphStageLogic + def createLogic(inheritedAttributes: Attributes): GraphStageLogic } /** @@ -91,7 +95,7 @@ object GraphStageLogic { */ class ConditionalTerminateInput(predicate: () ⇒ Boolean) extends InHandler { override def onPush(): Unit = () - override def onUpstreamFinish(): Unit = if (predicate()) ownerStageLogic.completeStage() + override def onUpstreamFinish(): Unit = if (predicate()) inOwnerStageLogic.completeStage() } /** @@ -125,7 +129,7 @@ object GraphStageLogic { */ class ConditionalTerminateOutput(predicate: () ⇒ Boolean) extends OutHandler { override def onPull(): Unit = () - override def onDownstreamFinish(): Unit = if (predicate()) ownerStageLogic.completeStage() + override def onDownstreamFinish(): Unit = if (predicate()) outOwnerStageLogic.completeStage() } private object DoNothing extends (() ⇒ Unit) { @@ -146,24 +150,11 @@ object GraphStageLogic { * The stage logic is always stopped once all its input and output ports have been closed, i.e. it is not possible to * keep the stage alive for further processing once it does not have any open ports. */ -abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { +abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: Int) { import GraphInterpreter._ - import TimerMessages._ import GraphStageLogic._ def this(shape: Shape) = this(shape.inlets.size, shape.outlets.size) - - private val keyToTimers = mutable.Map[Any, Timer]() - private val timerIdGen = Iterator from 1 - - private var _timerAsyncCallback: AsyncCallback[Scheduled] = _ - private def getTimerAsyncCallback: AsyncCallback[Scheduled] = { - if (_timerAsyncCallback eq null) - _timerAsyncCallback = getAsyncCallback(onInternalTimer) - - _timerAsyncCallback - } - /** * INTERNAL API */ @@ -172,20 +163,14 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { /** * INTERNAL API */ - private[stream] var inHandlers = Array.ofDim[InHandler](inCount) - /** - * INTERNAL API - */ - private[stream] var outHandlers = Array.ofDim[OutHandler](outCount) + // Using common array to reduce overhead for small port counts + private[stream] var handlers = Array.ofDim[Any](inCount + outCount) /** * INTERNAL API */ - private[stream] var inToConn = Array.ofDim[Int](inHandlers.length) - /** - * INTERNAL API - */ - private[stream] var outToConn = Array.ofDim[Int](outHandlers.length) + // Using common array to reduce overhead for small port counts + private[stream] var portToConn = Array.ofDim[Int](handlers.length) /** * INTERNAL API @@ -243,32 +228,35 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { * Assigns callbacks for the events for an [[Inlet]] */ final protected def setHandler(in: Inlet[_], handler: InHandler): Unit = { - handler.ownerStageLogic = this - inHandlers(in.id) = handler - if (_interpreter != null) _interpreter.setHandler(inToConn(in.id), handler) + handler.inOwnerStageLogic = this + handlers(in.id) = handler + if (_interpreter != null) _interpreter.setHandler(conn(in), handler) } /** * Retrieves the current callback for the events on the given [[Inlet]] */ final protected def getHandler(in: Inlet[_]): InHandler = { - inHandlers(in.id) + handlers(in.id).asInstanceOf[InHandler] } /** * Assigns callbacks for the events for an [[Outlet]] */ final protected def setHandler(out: Outlet[_], handler: OutHandler): Unit = { - handler.ownerStageLogic = this - outHandlers(out.id) = handler - if (_interpreter != null) _interpreter.setHandler(outToConn(out.id), handler) + handler.outOwnerStageLogic = this + handlers(out.id + inCount) = handler + if (_interpreter != null) _interpreter.setHandler(conn(out), handler) } + private def conn(in: Inlet[_]): Int = portToConn(in.id) + private def conn(out: Outlet[_]): Int = portToConn(out.id + inCount) + /** * Retrieves the current callback for the events on the given [[Outlet]] */ final protected def getHandler(out: Outlet[_]): OutHandler = { - outHandlers(out.id) + handlers(out.id + inCount).asInstanceOf[OutHandler] } private def getNonEmittingHandler(out: Outlet[_]): OutHandler = @@ -277,9 +265,6 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { case other ⇒ other } - private def conn[T](in: Inlet[T]): Int = inToConn(in.id) - private def conn[T](out: Outlet[T]): Int = outToConn(out.id) - /** * Requests an element on the given port. Calling this method twice before an element arrived will fail. * There can only be one outstanding request at any given time. The method [[hasBeenPulled()]] can be used @@ -374,10 +359,11 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { * used to check if the port is ready to be pushed or not. */ final protected def push[T](out: Outlet[T], elem: T): Unit = { - if ((interpreter.portStates(conn(out)) & (OutReady | OutClosed)) == OutReady) { + if ((interpreter.portStates(conn(out)) & (OutReady | OutClosed)) == OutReady && (elem != null)) { interpreter.push(conn(out), elem) } else { // Detailed error information should not add overhead to the hot path + ReactiveStreamsCompliance.requireNonNullElement(elem) require(isAvailable(out), "Cannot push port twice") require(!isClosed(out), "Cannot pull closed port") } @@ -399,13 +385,13 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { */ final def completeStage(): Unit = { var i = 0 - while (i < inToConn.length) { - interpreter.cancel(inToConn(i)) - i += 1 - } - i = 0 - while (i < outToConn.length) { - interpreter.complete(outToConn(i)) + while (i < portToConn.length) { + if (i < inCount) + interpreter.cancel(portToConn(i)) + else handlers(i) match { + case e: Emitting[_] ⇒ e.addFollowUp(new EmittingCompletion(e.out, e.previous)) + case _ ⇒ interpreter.complete(portToConn(i)) + } i += 1 } } @@ -416,13 +402,11 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { */ final def failStage(ex: Throwable): Unit = { var i = 0 - while (i < inToConn.length) { - interpreter.cancel(inToConn(i)) - i += 1 - } - i = 0 - while (i < outToConn.length) { - interpreter.fail(outToConn(i), ex) + while (i < portToConn.length) { + if (i < inCount) + interpreter.cancel(portToConn(i)) + else + interpreter.fail(portToConn(i), ex) i += 1 } } @@ -612,22 +596,53 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { case _ ⇒ setHandler(out, next) } - private abstract class Emitting[T](protected val out: Outlet[T], val previous: OutHandler, andThen: () ⇒ Unit) extends OutHandler { - private var followUps: mutable.Queue[Emitting[T]] = null + private abstract class Emitting[T](val out: Outlet[T], val previous: OutHandler, andThen: () ⇒ Unit) extends OutHandler { + private var followUps: Emitting[T] = _ + private var followUpsTail: Emitting[T] = _ + private def as[U] = this.asInstanceOf[Emitting[U]] protected def followUp(): Unit = { setHandler(out, previous) andThen() if (followUps != null) { - val next = followUps.dequeue() - if (followUps.nonEmpty) next.followUps = followUps - setHandler(out, next) + getHandler(out) match { + case e: Emitting[_] ⇒ e.as[T].addFollowUps(this) + case _ ⇒ + val next = dequeue() + if (next.isInstanceOf[EmittingCompletion[_]]) complete(out) + else setHandler(out, next) + } } } def addFollowUp(e: Emitting[T]): Unit = - if (followUps == null) followUps = mutable.Queue(e) - else followUps.enqueue(e) + if (followUps == null) { + followUps = e + followUpsTail = e + } else { + followUpsTail.followUps = e + followUpsTail = e + } + + private def addFollowUps(e: Emitting[T]): Unit = + if (followUps == null) { + followUps = e.followUps + followUpsTail = e.followUpsTail + } else { + followUpsTail.followUps = e.followUps + followUpsTail = e.followUpsTail + } + + /** + * Dequeue `this` from the head of the queue, meaning that this object will + * not be retained (setHandler will install the followUp). For this reason + * the followUpsTail knowledge needs to be passed on to the next runner. + */ + private def dequeue(): Emitting[T] = { + val ret = followUps + ret.followUpsTail = followUpsTail + ret + } override def onDownstreamFinish(): Unit = previous.onDownstreamFinish() } @@ -652,6 +667,10 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { } } + private class EmittingCompletion[T](_out: Outlet[T], _previous: OutHandler) extends Emitting(_out, _previous, DoNothing) { + override def onPull(): Unit = complete(out) + } + /** * Install a handler on the given inlet that emits received elements on the * given outlet before pulling for more data. `doTerminate` controls whether @@ -659,9 +678,10 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { */ final protected def passAlong[Out, In <: Out](from: Inlet[In], to: Outlet[Out], doFinish: Boolean, doFail: Boolean): Unit = setHandler(from, new InHandler { + val puller = () ⇒ tryPull(from) override def onPush(): Unit = { val elem = grab(from) - emit(to, elem, () ⇒ tryPull(from)) + emit(to, elem, puller) } override def onUpstreamFinish(): Unit = if (doFinish) super.onUpstreamFinish() override def onUpstreamFailure(ex: Throwable): Unit = if (doFail) super.onUpstreamFailure(ex) @@ -669,7 +689,7 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { /** * Obtain a callback object that can be used asynchronously to re-enter the - * current [[AsyncStage]] with an asynchronous notification. The [[invoke()]] method of the returned + * current [[GraphStage]] with an asynchronous notification. The [[invoke()]] method of the returned * [[AsyncCallback]] is safe to be called from other threads and it will in the background thread-safely * delegate to the passed callback function. I.e. [[invoke()]] will be called by the external world and * the passed handler will be invoked eventually in a thread-safe way by the execution environment. @@ -683,6 +703,50 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { } } + // Internal hooks to avoid reliance on user calling super in preStart + protected[stream] def beforePreStart(): Unit = () + + // Internal hooks to avoid reliance on user calling super in postStop + protected[stream] def afterPostStop(): Unit = () + + /** + * Invoked before any external events are processed, at the startup of the stage. + */ + def preStart(): Unit = () + + /** + * Invoked after processing of external events stopped because the stage is about to stop or fail. + */ + def postStop(): Unit = () +} + +/** + * An asynchronous callback holder that is attached to a [[GraphStageLogic]]. + * Invoking [[AsyncCallback#invoke]] will eventually lead to the registered handler + * being called. + */ +trait AsyncCallback[T] { + /** + * Dispatch an asynchronous notification. This method is thread-safe and + * may be invoked from external execution contexts. + */ + def invoke(t: T): Unit +} + +abstract class TimerGraphStageLogic(_shape: Shape) extends GraphStageLogic(_shape) { + import TimerMessages._ + + private val keyToTimers = mutable.Map[Any, Timer]() + private val timerIdGen = Iterator from 1 + + private var _timerAsyncCallback: AsyncCallback[Scheduled] = _ + private def getTimerAsyncCallback: AsyncCallback[Scheduled] = { + if (_timerAsyncCallback eq null) + _timerAsyncCallback = getAsyncCallback(onInternalTimer) + + _timerAsyncCallback + } + private def onInternalTimer(scheduled: Scheduled): Unit = { val Id = scheduled.timerId keyToTimers.get(scheduled.timerKey) match { @@ -694,12 +758,18 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { } /** - * Schedule timer to call [[#onTimer]] periodically with the given interval. - * Any existing timer with the same key will automatically be canceled before - * adding the new timer. + * Will be called when the scheduled timer is triggered. + * @param timerKey key of the scheduled timer */ - final protected def schedulePeriodically(timerKey: Any, interval: FiniteDuration): Unit = - schedulePeriodicallyWithInitialDelay(timerKey, interval, interval) + protected def onTimer(timerKey: Any): Unit = () + + // Internal hooks to avoid reliance on user calling super in postStop + protected[stream] override def afterPostStop(): Unit = { + if (keyToTimers ne null) { + keyToTimers.foreach { case (_, Timer(_, task)) ⇒ task.cancel() } + keyToTimers.clear() + } + } /** * Schedule timer to call [[#onTimer]] periodically with the given interval after the specified @@ -751,30 +821,13 @@ abstract class GraphStageLogic private[stream] (inCount: Int, outCount: Int) { final protected def isTimerActive(timerKey: Any): Boolean = keyToTimers contains timerKey /** - * Will be called when the scheduled timer is triggered. - * @param timerKey key of the scheduled timer + * Schedule timer to call [[#onTimer]] periodically with the given interval. + * Any existing timer with the same key will automatically be canceled before + * adding the new timer. */ - protected def onTimer(timerKey: Any): Unit = () + final protected def schedulePeriodically(timerKey: Any, interval: FiniteDuration): Unit = + schedulePeriodicallyWithInitialDelay(timerKey, interval, interval) - // Internal hooks to avoid reliance on user calling super in preStart - protected[stream] def beforePreStart(): Unit = { - } - - // Internal hooks to avoid reliance on user calling super in postStop - protected[stream] def afterPostStop(): Unit = { - keyToTimers.foreach { case (_, Timer(_, task)) ⇒ task.cancel() } - keyToTimers.clear() - } - - /** - * Invoked before any external events are processed, at the startup of the stage. - */ - def preStart(): Unit = () - - /** - * Invoked after processing of external events stopped because the stage is about to stop or fail. - */ - def postStop(): Unit = () } /** @@ -784,7 +837,7 @@ trait InHandler { /** * INTERNAL API */ - private[stream] var ownerStageLogic: GraphStageLogic = _ + private[stream] var inOwnerStageLogic: GraphStageLogic = _ /** * Called when the input port has a new element available. The actual element can be retrieved via the @@ -795,12 +848,12 @@ trait InHandler { /** * Called when the input port is finished. After this callback no other callbacks will be called for this port. */ - def onUpstreamFinish(): Unit = ownerStageLogic.completeStage() + def onUpstreamFinish(): Unit = inOwnerStageLogic.completeStage() /** * Called when the input port has failed. After this callback no other callbacks will be called for this port. */ - def onUpstreamFailure(ex: Throwable): Unit = ownerStageLogic.failStage(ex) + def onUpstreamFailure(ex: Throwable): Unit = inOwnerStageLogic.failStage(ex) } /** @@ -810,7 +863,7 @@ trait OutHandler { /** * INTERNAL API */ - private[stream] var ownerStageLogic: GraphStageLogic = _ + private[stream] var outOwnerStageLogic: GraphStageLogic = _ /** * Called when the output port has received a pull, and therefore ready to emit an element, i.e. [[GraphStageLogic.push()]] @@ -822,5 +875,5 @@ trait OutHandler { * Called when the output port will no longer accept any new elements. After this callback no other callbacks will * be called for this port. */ - def onDownstreamFinish(): Unit = ownerStageLogic.completeStage() + def onDownstreamFinish(): Unit = outOwnerStageLogic.completeStage() } \ No newline at end of file diff --git a/akka-stream/src/main/scala/akka/stream/stage/Stage.scala b/akka-stream/src/main/scala/akka/stream/stage/Stage.scala index 34c2acb981..5a3cc187c7 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/Stage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/Stage.scala @@ -3,7 +3,10 @@ */ package akka.stream.stage -import akka.stream.{ Attributes, Materializer, Supervision } +import akka.stream._ + +import scala.annotation.unchecked.uncheckedVariance +import scala.util.control.NonFatal /** * General interface for stream transformation. @@ -25,76 +28,169 @@ import akka.stream.{ Attributes, Materializer, Supervision } * @see [[akka.stream.scaladsl.Flow#transform]] * @see [[akka.stream.javadsl.Flow#transform]] */ -sealed trait Stage[-In, Out] +sealed trait Stage[-In, +Out] /** * INTERNAL API */ private[stream] object AbstractStage { - final val UpstreamBall = 1 - final val DownstreamBall = 2 - final val PrecedingWasPull = 0x4000 - final val NoTerminationPending = 0x8000 - final val BothBalls = UpstreamBall | DownstreamBall - final val BothBallsAndNoTerminationPending = UpstreamBall | DownstreamBall | NoTerminationPending + + private class PushPullGraphLogic[In, Out]( + private val shape: FlowShape[In, Out], + val attributes: Attributes, + val stage: AbstractStage[In, Out, Directive, Directive, Context[Out], LifecycleContext]) + extends GraphStageLogic(shape) with DetachedContext[Out] { + + final override def materializer: Materializer = interpreter.materializer + + private def ctx: DetachedContext[Out] = this + + private var currentStage: AbstractStage[In, Out, Directive, Directive, Context[Out], LifecycleContext] = stage + + { + // No need to refer to the handle in a private val + val handler = new InHandler with OutHandler { + override def onPush(): Unit = + try { currentStage.onPush(grab(shape.inlet), ctx) } catch { case NonFatal(ex) ⇒ onSupervision(ex) } + + override def onPull(): Unit = currentStage.onPull(ctx) + + override def onUpstreamFinish(): Unit = currentStage.onUpstreamFinish(ctx) + + override def onUpstreamFailure(ex: Throwable): Unit = currentStage.onUpstreamFailure(ex, ctx) + + override def onDownstreamFinish(): Unit = currentStage.onDownstreamFinish(ctx) + } + + setHandler(shape.inlet, handler) + setHandler(shape.outlet, handler) + } + + private def onSupervision(ex: Throwable): Unit = { + currentStage.decide(ex) match { + case Supervision.Stop ⇒ + failStage(ex) + case Supervision.Resume ⇒ + resetAfterSupervise() + case Supervision.Restart ⇒ + resetAfterSupervise() + currentStage.postStop() + currentStage = currentStage.restart().asInstanceOf[AbstractStage[In, Out, Directive, Directive, Context[Out], LifecycleContext]] + currentStage.preStart(ctx) + } + } + + private def resetAfterSupervise(): Unit = { + val mustPull = currentStage.isDetached || isAvailable(shape.outlet) + if (!hasBeenPulled(shape.inlet) && mustPull) pull(shape.inlet) + } + + override protected[stream] def beforePreStart(): Unit = { + super.beforePreStart() + if (currentStage.isDetached) pull(shape.inlet) + } + + final override def push(elem: Out): DownstreamDirective = { + push(shape.outlet, elem) + null + } + + final override def pull(): UpstreamDirective = { + pull(shape.inlet) + null + } + + final override def finish(): FreeDirective = { + completeStage() + null + } + + final override def pushAndFinish(elem: Out): DownstreamDirective = { + push(shape.outlet, elem) + completeStage() + null + } + + final override def fail(cause: Throwable): FreeDirective = { + failStage(cause) + null + } + + final override def isFinishing: Boolean = isClosed(shape.inlet) + + final override def absorbTermination(): TerminationDirective = { + if (isClosed(shape.outlet)) { + val ex = new UnsupportedOperationException("It is not allowed to call absorbTermination() from onDownstreamFinish.") + // This MUST be logged here, since the downstream has cancelled, i.e. there is noone to send onError to, the + // stage is just about to finish so noone will catch it anyway just the interpreter + + interpreter.log.error(ex.getMessage) + throw ex // We still throw for correctness (although a finish() would also work here) + } + if (isAvailable(shape.outlet)) currentStage.onPull(ctx) + null + } + + override def pushAndPull(elem: Out): FreeDirective = { + push(shape.outlet, elem) + pull(shape.inlet) + null + } + + final override def holdUpstreamAndPush(elem: Out): UpstreamDirective = { + push(shape.outlet, elem) + null + } + + final override def holdDownstreamAndPull(): DownstreamDirective = { + pull(shape.inlet) + null + } + + final override def isHoldingDownstream: Boolean = isAvailable(shape.outlet) + + final override def isHoldingUpstream: Boolean = !(isClosed(shape.inlet) || hasBeenPulled(shape.inlet)) + + final override def holdDownstream(): DownstreamDirective = null + + final override def holdUpstream(): UpstreamDirective = null + + override def preStart(): Unit = currentStage.preStart(ctx) + override def postStop(): Unit = currentStage.postStop() + + override def toString: String = s"PushPullGraphLogic($currentStage)" + } + + class PushPullGraphStageWithMaterializedValue[-In, +Out, Ext, +Mat]( + val factory: (Attributes) ⇒ (Stage[In, Out], Mat), + stageAttributes: Attributes) + extends GraphStageWithMaterializedValue[FlowShape[In, Out], Mat] { + + val name = stageAttributes.nameOrDefault() + val shape = FlowShape(Inlet[In](name + ".in"), Outlet[Out](name + ".out")) + + override def toString = name + + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Mat) = { + val effectiveAttributes = inheritedAttributes and stageAttributes + val stageAndMat = factory(effectiveAttributes) + val stage: AbstractStage[In, Out, Directive, Directive, Context[Out], LifecycleContext] = + stageAndMat._1.asInstanceOf[AbstractStage[In, Out, Directive, Directive, Context[Out], LifecycleContext]] + (new PushPullGraphLogic(shape, effectiveAttributes, stage), stageAndMat._2) + } + } + + class PushPullGraphStage[-In, +Out, Ext](_factory: (Attributes) ⇒ Stage[In, Out], _stageAttributes: Attributes) + extends PushPullGraphStageWithMaterializedValue[In, Out, Ext, Unit]((att: Attributes) ⇒ (_factory(att), ()), _stageAttributes) } abstract class AbstractStage[-In, Out, PushD <: Directive, PullD <: Directive, Ctx <: Context[Out], LifeCtx <: LifecycleContext] extends Stage[In, Out] { - /** - * INTERNAL API - */ - private[stream] var bits = AbstractStage.NoTerminationPending - - /** - * INTERNAL API - */ - private[stream] var context: Ctx = _ /** * INTERNAL API */ private[stream] def isDetached: Boolean = false - /** - * INTERNAL API - */ - private[stream] def enterAndPush(elem: Out): Unit = { - val c = context - c.enter() - c.push(elem) - c.execute() - } - - /** - * INTERNAL API - */ - private[stream] def enterAndPull(): Unit = { - val c = context - c.enter() - c.pull() - c.execute() - } - - /** - * INTERNAL API - */ - private[stream] def enterAndFinish(): Unit = { - val c = context - c.enter() - c.finish() - c.execute() - } - - /** - * INTERNAL API - */ - private[stream] def enterAndFail(e: Throwable): Unit = { - val c = context - c.enter() - c.fail(e) - c.execute() - } - /** * User overridable callback. *

@@ -285,26 +381,6 @@ abstract class DetachedStage[In, Out] override def decide(t: Throwable): Supervision.Directive = super.decide(t) } -/** - * This is a variant of [[DetachedStage]] that can receive asynchronous input - * from external sources, for example timers or Future results. In order to - * do this, obtain an [[AsyncCallback]] from the [[AsyncContext]] and attach - * it to the asynchronous event. When the event fires an asynchronous notification - * will be dispatched that eventually will lead to `onAsyncInput` being invoked - * with the provided data item. - */ -abstract class AsyncStage[In, Out, Ext] - extends AbstractStage[In, Out, UpstreamDirective, DownstreamDirective, AsyncContext[Out, Ext], AsyncContext[Out, Ext]] { - private[stream] override def isDetached = true - - /** - * Implement this method to define the action to be taken in response to an - * asynchronous notification that was previously registered using - * [[AsyncContext#getAsyncCallback]]. - */ - def onAsyncInput(event: Ext, ctx: AsyncContext[Out, Ext]): Directive -} - /** * The behavior of [[StatefulStage]] is defined by these two methods, which * has the same semantics as corresponding methods in [[PushPullStage]]. @@ -507,27 +583,6 @@ abstract class StatefulStage[In, Out] extends PushPullStage[In, Out] { } -/** - * INTERNAL API - * - * `BoundaryStage` implementations are meant to communicate with the external world. These stages do not have most of the - * safety properties enforced and should be used carefully. One important ability of BoundaryStages that they can take - * off an execution signal by calling `ctx.exit()`. This is typically used immediately after an external signal has - * been produced (for example an actor message). BoundaryStages can also kickstart execution by calling `enter()` which - * returns a context they can use to inject signals into the interpreter. There is no checks in place to enforce that - * the number of signals taken out by exit() and the number of signals returned via enter() are the same -- using this - * stage type needs extra care from the implementer. - * - * BoundaryStages are the elements that make the interpreter *tick*, there is no other way to start the interpreter - * than using a BoundaryStage. - */ -private[akka] abstract class BoundaryStage extends AbstractStage[Any, Any, Directive, Directive, BoundaryContext, LifecycleContext] { - final override def decide(t: Throwable): Supervision.Directive = Supervision.Stop - - final override def restart(): BoundaryStage = - throw new UnsupportedOperationException("BoundaryStage doesn't support restart") -} - /** * Return type from [[Context]] methods. */ @@ -555,16 +610,6 @@ trait LifecycleContext { * Passed to the callback methods of [[PushPullStage]] and [[StatefulStage]]. */ sealed trait Context[Out] extends LifecycleContext { - /** - * INTERNAL API - */ - private[stream] def enter(): Unit - - /** - * INTERNAL API - */ - private[stream] def execute(): Unit - /** * Push one element to downstreams. */ @@ -625,48 +670,3 @@ trait DetachedContext[Out] extends Context[Out] { def pushAndPull(elem: Out): FreeDirective } - -/** - * An asynchronous callback holder that is attached to an [[AsyncContext]]. - * Invoking [[AsyncCallback#invoke]] will eventually lead to [[AsyncStage#onAsyncInput]] - * being called. - */ -trait AsyncCallback[T] { - /** - * Dispatch an asynchronous notification. This method is thread-safe and - * may be invoked from external execution contexts. - */ - def invoke(t: T): Unit -} - -/** - * This kind of context is available to [[AsyncStage]]. It implements the same - * interface as for [[DetachedStage]] with the addition of being able to obtain - * [[AsyncCallback]] objects that allow the registration of asynchronous - * notifications. - */ -trait AsyncContext[Out, Ext] extends DetachedContext[Out] { - /** - * Obtain a callback object that can be used asynchronously to re-enter the - * current [[AsyncStage]] with an asynchronous notification. After the - * notification has been invoked, eventually [[AsyncStage#onAsyncInput]] - * will be called with the given data item. - * - * This object can be cached and reused within the same [[AsyncStage]]. - */ - def getAsyncCallback: AsyncCallback[Ext] - /** - * In response to an asynchronous notification an [[AsyncStage]] may choose - * to neither push nor pull nor terminate, which is represented as this - * directive. - */ - def ignore(): AsyncDirective -} - -/** - * INTERNAL API - */ -private[akka] trait BoundaryContext extends Context[Any] { - def exit(): FreeDirective -} -