=str #19361 migrating ByteStringParserStage to graph stage
This commit is contained in:
parent
3fc332d2c9
commit
07c0da36f2
14 changed files with 184 additions and 267 deletions
|
|
@ -46,7 +46,7 @@ private[http] object FrameEventParser extends ByteStringParser[FrameEvent] {
|
|||
}
|
||||
|
||||
object ReadFrameHeader extends Step {
|
||||
override def parse(reader: ByteReader): (FrameEvent, Step) = {
|
||||
override def parse(reader: ByteReader): ParseResult[FrameEvent] = {
|
||||
import Protocol._
|
||||
|
||||
val flagsAndOp = reader.readByte()
|
||||
|
|
@ -83,23 +83,25 @@ private[http] object FrameEventParser extends ByteStringParser[FrameEvent] {
|
|||
|
||||
val takeNow = (header.length min reader.remainingSize).toInt
|
||||
val thisFrameData = reader.take(takeNow)
|
||||
val noMoreData = thisFrameData.length == length
|
||||
|
||||
val nextState =
|
||||
if (thisFrameData.length == length) ReadFrameHeader
|
||||
if (noMoreData) ReadFrameHeader
|
||||
else new ReadData(length - thisFrameData.length)
|
||||
|
||||
(FrameStart(header, thisFrameData.compact), nextState)
|
||||
ParseResult(Some(FrameStart(header, thisFrameData.compact)), nextState, true)
|
||||
}
|
||||
}
|
||||
|
||||
class ReadData(_remaining: Long) extends Step {
|
||||
override def canWorkWithPartialData = true
|
||||
var remaining = _remaining
|
||||
override def parse(reader: ByteReader): (FrameEvent, Step) =
|
||||
override def parse(reader: ByteReader): ParseResult[FrameEvent] =
|
||||
if (reader.remainingSize < remaining) {
|
||||
remaining -= reader.remainingSize
|
||||
(FrameData(reader.takeAll(), lastPart = false), this)
|
||||
ParseResult(Some(FrameData(reader.takeAll(), lastPart = false)), this, true)
|
||||
} else {
|
||||
(FrameData(reader.take(remaining.toInt), lastPart = true), ReadFrameHeader)
|
||||
ParseResult(Some(FrameData(reader.take(remaining.toInt), lastPart = true)), ReadFrameHeader, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,51 +0,0 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.impl.util
|
||||
|
||||
import scala.util.control.NoStackTrace
|
||||
import akka.util.ByteString
|
||||
|
||||
/**
|
||||
* A helper class to read from a ByteString statefully.
|
||||
*
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[akka] class ByteReader(input: ByteString) {
|
||||
import ByteReader.NeedMoreData
|
||||
|
||||
private[this] var off = 0
|
||||
|
||||
def hasRemaining: Boolean = off < input.size
|
||||
|
||||
def currentOffset: Int = off
|
||||
def remainingData: ByteString = input.drop(off)
|
||||
def fromStartToHere: ByteString = input.take(off)
|
||||
|
||||
def readByte(): Int =
|
||||
if (off < input.length) {
|
||||
val x = input(off)
|
||||
off += 1
|
||||
x & 0xFF
|
||||
} else throw NeedMoreData
|
||||
def readShortLE(): Int = readByte() | (readByte() << 8)
|
||||
def readIntLE(): Int = readShortLE() | (readShortLE() << 16)
|
||||
def readLongLE(): Long = (readIntLE() & 0xffffffffL) | ((readIntLE() & 0xffffffffL) << 32)
|
||||
|
||||
def readShortBE(): Int = (readByte() << 8) | readByte()
|
||||
def readIntBE(): Int = (readShortBE() << 16) | readShortBE()
|
||||
def readLongBE(): Long = ((readIntBE() & 0xffffffffL) << 32) | (readIntBE() & 0xffffffffL)
|
||||
|
||||
def skip(numBytes: Int): Unit =
|
||||
if (off + numBytes <= input.length) off += numBytes
|
||||
else throw NeedMoreData
|
||||
def skipZeroTerminatedString(): Unit = while (readByte() != 0) {}
|
||||
}
|
||||
|
||||
/*
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[akka] object ByteReader {
|
||||
val NeedMoreData = new Exception with NoStackTrace
|
||||
}
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.impl.util
|
||||
|
||||
import akka.stream.stage.{ Context, StatefulStage }
|
||||
import akka.util.ByteString
|
||||
import akka.stream.stage.SyncDirective
|
||||
|
||||
/**
|
||||
* A helper class for writing parsers from ByteStrings.
|
||||
*
|
||||
* FIXME: move to akka.stream.io, https://github.com/akka/akka/issues/16529
|
||||
*
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[akka] abstract class ByteStringParserStage[Out] extends StatefulStage[ByteString, Out] {
|
||||
protected def onTruncation(ctx: Context[Out]): SyncDirective
|
||||
|
||||
/**
|
||||
* Derive a stage from [[IntermediateState]] and then call `pull(ctx)` instead of
|
||||
* `ctx.pull()` to have truncation errors reported.
|
||||
*/
|
||||
abstract class IntermediateState extends State {
|
||||
override def onPull(ctx: Context[Out]): SyncDirective = pull(ctx)
|
||||
def pull(ctx: Context[Out]): SyncDirective =
|
||||
if (ctx.isFinishing) onTruncation(ctx)
|
||||
else ctx.pull()
|
||||
}
|
||||
|
||||
/**
|
||||
* A stage that tries to read from a side-effecting [[ByteReader]]. If a buffer underrun
|
||||
* occurs the previous data is saved and the reading process is restarted from the beginning
|
||||
* once more data was received.
|
||||
*
|
||||
* As [[read]] may be called several times for the same prefix of data, make sure not to
|
||||
* manipulate any state during reading from the ByteReader.
|
||||
*/
|
||||
private[akka] trait ByteReadingState extends IntermediateState {
|
||||
def read(reader: ByteReader, ctx: Context[Out]): SyncDirective
|
||||
|
||||
def onPush(data: ByteString, ctx: Context[Out]): SyncDirective =
|
||||
try {
|
||||
val reader = new ByteReader(data)
|
||||
read(reader, ctx)
|
||||
} catch {
|
||||
case ByteReader.NeedMoreData ⇒
|
||||
become(TryAgain(data, this))
|
||||
pull(ctx)
|
||||
}
|
||||
}
|
||||
private case class TryAgain(previousData: ByteString, byteReadingState: ByteReadingState) extends IntermediateState {
|
||||
def onPush(data: ByteString, ctx: Context[Out]): SyncDirective = {
|
||||
become(byteReadingState)
|
||||
byteReadingState.onPush(previousData ++ data, ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -10,6 +10,7 @@ import akka.NotUsed
|
|||
import akka.http.scaladsl.model.RequestEntity
|
||||
import akka.stream._
|
||||
import akka.stream.impl.StreamLayout.Module
|
||||
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
|
||||
import akka.stream.impl.{ PublisherSink, SinkModule, SourceModule }
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream.stage._
|
||||
|
|
@ -114,40 +115,46 @@ private[http] object StreamUtils {
|
|||
Flow[ByteString].transform(() ⇒ transformer).named("sliceBytes")
|
||||
}
|
||||
|
||||
def limitByteChunksStage(maxBytesPerChunk: Int): PushPullStage[ByteString, ByteString] =
|
||||
new StatefulStage[ByteString, ByteString] {
|
||||
def initial = WaitingForData
|
||||
def limitByteChunksStage(maxBytesPerChunk: Int): GraphStage[FlowShape[ByteString, ByteString]] =
|
||||
new SimpleLinearGraphStage[ByteString] {
|
||||
override def initialAttributes = Attributes.name("limitByteChunksStage")
|
||||
var remaining = ByteString.empty
|
||||
|
||||
case object WaitingForData extends State {
|
||||
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective =
|
||||
if (elem.size <= maxBytesPerChunk) ctx.push(elem)
|
||||
else {
|
||||
become(DeliveringData(elem.drop(maxBytesPerChunk)))
|
||||
ctx.push(elem.take(maxBytesPerChunk))
|
||||
}
|
||||
}
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
|
||||
|
||||
case class DeliveringData(remaining: ByteString) extends State {
|
||||
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective =
|
||||
throw new IllegalStateException("Not expecting data")
|
||||
|
||||
override def onPull(ctx: Context[ByteString]): SyncDirective = {
|
||||
def splitAndPush(elem: ByteString): Unit = {
|
||||
val toPush = remaining.take(maxBytesPerChunk)
|
||||
val toKeep = remaining.drop(maxBytesPerChunk)
|
||||
push(out, toPush)
|
||||
remaining = toKeep
|
||||
}
|
||||
setHandlers(in, out, WaitingForData)
|
||||
|
||||
become {
|
||||
if (toKeep.isEmpty) WaitingForData
|
||||
else DeliveringData(toKeep)
|
||||
case object WaitingForData extends InHandler with OutHandler {
|
||||
override def onPush(): Unit = {
|
||||
val elem = grab(in)
|
||||
if (elem.size <= maxBytesPerChunk) push(out, elem)
|
||||
else {
|
||||
splitAndPush(elem)
|
||||
setHandlers(in, out, DeliveringData)
|
||||
}
|
||||
if (ctx.isFinishing) ctx.pushAndFinish(toPush)
|
||||
else ctx.push(toPush)
|
||||
}
|
||||
override def onPull(): Unit = pull(in)
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective =
|
||||
current match {
|
||||
case WaitingForData ⇒ ctx.finish()
|
||||
case _: DeliveringData ⇒ ctx.absorbTermination()
|
||||
case object DeliveringData extends InHandler() with OutHandler {
|
||||
var finishing = false
|
||||
override def onPush(): Unit = throw new IllegalStateException("Not expecting data")
|
||||
override def onPull(): Unit = {
|
||||
splitAndPush(remaining)
|
||||
if (remaining.isEmpty) {
|
||||
if (finishing) completeStage() else setHandlers(in, out, WaitingForData)
|
||||
}
|
||||
}
|
||||
override def onUpstreamFinish(): Unit = if (remaining.isEmpty) completeStage() else finishing = true
|
||||
}
|
||||
|
||||
override def toString = "limitByteChunksStage"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,13 @@
|
|||
|
||||
package akka.http.scaladsl.coding
|
||||
|
||||
import akka.stream.{ Attributes, FlowShape }
|
||||
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
|
||||
|
||||
import scala.concurrent.duration._
|
||||
import org.scalatest.WordSpec
|
||||
import akka.util.ByteString
|
||||
import akka.stream.stage.{ SyncDirective, Context, PushStage, Stage }
|
||||
import akka.stream.stage._
|
||||
import akka.http.scaladsl.model._
|
||||
import akka.http.impl.util._
|
||||
import headers._
|
||||
|
|
@ -34,10 +37,17 @@ class DecoderSpec extends WordSpec with CodecSpecSupport {
|
|||
case object DummyDecoder extends StreamDecoder {
|
||||
val encoding = HttpEncodings.compress
|
||||
|
||||
def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] =
|
||||
() ⇒ new PushStage[ByteString, ByteString] {
|
||||
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective =
|
||||
ctx.push(elem ++ ByteString("compressed"))
|
||||
override def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ GraphStage[FlowShape[ByteString, ByteString]] =
|
||||
() ⇒ new SimpleLinearGraphStage[ByteString] {
|
||||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
|
||||
setHandler(in, new InHandler {
|
||||
override def onPush(): Unit = push(out, grab(in) ++ ByteString("compressed"))
|
||||
})
|
||||
setHandler(out, new OutHandler {
|
||||
override def onPull(): Unit = pull(in)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ class GzipSpec extends CoderSpec {
|
|||
}
|
||||
"throw an error if compressed data is just missing the trailer at the end" in {
|
||||
def brokenCompress(payload: String) = Gzip.newCompressor.compress(ByteString(payload, "UTF-8"))
|
||||
|
||||
val ex = the[RuntimeException] thrownBy ourDecode(brokenCompress("abcdefghijkl"))
|
||||
ex.getCause.getMessage should equal("Truncated GZIP stream")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ package akka.http.scaladsl.coding
|
|||
|
||||
import akka.NotUsed
|
||||
import akka.http.scaladsl.model._
|
||||
import akka.stream.Materializer
|
||||
import akka.stream.stage.Stage
|
||||
import akka.stream.{ FlowShape, Materializer }
|
||||
import akka.stream.stage.{ GraphStage, Stage }
|
||||
import akka.util.ByteString
|
||||
import headers.HttpEncoding
|
||||
import akka.stream.scaladsl.{ Sink, Source, Flow }
|
||||
|
|
@ -37,7 +37,7 @@ object Decoder {
|
|||
|
||||
/** A decoder that is implemented in terms of a [[Stage]] */
|
||||
trait StreamDecoder extends Decoder { outer ⇒
|
||||
protected def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString]
|
||||
protected def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ GraphStage[FlowShape[ByteString, ByteString]]
|
||||
|
||||
def maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault
|
||||
def withMaxBytesPerChunk(newMaxBytesPerChunk: Int): Decoder =
|
||||
|
|
@ -45,11 +45,11 @@ trait StreamDecoder extends Decoder { outer ⇒
|
|||
def encoding: HttpEncoding = outer.encoding
|
||||
override def maxBytesPerChunk: Int = newMaxBytesPerChunk
|
||||
|
||||
def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] =
|
||||
def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ GraphStage[FlowShape[ByteString, ByteString]] =
|
||||
outer.newDecompressorStage(maxBytesPerChunk)
|
||||
}
|
||||
|
||||
def decoderFlow: Flow[ByteString, ByteString, NotUsed] =
|
||||
Flow[ByteString].transform(newDecompressorStage(maxBytesPerChunk))
|
||||
Flow.fromGraph(newDecompressorStage(maxBytesPerChunk)())
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,11 +5,12 @@
|
|||
package akka.http.scaladsl.coding
|
||||
|
||||
import java.util.zip.{ Inflater, Deflater }
|
||||
import akka.stream.stage._
|
||||
import akka.stream.Attributes
|
||||
import akka.stream.io.ByteStringParser
|
||||
import akka.stream.io.ByteStringParser.{ ParseResult, ParseStep }
|
||||
import akka.util.{ ByteStringBuilder, ByteString }
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import akka.http.impl.util._
|
||||
import akka.http.scaladsl.model._
|
||||
import akka.http.scaladsl.model.headers.HttpEncodings
|
||||
|
||||
|
|
@ -86,56 +87,49 @@ private[http] object DeflateCompressor {
|
|||
}
|
||||
|
||||
class DeflateDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) {
|
||||
protected def createInflater() = new Inflater()
|
||||
|
||||
def initial: State = StartInflate
|
||||
def afterInflate: State = StartInflate
|
||||
override def createLogic(attr: Attributes) = new DecompressorParsingLogic {
|
||||
override val inflater: Inflater = new Inflater()
|
||||
|
||||
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = {}
|
||||
protected def onTruncation(ctx: Context[ByteString]): SyncDirective = ctx.finish()
|
||||
override val inflateState = new Inflate(true) {
|
||||
override def onTruncation(): Unit = completeStage()
|
||||
}
|
||||
|
||||
abstract class DeflateDecompressorBase(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends ByteStringParserStage[ByteString] {
|
||||
protected def createInflater(): Inflater
|
||||
val inflater = createInflater()
|
||||
override def afterInflate = inflateState
|
||||
override def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = {}
|
||||
|
||||
protected def afterInflate: State
|
||||
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit
|
||||
|
||||
/** Start inflating */
|
||||
case object StartInflate extends IntermediateState {
|
||||
def onPush(data: ByteString, ctx: Context[ByteString]): SyncDirective = {
|
||||
require(inflater.needsInput())
|
||||
inflater.setInput(data.toArray)
|
||||
|
||||
becomeWithRemaining(Inflate()(data), ByteString.empty, ctx)
|
||||
startWith(inflateState)
|
||||
}
|
||||
}
|
||||
|
||||
/** Inflate */
|
||||
case class Inflate()(data: ByteString) extends IntermediateState {
|
||||
override def onPull(ctx: Context[ByteString]): SyncDirective = {
|
||||
abstract class DeflateDecompressorBase(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault)
|
||||
extends ByteStringParser[ByteString] {
|
||||
|
||||
abstract class DecompressorParsingLogic extends ParsingLogic {
|
||||
val inflater: Inflater
|
||||
def afterInflate: ParseStep[ByteString]
|
||||
def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit
|
||||
val inflateState: Inflate
|
||||
|
||||
abstract class Inflate(noPostProcessing: Boolean) extends ParseStep[ByteString] {
|
||||
override def canWorkWithPartialData = true
|
||||
override def parse(reader: ByteStringParser.ByteReader): ParseResult[ByteString] = {
|
||||
inflater.setInput(reader.remainingData.toArray)
|
||||
|
||||
val buffer = new Array[Byte](maxBytesPerChunk)
|
||||
val read = inflater.inflate(buffer)
|
||||
|
||||
reader.skip(reader.remainingSize - inflater.getRemaining)
|
||||
|
||||
if (read > 0) {
|
||||
afterBytesRead(buffer, 0, read)
|
||||
ctx.push(ByteString.fromArray(buffer, 0, read))
|
||||
val next = if (inflater.finished()) afterInflate else this
|
||||
ParseResult(Some(ByteString.fromArray(buffer, 0, read)), next, noPostProcessing)
|
||||
} else {
|
||||
val remaining = data.takeRight(inflater.getRemaining)
|
||||
val next =
|
||||
if (inflater.finished()) afterInflate
|
||||
else StartInflate
|
||||
|
||||
becomeWithRemaining(next, remaining, ctx)
|
||||
if (inflater.finished()) ParseResult(None, afterInflate, noPostProcessing)
|
||||
else throw ByteStringParser.NeedMoreData
|
||||
}
|
||||
}
|
||||
}
|
||||
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective =
|
||||
throw new IllegalStateException("Don't expect a new Element")
|
||||
}
|
||||
|
||||
def becomeWithRemaining(next: State, remaining: ByteString, ctx: Context[ByteString]) = {
|
||||
become(next)
|
||||
if (remaining.isEmpty) current.onPull(ctx)
|
||||
else current.onPush(remaining, ctx)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,14 +4,15 @@
|
|||
|
||||
package akka.http.scaladsl.coding
|
||||
|
||||
import akka.util.ByteString
|
||||
import akka.stream.stage._
|
||||
|
||||
import akka.http.impl.util.ByteReader
|
||||
import java.util.zip.{ Inflater, CRC32, ZipException, Deflater }
|
||||
import java.util.zip.{ CRC32, Deflater, Inflater, ZipException }
|
||||
|
||||
import akka.http.impl.engine.ws.{ ProtocolException, FrameEvent }
|
||||
import akka.http.scaladsl.model._
|
||||
import headers.HttpEncodings
|
||||
import akka.http.scaladsl.model.headers.HttpEncodings
|
||||
import akka.stream.Attributes
|
||||
import akka.stream.io.ByteStringParser
|
||||
import akka.stream.io.ByteStringParser.{ ParseResult, ParseStep }
|
||||
import akka.util.ByteString
|
||||
|
||||
class Gzip(val messageFilter: HttpMessage ⇒ Boolean) extends Coder with StreamDecoder {
|
||||
val encoding = HttpEncodings.gzip
|
||||
|
|
@ -60,30 +61,22 @@ class GzipCompressor extends DeflateCompressor {
|
|||
}
|
||||
|
||||
class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) {
|
||||
protected def createInflater(): Inflater = new Inflater(true)
|
||||
override def createLogic(attr: Attributes) = new DecompressorParsingLogic {
|
||||
override val inflater: Inflater = new Inflater(true)
|
||||
override def afterInflate: ParseStep[ByteString] = ReadTrailer
|
||||
override def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit =
|
||||
crc32.update(buffer, offset, length)
|
||||
|
||||
def initial: State = Initial
|
||||
|
||||
/** No bytes were received yet */
|
||||
case object Initial extends State {
|
||||
def onPush(data: ByteString, ctx: Context[ByteString]): SyncDirective =
|
||||
if (data.isEmpty) ctx.pull()
|
||||
else becomeWithRemaining(ReadHeaders, data, ctx)
|
||||
|
||||
override def onPull(ctx: Context[ByteString]): SyncDirective =
|
||||
if (ctx.isFinishing) {
|
||||
ctx.finish()
|
||||
} else super.onPull(ctx)
|
||||
trait Step extends ParseStep[ByteString] {
|
||||
override def onTruncation(): Unit = failStage(new ZipException("Truncated GZIP stream"))
|
||||
}
|
||||
|
||||
var crc32: CRC32 = new CRC32
|
||||
protected def afterInflate: State = ReadTrailer
|
||||
override val inflateState = new Inflate(false) with Step
|
||||
startWith(ReadHeaders)
|
||||
|
||||
/** Reading the header bytes */
|
||||
case object ReadHeaders extends ByteReadingState {
|
||||
def read(reader: ByteReader, ctx: Context[ByteString]): SyncDirective = {
|
||||
case object ReadHeaders extends Step {
|
||||
override def parse(reader: ByteStringParser.ByteReader): ParseResult[ByteString] = {
|
||||
import reader._
|
||||
|
||||
if (readByte() != 0x1F || readByte() != 0x8B) fail("Not in GZIP format") // check magic header
|
||||
if (readByte() != 8) fail("Unsupported GZIP compression method") // check compression method
|
||||
val flags = readByte()
|
||||
|
|
@ -95,36 +88,28 @@ class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault)
|
|||
|
||||
inflater.reset()
|
||||
crc32.reset()
|
||||
becomeWithRemaining(StartInflate, remainingData, ctx)
|
||||
ParseResult(None, inflateState, false)
|
||||
}
|
||||
}
|
||||
|
||||
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit =
|
||||
crc32.update(buffer, offset, length)
|
||||
var crc32: CRC32 = new CRC32
|
||||
private def fail(msg: String) = throw new ZipException(msg)
|
||||
|
||||
/** Reading the trailer */
|
||||
case object ReadTrailer extends ByteReadingState {
|
||||
def read(reader: ByteReader, ctx: Context[ByteString]): SyncDirective = {
|
||||
case object ReadTrailer extends Step {
|
||||
override def parse(reader: ByteStringParser.ByteReader): ParseResult[ByteString] = {
|
||||
import reader._
|
||||
|
||||
if (readIntLE() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)")
|
||||
if (readIntLE() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE")
|
||||
|
||||
becomeWithRemaining(Initial, remainingData, ctx)
|
||||
if (readIntLE() != inflater.getBytesWritten.toInt /* truncated to 32bit */ )
|
||||
fail("Corrupt GZIP trailer ISIZE")
|
||||
ParseResult(None, ReadHeaders, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination()
|
||||
|
||||
private def crc16(data: ByteString) = {
|
||||
val crc = new CRC32
|
||||
crc.update(data.toArray)
|
||||
crc.getValue.toInt & 0xFFFF
|
||||
}
|
||||
|
||||
override protected def onTruncation(ctx: Context[ByteString]): SyncDirective = ctx.fail(new ZipException("Truncated GZIP stream"))
|
||||
|
||||
private def fail(msg: String) = throw new ZipException(msg)
|
||||
}
|
||||
|
||||
/** INTERNAL API */
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ package akka.http.scaladsl.coding
|
|||
|
||||
import akka.http.scaladsl.model._
|
||||
import akka.http.impl.util.StreamUtils
|
||||
import akka.stream.stage.Stage
|
||||
import akka.stream.FlowShape
|
||||
import akka.stream.stage.{ GraphStage, Stage }
|
||||
import akka.util.ByteString
|
||||
import headers.HttpEncodings
|
||||
|
||||
|
|
@ -25,7 +26,7 @@ object NoCoding extends Coder with StreamDecoder {
|
|||
|
||||
def newCompressor = NoCodingCompressor
|
||||
|
||||
def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] =
|
||||
def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ GraphStage[FlowShape[ByteString, ByteString]] =
|
||||
() ⇒ StreamUtils.limitByteChunksStage(maxBytesPerChunk)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ object GraphStages {
|
|||
/**
|
||||
* INERNAL API
|
||||
*/
|
||||
private[stream] abstract class SimpleLinearGraphStage[T] extends GraphStage[FlowShape[T, T]] {
|
||||
private[akka] abstract class SimpleLinearGraphStage[T] extends GraphStage[FlowShape[T, T]] {
|
||||
val in = Inlet[T](Logging.simpleName(this) + ".in")
|
||||
val out = Outlet[T](Logging.simpleName(this) + ".out")
|
||||
override val shape = FlowShape(in, out)
|
||||
|
|
|
|||
|
|
@ -19,30 +19,35 @@ abstract class ByteStringParser[T] extends GraphStage[FlowShape[ByteString, T]]
|
|||
final override val shape = FlowShape(bytesIn, objOut)
|
||||
|
||||
class ParsingLogic extends GraphStageLogic(shape) {
|
||||
var pullOnParserRequest = false
|
||||
override def preStart(): Unit = pull(bytesIn)
|
||||
setHandler(objOut, eagerTerminateOutput)
|
||||
|
||||
private var buffer = ByteString.empty
|
||||
private var current: ParseStep[T] = FinishedParser
|
||||
private var acceptUpstreamFinish: Boolean = true
|
||||
|
||||
final protected def startWith(step: ParseStep[T]): Unit = current = step
|
||||
|
||||
@tailrec private def doParse(): Unit =
|
||||
if (buffer.nonEmpty) {
|
||||
val cont = try {
|
||||
val reader = new ByteReader(buffer)
|
||||
val (elem, next) = current.parse(reader)
|
||||
emit(objOut, elem)
|
||||
if (next == FinishedParser) {
|
||||
val cont = try {
|
||||
val parseResult = current.parse(reader)
|
||||
acceptUpstreamFinish = parseResult.acceptUpstreamFinish
|
||||
parseResult.result.map(emit(objOut, _))
|
||||
if (parseResult.nextStep == FinishedParser) {
|
||||
completeStage()
|
||||
false
|
||||
} else {
|
||||
buffer = reader.remainingData
|
||||
current = next
|
||||
current = parseResult.nextStep
|
||||
true
|
||||
}
|
||||
} catch {
|
||||
case NeedMoreData ⇒
|
||||
acceptUpstreamFinish = false
|
||||
if (current.canWorkWithPartialData) buffer = reader.remainingData
|
||||
pull(bytesIn)
|
||||
false
|
||||
}
|
||||
|
|
@ -51,11 +56,12 @@ abstract class ByteStringParser[T] extends GraphStage[FlowShape[ByteString, T]]
|
|||
|
||||
setHandler(bytesIn, new InHandler {
|
||||
override def onPush(): Unit = {
|
||||
pullOnParserRequest = false
|
||||
buffer ++= grab(bytesIn)
|
||||
doParse()
|
||||
}
|
||||
override def onUpstreamFinish(): Unit =
|
||||
if (buffer.isEmpty) completeStage()
|
||||
if (buffer.isEmpty && acceptUpstreamFinish) completeStage()
|
||||
else current.onTruncation()
|
||||
})
|
||||
}
|
||||
|
|
@ -63,13 +69,28 @@ abstract class ByteStringParser[T] extends GraphStage[FlowShape[ByteString, T]]
|
|||
|
||||
object ByteStringParser {
|
||||
|
||||
/**
|
||||
* @param result - parser can return some element for downstream or return None if no element was generated
|
||||
* @param nextStep - next parser
|
||||
* @param acceptUpstreamFinish - if true - stream will complete when received `onUpstreamFinish`, if "false"
|
||||
* - onTruncation will be called
|
||||
*/
|
||||
case class ParseResult[+T](result: Option[T],
|
||||
nextStep: ParseStep[T],
|
||||
acceptUpstreamFinish: Boolean = true)
|
||||
|
||||
trait ParseStep[+T] {
|
||||
def parse(reader: ByteReader): (T, ParseStep[T])
|
||||
/**
|
||||
* Must return true when NeedMoreData will clean buffer. If returns false - next pulled
|
||||
* data will be appended to existing data in buffer
|
||||
*/
|
||||
def canWorkWithPartialData: Boolean = false
|
||||
def parse(reader: ByteReader): ParseResult[T]
|
||||
def onTruncation(): Unit = throw new IllegalStateException("truncated data in ByteStringParser")
|
||||
}
|
||||
|
||||
object FinishedParser extends ParseStep[Nothing] {
|
||||
def parse(reader: ByteReader) =
|
||||
override def parse(reader: ByteReader) =
|
||||
throw new IllegalStateException("no initial parser installed: you must use startWith(...)")
|
||||
}
|
||||
|
||||
|
|
@ -83,6 +104,7 @@ object ByteStringParser {
|
|||
def remainingSize: Int = input.size - off
|
||||
|
||||
def currentOffset: Int = off
|
||||
|
||||
def remainingData: ByteString = input.drop(off)
|
||||
def fromStartToHere: ByteString = input.take(off)
|
||||
|
||||
|
|
|
|||
|
|
@ -257,7 +257,6 @@ object Zip {
|
|||
* '''Cancels when''' any downstream cancels
|
||||
*/
|
||||
object Unzip {
|
||||
import akka.japi.function.Function
|
||||
|
||||
/**
|
||||
* Creates a new `Unzip` stage with the specified output types.
|
||||
|
|
|
|||
|
|
@ -300,6 +300,14 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount:
|
|||
if (_interpreter != null) _interpreter.setHandler(conn(in), handler)
|
||||
}
|
||||
|
||||
/**
|
||||
* Assign callbacks for linear stage for both [[Inlet]] and [[Outlet]]
|
||||
*/
|
||||
final protected def setHandlers(in: Inlet[_], out: Outlet[_], handler: InHandler with OutHandler): Unit = {
|
||||
setHandler(in, handler)
|
||||
setHandler(out, handler)
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the current callback for the events on the given [[Inlet]]
|
||||
*/
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue