From a82f266367c6ec2a6043e1b7057022303d5d9cc4 Mon Sep 17 00:00:00 2001 From: Patrik Nordwall Date: Wed, 12 Nov 2014 10:43:39 +0100 Subject: [PATCH] !str #15236 Replace Transformer with Stage * replace all existing Transformer with Stage (PushPullStage) * use Flow[ByteString, ByteString] as encoder/decoder transformer in http * use the IteratorInterpreter for strict if possible * emit then become * emit then finish * termination emits * FlowTransformerSpec * rework types to work with Java API * rename and move things * add scaladoc --- .../engine/client/HttpClientPipeline.scala | 2 +- .../http/engine/parsing/BodyPartParser.scala | 36 +- .../engine/parsing/HttpMessageParser.scala | 27 +- .../engine/rendering/BodyPartRenderer.scala | 38 +- .../HttpRequestRendererFactory.scala | 16 +- .../HttpResponseRendererFactory.scala | 22 +- .../http/engine/rendering/RenderSupport.scala | 38 +- .../engine/server/HttpServerPipeline.scala | 23 +- .../scala/akka/http/model/HttpEntity.scala | 113 +++-- .../scala/akka/http/util/StreamUtils.scala | 161 ++++--- .../main/scala/akka/http/util/package.scala | 34 +- .../scala/akka/http/ClientServerSpec.scala | 2 +- .../src/test/scala/akka/http/TestClient.scala | 4 +- .../rendering/RequestRendererSpec.scala | 4 +- .../rendering/ResponseRendererSpec.scala | 4 +- .../akka/http/model/HttpEntitySpec.scala | 12 +- .../directives/CodingDirectivesSpec.scala | 2 +- .../scala/akka/http/coding/DataMapper.scala | 12 +- .../main/scala/akka/http/coding/Decoder.scala | 4 +- .../main/scala/akka/http/coding/Encoder.scala | 6 +- .../main/scala/akka/http/coding/Gzip.scala | 1 + .../server/directives/CodingDirectives.scala | 2 +- .../server/directives/RangeDirectives.scala | 4 +- .../MultipartUnmarshallers.scala | 13 +- .../stream/tck/FusableProcessorTest.scala | 4 +- .../stream/tck/TransformProcessorTest.scala | 11 +- .../java/akka/stream/javadsl/FlowTest.java | 133 ++---- .../akka/stream/DslConsistencySpec.scala | 2 +- .../stream/actor/ActorPublisherSpec.scala | 2 +- .../impl/fusing/InterpreterSpecKit.scala | 43 +- .../impl/fusing/IteratorInterpreterSpec.scala | 59 +-- .../scaladsl/FlowGraphCompileSpec.scala | 10 +- .../scala/akka/stream/scaladsl/FlowSpec.scala | 28 +- ...ransformSpec.scala => FlowStageSpec.scala} | 268 +++++------ .../scaladsl/FlowTransformRecoverSpec.scala | 390 ---------------- ...mizingActorBasedFlowMaterializerSpec.scala | 2 +- .../scala/akka/stream/TimerTransformer.scala | 2 +- .../main/scala/akka/stream/Transformer.scala | 18 +- .../main/scala/akka/stream/extra/Timed.scala | 38 +- .../impl/ActorBasedFlowMaterializer.scala | 78 ++-- .../akka/stream/impl/ConcatAllImpl.scala | 9 + .../stream/impl/GroupByProcessorImpl.scala | 9 + .../stream/impl/MapAsyncProcessorImpl.scala | 8 +- .../impl/MapAsyncUnorderedProcessorImpl.scala | 4 + .../akka/stream/impl/PrefixAndTailImpl.scala | 9 + .../stream/impl/SingleStreamProcessors.scala | 76 --- .../stream/impl/SplitWhenProcessorImpl.scala | 9 + .../impl/TimerTransformerProcessorsImpl.scala | 35 +- .../stream/impl/fusing/ActorInterpreter.scala | 56 ++- .../akka/stream/impl/fusing/Interpreter.scala | 172 ++++--- .../impl/fusing/IteratorInterpreter.scala | 64 ++- .../scala/akka/stream/impl/fusing/Ops.scala | 210 ++++----- .../akka/stream/javadsl/FlexiMerge.scala | 4 +- .../akka/stream/javadsl/FlexiRoute.scala | 6 +- .../main/scala/akka/stream/javadsl/Flow.scala | 42 +- .../scala/akka/stream/javadsl/Source.scala | 46 +- .../akka/stream/scaladsl/ActorFlowSink.scala | 68 +-- .../akka/stream/scaladsl/FlexiMerge.scala | 2 +- .../scala/akka/stream/scaladsl/Flow.scala | 61 +-- .../scala/akka/stream/scaladsl/Pipe.scala | 4 +- .../main/scala/akka/stream/stage/Stage.scala | 434 ++++++++++++++++++ 61 files changed, 1478 insertions(+), 1518 deletions(-) rename akka-stream-tests/src/test/scala/akka/stream/scaladsl/{FlowTransformSpec.scala => FlowStageSpec.scala} (56%) delete mode 100644 akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowTransformRecoverSpec.scala delete mode 100644 akka-stream/src/main/scala/akka/stream/impl/SingleStreamProcessors.scala create mode 100644 akka-stream/src/main/scala/akka/stream/stage/Stage.scala diff --git a/akka-http-core/src/main/scala/akka/http/engine/client/HttpClientPipeline.scala b/akka-http-core/src/main/scala/akka/http/engine/client/HttpClientPipeline.scala index 8c3a9145e1..d4edaabab6 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/client/HttpClientPipeline.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/client/HttpClientPipeline.scala @@ -97,4 +97,4 @@ private[http] class HttpClientPipeline(effectiveSettings: ClientConnectionSettin method } else HttpResponseParser.NoMethod } -} \ No newline at end of file +} diff --git a/akka-http-core/src/main/scala/akka/http/engine/parsing/BodyPartParser.scala b/akka-http-core/src/main/scala/akka/http/engine/parsing/BodyPartParser.scala index 53810b35a7..d7d33197af 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/parsing/BodyPartParser.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/parsing/BodyPartParser.scala @@ -8,8 +8,8 @@ import scala.annotation.tailrec import scala.collection.mutable.ListBuffer import akka.event.LoggingAdapter import akka.parboiled2.CharPredicate -import akka.stream.Transformer import akka.stream.scaladsl.Source +import akka.stream.stage._ import akka.util.ByteString import akka.http.model._ import akka.http.util._ @@ -24,7 +24,7 @@ private[http] final class BodyPartParser(defaultContentType: ContentType, boundary: String, log: LoggingAdapter, settings: BodyPartParser.Settings = BodyPartParser.defaultSettings) - extends Transformer[ByteString, BodyPartParser.Output] { + extends PushPullStage[ByteString, BodyPartParser.Output] { import BodyPartParser._ import settings._ @@ -52,25 +52,43 @@ private[http] final class BodyPartParser(defaultContentType: ContentType, private[this] val headerParser = HttpHeaderParser(settings, warnOnIllegalHeader) // TODO: prevent re-priming header parser from scratch private[this] val result = new ListBuffer[Output] // transformer op is currently optimized for LinearSeqs + private[this] var resultIterator: Iterator[Output] = Iterator.empty private[this] var state: ByteString ⇒ StateResult = tryParseInitialBoundary private[this] var receivedInitialBoundary = false private[this] var terminated = false - override def isComplete = terminated - def warnOnIllegalHeader(errorInfo: ErrorInfo): Unit = if (illegalHeaderWarnings) log.warning(errorInfo.withSummaryPrepended("Illegal multipart header").formatPretty) - def onNext(input: ByteString): List[Output] = { + override def onPush(input: ByteString, ctx: Context[Output]): Directive = { result.clear() try state(input) catch { - case e: ParsingException ⇒ fail(e.info) - case NotEnoughDataException ⇒ throw new IllegalStateException(NotEnoughDataException) // we are missing a try/catch{continue} wrapper somewhere + case e: ParsingException ⇒ fail(e.info) + case NotEnoughDataException ⇒ + // we are missing a try/catch{continue} wrapper somewhere + throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException) } - result.toList + resultIterator = result.iterator + if (resultIterator.hasNext) ctx.push(resultIterator.next()) + else if (!terminated) ctx.pull() + else ctx.finish() } + override def onPull(ctx: Context[Output]): Directive = { + if (resultIterator.hasNext) + ctx.push(resultIterator.next()) + else if (ctx.isFinishing) { + if (terminated || !receivedInitialBoundary) + ctx.finish() + else + ctx.pushAndFinish(ParseError(ErrorInfo("Unexpected end of multipart entity"))) + } else + ctx.pull() + } + + override def onUpstreamFinish(ctx: Context[Output]): TerminationDirective = ctx.absorbTermination() + def tryParseInitialBoundary(input: ByteString): StateResult = // we don't use boyerMoore here because we are testing for the boundary *without* a // preceding CRLF and at a known location (the very beginning of the entity) @@ -223,8 +241,6 @@ private[http] final class BodyPartParser(defaultContentType: ContentType, def doubleDash(input: ByteString, offset: Int): Boolean = byteChar(input, offset) == '-' && byteChar(input, offset + 1) == '-' - override def onTermination(e: Option[Throwable]): List[BodyPartParser.Output] = - if (terminated || !receivedInitialBoundary) Nil else ParseError(ErrorInfo("Unexpected end of multipart entity")) :: Nil } private[http] object BodyPartParser { diff --git a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala index 7f8bbf271a..a1d5359f64 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/parsing/HttpMessageParser.scala @@ -9,8 +9,8 @@ import scala.collection.mutable.ListBuffer import scala.collection.immutable import akka.parboiled2.CharUtils import akka.util.ByteString -import akka.stream.Transformer import akka.stream.scaladsl.Source +import akka.stream.stage._ import akka.http.model.parser.CharacterClasses import akka.http.model._ import headers._ @@ -21,7 +21,7 @@ import HttpProtocols._ */ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOutput <: ParserOutput](val settings: ParserSettings, val headerParser: HttpHeaderParser) - extends Transformer[ByteString, Output] { + extends StatefulStage[ByteString, Output] { import settings._ sealed trait StateResult // phantom type for ensuring soundness of our parsing method setup @@ -30,16 +30,21 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut private[this] var state: ByteString ⇒ StateResult = startNewMessage(_, 0) private[this] var protocol: HttpProtocol = `HTTP/1.1` private[this] var terminated = false - override def isComplete = terminated - def onNext(input: ByteString): immutable.Seq[Output] = { - result.clear() - try state(input) - catch { - case e: ParsingException ⇒ fail(e.status, e.info) - case NotEnoughDataException ⇒ throw new IllegalStateException // we are missing a try/catch{continue} wrapper somewhere + override def initial = new State { + override def onPush(input: ByteString, ctx: Context[Output]): Directive = { + result.clear() + try state(input) + catch { + case e: ParsingException ⇒ fail(e.status, e.info) + case NotEnoughDataException ⇒ + // we are missing a try/catch{continue} wrapper somewhere + throw new IllegalStateException("unexpected NotEnoughDataException", NotEnoughDataException) + } + val resultIterator = result.iterator + if (terminated) emitAndFinish(resultIterator, ctx) + else emit(resultIterator, ctx) } - result.toList } def startNewMessage(input: ByteString, offset: Int): StateResult = { @@ -252,4 +257,4 @@ private[http] abstract class HttpMessageParser[Output >: ParserOutput.MessageOut case Some(x) ⇒ x :: headers case None ⇒ headers } -} \ No newline at end of file +} diff --git a/akka-http-core/src/main/scala/akka/http/engine/rendering/BodyPartRenderer.scala b/akka-http-core/src/main/scala/akka/http/engine/rendering/BodyPartRenderer.scala index da36a7f045..ff367e1b39 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/rendering/BodyPartRenderer.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/rendering/BodyPartRenderer.scala @@ -12,7 +12,7 @@ import akka.http.model.headers._ import akka.http.engine.rendering.RenderSupport._ import akka.http.util._ import akka.stream.scaladsl.Source -import akka.stream.Transformer +import akka.stream.stage._ import akka.util.ByteString import HttpEntity._ @@ -24,19 +24,19 @@ private[http] object BodyPartRenderer { def streamed(boundary: String, nioCharset: Charset, partHeadersSizeHint: Int, - log: LoggingAdapter): Transformer[Multipart.BodyPart, Source[ChunkStreamPart]] = - new Transformer[Multipart.BodyPart, Source[ChunkStreamPart]] { + log: LoggingAdapter): PushPullStage[Multipart.BodyPart, Source[ChunkStreamPart]] = + new PushPullStage[Multipart.BodyPart, Source[ChunkStreamPart]] { var firstBoundaryRendered = false - def onNext(bodyPart: Multipart.BodyPart): List[Source[ChunkStreamPart]] = { + override def onPush(bodyPart: Multipart.BodyPart, ctx: Context[Source[ChunkStreamPart]]): Directive = { val r = new CustomCharsetByteStringRendering(nioCharset, partHeadersSizeHint) - def bodyPartChunks(data: Source[ByteString]): List[Source[ChunkStreamPart]] = { + def bodyPartChunks(data: Source[ByteString]): Source[ChunkStreamPart] = { val entityChunks = data.map[ChunkStreamPart](Chunk(_)) - (Source(Chunk(r.get) :: Nil) ++ entityChunks) :: Nil + chunkStream(r.get) ++ entityChunks } - def completePartRendering(): List[Source[ChunkStreamPart]] = + def completePartRendering(): Source[ChunkStreamPart] = bodyPart.entity match { case x if x.isKnownEmpty ⇒ chunkStream(r.get) case Strict(_, data) ⇒ chunkStream((r ~~ data).get) @@ -48,18 +48,26 @@ private[http] object BodyPartRenderer { firstBoundaryRendered = true renderEntityContentType(r, bodyPart.entity) renderHeaders(r, bodyPart.headers, log) - completePartRendering() + ctx.push(completePartRendering()) } - override def onTermination(e: Option[Throwable]): List[Source[ChunkStreamPart]] = - if (e.isEmpty && firstBoundaryRendered) { + override def onPull(ctx: Context[Source[ChunkStreamPart]]): Directive = { + val finishing = ctx.isFinishing + if (finishing && firstBoundaryRendered) { val r = new ByteStringRendering(boundary.length + 4) renderFinalBoundary(r, boundary) - chunkStream(r.get) - } else Nil + ctx.pushAndFinish(chunkStream(r.get)) + } else if (finishing) + ctx.finish() + else + ctx.pull() + } + + override def onUpstreamFinish(ctx: Context[Source[ChunkStreamPart]]): TerminationDirective = ctx.absorbTermination() + + private def chunkStream(byteString: ByteString): Source[ChunkStreamPart] = + Source.singleton(Chunk(byteString)) - private def chunkStream(byteString: ByteString) = - Source[ChunkStreamPart](Chunk(byteString) :: Nil) :: Nil } def strict(parts: immutable.Seq[Multipart.BodyPart.Strict], boundary: String, nioCharset: Charset, @@ -102,4 +110,4 @@ private[http] object BodyPartRenderer { case x ⇒ r ~~ x ~~ CrLf } -} \ No newline at end of file +} diff --git a/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpRequestRendererFactory.scala b/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpRequestRendererFactory.scala index f8e9f12810..1d9fe5875e 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpRequestRendererFactory.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpRequestRendererFactory.scala @@ -9,7 +9,7 @@ import scala.annotation.tailrec import akka.event.LoggingAdapter import akka.util.ByteString import akka.stream.scaladsl.Source -import akka.stream.Transformer +import akka.stream.stage._ import akka.http.model._ import akka.http.util._ import RenderSupport._ @@ -24,9 +24,9 @@ private[http] class HttpRequestRendererFactory(userAgentHeader: Option[headers.` def newRenderer: HttpRequestRenderer = new HttpRequestRenderer - final class HttpRequestRenderer extends Transformer[RequestRenderingContext, Source[ByteString]] { + final class HttpRequestRenderer extends PushStage[RequestRenderingContext, Source[ByteString]] { - def onNext(ctx: RequestRenderingContext): List[Source[ByteString]] = { + override def onPush(ctx: RequestRenderingContext, opCtx: Context[Source[ByteString]]): Directive = { val r = new ByteStringRendering(requestHeaderSizeHint) import ctx.request._ @@ -103,15 +103,15 @@ private[http] class HttpRequestRendererFactory(userAgentHeader: Option[headers.` r ~~ CrLf } - def completeRequestRendering(): List[Source[ByteString]] = + def completeRequestRendering(): Source[ByteString] = entity match { case x if x.isKnownEmpty ⇒ renderContentLength(0) - Source(r.get :: Nil) :: Nil + Source.singleton(r.get) case HttpEntity.Strict(_, data) ⇒ renderContentLength(data.length) - Source.singleton(r.get ++ data) :: Nil + Source.singleton(r.get ++ data) case HttpEntity.Default(_, contentLength, data) ⇒ renderContentLength(contentLength) @@ -126,7 +126,7 @@ private[http] class HttpRequestRendererFactory(userAgentHeader: Option[headers.` renderRequestLine() renderHeaders(headers.toList) renderEntityContentType(r, entity) - completeRequestRendering() + opCtx.push(completeRequestRendering()) } } } @@ -134,4 +134,4 @@ private[http] class HttpRequestRendererFactory(userAgentHeader: Option[headers.` /** * INTERNAL API */ -private[http] final case class RequestRenderingContext(request: HttpRequest, serverAddress: InetSocketAddress) \ No newline at end of file +private[http] final case class RequestRenderingContext(request: HttpRequest, serverAddress: InetSocketAddress) diff --git a/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpResponseRendererFactory.scala b/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpResponseRendererFactory.scala index 36dd2017d8..fa8b575b5c 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpResponseRendererFactory.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/rendering/HttpResponseRendererFactory.scala @@ -8,7 +8,7 @@ import scala.annotation.tailrec import akka.event.LoggingAdapter import akka.util.ByteString import akka.stream.scaladsl.Source -import akka.stream.Transformer +import akka.stream.stage._ import akka.http.model._ import akka.http.util._ import RenderSupport._ @@ -50,12 +50,14 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser def newRenderer: HttpResponseRenderer = new HttpResponseRenderer - final class HttpResponseRenderer extends Transformer[ResponseRenderingContext, Source[ByteString]] { + final class HttpResponseRenderer extends PushStage[ResponseRenderingContext, Source[ByteString]] { + private[this] var close = false // signals whether the connection is to be closed after the current response - override def isComplete = close + // need this for testing + private[http] def isComplete = close - def onNext(ctx: ResponseRenderingContext): List[Source[ByteString]] = { + override def onPush(ctx: ResponseRenderingContext, opCtx: Context[Source[ByteString]]): Directive = { val r = new ByteStringRendering(responseHeaderSizeHint) import ctx.response._ @@ -134,17 +136,17 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf } - def byteStrings(entityBytes: ⇒ Source[ByteString]): List[Source[ByteString]] = + def byteStrings(entityBytes: ⇒ Source[ByteString]): Source[ByteString] = renderByteStrings(r, entityBytes, skipEntity = noEntity) - def completeResponseRendering(entity: ResponseEntity): List[Source[ByteString]] = + def completeResponseRendering(entity: ResponseEntity): Source[ByteString] = entity match { case HttpEntity.Strict(_, data) ⇒ renderHeaders(headers.toList) renderEntityContentType(r, entity) r ~~ `Content-Length` ~~ data.length ~~ CrLf ~~ CrLf val entityBytes = if (noEntity) ByteString.empty else data - Source.singleton(r.get ++ entityBytes) :: Nil + Source.singleton(r.get ++ entityBytes) case HttpEntity.Default(_, contentLength, data) ⇒ renderHeaders(headers.toList) @@ -170,7 +172,11 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser } renderStatusLine() - completeResponseRendering(entity) + val result = completeResponseRendering(entity) + if (close) + opCtx.pushAndFinish(result) + else + opCtx.push(result) } } } diff --git a/akka-http-core/src/main/scala/akka/http/engine/rendering/RenderSupport.scala b/akka-http-core/src/main/scala/akka/http/engine/rendering/RenderSupport.scala index b5418fe2f2..775cd3ebfa 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/rendering/RenderSupport.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/rendering/RenderSupport.scala @@ -9,7 +9,7 @@ import akka.stream.impl.ActorBasedFlowMaterializer import akka.util.ByteString import akka.event.LoggingAdapter import akka.stream.scaladsl._ -import akka.stream.Transformer +import akka.stream.stage._ import akka.http.model._ import akka.http.util._ import org.reactivestreams.Subscriber @@ -45,38 +45,46 @@ private object RenderSupport { r ~~ headers.`Content-Type` ~~ entity.contentType ~~ CrLf def renderByteStrings(r: ByteStringRendering, entityBytes: ⇒ Source[ByteString], - skipEntity: Boolean = false): List[Source[ByteString]] = { - val messageStart = Source(r.get :: Nil) + skipEntity: Boolean = false): Source[ByteString] = { + val messageStart = Source.singleton(r.get) val messageBytes = if (!skipEntity) messageStart ++ entityBytes else CancelSecond(messageStart, entityBytes) - messageBytes :: Nil + messageBytes } - class ChunkTransformer extends Transformer[HttpEntity.ChunkStreamPart, ByteString] { + class ChunkTransformer extends StatefulStage[HttpEntity.ChunkStreamPart, ByteString] { var lastChunkSeen = false - def onNext(chunk: HttpEntity.ChunkStreamPart): List[ByteString] = { - if (chunk.isLastChunk) lastChunkSeen = true - renderChunk(chunk) :: Nil + + override def initial = new State { + override def onPush(chunk: HttpEntity.ChunkStreamPart, ctx: Context[ByteString]): Directive = { + if (chunk.isLastChunk) + lastChunkSeen = true + ctx.push(renderChunk(chunk)) + } } - override def isComplete = lastChunkSeen - override def onTermination(e: Option[Throwable]) = if (lastChunkSeen) Nil else defaultLastChunkBytes :: Nil + + override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = + if (lastChunkSeen) super.onUpstreamFinish(ctx) + else terminationEmit(Iterator.single(defaultLastChunkBytes), ctx) } - class CheckContentLengthTransformer(length: Long) extends Transformer[ByteString, ByteString] { + class CheckContentLengthTransformer(length: Long) extends PushStage[ByteString, ByteString] { var sent = 0L - def onNext(elem: ByteString): List[ByteString] = { + + override def onPush(elem: ByteString, ctx: Context[ByteString]): Directive = { sent += elem.length if (sent > length) throw new InvalidContentLengthException(s"HTTP message had declared Content-Length $length but entity data stream amounts to more bytes") - elem :: Nil + ctx.push(elem) } - override def onTermination(e: Option[Throwable]): List[ByteString] = { + override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = { if (sent < length) throw new InvalidContentLengthException(s"HTTP message had declared Content-Length $length but entity data stream amounts to ${length - sent} bytes less") - Nil + ctx.finish() } + } private def renderChunk(chunk: HttpEntity.ChunkStreamPart): ByteString = { diff --git a/akka-http-core/src/main/scala/akka/http/engine/server/HttpServerPipeline.scala b/akka-http-core/src/main/scala/akka/http/engine/server/HttpServerPipeline.scala index bb9f99f52c..0383635da9 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/server/HttpServerPipeline.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/server/HttpServerPipeline.scala @@ -8,8 +8,8 @@ import akka.event.LoggingAdapter import akka.stream.io.StreamTcp import akka.stream.FlattenStrategy import akka.stream.FlowMaterializer -import akka.stream.Transformer import akka.stream.scaladsl._ +import akka.stream.stage._ import akka.http.engine.parsing.HttpRequestParser import akka.http.engine.rendering.{ ResponseRenderingContext, HttpResponseRendererFactory } import akka.http.model.{ StatusCode, ErrorInfo, HttpRequest, HttpResponse, HttpMethods } @@ -84,37 +84,38 @@ private[http] class HttpServerPipeline(settings: ServerSettings, log: LoggingAda * If the parser produced a ParserOutput.ParseError the error response is immediately dispatched to downstream. */ def applyApplicationBypass = - new Transformer[Any, ResponseRenderingContext] { + new PushStage[Any, ResponseRenderingContext] { var applicationResponse: HttpResponse = _ var requestStart: RequestStart = _ - def onNext(elem: Any) = elem match { + override def onPush(elem: Any, ctx: Context[ResponseRenderingContext]): Directive = elem match { case response: HttpResponse ⇒ requestStart match { case null ⇒ applicationResponse = response - Nil + ctx.pull() case x: RequestStart ⇒ requestStart = null - dispatch(x, response) + ctx.push(dispatch(x, response)) } case requestStart: RequestStart ⇒ applicationResponse match { case null ⇒ this.requestStart = requestStart - Nil + ctx.pull() case response ⇒ applicationResponse = null - dispatch(requestStart, response) + ctx.push(dispatch(requestStart, response)) } - case ParseError(status, info) ⇒ errorResponse(status, info) :: Nil + case ParseError(status, info) ⇒ + ctx.push(errorResponse(status, info)) } - def dispatch(requestStart: RequestStart, response: HttpResponse): List[ResponseRenderingContext] = { + def dispatch(requestStart: RequestStart, response: HttpResponse): ResponseRenderingContext = { import requestStart._ - ResponseRenderingContext(response, method, protocol, closeAfterResponseCompletion) :: Nil + ResponseRenderingContext(response, method, protocol, closeAfterResponseCompletion) } def errorResponse(status: StatusCode, info: ErrorInfo): ResponseRenderingContext = { @@ -123,4 +124,4 @@ private[http] class HttpServerPipeline(settings: ServerSettings, log: LoggingAda ResponseRenderingContext(HttpResponse(status, entity = msg), closeAfterResponseCompletion = true) } } -} \ No newline at end of file +} diff --git a/akka-http-core/src/main/scala/akka/http/model/HttpEntity.scala b/akka-http-core/src/main/scala/akka/http/model/HttpEntity.scala index e65337dea4..e54384f770 100644 --- a/akka-http-core/src/main/scala/akka/http/model/HttpEntity.scala +++ b/akka-http-core/src/main/scala/akka/http/model/HttpEntity.scala @@ -14,9 +14,11 @@ import scala.util.control.NonFatal import akka.util.ByteString import akka.stream.FlowMaterializer import akka.stream.scaladsl._ -import akka.stream.{ TimerTransformer, Transformer } +import akka.stream.TimerTransformer import akka.http.util._ import japi.JavaMapping.Implicits._ +import scala.util.Success +import scala.util.Failure /** * Models the entity (aka "body" or "content) of an HTTP message. @@ -70,7 +72,7 @@ sealed trait HttpEntity extends japi.HttpEntity { * This method may only throw an exception if the `transformer` function throws an exception while creating the transformer. * Any other errors are reported through the new entity data stream. */ - def transformDataBytes(transformer: () ⇒ Transformer[ByteString, ByteString]): HttpEntity + def transformDataBytes(transformer: Flow[ByteString, ByteString]): HttpEntity /** * Creates a copy of this HttpEntity with the `contentType` overridden with the given one. @@ -95,13 +97,13 @@ sealed trait BodyPartEntity extends HttpEntity with japi.BodyPartEntity { sealed trait RequestEntity extends HttpEntity with japi.RequestEntity with ResponseEntity { def withContentType(contentType: ContentType): RequestEntity - override def transformDataBytes(transformer: () ⇒ Transformer[ByteString, ByteString]): RequestEntity + override def transformDataBytes(transformer: Flow[ByteString, ByteString]): RequestEntity } /* An entity that can be used for responses */ sealed trait ResponseEntity extends HttpEntity with japi.ResponseEntity { def withContentType(contentType: ContentType): ResponseEntity - override def transformDataBytes(transformer: () ⇒ Transformer[ByteString, ByteString]): ResponseEntity + override def transformDataBytes(transformer: Flow[ByteString, ByteString]): ResponseEntity } /* An entity that can be used for requests, responses, and body parts */ sealed trait UniversalEntity extends japi.UniversalEntity with MessageEntity with BodyPartEntity { @@ -112,7 +114,7 @@ sealed trait UniversalEntity extends japi.UniversalEntity with MessageEntity wit * Transforms this' entities data bytes with a transformer that will produce exactly the number of bytes given as * ``newContentLength``. */ - def transformDataBytes(newContentLength: Long, transformer: () ⇒ Transformer[ByteString, ByteString]): UniversalEntity + def transformDataBytes(newContentLength: Long, transformer: Flow[ByteString, ByteString]): UniversalEntity } object HttpEntity { @@ -143,11 +145,20 @@ object HttpEntity { // TODO: re-establish serializability // TODO: equal/hashcode ? + object Strict { + // FIXME configurable? + private val MaxByteSize = 1L * 1024 * 1024 * 1024 + private val MaxElements = 1000 + } + /** * The model for the entity of a "regular" unchunked HTTP message with known, fixed data. */ final case class Strict(contentType: ContentType, data: ByteString) extends japi.HttpEntityStrict with UniversalEntity { + + import Strict._ + def contentLength: Long = data.length def isKnownEmpty: Boolean = data.isEmpty @@ -157,22 +168,27 @@ object HttpEntity { override def toStrict(timeout: FiniteDuration)(implicit ec: ExecutionContext, fm: FlowMaterializer) = FastFuture.successful(this) - override def transformDataBytes(transformer: () ⇒ Transformer[ByteString, ByteString]): MessageEntity = { - try { - val t = transformer() - val newData = (t.onNext(data) ++ t.onTermination(None)).join - copy(data = newData) - } catch { - case NonFatal(ex) ⇒ + override def transformDataBytes(transformer: Flow[ByteString, ByteString]): MessageEntity = + StreamUtils.runStrict(data, transformer, MaxByteSize, MaxElements) match { + case Success(Some(newData)) ⇒ + copy(data = newData) + case Success(None) ⇒ + Chunked.fromData(contentType, Source.singleton(data).via(transformer)) + case Failure(ex) ⇒ Chunked(contentType, Source.failed(ex)) } - } - override def transformDataBytes(newContentLength: Long, transformer: () ⇒ Transformer[ByteString, ByteString]): UniversalEntity = { - val t = transformer() - val newData = (t.onNext(data) ++ t.onTermination(None)).join - assert(newData.length.toLong == newContentLength, s"Transformer didn't produce as much bytes (${newData.length}:'${newData.utf8String}') as claimed ($newContentLength)") - copy(data = newData) - } + + override def transformDataBytes(newContentLength: Long, transformer: Flow[ByteString, ByteString]): UniversalEntity = + StreamUtils.runStrict(data, transformer, MaxByteSize, MaxElements) match { + case Success(Some(newData)) ⇒ + if (newData.length.toLong != newContentLength) + throw new IllegalStateException(s"Transformer didn't produce as much bytes (${newData.length}:'${newData.utf8String}') as claimed ($newContentLength)") + copy(data = newData) + case Success(None) ⇒ + Default(contentType, newContentLength, Source.singleton(data).via(transformer)) + case Failure(ex) ⇒ + Default(contentType, newContentLength, Source.failed(ex)) + } def withContentType(contentType: ContentType): Strict = if (contentType == this.contentType) this else copy(contentType = contentType) @@ -194,13 +210,11 @@ object HttpEntity { def dataBytes: Source[ByteString] = data - override def transformDataBytes(transformer: () ⇒ Transformer[ByteString, ByteString]): Chunked = { - val chunks = data.transform("transformDataBytes-Default", () ⇒ transformer().map(Chunk(_): ChunkStreamPart)) + override def transformDataBytes(transformer: Flow[ByteString, ByteString]): Chunked = + Chunked.fromData(contentType, data.via(transformer)) - HttpEntity.Chunked(contentType, chunks) - } - override def transformDataBytes(newContentLength: Long, transformer: () ⇒ Transformer[ByteString, ByteString]): UniversalEntity = - Default(contentType, newContentLength, data.transform("transformDataBytes-with-new-length-Default", transformer)) + override def transformDataBytes(newContentLength: Long, transformer: Flow[ByteString, ByteString]): UniversalEntity = + Default(contentType, newContentLength, data.via(transformer)) def withContentType(contentType: ContentType): Default = if (contentType == this.contentType) this else copy(contentType = contentType) @@ -235,9 +249,8 @@ object HttpEntity { def withContentType(contentType: ContentType): CloseDelimited = if (contentType == this.contentType) this else copy(contentType = contentType) - override def transformDataBytes(transformer: () ⇒ Transformer[ByteString, ByteString]): CloseDelimited = - HttpEntity.CloseDelimited(contentType, - data.transform("transformDataBytes-CloseDelimited", transformer)) + override def transformDataBytes(transformer: Flow[ByteString, ByteString]): CloseDelimited = + HttpEntity.CloseDelimited(contentType, data.via(transformer)) override def productPrefix = "HttpEntity.CloseDelimited" } @@ -253,9 +266,8 @@ object HttpEntity { def withContentType(contentType: ContentType): IndefiniteLength = if (contentType == this.contentType) this else copy(contentType = contentType) - override def transformDataBytes(transformer: () ⇒ Transformer[ByteString, ByteString]): IndefiniteLength = - HttpEntity.IndefiniteLength(contentType, - data.transform("transformDataBytes-IndefiniteLength", transformer)) + override def transformDataBytes(transformer: Flow[ByteString, ByteString]): IndefiniteLength = + HttpEntity.IndefiniteLength(contentType, data.via(transformer)) override def productPrefix = "HttpEntity.IndefiniteLength" } @@ -272,35 +284,16 @@ object HttpEntity { def dataBytes: Source[ByteString] = chunks.map(_.data).filter(_.nonEmpty) - override def transformDataBytes(transformer: () ⇒ Transformer[ByteString, ByteString]): Chunked = { - val newChunks = - chunks.transform("transformDataBytes-Chunked", () ⇒ new Transformer[ChunkStreamPart, ChunkStreamPart] { - val byteTransformer = transformer() - var sentLastChunk = false + override def transformDataBytes(transformer: Flow[ByteString, ByteString]): Chunked = { + val newData = + chunks.map { + case Chunk(data, "") ⇒ data + case LastChunk("", Nil) ⇒ ByteString.empty + case _ ⇒ + throw new IllegalArgumentException("Chunked.transformDataBytes not allowed for chunks with metadata") + }.via(transformer) - override def isComplete: Boolean = byteTransformer.isComplete - - def onNext(element: ChunkStreamPart): immutable.Seq[ChunkStreamPart] = element match { - case Chunk(data, ext) ⇒ Chunk(byteTransformer.onNext(data).join, ext) :: Nil - case l: LastChunk ⇒ - sentLastChunk = true - Chunk(byteTransformer.onTermination(None).join) :: l :: Nil - } - override def onError(cause: scala.Throwable): Unit = byteTransformer.onError(cause) - override def onTermination(e: Option[Throwable]): immutable.Seq[ChunkStreamPart] = { - val remaining = - if (e.isEmpty && !sentLastChunk) byteTransformer.onTermination(None) - else if (e.isDefined /* && sentLastChunk */ ) byteTransformer.onTermination(e) - else Nil - - if (remaining.nonEmpty) Chunk(remaining.join) :: Nil - else Nil - } - - override def cleanup(): Unit = byteTransformer.cleanup() - }) - - HttpEntity.Chunked(contentType, newChunks) + Chunked.fromData(contentType, newData) } def withContentType(contentType: ContentType): Chunked = @@ -365,4 +358,4 @@ object HttpEntity { def getTrailerHeaders: JIterable[japi.HttpHeader] = trailer.asJava } object LastChunk extends LastChunk("", Nil) -} \ No newline at end of file +} diff --git a/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala b/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala index 982a355840..276b855b70 100644 --- a/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala +++ b/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala @@ -4,129 +4,120 @@ package akka.http.util -import java.util.concurrent.atomic.AtomicBoolean import java.io.InputStream -import org.reactivestreams.{ Subscriber, Publisher } +import java.util.concurrent.atomic.AtomicBoolean +import scala.annotation.tailrec import scala.collection.immutable import scala.concurrent.{ ExecutionContext, Future } - +import scala.util.Try import akka.actor.Props -import akka.util.ByteString - -import akka.stream.{ impl, Transformer, FlowMaterializer } -import akka.stream.scaladsl._ - import akka.http.model.RequestEntity +import akka.stream.FlowMaterializer +import akka.stream.impl.Ast.AstNode +import akka.stream.impl.Ast.StageFactory +import akka.stream.impl.fusing.IteratorInterpreter +import akka.stream.scaladsl._ +import akka.stream.stage._ +import akka.stream.impl +import akka.util.ByteString +import org.reactivestreams.{ Subscriber, Publisher } /** * INTERNAL API */ private[http] object StreamUtils { - /** - * Maps a transformer by strictly applying the given function to each output element. - */ - def mapTransformer[T, U, V](t: Transformer[T, U], f: U ⇒ V): Transformer[T, V] = - new Transformer[T, V] { - override def isComplete: Boolean = t.isComplete - - def onNext(element: T): immutable.Seq[V] = t.onNext(element).map(f) - override def onTermination(e: Option[Throwable]): immutable.Seq[V] = t.onTermination(e).map(f) - override def onError(cause: Throwable): Unit = t.onError(cause) - override def cleanup(): Unit = t.cleanup() - } /** * Creates a transformer that will call `f` for each incoming ByteString and output its result. After the complete * input has been read it will call `finish` once to determine the final ByteString to post to the output. */ - def byteStringTransformer(f: ByteString ⇒ ByteString, finish: () ⇒ ByteString): Transformer[ByteString, ByteString] = - new Transformer[ByteString, ByteString] { - def onNext(element: ByteString): immutable.Seq[ByteString] = f(element) :: Nil + def byteStringTransformer(f: ByteString ⇒ ByteString, finish: () ⇒ ByteString): Flow[ByteString, ByteString] = { + val transformer = new PushPullStage[ByteString, ByteString] { + override def onPush(element: ByteString, ctx: Context[ByteString]): Directive = + ctx.push(f(element)) - override def onTermination(e: Option[Throwable]): immutable.Seq[ByteString] = - if (e.isEmpty) { - val last = finish() - if (last.nonEmpty) last :: Nil - else Nil - } else super.onTermination(e) + override def onPull(ctx: Context[ByteString]): Directive = + if (ctx.isFinishing) ctx.pushAndFinish(finish()) + else ctx.pull() + + override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination() } + Flow[ByteString].transform("transformBytes", () ⇒ transformer) + } def failedPublisher[T](ex: Throwable): Publisher[T] = impl.ErrorPublisher(ex, "failed").asInstanceOf[Publisher[T]] - def mapErrorTransformer[T](f: Throwable ⇒ Throwable): Transformer[T, T] = - new Transformer[T, T] { - def onNext(element: T): immutable.Seq[T] = immutable.Seq(element) - override def onError(cause: scala.Throwable): Unit = throw f(cause) + def mapErrorTransformer(f: Throwable ⇒ Throwable): Flow[ByteString, ByteString] = { + val transformer = new PushStage[ByteString, ByteString] { + override def onPush(element: ByteString, ctx: Context[ByteString]): Directive = + ctx.push(element) + + override def onUpstreamFailure(cause: Throwable, ctx: Context[ByteString]): TerminationDirective = + ctx.fail(f(cause)) } - def sliceBytesTransformer(start: Long, length: Long): Transformer[ByteString, ByteString] = - new Transformer[ByteString, ByteString] { - type State = Transformer[ByteString, ByteString] + Flow[ByteString].transform("transformError", () ⇒ transformer) + } + + def sliceBytesTransformer(start: Long, length: Long): Flow[ByteString, ByteString] = { + val transformer = new StatefulStage[ByteString, ByteString] { def skipping = new State { var toSkip = start - def onNext(element: ByteString): immutable.Seq[ByteString] = + override def onPush(element: ByteString, ctx: Context[ByteString]): Directive = if (element.length < toSkip) { // keep skipping toSkip -= element.length - Nil + ctx.pull() } else { become(taking(length)) // toSkip <= element.length <= Int.MaxValue - currentState.onNext(element.drop(toSkip.toInt)) + current.onPush(element.drop(toSkip.toInt), ctx) } } def taking(initiallyRemaining: Long) = new State { var remaining: Long = initiallyRemaining - def onNext(element: ByteString): immutable.Seq[ByteString] = { + override def onPush(element: ByteString, ctx: Context[ByteString]): Directive = { val data = element.take(math.min(remaining, Int.MaxValue).toInt) remaining -= data.size - if (remaining <= 0) become(finishing) - data :: Nil + if (remaining <= 0) ctx.pushAndFinish(data) + else ctx.push(data) } } - def finishing = new State { - override def isComplete: Boolean = true - def onNext(element: ByteString): immutable.Seq[ByteString] = - throw new IllegalStateException("onNext called on complete stream") - } - var currentState: State = if (start > 0) skipping else taking(length) - def become(state: State): Unit = currentState = state - - override def isComplete: Boolean = currentState.isComplete - def onNext(element: ByteString): immutable.Seq[ByteString] = currentState.onNext(element) - override def onTermination(e: Option[Throwable]): immutable.Seq[ByteString] = currentState.onTermination(e) + override def initial: State = if (start > 0) skipping else taking(length) } + Flow[ByteString].transform("sliceBytes", () ⇒ transformer) + } /** * Applies a sequence of transformers on one source and returns a sequence of sources with the result. The input source * will only be traversed once. */ - def transformMultiple[T, U](input: Source[T], transformers: immutable.Seq[() ⇒ Transformer[T, U]])(implicit materializer: FlowMaterializer): immutable.Seq[Source[U]] = + def transformMultiple(input: Source[ByteString], transformers: immutable.Seq[Flow[ByteString, ByteString]])(implicit materializer: FlowMaterializer): immutable.Seq[Source[ByteString]] = transformers match { case Nil ⇒ Nil - case Seq(one) ⇒ Vector(input.transform("transformMultipleElement", one)) + case Seq(one) ⇒ Vector(input.via(one)) case multiple ⇒ - val results = Vector.fill(multiple.size)(Sink.publisher[U]) + val results = Vector.fill(multiple.size)(Sink.publisher[ByteString]) val mat = FlowGraph { implicit b ⇒ import FlowGraphImplicits._ - val broadcast = Broadcast[T]("transformMultipleInputBroadcast") + val broadcast = Broadcast[ByteString]("transformMultipleInputBroadcast") input ~> broadcast (multiple, results).zipped.foreach { (trans, sink) ⇒ - broadcast ~> Flow[T].transform("transformMultipleElement", trans) ~> sink + broadcast ~> trans ~> sink } }.run() results.map(s ⇒ Source(mat.get(s))) } def mapEntityError(f: Throwable ⇒ Throwable): RequestEntity ⇒ RequestEntity = - _.transformDataBytes(() ⇒ mapErrorTransformer(f)) + _.transformDataBytes(mapErrorTransformer(f)) /** * Simple blocking Source backed by an InputStream. @@ -186,13 +177,53 @@ private[http] object StreamUtils { else (ErrorPublisher(new IllegalStateException("One time source can only be instantiated once"), "failed").asInstanceOf[Publisher[T]], ()) } } -} -/** - * INTERNAL API - */ -private[http] class EnhancedTransformer[T, U](val t: Transformer[T, U]) extends AnyVal { - def map[V](f: U ⇒ V): Transformer[T, V] = StreamUtils.mapTransformer(t, f) + def runStrict(sourceData: ByteString, transformer: Flow[ByteString, ByteString], maxByteSize: Long, maxElements: Int): Try[Option[ByteString]] = + Try { + transformer match { + case Pipe(ops) ⇒ + if (ops.isEmpty) + Some(sourceData) + else { + @tailrec def tryBuild(remaining: List[AstNode], acc: List[PushPullStage[ByteString, ByteString]]): List[PushPullStage[ByteString, ByteString]] = + remaining match { + case Nil ⇒ acc.reverse + case StageFactory(mkStage, _) :: tail ⇒ + mkStage() match { + case d: PushPullStage[ByteString, ByteString] ⇒ + tryBuild(tail, d :: acc) + case _ ⇒ Nil + } + case _ ⇒ Nil + } + + val strictOps = tryBuild(ops, Nil) + if (strictOps.isEmpty) + None + else { + val iter: Iterator[ByteString] = new IteratorInterpreter(Iterator.single(sourceData), strictOps).iterator + var byteSize = 0L + var result = ByteString.empty + var i = 0 + // note that iter.next() will throw exception if the stream fails, caught by the enclosing Try + while (iter.hasNext) { + i += 1 + if (i > maxElements) + throw new IllegalArgumentException(s"Too many elements produced by byte transformation, $i was greater than max allowed $maxElements elements") + val elem = iter.next() + byteSize += elem.size + if (byteSize > maxByteSize) + throw new IllegalArgumentException(s"Too large data result, $byteSize bytes was greater than max allowed $maxByteSize bytes") + result ++= elem + } + Some(result) + } + } + + case _ ⇒ None + } + } + } /** diff --git a/akka-http-core/src/main/scala/akka/http/util/package.scala b/akka-http-core/src/main/scala/akka/http/util/package.scala index 7fb3787aa2..f0be761f04 100644 --- a/akka-http-core/src/main/scala/akka/http/util/package.scala +++ b/akka-http-core/src/main/scala/akka/http/util/package.scala @@ -6,10 +6,12 @@ package akka.http import language.implicitConversions import language.higherKinds +import scala.collection.immutable import java.nio.charset.Charset import com.typesafe.config.Config -import akka.stream.{ FlowMaterializer, FlattenStrategy, Transformer } +import akka.stream.{ FlowMaterializer, FlattenStrategy } import akka.stream.scaladsl.{ Flow, Source } +import akka.stream.stage._ import scala.concurrent.duration.Duration import scala.concurrent.{ Await, Future } import scala.util.{ Failure, Success } @@ -17,7 +19,6 @@ import scala.util.matching.Regex import akka.event.LoggingAdapter import akka.util.ByteString import akka.actor._ -import scala.collection.immutable package object util { private[http] val UTF8 = Charset.forName("UTF8") @@ -41,8 +42,6 @@ package object util { new EnhancedByteStringTraversableOnce(byteStrings) private[http] implicit def enhanceByteStrings(byteStrings: Source[ByteString]): EnhancedByteStringSource = new EnhancedByteStringSource(byteStrings) - private[http] implicit def enhanceTransformer[T, U](transformer: Transformer[T, U]): EnhancedTransformer[T, U] = - new EnhancedTransformer(transformer) private[http] implicit class SourceWithHeadAndTail[T](val underlying: Source[Source[T]]) extends AnyVal { def headAndTail: Source[(T, Source[T])] = @@ -59,14 +58,18 @@ package object util { private[http] implicit class EnhancedSource[T](val underlying: Source[T]) { def printEvent(marker: String): Source[T] = underlying.transform("transform", - () ⇒ new Transformer[T, T] { - def onNext(element: T) = { + () ⇒ new PushStage[T, T] { + override def onPush(element: T, ctx: Context[T]): Directive = { println(s"$marker: $element") - element :: Nil + ctx.push(element) } - override def onTermination(e: Option[Throwable]) = { - println(s"$marker: Terminated with error $e") - Nil + override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = { + println(s"$marker: Failure $cause") + super.onUpstreamFailure(cause, ctx) + } + override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = { + println(s"$marker: Terminated") + super.onUpstreamFinish(ctx) } }) @@ -90,10 +93,13 @@ package object util { } } - private[http] def errorLogger(log: LoggingAdapter, msg: String): Transformer[ByteString, ByteString] = - new Transformer[ByteString, ByteString] { - def onNext(element: ByteString) = element :: Nil - override def onError(cause: Throwable): Unit = log.error(cause, msg) + private[http] def errorLogger(log: LoggingAdapter, msg: String): PushStage[ByteString, ByteString] = + new PushStage[ByteString, ByteString] { + override def onPush(element: ByteString, ctx: Context[ByteString]): Directive = ctx.push(element) + override def onUpstreamFailure(cause: Throwable, ctx: Context[ByteString]): TerminationDirective = { + log.error(cause, msg) + super.onUpstreamFailure(cause, ctx) + } } private[this] val _identityFunc: Any ⇒ Any = x ⇒ x diff --git a/akka-http-core/src/test/scala/akka/http/ClientServerSpec.scala b/akka-http-core/src/test/scala/akka/http/ClientServerSpec.scala index b6f2e04af6..5f63675b35 100644 --- a/akka-http-core/src/test/scala/akka/http/ClientServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/ClientServerSpec.scala @@ -184,4 +184,4 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { } def toStrict(entity: HttpEntity): HttpEntity.Strict = Await.result(entity.toStrict(500.millis), 1.second) -} \ No newline at end of file +} diff --git a/akka-http-core/src/test/scala/akka/http/TestClient.scala b/akka-http-core/src/test/scala/akka/http/TestClient.scala index f79a37baad..d5c88bb917 100644 --- a/akka-http-core/src/test/scala/akka/http/TestClient.scala +++ b/akka-http-core/src/test/scala/akka/http/TestClient.scala @@ -38,7 +38,7 @@ object TestClient extends App { } yield response.header[headers.Server] def sendRequest(request: HttpRequest, connection: Http.OutgoingConnection): Future[HttpResponse] = { - Source(List(HttpRequest() -> 'NoContext)) + Source.singleton(HttpRequest() -> 'NoContext) .to(Sink(connection.requestSubscriber)) .run() Source(connection.responsePublisher).map(_._1).runWith(Sink.head) @@ -49,4 +49,4 @@ object TestClient extends App { case Failure(error) ⇒ println(s"Error: $error") } result onComplete { _ ⇒ system.shutdown() } -} \ No newline at end of file +} diff --git a/akka-http-core/src/test/scala/akka/http/engine/rendering/RequestRendererSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/rendering/RequestRendererSpec.scala index 0f18d7b960..43f2926d0c 100644 --- a/akka-http-core/src/test/scala/akka/http/engine/rendering/RequestRendererSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/engine/rendering/RequestRendererSpec.scala @@ -252,7 +252,9 @@ class RequestRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll def renderTo(expected: String): Matcher[HttpRequest] = equal(expected.stripMarginWithNewline("\r\n")).matcher[String] compose { request ⇒ val renderer = newRenderer - val byteStringSource :: Nil = renderer.onNext(RequestRenderingContext(request, serverAddress)) + val byteStringSource = Await.result(Source.singleton(RequestRenderingContext(request, serverAddress)). + transform("renderer", () ⇒ renderer). + runWith(Sink.head), 1.second) val future = byteStringSource.grouped(1000).runWith(Sink.head).map(_.reduceLeft(_ ++ _).utf8String) Await.result(future, 250.millis) } diff --git a/akka-http-core/src/test/scala/akka/http/engine/rendering/ResponseRendererSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/rendering/ResponseRendererSpec.scala index 5601023e0e..84cbb078f8 100644 --- a/akka-http-core/src/test/scala/akka/http/engine/rendering/ResponseRendererSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/engine/rendering/ResponseRendererSpec.scala @@ -402,7 +402,9 @@ class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll def renderTo(expected: String, close: Boolean): Matcher[ResponseRenderingContext] = equal(expected.stripMarginWithNewline("\r\n") -> close).matcher[(String, Boolean)] compose { ctx ⇒ val renderer = newRenderer - val byteStringSource :: Nil = renderer.onNext(ctx) + val byteStringSource = Await.result(Source.singleton(ctx). + transform("renderer", () ⇒ renderer). + runWith(Sink.head), 1.second) val future = byteStringSource.grouped(1000).runWith(Sink.head).map(_.reduceLeft(_ ++ _).utf8String) Await.result(future, 250.millis) -> renderer.isComplete } diff --git a/akka-http-core/src/test/scala/akka/http/model/HttpEntitySpec.scala b/akka-http-core/src/test/scala/akka/http/model/HttpEntitySpec.scala index 28a37e0aa1..707ec1604b 100644 --- a/akka-http-core/src/test/scala/akka/http/model/HttpEntitySpec.scala +++ b/akka-http-core/src/test/scala/akka/http/model/HttpEntitySpec.scala @@ -16,8 +16,8 @@ import akka.util.ByteString import akka.actor.ActorSystem import akka.stream.scaladsl._ import akka.stream.FlowMaterializer -import akka.stream.Transformer import akka.http.model.HttpEntity._ +import akka.http.util.StreamUtils class HttpEntitySpec extends FreeSpec with MustMatchers with BeforeAndAfterAll { val tpe: ContentType = ContentTypes.`application/octet-stream` @@ -120,14 +120,8 @@ class HttpEntitySpec extends FreeSpec with MustMatchers with BeforeAndAfterAll { Await.result(transformed.toStrict(250.millis), 250.millis) } - def duplicateBytesTransformer(): Transformer[ByteString, ByteString] = - new Transformer[ByteString, ByteString] { - def onNext(bs: ByteString): immutable.Seq[ByteString] = - Vector(doubleChars(bs)) - - override def onTermination(e: Option[Throwable]): immutable.Seq[ByteString] = - Vector(trailer) - } + def duplicateBytesTransformer(): Flow[ByteString, ByteString] = + StreamUtils.byteStringTransformer(doubleChars, () ⇒ trailer) def trailer: ByteString = ByteString("--dup") def doubleChars(bs: ByteString): ByteString = ByteString(bs.flatMap(b ⇒ Seq(b, b)): _*) diff --git a/akka-http-tests/src/test/scala/akka/http/server/directives/CodingDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/server/directives/CodingDirectivesSpec.scala index c3f9ad5ff6..234d14fa41 100644 --- a/akka-http-tests/src/test/scala/akka/http/server/directives/CodingDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/server/directives/CodingDirectivesSpec.scala @@ -507,4 +507,4 @@ class CodingDirectivesSpec extends RoutingSpec { be(Some(`Content-Encoding`(encoding))) compose { (_: HttpResponse).header[`Content-Encoding`] } def readAs(string: String, charset: String = "UTF8") = be(string) compose { (_: ByteString).decodeString(charset) } -} \ No newline at end of file +} diff --git a/akka-http/src/main/scala/akka/http/coding/DataMapper.scala b/akka-http/src/main/scala/akka/http/coding/DataMapper.scala index 32b997a7ae..13da7f0ec3 100644 --- a/akka-http/src/main/scala/akka/http/coding/DataMapper.scala +++ b/akka-http/src/main/scala/akka/http/coding/DataMapper.scala @@ -5,22 +5,22 @@ package akka.http.coding import akka.http.model.{ HttpRequest, HttpResponse, ResponseEntity, RequestEntity } -import akka.stream.Transformer import akka.util.ByteString +import akka.stream.scaladsl.Flow /** An abstraction to transform data bytes of HttpMessages or HttpEntities */ sealed trait DataMapper[T] { - def transformDataBytes(t: T, transformer: () ⇒ Transformer[ByteString, ByteString]): T + def transformDataBytes(t: T, transformer: Flow[ByteString, ByteString]): T } object DataMapper { implicit val mapRequestEntity: DataMapper[RequestEntity] = new DataMapper[RequestEntity] { - def transformDataBytes(t: RequestEntity, transformer: () ⇒ Transformer[ByteString, ByteString]): RequestEntity = + def transformDataBytes(t: RequestEntity, transformer: Flow[ByteString, ByteString]): RequestEntity = t.transformDataBytes(transformer) } implicit val mapResponseEntity: DataMapper[ResponseEntity] = new DataMapper[ResponseEntity] { - def transformDataBytes(t: ResponseEntity, transformer: () ⇒ Transformer[ByteString, ByteString]): ResponseEntity = + def transformDataBytes(t: ResponseEntity, transformer: Flow[ByteString, ByteString]): ResponseEntity = t.transformDataBytes(transformer) } @@ -29,7 +29,7 @@ object DataMapper { def mapMessage[T, E](entityMapper: DataMapper[E])(mapEntity: (T, E ⇒ E) ⇒ T): DataMapper[T] = new DataMapper[T] { - def transformDataBytes(t: T, transformer: () ⇒ Transformer[ByteString, ByteString]): T = + def transformDataBytes(t: T, transformer: Flow[ByteString, ByteString]): T = mapEntity(t, entityMapper.transformDataBytes(_, transformer)) } -} \ No newline at end of file +} diff --git a/akka-http/src/main/scala/akka/http/coding/Decoder.scala b/akka-http/src/main/scala/akka/http/coding/Decoder.scala index 08ee9c3cfd..b0604b622c 100644 --- a/akka-http/src/main/scala/akka/http/coding/Decoder.scala +++ b/akka-http/src/main/scala/akka/http/coding/Decoder.scala @@ -6,9 +6,9 @@ package akka.http.coding import akka.http.model._ import akka.http.util.StreamUtils -import akka.stream.Transformer import akka.util.ByteString import headers.HttpEncoding +import akka.stream.scaladsl.Flow trait Decoder { def encoding: HttpEncoding @@ -23,7 +23,7 @@ trait Decoder { def newDecompressor: Decompressor - def newDecodeTransfomer(): Transformer[ByteString, ByteString] = { + def newDecodeTransfomer(): Flow[ByteString, ByteString] = { val decompressor = newDecompressor def decodeChunk(bytes: ByteString): ByteString = decompressor.decompress(bytes) diff --git a/akka-http/src/main/scala/akka/http/coding/Encoder.scala b/akka-http/src/main/scala/akka/http/coding/Encoder.scala index 3f41ed36b0..39eb2ef484 100644 --- a/akka-http/src/main/scala/akka/http/coding/Encoder.scala +++ b/akka-http/src/main/scala/akka/http/coding/Encoder.scala @@ -7,9 +7,9 @@ package akka.http.coding import java.io.ByteArrayOutputStream import akka.http.model._ import akka.http.util.StreamUtils -import akka.stream.Transformer import akka.util.ByteString import headers._ +import akka.stream.scaladsl.Flow trait Encoder { def encoding: HttpEncoding @@ -26,7 +26,7 @@ trait Encoder { def newCompressor: Compressor - def newEncodeTransformer(): Transformer[ByteString, ByteString] = { + def newEncodeTransformer(): Flow[ByteString, ByteString] = { val compressor = newCompressor def encodeChunk(bytes: ByteString): ByteString = compressor.compressAndFlush(bytes) @@ -72,4 +72,4 @@ abstract class Compressor { def compressAndFlush(input: ByteString): ByteString /** Combines `compress` + `finish` */ def compressAndFinish(input: ByteString): ByteString -} \ No newline at end of file +} diff --git a/akka-http/src/main/scala/akka/http/coding/Gzip.scala b/akka-http/src/main/scala/akka/http/coding/Gzip.scala index eafffb8d4d..782c77793e 100644 --- a/akka-http/src/main/scala/akka/http/coding/Gzip.scala +++ b/akka-http/src/main/scala/akka/http/coding/Gzip.scala @@ -128,6 +128,7 @@ class GzipDecompressor extends DeflateDecompressor { } private def fail(msg: String) = throw new ZipException(msg) + } /** INTERNAL API */ diff --git a/akka-http/src/main/scala/akka/http/server/directives/CodingDirectives.scala b/akka-http/src/main/scala/akka/http/server/directives/CodingDirectives.scala index e09a09739b..d63f30f693 100644 --- a/akka-http/src/main/scala/akka/http/server/directives/CodingDirectives.scala +++ b/akka-http/src/main/scala/akka/http/server/directives/CodingDirectives.scala @@ -93,4 +93,4 @@ object CodingDirectives extends CodingDirectives { def theseOrDefault[T >: Coder](these: Seq[T]): Seq[T] = if (these.isEmpty) defaultCoders else these -} \ No newline at end of file +} diff --git a/akka-http/src/main/scala/akka/http/server/directives/RangeDirectives.scala b/akka-http/src/main/scala/akka/http/server/directives/RangeDirectives.scala index e38c584122..c07e0b1141 100644 --- a/akka-http/src/main/scala/akka/http/server/directives/RangeDirectives.scala +++ b/akka-http/src/main/scala/akka/http/server/directives/RangeDirectives.scala @@ -41,7 +41,7 @@ trait RangeDirectives { class IndexRange(val start: Long, val end: Long) { def length = end - start - def apply(entity: UniversalEntity): UniversalEntity = entity.transformDataBytes(length, () ⇒ StreamUtils.sliceBytesTransformer(start, length)) + def apply(entity: UniversalEntity): UniversalEntity = entity.transformDataBytes(length, StreamUtils.sliceBytesTransformer(start, length)) def distance(other: IndexRange) = mergedEnd(other) - mergedStart(other) - (length + other.length) def mergeWith(other: IndexRange) = new IndexRange(mergedStart(other), mergedEnd(other)) def contentRange(entityLength: Long) = ContentRange(start, end - 1, entityLength) @@ -73,7 +73,7 @@ trait RangeDirectives { // Therefore, ranges need to be sorted to prevent that some selected ranges already start to accumulate data // but cannot be sent out because another range is blocking the queue. val coalescedRanges = coalesceRanges(iRanges).sortBy(_.start) - val bodyPartTransformers = coalescedRanges.map(ir ⇒ () ⇒ StreamUtils.sliceBytesTransformer(ir.start, ir.length)).toVector + val bodyPartTransformers = coalescedRanges.map(ir ⇒ StreamUtils.sliceBytesTransformer(ir.start, ir.length)).toVector val bodyPartByteStreams = StreamUtils.transformMultiple(entity.dataBytes, bodyPartTransformers) val bodyParts = (coalescedRanges, bodyPartByteStreams).zipped.map { (range, bytes) ⇒ Multipart.ByteRanges.BodyPart(range.contentRange(length), HttpEntity(entity.contentType, range.length, bytes)) diff --git a/akka-http/src/main/scala/akka/http/unmarshalling/MultipartUnmarshallers.scala b/akka-http/src/main/scala/akka/http/unmarshalling/MultipartUnmarshallers.scala index 9b1a0b9b63..ba3b7ab6ac 100644 --- a/akka-http/src/main/scala/akka/http/unmarshalling/MultipartUnmarshallers.scala +++ b/akka-http/src/main/scala/akka/http/unmarshalling/MultipartUnmarshallers.scala @@ -15,6 +15,8 @@ import akka.stream.scaladsl._ import MediaRanges._ import MediaTypes._ import HttpCharsets._ +import akka.stream.impl.fusing.IteratorInterpreter +import akka.util.ByteString trait MultipartUnmarshallers { @@ -67,15 +69,18 @@ trait MultipartUnmarshallers { entity match { case HttpEntity.Strict(ContentType(mediaType: MultipartMediaType, _), data) ⇒ val builder = new VectorBuilder[BPS]() - (parser.onNext(data) ++ parser.onTermination(None)) foreach { + val iter = new IteratorInterpreter[ByteString, BodyPartParser.Output]( + Iterator.single(data), List(parser)).iterator + // note that iter.next() will throw exception if stream fails + iter.foreach { case BodyPartStart(headers, createEntity) ⇒ val entity = createEntity(Source.empty()) match { case x: HttpEntity.Strict ⇒ x - case x ⇒ throw new IllegalStateException("Unexpected entity type from strict BodyPartParser: " + x.getClass.getName) + case x ⇒ throw new IllegalStateException("Unexpected entity type from strict BodyPartParser: " + x) } builder += createStrictBodyPart(entity, headers) case ParseError(errorInfo) ⇒ throw new ParsingException(errorInfo) - case x ⇒ throw new IllegalStateException(s"Unexpected BodyPartParser result `x` in strict case") + case x ⇒ throw new IllegalStateException(s"Unexpected BodyPartParser result $x in strict case") } createStrict(mediaType, builder.result()) case _ ⇒ @@ -95,4 +100,4 @@ trait MultipartUnmarshallers { } } -object MultipartUnmarshallers extends MultipartUnmarshallers \ No newline at end of file +object MultipartUnmarshallers extends MultipartUnmarshallers diff --git a/akka-stream-tck/src/test/scala/akka/stream/tck/FusableProcessorTest.scala b/akka-stream-tck/src/test/scala/akka/stream/tck/FusableProcessorTest.scala index 69e526c120..ffccf1074b 100644 --- a/akka-stream-tck/src/test/scala/akka/stream/tck/FusableProcessorTest.scala +++ b/akka-stream-tck/src/test/scala/akka/stream/tck/FusableProcessorTest.scala @@ -4,10 +4,10 @@ package akka.stream.tck import java.util.concurrent.atomic.AtomicInteger - import akka.stream.impl.{ Ast, ActorBasedFlowMaterializer } import akka.stream.{ FlowMaterializer, MaterializerSettings } import org.reactivestreams.{ Publisher, Processor } +import akka.stream.impl.fusing.Map class FusableProcessorTest extends AkkaIdentityProcessorVerification[Int] { @@ -22,7 +22,7 @@ class FusableProcessorTest extends AkkaIdentityProcessorVerification[Int] { val flowName = getClass.getSimpleName + "-" + processorCounter.incrementAndGet() val processor = materializer.asInstanceOf[ActorBasedFlowMaterializer].processorForNode( - Ast.Fused(List(akka.stream.impl.fusing.Map[Int, Int](identity)), "identity"), flowName, 1) + Ast.Fused(List(Map[Int, Int](identity)), "identity"), flowName, 1) processor.asInstanceOf[Processor[Int, Int]] } diff --git a/akka-stream-tck/src/test/scala/akka/stream/tck/TransformProcessorTest.scala b/akka-stream-tck/src/test/scala/akka/stream/tck/TransformProcessorTest.scala index 391a82eaa2..b8ca43d12f 100644 --- a/akka-stream-tck/src/test/scala/akka/stream/tck/TransformProcessorTest.scala +++ b/akka-stream-tck/src/test/scala/akka/stream/tck/TransformProcessorTest.scala @@ -4,13 +4,14 @@ package akka.stream.tck import akka.stream.MaterializerSettings -import akka.stream.Transformer import akka.stream.impl.ActorBasedFlowMaterializer import akka.stream.impl.Ast import akka.stream.FlowMaterializer import java.util.concurrent.atomic.AtomicInteger import org.reactivestreams.Processor import org.reactivestreams.Publisher +import akka.stream.stage.PushStage +import akka.stream.stage.Context class TransformProcessorTest extends AkkaIdentityProcessorVerification[Int] { @@ -24,13 +25,13 @@ class TransformProcessorTest extends AkkaIdentityProcessorVerification[Int] { val flowName = getClass.getSimpleName + "-" + processorCounter.incrementAndGet() - val mkTransformer = () ⇒ - new Transformer[Any, Any] { - override def onNext(in: Any) = List(in) + val mkStage = () ⇒ + new PushStage[Any, Any] { + override def onPush(in: Any, ctx: Context[Any]) = ctx.push(in) } val processor = materializer.asInstanceOf[ActorBasedFlowMaterializer].processorForNode( - Ast.Transform("transform", mkTransformer), flowName, 1) + Ast.StageFactory(mkStage, "transform"), flowName, 1) processor.asInstanceOf[Processor[Int, Int]] } diff --git a/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java b/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java index 2c2804845e..99a29fdf6d 100644 --- a/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java +++ b/akka-stream-tests/src/test/java/akka/stream/javadsl/FlowTest.java @@ -5,29 +5,25 @@ import akka.dispatch.Foreach; import akka.dispatch.Futures; import akka.dispatch.OnSuccess; import akka.japi.Pair; -import akka.japi.Util; import akka.stream.OverflowStrategy; import akka.stream.StreamTest; -import akka.stream.Transformer; +import akka.stream.stage.*; import akka.stream.javadsl.japi.*; import akka.stream.testkit.AkkaSpec; import akka.testkit.JavaTestKit; + import org.junit.ClassRule; import org.junit.Test; import org.reactivestreams.Publisher; -import scala.Option; -import scala.collection.immutable.Seq; import scala.concurrent.Await; import scala.concurrent.Future; import scala.concurrent.duration.Duration; import scala.concurrent.duration.FiniteDuration; import scala.runtime.BoxedUnit; import scala.util.Try; - import java.util.*; import java.util.concurrent.Callable; import java.util.concurrent.TimeUnit; - import static org.junit.Assert.assertEquals; public class FlowTest extends StreamTest { @@ -104,37 +100,37 @@ public class FlowTest extends StreamTest { @Test public void mustBeAbleToUseTransform() { final JavaTestKit probe = new JavaTestKit(system); - final JavaTestKit probe2 = new JavaTestKit(system); final Iterable input = Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7); // duplicate each element, stop after 4 elements, and emit sum to the end - Source.from(input).transform("publish", new Creator>() { + Source.from(input).transform("publish", new Creator>() { @Override - public Transformer create() throws Exception { - return new Transformer() { + public PushPullStage create() throws Exception { + return new StatefulStage() { int sum = 0; int count = 0; @Override - public scala.collection.immutable.Seq onNext(Integer element) { - sum += element; - count += 1; - return Util.immutableSeq(new Integer[] { element, element }); + public StageState initial() { + return new StageState() { + @Override + public Directive onPush(Integer element, Context ctx) { + sum += element; + count += 1; + if (count == 4) { + return emitAndFinish(Arrays.asList(element, element, sum).iterator(), ctx); + } else { + return emit(Arrays.asList(element, element).iterator(), ctx); + } + } + + }; } - + @Override - public boolean isComplete() { - return count == 4; - } - - @Override - public scala.collection.immutable.Seq onTermination(Option e) { - return Util.immutableSingletonSeq(sum); - } - - @Override - public void cleanup() { - probe2.getRef().tell("cleanup", ActorRef.noSender()); + public TerminationDirective onUpstreamFinish(Context ctx) { + return terminationEmit(Collections.singletonList(sum).iterator(), ctx); } + }; } }).foreach(new Procedure() { @@ -152,66 +148,6 @@ public class FlowTest extends StreamTest { probe.expectMsgEquals(3); probe.expectMsgEquals(3); probe.expectMsgEquals(6); - probe2.expectMsgEquals("cleanup"); - } - - @Test - public void mustBeAbleToUseTransformRecover() { - final JavaTestKit probe = new JavaTestKit(system); - final Iterable input = Arrays.asList(0, 1, 2, 3, 4, 5); - Source.from(input).map(new Function() { - public Integer apply(Integer elem) { - if (elem == 4) { - throw new IllegalArgumentException("4 not allowed"); - } else { - return elem + elem; - } - } - }).transform("publish", new Creator>() { - @Override - public Transformer create() throws Exception { - return new Transformer() { - - @Override - public scala.collection.immutable.Seq onNext(Integer element) { - return Util.immutableSingletonSeq(element.toString()); - } - - @Override - public scala.collection.immutable.Seq onTermination(Option e) { - if (e.isEmpty()) { - return Util.immutableSeq(new String[0]); - } else { - return Util.immutableSingletonSeq(e.get().getMessage()); - } - } - - @Override - public void onError(Throwable e) { - } - - @Override - public boolean isComplete() { - return false; - } - - @Override - public void cleanup() { - } - - }; - } - }).foreach(new Procedure() { - public void apply(String elem) { - probe.getRef().tell(elem, ActorRef.noSender()); - } - }, materializer); - - probe.expectMsgEquals("0"); - probe.expectMsgEquals("2"); - probe.expectMsgEquals("4"); - probe.expectMsgEquals("6"); - probe.expectMsgEquals("4 not allowed"); } @Test @@ -291,14 +227,19 @@ public class FlowTest extends StreamTest { } - public Creator> op() { - return new akka.stream.javadsl.japi.Creator>() { + public Creator> op() { + return new akka.stream.javadsl.japi.Creator>() { @Override - public Transformer create() throws Exception { - return new Transformer() { + public PushPullStage create() throws Exception { + return new PushPullStage() { @Override - public Seq onNext(In element) { - return Util.immutableSeq(Collections.singletonList((Out) element)); // TODO needs helpers + public Directive onPush(T element, Context ctx) { + return ctx.push(element); + } + + @Override + public Directive onPull(Context ctx) { + return ctx.pull(); } }; } @@ -307,9 +248,9 @@ public class FlowTest extends StreamTest { @Test public void mustBeAbleToUseMerge() throws Exception { - final Flow f1 = Flow.of(String.class).transform("f1", this.op()); // javadsl - final Flow f2 = Flow.of(String.class).transform("f2", this.op()); // javadsl - final Flow f3 = Flow.of(String.class).transform("f2", this.op()); // javadsl + final Flow f1 = Flow.of(String.class).transform("f1", this.op()); // javadsl + final Flow f2 = Flow.of(String.class).transform("f2", this.op()); // javadsl + final Flow f3 = Flow.of(String.class).transform("f2", this.op()); // javadsl final Source in1 = Source.from(Arrays.asList("a", "b", "c")); final Source in2 = Source.from(Arrays.asList("d", "e", "f")); diff --git a/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala b/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala index eb1fe13830..253c1fa142 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/DslConsistencySpec.scala @@ -89,4 +89,4 @@ class DslConsistencySpec extends WordSpec with Matchers { } } -} \ No newline at end of file +} diff --git a/akka-stream-tests/src/test/scala/akka/stream/actor/ActorPublisherSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/actor/ActorPublisherSpec.scala index ebbcfd80be..afce4b7654 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/actor/ActorPublisherSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/actor/ActorPublisherSpec.scala @@ -329,7 +329,6 @@ class ActorPublisherSpec extends AkkaSpec with ImplicitSender { val timeout = 150.millis val a = system.actorOf(timeoutingProps(testActor, timeout)) val pub = ActorPublisher(a) - watch(a) // don't subscribe for `timeout` millis, so it will shut itself down expectMsg("timed-out") @@ -341,6 +340,7 @@ class ActorPublisherSpec extends AkkaSpec with ImplicitSender { expectMsg("cleaned-up") // termination is tiggered by user code + watch(a) expectTerminated(a) } diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpecKit.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpecKit.scala index 40fcb8c72f..575f1c715c 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpecKit.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSpecKit.scala @@ -4,6 +4,7 @@ package akka.stream.impl.fusing import akka.stream.testkit.AkkaSpec +import akka.stream.stage._ trait InterpreterSpecKit extends AkkaSpec { @@ -13,25 +14,25 @@ trait InterpreterSpecKit extends AkkaSpec { case class OnNext(elem: Any) case object RequestOne - private[akka] case class Doubler[T]() extends DeterministicOp[T, T] { + private[akka] case class Doubler[T]() extends PushPullStage[T, T] { var oneMore: Boolean = false var lastElem: T = _ - override def onPush(elem: T, ctxt: Context[T]): Directive = { + override def onPush(elem: T, ctx: Context[T]): Directive = { lastElem = elem oneMore = true - ctxt.push(elem) + ctx.push(elem) } - override def onPull(ctxt: Context[T]): Directive = { + override def onPull(ctx: Context[T]): Directive = { if (oneMore) { oneMore = false - ctxt.push(lastElem) - } else ctxt.pull() + ctx.push(lastElem) + } else ctx.pull() } } - abstract class TestSetup(ops: Seq[Op[_, _, _, _, _]], forkLimit: Int = 100, overflowToHeap: Boolean = false) { + abstract class TestSetup(ops: Seq[Stage[_, _]], forkLimit: Int = 100, overflowToHeap: Boolean = false) { private var lastEvent: Set[Any] = Set.empty val upstream = new UpstreamProbe @@ -45,19 +46,19 @@ trait InterpreterSpecKit extends AkkaSpec { result } - class UpstreamProbe extends BoundaryOp { + class UpstreamProbe extends BoundaryStage { - override def onDownstreamFinish(ctxt: BoundaryContext): TerminationDirective = { + override def onDownstreamFinish(ctx: BoundaryContext): TerminationDirective = { lastEvent += Cancel - ctxt.exit() + ctx.exit() } - override def onPull(ctxt: BoundaryContext): Directive = { + override def onPull(ctx: BoundaryContext): Directive = { lastEvent += RequestOne - ctxt.exit() + ctx.exit() } - override def onPush(elem: Any, ctxt: BoundaryContext): Directive = + override def onPush(elem: Any, ctx: BoundaryContext): Directive = throw new UnsupportedOperationException("Cannot push the boundary") def onNext(elem: Any): Unit = enter().push(elem) @@ -66,23 +67,23 @@ trait InterpreterSpecKit extends AkkaSpec { } - class DownstreamProbe extends BoundaryOp { - override def onPush(elem: Any, ctxt: BoundaryContext): Directive = { + class DownstreamProbe extends BoundaryStage { + override def onPush(elem: Any, ctx: BoundaryContext): Directive = { lastEvent += OnNext(elem) - ctxt.exit() + ctx.exit() } - override def onUpstreamFinish(ctxt: BoundaryContext): TerminationDirective = { + override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = { lastEvent += OnComplete - ctxt.exit() + ctx.exit() } - override def onFailure(cause: Throwable, ctxt: BoundaryContext): TerminationDirective = { + override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = { lastEvent += OnError(cause) - ctxt.exit() + ctx.exit() } - override def onPull(ctxt: BoundaryContext): Directive = + override def onPull(ctx: BoundaryContext): Directive = throw new UnsupportedOperationException("Cannot pull the boundary") def requestOne(): Unit = enter().pull() diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala index f69b71cd29..78a12884e5 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/IteratorInterpreterSpec.scala @@ -1,9 +1,12 @@ +/** + * Copyright (C) 2014 Typesafe Inc. + */ package akka.stream.impl.fusing +import scala.collection.immutable import akka.stream.testkit.AkkaSpec import akka.util.ByteString - -import scala.collection.immutable +import akka.stream.stage._ class IteratorInterpreterSpec extends AkkaSpec { @@ -45,28 +48,32 @@ class IteratorInterpreterSpec extends AkkaSpec { "throw exceptions when chain fails" in { val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq( - new TransitivePullOp[Int, Int] { - override def onPush(elem: Int, ctxt: Context[Int]): Directive = { - if (elem == 2) ctxt.fail(new ArithmeticException()) - else ctxt.push(elem) + new PushStage[Int, Int] { + override def onPush(elem: Int, ctx: Context[Int]): Directive = { + if (elem == 2) ctx.fail(new ArithmeticException()) + else ctx.push(elem) } })).iterator itr.next() should be(1) + itr.hasNext should be(true) a[ArithmeticException] should be thrownBy { itr.next() } + itr.hasNext should be(false) } "throw exceptions when op in chain throws" in { val itr = new IteratorInterpreter[Int, Int](List(1, 2, 3).iterator, Seq( - new TransitivePullOp[Int, Int] { - override def onPush(elem: Int, ctxt: Context[Int]): Directive = { + new PushStage[Int, Int] { + override def onPush(elem: Int, ctx: Context[Int]): Directive = { if (elem == 2) throw new ArithmeticException() - else ctxt.push(elem) + else ctx.push(elem) } })).iterator itr.next() should be(1) + itr.hasNext should be(true) a[ArithmeticException] should be thrownBy { itr.next() } + itr.hasNext should be(false) } "work with an empty iterator" in { @@ -108,47 +115,47 @@ class IteratorInterpreterSpec extends AkkaSpec { } // This op needs an extra pull round to finish - case class NaiveTake[T](count: Int) extends DeterministicOp[T, T] { + case class NaiveTake[T](count: Int) extends PushPullStage[T, T] { private var left: Int = count - override def onPush(elem: T, ctxt: Context[T]): Directive = { + override def onPush(elem: T, ctx: Context[T]): Directive = { left -= 1 - ctxt.push(elem) + ctx.push(elem) } - override def onPull(ctxt: Context[T]): Directive = { - if (left == 0) ctxt.finish() - else ctxt.pull() + override def onPull(ctx: Context[T]): Directive = { + if (left == 0) ctx.finish() + else ctx.pull() } } - case class ByteStringBatcher(threshold: Int, compact: Boolean = true) extends DeterministicOp[ByteString, ByteString] { + case class ByteStringBatcher(threshold: Int, compact: Boolean = true) extends PushPullStage[ByteString, ByteString] { require(threshold > 0, "Threshold must be positive") private var buf = ByteString.empty private var passthrough = false - override def onPush(elem: ByteString, ctxt: Context[ByteString]): Directive = { - if (passthrough) ctxt.push(elem) + override def onPush(elem: ByteString, ctx: Context[ByteString]): Directive = { + if (passthrough) ctx.push(elem) else { buf = buf ++ elem if (buf.size >= threshold) { val batch = if (compact) buf.compact else buf passthrough = true buf = ByteString.empty - ctxt.push(batch) - } else ctxt.pull() + ctx.push(batch) + } else ctx.pull() } } - override def onPull(ctxt: Context[ByteString]): Directive = { - if (isFinishing) ctxt.pushAndFinish(buf) - else ctxt.pull() + override def onPull(ctx: Context[ByteString]): Directive = { + if (ctx.isFinishing) ctx.pushAndFinish(buf) + else ctx.pull() } - override def onUpstreamFinish(ctxt: Context[ByteString]): TerminationDirective = { - if (passthrough || buf.isEmpty) ctxt.finish() - else ctxt.absorbTermination() + override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = { + if (passthrough || buf.isEmpty) ctx.finish() + else ctx.absorbTermination() } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGraphCompileSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGraphCompileSpec.scala index c5d7be6650..10204d9449 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGraphCompileSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowGraphCompileSpec.scala @@ -3,10 +3,11 @@ */ package akka.stream.scaladsl -import akka.stream.{ OverflowStrategy, Transformer } +import akka.stream.OverflowStrategy import akka.stream.FlowMaterializer import akka.stream.testkit.AkkaSpec import akka.stream.testkit.StreamTestKit.{ PublisherProbe, SubscriberProbe } +import akka.stream.stage._ object FlowGraphCompileSpec { class Fruit @@ -18,9 +19,10 @@ class FlowGraphCompileSpec extends AkkaSpec { implicit val mat = FlowMaterializer() - def op[In, Out]: () ⇒ Transformer[In, Out] = { () ⇒ - new Transformer[In, Out] { - override def onNext(elem: In) = List(elem.asInstanceOf[Out]) + def op[In, Out]: () ⇒ PushStage[In, Out] = { () ⇒ + new PushStage[In, Out] { + override def onPush(elem: In, ctx: Context[Out]): Directive = + ctx.push(elem.asInstanceOf[Out]) } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSpec.scala index c751c47f94..dfcb133af0 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowSpec.scala @@ -4,13 +4,10 @@ package akka.stream.scaladsl import java.util.concurrent.atomic.AtomicLong - import akka.dispatch.Dispatchers -import akka.stream.impl.fusing.{ Op, ActorInterpreter } - +import akka.stream.stage.Stage import scala.collection.immutable import scala.concurrent.duration._ - import akka.actor._ import akka.stream.{ TransformerLike, MaterializerSettings } import akka.stream.FlowMaterializer @@ -22,6 +19,8 @@ import akka.testkit._ import akka.testkit.TestEvent.{ UnMute, Mute } import com.typesafe.config.ConfigFactory import org.reactivestreams.{ Processor, Subscriber, Publisher } +import akka.stream.impl.fusing.ActorInterpreter +import scala.util.control.NoStackTrace object FlowSpec { class Fruit @@ -32,25 +31,9 @@ object FlowSpec { case class BrokenMessage(msg: String) - class BrokenTransformProcessorImpl( - _settings: MaterializerSettings, - transformer: TransformerLike[Any, Any], - brokenMessage: Any) extends TransformProcessorImpl(_settings, transformer) { - - import akka.stream.actor.ActorSubscriberMessage._ - - override protected[akka] def aroundReceive(receive: Receive, msg: Any) = { - msg match { - case OnNext(m) if m == brokenMessage ⇒ - throw new NullPointerException(s"I'm so broken [$m]") - case _ ⇒ super.aroundReceive(receive, msg) - } - } - } - class BrokenActorInterpreter( _settings: MaterializerSettings, - _ops: Seq[Op[_, _, _, _, _]], + _ops: Seq[Stage[_, _]], brokenMessage: Any) extends ActorInterpreter(_settings, _ops) { @@ -76,7 +59,6 @@ object FlowSpec { override def processorForNode[In, Out](op: AstNode, flowName: String, n: Int): Processor[In, Out] = { val props = op match { - case t: Transform ⇒ Props(new BrokenTransformProcessorImpl(settings, t.mkTransformer(), brokenMessage)) case f: Fused ⇒ Props(new BrokenActorInterpreter(settings, f.ops, brokenMessage)).withDispatcher(settings.dispatcher) case Map(f) ⇒ Props(new BrokenActorInterpreter(settings, List(fusing.Map(f)), brokenMessage)) case Filter(p) ⇒ Props(new BrokenActorInterpreter(settings, List(fusing.Filter(p)), brokenMessage)) @@ -639,6 +621,6 @@ class FlowSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug.rece } } - object TestException extends RuntimeException + object TestException extends RuntimeException with NoStackTrace } diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowTransformSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala similarity index 56% rename from akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowTransformSpec.scala rename to akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala index f1e7f29b42..eba691bcf2 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowTransformSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowStageSpec.scala @@ -6,15 +6,14 @@ package akka.stream.scaladsl import scala.collection.immutable.Seq import scala.concurrent.duration._ import scala.util.control.NoStackTrace - import akka.stream.FlowMaterializer import akka.stream.MaterializerSettings -import akka.stream.Transformer import akka.stream.testkit.{ AkkaSpec, StreamTestKit } import akka.testkit.{ EventFilter, TestProbe } import com.typesafe.config.ConfigFactory +import akka.stream.stage._ -class FlowTransformSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug.receive=off\nakka.loglevel=INFO")) { +class FlowStageSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.debug.receive=off\nakka.loglevel=INFO")) { val settings = MaterializerSettings(system) .withInputBuffer(initialSize = 2, maxSize = 2) @@ -26,11 +25,11 @@ class FlowTransformSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.d "produce one-to-one transformation as expected" in { val p = Source(List(1, 2, 3)).runWith(Sink.publisher) val p2 = Source(p). - transform("transform", () ⇒ new Transformer[Int, Int] { + transform("transform", () ⇒ new PushStage[Int, Int] { var tot = 0 - override def onNext(elem: Int) = { + override def onPush(elem: Int, ctx: Context[Int]) = { tot += elem - List(tot) + ctx.push(tot) } }). runWith(Sink.publisher) @@ -49,12 +48,23 @@ class FlowTransformSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.d "produce one-to-several transformation as expected" in { val p = Source(List(1, 2, 3)).runWith(Sink.publisher) val p2 = Source(p). - transform("transform", () ⇒ new Transformer[Int, Int] { + transform("transform", () ⇒ new StatefulStage[Int, Int] { var tot = 0 - override def onNext(elem: Int) = { - tot += elem - Vector.fill(elem)(tot) + + lazy val waitForNext = new State { + override def onPush(elem: Int, ctx: Context[Int]) = { + tot += elem + emit(Iterator.fill(elem)(tot), ctx) + } } + + override def initial = waitForNext + + override def onUpstreamFinish(ctx: Context[Int]): TerminationDirective = { + if (current eq waitForNext) ctx.finish() + else ctx.absorbTermination() + } + }). runWith(Sink.publisher) val subscriber = StreamTestKit.SubscriberProbe[Int]() @@ -75,15 +85,14 @@ class FlowTransformSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.d "produce dropping transformation as expected" in { val p = Source(List(1, 2, 3, 4)).runWith(Sink.publisher) val p2 = Source(p). - transform("transform", () ⇒ new Transformer[Int, Int] { + transform("transform", () ⇒ new PushStage[Int, Int] { var tot = 0 - override def onNext(elem: Int) = { + override def onPush(elem: Int, ctx: Context[Int]) = { tot += elem - if (elem % 2 == 0) { - Nil - } else { - List(tot) - } + if (elem % 2 == 0) + ctx.pull() + else + ctx.push(tot) } }). runWith(Sink.publisher) @@ -102,18 +111,18 @@ class FlowTransformSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.d "produce multi-step transformation as expected" in { val p = Source(List("a", "bc", "def")).runWith(Sink.publisher) val p2 = Source(p). - transform("transform", () ⇒ new Transformer[String, Int] { + transform("transform", () ⇒ new PushStage[String, Int] { var concat = "" - override def onNext(elem: String) = { + override def onPush(elem: String, ctx: Context[Int]) = { concat += elem - List(concat.length) + ctx.push(concat.length) } }). - transform("transform", () ⇒ new Transformer[Int, Int] { + transform("transform", () ⇒ new PushStage[Int, Int] { var tot = 0 - override def onNext(length: Int) = { + override def onPush(length: Int, ctx: Context[Int]) = { tot += length - List(tot) + ctx.push(tot) } }). runWith(Sink.fanoutPublisher(2, 2)) @@ -138,104 +147,41 @@ class FlowTransformSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.d c2.expectComplete() } - "invoke onComplete when done" in { + "support emit onUpstreamFinish" in { val p = Source(List("a")).runWith(Sink.publisher) val p2 = Source(p). - transform("transform", () ⇒ new Transformer[String, String] { + transform("transform", () ⇒ new StatefulStage[String, String] { var s = "" - override def onNext(element: String) = { - s += element - Nil - } - override def onTermination(e: Option[Throwable]) = List(s + "B") - }). - runWith(Sink.publisher) - val c = StreamTestKit.SubscriberProbe[String]() - p2.subscribe(c) - val s = c.expectSubscription() - s.request(1) - c.expectNext("aB") - c.expectComplete() - } - - "invoke cleanup when done" in { - val cleanupProbe = TestProbe() - val p = Source(List("a")).runWith(Sink.publisher) - val p2 = Source(p). - transform("transform", () ⇒ new Transformer[String, String] { - var s = "" - override def onNext(element: String) = { - s += element - Nil - } - override def onTermination(e: Option[Throwable]) = List(s + "B") - override def cleanup() = cleanupProbe.ref ! s - }). - runWith(Sink.publisher) - val c = StreamTestKit.SubscriberProbe[String]() - p2.subscribe(c) - val s = c.expectSubscription() - s.request(1) - c.expectNext("aB") - c.expectComplete() - cleanupProbe.expectMsg("a") - } - - "invoke cleanup when done consume" in { - val cleanupProbe = TestProbe() - val p = Source(List("a")).runWith(Sink.publisher) - Source(p). - transform("transform", () ⇒ new Transformer[String, String] { - var s = "x" - override def onNext(element: String) = { - s = element - List(element) - } - override def cleanup() = cleanupProbe.ref ! s - }). - to(Sink.ignore).run() - cleanupProbe.expectMsg("a") - } - - "invoke cleanup when done after error" in { - val cleanupProbe = TestProbe() - val p = Source(List("a", "b", "c")).runWith(Sink.publisher) - val p2 = Source(p). - transform("transform", () ⇒ new Transformer[String, String] { - var s = "" - override def onNext(in: String) = { - if (in == "b") { - throw new IllegalArgumentException("Not b") with NoStackTrace - } else { - val out = s + in - s += in.toUpperCase - List(out) + override def initial = new State { + override def onPush(element: String, ctx: Context[String]) = { + s += element + ctx.pull() } } - override def onTermination(e: Option[Throwable]) = List(s + "B") - override def cleanup() = cleanupProbe.ref ! s + override def onUpstreamFinish(ctx: Context[String]) = + terminationEmit(Iterator.single(s + "B"), ctx) }). runWith(Sink.publisher) val c = StreamTestKit.SubscriberProbe[String]() p2.subscribe(c) val s = c.expectSubscription() s.request(1) - c.expectNext("a") - s.request(1) - c.expectError() - cleanupProbe.expectMsg("A") + c.expectNext("aB") + c.expectComplete() } - "allow cancellation using isComplete" in { + "allow early finish" in { val p = StreamTestKit.PublisherProbe[Int]() val p2 = Source(p). - transform("transform", () ⇒ new Transformer[Int, Int] { + transform("transform", () ⇒ new PushStage[Int, Int] { var s = "" - override def onNext(element: Int) = { + override def onPush(element: Int, ctx: Context[Int]) = { s += element - List(element) + if (s == "1") + ctx.pushAndFinish(element) + else + ctx.push(element) } - override def isComplete = s == "1" }). runWith(Sink.publisher) val proc = p.expectSubscription @@ -250,44 +196,17 @@ class FlowTransformSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.d proc.expectCancellation() } - "call onComplete after isComplete signaled completion" in { - val cleanupProbe = TestProbe() - val p = StreamTestKit.PublisherProbe[Int]() - val p2 = Source(p). - transform("transform", () ⇒ new Transformer[Int, Int] { - var s = "" - override def onNext(element: Int) = { - s += element - List(element) - } - override def isComplete = s == "1" - override def onTermination(e: Option[Throwable]) = List(s.length + 10) - override def cleanup() = cleanupProbe.ref ! s - }). - runWith(Sink.publisher) - val proc = p.expectSubscription - val c = StreamTestKit.SubscriberProbe[Int]() - p2.subscribe(c) - val s = c.expectSubscription() - s.request(10) - proc.sendNext(1) - proc.sendNext(2) - c.expectNext(1) - c.expectNext(11) - c.expectComplete() - proc.expectCancellation() - cleanupProbe.expectMsg("1") - } - "report error when exception is thrown" in { val p = Source(List(1, 2, 3)).runWith(Sink.publisher) val p2 = Source(p). - transform("transform", () ⇒ new Transformer[Int, Int] { - override def onNext(elem: Int) = { - if (elem == 2) { - throw new IllegalArgumentException("two not allowed") - } else { - List(elem, elem) + transform("transform", () ⇒ new StatefulStage[Int, Int] { + override def initial = new State { + override def onPush(elem: Int, ctx: Context[Int]) = { + if (elem == 2) { + throw new IllegalArgumentException("two not allowed") + } else { + emit(Iterator(elem, elem), ctx) + } } } }). @@ -304,11 +223,41 @@ class FlowTransformSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.d } } + "support emit of final elements when onUpstreamFailure" in { + val p = Source(List(1, 2, 3)).runWith(Sink.publisher) + val p2 = Source(p). + map(elem ⇒ if (elem == 2) throw new IllegalArgumentException("two not allowed") else elem). + transform("transform", () ⇒ new StatefulStage[Int, Int] { + override def initial = new State { + override def onPush(elem: Int, ctx: Context[Int]) = ctx.push(elem) + } + + override def onUpstreamFailure(cause: Throwable, ctx: Context[Int]) = { + terminationEmit(Iterator(100, 101), ctx) + } + }). + filter(elem ⇒ elem != 1). // it's undefined if element 1 got through before the error or not + runWith(Sink.publisher) + val subscriber = StreamTestKit.SubscriberProbe[Int]() + p2.subscribe(subscriber) + val subscription = subscriber.expectSubscription() + EventFilter[IllegalArgumentException]("two not allowed") intercept { + subscription.request(100) + subscriber.expectNext(100) + subscriber.expectNext(101) + subscriber.expectComplete() + subscriber.expectNoMsg(200.millis) + } + } + "support cancel as expected" in { val p = Source(List(1, 2, 3)).runWith(Sink.publisher) val p2 = Source(p). - transform("transform", () ⇒ new Transformer[Int, Int] { - override def onNext(elem: Int) = List(elem, elem) + transform("transform", () ⇒ new StatefulStage[Int, Int] { + override def initial = new State { + override def onPush(elem: Int, ctx: Context[Int]) = + emit(Iterator(elem, elem), ctx) + } }). runWith(Sink.publisher) val subscriber = StreamTestKit.SubscriberProbe[Int]() @@ -326,9 +275,12 @@ class FlowTransformSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.d "support producing elements from empty inputs" in { val p = Source(List.empty[Int]).runWith(Sink.publisher) val p2 = Source(p). - transform("transform", () ⇒ new Transformer[Int, Int] { - override def onNext(elem: Int) = Nil - override def onTermination(e: Option[Throwable]) = List(1, 2, 3) + transform("transform", () ⇒ new StatefulStage[Int, Int] { + override def initial = new State { + override def onPush(elem: Int, ctx: Context[Int]) = ctx.pull() + } + override def onUpstreamFinish(ctx: Context[Int]) = + terminationEmit(Iterator(1, 2, 3), ctx) }). runWith(Sink.publisher) val subscriber = StreamTestKit.SubscriberProbe[Int]() @@ -344,26 +296,24 @@ class FlowTransformSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.d "support converting onComplete into onError" in { val subscriber = StreamTestKit.SubscriberProbe[Int]() - Source(List(5, 1, 2, 3)).transform("transform", () ⇒ new Transformer[Int, Int] { + Source(List(5, 1, 2, 3)).transform("transform", () ⇒ new PushStage[Int, Int] { var expectedNumberOfElements: Option[Int] = None var count = 0 - override def onNext(elem: Int) = + override def onPush(elem: Int, ctx: Context[Int]) = if (expectedNumberOfElements.isEmpty) { expectedNumberOfElements = Some(elem) - Nil + ctx.pull() } else { count += 1 - List(elem) + ctx.push(elem) + } + + override def onUpstreamFinish(ctx: Context[Int]) = + expectedNumberOfElements match { + case Some(expected) if count != expected ⇒ + throw new RuntimeException(s"Expected $expected, got $count") with NoStackTrace + case _ ⇒ ctx.finish() } - override def onTermination(err: Option[Throwable]) = err match { - case Some(e) ⇒ Nil - case None ⇒ - expectedNumberOfElements match { - case Some(expected) if count != expected ⇒ - throw new RuntimeException(s"Expected $expected, got $count") with NoStackTrace - case _ ⇒ Nil - } - } }).to(Sink(subscriber)).run() val subscription = subscriber.expectSubscription() @@ -377,12 +327,12 @@ class FlowTransformSpec extends AkkaSpec(ConfigFactory.parseString("akka.actor.d "be safe to reuse" in { val flow = Source(1 to 3).transform("transform", () ⇒ - new Transformer[Int, Int] { + new PushStage[Int, Int] { var count = 0 - override def onNext(elem: Int): Seq[Int] = { + override def onPush(elem: Int, ctx: Context[Int]) = { count += 1 - List(count) + ctx.push(count) } }) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowTransformRecoverSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowTransformRecoverSpec.scala deleted file mode 100644 index bd65eb64b6..0000000000 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowTransformRecoverSpec.scala +++ /dev/null @@ -1,390 +0,0 @@ -/** - * Copyright (C) 2014 Typesafe Inc. - */ -package akka.stream.scaladsl - -import scala.collection.immutable -import scala.concurrent.duration._ -import scala.util.Failure -import scala.util.Success -import scala.util.Try -import scala.util.control.NoStackTrace - -import akka.stream.FlowMaterializer -import akka.stream.MaterializerSettings -import akka.stream.Transformer -import akka.stream.testkit.AkkaSpec -import akka.stream.testkit.StreamTestKit -import akka.testkit.EventFilter - -object FlowTransformRecoverSpec { - abstract class TryRecoveryTransformer[T, U] extends Transformer[T, U] { - def onNext(element: Try[T]): immutable.Seq[U] - - override def onNext(element: T): immutable.Seq[U] = onNext(Success(element)) - override def onError(cause: Throwable) = () - override def onTermination(cause: Option[Throwable]): immutable.Seq[U] = cause match { - case None ⇒ Nil - case Some(e) ⇒ onNext(Failure(e)) - } - } -} - -class FlowTransformRecoverSpec extends AkkaSpec { - import FlowTransformRecoverSpec._ - - val settings = MaterializerSettings(system) - .withInputBuffer(initialSize = 2, maxSize = 2) - .withFanOutBuffer(initialSize = 2, maxSize = 2) - - implicit val materializer = FlowMaterializer(settings) - - "A Flow with transformRecover operations" must { - "produce one-to-one transformation as expected" in { - val p = Source(1 to 3).runWith(Sink.publisher) - val p2 = Source(p). - transform("transform", () ⇒ new Transformer[Int, Int] { - var tot = 0 - override def onNext(elem: Int) = { - tot += elem - List(tot) - } - override def onError(e: Throwable) = () - override def onTermination(e: Option[Throwable]) = e match { - case None ⇒ Nil - case Some(_) ⇒ List(-1) - } - }). - runWith(Sink.publisher) - val subscriber = StreamTestKit.SubscriberProbe[Int]() - p2.subscribe(subscriber) - val subscription = subscriber.expectSubscription() - subscription.request(1) - subscriber.expectNext(1) - subscriber.expectNoMsg(200.millis) - subscription.request(2) - subscriber.expectNext(3) - subscriber.expectNext(6) - subscriber.expectComplete() - } - - "produce one-to-several transformation as expected" in { - val p = Source(1 to 3).runWith(Sink.publisher) - val p2 = Source(p). - transform("transform", () ⇒ new Transformer[Int, Int] { - var tot = 0 - override def onNext(elem: Int) = { - tot += elem - Vector.fill(elem)(tot) - } - override def onError(e: Throwable) = () - override def onTermination(e: Option[Throwable]) = e match { - case None ⇒ Nil - case Some(_) ⇒ List(-1) - } - }). - runWith(Sink.publisher) - val subscriber = StreamTestKit.SubscriberProbe[Int]() - p2.subscribe(subscriber) - val subscription = subscriber.expectSubscription() - subscription.request(4) - subscriber.expectNext(1) - subscriber.expectNext(3) - subscriber.expectNext(3) - subscriber.expectNext(6) - subscriber.expectNoMsg(200.millis) - subscription.request(100) - subscriber.expectNext(6) - subscriber.expectNext(6) - subscriber.expectComplete() - } - - "produce dropping transformation as expected" in { - val p = Source(1 to 4).runWith(Sink.publisher) - val p2 = Source(p). - transform("transform", () ⇒ new Transformer[Int, Int] { - var tot = 0 - override def onNext(elem: Int) = { - tot += elem - if (elem % 2 == 0) Nil else List(tot) - } - override def onError(e: Throwable) = () - override def onTermination(e: Option[Throwable]) = e match { - case None ⇒ Nil - case Some(_) ⇒ List(-1) - } - }). - runWith(Sink.publisher) - val subscriber = StreamTestKit.SubscriberProbe[Int]() - p2.subscribe(subscriber) - val subscription = subscriber.expectSubscription() - subscription.request(1) - subscriber.expectNext(1) - subscriber.expectNoMsg(200.millis) - subscription.request(1) - subscriber.expectNext(6) - subscription.request(1) - subscriber.expectComplete() - } - - "produce multi-step transformation as expected" in { - val p = Source(List("a", "bc", "def")).runWith(Sink.publisher) - val p2 = Source(p). - transform("transform", () ⇒ new TryRecoveryTransformer[String, Int] { - var concat = "" - override def onNext(element: Try[String]) = { - concat += element - List(concat.length) - } - }). - transform("transform", () ⇒ new Transformer[Int, Int] { - var tot = 0 - override def onNext(length: Int) = { - tot += length - List(tot) - } - override def onError(e: Throwable) = () - override def onTermination(e: Option[Throwable]) = e match { - case None ⇒ Nil - case Some(_) ⇒ List(-1) - } - }).runWith(Sink.fanoutPublisher(1, 1)) - val c1 = StreamTestKit.SubscriberProbe[Int]() - p2.subscribe(c1) - val sub1 = c1.expectSubscription() - val c2 = StreamTestKit.SubscriberProbe[Int]() - p2.subscribe(c2) - val sub2 = c2.expectSubscription() - sub1.request(1) - sub2.request(2) - c1.expectNext(10) - c2.expectNext(10) - c2.expectNext(31) - c1.expectNoMsg(200.millis) - sub1.request(2) - sub2.request(2) - c1.expectNext(31) - c1.expectNext(64) - c2.expectNext(64) - c1.expectComplete() - c2.expectComplete() - } - - "invoke onComplete when done" in { - val p = Source(List("a")).runWith(Sink.publisher) - val p2 = Source(p). - transform("transform", () ⇒ new TryRecoveryTransformer[String, String] { - var s = "" - override def onNext(element: Try[String]) = { - s += element - Nil - } - override def onTermination(e: Option[Throwable]) = List(s + "B") - }). - runWith(Sink.publisher) - val c = StreamTestKit.SubscriberProbe[String]() - p2.subscribe(c) - val s = c.expectSubscription() - s.request(1) - c.expectNext("Success(a)B") - c.expectComplete() - } - - "allow cancellation using isComplete" in { - val p = StreamTestKit.PublisherProbe[Int]() - val p2 = Source(p). - transform("transform", () ⇒ new TryRecoveryTransformer[Int, Int] { - var s = "" - override def onNext(element: Try[Int]) = { - s += element - List(element.get) - } - override def isComplete = s == "Success(1)" - }). - runWith(Sink.publisher) - val proc = p.expectSubscription - val c = StreamTestKit.SubscriberProbe[Int]() - p2.subscribe(c) - val s = c.expectSubscription() - s.request(10) - proc.sendNext(1) - proc.sendNext(2) - c.expectNext(1) - c.expectComplete() - proc.expectCancellation() - } - - "call onComplete after isComplete signaled completion" in { - val p = StreamTestKit.PublisherProbe[Int]() - val p2 = Source(p). - transform("transform", () ⇒ new TryRecoveryTransformer[Int, Int] { - var s = "" - override def onNext(element: Try[Int]) = { - s += element - List(element.get) - } - override def isComplete = s == "Success(1)" - override def onTermination(e: Option[Throwable]) = List(s.length + 10) - }). - runWith(Sink.publisher) - val proc = p.expectSubscription - val c = StreamTestKit.SubscriberProbe[Int]() - p2.subscribe(c) - val s = c.expectSubscription() - s.request(10) - proc.sendNext(1) - proc.sendNext(2) - c.expectNext(1) - c.expectNext(20) - c.expectComplete() - proc.expectCancellation() - } - - "report error when exception is thrown" in { - val p = Source(1 to 3).runWith(Sink.publisher) - val p2 = Source(p). - transform("transform", () ⇒ new Transformer[Int, Int] { - override def onNext(elem: Int) = { - if (elem == 2) throw new IllegalArgumentException("two not allowed") - else List(elem, elem) - } - override def onError(e: Throwable) = List(-1) - }). - runWith(Sink.publisher) - val subscriber = StreamTestKit.SubscriberProbe[Int]() - p2.subscribe(subscriber) - val subscription = subscriber.expectSubscription() - EventFilter[IllegalArgumentException]("two not allowed") intercept { - subscription.request(1) - subscriber.expectNext(1) - subscriber.expectNoMsg(200.millis) - subscription.request(100) - subscriber.expectNext(1) - subscriber.expectError().getMessage should be("two not allowed") - subscriber.expectNoMsg(200.millis) - } - } - - "report error after emitted elements" in { - EventFilter[IllegalArgumentException]("two not allowed") intercept { - val p2 = Source(1 to 3). - mapConcat { elem ⇒ - if (elem == 2) throw new IllegalArgumentException("two not allowed") - else (1 to 5).map(elem * 100 + _) - }. - transform("transform", () ⇒ new Transformer[Int, Int] { - override def onNext(elem: Int) = List(elem) - override def onError(e: Throwable) = () - override def onTermination(e: Option[Throwable]) = e match { - case None ⇒ Nil - case Some(_) ⇒ List(-1, -2, -3) - } - }). - runWith(Sink.publisher) - val subscriber = StreamTestKit.SubscriberProbe[Int]() - p2.subscribe(subscriber) - val subscription = subscriber.expectSubscription() - - subscription.request(1) - subscriber.expectNext(101) - subscriber.expectNoMsg(100.millis) - subscription.request(1) - subscriber.expectNext(102) - subscriber.expectNoMsg(100.millis) - subscription.request(1) - subscriber.expectNext(103) - subscriber.expectNoMsg(100.millis) - subscription.request(1) - subscriber.expectNext(104) - subscriber.expectNoMsg(100.millis) - subscription.request(1) - subscriber.expectNext(105) - subscriber.expectNoMsg(100.millis) - - subscription.request(1) - subscriber.expectNext(-1) - subscriber.expectNoMsg(100.millis) - subscription.request(10) - subscriber.expectNext(-2) - subscriber.expectNext(-3) - subscriber.expectComplete() - subscriber.expectNoMsg(200.millis) - } - } - - case class TE(message: String) extends RuntimeException(message) with NoStackTrace - - "transform errors in sequence with normal messages" in { - val p = StreamTestKit.PublisherProbe[Int]() - val p2 = Source(p). - transform("transform", () ⇒ new Transformer[Int, String] { - var s = "" - override def onNext(element: Int) = { - s += element.toString - List(s) - } - override def onError(ex: Throwable) = () - override def onTermination(ex: Option[Throwable]) = { - ex match { - case None ⇒ Nil - case Some(e) ⇒ - s += e.getMessage - List(s) - } - } - }). - runWith(Sink.publisher) - val proc = p.expectSubscription() - val c = StreamTestKit.SubscriberProbe[String]() - p2.subscribe(c) - val s = c.expectSubscription() - proc.sendNext(0) - proc.sendError(TE("1")) - // Request late to prove the in-sequence nature - s.request(10) - c.expectNext("0") - c.expectNext("01") - c.expectComplete() - } - - "forward errors when received and thrown" in { - val p = StreamTestKit.PublisherProbe[Int]() - val p2 = Source(p). - transform("transform", () ⇒ new Transformer[Int, Int] { - override def onNext(in: Int) = List(in) - override def onError(e: Throwable) = throw e - }). - runWith(Sink.publisher) - val proc = p.expectSubscription() - val c = StreamTestKit.SubscriberProbe[Int]() - p2.subscribe(c) - val s = c.expectSubscription() - s.request(10) - EventFilter[TE](occurrences = 1) intercept { - proc.sendError(TE("1")) - c.expectError(TE("1")) - } - } - - "support cancel as expected" in { - val p = Source(1 to 3).runWith(Sink.publisher) - val p2 = Source(p). - transform("transform", () ⇒ new Transformer[Int, Int] { - override def onNext(elem: Int) = List(elem, elem) - override def onError(e: Throwable) = List(-1) - }). - runWith(Sink.publisher) - val subscriber = StreamTestKit.SubscriberProbe[Int]() - p2.subscribe(subscriber) - val subscription = subscriber.expectSubscription() - subscription.request(2) - subscriber.expectNext(1) - subscription.cancel() - subscriber.expectNext(1) - subscriber.expectNoMsg(500.millis) - subscription.request(2) - subscriber.expectNoMsg(200.millis) - } - } - -} \ No newline at end of file diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/OptimizingActorBasedFlowMaterializerSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/OptimizingActorBasedFlowMaterializerSpec.scala index 901a815178..c17e986f57 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/OptimizingActorBasedFlowMaterializerSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/OptimizingActorBasedFlowMaterializerSpec.scala @@ -20,7 +20,7 @@ class OptimizingActorBasedFlowMaterializerSpec extends AkkaSpec with ImplicitSen val f = Source(1 to 100). drop(4). drop(5). - transform("identity", () ⇒ FlowOps.identityTransformer). + transform("identity", () ⇒ FlowOps.identityStage). filter(_ % 2 == 0). map(_ * 2). map(identity). diff --git a/akka-stream/src/main/scala/akka/stream/TimerTransformer.scala b/akka-stream/src/main/scala/akka/stream/TimerTransformer.scala index 476b5a9978..b36ae5be28 100644 --- a/akka-stream/src/main/scala/akka/stream/TimerTransformer.scala +++ b/akka-stream/src/main/scala/akka/stream/TimerTransformer.scala @@ -9,7 +9,7 @@ import scala.collection.{ immutable, mutable } import scala.concurrent.duration.FiniteDuration /** - * [[Transformer]] with support for scheduling keyed (named) timer events. + * Transformer with support for scheduling keyed (named) timer events. */ abstract class TimerTransformer[-T, +U] extends TransformerLike[T, U] { import TimerTransformer._ diff --git a/akka-stream/src/main/scala/akka/stream/Transformer.scala b/akka-stream/src/main/scala/akka/stream/Transformer.scala index 5c2b898e26..be2b5e15b0 100644 --- a/akka-stream/src/main/scala/akka/stream/Transformer.scala +++ b/akka-stream/src/main/scala/akka/stream/Transformer.scala @@ -24,8 +24,8 @@ abstract class TransformerLike[-T, +U] { * to produce a (possibly empty) sequence of elements in response to the * end-of-stream event. * - * This method is only called if [[Transformer#onError]] does not throw an exception. The default implementation - * of [[Transformer#onError]] throws the received cause forcing the error to propagate downstream immediately. + * This method is only called if [[#onError]] does not throw an exception. The default implementation + * of [[#onError]] throws the received cause forcing the error to propagate downstream immediately. * * @param e Contains a non-empty option with the error causing the termination or an empty option * if the Transformer was completed normally @@ -34,7 +34,7 @@ abstract class TransformerLike[-T, +U] { /** * Invoked when failure is signaled from upstream. If this method throws an exception, then onError is immediately - * propagated downstream. If this method completes normally then [[Transformer#onTermination]] is invoked as a final + * propagated downstream. If this method completes normally then [[#onTermination]] is invoked as a final * step, passing the original cause. */ def onError(cause: Throwable): Unit = throw cause @@ -46,15 +46,3 @@ abstract class TransformerLike[-T, +U] { } -/** - * General interface for stream transformation. - * - * It is possible to keep state in the concrete [[Transformer]] instance with - * ordinary instance variables. The [[Transformer]] is executed by an actor and - * therefore you don not have to add any additional thread safety or memory - * visibility constructs to access the state from the callback methods. - * - * @see [[akka.stream.scaladsl.Flow#transform]] - * @see [[akka.stream.javadsl.Flow#transform]] - */ -abstract class Transformer[-T, +U] extends TransformerLike[T, U] diff --git a/akka-stream/src/main/scala/akka/stream/extra/Timed.scala b/akka-stream/src/main/scala/akka/stream/extra/Timed.scala index 39f9de0215..89daa87334 100644 --- a/akka-stream/src/main/scala/akka/stream/extra/Timed.scala +++ b/akka-stream/src/main/scala/akka/stream/extra/Timed.scala @@ -8,9 +8,9 @@ import java.util.concurrent.atomic.AtomicLong import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.existentials -import akka.stream.Transformer import akka.stream.scaladsl.Source import akka.stream.scaladsl.Flow +import akka.stream.stage._ /** * Provides operations needed to implement the `timed` DSL @@ -99,41 +99,49 @@ object Timed extends TimedOps with TimedIntervalBetweenOps { } } - final class StartTimedFlow[T](ctx: TimedFlowContext) extends Transformer[T, T] { + final class StartTimedFlow[T](timedContext: TimedFlowContext) extends PushStage[T, T] { private var started = false - override def onNext(element: T) = { + override def onPush(elem: T, ctx: Context[T]): Directive = { if (!started) { - ctx.start() + timedContext.start() started = true } - - immutable.Seq(element) + ctx.push(elem) } } - final class StopTimed[T](ctx: TimedFlowContext, _onComplete: FiniteDuration ⇒ Unit) extends Transformer[T, T] { + final class StopTimed[T](timedContext: TimedFlowContext, _onComplete: FiniteDuration ⇒ Unit) extends PushStage[T, T] { - override def cleanup() { - val d = ctx.stop() + override def onPush(elem: T, ctx: Context[T]): Directive = ctx.push(elem) + + override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = { + stopTime() + ctx.fail(cause) + } + override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = { + stopTime() + ctx.finish() + } + private def stopTime() { + val d = timedContext.stop() _onComplete(d) } - override def onNext(element: T) = immutable.Seq(element) } - final class TimedIntervalTransformer[T](matching: T ⇒ Boolean, onInterval: FiniteDuration ⇒ Unit) extends Transformer[T, T] { + final class TimedIntervalTransformer[T](matching: T ⇒ Boolean, onInterval: FiniteDuration ⇒ Unit) extends PushStage[T, T] { private var prevNanos = 0L private var matched = 0L - override def onNext(in: T): immutable.Seq[T] = { - if (matching(in)) { - val d = updateInterval(in) + override def onPush(elem: T, ctx: Context[T]): Directive = { + if (matching(elem)) { + val d = updateInterval(elem) if (matched > 1) onInterval(d) } - immutable.Seq(in) + ctx.push(elem) } private def updateInterval(in: T): FiniteDuration = { diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala index 4e45e88d64..2a239c39d7 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorBasedFlowMaterializer.scala @@ -7,18 +7,17 @@ import java.util.concurrent.atomic.AtomicLong import akka.dispatch.Dispatchers import akka.event.Logging -import akka.stream.impl.fusing.{ ActorInterpreter, Op } - +import akka.stream.impl.fusing.ActorInterpreter import scala.annotation.tailrec import scala.collection.immutable import scala.concurrent.{ ExecutionContext, Await, Future } - import akka.actor._ -import akka.stream.{ FlowMaterializer, MaterializerSettings, OverflowStrategy, TimerTransformer, Transformer } +import akka.stream.{ FlowMaterializer, MaterializerSettings, OverflowStrategy, TimerTransformer } import akka.stream.MaterializationException import akka.stream.actor.ActorSubscriber import akka.stream.impl.Zip.ZipAs import akka.stream.scaladsl._ +import akka.stream.stage._ import akka.pattern.ask import org.reactivestreams.{ Processor, Publisher, Subscriber } @@ -29,20 +28,16 @@ private[akka] object Ast { sealed abstract class AstNode { def name: String } - // FIXME Replace with Operate - final case class Transform(name: String, mkTransformer: () ⇒ Transformer[Any, Any]) extends AstNode - // FIXME Replace with Operate - final case class TimerTransform(name: String, mkTransformer: () ⇒ TimerTransformer[Any, Any]) extends AstNode - final case class Operate(mkOp: () ⇒ fusing.Op[_, _, _, _, _]) extends AstNode { - override def name = "operate" - } + final case class TimerTransform(mkStage: () ⇒ TimerTransformer[Any, Any], override val name: String) extends AstNode + + final case class StageFactory(mkStage: () ⇒ Stage[_, _], override val name: String) extends AstNode object Fused { - def apply(ops: immutable.Seq[Op[_, _, _, _, _]]): Fused = + def apply(ops: immutable.Seq[Stage[_, _]]): Fused = Fused(ops, ops.map(x ⇒ Logging.simpleName(x).toLowerCase).mkString("+")) //FIXME change to something more performant for name } - final case class Fused(ops: immutable.Seq[Op[_, _, _, _, _]], override val name: String) extends AstNode + final case class Fused(ops: immutable.Seq[Stage[_, _]], override val name: String) extends AstNode final case class Map(f: Any ⇒ Any) extends AstNode { override def name = "map" } @@ -197,7 +192,7 @@ case class ActorBasedFlowMaterializer(override val settings: MaterializerSetting //FIXME Optimize the implementation of the optimizer (no joke) // AstNodes are in reverse order, Fusable Ops are in order private[this] final def optimize(ops: List[Ast.AstNode]): (List[Ast.AstNode], Int) = { - @tailrec def analyze(rest: List[Ast.AstNode], optimized: List[Ast.AstNode], fuseCandidates: List[fusing.Op[_, _, _, _, _]]): (List[Ast.AstNode], Int) = { + @tailrec def analyze(rest: List[Ast.AstNode], optimized: List[Ast.AstNode], fuseCandidates: List[Stage[_, _]]): (List[Ast.AstNode], Int) = { //The `verify` phase def verify(rest: List[Ast.AstNode], orig: List[Ast.AstNode]): List[Ast.AstNode] = @@ -245,10 +240,10 @@ case class ActorBasedFlowMaterializer(override val settings: MaterializerSetting } // Tries to squeeze AstNode into a single fused pipeline - def ast2op(head: Ast.AstNode, prev: List[fusing.Op[_, _, _, _, _]]): List[fusing.Op[_, _, _, _, _]] = + def ast2op(head: Ast.AstNode, prev: List[Stage[_, _]]): List[Stage[_, _]] = head match { // Always-on below - case Ast.Operate(mkOp) ⇒ mkOp() :: prev + case Ast.StageFactory(mkStage, _) ⇒ mkStage() :: prev // Optimizations below case noMatch if !optimizations.fusion ⇒ prev @@ -332,7 +327,7 @@ case class ActorBasedFlowMaterializer(override val settings: MaterializerSetting val (pub, value) = createSource(flowName) (value, attachSink(pub, flowName)) } else { - val id = processorForNode[In, Out](identityTransform, flowName, 1) + val id = processorForNode[In, Out](identityStageNode, flowName, 1) (attachSource(id, flowName), attachSink(id, flowName)) } } else { @@ -344,7 +339,7 @@ case class ActorBasedFlowMaterializer(override val settings: MaterializerSetting new MaterializedPipe(source, sourceValue, sink, sinkValue) } //FIXME Should this be a dedicated AstNode? - private[this] val identityTransform = Ast.Transform("identity", () ⇒ FlowOps.identityTransformer[Any]) + private[this] val identityStageNode = Ast.StageFactory(() ⇒ FlowOps.identityStage[Any], "identity") def executionContext: ExecutionContext = dispatchers.lookup(settings.dispatcher match { case Deploy.NoDispatcherGiven ⇒ Dispatchers.DefaultDispatcherId @@ -406,7 +401,7 @@ case class ActorBasedFlowMaterializer(override val settings: MaterializerSetting (List(subscriber), publishers) case identity @ Ast.IdentityAstNode ⇒ // FIXME Why is IdentityAstNode a JunctionAStNode? - val id = List(processorForNode[In, Out](identityTransform, identity.name, 1)) // FIXME is `identity.name` appropriate/unique here? + val id = List(processorForNode[In, Out](identityStageNode, identity.name, 1)) // FIXME is `identity.name` appropriate/unique here? (id, id) } @@ -460,31 +455,26 @@ private[akka] object ActorProcessorFactory { def props(materializer: FlowMaterializer, op: AstNode): Props = { val settings = materializer.settings // USE THIS TO AVOID CLOSING OVER THE MATERIALIZER BELOW (op match { - case Fused(ops, _) ⇒ Props(new ActorInterpreter(settings, ops)) - case Map(f) ⇒ Props(new ActorInterpreter(settings, List(fusing.Map(f)))) - case Filter(p) ⇒ Props(new ActorInterpreter(settings, List(fusing.Filter(p)))) - case Drop(n) ⇒ Props(new ActorInterpreter(settings, List(fusing.Drop(n)))) - case Take(n) ⇒ Props(new ActorInterpreter(settings, List(fusing.Take(n)))) - case Collect(pf) ⇒ Props(new ActorInterpreter(settings, List(fusing.Collect(pf)))) - case Scan(z, f) ⇒ Props(new ActorInterpreter(settings, List(fusing.Scan(z, f)))) - case Expand(s, f) ⇒ Props(new ActorInterpreter(settings, List(fusing.Expand(s, f)))) - case Conflate(s, f) ⇒ Props(new ActorInterpreter(settings, List(fusing.Conflate(s, f)))) - case Buffer(n, s) ⇒ Props(new ActorInterpreter(settings, List(fusing.Buffer(n, s)))) - case MapConcat(f) ⇒ Props(new ActorInterpreter(settings, List(fusing.MapConcat(f)))) - case Operate(mkOp) ⇒ Props(new ActorInterpreter(settings, List(mkOp()))) - case MapAsync(f) ⇒ Props(new MapAsyncProcessorImpl(settings, f)) - case MapAsyncUnordered(f) ⇒ Props(new MapAsyncUnorderedProcessorImpl(settings, f)) - case Grouped(n) ⇒ Props(new ActorInterpreter(settings, List(fusing.Grouped(n)))) - case GroupBy(f) ⇒ Props(new GroupByProcessorImpl(settings, f)) - case PrefixAndTail(n) ⇒ Props(new PrefixAndTailImpl(settings, n)) - case SplitWhen(p) ⇒ Props(new SplitWhenProcessorImpl(settings, p)) - case ConcatAll ⇒ Props(new ConcatAllImpl(materializer)) //FIXME closes over the materializer, is this good? - case t: Transform ⇒ - val tr = t.mkTransformer() - Props(new TransformProcessorImpl(settings, tr)) - case t: TimerTransform ⇒ - val tr = t.mkTransformer() - Props(new TimerTransformerProcessorsImpl(settings, tr)) + case Fused(ops, _) ⇒ ActorInterpreter.props(settings, ops) + case Map(f) ⇒ ActorInterpreter.props(settings, List(fusing.Map(f))) + case Filter(p) ⇒ ActorInterpreter.props(settings, List(fusing.Filter(p))) + case Drop(n) ⇒ ActorInterpreter.props(settings, List(fusing.Drop(n))) + case Take(n) ⇒ ActorInterpreter.props(settings, List(fusing.Take(n))) + case Collect(pf) ⇒ ActorInterpreter.props(settings, List(fusing.Collect(pf))) + case Scan(z, f) ⇒ ActorInterpreter.props(settings, List(fusing.Scan(z, f))) + case Expand(s, f) ⇒ ActorInterpreter.props(settings, List(fusing.Expand(s, f))) + case Conflate(s, f) ⇒ ActorInterpreter.props(settings, List(fusing.Conflate(s, f))) + case Buffer(n, s) ⇒ ActorInterpreter.props(settings, List(fusing.Buffer(n, s))) + case MapConcat(f) ⇒ ActorInterpreter.props(settings, List(fusing.MapConcat(f))) + case MapAsync(f) ⇒ MapAsyncProcessorImpl.props(settings, f) + case MapAsyncUnordered(f) ⇒ MapAsyncUnorderedProcessorImpl.props(settings, f) + case Grouped(n) ⇒ ActorInterpreter.props(settings, List(fusing.Grouped(n))) + case GroupBy(f) ⇒ GroupByProcessorImpl.props(settings, f) + case PrefixAndTail(n) ⇒ PrefixAndTailImpl.props(settings, n) + case SplitWhen(p) ⇒ SplitWhenProcessorImpl.props(settings, p) + case ConcatAll ⇒ ConcatAllImpl.props(materializer) //FIXME closes over the materializer, is this good? + case StageFactory(mkStage, _) ⇒ ActorInterpreter.props(settings, List(mkStage())) + case TimerTransform(mkStage, _) ⇒ TimerTransformerProcessorsImpl.props(settings, mkStage()) }).withDispatcher(settings.dispatcher) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/ConcatAllImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/ConcatAllImpl.scala index e9ac39f3c5..bd1e057bb7 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ConcatAllImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ConcatAllImpl.scala @@ -5,6 +5,15 @@ package akka.stream.impl import akka.stream.FlowMaterializer import akka.stream.scaladsl.Sink +import akka.actor.Props + +/** + * INTERNAL API + */ +private[akka] object ConcatAllImpl { + def props(materializer: FlowMaterializer): Props = + Props(new ConcatAllImpl(materializer)) +} /** * INTERNAL API diff --git a/akka-stream/src/main/scala/akka/stream/impl/GroupByProcessorImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/GroupByProcessorImpl.scala index 58b4881777..4c43f174c3 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/GroupByProcessorImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/GroupByProcessorImpl.scala @@ -5,6 +5,15 @@ package akka.stream.impl import akka.stream.MaterializerSettings import akka.stream.scaladsl.Source +import akka.actor.Props + +/** + * INTERNAL API + */ +private[akka] object GroupByProcessorImpl { + def props(settings: MaterializerSettings, keyFor: Any ⇒ Any): Props = + Props(new GroupByProcessorImpl(settings, keyFor)) +} /** * INTERNAL API diff --git a/akka-stream/src/main/scala/akka/stream/impl/MapAsyncProcessorImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/MapAsyncProcessorImpl.scala index 8dc78f3e7f..f55e1fed77 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/MapAsyncProcessorImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/MapAsyncProcessorImpl.scala @@ -11,12 +11,16 @@ import scala.util.control.NonFatal import akka.stream.MaterializerSettings import akka.pattern.pipe import scala.annotation.tailrec +import akka.actor.Props /** * INTERNAL API */ private[akka] object MapAsyncProcessorImpl { + def props(settings: MaterializerSettings, f: Any ⇒ Future[Any]): Props = + Props(new MapAsyncProcessorImpl(settings, f)) + object FutureElement { implicit val ordering: Ordering[FutureElement] = new Ordering[FutureElement] { def compare(a: FutureElement, b: FutureElement): Int = { @@ -43,7 +47,7 @@ private[akka] class MapAsyncProcessorImpl(_settings: MaterializerSettings, f: An var doneSeqNo = 0L def gap: Long = submittedSeqNo - doneSeqNo - // TODO performance improvement: explore Endre's proposal of using an array based ring buffer addressed by + // TODO performance improvement: explore Endre's proposal of using an array based ring buffer addressed by // seqNo & Mask and explicitly storing a Gap object to denote missing pieces instead of the sorted set // keep future results arriving too early in a buffer sorted by seqNo @@ -62,7 +66,7 @@ private[akka] class MapAsyncProcessorImpl(_settings: MaterializerSettings, f: An if (iter.hasNext) { val next = iter.next() val inOrder = next.seqNo == (doneSeqNo + 1) - // stop at first missing seqNo + // stop at first missing seqNo if (inOrder) { n += 1 doneSeqNo = next.seqNo diff --git a/akka-stream/src/main/scala/akka/stream/impl/MapAsyncUnorderedProcessorImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/MapAsyncUnorderedProcessorImpl.scala index b173baa980..e6a35505f9 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/MapAsyncUnorderedProcessorImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/MapAsyncUnorderedProcessorImpl.scala @@ -8,11 +8,15 @@ import scala.util.control.NonFatal import akka.stream.MaterializerSettings import akka.stream.MaterializerSettings import akka.pattern.pipe +import akka.actor.Props /** * INTERNAL API */ private[akka] object MapAsyncUnorderedProcessorImpl { + def props(settings: MaterializerSettings, f: Any ⇒ Future[Any]): Props = + Props(new MapAsyncUnorderedProcessorImpl(settings, f)) + case class FutureElement(element: Any) case class FutureFailure(cause: Throwable) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/PrefixAndTailImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/PrefixAndTailImpl.scala index 8f35efd52d..94432ba0f2 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/PrefixAndTailImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/PrefixAndTailImpl.scala @@ -6,6 +6,15 @@ package akka.stream.impl import scala.collection.immutable import akka.stream.MaterializerSettings import akka.stream.scaladsl.Source +import akka.actor.Props + +/** + * INTERNAL API + */ +private[akka] object PrefixAndTailImpl { + def props(settings: MaterializerSettings, takeMax: Int): Props = + Props(new PrefixAndTailImpl(settings, takeMax)) +} /** * INTERNAL API diff --git a/akka-stream/src/main/scala/akka/stream/impl/SingleStreamProcessors.scala b/akka-stream/src/main/scala/akka/stream/impl/SingleStreamProcessors.scala deleted file mode 100644 index c1eb21f90d..0000000000 --- a/akka-stream/src/main/scala/akka/stream/impl/SingleStreamProcessors.scala +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright (C) 2014 Typesafe Inc. - */ -package akka.stream.impl - -import akka.actor.Props -import akka.stream.{ MaterializerSettings, TransformerLike } - -import scala.collection.immutable -import scala.util.control.NonFatal - -/** - * INTERNAL API - */ -private[akka] class TransformProcessorImpl(_settings: MaterializerSettings, transformer: TransformerLike[Any, Any]) - extends ActorProcessorImpl(_settings) with Emit { - - var errorEvent: Option[Throwable] = None - - override def preStart(): Unit = { - super.preStart() - nextPhase(running) - } - - override def onError(e: Throwable): Unit = { - try { - transformer.onError(e) - errorEvent = Some(e) - pump() - } catch { case NonFatal(ex) ⇒ fail(ex) } - } - - object NeedsInputAndDemandOrCompletion extends TransferState { - def isReady = (primaryInputs.inputsAvailable && primaryOutputs.demandAvailable) || primaryInputs.inputsDepleted - def isCompleted = primaryOutputs.isClosed - } - - private val runningPhase: TransferPhase = TransferPhase(NeedsInputAndDemandOrCompletion) { () ⇒ - if (primaryInputs.inputsDepleted) nextPhase(terminate) - else { - emits = transformer.onNext(primaryInputs.dequeueInputElement()) - if (transformer.isComplete) emitAndThen(terminate) - else emitAndThen(running) - } - } - - def running: TransferPhase = runningPhase - - val terminate = TransferPhase(Always) { () ⇒ - emits = transformer.onTermination(errorEvent) - emitAndThen(completedPhase) - } - - override def toString: String = s"Transformer(emits=$emits, transformer=$transformer)" - - override def postStop(): Unit = try super.postStop() finally transformer.cleanup() -} - -/** - * INTERNAL API - */ -private[akka] object IdentityProcessorImpl { - def props(settings: MaterializerSettings): Props = Props(new IdentityProcessorImpl(settings)) -} - -/** - * INTERNAL API - */ -private[akka] class IdentityProcessorImpl(_settings: MaterializerSettings) extends ActorProcessorImpl(_settings) { - - val running: TransferPhase = TransferPhase(primaryInputs.NeedsInput && primaryOutputs.NeedsDemand) { () ⇒ - primaryOutputs.enqueueOutputElement(primaryInputs.dequeueInputElement()) - } - - nextPhase(running) -} diff --git a/akka-stream/src/main/scala/akka/stream/impl/SplitWhenProcessorImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/SplitWhenProcessorImpl.scala index 5f901eb8a5..375b677c69 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/SplitWhenProcessorImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/SplitWhenProcessorImpl.scala @@ -5,6 +5,15 @@ package akka.stream.impl import akka.stream.MaterializerSettings import akka.stream.scaladsl.Source +import akka.actor.Props + +/** + * INTERNAL API + */ +private[akka] object SplitWhenProcessorImpl { + def props(settings: MaterializerSettings, splitPredicate: Any ⇒ Boolean): Props = + Props(new SplitWhenProcessorImpl(settings, splitPredicate)) +} /** * INTERNAL API diff --git a/akka-stream/src/main/scala/akka/stream/impl/TimerTransformerProcessorsImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/TimerTransformerProcessorsImpl.scala index d96305098c..8f0accdde6 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/TimerTransformerProcessorsImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/TimerTransformerProcessorsImpl.scala @@ -7,6 +7,12 @@ import java.util.LinkedList import akka.stream.MaterializerSettings import akka.stream.TimerTransformer import scala.util.control.NonFatal +import akka.actor.Props + +private[akka] object TimerTransformerProcessorsImpl { + def props(settings: MaterializerSettings, transformer: TimerTransformer[Any, Any]): Props = + Props(new TimerTransformerProcessorsImpl(settings, transformer)) +} /** * INTERNAL API @@ -14,17 +20,29 @@ import scala.util.control.NonFatal private[akka] class TimerTransformerProcessorsImpl( _settings: MaterializerSettings, transformer: TimerTransformer[Any, Any]) - extends TransformProcessorImpl(_settings, transformer) { + extends ActorProcessorImpl(_settings) with Emit { import TimerTransformer._ + var errorEvent: Option[Throwable] = None + override def preStart(): Unit = { super.preStart() + nextPhase(running) transformer.start(context) } - override def postStop(): Unit = { - super.postStop() - transformer.stop() + override def postStop(): Unit = + try { + super.postStop() + transformer.stop() + } finally transformer.cleanup() + + override def onError(e: Throwable): Unit = { + try { + transformer.onError(e) + errorEvent = Some(e) + pump() + } catch { case NonFatal(ex) ⇒ fail(ex) } } val schedulerInputs: Inputs = new DefaultInputTransferStates { @@ -58,7 +76,7 @@ private[akka] class TimerTransformerProcessorsImpl( def isCompleted = false } - private val runningPhase: TransferPhase = TransferPhase(RunningCondition) { () ⇒ + private val running: TransferPhase = TransferPhase(RunningCondition) { () ⇒ if (primaryInputs.inputsDepleted || (transformer.isComplete && !schedulerInputs.inputsAvailable)) { nextPhase(terminate) } else if (schedulerInputs.inputsAvailable) { @@ -71,6 +89,11 @@ private[akka] class TimerTransformerProcessorsImpl( } } - override def running: TransferPhase = runningPhase + private val terminate = TransferPhase(Always) { () ⇒ + emits = transformer.onTermination(errorEvent) + emitAndThen(completedPhase) + } + + override def toString: String = s"Transformer(emits=$emits, transformer=$transformer)" } diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala index 36c8d25536..1641d49a52 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala @@ -4,21 +4,21 @@ package akka.stream.impl.fusing import java.util.Arrays - import akka.actor.{ Actor, ActorRef } import akka.event.Logging import akka.stream.MaterializerSettings import akka.stream.actor.ActorSubscriber.OnSubscribe import akka.stream.actor.ActorSubscriberMessage.{ OnNext, OnError, OnComplete } import akka.stream.impl._ +import akka.stream.stage._ import org.reactivestreams.{ Subscriber, Subscription } - import scala.util.control.NonFatal +import akka.actor.Props /** * INTERNAL API */ -private[akka] class BatchingActorInputBoundary(val size: Int) extends BoundaryOp { +private[akka] class BatchingActorInputBoundary(val size: Int) extends BoundaryStage { require(size > 0, "buffer size cannot be zero") require((size & (size - 1)) == 0, "buffer size must be a power of two") @@ -60,25 +60,25 @@ private[akka] class BatchingActorInputBoundary(val size: Int) extends BoundaryOp } } - override def onPush(elem: Any, ctxt: BoundaryContext): Directive = + override def onPush(elem: Any, ctx: BoundaryContext): Directive = throw new UnsupportedOperationException("BUG: Cannot push the upstream boundary") - override def onPull(ctxt: BoundaryContext): Directive = { - if (inputBufferElements > 1) ctxt.push(dequeue()) + override def onPull(ctx: BoundaryContext): Directive = { + if (inputBufferElements > 1) ctx.push(dequeue()) else if (inputBufferElements == 1) { - if (upstreamCompleted) ctxt.pushAndFinish(dequeue()) - else ctxt.push(dequeue()) + if (upstreamCompleted) ctx.pushAndFinish(dequeue()) + else ctx.push(dequeue()) } else if (upstreamCompleted) { - ctxt.finish() + ctx.finish() } else { downstreamWaiting = true - ctxt.exit() + ctx.exit() } } - override def onDownstreamFinish(ctxt: BoundaryContext): TerminationDirective = { + override def onDownstreamFinish(ctx: BoundaryContext): TerminationDirective = { cancel() - ctxt.exit() + ctx.exit() } def cancel(): Unit = { @@ -143,7 +143,7 @@ private[akka] class BatchingActorInputBoundary(val size: Int) extends BoundaryOp /** * INTERNAL API */ -private[akka] class ActorOutputBoundary(val actor: ActorRef) extends BoundaryOp { +private[akka] class ActorOutputBoundary(val actor: ActorRef) extends BoundaryStage { private var exposedPublisher: ActorPublisher[Any] = _ @@ -177,27 +177,27 @@ private[akka] class ActorOutputBoundary(val actor: ActorRef) extends BoundaryOp } } - override def onPush(elem: Any, ctxt: BoundaryContext): Directive = { + override def onPush(elem: Any, ctx: BoundaryContext): Directive = { onNext(elem) - if (downstreamDemand > 0) ctxt.pull() - else if (downstreamCompleted) ctxt.finish() + if (downstreamDemand > 0) ctx.pull() + else if (downstreamCompleted) ctx.finish() else { upstreamWaiting = true - ctxt.exit() + ctx.exit() } } - override def onPull(ctxt: BoundaryContext): Directive = + override def onPull(ctx: BoundaryContext): Directive = throw new UnsupportedOperationException("BUG: Cannot pull the downstream boundary") - override def onUpstreamFinish(ctxt: BoundaryContext): TerminationDirective = { + override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = { complete() - ctxt.finish() + ctx.finish() } - override def onFailure(cause: Throwable, ctxt: BoundaryContext): TerminationDirective = { + override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = { fail(cause) - ctxt.fail(cause) + ctx.fail(cause) } private def subscribePending(subscribers: Seq[Subscriber[Any]]): Unit = @@ -245,7 +245,15 @@ private[akka] class ActorOutputBoundary(val actor: ActorRef) extends BoundaryOp /** * INTERNAL API */ -private[akka] class ActorInterpreter(settings: MaterializerSettings, ops: Seq[Op[_, _, _, _, _]]) +private[akka] object ActorInterpreter { + def props(settings: MaterializerSettings, ops: Seq[Stage[_, _]]): Props = + Props(new ActorInterpreter(settings, ops)) +} + +/** + * INTERNAL API + */ +private[akka] class ActorInterpreter(val settings: MaterializerSettings, val ops: Seq[Stage[_, _]]) extends Actor { private val upstream = new BatchingActorInputBoundary(settings.initialInputBufferSize) @@ -270,4 +278,4 @@ private[akka] class ActorInterpreter(settings: MaterializerSettings, ops: Seq[Op throw new IllegalStateException("This actor cannot be restarted", reason) } -} \ No newline at end of file +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Interpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Interpreter.scala index 51b7765dd8..c9c1c3fa28 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Interpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Interpreter.scala @@ -4,63 +4,38 @@ package akka.stream.impl.fusing import scala.annotation.tailrec +import scala.collection.breakOut import scala.util.control.NonFatal +import akka.stream.stage._ // TODO: // fix jumpback table with keep-going-on-complete ops (we might jump between otherwise isolated execution regions) // implement grouped, buffer // add recover -trait Op[In, Out, PushD <: Directive, PullD <: Directive, Ctxt <: Context[Out]] { - private[fusing] var holding = false - private[fusing] var allowedToPush = false - private[fusing] var terminationPending = false - - def isHolding: Boolean = holding - def isFinishing: Boolean = terminationPending - def onPush(elem: In, ctxt: Ctxt): PushD - def onPull(ctxt: Ctxt): PullD - def onUpstreamFinish(ctxt: Ctxt): TerminationDirective = ctxt.finish() - def onDownstreamFinish(ctxt: Ctxt): TerminationDirective = ctxt.finish() - def onFailure(cause: Throwable, ctxt: Ctxt): TerminationDirective = ctxt.fail(cause) +/** + * INTERNAL API + * + * `BoundaryStage` implementations are meant to communicate with the external world. These stages do not have most of the + * safety properties enforced and should be used carefully. One important ability of BoundaryStages that they can take + * off an execution signal by calling `ctx.exit()`. This is typically used immediately after an external signal has + * been produced (for example an actor message). BoundaryStages can also kickstart execution by calling `enter()` which + * returns a context they can use to inject signals into the interpreter. There is no checks in place to enforce that + * the number of signals taken out by exit() and the number of signals returned via enter() are the same -- using this + * stage type needs extra care from the implementer. + * + * BoundaryStages are the elements that make the interpreter *tick*, there is no other way to start the interpreter + * than using a BoundaryStage. + */ +private[akka] abstract class BoundaryStage extends AbstractStage[Any, Any, Directive, Directive, BoundaryContext] { + private[fusing] var bctx: BoundaryContext = _ + def enter(): BoundaryContext = bctx } -trait DeterministicOp[In, Out] extends Op[In, Out, Directive, Directive, Context[Out]] -trait DetachedOp[In, Out] extends Op[In, Out, UpstreamDirective, DownstreamDirective, DetachedContext[Out]] -trait BoundaryOp extends Op[Any, Any, Directive, Directive, BoundaryContext] { - private[fusing] var bctxt: BoundaryContext = _ - def enter(): BoundaryContext = bctxt -} - -trait TransitivePullOp[In, Out] extends DeterministicOp[In, Out] { - final override def onPull(ctxt: Context[Out]): Directive = ctxt.pull() -} - -sealed trait Directive -sealed trait UpstreamDirective extends Directive -sealed trait DownstreamDirective extends Directive -sealed trait TerminationDirective extends Directive -final class FreeDirective extends UpstreamDirective with DownstreamDirective with TerminationDirective - -sealed trait Context[Out] { - def push(elem: Out): DownstreamDirective - def pull(): UpstreamDirective - def finish(): FreeDirective - def pushAndFinish(elem: Out): DownstreamDirective - def fail(cause: Throwable): FreeDirective - def absorbTermination(): TerminationDirective -} - -trait DetachedContext[Out] extends Context[Out] { - def hold(): FreeDirective - def pushAndPull(elem: Out): FreeDirective -} - -trait BoundaryContext extends Context[Any] { - def exit(): FreeDirective -} - -object OneBoundedInterpreter { +/** + * INTERNAL API + */ +private[akka] object OneBoundedInterpreter { final val PhantomDirective = null /** @@ -70,16 +45,18 @@ object OneBoundedInterpreter { * paths again. When finishing an op this op is injected in its place to isolate upstream and downstream execution * domains. */ - private[akka] object Finished extends BoundaryOp { - override def onPush(elem: Any, ctxt: BoundaryContext): UpstreamDirective = ctxt.finish() - override def onPull(ctxt: BoundaryContext): DownstreamDirective = ctxt.finish() - override def onUpstreamFinish(ctxt: BoundaryContext): TerminationDirective = ctxt.exit() - override def onDownstreamFinish(ctxt: BoundaryContext): TerminationDirective = ctxt.exit() - override def onFailure(cause: Throwable, ctxt: BoundaryContext): TerminationDirective = ctxt.exit() + private[akka] object Finished extends BoundaryStage { + override def onPush(elem: Any, ctx: BoundaryContext): UpstreamDirective = ctx.finish() + override def onPull(ctx: BoundaryContext): DownstreamDirective = ctx.finish() + override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = ctx.exit() + override def onDownstreamFinish(ctx: BoundaryContext): TerminationDirective = ctx.exit() + override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = ctx.exit() } } /** + * INTERNAL API + * * One-bounded interpreter for a linear chain of stream operations (graph support is possible and will be implemented * later) * @@ -105,70 +82,70 @@ object OneBoundedInterpreter { * time. This "exactly one" property is enforced by proper types and runtime checks where needed. Currently there are * three kinds of ops: * - * - DeterministicOp implementations participate in 1-bounded regions. For every external non-completion signal these + * - PushPullStage implementations participate in 1-bounded regions. For every external non-completion signal these * ops produce *exactly one* signal (completion is different, explained later) therefore keeping the number of events * the same: exactly one. * - * - DetachedOp implementations are boundaries between 1-bounded regions. This means that they need to enforce the - * "exactly one" property both on their upstream and downstream regions. As a consequence a DetachedOp can never - * answer an onPull with a ctxt.pull() or answer an onPush() with a ctxt.push() since such an action would "steal" + * - DetachedStage implementations are boundaries between 1-bounded regions. This means that they need to enforce the + * "exactly one" property both on their upstream and downstream regions. As a consequence a DetachedStage can never + * answer an onPull with a ctx.pull() or answer an onPush() with a ctx.push() since such an action would "steal" * the event from one region (resulting in zero signals) and would inject it to the other region (resulting in two - * signals). However DetachedOps have the ability to call ctxt.hold() as a response to onPush/onPull which temporarily + * signals). However DetachedStages have the ability to call ctx.hold() as a response to onPush/onPull which temporarily * takes the signal off and stops execution, at the same time putting the op in a "holding" state. If the op is in a * holding state it contains one absorbed signal, therefore in this state the only possible command to call is - * ctxt.pushAndPull() which results in two events making the balance right again: + * ctx.pushAndPull() which results in two events making the balance right again: * 1 hold + 1 external event = 2 external event * This mechanism allows synchronization between the upstream and downstream regions which otherwise can progress * independently. * - * - BoundaryOp implementations are meant to communicate with the external world. These ops do not have most of the - * safety properties enforced and should be used carefully. One important ability of BoundaryOps that they can take - * off an execution signal by calling ctxt.exit(). This is typically used immediately after an external signal has - * been produced (for example an actor message). BoundaryOps can also kickstart execution by calling enter() which + * - BoundaryStage implementations are meant to communicate with the external world. These ops do not have most of the + * safety properties enforced and should be used carefully. One important ability of BoundaryStages that they can take + * off an execution signal by calling ctx.exit(). This is typically used immediately after an external signal has + * been produced (for example an actor message). BoundaryStages can also kickstart execution by calling enter() which * returns a context they can use to inject signals into the interpreter. There is no checks in place to enforce that * the number of signals taken out by exit() and the number of signals returned via enter() are the same -- using this * op type needs extra care from the implementer. - * BoundaryOps are the elements that make the interpreter *tick*, there is no other way to start the interpreter - * than using a BoundaryOp. + * BoundaryStages are the elements that make the interpreter *tick*, there is no other way to start the interpreter + * than using a BoundaryStage. * * Operations are allowed to do early completion and cancel/complete their upstreams and downstreams. It is *not* - * allowed however to do these independently to avoid isolated execution islands. The only call possible is ctxt.finish() + * allowed however to do these independently to avoid isolated execution islands. The only call possible is ctx.finish() * which is a combination of cancel/complete. * Since onComplete is not a backpressured signal it is sometimes preferable to push a final element and then immediately * finish. This combination is exposed as pushAndFinish() which enables op writers to propagate completion events without * waiting for an extra round of pull. * Another peculiarity is how to convert termination events (complete/failure) into elements. The problem - * here is that the termination events are not backpressured while elements are. This means that simply calling ctxt.push() + * here is that the termination events are not backpressured while elements are. This means that simply calling ctx.push() * as a response to onUpstreamFinished() will very likely break boundedness and result in a buffer overflow somewhere. - * Therefore the only allowed command in this case is ctxt.absorbTermination() which stops the propagation of the + * Therefore the only allowed command in this case is ctx.absorbTermination() which stops the propagation of the * termination signal, and puts the op in a finishing state. Depending on whether the op has a pending pull signal it has * not yet "consumed" by a push its onPull() handler might be called immediately. * * In order to execute different individual execution regions the interpreter uses the callstack to schedule these. The * current execution forking operations are - * - ctxt.finish() which starts a wave of completion and cancellation in two directions. When an op calls finish() + * - ctx.finish() which starts a wave of completion and cancellation in two directions. When an op calls finish() * it is immediately replaced by an artificial Finished op which makes sure that the two execution paths are isolated * forever. - * - ctxt.fail() which is similar to finish() - * - ctxt.pushAndPull() which (as a response to a previous ctxt.hold()) starts a wawe of downstream push and upstream + * - ctx.fail() which is similar to finish() + * - ctx.pushAndPull() which (as a response to a previous ctx.hold()) starts a wawe of downstream push and upstream * pull. The two execution paths are isolated by the op itself since onPull() from downstream can only answered by hold or * push, while onPush() from upstream can only answered by hold or pull -- it is impossible to "cross" the op. - * - ctxt.pushAndFinish() which is different from the forking ops above because the execution of push and finish happens on + * - ctx.pushAndFinish() which is different from the forking ops above because the execution of push and finish happens on * the same execution region and they are order dependent, too. * The interpreter tracks the depth of recursive forking and allows various strategies of dealing with the situation * when this depth reaches a certain limit. In the simplest case an error is reported (this is very useful for stress * testing and finding callstack wasting bugs), in the other case the forked call is scheduled via a list -- i.e. instead * of the stack the heap is used. */ -class OneBoundedInterpreter(ops: Seq[Op[_, _, _, _, _]], val forkLimit: Int = 100, val overflowToHeap: Boolean = true) { +private[akka] class OneBoundedInterpreter(ops: Seq[Stage[_, _]], val forkLimit: Int = 100, val overflowToHeap: Boolean = true) { import OneBoundedInterpreter._ - type UntypedOp = Op[Any, Any, Directive, Directive, DetachedContext[Any]] + type UntypedOp = AbstractStage[Any, Any, Directive, Directive, Context[Any]] require(ops.nonEmpty, "OneBoundedInterpreter cannot be created without at least one Op") - private val pipeline = ops.toArray.asInstanceOf[Array[UntypedOp]] + private val pipeline: Array[UntypedOp] = ops.map(_.asInstanceOf[UntypedOp])(breakOut) /** - * This table is used to accelerate demand propagation upstream. All ops that implement TransitivePullOp are guaranteed + * This table is used to accelerate demand propagation upstream. All ops that implement PushStage are guaranteed * to only do upstream propagation of demand signals, therefore it is not necessary to execute them but enough to * "jump over" them. This means that when a chain of one million maps gets a downstream demand it is propagated * to the upstream *in one step* instead of one million onPull() calls. @@ -199,7 +176,7 @@ class OneBoundedInterpreter(ops: Seq[Op[_, _, _, _, _]], val forkLimit: Int = 10 var nextJumpBack = -1 for (pos ← 0 until pipeline.length) { table(pos) = nextJumpBack - if (!pipeline(pos).isInstanceOf[TransitivePullOp[_, _]]) nextJumpBack = pos + if (!pipeline(pos).isInstanceOf[PushStage[_, _]]) nextJumpBack = pos } table } @@ -217,7 +194,7 @@ class OneBoundedInterpreter(ops: Seq[Op[_, _, _, _, _]], val forkLimit: Int = 10 override def pull(): UpstreamDirective = { if (pipeline(activeOp).holding) throw new IllegalStateException("Cannot pull while holding, only pushAndPull") - pipeline(activeOp).allowedToPush = !pipeline(activeOp).isInstanceOf[DetachedOp[_, _]] + pipeline(activeOp).allowedToPush = !pipeline(activeOp).isInstanceOf[DetachedStage[_, _]] state = Pulling PhantomDirective } @@ -228,6 +205,8 @@ class OneBoundedInterpreter(ops: Seq[Op[_, _, _, _, _]], val forkLimit: Int = 10 PhantomDirective } + def isFinishing: Boolean = pipeline(activeOp).terminationPending + override def pushAndFinish(elem: Any): DownstreamDirective = { pipeline(activeOp) = Finished.asInstanceOf[UntypedOp] // This MUST be an unsafeFork because the execution of PushFinish MUST strictly come before the finish execution @@ -253,6 +232,8 @@ class OneBoundedInterpreter(ops: Seq[Op[_, _, _, _, _]], val forkLimit: Int = 10 exit() } + override def isHolding: Boolean = pipeline(activeOp).holding + override def pushAndPull(elem: Any): FreeDirective = { if (!pipeline(activeOp).holding) throw new IllegalStateException("Cannot pushAndPull without holding first") pipeline(activeOp).holding = false @@ -276,14 +257,14 @@ class OneBoundedInterpreter(ops: Seq[Op[_, _, _, _, _]], val forkLimit: Int = 10 private object Pushing extends State { override def advance(): Unit = { activeOp += 1 - pipeline(activeOp).onPush(elementInFlight, ctxt = this) + pipeline(activeOp).onPush(elementInFlight, ctx = this) } } private object PushFinish extends State { override def advance(): Unit = { activeOp += 1 - pipeline(activeOp).onPush(elementInFlight, ctxt = this) + pipeline(activeOp).onPush(elementInFlight, ctx = this) } override def pushAndFinish(elem: Any): DownstreamDirective = { @@ -302,7 +283,7 @@ class OneBoundedInterpreter(ops: Seq[Op[_, _, _, _, _]], val forkLimit: Int = 10 override def advance(): Unit = { elementInFlight = null activeOp = jumpBacks(activeOp) - pipeline(activeOp).onPull(ctxt = this) + pipeline(activeOp).onPull(ctx = this) } override def hold(): FreeDirective = { @@ -317,7 +298,9 @@ class OneBoundedInterpreter(ops: Seq[Op[_, _, _, _, _]], val forkLimit: Int = 10 elementInFlight = null pipeline(activeOp) = Finished.asInstanceOf[UntypedOp] activeOp += 1 - if (!pipeline(activeOp).isFinishing) pipeline(activeOp).onUpstreamFinish(ctxt = this) + + // FIXME issue #16345, ArrayIndexOutOfBoundsException + if (!pipeline(activeOp).terminationPending) pipeline(activeOp).onUpstreamFinish(ctx = this) else exit() } @@ -330,7 +313,7 @@ class OneBoundedInterpreter(ops: Seq[Op[_, _, _, _, _]], val forkLimit: Int = 10 pipeline(activeOp).terminationPending = true pipeline(activeOp).holding = false // FIXME: This state is potentially corrupted by the jumpBackTable (not updated when jumping over) - if (pipeline(activeOp).allowedToPush) pipeline(activeOp).onPull(ctxt = Pulling) + if (pipeline(activeOp).allowedToPush) pipeline(activeOp).onPull(ctx = Pulling) else exit() PhantomDirective } @@ -341,7 +324,9 @@ class OneBoundedInterpreter(ops: Seq[Op[_, _, _, _, _]], val forkLimit: Int = 10 elementInFlight = null pipeline(activeOp) = Finished.asInstanceOf[UntypedOp] activeOp -= 1 - if (!pipeline(activeOp).isFinishing) pipeline(activeOp).onDownstreamFinish(ctxt = this) + + // FIXME issue #16345, ArrayIndexOutOfBoundsException + if (!pipeline(activeOp).terminationPending) pipeline(activeOp).onDownstreamFinish(ctx = this) else exit() } @@ -356,13 +341,13 @@ class OneBoundedInterpreter(ops: Seq[Op[_, _, _, _, _]], val forkLimit: Int = 10 elementInFlight = null pipeline(activeOp) = Finished.asInstanceOf[UntypedOp] activeOp += 1 - pipeline(activeOp).onFailure(cause, ctxt = this) + pipeline(activeOp).onUpstreamFailure(cause, ctx = this) } override def absorbTermination(): TerminationDirective = { pipeline(activeOp).terminationPending = true pipeline(activeOp).holding = false - if (pipeline(activeOp).allowedToPush) pipeline(activeOp).onPull(ctxt = Pulling) + if (pipeline(activeOp).allowedToPush) pipeline(activeOp).onPull(ctx = Pulling) else exit() PhantomDirective } @@ -421,7 +406,6 @@ class OneBoundedInterpreter(ops: Seq[Op[_, _, _, _, _]], val forkLimit: Int = 10 state = forkState execute() activeOp = savePos - PhantomDirective } def init(): Unit = { @@ -432,13 +416,15 @@ class OneBoundedInterpreter(ops: Seq[Op[_, _, _, _, _]], val forkLimit: Int = 10 def isFinished: Boolean = pipeline(Upstream) == Finished && pipeline(Downstream) == Finished /** - * This method injects a Context to each of the BoundaryOps. This will be the context returned by enter(). + * This method injects a Context to each of the BoundaryStages. This will be the context returned by enter(). */ private def initBoundaries(): Unit = { var op = 0 while (op < pipeline.length) { - if (pipeline(op).isInstanceOf[BoundaryOp]) { - pipeline(op).asInstanceOf[BoundaryOp].bctxt = new State { + // FIXME try to change this to a pattern match `case boundary: BoundaryStage` + // but that doesn't work with current Context types + if (pipeline(op).isInstanceOf[BoundaryStage]) { + pipeline(op).asInstanceOf[BoundaryStage].bctx = new State { val entryPoint = op override def advance(): Unit = () @@ -499,7 +485,7 @@ class OneBoundedInterpreter(ops: Seq[Op[_, _, _, _, _]], val forkLimit: Int = 10 private def runDetached(): Unit = { var op = pipeline.length - 1 while (op >= 0) { - if (pipeline(op).isInstanceOf[DetachedOp[_, _]]) { + if (pipeline(op).isInstanceOf[DetachedStage[_, _]]) { activeOp = op state = Pulling execute() @@ -508,4 +494,4 @@ class OneBoundedInterpreter(ops: Seq[Op[_, _, _, _, _]], val forkLimit: Int = 10 } } -} \ No newline at end of file +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala index 480afec88a..8dead5c510 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/IteratorInterpreter.scala @@ -3,49 +3,57 @@ */ package akka.stream.impl.fusing -object IteratorInterpreter { - case class IteratorUpstream[T](input: Iterator[T]) extends DeterministicOp[T, T] { +import akka.stream.stage._ + +/** + * INTERNAL API + */ +private[akka] object IteratorInterpreter { + final case class IteratorUpstream[T](input: Iterator[T]) extends PushPullStage[T, T] { private var hasNext = input.hasNext - override def onPush(elem: T, ctxt: Context[T]): Directive = + override def onPush(elem: T, ctx: Context[T]): Directive = throw new UnsupportedOperationException("IteratorUpstream operates as a source, it cannot be pushed") - override def onPull(ctxt: Context[T]): Directive = { - if (!hasNext) ctxt.finish() + override def onPull(ctx: Context[T]): Directive = { + if (!hasNext) ctx.finish() else { val elem = input.next() hasNext = input.hasNext - if (!hasNext) ctxt.pushAndFinish(elem) - else ctxt.push(elem) + if (!hasNext) ctx.pushAndFinish(elem) + else ctx.push(elem) } } + + // don't let toString consume the iterator + override def toString: String = "IteratorUpstream" } - case class IteratorDownstream[T]() extends BoundaryOp with Iterator[T] { + final case class IteratorDownstream[T]() extends BoundaryStage with Iterator[T] { private var done = false private var nextElem: T = _ private var needsPull = true private var lastError: Throwable = null - override def onPush(elem: Any, ctxt: BoundaryContext): Directive = { + override def onPush(elem: Any, ctx: BoundaryContext): Directive = { nextElem = elem.asInstanceOf[T] needsPull = false - ctxt.exit() + ctx.exit() } - override def onPull(ctxt: BoundaryContext): Directive = + override def onPull(ctx: BoundaryContext): Directive = throw new UnsupportedOperationException("IteratorDownstream operates as a sink, it cannot be pulled") - override def onUpstreamFinish(ctxt: BoundaryContext): TerminationDirective = { + override def onUpstreamFinish(ctx: BoundaryContext): TerminationDirective = { done = true - ctxt.finish() + ctx.finish() } - override def onFailure(cause: Throwable, ctxt: BoundaryContext): TerminationDirective = { + override def onUpstreamFailure(cause: Throwable, ctx: BoundaryContext): TerminationDirective = { done = true lastError = cause - ctxt.finish() + ctx.finish() } private def pullIfNeeded(): Unit = { @@ -56,27 +64,37 @@ object IteratorInterpreter { override def hasNext: Boolean = { if (!done) pullIfNeeded() - !(done && needsPull) + !(done && needsPull) || (lastError ne null) } override def next(): T = { - if (!hasNext) { - if (lastError != null) throw lastError - else Iterator.empty.next() + if (lastError ne null) { + val e = lastError + lastError = null + throw e + } else if (!hasNext) + Iterator.empty.next() + else { + needsPull = true + nextElem } - needsPull = true - nextElem } + // don't let toString consume the iterator + override def toString: String = "IteratorDownstream" + } } -class IteratorInterpreter[I, O](val input: Iterator[I], val ops: Seq[DeterministicOp[_, _]]) { +/** + * INTERNAL API + */ +private[akka] class IteratorInterpreter[I, O](val input: Iterator[I], val ops: Seq[PushPullStage[_, _]]) { import akka.stream.impl.fusing.IteratorInterpreter._ private val upstream = IteratorUpstream(input) private val downstream = IteratorDownstream[O]() - private val interpreter = new OneBoundedInterpreter(upstream +: ops.asInstanceOf[Seq[Op[_, _, _, _, _]]] :+ downstream) + private val interpreter = new OneBoundedInterpreter(upstream +: ops.asInstanceOf[Seq[Stage[_, _]]] :+ downstream) interpreter.init() def iterator: Iterator[O] = downstream diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 3a91955379..b1f9b0d030 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -3,25 +3,25 @@ */ package akka.stream.impl.fusing +import scala.collection.immutable import akka.stream.OverflowStrategy import akka.stream.impl.FixedSizeBuffer - -import scala.collection.immutable +import akka.stream.stage._ /** * INTERNAL API */ -private[akka] final case class Map[In, Out](f: In ⇒ Out) extends TransitivePullOp[In, Out] { - override def onPush(elem: In, ctxt: Context[Out]): Directive = ctxt.push(f(elem)) +private[akka] final case class Map[In, Out](f: In ⇒ Out) extends PushStage[In, Out] { + override def onPush(elem: In, ctx: Context[Out]): Directive = ctx.push(f(elem)) } /** * INTERNAL API */ -private[akka] final case class Filter[T](p: T ⇒ Boolean) extends TransitivePullOp[T, T] { - override def onPush(elem: T, ctxt: Context[T]): Directive = - if (p(elem)) ctxt.push(elem) - else ctxt.pull() +private[akka] final case class Filter[T](p: T ⇒ Boolean) extends PushStage[T, T] { + override def onPush(elem: T, ctx: Context[T]): Directive = + if (p(elem)) ctx.push(elem) + else ctx.pull() } private[akka] final object Collect { @@ -31,103 +31,103 @@ private[akka] final object Collect { final val NotApplied: Any ⇒ Any = _ ⇒ Collect.NotApplied } -private[akka] final case class Collect[In, Out](pf: PartialFunction[In, Out]) extends TransitivePullOp[In, Out] { +private[akka] final case class Collect[In, Out](pf: PartialFunction[In, Out]) extends PushStage[In, Out] { import Collect.NotApplied - override def onPush(elem: In, ctxt: Context[Out]): Directive = + override def onPush(elem: In, ctx: Context[Out]): Directive = pf.applyOrElse(elem, NotApplied) match { - case NotApplied ⇒ ctxt.pull() - case result: Out ⇒ ctxt.push(result) + case NotApplied ⇒ ctx.pull() + case result: Out ⇒ ctx.push(result) } } /** * INTERNAL API */ -private[akka] final case class MapConcat[In, Out](f: In ⇒ immutable.Seq[Out]) extends DeterministicOp[In, Out] { +private[akka] final case class MapConcat[In, Out](f: In ⇒ immutable.Seq[Out]) extends PushPullStage[In, Out] { private var currentIterator: Iterator[Out] = Iterator.empty - override def onPush(elem: In, ctxt: Context[Out]): Directive = { + override def onPush(elem: In, ctx: Context[Out]): Directive = { currentIterator = f(elem).iterator - if (currentIterator.isEmpty) ctxt.pull() - else ctxt.push(currentIterator.next()) + if (currentIterator.isEmpty) ctx.pull() + else ctx.push(currentIterator.next()) } - override def onPull(ctxt: Context[Out]): Directive = - if (currentIterator.hasNext) ctxt.push(currentIterator.next()) - else if (isFinishing) ctxt.finish() - else ctxt.pull() + override def onPull(ctx: Context[Out]): Directive = + if (currentIterator.hasNext) ctx.push(currentIterator.next()) + else if (ctx.isFinishing) ctx.finish() + else ctx.pull() - override def onUpstreamFinish(ctxt: Context[Out]): TerminationDirective = - ctxt.absorbTermination() + override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective = + ctx.absorbTermination() } /** * INTERNAL API */ -private[akka] final case class Take[T](count: Int) extends TransitivePullOp[T, T] { +private[akka] final case class Take[T](count: Int) extends PushStage[T, T] { private var left: Int = count - override def onPush(elem: T, ctxt: Context[T]): Directive = { + override def onPush(elem: T, ctx: Context[T]): Directive = { left -= 1 - if (left > 0) ctxt.push(elem) - else if (left == 0) ctxt.pushAndFinish(elem) - else ctxt.finish() //Handle negative take counts + if (left > 0) ctx.push(elem) + else if (left == 0) ctx.pushAndFinish(elem) + else ctx.finish() //Handle negative take counts } } /** * INTERNAL API */ -private[akka] final case class Drop[T](count: Int) extends TransitivePullOp[T, T] { +private[akka] final case class Drop[T](count: Int) extends PushStage[T, T] { private var left: Int = count - override def onPush(elem: T, ctxt: Context[T]): Directive = + override def onPush(elem: T, ctx: Context[T]): Directive = if (left > 0) { left -= 1 - ctxt.pull() - } else ctxt.push(elem) + ctx.pull() + } else ctx.push(elem) } /** * INTERNAL API */ -private[akka] final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends DeterministicOp[In, Out] { +private[akka] final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends PushPullStage[In, Out] { private var aggregator = zero - override def onPush(elem: In, ctxt: Context[Out]): Directive = { + override def onPush(elem: In, ctx: Context[Out]): Directive = { val old = aggregator aggregator = f(old, elem) - ctxt.push(old) + ctx.push(old) } - override def onPull(ctxt: Context[Out]): Directive = - if (isFinishing) ctxt.pushAndFinish(aggregator) - else ctxt.pull() + override def onPull(ctx: Context[Out]): Directive = + if (ctx.isFinishing) ctx.pushAndFinish(aggregator) + else ctx.pull() - override def onUpstreamFinish(ctxt: Context[Out]): TerminationDirective = ctxt.absorbTermination() + override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective = ctx.absorbTermination() } /** * INTERNAL API */ -private[akka] final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends DeterministicOp[In, Out] { +private[akka] final case class Fold[In, Out](zero: Out, f: (Out, In) ⇒ Out) extends PushPullStage[In, Out] { private var aggregator = zero - override def onPush(elem: In, ctxt: Context[Out]): Directive = { + override def onPush(elem: In, ctx: Context[Out]): Directive = { aggregator = f(aggregator, elem) - ctxt.pull() + ctx.pull() } - override def onPull(ctxt: Context[Out]): Directive = - if (isFinishing) ctxt.pushAndFinish(aggregator) - else ctxt.pull() + override def onPull(ctx: Context[Out]): Directive = + if (ctx.isFinishing) ctx.pushAndFinish(aggregator) + else ctx.pull() - override def onUpstreamFinish(ctxt: Context[Out]): TerminationDirective = ctxt.absorbTermination() + override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective = ctx.absorbTermination() } /** * INTERNAL API */ -private[akka] final case class Grouped[T](n: Int) extends DeterministicOp[T, immutable.Seq[T]] { +private[akka] final case class Grouped[T](n: Int) extends PushPullStage[T, immutable.Seq[T]] { private val buf = { val b = Vector.newBuilder[T] b.sizeHint(n) @@ -135,83 +135,83 @@ private[akka] final case class Grouped[T](n: Int) extends DeterministicOp[T, imm } private var left = n - override def onPush(elem: T, ctxt: Context[immutable.Seq[T]]): Directive = { + override def onPush(elem: T, ctx: Context[immutable.Seq[T]]): Directive = { buf += elem left -= 1 if (left == 0) { val emit = buf.result() buf.clear() left = n - ctxt.push(emit) - } else ctxt.pull() + ctx.push(emit) + } else ctx.pull() } - override def onPull(ctxt: Context[immutable.Seq[T]]): Directive = - if (isFinishing) { + override def onPull(ctx: Context[immutable.Seq[T]]): Directive = + if (ctx.isFinishing) { val elem = buf.result() buf.clear() //FIXME null out the reference to the `buf`? left = n - ctxt.pushAndFinish(elem) - } else ctxt.pull() + ctx.pushAndFinish(elem) + } else ctx.pull() - override def onUpstreamFinish(ctxt: Context[immutable.Seq[T]]): TerminationDirective = - if (left == n) ctxt.finish() - else ctxt.absorbTermination() + override def onUpstreamFinish(ctx: Context[immutable.Seq[T]]): TerminationDirective = + if (left == n) ctx.finish() + else ctx.absorbTermination() } /** * INTERNAL API */ -private[akka] final case class Buffer[T](size: Int, overflowStrategy: OverflowStrategy) extends DetachedOp[T, T] { +private[akka] final case class Buffer[T](size: Int, overflowStrategy: OverflowStrategy) extends DetachedStage[T, T] { import OverflowStrategy._ private val buffer = FixedSizeBuffer(size) - override def onPush(elem: T, ctxt: DetachedContext[T]): UpstreamDirective = - if (isHolding) ctxt.pushAndPull(elem) - else enqueueAction(ctxt, elem) + override def onPush(elem: T, ctx: DetachedContext[T]): UpstreamDirective = + if (ctx.isHolding) ctx.pushAndPull(elem) + else enqueueAction(ctx, elem) - override def onPull(ctxt: DetachedContext[T]): DownstreamDirective = { - if (isFinishing) { + override def onPull(ctx: DetachedContext[T]): DownstreamDirective = { + if (ctx.isFinishing) { val elem = buffer.dequeue().asInstanceOf[T] - if (buffer.isEmpty) ctxt.pushAndFinish(elem) - else ctxt.push(elem) - } else if (isHolding) ctxt.pushAndPull(buffer.dequeue().asInstanceOf[T]) - else if (buffer.isEmpty) ctxt.hold() - else ctxt.push(buffer.dequeue().asInstanceOf[T]) + if (buffer.isEmpty) ctx.pushAndFinish(elem) + else ctx.push(elem) + } else if (ctx.isHolding) ctx.pushAndPull(buffer.dequeue().asInstanceOf[T]) + else if (buffer.isEmpty) ctx.hold() + else ctx.push(buffer.dequeue().asInstanceOf[T]) } - override def onUpstreamFinish(ctxt: DetachedContext[T]): TerminationDirective = - if (buffer.isEmpty) ctxt.finish() - else ctxt.absorbTermination() + override def onUpstreamFinish(ctx: DetachedContext[T]): TerminationDirective = + if (buffer.isEmpty) ctx.finish() + else ctx.absorbTermination() val enqueueAction: (DetachedContext[T], T) ⇒ UpstreamDirective = { overflowStrategy match { - case DropHead ⇒ { (ctxt, elem) ⇒ + case DropHead ⇒ { (ctx, elem) ⇒ if (buffer.isFull) buffer.dropHead() buffer.enqueue(elem) - ctxt.pull() + ctx.pull() } - case DropTail ⇒ { (ctxt, elem) ⇒ + case DropTail ⇒ { (ctx, elem) ⇒ if (buffer.isFull) buffer.dropTail() buffer.enqueue(elem) - ctxt.pull() + ctx.pull() } - case DropBuffer ⇒ { (ctxt, elem) ⇒ + case DropBuffer ⇒ { (ctx, elem) ⇒ if (buffer.isFull) buffer.clear() buffer.enqueue(elem) - ctxt.pull() + ctx.pull() } - case Backpressure ⇒ { (ctxt, elem) ⇒ + case Backpressure ⇒ { (ctx, elem) ⇒ buffer.enqueue(elem) - if (buffer.isFull) ctxt.hold() - else ctxt.pull() + if (buffer.isFull) ctx.hold() + else ctx.pull() } - case Error ⇒ { (ctxt, elem) ⇒ - if (buffer.isFull) ctxt.fail(new Error.BufferOverflowException(s"Buffer overflow (max capacity was: $size)!")) + case Error ⇒ { (ctx, elem) ⇒ + if (buffer.isFull) ctx.fail(new Error.BufferOverflowException(s"Buffer overflow (max capacity was: $size)!")) else { buffer.enqueue(elem) - ctxt.pull() + ctx.pull() } } } @@ -221,70 +221,70 @@ private[akka] final case class Buffer[T](size: Int, overflowStrategy: OverflowSt /** * INTERNAL API */ -private[akka] final case class Completed[T]() extends DeterministicOp[T, T] { - override def onPush(elem: T, ctxt: Context[T]): Directive = ctxt.finish() - override def onPull(ctxt: Context[T]): Directive = ctxt.finish() +private[akka] final case class Completed[T]() extends PushPullStage[T, T] { + override def onPush(elem: T, ctx: Context[T]): Directive = ctx.finish() + override def onPull(ctx: Context[T]): Directive = ctx.finish() } /** * INTERNAL API */ -private[akka] final case class Conflate[In, Out](seed: In ⇒ Out, aggregate: (Out, In) ⇒ Out) extends DetachedOp[In, Out] { +private[akka] final case class Conflate[In, Out](seed: In ⇒ Out, aggregate: (Out, In) ⇒ Out) extends DetachedStage[In, Out] { private var agg: Any = null - override def onPush(elem: In, ctxt: DetachedContext[Out]): UpstreamDirective = { + override def onPush(elem: In, ctx: DetachedContext[Out]): UpstreamDirective = { agg = if (agg == null) seed(elem) else aggregate(agg.asInstanceOf[Out], elem) - if (!isHolding) ctxt.pull() + if (!ctx.isHolding) ctx.pull() else { val result = agg.asInstanceOf[Out] agg = null - ctxt.pushAndPull(result) + ctx.pushAndPull(result) } } - override def onPull(ctxt: DetachedContext[Out]): DownstreamDirective = { - if (isFinishing) { - if (agg == null) ctxt.finish() + override def onPull(ctx: DetachedContext[Out]): DownstreamDirective = { + if (ctx.isFinishing) { + if (agg == null) ctx.finish() else { val result = agg.asInstanceOf[Out] agg = null - ctxt.pushAndFinish(result) + ctx.pushAndFinish(result) } - } else if (agg == null) ctxt.hold() + } else if (agg == null) ctx.hold() else { val result = agg.asInstanceOf[Out] agg = null - ctxt.push(result) + ctx.push(result) } } - override def onUpstreamFinish(ctxt: DetachedContext[Out]): TerminationDirective = ctxt.absorbTermination() + override def onUpstreamFinish(ctx: DetachedContext[Out]): TerminationDirective = ctx.absorbTermination() } /** * INTERNAL API */ -private[akka] final case class Expand[In, Out, Seed](seed: In ⇒ Seed, extrapolate: Seed ⇒ (Out, Seed)) extends DetachedOp[In, Out] { +private[akka] final case class Expand[In, Out, Seed](seed: In ⇒ Seed, extrapolate: Seed ⇒ (Out, Seed)) extends DetachedStage[In, Out] { private var s: Any = null - override def onPush(elem: In, ctxt: DetachedContext[Out]): UpstreamDirective = { + override def onPush(elem: In, ctx: DetachedContext[Out]): UpstreamDirective = { s = seed(elem) - if (isHolding) { + if (ctx.isHolding) { val (emit, newS) = extrapolate(s.asInstanceOf[Seed]) s = newS - ctxt.pushAndPull(emit) - } else ctxt.hold() + ctx.pushAndPull(emit) + } else ctx.hold() } - override def onPull(ctxt: DetachedContext[Out]): DownstreamDirective = { - if (s == null) ctxt.hold() + override def onPull(ctx: DetachedContext[Out]): DownstreamDirective = { + if (s == null) ctx.hold() else { val (emit, newS) = extrapolate(s.asInstanceOf[Seed]) s = newS - if (isHolding) ctxt.pushAndPull(emit) - else ctxt.push(emit) + if (ctx.isHolding) ctx.pushAndPull(emit) + else ctx.push(emit) } } diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/FlexiMerge.scala b/akka-stream/src/main/scala/akka/stream/javadsl/FlexiMerge.scala index 63701dca70..b88efc9994 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/FlexiMerge.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/FlexiMerge.scala @@ -68,7 +68,7 @@ object FlexiMerge { class ReadAny(val inputs: JList[InputHandle]) extends ReadCondition /** - * Read condition for the [[MergeLogic#State]] that will be + * Read condition for the [[FlexiMerge#State]] that will be * fulfilled when there are elements for any of the given upstream * inputs, however it differs from [[ReadAny]] in the case that both * the `preferred` and at least one other `secondary` input have demand, @@ -80,7 +80,7 @@ object FlexiMerge { class ReadPreferred(val preferred: InputHandle, val secondaries: JList[InputHandle]) extends ReadCondition /** - * Read condition for the [[MergeLogic#State]] that will be + * Read condition for the [[FlexiMerge#State]] that will be * fulfilled when there are elements for *all* of the given upstream * inputs. * diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/FlexiRoute.scala b/akka-stream/src/main/scala/akka/stream/javadsl/FlexiRoute.scala index 2ecbc074ea..6c76210021 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/FlexiRoute.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/FlexiRoute.scala @@ -122,13 +122,13 @@ object FlexiRoute { * handle cancel from downstream output. * * The `onComplete` method is called the upstream input was completed successfully. - * It returns next behavior or [[#SameState]] to keep current behavior. + * It returns next behavior or [[#sameState]] to keep current behavior. * * The `onError` method is called when the upstream input was completed with failure. * It returns next behavior or [[#SameState]] to keep current behavior. * * The `onCancel` method is called when a downstream output cancels. - * It returns next behavior or [[#SameState]] to keep current behavior. + * It returns next behavior or [[#sameState]] to keep current behavior. */ abstract class CompletionHandling[In] { def onComplete(ctx: RouteLogicContext[In, Any]): Unit @@ -144,7 +144,7 @@ object FlexiRoute { * [[RouteLogicContext#emit]]. * * The `onInput` method is called when an `element` was read from upstream. - * The function returns next behavior or [[#SameState]] to keep current behavior. + * The function returns next behavior or [[#sameState]] to keep current behavior. */ abstract class State[In, Out](val condition: DemandCondition) { def onInput(ctx: RouteLogicContext[In, Out], preferredOutput: OutputHandle, element: In): State[In, _] diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala index 36cc22c090..7561746a51 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Flow.scala @@ -4,13 +4,12 @@ package akka.stream.javadsl import akka.stream._ - import akka.japi.Util import akka.stream.scaladsl - import scala.annotation.unchecked.uncheckedVariance import scala.concurrent.Future import scala.concurrent.duration.FiniteDuration +import akka.stream.stage.Stage object Flow { @@ -282,46 +281,31 @@ class Flow[-In, +Out](delegate: scaladsl.Flow[In, Out]) { new Flow(delegate.buffer(size, overflowStrategy)) /** - * Generic transformation of a stream: for each element the [[akka.stream.Transformer#onNext]] - * function is invoked, expecting a (possibly empty) sequence of output elements - * to be produced. - * After handing off the elements produced from one input element to the downstream - * subscribers, the [[akka.stream.Transformer#isComplete]] predicate determines whether to end - * stream processing at this point; in that case the upstream subscription is - * canceled. Before signaling normal completion to the downstream subscribers, - * the [[akka.stream.Transformer#onTermination]] function is invoked to produce a (possibly empty) - * sequence of elements in response to the end-of-stream event. - * - * [[akka.stream.Transformer#onError]] is called when failure is signaled from upstream. - * - * After normal completion or error the [[akka.stream.Transformer#cleanup]] function is called. - * - * It is possible to keep state in the concrete [[akka.stream.Transformer]] instance with - * ordinary instance variables. The [[akka.stream.Transformer]] is executed by an actor and - * therefore you do not have to add any additional thread safety or memory - * visibility constructs to access the state from the callback methods. + * Generic transformation of a stream with a custom processing [[akka.stream.stage.Stage]]. + * This operator makes it possible to extend the `Flow` API when there is no specialized + * operator that performs the transformation. * * Note that you can use [[#timerTransform]] if you need support for scheduled events in the transformer. */ - def transform[U](name: String, mkTransformer: japi.Creator[Transformer[Out, U]]): javadsl.Flow[In, U] = - new Flow(delegate.transform(name, () ⇒ mkTransformer.create())) + def transform[U](name: String, mkStage: japi.Creator[Stage[Out, U]]): javadsl.Flow[In, U] = + new Flow(delegate.transform(name, () ⇒ mkStage.create())) /** * Transformation of a stream, with additional support for scheduled events. * - * For each element the [[akka.stream.Transformer#onNext]] + * For each element the [[akka.stream.TransformerLike#onNext]] * function is invoked, expecting a (possibly empty) sequence of output elements * to be produced. * After handing off the elements produced from one input element to the downstream - * subscribers, the [[akka.stream.Transformer#isComplete]] predicate determines whether to end + * subscribers, the [[akka.stream.TransformerLike#isComplete]] predicate determines whether to end * stream processing at this point; in that case the upstream subscription is * canceled. Before signaling normal completion to the downstream subscribers, - * the [[akka.stream.Transformer#onTermination]] function is invoked to produce a (possibly empty) + * the [[akka.stream.TransformerLike#onTermination]] function is invoked to produce a (possibly empty) * sequence of elements in response to the end-of-stream event. * - * [[akka.stream.Transformer#onError]] is called when failure is signaled from upstream. + * [[akka.stream.TransformerLike#onError]] is called when failure is signaled from upstream. * - * After normal completion or error the [[akka.stream.Transformer#cleanup]] function is called. + * After normal completion or error the [[akka.stream.TransformerLike#cleanup]] function is called. * * It is possible to keep state in the concrete [[akka.stream.Transformer]] instance with * ordinary instance variables. The [[akka.stream.Transformer]] is executed by an actor and @@ -330,8 +314,8 @@ class Flow[-In, +Out](delegate: scaladsl.Flow[In, Out]) { * * Note that you can use [[#transform]] if you just need to transform elements time plays no role in the transformation. */ - def timerTransform[U](name: String, mkTransformer: japi.Creator[TimerTransformer[Out, U]]): javadsl.Flow[In, U] = - new Flow(delegate.timerTransform(name, () ⇒ mkTransformer.create())) + def timerTransform[U](name: String, mkStage: japi.Creator[TimerTransformer[Out, U]]): javadsl.Flow[In, U] = + new Flow(delegate.timerTransform(name, () ⇒ mkStage.create())) /** * Takes up to `n` elements from the stream and returns a pair containing a strict sequence of the taken element diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala index b551faed02..8dadbb5c16 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Source.scala @@ -4,7 +4,6 @@ package akka.stream.javadsl import java.util.concurrent.Callable - import akka.actor.ActorRef import akka.actor.Props import akka.japi.Util @@ -12,13 +11,13 @@ import akka.stream._ import akka.stream.scaladsl.PropsSource import org.reactivestreams.Publisher import org.reactivestreams.Subscriber - import scala.annotation.unchecked.uncheckedVariance import scala.collection.JavaConverters._ import scala.concurrent.Future import scala.concurrent.duration.FiniteDuration import scala.language.higherKinds import scala.language.implicitConversions +import akka.stream.stage.Stage /** Java API */ object Source { @@ -387,56 +386,41 @@ class Source[+Out](delegate: scaladsl.Source[Out]) { new Source(delegate.buffer(size, overflowStrategy)) /** - * Generic transformation of a stream: for each element the [[akka.stream.Transformer#onNext]] - * function is invoked, expecting a (possibly empty) sequence of output elements - * to be produced. - * After handing off the elements produced from one input element to the downstream - * subscribers, the [[akka.stream.Transformer#isComplete]] predicate determines whether to end - * stream processing at this point; in that case the upstream subscription is - * canceled. Before signaling normal completion to the downstream subscribers, - * the [[akka.stream.Transformer#onTermination]] function is invoked to produce a (possibly empty) - * sequence of elements in response to the end-of-stream event. - * - * [[akka.stream.Transformer#onError]] is called when failure is signaled from upstream. - * - * After normal completion or error the [[akka.stream.Transformer#cleanup]] function is called. - * - * It is possible to keep state in the concrete [[akka.stream.Transformer]] instance with - * ordinary instance variables. The [[akka.stream.Transformer]] is executed by an actor and - * therefore you do not have to add any additional thread safety or memory - * visibility constructs to access the state from the callback methods. + * Generic transformation of a stream with a custom processing [[akka.stream.stage.Stage]]. + * This operator makes it possible to extend the `Flow` API when there is no specialized + * operator that performs the transformation. * * Note that you can use [[#timerTransform]] if you need support for scheduled events in the transformer. */ - def transform[U](name: String, mkTransformer: japi.Creator[Transformer[Out, U]]): javadsl.Source[U] = - new Source(delegate.transform(name, () ⇒ mkTransformer.create())) + def transform[U](name: String, mkStage: japi.Creator[Stage[Out, U]]): javadsl.Source[U] = + new Source(delegate.transform(name, () ⇒ mkStage.create())) /** * Transformation of a stream, with additional support for scheduled events. * - * For each element the [[akka.stream.Transformer#onNext]] + * For each element the [[akka.stream.TransformerLike#onNext]] * function is invoked, expecting a (possibly empty) sequence of output elements * to be produced. * After handing off the elements produced from one input element to the downstream - * subscribers, the [[akka.stream.Transformer#isComplete]] predicate determines whether to end + * subscribers, the [[akka.stream.TransformerLike#isComplete]] predicate determines whether to end * stream processing at this point; in that case the upstream subscription is * canceled. Before signaling normal completion to the downstream subscribers, - * the [[akka.stream.Transformer#onTermination]] function is invoked to produce a (possibly empty) + * the [[akka.stream.TransformerLike#onTermination]] function is invoked to produce a (possibly empty) * sequence of elements in response to the end-of-stream event. * - * [[akka.stream.Transformer#onError]] is called when failure is signaled from upstream. + * [[akka.stream.TransformerLike#onError]] is called when failure is signaled from upstream. * - * After normal completion or error the [[akka.stream.Transformer#cleanup]] function is called. + * After normal completion or error the [[akka.stream.TransformerLike#cleanup]] function is called. * - * It is possible to keep state in the concrete [[akka.stream.Transformer]] instance with - * ordinary instance variables. The [[akka.stream.Transformer]] is executed by an actor and + * It is possible to keep state in the concrete [[akka.stream.TimerTransformer]] instance with + * ordinary instance variables. The [[akka.stream.TimerTransformer]] is executed by an actor and * therefore you do not have to add any additional thread safety or memory * visibility constructs to access the state from the callback methods. * * Note that you can use [[#transform]] if you just need to transform elements time plays no role in the transformation. */ - def timerTransform[U](name: String, mkTransformer: japi.Creator[TimerTransformer[Out, U]]): javadsl.Source[U] = - new Source(delegate.timerTransform(name, () ⇒ mkTransformer.create())) + def timerTransform[U](name: String, mkStage: japi.Creator[TimerTransformer[Out, U]]): javadsl.Source[U] = + new Source(delegate.timerTransform(name, () ⇒ mkStage.create())) /** * Takes up to `n` elements from the stream and returns a pair containing a strict sequence of the taken element diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/ActorFlowSink.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/ActorFlowSink.scala index ea795d847e..fdd33d4c62 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/ActorFlowSink.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/ActorFlowSink.scala @@ -5,15 +5,14 @@ package akka.stream.scaladsl import akka.actor.ActorRef import akka.actor.Props - import scala.collection.immutable import scala.annotation.unchecked.uncheckedVariance import scala.concurrent.{ Future, Promise } import scala.util.{ Failure, Success, Try } import org.reactivestreams.{ Publisher, Subscriber, Subscription } -import akka.stream.Transformer import akka.stream.impl.{ ActorBasedFlowMaterializer, ActorProcessorFactory, FanoutProcessorImpl, BlackholeSubscriber } import java.util.concurrent.atomic.AtomicReference +import akka.stream.stage._ sealed trait ActorFlowSink[-In] extends Sink[In] { @@ -167,15 +166,15 @@ object OnCompleteSink { final case class OnCompleteSink[In](callback: Try[Unit] ⇒ Unit) extends SimpleActorFlowSink[In] { override def attach(flowPublisher: Publisher[In], materializer: ActorBasedFlowMaterializer, flowName: String) = - Source(flowPublisher).transform("onCompleteSink", () ⇒ new Transformer[In, Unit] { - override def onNext(in: In) = Nil - override def onError(e: Throwable) = () - override def onTermination(e: Option[Throwable]) = { - e match { - case None ⇒ callback(OnCompleteSink.SuccessUnit) - case Some(e) ⇒ callback(Failure(e)) - } - Nil + Source(flowPublisher).transform("onCompleteSink", () ⇒ new PushStage[In, Unit] { + override def onPush(elem: In, ctx: Context[Unit]): Directive = ctx.pull() + override def onUpstreamFailure(cause: Throwable, ctx: Context[Unit]): TerminationDirective = { + callback(Failure(cause)) + ctx.fail(cause) + } + override def onUpstreamFinish(ctx: Context[Unit]): TerminationDirective = { + callback(OnCompleteSink.SuccessUnit) + ctx.finish() } }).to(BlackholeSink).run()(materializer.withNamePrefix(flowName)) } @@ -191,15 +190,18 @@ final case class ForeachSink[In](f: In ⇒ Unit) extends KeyedActorFlowSink[In] override def attach(flowPublisher: Publisher[In], materializer: ActorBasedFlowMaterializer, flowName: String) = { val promise = Promise[Unit]() - Source(flowPublisher).transform("foreach", () ⇒ new Transformer[In, Unit] { - override def onNext(in: In) = { f(in); Nil } - override def onError(cause: Throwable): Unit = () - override def onTermination(e: Option[Throwable]) = { - e match { - case None ⇒ promise.success(()) - case Some(e) ⇒ promise.failure(e) - } - Nil + Source(flowPublisher).transform("foreach", () ⇒ new PushStage[In, Unit] { + override def onPush(elem: In, ctx: Context[Unit]): Directive = { + f(elem) + ctx.pull() + } + override def onUpstreamFailure(cause: Throwable, ctx: Context[Unit]): TerminationDirective = { + promise.failure(cause) + ctx.fail(cause) + } + override def onUpstreamFinish(ctx: Context[Unit]): TerminationDirective = { + promise.success(()) + ctx.finish() } }).to(BlackholeSink).run()(materializer.withNamePrefix(flowName)) promise.future @@ -220,16 +222,22 @@ final case class FoldSink[U, In](zero: U)(f: (U, In) ⇒ U) extends KeyedActorFl override def attach(flowPublisher: Publisher[In], materializer: ActorBasedFlowMaterializer, flowName: String) = { val promise = Promise[U]() - Source(flowPublisher).transform("fold", () ⇒ new Transformer[In, U] { - var state: U = zero - override def onNext(in: In): immutable.Seq[U] = { state = f(state, in); Nil } - override def onError(cause: Throwable) = () - override def onTermination(e: Option[Throwable]) = { - e match { - case None ⇒ promise.success(state) - case Some(e) ⇒ promise.failure(e) - } - Nil + Source(flowPublisher).transform("fold", () ⇒ new PushStage[In, U] { + private var aggregator = zero + + override def onPush(elem: In, ctx: Context[U]): Directive = { + aggregator = f(aggregator, elem) + ctx.pull() + } + + override def onUpstreamFailure(cause: Throwable, ctx: Context[U]): TerminationDirective = { + promise.failure(cause) + ctx.fail(cause) + } + + override def onUpstreamFinish(ctx: Context[U]): TerminationDirective = { + promise.success(aggregator) + ctx.finish() } }).to(BlackholeSink).run()(materializer.withNamePrefix(flowName)) diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/FlexiMerge.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/FlexiMerge.scala index 6a7a07bc82..02193acef5 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/FlexiMerge.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/FlexiMerge.scala @@ -107,7 +107,7 @@ object FlexiMerge { /** * The possibly stateful logic that reads from input via the defined [[MergeLogic#State]] and - * handles completion and error via the defined [[FlexiMerge#CompletionHandling]]. + * handles completion and error via the defined [[MergeLogic#CompletionHandling]]. * * Concrete instance is supposed to be created by implementing [[FlexiMerge#createMergeLogic]]. */ diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index 04f7ec3f3d..3b17c6eafd 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -4,7 +4,7 @@ package akka.stream.scaladsl import akka.stream.impl.Ast._ -import akka.stream.{ TimerTransformer, Transformer, OverflowStrategy } +import akka.stream.{ TimerTransformer, TransformerLike, OverflowStrategy } import akka.util.Collections.EmptyImmutableSeq import scala.collection.immutable import scala.concurrent.duration.{ Duration, FiniteDuration } @@ -12,6 +12,7 @@ import scala.concurrent.Future import scala.language.higherKinds import akka.stream.FlowMaterializer import akka.stream.FlattenStrategy +import akka.stream.stage._ /** * A `Flow` is a set of stream processing steps that has one open input and one open output. @@ -217,8 +218,8 @@ trait FlowOps[+Out] { timerTransform("dropWithin", () ⇒ new TimerTransformer[Out, Out] { scheduleOnce(DropWithinTimerKey, d) - var delegate: Transformer[Out, Out] = - new Transformer[Out, Out] { + var delegate: TransformerLike[Out, Out] = + new TransformerLike[Out, Out] { def onNext(in: Out) = Nil } @@ -253,7 +254,7 @@ trait FlowOps[+Out] { timerTransform("takeWithin", () ⇒ new TimerTransformer[Out, Out] { scheduleOnce(TakeWithinTimerKey, d) - var delegate: Transformer[Out, Out] = FlowOps.identityTransformer[Out] + var delegate: TransformerLike[Out, Out] = FlowOps.identityTransformer[Out] override def onNext(in: Out) = delegate.onNext(in) override def isComplete = delegate.isComplete @@ -305,30 +306,14 @@ trait FlowOps[+Out] { andThen(Buffer(size, overflowStrategy)) /** - * Generic transformation of a stream: for each element the [[akka.stream.Transformer#onNext]] - * function is invoked, expecting a (possibly empty) sequence of output elements - * to be produced. - * After handing off the elements produced from one input element to the downstream - * subscribers, the [[akka.stream.Transformer#isComplete]] predicate determines whether to end - * stream processing at this point; in that case the upstream subscription is - * canceled. Before signaling normal completion to the downstream subscribers, - * the [[akka.stream.Transformer#onTermination]] function is invoked to produce a (possibly empty) - * sequence of elements in response to the end-of-stream event. - * - * [[akka.stream.Transformer#onError]] is called when failure is signaled from upstream. - * - * After normal completion or error the [[akka.stream.Transformer#cleanup]] function is called. - * - * It is possible to keep state in the concrete [[akka.stream.Transformer]] instance with - * ordinary instance variables. The [[akka.stream.Transformer]] is executed by an actor and - * therefore you do not have to add any additional thread safety or memory - * visibility constructs to access the state from the callback methods. + * Generic transformation of a stream with a custom processing [[akka.stream.stage.Stage]]. + * This operator makes it possible to extend the `Flow` API when there is no specialized + * operator that performs the transformation. * * Note that you can use [[#timerTransform]] if you need support for scheduled events in the transformer. */ - def transform[T](name: String, mkTransformer: () ⇒ Transformer[Out, T]): Repr[T] = { - andThen(Transform(name, mkTransformer.asInstanceOf[() ⇒ Transformer[Any, Any]])) - } + def transform[T](name: String, mkStage: () ⇒ Stage[Out, T]): Repr[T] = + andThen(StageFactory(mkStage, name)) /** * Takes up to `n` elements from the stream and returns a pair containing a strict sequence of the taken element @@ -381,19 +366,19 @@ trait FlowOps[+Out] { /** * Transformation of a stream, with additional support for scheduled events. * - * For each element the [[akka.stream.Transformer#onNext]] + * For each element the [[akka.stream.TransformerLike#onNext]] * function is invoked, expecting a (possibly empty) sequence of output elements * to be produced. * After handing off the elements produced from one input element to the downstream - * subscribers, the [[akka.stream.Transformer#isComplete]] predicate determines whether to end + * subscribers, the [[akka.stream.TransformerLike#isComplete]] predicate determines whether to end * stream processing at this point; in that case the upstream subscription is * canceled. Before signaling normal completion to the downstream subscribers, - * the [[akka.stream.Transformer#onTermination]] function is invoked to produce a (possibly empty) + * the [[akka.stream.TransformerLike#onTermination]] function is invoked to produce a (possibly empty) * sequence of elements in response to the end-of-stream event. * - * [[akka.stream.Transformer#onError]] is called when failure is signaled from upstream. + * [[akka.stream.TransformerLike#onError]] is called when failure is signaled from upstream. * - * After normal completion or error the [[akka.stream.Transformer#cleanup]] function is called. + * After normal completion or error the [[akka.stream.TransformerLike#cleanup]] function is called. * * It is possible to keep state in the concrete [[akka.stream.Transformer]] instance with * ordinary instance variables. The [[akka.stream.Transformer]] is executed by an actor and @@ -402,8 +387,8 @@ trait FlowOps[+Out] { * * Note that you can use [[#transform]] if you just need to transform elements time plays no role in the transformation. */ - def timerTransform[U](name: String, mkTransformer: () ⇒ TimerTransformer[Out, U]): Repr[U] = - andThen(TimerTransform(name, mkTransformer.asInstanceOf[() ⇒ TimerTransformer[Any, Any]])) + def timerTransform[U](name: String, mkStage: () ⇒ TimerTransformer[Out, U]): Repr[U] = + andThen(TimerTransform(mkStage.asInstanceOf[() ⇒ TimerTransformer[Any, Any]], name)) /** INTERNAL API */ // Storing ops in reverse order @@ -418,16 +403,20 @@ private[stream] object FlowOps { private case object DropWithinTimerKey private case object GroupedWithinTimerKey - private[this] final case object CompletedTransformer extends Transformer[Any, Any] { + private[this] final case object CompletedTransformer extends TransformerLike[Any, Any] { override def onNext(elem: Any) = Nil override def isComplete = true } - private[this] final case object IdentityTransformer extends Transformer[Any, Any] { + private[this] final case object IdentityTransformer extends TransformerLike[Any, Any] { override def onNext(elem: Any) = List(elem) } - def completedTransformer[T]: Transformer[T, T] = CompletedTransformer.asInstanceOf[Transformer[T, T]] - def identityTransformer[T]: Transformer[T, T] = IdentityTransformer.asInstanceOf[Transformer[T, T]] + def completedTransformer[T]: TransformerLike[T, T] = CompletedTransformer.asInstanceOf[TransformerLike[T, T]] + def identityTransformer[T]: TransformerLike[T, T] = IdentityTransformer.asInstanceOf[TransformerLike[T, T]] + + def identityStage[T]: Stage[T, T] = new PushStage[T, T] { + override def onPush(elem: T, ctx: Context[T]): Directive = ctx.push(elem) + } } diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Pipe.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Pipe.scala index 700f116cf0..768d364f1f 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Pipe.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Pipe.scala @@ -8,7 +8,7 @@ import scala.annotation.unchecked.uncheckedVariance import scala.language.{ existentials, higherKinds } import akka.stream.FlowMaterializer -private[stream] object Pipe { +private[akka] object Pipe { private val emptyInstance = Pipe[Any, Any](ops = Nil) def empty[T]: Pipe[T, T] = emptyInstance.asInstanceOf[Pipe[T, T]] } @@ -16,7 +16,7 @@ private[stream] object Pipe { /** * Flow with one open input and one open output. */ -private[stream] final case class Pipe[-In, +Out](ops: List[AstNode]) extends Flow[In, Out] { +private[akka] final case class Pipe[-In, +Out](ops: List[AstNode]) extends Flow[In, Out] { override type Repr[+O] = Pipe[In @uncheckedVariance, O] override private[scaladsl] def andThen[U](op: AstNode): Repr[U] = this.copy(ops = op :: ops) // FIXME raw addition of AstNodes diff --git a/akka-stream/src/main/scala/akka/stream/stage/Stage.scala b/akka-stream/src/main/scala/akka/stream/stage/Stage.scala new file mode 100644 index 0000000000..a47d2a81b5 --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/stage/Stage.scala @@ -0,0 +1,434 @@ +/** + * Copyright (C) 2014 Typesafe Inc. + */ +package akka.stream.stage + +/** + * General interface for stream transformation. + * + * Custom `Stage` implementations are intended to be used with + * [[akka.stream.scaladsl.FlowOps#transform]] or + * [[akka.stream.javadsl.Flow#transform]] to extend the `Flow` API when there + * is no specialized operator that performs the transformation. + * + * Custom implementations are subclasses of [[PushPullStage]] or + * [[DetachedStage]]. Sometimes it is convenient to extend + * [[StatefulStage]] for support of become like behavior. + * + * It is possible to keep state in the concrete `Stage` instance with + * ordinary instance variables. The `Transformer` is executed by an actor and + * therefore you don not have to add any additional thread safety or memory + * visibility constructs to access the state from the callback methods. + * + * @see [[akka.stream.scaladsl.Flow#transform]] + * @see [[akka.stream.javadsl.Flow#transform]] + */ +sealed trait Stage[-In, Out] + +private[stream] abstract class AbstractStage[-In, Out, PushD <: Directive, PullD <: Directive, Ctx <: Context[Out]] extends Stage[In, Out] { + private[stream] var holding = false + private[stream] var allowedToPush = false + private[stream] var terminationPending = false + + /** + * `onPush` is called when an element from upstream is available and there is demand from downstream, i.e. + * in `onPush` you are allowed to call [[akka.stream.stage.Context#push]] to emit one element downstreams, + * or you can absorb the element by calling [[akka.stream.stage.Context#pull]]. Note that you can only + * emit zero or one element downstream from `onPull`. + * + * To emit more than one element you have to push the remaining elements from [[#onPull]], one-by-one. + * `onPush` is not called again until `onPull` has requested more elements with + * [[akka.stream.stage.Context#pull]]. + */ + def onPush(elem: In, ctx: Ctx): PushD + + /** + * `onPull` is called when there is demand from downstream, i.e. you are allowed to push one element + * downstreams with [[akka.stream.stage.Context#push]], or request elements from upstreams with + * [[akka.stream.stage.Context#pull]] + */ + def onPull(ctx: Ctx): PullD + + /** + * `onUpstreamFinish` is called when upstream has signaled that the stream is + * successfully completed. Here you cannot call [[akka.stream.stage.Context#push]], + * because there might not be any demand from downstream. To emit additional elements before + * terminating you can use [[akka.stream.stage.Context#absorbTermination]] and push final elements + * from [[#onPull]]. The stage will then be in finishing state, which can be checked + * with [[akka.stream.stage.Context#isFinishing]]. + * + * By default the finish signal is immediately propagated with [[akka.stream.stage.Context#finish]]. + */ + def onUpstreamFinish(ctx: Ctx): TerminationDirective = ctx.finish() + + /** + * `onDownstreamFinish` is called when downstream has cancelled. + * + * By default the cancel signal is immediately propagated with [[akka.stream.stage.Context#finish]]. + */ + def onDownstreamFinish(ctx: Ctx): TerminationDirective = ctx.finish() + + /** + * `onUpstreamFailure` is called when upstream has signaled that the stream is completed + * with error. It is not called if [[#onPull]] or [[#onPush]] of the stage itself + * throws an exception. + * + * Note that elements that were emitted by upstream before the error happened might + * not have been received by this stage when `onUpstreamFailure` is called, i.e. + * errors are not backpressured and might be propagated as soon as possible. + * + * Here you cannot call [[akka.stream.stage.Context#push]], because there might not + * be any demand from downstream. To emit additional elements before terminating you + * can use [[akka.stream.stage.Context#absorbTermination]] and push final elements + * from [[#onPull]]. The stage will then be in finishing state, which can be checked + * with [[akka.stream.stage.Context#isFinishing]]. + */ + def onUpstreamFailure(cause: Throwable, ctx: Ctx): TerminationDirective = ctx.fail(cause) + +} + +/** + * `PushPullStage` implementations participate in 1-bounded regions. For every external non-completion signal these + * stages produce *exactly one* push or pull signal. + * + * [[#onPush]] is called when an element from upstream is available and there is demand from downstream, i.e. + * in `onPush` you are allowed to call [[Context#push]] to emit one element downstreams, or you can absorb the + * element by calling [[Context#pull]]. Note that you can only emit zero or one element downstream from `onPull`. + * To emit more than one element you have to push the remaining elements from [[#onPull]], one-by-one. + * `onPush` is not called again until `onPull` has requested more elements with [[Context#pull]]. + * + * [[StatefulStage]] has support for making it easy to emit more than one element from `onPush`. + * + * [[#onPull]] is called when there is demand from downstream, i.e. you are allowed to push one element + * downstreams with [[Context#push]], or request elements from upstreams with [[Context#pull]]. If you + * always perform transitive pull by calling `ctx.pull` from `onPull` you can use [[PushStage]] instead of + * `PushPullStage`. + * + * Stages are allowed to do early completion of downstream and cancel of upstream. This is done with [[Context#finish]], + * which is a combination of cancel/complete. + * + * Since onComplete is not a backpressured signal it is sometimes preferable to push a final element and then + * immediately finish. This combination is exposed as [[Context#pushAndFinish]] which enables stages to + * propagate completion events without waiting for an extra round of pull. + * + * Another peculiarity is how to convert termination events (complete/failure) into elements. The problem + * here is that the termination events are not backpressured while elements are. This means that simply calling + * [[Context#push]] as a response to [[#onUpstreamFinish]] or [[#onUpstreamFailure]] will very likely break boundedness + * and result in a buffer overflow somewhere. Therefore the only allowed command in this case is + * [[Context#absorbTermination]] which stops the propagation of the termination signal, and puts the stage in a + * [[akka.stream.stage.Context#isFinishing]] state. Depending on whether the stage has a pending pull signal it + * has not yet "consumed" by a push its [[#onPull]] handler might be called immediately or later. From + * [[#onPull]] final elements can be pushed before completing downstream with [[Context#finish]] or + * [[Context#pushAndFinish]]. + * + * [[StatefulStage]] has support for making it easy to emit final elements. + * + * All these rules are enforced by types and runtime checks where needed. Always return the `Directive` + * from the call to the [[Context]] method, and do only call [[Context]] commands once per callback. + * + * @see [[DetachedStage]] + * @see [[StatefulStage]] + * @see [[PushStage]] + */ +abstract class PushPullStage[In, Out] extends AbstractStage[In, Out, Directive, Directive, Context[Out]] + +/** + * `PushStage` is a [[PushPullStage]] that always perform transitive pull by calling `ctx.pull` from `onPull`. + */ +abstract class PushStage[In, Out] extends PushPullStage[In, Out] { + /** + * Always pulls from upstream. + */ + final override def onPull(ctx: Context[Out]): Directive = ctx.pull() +} + +/** + * `DetachedStage` can be used to implement operations similar to [[akka.stream.scaladsl.FlowOps#buffer buffer]], + * [[akka.stream.scaladsl.FlowOps#expand expand]] and [[akka.stream.scaladsl.FlowOps#conflate conflate]]. + * + * `DetachedStage` implementations are boundaries between 1-bounded regions. This means that they need to enforce the + * "exactly one" property both on their upstream and downstream regions. As a consequence a `DetachedStage` can never + * answer an [[#onPull]] with a [[Context#pull]] or answer an [[#onPush]] with a [[Context#push]] since such an action + * would "steal" the event from one region (resulting in zero signals) and would inject it to the other region + * (resulting in two signals). + * + * However, DetachedStages have the ability to call [[akka.stream.stage.DetachedContext#hold]] as a response to + * [[#onPush]] and [[akka.stream.stage.DetachedContext##onPull]] which temporarily takes the signal off and + * stops execution, at the same time putting the stage in an [[akka.stream.stage.DetachedContext#isHolding]] state. + * If the stage is in a holding state it contains one absorbed signal, therefore in this state the only possible + * command to call is [[akka.stream.stage.DetachedContext#pushAndPull]] which results in two events making the + * balance right again: 1 hold + 1 external event = 2 external event + * + * This mechanism allows synchronization between the upstream and downstream regions which otherwise can progress + * independently. + * + * @see [[PushPullStage]] + */ +abstract class DetachedStage[In, Out] extends AbstractStage[In, Out, UpstreamDirective, DownstreamDirective, DetachedContext[Out]] + +/** + * The behavior of [[StatefulStage]] is defined by these two methods, which + * has the same sematics as corresponding methods in [[PushPullStage]]. + */ +abstract class StageState[In, Out] { + def onPush(elem: In, ctx: Context[Out]): Directive + def onPull(ctx: Context[Out]): Directive = ctx.pull() +} + +/** + * INTERNAL API + */ +private[akka] object StatefulStage { + sealed trait AndThen + case object Finish extends AndThen + final case class Become(state: StageState[Any, Any]) extends AndThen + case object Stay extends AndThen +} + +/** + * `StatefulStage` is a [[PushPullStage]] that provides convenience to make some things easier. + * + * The behavior is defined in [[StageState]] instances. The initial behavior is specified + * by subclass implementing the [[#initial]] method. The behavior can be changed by using [[#become]]. + * + * Use [[#emit]] or [[#emitAndFinish]] to push more than one element from [[StageState#onPush]] or + * [[StageState#onPull]]. + * + * Use [[#terminationEmit]] to push final elements from [[#onUpstreamFinish]] or [[#onUpstreamFailure]]. + */ +abstract class StatefulStage[In, Out] extends PushPullStage[In, Out] { + import StatefulStage._ + + /** + * Scala API + */ + abstract class State extends StageState[In, Out] + + private var emitting = false + private var _current: StageState[In, Out] = _ + become(initial) + + /** + * Concrete subclass must return the initial behavior from this method. + */ + def initial: StageState[In, Out] + + /** + * Current state. + */ + final def current: StageState[In, Out] = _current + + /** + * Change the behavior to another [[StageState]]. + */ + final def become(state: StageState[In, Out]): Unit = { + require(state ne null, "New state must not be null") + _current = state + } + + /** + * Invokes current state. + */ + final override def onPush(elem: In, ctx: Context[Out]): Directive = _current.onPush(elem, ctx) + /** + * Invokes current state. + */ + final override def onPull(ctx: Context[Out]): Directive = _current.onPull(ctx) + + override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective = + if (emitting) ctx.absorbTermination() + else ctx.finish() + + /** + * Scala API: Can be used from [[StageState#onPush]] or [[StageState#onPull]] to push more than one + * element downstreams. + */ + final def emit(iter: Iterator[Out], ctx: Context[Out]): Directive = emit(iter, ctx, _current) + + /** + * Java API: Can be used from [[StageState#onPush]] or [[StageState#onPull]] to push more than one + * element downstreams. + */ + final def emit(iter: java.util.Iterator[Out], ctx: Context[Out]): Directive = { + import scala.collection.JavaConverters._ + emit(iter.asScala, ctx) + } + + /** + * Scala API: Can be used from [[StageState#onPush]] or [[StageState#onPull]] to push more than one + * element downstreams and after that change behavior. + */ + final def emit(iter: Iterator[Out], ctx: Context[Out], nextState: StageState[In, Out]): Directive = { + if (emitting) throw new IllegalStateException("already in emitting state") + if (iter.isEmpty) { + become(nextState) + ctx.pull() + } else { + val elem = iter.next() + if (iter.hasNext) { + emitting = true + become(emittingState(iter, andThen = Become(nextState.asInstanceOf[StageState[Any, Any]]))) + } + ctx.push(elem) + } + } + + /** + * Java API: Can be used from [[StageState#onPush]] or [[StageState#onPull]] to push more than one + * element downstreams and after that change behavior. + */ + final def emit(iter: java.util.Iterator[Out], ctx: Context[Out], nextState: StageState[In, Out]): Directive = { + import scala.collection.JavaConverters._ + emit(iter.asScala, ctx, nextState) + } + + /** + * Scala API: Can be used from [[StageState#onPush]] or [[StageState#onPull]] to push more than one + * element downstreams and after that finish (complete downstreams, cancel upstreams). + */ + final def emitAndFinish(iter: Iterator[Out], ctx: Context[Out]): Directive = { + if (emitting) throw new IllegalStateException("already in emitting state") + if (iter.isEmpty) + ctx.finish() + else { + val elem = iter.next() + if (iter.hasNext) { + emitting = true + become(emittingState(iter, andThen = Finish)) + ctx.push(elem) + } else + ctx.pushAndFinish(elem) + } + } + + /** + * Java API: Can be used from [[StageState#onPush]] or [[StageState#onPull]] to push more than one + * element downstreams and after that finish (complete downstreams, cancel upstreams). + */ + final def emitAndFinish(iter: java.util.Iterator[Out], ctx: Context[Out]): Directive = { + import scala.collection.JavaConverters._ + emitAndFinish(iter.asScala, ctx) + } + + /** + * Scala API: Can be used from [[#onUpstreamFinish]] to push final elements downstreams + * before completing the stream successfully. Note that if this is used from + * [[#onUpstreamFailure]] the error will be absorbed and the stream will be completed + * successfully. + */ + final def terminationEmit(iter: Iterator[Out], ctx: Context[Out]): TerminationDirective = { + val empty = iter.isEmpty + if (empty && emitting) ctx.absorbTermination() + else if (empty) ctx.finish() + else { + become(emittingState(iter, andThen = Finish)) + ctx.absorbTermination() + } + } + + /** + * Java API: Can be used from [[#onUpstreamFinish]] or [[#onUpstreamFailure]] to push final + * elements downstreams. + */ + final def terminationEmit(iter: java.util.Iterator[Out], ctx: Context[Out]): TerminationDirective = { + import scala.collection.JavaConverters._ + terminationEmit(iter.asScala, ctx) + } + + private def emittingState(iter: Iterator[Out], andThen: AndThen) = new State { + override def onPush(elem: In, ctx: Context[Out]) = throw new IllegalStateException("onPush not allowed in emittingState") + override def onPull(ctx: Context[Out]) = { + if (iter.hasNext) { + val elem = iter.next() + if (iter.hasNext) + ctx.push(elem) + else if (!ctx.isFinishing) { + emitting = false + andThen match { + case Stay ⇒ // ok + case Become(newState) ⇒ become(newState.asInstanceOf[StageState[In, Out]]) + case Finish ⇒ ctx.pushAndFinish(elem) + } + ctx.push(elem) + } else + ctx.pushAndFinish(elem) + } else + throw new IllegalStateException("onPull with empty iterator is not expected in emittingState") + } + } + +} + +/** + * Return type from [[Context]] methods. + */ +sealed trait Directive +sealed trait UpstreamDirective extends Directive +sealed trait DownstreamDirective extends Directive +sealed trait TerminationDirective extends Directive +final class FreeDirective extends UpstreamDirective with DownstreamDirective with TerminationDirective + +/** + * Passed to the callback methods of [[PushPullStage]] and [[StatefulStage]]. + */ +sealed trait Context[Out] { + /** + * Push one element to downstreams. + */ + def push(elem: Out): DownstreamDirective + /** + * Request for more elements from upstreams. + */ + def pull(): UpstreamDirective + /** + * Cancel upstreams and complete downstreams successfully. + */ + def finish(): FreeDirective + /** + * Push one element to downstream immediately followed by + * cancel of upstreams and complete of downstreams. + */ + def pushAndFinish(elem: Out): DownstreamDirective + /** + * Cancel upstreams and complete downstreams with failure. + */ + def fail(cause: Throwable): FreeDirective + /** + * Puts the stage in a finishing state so that + * final elements can be pushed from `onPull`. + */ + def absorbTermination(): TerminationDirective + + /** + * This returns `true` after [[#absorbTermination]] has been used. + */ + def isFinishing: Boolean +} + +/** + * Passed to the callback methods of [[DetachedStage]]. + * + * [[#hold]] stops execution and at the same time putting the stage in a holding state. + * If the stage is in a holding state it contains one absorbed signal, therefore in + * this state the only possible command to call is [[#pushAndPull]] which results in two + * events making the balance right again: 1 hold + 1 external event = 2 external event + */ +trait DetachedContext[Out] extends Context[Out] { + def hold(): FreeDirective + + /** + * This returns `true` when [[#hold]] has been used + * and it is reset to `false` after [[#pushAndPull]]. + */ + def isHolding: Boolean + + def pushAndPull(elem: Out): FreeDirective + +} + +/** + * INTERNAL API + */ +private[akka] trait BoundaryContext extends Context[Any] { + def exit(): FreeDirective +}