diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/RenderSupport.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/RenderSupport.scala index 23895ada3a..ca4a3ef0a2 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/RenderSupport.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/RenderSupport.scala @@ -15,6 +15,9 @@ import akka.http.scaladsl.model._ import akka.http.impl.util._ import akka.http.scaladsl.model.HttpEntity.ChunkStreamPart +import akka.stream.stage.{ Context, GraphStage, SyncDirective, TerminationDirective } +import akka.stream._ +import akka.stream.scaladsl.{ Sink, Source, Flow, Keep } /** * INTERNAL API */ @@ -53,19 +56,31 @@ private object RenderSupport { } object ChunkTransformer { - val flow = Flow[ChunkStreamPart].transform(() ⇒ new ChunkTransformer).named("renderChunks") + val flow = Flow.fromGraph(new ChunkTransformer).named("renderChunks") } - class ChunkTransformer extends StatefulStage[HttpEntity.ChunkStreamPart, ByteString] { - override def initial = new State { - override def onPush(chunk: HttpEntity.ChunkStreamPart, ctx: Context[ByteString]): SyncDirective = { - val bytes = renderChunk(chunk) - if (chunk.isLastChunk) ctx.pushAndFinish(bytes) - else ctx.push(bytes) + class ChunkTransformer extends GraphStage[FlowShape[HttpEntity.ChunkStreamPart, ByteString]] { + val out: Outlet[ByteString] = Outlet("ChunkTransformer.out") + val in: Inlet[HttpEntity.ChunkStreamPart] = Inlet("ChunkTransformer.in") + val shape: FlowShape[HttpEntity.ChunkStreamPart, ByteString] = FlowShape.of(in, out) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with InHandler with OutHandler { + override def onPush(): Unit = { + val chunk = grab(in) + val bytes = renderChunk(chunk) + push(out, bytes) + if (chunk.isLastChunk) completeStage() + } + + override def onPull(): Unit = pull(in) + + override def onUpstreamFinish(): Unit = { + emit(out, defaultLastChunkBytes) + completeStage() + } + setHandlers(in, out, this) } - } - override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = - terminationEmit(Iterator.single(defaultLastChunkBytes), ctx) } object CheckContentLengthTransformer {