=str #19361 migrating ByteStringParserStage to graph stage
This commit is contained in:
parent
3fc332d2c9
commit
07c0da36f2
14 changed files with 184 additions and 267 deletions
|
|
@ -4,14 +4,15 @@
|
|||
|
||||
package akka.http.scaladsl.coding
|
||||
|
||||
import akka.util.ByteString
|
||||
import akka.stream.stage._
|
||||
|
||||
import akka.http.impl.util.ByteReader
|
||||
import java.util.zip.{ Inflater, CRC32, ZipException, Deflater }
|
||||
import java.util.zip.{ CRC32, Deflater, Inflater, ZipException }
|
||||
|
||||
import akka.http.impl.engine.ws.{ ProtocolException, FrameEvent }
|
||||
import akka.http.scaladsl.model._
|
||||
import headers.HttpEncodings
|
||||
import akka.http.scaladsl.model.headers.HttpEncodings
|
||||
import akka.stream.Attributes
|
||||
import akka.stream.io.ByteStringParser
|
||||
import akka.stream.io.ByteStringParser.{ ParseResult, ParseStep }
|
||||
import akka.util.ByteString
|
||||
|
||||
class Gzip(val messageFilter: HttpMessage ⇒ Boolean) extends Coder with StreamDecoder {
|
||||
val encoding = HttpEncodings.gzip
|
||||
|
|
@ -60,71 +61,55 @@ class GzipCompressor extends DeflateCompressor {
|
|||
}
|
||||
|
||||
class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) {
|
||||
protected def createInflater(): Inflater = new Inflater(true)
|
||||
override def createLogic(attr: Attributes) = new DecompressorParsingLogic {
|
||||
override val inflater: Inflater = new Inflater(true)
|
||||
override def afterInflate: ParseStep[ByteString] = ReadTrailer
|
||||
override def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit =
|
||||
crc32.update(buffer, offset, length)
|
||||
|
||||
def initial: State = Initial
|
||||
trait Step extends ParseStep[ByteString] {
|
||||
override def onTruncation(): Unit = failStage(new ZipException("Truncated GZIP stream"))
|
||||
}
|
||||
override val inflateState = new Inflate(false) with Step
|
||||
startWith(ReadHeaders)
|
||||
|
||||
/** No bytes were received yet */
|
||||
case object Initial extends State {
|
||||
def onPush(data: ByteString, ctx: Context[ByteString]): SyncDirective =
|
||||
if (data.isEmpty) ctx.pull()
|
||||
else becomeWithRemaining(ReadHeaders, data, ctx)
|
||||
/** Reading the header bytes */
|
||||
case object ReadHeaders extends Step {
|
||||
override def parse(reader: ByteStringParser.ByteReader): ParseResult[ByteString] = {
|
||||
import reader._
|
||||
if (readByte() != 0x1F || readByte() != 0x8B) fail("Not in GZIP format") // check magic header
|
||||
if (readByte() != 8) fail("Unsupported GZIP compression method") // check compression method
|
||||
val flags = readByte()
|
||||
skip(6) // skip MTIME, XFL and OS fields
|
||||
if ((flags & 4) > 0) skip(readShortLE()) // skip optional extra fields
|
||||
if ((flags & 8) > 0) skipZeroTerminatedString() // skip optional file name
|
||||
if ((flags & 16) > 0) skipZeroTerminatedString() // skip optional file comment
|
||||
if ((flags & 2) > 0 && crc16(fromStartToHere) != readShortLE()) fail("Corrupt GZIP header")
|
||||
|
||||
override def onPull(ctx: Context[ByteString]): SyncDirective =
|
||||
if (ctx.isFinishing) {
|
||||
ctx.finish()
|
||||
} else super.onPull(ctx)
|
||||
}
|
||||
inflater.reset()
|
||||
crc32.reset()
|
||||
ParseResult(None, inflateState, false)
|
||||
}
|
||||
}
|
||||
var crc32: CRC32 = new CRC32
|
||||
private def fail(msg: String) = throw new ZipException(msg)
|
||||
|
||||
var crc32: CRC32 = new CRC32
|
||||
protected def afterInflate: State = ReadTrailer
|
||||
|
||||
/** Reading the header bytes */
|
||||
case object ReadHeaders extends ByteReadingState {
|
||||
def read(reader: ByteReader, ctx: Context[ByteString]): SyncDirective = {
|
||||
import reader._
|
||||
|
||||
if (readByte() != 0x1F || readByte() != 0x8B) fail("Not in GZIP format") // check magic header
|
||||
if (readByte() != 8) fail("Unsupported GZIP compression method") // check compression method
|
||||
val flags = readByte()
|
||||
skip(6) // skip MTIME, XFL and OS fields
|
||||
if ((flags & 4) > 0) skip(readShortLE()) // skip optional extra fields
|
||||
if ((flags & 8) > 0) skipZeroTerminatedString() // skip optional file name
|
||||
if ((flags & 16) > 0) skipZeroTerminatedString() // skip optional file comment
|
||||
if ((flags & 2) > 0 && crc16(fromStartToHere) != readShortLE()) fail("Corrupt GZIP header")
|
||||
|
||||
inflater.reset()
|
||||
crc32.reset()
|
||||
becomeWithRemaining(StartInflate, remainingData, ctx)
|
||||
/** Reading the trailer */
|
||||
case object ReadTrailer extends Step {
|
||||
override def parse(reader: ByteStringParser.ByteReader): ParseResult[ByteString] = {
|
||||
import reader._
|
||||
if (readIntLE() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)")
|
||||
if (readIntLE() != inflater.getBytesWritten.toInt /* truncated to 32bit */ )
|
||||
fail("Corrupt GZIP trailer ISIZE")
|
||||
ParseResult(None, ReadHeaders, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit =
|
||||
crc32.update(buffer, offset, length)
|
||||
|
||||
/** Reading the trailer */
|
||||
case object ReadTrailer extends ByteReadingState {
|
||||
def read(reader: ByteReader, ctx: Context[ByteString]): SyncDirective = {
|
||||
import reader._
|
||||
|
||||
if (readIntLE() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)")
|
||||
if (readIntLE() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE")
|
||||
|
||||
becomeWithRemaining(Initial, remainingData, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination()
|
||||
|
||||
private def crc16(data: ByteString) = {
|
||||
val crc = new CRC32
|
||||
crc.update(data.toArray)
|
||||
crc.getValue.toInt & 0xFFFF
|
||||
}
|
||||
|
||||
override protected def onTruncation(ctx: Context[ByteString]): SyncDirective = ctx.fail(new ZipException("Truncated GZIP stream"))
|
||||
|
||||
private def fail(msg: String) = throw new ZipException(msg)
|
||||
}
|
||||
|
||||
/** INTERNAL API */
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue