=str #19361 migrating ByteStringParserStage to graph stage
This commit is contained in:
parent
3fc332d2c9
commit
07c0da36f2
14 changed files with 184 additions and 267 deletions
|
|
@ -46,7 +46,7 @@ private[http] object FrameEventParser extends ByteStringParser[FrameEvent] {
|
||||||
}
|
}
|
||||||
|
|
||||||
object ReadFrameHeader extends Step {
|
object ReadFrameHeader extends Step {
|
||||||
override def parse(reader: ByteReader): (FrameEvent, Step) = {
|
override def parse(reader: ByteReader): ParseResult[FrameEvent] = {
|
||||||
import Protocol._
|
import Protocol._
|
||||||
|
|
||||||
val flagsAndOp = reader.readByte()
|
val flagsAndOp = reader.readByte()
|
||||||
|
|
@ -83,23 +83,25 @@ private[http] object FrameEventParser extends ByteStringParser[FrameEvent] {
|
||||||
|
|
||||||
val takeNow = (header.length min reader.remainingSize).toInt
|
val takeNow = (header.length min reader.remainingSize).toInt
|
||||||
val thisFrameData = reader.take(takeNow)
|
val thisFrameData = reader.take(takeNow)
|
||||||
|
val noMoreData = thisFrameData.length == length
|
||||||
|
|
||||||
val nextState =
|
val nextState =
|
||||||
if (thisFrameData.length == length) ReadFrameHeader
|
if (noMoreData) ReadFrameHeader
|
||||||
else new ReadData(length - thisFrameData.length)
|
else new ReadData(length - thisFrameData.length)
|
||||||
|
|
||||||
(FrameStart(header, thisFrameData.compact), nextState)
|
ParseResult(Some(FrameStart(header, thisFrameData.compact)), nextState, true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class ReadData(_remaining: Long) extends Step {
|
class ReadData(_remaining: Long) extends Step {
|
||||||
|
override def canWorkWithPartialData = true
|
||||||
var remaining = _remaining
|
var remaining = _remaining
|
||||||
override def parse(reader: ByteReader): (FrameEvent, Step) =
|
override def parse(reader: ByteReader): ParseResult[FrameEvent] =
|
||||||
if (reader.remainingSize < remaining) {
|
if (reader.remainingSize < remaining) {
|
||||||
remaining -= reader.remainingSize
|
remaining -= reader.remainingSize
|
||||||
(FrameData(reader.takeAll(), lastPart = false), this)
|
ParseResult(Some(FrameData(reader.takeAll(), lastPart = false)), this, true)
|
||||||
} else {
|
} else {
|
||||||
(FrameData(reader.take(remaining.toInt), lastPart = true), ReadFrameHeader)
|
ParseResult(Some(FrameData(reader.take(remaining.toInt), lastPart = true)), ReadFrameHeader, true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,51 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
|
||||||
*/
|
|
||||||
|
|
||||||
package akka.http.impl.util
|
|
||||||
|
|
||||||
import scala.util.control.NoStackTrace
|
|
||||||
import akka.util.ByteString
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A helper class to read from a ByteString statefully.
|
|
||||||
*
|
|
||||||
* INTERNAL API
|
|
||||||
*/
|
|
||||||
private[akka] class ByteReader(input: ByteString) {
|
|
||||||
import ByteReader.NeedMoreData
|
|
||||||
|
|
||||||
private[this] var off = 0
|
|
||||||
|
|
||||||
def hasRemaining: Boolean = off < input.size
|
|
||||||
|
|
||||||
def currentOffset: Int = off
|
|
||||||
def remainingData: ByteString = input.drop(off)
|
|
||||||
def fromStartToHere: ByteString = input.take(off)
|
|
||||||
|
|
||||||
def readByte(): Int =
|
|
||||||
if (off < input.length) {
|
|
||||||
val x = input(off)
|
|
||||||
off += 1
|
|
||||||
x & 0xFF
|
|
||||||
} else throw NeedMoreData
|
|
||||||
def readShortLE(): Int = readByte() | (readByte() << 8)
|
|
||||||
def readIntLE(): Int = readShortLE() | (readShortLE() << 16)
|
|
||||||
def readLongLE(): Long = (readIntLE() & 0xffffffffL) | ((readIntLE() & 0xffffffffL) << 32)
|
|
||||||
|
|
||||||
def readShortBE(): Int = (readByte() << 8) | readByte()
|
|
||||||
def readIntBE(): Int = (readShortBE() << 16) | readShortBE()
|
|
||||||
def readLongBE(): Long = ((readIntBE() & 0xffffffffL) << 32) | (readIntBE() & 0xffffffffL)
|
|
||||||
|
|
||||||
def skip(numBytes: Int): Unit =
|
|
||||||
if (off + numBytes <= input.length) off += numBytes
|
|
||||||
else throw NeedMoreData
|
|
||||||
def skipZeroTerminatedString(): Unit = while (readByte() != 0) {}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* INTERNAL API
|
|
||||||
*/
|
|
||||||
private[akka] object ByteReader {
|
|
||||||
val NeedMoreData = new Exception with NoStackTrace
|
|
||||||
}
|
|
||||||
|
|
@ -1,59 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
|
||||||
*/
|
|
||||||
|
|
||||||
package akka.http.impl.util
|
|
||||||
|
|
||||||
import akka.stream.stage.{ Context, StatefulStage }
|
|
||||||
import akka.util.ByteString
|
|
||||||
import akka.stream.stage.SyncDirective
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A helper class for writing parsers from ByteStrings.
|
|
||||||
*
|
|
||||||
* FIXME: move to akka.stream.io, https://github.com/akka/akka/issues/16529
|
|
||||||
*
|
|
||||||
* INTERNAL API
|
|
||||||
*/
|
|
||||||
private[akka] abstract class ByteStringParserStage[Out] extends StatefulStage[ByteString, Out] {
|
|
||||||
protected def onTruncation(ctx: Context[Out]): SyncDirective
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Derive a stage from [[IntermediateState]] and then call `pull(ctx)` instead of
|
|
||||||
* `ctx.pull()` to have truncation errors reported.
|
|
||||||
*/
|
|
||||||
abstract class IntermediateState extends State {
|
|
||||||
override def onPull(ctx: Context[Out]): SyncDirective = pull(ctx)
|
|
||||||
def pull(ctx: Context[Out]): SyncDirective =
|
|
||||||
if (ctx.isFinishing) onTruncation(ctx)
|
|
||||||
else ctx.pull()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A stage that tries to read from a side-effecting [[ByteReader]]. If a buffer underrun
|
|
||||||
* occurs the previous data is saved and the reading process is restarted from the beginning
|
|
||||||
* once more data was received.
|
|
||||||
*
|
|
||||||
* As [[read]] may be called several times for the same prefix of data, make sure not to
|
|
||||||
* manipulate any state during reading from the ByteReader.
|
|
||||||
*/
|
|
||||||
private[akka] trait ByteReadingState extends IntermediateState {
|
|
||||||
def read(reader: ByteReader, ctx: Context[Out]): SyncDirective
|
|
||||||
|
|
||||||
def onPush(data: ByteString, ctx: Context[Out]): SyncDirective =
|
|
||||||
try {
|
|
||||||
val reader = new ByteReader(data)
|
|
||||||
read(reader, ctx)
|
|
||||||
} catch {
|
|
||||||
case ByteReader.NeedMoreData ⇒
|
|
||||||
become(TryAgain(data, this))
|
|
||||||
pull(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
private case class TryAgain(previousData: ByteString, byteReadingState: ByteReadingState) extends IntermediateState {
|
|
||||||
def onPush(data: ByteString, ctx: Context[Out]): SyncDirective = {
|
|
||||||
become(byteReadingState)
|
|
||||||
byteReadingState.onPush(previousData ++ data, ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -10,6 +10,7 @@ import akka.NotUsed
|
||||||
import akka.http.scaladsl.model.RequestEntity
|
import akka.http.scaladsl.model.RequestEntity
|
||||||
import akka.stream._
|
import akka.stream._
|
||||||
import akka.stream.impl.StreamLayout.Module
|
import akka.stream.impl.StreamLayout.Module
|
||||||
|
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
|
||||||
import akka.stream.impl.{ PublisherSink, SinkModule, SourceModule }
|
import akka.stream.impl.{ PublisherSink, SinkModule, SourceModule }
|
||||||
import akka.stream.scaladsl._
|
import akka.stream.scaladsl._
|
||||||
import akka.stream.stage._
|
import akka.stream.stage._
|
||||||
|
|
@ -114,40 +115,46 @@ private[http] object StreamUtils {
|
||||||
Flow[ByteString].transform(() ⇒ transformer).named("sliceBytes")
|
Flow[ByteString].transform(() ⇒ transformer).named("sliceBytes")
|
||||||
}
|
}
|
||||||
|
|
||||||
def limitByteChunksStage(maxBytesPerChunk: Int): PushPullStage[ByteString, ByteString] =
|
def limitByteChunksStage(maxBytesPerChunk: Int): GraphStage[FlowShape[ByteString, ByteString]] =
|
||||||
new StatefulStage[ByteString, ByteString] {
|
new SimpleLinearGraphStage[ByteString] {
|
||||||
def initial = WaitingForData
|
override def initialAttributes = Attributes.name("limitByteChunksStage")
|
||||||
|
var remaining = ByteString.empty
|
||||||
|
|
||||||
case object WaitingForData extends State {
|
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
|
||||||
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective =
|
|
||||||
if (elem.size <= maxBytesPerChunk) ctx.push(elem)
|
|
||||||
else {
|
|
||||||
become(DeliveringData(elem.drop(maxBytesPerChunk)))
|
|
||||||
ctx.push(elem.take(maxBytesPerChunk))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case class DeliveringData(remaining: ByteString) extends State {
|
def splitAndPush(elem: ByteString): Unit = {
|
||||||
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective =
|
|
||||||
throw new IllegalStateException("Not expecting data")
|
|
||||||
|
|
||||||
override def onPull(ctx: Context[ByteString]): SyncDirective = {
|
|
||||||
val toPush = remaining.take(maxBytesPerChunk)
|
val toPush = remaining.take(maxBytesPerChunk)
|
||||||
val toKeep = remaining.drop(maxBytesPerChunk)
|
val toKeep = remaining.drop(maxBytesPerChunk)
|
||||||
|
push(out, toPush)
|
||||||
|
remaining = toKeep
|
||||||
|
}
|
||||||
|
setHandlers(in, out, WaitingForData)
|
||||||
|
|
||||||
become {
|
case object WaitingForData extends InHandler with OutHandler {
|
||||||
if (toKeep.isEmpty) WaitingForData
|
override def onPush(): Unit = {
|
||||||
else DeliveringData(toKeep)
|
val elem = grab(in)
|
||||||
|
if (elem.size <= maxBytesPerChunk) push(out, elem)
|
||||||
|
else {
|
||||||
|
splitAndPush(elem)
|
||||||
|
setHandlers(in, out, DeliveringData)
|
||||||
}
|
}
|
||||||
if (ctx.isFinishing) ctx.pushAndFinish(toPush)
|
|
||||||
else ctx.push(toPush)
|
|
||||||
}
|
}
|
||||||
|
override def onPull(): Unit = pull(in)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective =
|
case object DeliveringData extends InHandler() with OutHandler {
|
||||||
current match {
|
var finishing = false
|
||||||
case WaitingForData ⇒ ctx.finish()
|
override def onPush(): Unit = throw new IllegalStateException("Not expecting data")
|
||||||
case _: DeliveringData ⇒ ctx.absorbTermination()
|
override def onPull(): Unit = {
|
||||||
|
splitAndPush(remaining)
|
||||||
|
if (remaining.isEmpty) {
|
||||||
|
if (finishing) completeStage() else setHandlers(in, out, WaitingForData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
override def onUpstreamFinish(): Unit = if (remaining.isEmpty) completeStage() else finishing = true
|
||||||
|
}
|
||||||
|
|
||||||
|
override def toString = "limitByteChunksStage"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,13 @@
|
||||||
|
|
||||||
package akka.http.scaladsl.coding
|
package akka.http.scaladsl.coding
|
||||||
|
|
||||||
|
import akka.stream.{ Attributes, FlowShape }
|
||||||
|
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
|
||||||
|
|
||||||
import scala.concurrent.duration._
|
import scala.concurrent.duration._
|
||||||
import org.scalatest.WordSpec
|
import org.scalatest.WordSpec
|
||||||
import akka.util.ByteString
|
import akka.util.ByteString
|
||||||
import akka.stream.stage.{ SyncDirective, Context, PushStage, Stage }
|
import akka.stream.stage._
|
||||||
import akka.http.scaladsl.model._
|
import akka.http.scaladsl.model._
|
||||||
import akka.http.impl.util._
|
import akka.http.impl.util._
|
||||||
import headers._
|
import headers._
|
||||||
|
|
@ -34,10 +37,17 @@ class DecoderSpec extends WordSpec with CodecSpecSupport {
|
||||||
case object DummyDecoder extends StreamDecoder {
|
case object DummyDecoder extends StreamDecoder {
|
||||||
val encoding = HttpEncodings.compress
|
val encoding = HttpEncodings.compress
|
||||||
|
|
||||||
def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] =
|
override def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ GraphStage[FlowShape[ByteString, ByteString]] =
|
||||||
() ⇒ new PushStage[ByteString, ByteString] {
|
() ⇒ new SimpleLinearGraphStage[ByteString] {
|
||||||
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective =
|
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
|
||||||
ctx.push(elem ++ ByteString("compressed"))
|
setHandler(in, new InHandler {
|
||||||
|
override def onPush(): Unit = push(out, grab(in) ++ ByteString("compressed"))
|
||||||
|
})
|
||||||
|
setHandler(out, new OutHandler {
|
||||||
|
override def onPull(): Unit = pull(in)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,6 @@ class GzipSpec extends CoderSpec {
|
||||||
}
|
}
|
||||||
"throw an error if compressed data is just missing the trailer at the end" in {
|
"throw an error if compressed data is just missing the trailer at the end" in {
|
||||||
def brokenCompress(payload: String) = Gzip.newCompressor.compress(ByteString(payload, "UTF-8"))
|
def brokenCompress(payload: String) = Gzip.newCompressor.compress(ByteString(payload, "UTF-8"))
|
||||||
|
|
||||||
val ex = the[RuntimeException] thrownBy ourDecode(brokenCompress("abcdefghijkl"))
|
val ex = the[RuntimeException] thrownBy ourDecode(brokenCompress("abcdefghijkl"))
|
||||||
ex.getCause.getMessage should equal("Truncated GZIP stream")
|
ex.getCause.getMessage should equal("Truncated GZIP stream")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@ package akka.http.scaladsl.coding
|
||||||
|
|
||||||
import akka.NotUsed
|
import akka.NotUsed
|
||||||
import akka.http.scaladsl.model._
|
import akka.http.scaladsl.model._
|
||||||
import akka.stream.Materializer
|
import akka.stream.{ FlowShape, Materializer }
|
||||||
import akka.stream.stage.Stage
|
import akka.stream.stage.{ GraphStage, Stage }
|
||||||
import akka.util.ByteString
|
import akka.util.ByteString
|
||||||
import headers.HttpEncoding
|
import headers.HttpEncoding
|
||||||
import akka.stream.scaladsl.{ Sink, Source, Flow }
|
import akka.stream.scaladsl.{ Sink, Source, Flow }
|
||||||
|
|
@ -37,7 +37,7 @@ object Decoder {
|
||||||
|
|
||||||
/** A decoder that is implemented in terms of a [[Stage]] */
|
/** A decoder that is implemented in terms of a [[Stage]] */
|
||||||
trait StreamDecoder extends Decoder { outer ⇒
|
trait StreamDecoder extends Decoder { outer ⇒
|
||||||
protected def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString]
|
protected def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ GraphStage[FlowShape[ByteString, ByteString]]
|
||||||
|
|
||||||
def maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault
|
def maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault
|
||||||
def withMaxBytesPerChunk(newMaxBytesPerChunk: Int): Decoder =
|
def withMaxBytesPerChunk(newMaxBytesPerChunk: Int): Decoder =
|
||||||
|
|
@ -45,11 +45,11 @@ trait StreamDecoder extends Decoder { outer ⇒
|
||||||
def encoding: HttpEncoding = outer.encoding
|
def encoding: HttpEncoding = outer.encoding
|
||||||
override def maxBytesPerChunk: Int = newMaxBytesPerChunk
|
override def maxBytesPerChunk: Int = newMaxBytesPerChunk
|
||||||
|
|
||||||
def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] =
|
def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ GraphStage[FlowShape[ByteString, ByteString]] =
|
||||||
outer.newDecompressorStage(maxBytesPerChunk)
|
outer.newDecompressorStage(maxBytesPerChunk)
|
||||||
}
|
}
|
||||||
|
|
||||||
def decoderFlow: Flow[ByteString, ByteString, NotUsed] =
|
def decoderFlow: Flow[ByteString, ByteString, NotUsed] =
|
||||||
Flow[ByteString].transform(newDecompressorStage(maxBytesPerChunk))
|
Flow.fromGraph(newDecompressorStage(maxBytesPerChunk)())
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,12 @@
|
||||||
package akka.http.scaladsl.coding
|
package akka.http.scaladsl.coding
|
||||||
|
|
||||||
import java.util.zip.{ Inflater, Deflater }
|
import java.util.zip.{ Inflater, Deflater }
|
||||||
import akka.stream.stage._
|
import akka.stream.Attributes
|
||||||
|
import akka.stream.io.ByteStringParser
|
||||||
|
import akka.stream.io.ByteStringParser.{ ParseResult, ParseStep }
|
||||||
import akka.util.{ ByteStringBuilder, ByteString }
|
import akka.util.{ ByteStringBuilder, ByteString }
|
||||||
|
|
||||||
import scala.annotation.tailrec
|
import scala.annotation.tailrec
|
||||||
import akka.http.impl.util._
|
|
||||||
import akka.http.scaladsl.model._
|
import akka.http.scaladsl.model._
|
||||||
import akka.http.scaladsl.model.headers.HttpEncodings
|
import akka.http.scaladsl.model.headers.HttpEncodings
|
||||||
|
|
||||||
|
|
@ -86,56 +87,49 @@ private[http] object DeflateCompressor {
|
||||||
}
|
}
|
||||||
|
|
||||||
class DeflateDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) {
|
class DeflateDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) {
|
||||||
protected def createInflater() = new Inflater()
|
|
||||||
|
|
||||||
def initial: State = StartInflate
|
override def createLogic(attr: Attributes) = new DecompressorParsingLogic {
|
||||||
def afterInflate: State = StartInflate
|
override val inflater: Inflater = new Inflater()
|
||||||
|
|
||||||
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = {}
|
override val inflateState = new Inflate(true) {
|
||||||
protected def onTruncation(ctx: Context[ByteString]): SyncDirective = ctx.finish()
|
override def onTruncation(): Unit = completeStage()
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract class DeflateDecompressorBase(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends ByteStringParserStage[ByteString] {
|
override def afterInflate = inflateState
|
||||||
protected def createInflater(): Inflater
|
override def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = {}
|
||||||
val inflater = createInflater()
|
|
||||||
|
|
||||||
protected def afterInflate: State
|
startWith(inflateState)
|
||||||
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit
|
|
||||||
|
|
||||||
/** Start inflating */
|
|
||||||
case object StartInflate extends IntermediateState {
|
|
||||||
def onPush(data: ByteString, ctx: Context[ByteString]): SyncDirective = {
|
|
||||||
require(inflater.needsInput())
|
|
||||||
inflater.setInput(data.toArray)
|
|
||||||
|
|
||||||
becomeWithRemaining(Inflate()(data), ByteString.empty, ctx)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Inflate */
|
abstract class DeflateDecompressorBase(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault)
|
||||||
case class Inflate()(data: ByteString) extends IntermediateState {
|
extends ByteStringParser[ByteString] {
|
||||||
override def onPull(ctx: Context[ByteString]): SyncDirective = {
|
|
||||||
|
abstract class DecompressorParsingLogic extends ParsingLogic {
|
||||||
|
val inflater: Inflater
|
||||||
|
def afterInflate: ParseStep[ByteString]
|
||||||
|
def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit
|
||||||
|
val inflateState: Inflate
|
||||||
|
|
||||||
|
abstract class Inflate(noPostProcessing: Boolean) extends ParseStep[ByteString] {
|
||||||
|
override def canWorkWithPartialData = true
|
||||||
|
override def parse(reader: ByteStringParser.ByteReader): ParseResult[ByteString] = {
|
||||||
|
inflater.setInput(reader.remainingData.toArray)
|
||||||
|
|
||||||
val buffer = new Array[Byte](maxBytesPerChunk)
|
val buffer = new Array[Byte](maxBytesPerChunk)
|
||||||
val read = inflater.inflate(buffer)
|
val read = inflater.inflate(buffer)
|
||||||
|
|
||||||
|
reader.skip(reader.remainingSize - inflater.getRemaining)
|
||||||
|
|
||||||
if (read > 0) {
|
if (read > 0) {
|
||||||
afterBytesRead(buffer, 0, read)
|
afterBytesRead(buffer, 0, read)
|
||||||
ctx.push(ByteString.fromArray(buffer, 0, read))
|
val next = if (inflater.finished()) afterInflate else this
|
||||||
|
ParseResult(Some(ByteString.fromArray(buffer, 0, read)), next, noPostProcessing)
|
||||||
} else {
|
} else {
|
||||||
val remaining = data.takeRight(inflater.getRemaining)
|
if (inflater.finished()) ParseResult(None, afterInflate, noPostProcessing)
|
||||||
val next =
|
else throw ByteStringParser.NeedMoreData
|
||||||
if (inflater.finished()) afterInflate
|
}
|
||||||
else StartInflate
|
|
||||||
|
|
||||||
becomeWithRemaining(next, remaining, ctx)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective =
|
|
||||||
throw new IllegalStateException("Don't expect a new Element")
|
|
||||||
}
|
|
||||||
|
|
||||||
def becomeWithRemaining(next: State, remaining: ByteString, ctx: Context[ByteString]) = {
|
|
||||||
become(next)
|
|
||||||
if (remaining.isEmpty) current.onPull(ctx)
|
|
||||||
else current.onPush(remaining, ctx)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,14 +4,15 @@
|
||||||
|
|
||||||
package akka.http.scaladsl.coding
|
package akka.http.scaladsl.coding
|
||||||
|
|
||||||
import akka.util.ByteString
|
import java.util.zip.{ CRC32, Deflater, Inflater, ZipException }
|
||||||
import akka.stream.stage._
|
|
||||||
|
|
||||||
import akka.http.impl.util.ByteReader
|
|
||||||
import java.util.zip.{ Inflater, CRC32, ZipException, Deflater }
|
|
||||||
|
|
||||||
|
import akka.http.impl.engine.ws.{ ProtocolException, FrameEvent }
|
||||||
import akka.http.scaladsl.model._
|
import akka.http.scaladsl.model._
|
||||||
import headers.HttpEncodings
|
import akka.http.scaladsl.model.headers.HttpEncodings
|
||||||
|
import akka.stream.Attributes
|
||||||
|
import akka.stream.io.ByteStringParser
|
||||||
|
import akka.stream.io.ByteStringParser.{ ParseResult, ParseStep }
|
||||||
|
import akka.util.ByteString
|
||||||
|
|
||||||
class Gzip(val messageFilter: HttpMessage ⇒ Boolean) extends Coder with StreamDecoder {
|
class Gzip(val messageFilter: HttpMessage ⇒ Boolean) extends Coder with StreamDecoder {
|
||||||
val encoding = HttpEncodings.gzip
|
val encoding = HttpEncodings.gzip
|
||||||
|
|
@ -60,30 +61,22 @@ class GzipCompressor extends DeflateCompressor {
|
||||||
}
|
}
|
||||||
|
|
||||||
class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) {
|
class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) {
|
||||||
protected def createInflater(): Inflater = new Inflater(true)
|
override def createLogic(attr: Attributes) = new DecompressorParsingLogic {
|
||||||
|
override val inflater: Inflater = new Inflater(true)
|
||||||
|
override def afterInflate: ParseStep[ByteString] = ReadTrailer
|
||||||
|
override def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit =
|
||||||
|
crc32.update(buffer, offset, length)
|
||||||
|
|
||||||
def initial: State = Initial
|
trait Step extends ParseStep[ByteString] {
|
||||||
|
override def onTruncation(): Unit = failStage(new ZipException("Truncated GZIP stream"))
|
||||||
/** No bytes were received yet */
|
|
||||||
case object Initial extends State {
|
|
||||||
def onPush(data: ByteString, ctx: Context[ByteString]): SyncDirective =
|
|
||||||
if (data.isEmpty) ctx.pull()
|
|
||||||
else becomeWithRemaining(ReadHeaders, data, ctx)
|
|
||||||
|
|
||||||
override def onPull(ctx: Context[ByteString]): SyncDirective =
|
|
||||||
if (ctx.isFinishing) {
|
|
||||||
ctx.finish()
|
|
||||||
} else super.onPull(ctx)
|
|
||||||
}
|
}
|
||||||
|
override val inflateState = new Inflate(false) with Step
|
||||||
var crc32: CRC32 = new CRC32
|
startWith(ReadHeaders)
|
||||||
protected def afterInflate: State = ReadTrailer
|
|
||||||
|
|
||||||
/** Reading the header bytes */
|
/** Reading the header bytes */
|
||||||
case object ReadHeaders extends ByteReadingState {
|
case object ReadHeaders extends Step {
|
||||||
def read(reader: ByteReader, ctx: Context[ByteString]): SyncDirective = {
|
override def parse(reader: ByteStringParser.ByteReader): ParseResult[ByteString] = {
|
||||||
import reader._
|
import reader._
|
||||||
|
|
||||||
if (readByte() != 0x1F || readByte() != 0x8B) fail("Not in GZIP format") // check magic header
|
if (readByte() != 0x1F || readByte() != 0x8B) fail("Not in GZIP format") // check magic header
|
||||||
if (readByte() != 8) fail("Unsupported GZIP compression method") // check compression method
|
if (readByte() != 8) fail("Unsupported GZIP compression method") // check compression method
|
||||||
val flags = readByte()
|
val flags = readByte()
|
||||||
|
|
@ -95,36 +88,28 @@ class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault)
|
||||||
|
|
||||||
inflater.reset()
|
inflater.reset()
|
||||||
crc32.reset()
|
crc32.reset()
|
||||||
becomeWithRemaining(StartInflate, remainingData, ctx)
|
ParseResult(None, inflateState, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
var crc32: CRC32 = new CRC32
|
||||||
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit =
|
private def fail(msg: String) = throw new ZipException(msg)
|
||||||
crc32.update(buffer, offset, length)
|
|
||||||
|
|
||||||
/** Reading the trailer */
|
/** Reading the trailer */
|
||||||
case object ReadTrailer extends ByteReadingState {
|
case object ReadTrailer extends Step {
|
||||||
def read(reader: ByteReader, ctx: Context[ByteString]): SyncDirective = {
|
override def parse(reader: ByteStringParser.ByteReader): ParseResult[ByteString] = {
|
||||||
import reader._
|
import reader._
|
||||||
|
|
||||||
if (readIntLE() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)")
|
if (readIntLE() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)")
|
||||||
if (readIntLE() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE")
|
if (readIntLE() != inflater.getBytesWritten.toInt /* truncated to 32bit */ )
|
||||||
|
fail("Corrupt GZIP trailer ISIZE")
|
||||||
becomeWithRemaining(Initial, remainingData, ctx)
|
ParseResult(None, ReadHeaders, true)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination()
|
|
||||||
|
|
||||||
private def crc16(data: ByteString) = {
|
private def crc16(data: ByteString) = {
|
||||||
val crc = new CRC32
|
val crc = new CRC32
|
||||||
crc.update(data.toArray)
|
crc.update(data.toArray)
|
||||||
crc.getValue.toInt & 0xFFFF
|
crc.getValue.toInt & 0xFFFF
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected def onTruncation(ctx: Context[ByteString]): SyncDirective = ctx.fail(new ZipException("Truncated GZIP stream"))
|
|
||||||
|
|
||||||
private def fail(msg: String) = throw new ZipException(msg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** INTERNAL API */
|
/** INTERNAL API */
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,8 @@ package akka.http.scaladsl.coding
|
||||||
|
|
||||||
import akka.http.scaladsl.model._
|
import akka.http.scaladsl.model._
|
||||||
import akka.http.impl.util.StreamUtils
|
import akka.http.impl.util.StreamUtils
|
||||||
import akka.stream.stage.Stage
|
import akka.stream.FlowShape
|
||||||
|
import akka.stream.stage.{ GraphStage, Stage }
|
||||||
import akka.util.ByteString
|
import akka.util.ByteString
|
||||||
import headers.HttpEncodings
|
import headers.HttpEncodings
|
||||||
|
|
||||||
|
|
@ -25,7 +26,7 @@ object NoCoding extends Coder with StreamDecoder {
|
||||||
|
|
||||||
def newCompressor = NoCodingCompressor
|
def newCompressor = NoCodingCompressor
|
||||||
|
|
||||||
def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] =
|
def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ GraphStage[FlowShape[ByteString, ByteString]] =
|
||||||
() ⇒ StreamUtils.limitByteChunksStage(maxBytesPerChunk)
|
() ⇒ StreamUtils.limitByteChunksStage(maxBytesPerChunk)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ object GraphStages {
|
||||||
/**
|
/**
|
||||||
* INERNAL API
|
* INERNAL API
|
||||||
*/
|
*/
|
||||||
private[stream] abstract class SimpleLinearGraphStage[T] extends GraphStage[FlowShape[T, T]] {
|
private[akka] abstract class SimpleLinearGraphStage[T] extends GraphStage[FlowShape[T, T]] {
|
||||||
val in = Inlet[T](Logging.simpleName(this) + ".in")
|
val in = Inlet[T](Logging.simpleName(this) + ".in")
|
||||||
val out = Outlet[T](Logging.simpleName(this) + ".out")
|
val out = Outlet[T](Logging.simpleName(this) + ".out")
|
||||||
override val shape = FlowShape(in, out)
|
override val shape = FlowShape(in, out)
|
||||||
|
|
|
||||||
|
|
@ -19,30 +19,35 @@ abstract class ByteStringParser[T] extends GraphStage[FlowShape[ByteString, T]]
|
||||||
final override val shape = FlowShape(bytesIn, objOut)
|
final override val shape = FlowShape(bytesIn, objOut)
|
||||||
|
|
||||||
class ParsingLogic extends GraphStageLogic(shape) {
|
class ParsingLogic extends GraphStageLogic(shape) {
|
||||||
|
var pullOnParserRequest = false
|
||||||
override def preStart(): Unit = pull(bytesIn)
|
override def preStart(): Unit = pull(bytesIn)
|
||||||
setHandler(objOut, eagerTerminateOutput)
|
setHandler(objOut, eagerTerminateOutput)
|
||||||
|
|
||||||
private var buffer = ByteString.empty
|
private var buffer = ByteString.empty
|
||||||
private var current: ParseStep[T] = FinishedParser
|
private var current: ParseStep[T] = FinishedParser
|
||||||
|
private var acceptUpstreamFinish: Boolean = true
|
||||||
|
|
||||||
final protected def startWith(step: ParseStep[T]): Unit = current = step
|
final protected def startWith(step: ParseStep[T]): Unit = current = step
|
||||||
|
|
||||||
@tailrec private def doParse(): Unit =
|
@tailrec private def doParse(): Unit =
|
||||||
if (buffer.nonEmpty) {
|
if (buffer.nonEmpty) {
|
||||||
val cont = try {
|
|
||||||
val reader = new ByteReader(buffer)
|
val reader = new ByteReader(buffer)
|
||||||
val (elem, next) = current.parse(reader)
|
val cont = try {
|
||||||
emit(objOut, elem)
|
val parseResult = current.parse(reader)
|
||||||
if (next == FinishedParser) {
|
acceptUpstreamFinish = parseResult.acceptUpstreamFinish
|
||||||
|
parseResult.result.map(emit(objOut, _))
|
||||||
|
if (parseResult.nextStep == FinishedParser) {
|
||||||
completeStage()
|
completeStage()
|
||||||
false
|
false
|
||||||
} else {
|
} else {
|
||||||
buffer = reader.remainingData
|
buffer = reader.remainingData
|
||||||
current = next
|
current = parseResult.nextStep
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
case NeedMoreData ⇒
|
case NeedMoreData ⇒
|
||||||
|
acceptUpstreamFinish = false
|
||||||
|
if (current.canWorkWithPartialData) buffer = reader.remainingData
|
||||||
pull(bytesIn)
|
pull(bytesIn)
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
@ -51,11 +56,12 @@ abstract class ByteStringParser[T] extends GraphStage[FlowShape[ByteString, T]]
|
||||||
|
|
||||||
setHandler(bytesIn, new InHandler {
|
setHandler(bytesIn, new InHandler {
|
||||||
override def onPush(): Unit = {
|
override def onPush(): Unit = {
|
||||||
|
pullOnParserRequest = false
|
||||||
buffer ++= grab(bytesIn)
|
buffer ++= grab(bytesIn)
|
||||||
doParse()
|
doParse()
|
||||||
}
|
}
|
||||||
override def onUpstreamFinish(): Unit =
|
override def onUpstreamFinish(): Unit =
|
||||||
if (buffer.isEmpty) completeStage()
|
if (buffer.isEmpty && acceptUpstreamFinish) completeStage()
|
||||||
else current.onTruncation()
|
else current.onTruncation()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -63,13 +69,28 @@ abstract class ByteStringParser[T] extends GraphStage[FlowShape[ByteString, T]]
|
||||||
|
|
||||||
object ByteStringParser {
|
object ByteStringParser {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param result - parser can return some element for downstream or return None if no element was generated
|
||||||
|
* @param nextStep - next parser
|
||||||
|
* @param acceptUpstreamFinish - if true - stream will complete when received `onUpstreamFinish`, if "false"
|
||||||
|
* - onTruncation will be called
|
||||||
|
*/
|
||||||
|
case class ParseResult[+T](result: Option[T],
|
||||||
|
nextStep: ParseStep[T],
|
||||||
|
acceptUpstreamFinish: Boolean = true)
|
||||||
|
|
||||||
trait ParseStep[+T] {
|
trait ParseStep[+T] {
|
||||||
def parse(reader: ByteReader): (T, ParseStep[T])
|
/**
|
||||||
|
* Must return true when NeedMoreData will clean buffer. If returns false - next pulled
|
||||||
|
* data will be appended to existing data in buffer
|
||||||
|
*/
|
||||||
|
def canWorkWithPartialData: Boolean = false
|
||||||
|
def parse(reader: ByteReader): ParseResult[T]
|
||||||
def onTruncation(): Unit = throw new IllegalStateException("truncated data in ByteStringParser")
|
def onTruncation(): Unit = throw new IllegalStateException("truncated data in ByteStringParser")
|
||||||
}
|
}
|
||||||
|
|
||||||
object FinishedParser extends ParseStep[Nothing] {
|
object FinishedParser extends ParseStep[Nothing] {
|
||||||
def parse(reader: ByteReader) =
|
override def parse(reader: ByteReader) =
|
||||||
throw new IllegalStateException("no initial parser installed: you must use startWith(...)")
|
throw new IllegalStateException("no initial parser installed: you must use startWith(...)")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -83,6 +104,7 @@ object ByteStringParser {
|
||||||
def remainingSize: Int = input.size - off
|
def remainingSize: Int = input.size - off
|
||||||
|
|
||||||
def currentOffset: Int = off
|
def currentOffset: Int = off
|
||||||
|
|
||||||
def remainingData: ByteString = input.drop(off)
|
def remainingData: ByteString = input.drop(off)
|
||||||
def fromStartToHere: ByteString = input.take(off)
|
def fromStartToHere: ByteString = input.take(off)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -257,7 +257,6 @@ object Zip {
|
||||||
* '''Cancels when''' any downstream cancels
|
* '''Cancels when''' any downstream cancels
|
||||||
*/
|
*/
|
||||||
object Unzip {
|
object Unzip {
|
||||||
import akka.japi.function.Function
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new `Unzip` stage with the specified output types.
|
* Creates a new `Unzip` stage with the specified output types.
|
||||||
|
|
|
||||||
|
|
@ -300,6 +300,14 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
|
||||||
if (_interpreter != null) _interpreter.setHandler(conn(in), handler)
|
if (_interpreter != null) _interpreter.setHandler(conn(in), handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Assign callbacks for linear stage for both [[Inlet]] and [[Outlet]]
|
||||||
|
*/
|
||||||
|
final protected def setHandlers(in: Inlet[_], out: Outlet[_], handler: InHandler with OutHandler): Unit = {
|
||||||
|
setHandler(in, handler)
|
||||||
|
setHandler(out, handler)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Retrieves the current callback for the events on the given [[Inlet]]
|
* Retrieves the current callback for the events on the given [[Inlet]]
|
||||||
*/
|
*/
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue