=htc Convert ChunkTransformer to GraphStage (#20493)

* Convert ChunkTransformer to GraphStage

* Updated onUpstreamFinish of ChunkTransformer class

* Convert chunkTransformer to graph stage

* Convert ChunkTransformer to GraphStage

* Convert ChunkTranformer to GraphStage
This commit is contained in:
poojadshende 2016-07-07 14:26:52 -07:00 committed by Konrad Malawski
parent de18e3fe09
commit 2769b5e1cb

View file

@ -15,6 +15,9 @@ import akka.http.scaladsl.model._
import akka.http.impl.util._ import akka.http.impl.util._
import akka.http.scaladsl.model.HttpEntity.ChunkStreamPart import akka.http.scaladsl.model.HttpEntity.ChunkStreamPart
import akka.stream.stage.{ Context, GraphStage, SyncDirective, TerminationDirective }
import akka.stream._
import akka.stream.scaladsl.{ Sink, Source, Flow, Keep }
/** /**
* INTERNAL API * INTERNAL API
*/ */
@ -53,19 +56,31 @@ private object RenderSupport {
} }
object ChunkTransformer { object ChunkTransformer {
val flow = Flow[ChunkStreamPart].transform(() new ChunkTransformer).named("renderChunks") val flow = Flow.fromGraph(new ChunkTransformer).named("renderChunks")
} }
class ChunkTransformer extends StatefulStage[HttpEntity.ChunkStreamPart, ByteString] { class ChunkTransformer extends GraphStage[FlowShape[HttpEntity.ChunkStreamPart, ByteString]] {
override def initial = new State { val out: Outlet[ByteString] = Outlet("ChunkTransformer.out")
override def onPush(chunk: HttpEntity.ChunkStreamPart, ctx: Context[ByteString]): SyncDirective = { val in: Inlet[HttpEntity.ChunkStreamPart] = Inlet("ChunkTransformer.in")
val bytes = renderChunk(chunk) val shape: FlowShape[HttpEntity.ChunkStreamPart, ByteString] = FlowShape.of(in, out)
if (chunk.isLastChunk) ctx.pushAndFinish(bytes)
else ctx.push(bytes) override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {
override def onPush(): Unit = {
val chunk = grab(in)
val bytes = renderChunk(chunk)
push(out, bytes)
if (chunk.isLastChunk) completeStage()
}
override def onPull(): Unit = pull(in)
override def onUpstreamFinish(): Unit = {
emit(out, defaultLastChunkBytes)
completeStage()
}
setHandlers(in, out, this)
} }
}
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective =
terminationEmit(Iterator.single(defaultLastChunkBytes), ctx)
} }
object CheckContentLengthTransformer { object CheckContentLengthTransformer {