From 29d7a041f6eb5a47d835040bc92ef3cd2748593e Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Wed, 10 Dec 2014 16:55:50 +0100 Subject: [PATCH] +htp #16516 rewrite Deflate/GzipDecompressor as StatefulStage to defuse zip bomb Also, the tests have been DRY'd up. --- .../scala/akka/http/util/ByteReader.scala | 44 ++++ .../http/util/ByteStringParserStage.scala | 58 +++++ .../scala/akka/http/util/StreamUtils.scala | 40 +++- .../akka/http/model/HttpEntitySpec.scala | 2 +- .../scala/akka/http/coding/CoderSpec.scala | 154 +++++++++++++ .../scala/akka/http/coding/DecoderSpec.scala | 15 +- .../scala/akka/http/coding/DeflateSpec.scala | 83 ++----- .../scala/akka/http/coding/GzipSpec.scala | 131 ++--------- .../scala/akka/http/coding/NoCodingSpec.scala | 16 ++ .../directives/CodingDirectivesSpec.scala | 2 +- .../main/scala/akka/http/coding/Decoder.scala | 50 +++-- .../main/scala/akka/http/coding/Deflate.scala | 90 +++++--- .../main/scala/akka/http/coding/Encoder.scala | 5 +- .../main/scala/akka/http/coding/Gzip.scala | 204 ++++++------------ .../scala/akka/http/coding/NoCoding.scala | 12 +- 15 files changed, 510 insertions(+), 396 deletions(-) create mode 100644 akka-http-core/src/main/scala/akka/http/util/ByteReader.scala create mode 100644 akka-http-core/src/main/scala/akka/http/util/ByteStringParserStage.scala create mode 100644 akka-http-tests/src/test/scala/akka/http/coding/CoderSpec.scala create mode 100644 akka-http-tests/src/test/scala/akka/http/coding/NoCodingSpec.scala diff --git a/akka-http-core/src/main/scala/akka/http/util/ByteReader.scala b/akka-http-core/src/main/scala/akka/http/util/ByteReader.scala new file mode 100644 index 0000000000..e68b867719 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/util/ByteReader.scala @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.util + +import akka.util.ByteString + +import scala.util.control.NoStackTrace + +/** + * A helper class to read from a ByteString statefully. + * + * INTERNAL API + */ +private[akka] class ByteReader(input: ByteString) { + import ByteReader.NeedMoreData + + private[this] var off = 0 + + def currentOffset: Int = off + def remainingData: ByteString = input.drop(off) + def fromStartToHere: ByteString = input.take(currentOffset) + + def readByte(): Int = + if (off < input.length) { + val x = input(off) + off += 1 + x.toInt & 0xFF + } else throw NeedMoreData + def readShort(): Int = readByte() | (readByte() << 8) + def readInt(): Int = readShort() | (readShort() << 16) + def skip(numBytes: Int): Unit = + if (off + numBytes <= input.length) off += numBytes + else throw NeedMoreData + def skipZeroTerminatedString(): Unit = while (readByte() != 0) {} +} + +/* +* INTERNAL API +*/ +private[akka] object ByteReader { + val NeedMoreData = new Exception with NoStackTrace +} \ No newline at end of file diff --git a/akka-http-core/src/main/scala/akka/http/util/ByteStringParserStage.scala b/akka-http-core/src/main/scala/akka/http/util/ByteStringParserStage.scala new file mode 100644 index 0000000000..f030b8423b --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/util/ByteStringParserStage.scala @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.util + +import akka.stream.stage.{ Directive, Context, StatefulStage } +import akka.util.ByteString + +/** + * A helper class for writing parsers from ByteStrings. + * + * FIXME: move to akka.stream.io, https://github.com/akka/akka/issues/16529 + * + * INTERNAL API + */ +private[akka] abstract class ByteStringParserStage[Out] extends StatefulStage[ByteString, Out] { + protected def onTruncation(ctx: Context[Out]): Directive + + /** + * Derive a stage from [[IntermediateState]] and then call `pull(ctx)` instead of + * `ctx.pull()` to have truncation errors reported. + */ + abstract class IntermediateState extends State { + override def onPull(ctx: Context[Out]): Directive = pull(ctx) + def pull(ctx: Context[Out]): Directive = + if (ctx.isFinishing) onTruncation(ctx) + else ctx.pull() + } + + /** + * A stage that tries to read from a side-effecting [[ByteReader]]. If a buffer underrun + * occurs the previous data is saved and the reading process is restarted from the beginning + * once more data was received. + * + * As [[read]] may be called several times for the same prefix of data, make sure not to + * manipulate any state during reading from the ByteReader. + */ + trait ByteReadingState extends IntermediateState { + def read(reader: ByteReader, ctx: Context[Out]): Directive + + def onPush(data: ByteString, ctx: Context[Out]): Directive = + try { + val reader = new ByteReader(data) + read(reader, ctx) + } catch { + case ByteReader.NeedMoreData ⇒ + become(TryAgain(data, this)) + pull(ctx) + } + } + case class TryAgain(previousData: ByteString, byteReadingState: ByteReadingState) extends IntermediateState { + def onPush(data: ByteString, ctx: Context[Out]): Directive = { + become(byteReadingState) + byteReadingState.onPush(previousData ++ data, ctx) + } + } +} \ No newline at end of file diff --git a/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala b/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala index 10ec792f5e..c3c9b8e4cb 100644 --- a/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala +++ b/akka-http-core/src/main/scala/akka/http/util/StreamUtils.scala @@ -34,8 +34,8 @@ private[http] object StreamUtils { * Creates a transformer that will call `f` for each incoming ByteString and output its result. After the complete * input has been read it will call `finish` once to determine the final ByteString to post to the output. */ - def byteStringTransformer(f: ByteString ⇒ ByteString, finish: () ⇒ ByteString): Flow[ByteString, ByteString] = { - val transformer = new PushPullStage[ByteString, ByteString] { + def byteStringTransformer(f: ByteString ⇒ ByteString, finish: () ⇒ ByteString): Stage[ByteString, ByteString] = { + new PushPullStage[ByteString, ByteString] { override def onPush(element: ByteString, ctx: Context[ByteString]): Directive = ctx.push(f(element)) @@ -45,7 +45,6 @@ private[http] object StreamUtils { override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination() } - Flow[ByteString].section(name("transformBytes"))(_.transform(() ⇒ transformer)) } def failedPublisher[T](ex: Throwable): Publisher[T] = @@ -94,6 +93,41 @@ private[http] object StreamUtils { Flow[ByteString].section(name("sliceBytes"))(_.transform(() ⇒ transformer)) } + def limitByteChunksStage(maxBytesPerChunk: Int): Stage[ByteString, ByteString] = + new StatefulStage[ByteString, ByteString] { + def initial = WaitingForData + case object WaitingForData extends State { + def onPush(elem: ByteString, ctx: Context[ByteString]): Directive = + if (elem.size <= maxBytesPerChunk) ctx.push(elem) + else { + become(DeliveringData(elem.drop(maxBytesPerChunk))) + ctx.push(elem.take(maxBytesPerChunk)) + } + } + case class DeliveringData(remaining: ByteString) extends State { + def onPush(elem: ByteString, ctx: Context[ByteString]): Directive = + throw new IllegalStateException("Not expecting data") + + override def onPull(ctx: Context[ByteString]): Directive = { + val toPush = remaining.take(maxBytesPerChunk) + val toKeep = remaining.drop(maxBytesPerChunk) + + become { + if (toKeep.isEmpty) WaitingForData + else DeliveringData(toKeep) + } + if (ctx.isFinishing) ctx.pushAndFinish(toPush) + else ctx.push(toPush) + } + } + + override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = + current match { + case WaitingForData ⇒ ctx.finish() + case _: DeliveringData ⇒ ctx.absorbTermination() + } + } + /** * Applies a sequence of transformers on one source and returns a sequence of sources with the result. The input source * will only be traversed once. diff --git a/akka-http-core/src/test/scala/akka/http/model/HttpEntitySpec.scala b/akka-http-core/src/test/scala/akka/http/model/HttpEntitySpec.scala index 707ec1604b..9a0ce78c77 100644 --- a/akka-http-core/src/test/scala/akka/http/model/HttpEntitySpec.scala +++ b/akka-http-core/src/test/scala/akka/http/model/HttpEntitySpec.scala @@ -121,7 +121,7 @@ class HttpEntitySpec extends FreeSpec with MustMatchers with BeforeAndAfterAll { } def duplicateBytesTransformer(): Flow[ByteString, ByteString] = - StreamUtils.byteStringTransformer(doubleChars, () ⇒ trailer) + Flow[ByteString].transform(() ⇒ StreamUtils.byteStringTransformer(doubleChars, () ⇒ trailer)) def trailer: ByteString = ByteString("--dup") def doubleChars(bs: ByteString): ByteString = ByteString(bs.flatMap(b ⇒ Seq(b, b)): _*) diff --git a/akka-http-tests/src/test/scala/akka/http/coding/CoderSpec.scala b/akka-http-tests/src/test/scala/akka/http/coding/CoderSpec.scala new file mode 100644 index 0000000000..082036308d --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/coding/CoderSpec.scala @@ -0,0 +1,154 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.coding + +import org.scalatest.{ Inspectors, WordSpec } +import scala.annotation.tailrec + +import scala.concurrent.duration._ + +import java.io.{ OutputStream, InputStream, ByteArrayInputStream, ByteArrayOutputStream } +import java.util +import java.util.zip.DataFormatException + +import akka.http.util._ + +import akka.http.model.HttpMethods._ +import akka.http.model.{ HttpEntity, HttpRequest } +import akka.stream.scaladsl.Source +import akka.util.ByteString + +import scala.util.control.NoStackTrace + +abstract class CoderSpec extends WordSpec with CodecSpecSupport with Inspectors { + protected def Coder: Coder with StreamDecoder + protected def newDecodedInputStream(underlying: InputStream): InputStream + protected def newEncodedOutputStream(underlying: OutputStream): OutputStream + + case object AllDataAllowed extends Exception with NoStackTrace + protected def corruptInputMessage: Option[String] + + def extraTests(): Unit = {} + + s"The ${Coder.encoding.value} codec" should { + "properly encode a small string" in { + streamDecode(ourEncode(smallTextBytes)) should readAs(smallText) + } + "properly decode a small string" in { + ourDecode(streamEncode(smallTextBytes)) should readAs(smallText) + } + "properly round-trip encode/decode a small string" in { + ourDecode(ourEncode(smallTextBytes)) should readAs(smallText) + } + "properly encode a large string" in { + streamDecode(ourEncode(largeTextBytes)) should readAs(largeText) + } + "properly decode a large string" in { + ourDecode(streamEncode(largeTextBytes)) should readAs(largeText) + } + "properly round-trip encode/decode a large string" in { + ourDecode(ourEncode(largeTextBytes)) should readAs(largeText) + } + "properly round-trip encode/decode an HttpRequest" in { + val request = HttpRequest(POST, entity = HttpEntity(largeText)) + Coder.decode(Coder.encode(request)) should equal(request) + } + "throw an error on corrupt input" in { + corruptInputMessage foreach { message ⇒ + val ex = the[DataFormatException] thrownBy ourDecode(corruptContent) + ex.getMessage should equal(message) + } + } + "not throw an error if a subsequent block is corrupt" in { + pending // FIXME: should we read as long as possible and only then report an error, that seems somewhat arbitrary + ourDecode(Seq(encode("Hello,"), encode(" dear "), corruptContent).join) should readAs("Hello, dear ") + } + "decompress in very small chunks" in { + val compressed = encode("Hello") + + decodeChunks(Source(Vector(compressed.take(10), compressed.drop(10)))) should readAs("Hello") + } + "support chunked round-trip encoding/decoding" in { + val chunks = largeTextBytes.grouped(512).toVector + val comp = Coder.newCompressor + val compressedChunks = chunks.map { chunk ⇒ comp.compressAndFlush(chunk) } :+ comp.finish() + val uncompressed = Coder.decodeFromIterator(compressedChunks.iterator) + + uncompressed should readAs(largeText) + } + "works for any split in prefix + suffix" in { + val compressed = streamEncode(smallTextBytes) + def tryWithPrefixOfSize(prefixSize: Int): Unit = { + val prefix = compressed.take(prefixSize) + val suffix = compressed.drop(prefixSize) + + decodeChunks(Source(prefix :: suffix :: Nil)) should readAs(smallText) + } + (0 to compressed.size).foreach(tryWithPrefixOfSize) + } + "works for chunked compressed data of sizes just above 1024" in { + val comp = Coder.newCompressor + val inputBytes = ByteString("""{"baseServiceURL":"http://www.acme.com","endpoints":{"assetSearchURL":"/search","showsURL":"/shows","mediaContainerDetailURL":"/container","featuredTapeURL":"/tape","assetDetailURL":"/asset","moviesURL":"/movies","recentlyAddedURL":"/recent","topicsURL":"/topics","scheduleURL":"/schedule"},"urls":{"aboutAweURL":"www.foobar.com"},"channelName":"Cool Stuff","networkId":"netId","slotProfile":"slot_1","brag":{"launchesUntilPrompt":10,"daysUntilPrompt":5,"launchesUntilReminder":5,"daysUntilReminder":2},"feedbackEmailAddress":"feedback@acme.com","feedbackEmailSubject":"Commends from User","splashSponsor":[],"adProvider":{"adProviderProfile":"","adProviderProfileAndroid":"","adProviderNetworkID":0,"adProviderSiteSectionNetworkID":0,"adProviderVideoAssetNetworkID":0,"adProviderSiteSectionCustomID":{},"adProviderServerURL":"","adProviderLiveVideoAssetID":""},"update":[{"forPlatform":"ios","store":{"iTunes":"www.something.com"},"minVer":"1.2.3","notificationVer":"1.2.5"},{"forPlatform":"android","store":{"amazon":"www.something.com","play":"www.something.com"},"minVer":"1.2.3","notificationVer":"1.2.5"}],"tvRatingPolicies":[{"type":"sometype","imageKey":"tv_rating_small","durationMS":15000,"precedence":1},{"type":"someothertype","imageKey":"tv_rating_big","durationMS":15000,"precedence":2}],"exts":{"adConfig":{"globals":{"#{adNetworkID}":"2620","#{ssid}":"usa_tveapp"},"iPad":{"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_ipad/shows","adSize":[{"#{height}":90,"#{width}":728}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_ipad&sz=1x1&t=&c=#{doubleclickrandom}"},"watchwithshowtile":{"adMobAdUnitID":"/2620/usa_tveapp_ipad/watchwithshowtile","adSize":[{"#{height}":120,"#{width}":240}]},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_ipad/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"iPadRetina":{"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_ipad/shows","adSize":[{"#{height}":90,"#{width}":728}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_ipad&sz=1x1&t=&c=#{doubleclickrandom}"},"watchwithshowtile":{"adMobAdUnitID":"/2620/usa_tveapp_ipad/watchwithshowtile","adSize":[{"#{height}":120,"#{width}":240}]},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_ipad/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"iPhone":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/home","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/shows","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/shows/#{SHOW_NAME}","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_iphone&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_iphone/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"iPhoneRetina":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/home","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/shows","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/shows/#{SHOW_NAME}","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_iphone&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_iphone/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"Tablet":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/home","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/shows","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/shows/#{SHOW_NAME}","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_androidtab&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_androidtab/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"TabletHD":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/home","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/shows","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/shows/#{SHOW_NAME}","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_androidtab&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_androidtab/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"Phone":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_android/home","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_android/shows","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_android/shows/#{SHOW_NAME}","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_android&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_android/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"PhoneHD":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_android/home","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_android/shows","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_android/shows/#{SHOW_NAME}","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_android&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_android/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}}}}}""", "utf8") + val compressed = comp.compressAndFinish(inputBytes) + + ourDecode(compressed) should equal(inputBytes) + } + + extraTests() + + "shouldn't produce huge ByteStrings for some input" in { + val array = new Array[Byte](10) // FIXME + util.Arrays.fill(array, 1.toByte) + val compressed = streamEncode(ByteString(array)) + val limit = 10000 + val resultBs = + Source.singleton(compressed) + .via(Coder.withMaxBytesPerChunk(limit).decoderFlow) + .collectAll + .awaitResult(1.second) + + forAll(resultBs) { bs ⇒ + bs.length should be < limit + bs.forall(_ == 1) should equal(true) + } + } + } + + def encode(s: String) = ourEncode(ByteString(s, "UTF8")) + def ourEncode(bytes: ByteString): ByteString = Coder.encode(bytes) + def ourDecode(bytes: ByteString): ByteString = Coder.decode(bytes) + + lazy val corruptContent = { + val content = encode(largeText).toArray + content(14) = 26.toByte + ByteString(content) + } + + def streamEncode(bytes: ByteString): ByteString = { + val output = new ByteArrayOutputStream() + val gos = newEncodedOutputStream(output); gos.write(bytes.toArray); gos.close() + ByteString(output.toByteArray) + } + + def streamDecode(bytes: ByteString): ByteString = { + val output = new ByteArrayOutputStream() + val input = newDecodedInputStream(new ByteArrayInputStream(bytes.toArray)) + + val buffer = new Array[Byte](500) + @tailrec def copy(from: InputStream, to: OutputStream): Unit = { + val read = from.read(buffer) + if (read >= 0) { + to.write(buffer, 0, read) + copy(from, to) + } + } + + copy(input, output) + ByteString(output.toByteArray) + } + + def decodeChunks(input: Source[ByteString]): ByteString = + input.via(Coder.decoderFlow).join.awaitResult(3.seconds) +} diff --git a/akka-http-tests/src/test/scala/akka/http/coding/DecoderSpec.scala b/akka-http-tests/src/test/scala/akka/http/coding/DecoderSpec.scala index cbbf3a056e..9db4bbc2a2 100644 --- a/akka-http-tests/src/test/scala/akka/http/coding/DecoderSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/coding/DecoderSpec.scala @@ -4,6 +4,7 @@ package akka.http.coding +import akka.stream.stage.{ Directive, Context, PushStage, Stage } import akka.util.ByteString import org.scalatest.WordSpec import akka.http.model._ @@ -26,15 +27,15 @@ class DecoderSpec extends WordSpec with CodecSpecSupport { } def dummyDecompress(s: String): String = dummyDecompress(ByteString(s, "UTF8")).decodeString("UTF8") - def dummyDecompress(bytes: ByteString): ByteString = DummyDecompressor.decompress(bytes) + def dummyDecompress(bytes: ByteString): ByteString = DummyDecoder.decode(bytes) - case object DummyDecoder extends Decoder { + case object DummyDecoder extends StreamDecoder { val encoding = HttpEncodings.compress - def newDecompressor = DummyDecompressor - } - case object DummyDecompressor extends Decompressor { - def decompress(buffer: ByteString): ByteString = buffer ++ ByteString("compressed") - def finish(): ByteString = ByteString.empty + def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] = + () ⇒ new PushStage[ByteString, ByteString] { + def onPush(elem: ByteString, ctx: Context[ByteString]): Directive = + ctx.push(elem ++ ByteString("compressed")) + } } } diff --git a/akka-http-tests/src/test/scala/akka/http/coding/DeflateSpec.scala b/akka-http-tests/src/test/scala/akka/http/coding/DeflateSpec.scala index d0ea56de8f..4a88984360 100644 --- a/akka-http-tests/src/test/scala/akka/http/coding/DeflateSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/coding/DeflateSpec.scala @@ -5,78 +5,25 @@ package akka.http.coding import akka.util.ByteString -import akka.http.util._ -import org.scalatest.WordSpec -import akka.http.model._ -import HttpMethods.POST -import java.io.ByteArrayOutputStream -import java.util.zip.{ DeflaterOutputStream, InflaterOutputStream } +import java.io.{ InputStream, OutputStream } +import java.util.zip._ -class DeflateSpec extends WordSpec with CodecSpecSupport { +class DeflateSpec extends CoderSpec { + protected def Coder: Coder with StreamDecoder = Deflate - "The Deflate codec" should { - "properly encode a small string" in { - streamInflate(ourDeflate(smallTextBytes)) should readAs(smallText) - } - "properly decode a small string" in { - ourInflate(streamDeflate(smallTextBytes)) should readAs(smallText) - } - "properly round-trip encode/decode a small string" in { - ourInflate(ourDeflate(smallTextBytes)) should readAs(smallText) - } - "properly encode a large string" in { - streamInflate(ourDeflate(largeTextBytes)) should readAs(largeText) - } - "properly decode a large string" in { - ourInflate(streamDeflate(largeTextBytes)) should readAs(largeText) - } - "properly round-trip encode/decode a large string" in { - ourInflate(ourDeflate(largeTextBytes)) should readAs(largeText) - } - "properly round-trip encode/decode an HttpRequest" in { - val request = HttpRequest(POST, entity = HttpEntity(largeText)) - Deflate.decode(Deflate.encode(request)) should equal(request) - } - "provide a better compression ratio than the standard Deflater/Inflater streams" in { - ourDeflate(largeTextBytes).length should be < streamDeflate(largeTextBytes).length - } - "support chunked round-trip encoding/decoding" in { - val chunks = largeTextBytes.grouped(512).toVector - val comp = Deflate.newCompressor - val decomp = Deflate.newDecompressor - val chunks2 = - chunks.map { chunk ⇒ - decomp.decompress(comp.compressAndFlush(chunk)) - } :+ - decomp.decompress(comp.finish()) - chunks2.join should readAs(largeText) - } - "works for any split in prefix + suffix" in { - val compressed = streamDeflate(smallTextBytes) - def tryWithPrefixOfSize(prefixSize: Int): Unit = { - val decomp = Deflate.newDecompressor - val prefix = compressed.take(prefixSize) - val suffix = compressed.drop(prefixSize) + protected def newDecodedInputStream(underlying: InputStream): InputStream = + new InflaterInputStream(underlying) - decomp.decompress(prefix) ++ decomp.decompress(suffix) should readAs(smallText) - } - (0 to compressed.size).foreach(tryWithPrefixOfSize) + protected def newEncodedOutputStream(underlying: OutputStream): OutputStream = + new DeflaterOutputStream(underlying) + + protected def corruptInputMessage: Option[String] = Some("invalid code -- missing end-of-block") + + override def extraTests(): Unit = { + "throw early if header is corrupt" in { + val ex = the[DataFormatException] thrownBy ourDecode(ByteString(0, 1, 2, 3, 4)) + ex.getMessage should equal("incorrect header check") } } - - def ourDeflate(bytes: ByteString): ByteString = Deflate.encode(bytes) - def ourInflate(bytes: ByteString): ByteString = Deflate.decode(bytes) - - def streamDeflate(bytes: ByteString) = { - val output = new ByteArrayOutputStream() - val dos = new DeflaterOutputStream(output); dos.write(bytes.toArray); dos.close() - ByteString(output.toByteArray) - } - - def streamInflate(bytes: ByteString) = { - val output = new ByteArrayOutputStream() - val ios = new InflaterOutputStream(output); ios.write(bytes.toArray); ios.close() - ByteString(output.toByteArray) - } } diff --git a/akka-http-tests/src/test/scala/akka/http/coding/GzipSpec.scala b/akka-http-tests/src/test/scala/akka/http/coding/GzipSpec.scala index 994cfb795a..9f2b5658f3 100644 --- a/akka-http-tests/src/test/scala/akka/http/coding/GzipSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/coding/GzipSpec.scala @@ -4,135 +4,38 @@ package akka.http.coding -import org.scalatest.WordSpec -import akka.http.model._ import akka.http.util._ -import HttpMethods.POST -import java.io.{ InputStream, OutputStream, ByteArrayInputStream, ByteArrayOutputStream } -import java.util.zip.{ ZipException, DataFormatException, GZIPInputStream, GZIPOutputStream } +import java.io.{ InputStream, OutputStream } +import java.util.zip.{ ZipException, GZIPInputStream, GZIPOutputStream } import akka.util.ByteString -import scala.annotation.tailrec +class GzipSpec extends CoderSpec { + protected def Coder: Coder with StreamDecoder = Gzip -class GzipSpec extends WordSpec with CodecSpecSupport { + protected def newDecodedInputStream(underlying: InputStream): InputStream = + new GZIPInputStream(underlying) - "The Gzip codec" should { - "properly encode a small string" in { - streamGunzip(ourGzip(smallTextBytes)) should readAs(smallText) - } - "properly decode a small string" in { - ourGunzip(streamGzip(smallTextBytes)) should readAs(smallText) - } - "properly round-trip encode/decode a small string" in { - ourGunzip(ourGzip(smallTextBytes)) should readAs(smallText) - } - "properly encode a large string" in { - streamGunzip(ourGzip(largeTextBytes)) should readAs(largeText) - } - "properly decode a large string" in { - ourGunzip(streamGzip(largeTextBytes)) should readAs(largeText) - } - "properly round-trip encode/decode a large string" in { - ourGunzip(ourGzip(largeTextBytes)) should readAs(largeText) - } - "properly round-trip encode/decode an HttpRequest" in { - val request = HttpRequest(POST, entity = HttpEntity(largeText)) - Gzip.decode(Gzip.encode(request)) should equal(request) + protected def newEncodedOutputStream(underlying: OutputStream): OutputStream = + new GZIPOutputStream(underlying) + + protected def corruptInputMessage: Option[String] = Some("invalid code lengths set") + + override def extraTests(): Unit = { + "decode concatenated compressions" in { + ourDecode(Seq(encode("Hello, "), encode("dear "), encode("User!")).join) should readAs("Hello, dear User!") } "provide a better compression ratio than the standard Gzip/Gunzip streams" in { - ourGzip(largeTextBytes).length should be < streamGzip(largeTextBytes).length - } - "properly decode concatenated compressions" in { - ourGunzip(Seq(gzip("Hello,"), gzip(" dear "), gzip("User!")).join) should readAs("Hello, dear User!") - } - "throw an error on corrupt input" in { - val ex = the[DataFormatException] thrownBy ourGunzip(corruptGzipContent) - ex.getMessage should equal("invalid literal/length code") + ourEncode(largeTextBytes).length should be < streamEncode(largeTextBytes).length } "throw an error on truncated input" in { - val ex = the[ZipException] thrownBy ourGunzip(streamGzip(smallTextBytes).dropRight(5)) + val ex = the[ZipException] thrownBy ourDecode(streamEncode(smallTextBytes).dropRight(5)) ex.getMessage should equal("Truncated GZIP stream") } "throw early if header is corrupt" in { - val ex = the[ZipException] thrownBy ourGunzip(ByteString(0, 1, 2, 3, 4)) + val ex = the[ZipException] thrownBy ourDecode(ByteString(0, 1, 2, 3, 4)) ex.getMessage should equal("Not in GZIP format") } - "not throw an error if a subsequent block is corrupt" in { - pending // FIXME: should we read as long as possible and only then report an error, that seems somewhat arbitrary - ourGunzip(Seq(gzip("Hello,"), gzip(" dear "), corruptGzipContent).join) should readAs("Hello, dear ") - } - "decompress in very small chunks" in { - val compressed = gzip("Hello") - val decomp = Gzip.newDecompressor - val result = decomp.decompress(compressed.take(10)) // just the headers - result.size should equal(0) - val data = decomp.decompress(compressed.drop(10)) // the rest - data should readAs("Hello") - } - "support chunked round-trip encoding/decoding" in { - val chunks = largeTextBytes.grouped(512).toVector - val comp = Gzip.newCompressor - val decomp = Gzip.newDecompressor - val chunks2 = - chunks.map { chunk ⇒ decomp.decompress(comp.compressAndFlush(chunk)) } :+ decomp.decompress(comp.finish()) - chunks2.join should readAs(largeText) - } - "works for any split in prefix + suffix" in { - val compressed = streamGzip(smallTextBytes) - def tryWithPrefixOfSize(prefixSize: Int): Unit = { - val decomp = Gzip.newDecompressor - val prefix = compressed.take(prefixSize) - val suffix = compressed.drop(prefixSize) - - decomp.decompress(prefix) ++ decomp.decompress(suffix) should readAs(smallText) - } - (0 to compressed.size).foreach(tryWithPrefixOfSize) - } - "works for chunked compressed data of sizes just above 1024" in { - val comp = new GzipCompressor - val decomp = new GzipDecompressor - - val inputBytes = ByteString("""{"baseServiceURL":"http://www.acme.com","endpoints":{"assetSearchURL":"/search","showsURL":"/shows","mediaContainerDetailURL":"/container","featuredTapeURL":"/tape","assetDetailURL":"/asset","moviesURL":"/movies","recentlyAddedURL":"/recent","topicsURL":"/topics","scheduleURL":"/schedule"},"urls":{"aboutAweURL":"www.foobar.com"},"channelName":"Cool Stuff","networkId":"netId","slotProfile":"slot_1","brag":{"launchesUntilPrompt":10,"daysUntilPrompt":5,"launchesUntilReminder":5,"daysUntilReminder":2},"feedbackEmailAddress":"feedback@acme.com","feedbackEmailSubject":"Commends from User","splashSponsor":[],"adProvider":{"adProviderProfile":"","adProviderProfileAndroid":"","adProviderNetworkID":0,"adProviderSiteSectionNetworkID":0,"adProviderVideoAssetNetworkID":0,"adProviderSiteSectionCustomID":{},"adProviderServerURL":"","adProviderLiveVideoAssetID":""},"update":[{"forPlatform":"ios","store":{"iTunes":"www.something.com"},"minVer":"1.2.3","notificationVer":"1.2.5"},{"forPlatform":"android","store":{"amazon":"www.something.com","play":"www.something.com"},"minVer":"1.2.3","notificationVer":"1.2.5"}],"tvRatingPolicies":[{"type":"sometype","imageKey":"tv_rating_small","durationMS":15000,"precedence":1},{"type":"someothertype","imageKey":"tv_rating_big","durationMS":15000,"precedence":2}],"exts":{"adConfig":{"globals":{"#{adNetworkID}":"2620","#{ssid}":"usa_tveapp"},"iPad":{"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_ipad/shows","adSize":[{"#{height}":90,"#{width}":728}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_ipad&sz=1x1&t=&c=#{doubleclickrandom}"},"watchwithshowtile":{"adMobAdUnitID":"/2620/usa_tveapp_ipad/watchwithshowtile","adSize":[{"#{height}":120,"#{width}":240}]},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_ipad/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"iPadRetina":{"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_ipad/shows","adSize":[{"#{height}":90,"#{width}":728}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_ipad&sz=1x1&t=&c=#{doubleclickrandom}"},"watchwithshowtile":{"adMobAdUnitID":"/2620/usa_tveapp_ipad/watchwithshowtile","adSize":[{"#{height}":120,"#{width}":240}]},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_ipad/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"iPhone":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/home","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/shows","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/shows/#{SHOW_NAME}","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_iphone&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_iphone/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"iPhoneRetina":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/home","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/shows","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_iphone/shows/#{SHOW_NAME}","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_iphone&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_iphone/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"Tablet":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/home","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/shows","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/shows/#{SHOW_NAME}","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_androidtab&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_androidtab/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"TabletHD":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/home","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/shows","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_androidtab/shows/#{SHOW_NAME}","adSize":[{"#{height}":90,"#{width}":728},{"#{height}":50,"#{width}":320},{"#{height}":50,"#{width}":300}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_androidtab&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_androidtab/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"Phone":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_android/home","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_android/shows","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_android/shows/#{SHOW_NAME}","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_android&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_android/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}},"PhoneHD":{"home":{"adMobAdUnitID":"/2620/usa_tveapp_android/home","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"showlist":{"adMobAdUnitID":"/2620/usa_tveapp_android/shows","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"episodepage":{"adMobAdUnitID":"/2620/usa_tveapp_android/shows/#{SHOW_NAME}","adSize":[{"#{height}":50,"#{width}":300},{"#{height}":50,"#{width}":320}]},"launch":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_android&sz=1x1&t=&c=#{doubleclickrandom}"},"showpage":{"doubleClickCallbackURL":"http://pubads.g.doubleclick.net/gampad/ad?iu=/2620/usa_tveapp_android/shows/#{SHOW_NAME}&sz=1x1&t=&c=#{doubleclickrandom}"}}}}}""", "utf8") - val compressed = comp.compressAndFinish(inputBytes) - - val decompressed = decomp.decompress(compressed) - decompressed should equal(inputBytes) - } } - - def gzip(s: String) = ourGzip(ByteString(s, "UTF8")) - def ourGzip(bytes: ByteString): ByteString = Gzip.encode(bytes) - def ourGunzip(bytes: ByteString): ByteString = Gzip.decode(bytes) - - lazy val corruptGzipContent = { - val content = gzip("Hello").toArray - content(14) = 26.toByte - ByteString(content) - } - - def streamGzip(bytes: ByteString): ByteString = { - val output = new ByteArrayOutputStream() - val gos = new GZIPOutputStream(output); gos.write(bytes.toArray); gos.close() - ByteString(output.toByteArray) - } - - def streamGunzip(bytes: ByteString): ByteString = { - val output = new ByteArrayOutputStream() - val input = new GZIPInputStream(new ByteArrayInputStream(bytes.toArray)) - - val buffer = new Array[Byte](500) - @tailrec def copy(from: InputStream, to: OutputStream): Unit = { - val read = from.read(buffer) - if (read >= 0) { - to.write(buffer, 0, read) - copy(from, to) - } - } - - copy(input, output) - ByteString(output.toByteArray) - } - } diff --git a/akka-http-tests/src/test/scala/akka/http/coding/NoCodingSpec.scala b/akka-http-tests/src/test/scala/akka/http/coding/NoCodingSpec.scala new file mode 100644 index 0000000000..139f88c3bc --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/coding/NoCodingSpec.scala @@ -0,0 +1,16 @@ +/* + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.http.coding + +import java.io.{ OutputStream, InputStream } + +class NoCodingSpec extends CoderSpec { + protected def Coder: Coder with StreamDecoder = NoCoding + + protected def corruptInputMessage: Option[String] = None // all input data is valid + + protected def newEncodedOutputStream(underlying: OutputStream): OutputStream = underlying + protected def newDecodedInputStream(underlying: InputStream): InputStream = underlying +} diff --git a/akka-http-tests/src/test/scala/akka/http/server/directives/CodingDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/server/directives/CodingDirectivesSpec.scala index 1d92623e38..8b75201f99 100644 --- a/akka-http-tests/src/test/scala/akka/http/server/directives/CodingDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/server/directives/CodingDirectivesSpec.scala @@ -166,7 +166,7 @@ class CodingDirectivesSpec extends RoutingSpec { response should haveContentEncoding(gzip) chunks.size shouldEqual (11 + 1) // 11 regular + the last one val bytes = chunks.foldLeft(ByteString.empty)(_ ++ _.data) - Gzip.newDecompressor.decompress(bytes) should readAs(text) + Gzip.decode(bytes) should readAs(text) } } } diff --git a/akka-http/src/main/scala/akka/http/coding/Decoder.scala b/akka-http/src/main/scala/akka/http/coding/Decoder.scala index 75c1da3ce2..ee28ccbaba 100644 --- a/akka-http/src/main/scala/akka/http/coding/Decoder.scala +++ b/akka-http/src/main/scala/akka/http/coding/Decoder.scala @@ -6,6 +6,7 @@ package akka.http.coding import akka.http.model._ import akka.http.util.StreamUtils +import akka.stream.stage.Stage import akka.util.ByteString import headers.HttpEncoding import akka.stream.scaladsl.Flow @@ -18,31 +19,38 @@ trait Decoder { decodeData(message).withHeaders(message.headers filterNot Encoder.isContentEncodingHeader) else message.self - def decodeData[T](t: T)(implicit mapper: DataMapper[T]): T = - mapper.transformDataBytes(t, newDecodeTransfomer) + def decodeData[T](t: T)(implicit mapper: DataMapper[T]): T = mapper.transformDataBytes(t, decoderFlow) - def decode(input: ByteString): ByteString = newDecompressor.decompressAndFinish(input) + def maxBytesPerChunk: Int + def withMaxBytesPerChunk(maxBytesPerChunk: Int): Decoder - def newDecompressor: Decompressor - - def newDecodeTransfomer(): Flow[ByteString, ByteString] = { - val decompressor = newDecompressor - - def decodeChunk(bytes: ByteString): ByteString = decompressor.decompress(bytes) - def finish(): ByteString = decompressor.finish() - - StreamUtils.byteStringTransformer(decodeChunk, finish) - } + def decoderFlow: Flow[ByteString, ByteString] + def decode(input: ByteString): ByteString +} +object Decoder { + val MaxBytesPerChunkDefault: Int = 65536 } -/** A stateful object representing ongoing decompression. */ -abstract class Decompressor { - /** Decompress the buffer and return decompressed data. */ - def decompress(input: ByteString): ByteString +/** A decoder that is implemented in terms of a [[Stage]] */ +trait StreamDecoder extends Decoder { outer ⇒ + protected def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] - /** Flushes potential remaining data from any internal buffers and may report on truncation errors */ - def finish(): ByteString + def maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault + def withMaxBytesPerChunk(newMaxBytesPerChunk: Int): Decoder = + new StreamDecoder { + def encoding: HttpEncoding = outer.encoding + override def maxBytesPerChunk: Int = newMaxBytesPerChunk - /** Combines decompress and finish */ - def decompressAndFinish(input: ByteString): ByteString = decompress(input) ++ finish() + def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] = + outer.newDecompressorStage(maxBytesPerChunk) + } + + def decoderFlow: Flow[ByteString, ByteString] = + Flow[ByteString].transform(newDecompressorStage(maxBytesPerChunk)) + + def decode(input: ByteString): ByteString = decodeWithLimits(input) + def decodeWithLimits(input: ByteString, maxBytesSize: Int = Int.MaxValue, maxIterations: Int = 1000): ByteString = + StreamUtils.runStrict(input, decoderFlow, maxBytesSize, maxIterations).get.get + def decodeFromIterator(input: Iterator[ByteString], maxBytesSize: Int = Int.MaxValue, maxIterations: Int = 1000): ByteString = + StreamUtils.runStrict(input, decoderFlow, maxBytesSize, maxIterations).get.get } diff --git a/akka-http/src/main/scala/akka/http/coding/Deflate.scala b/akka-http/src/main/scala/akka/http/coding/Deflate.scala index 5ef8fdad4c..785cd507e2 100644 --- a/akka-http/src/main/scala/akka/http/coding/Deflate.scala +++ b/akka-http/src/main/scala/akka/http/coding/Deflate.scala @@ -4,27 +4,21 @@ package akka.http.coding -import java.io.OutputStream -import java.util.zip.{ DataFormatException, ZipException, Inflater, Deflater } +import java.util.zip.{ Inflater, Deflater } +import akka.stream.stage._ import akka.util.{ ByteStringBuilder, ByteString } import scala.annotation.tailrec import akka.http.util._ import akka.http.model._ -import headers.HttpEncodings +import akka.http.model.headers.HttpEncodings -class Deflate(val messageFilter: HttpMessage ⇒ Boolean) extends Coder { +class Deflate(val messageFilter: HttpMessage ⇒ Boolean) extends Coder with StreamDecoder { val encoding = HttpEncodings.deflate def newCompressor = new DeflateCompressor - def newDecompressor = new DeflateDecompressor -} - -/** - * An encoder and decoder for the HTTP 'deflate' encoding. - */ -object Deflate extends Deflate(Encoder.DefaultFilter) { - def apply(messageFilter: HttpMessage ⇒ Boolean) = new Deflate(messageFilter) + def newDecompressorStage(maxBytesPerChunk: Int) = () ⇒ new DeflateDecompressor(maxBytesPerChunk) } +object Deflate extends Deflate(Encoder.DefaultFilter) class DeflateCompressor extends Compressor { protected lazy val deflater = new Deflater(Deflater.BEST_COMPRESSION, false) @@ -88,27 +82,57 @@ class DeflateCompressor extends Compressor { new Array[Byte](size) } -class DeflateDecompressor extends Decompressor { - protected lazy val inflater = new Inflater() +class DeflateDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) { + protected def createInflater() = new Inflater() - def decompress(input: ByteString): ByteString = - try { - inflater.setInput(input.toArray) - drain(new Array[Byte](input.length * 2)) - } catch { - case e: DataFormatException ⇒ - throw new ZipException(e.getMessage.toOption getOrElse "Invalid ZLIB data format") - } + def initial: State = StartInflate + def afterInflate: State = StartInflate - @tailrec protected final def drain(buffer: Array[Byte], result: ByteString = ByteString.empty): ByteString = { - val len = inflater.inflate(buffer) - if (len > 0) drain(buffer, result ++ ByteString.fromArray(buffer, 0, len)) - else if (inflater.needsDictionary) throw new ZipException("ZLIB dictionary missing") - else result - } - - def finish(): ByteString = { - inflater.end() - ByteString.empty - } + protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = {} + protected def onTruncation(ctx: Context[ByteString]): Directive = ctx.finish() } + +abstract class DeflateDecompressorBase(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends ByteStringParserStage[ByteString] { + protected def createInflater(): Inflater + val inflater = createInflater() + + protected def afterInflate: State + protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit + + /** Start inflating */ + case object StartInflate extends State { + def onPush(data: ByteString, ctx: Context[ByteString]): Directive = { + require(inflater.needsInput()) + inflater.setInput(data.toArray) + + becomeWithRemaining(Inflate()(data), ByteString.empty, ctx) + } + } + + /** Inflate */ + case class Inflate()(data: ByteString) extends IntermediateState { + override def onPull(ctx: Context[ByteString]): Directive = { + val buffer = new Array[Byte](maxBytesPerChunk) + val read = inflater.inflate(buffer) + if (read > 0) { + afterBytesRead(buffer, 0, read) + ctx.push(ByteString.fromArray(buffer, 0, read)) + } else { + val remaining = data.takeRight(inflater.getRemaining) + val next = + if (inflater.finished()) afterInflate + else StartInflate + + becomeWithRemaining(next, remaining, ctx) + } + } + def onPush(elem: ByteString, ctx: Context[ByteString]): Directive = + throw new IllegalStateException("Don't expect a new Element") + } + + def becomeWithRemaining(next: State, remaining: ByteString, ctx: Context[ByteString]) = { + become(next) + if (remaining.isEmpty) current.onPull(ctx) + else current.onPush(remaining, ctx) + } +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/coding/Encoder.scala b/akka-http/src/main/scala/akka/http/coding/Encoder.scala index 9dce5febe4..cbe925f260 100644 --- a/akka-http/src/main/scala/akka/http/coding/Encoder.scala +++ b/akka-http/src/main/scala/akka/http/coding/Encoder.scala @@ -6,6 +6,7 @@ package akka.http.coding import akka.http.model._ import akka.http.util.StreamUtils +import akka.stream.stage.Stage import akka.util.ByteString import headers._ import akka.stream.scaladsl.Flow @@ -21,13 +22,13 @@ trait Encoder { else message.self def encodeData[T](t: T)(implicit mapper: DataMapper[T]): T = - mapper.transformDataBytes(t, newEncodeTransformer) + mapper.transformDataBytes(t, Flow[ByteString].transform(newEncodeTransformer)) def encode(input: ByteString): ByteString = newCompressor.compressAndFinish(input) def newCompressor: Compressor - def newEncodeTransformer(): Flow[ByteString, ByteString] = { + def newEncodeTransformer(): Stage[ByteString, ByteString] = { val compressor = newCompressor def encodeChunk(bytes: ByteString): ByteString = compressor.compressAndFlush(bytes) diff --git a/akka-http/src/main/scala/akka/http/coding/Gzip.scala b/akka-http/src/main/scala/akka/http/coding/Gzip.scala index 782c77793e..7b3dff2d8a 100644 --- a/akka-http/src/main/scala/akka/http/coding/Gzip.scala +++ b/akka-http/src/main/scala/akka/http/coding/Gzip.scala @@ -4,19 +4,19 @@ package akka.http.coding -import java.util.zip.{ Inflater, CRC32, ZipException, Deflater } import akka.util.ByteString +import akka.stream.stage._ + +import akka.http.util.ByteReader +import java.util.zip.{ Inflater, CRC32, ZipException, Deflater } -import scala.annotation.tailrec import akka.http.model._ import headers.HttpEncodings -import scala.util.control.{ NonFatal, NoStackTrace } - -class Gzip(val messageFilter: HttpMessage ⇒ Boolean) extends Coder { +class Gzip(val messageFilter: HttpMessage ⇒ Boolean) extends Coder with StreamDecoder { val encoding = HttpEncodings.gzip def newCompressor = new GzipCompressor - def newDecompressor = new GzipDecompressor + def newDecompressorStage(maxBytesPerChunk: Int) = () ⇒ new GzipDecompressor(maxBytesPerChunk) } /** @@ -59,84 +59,80 @@ class GzipCompressor extends DeflateCompressor { } } -/** A suspendable gzip decompressor */ -class GzipDecompressor extends DeflateDecompressor { - override protected lazy val inflater = new Inflater(true) // disable ZLIB headers - override def decompress(input: ByteString): ByteString = DecompressionStateMachine.run(input) - override def finish(): ByteString = - if (DecompressionStateMachine.isFinished) ByteString.empty - else fail("Truncated GZIP stream") +class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) { + protected def createInflater(): Inflater = new Inflater(true) - import GzipDecompressor._ + def initial: State = Initial - object DecompressionStateMachine extends StateMachine { - def isFinished: Boolean = currentState == finished + /** No bytes were received yet */ + case object Initial extends State { + def onPush(data: ByteString, ctx: Context[ByteString]): Directive = + if (data.isEmpty) ctx.pull() + else becomeWithRemaining(ReadHeaders, data, ctx) - def initialState = finished + override def onPull(ctx: Context[ByteString]): Directive = + if (ctx.isFinishing) { + ctx.finish() + } else super.onPull(ctx) + } - private def readHeaders(data: ByteString): Action = - try { - val reader = new ByteReader(data) - import reader._ + var crc32: CRC32 = new CRC32 + protected def afterInflate: State = ReadTrailer - if (readByte() != 0x1F || readByte() != 0x8B) fail("Not in GZIP format") // check magic header - if (readByte() != 8) fail("Unsupported GZIP compression method") // check compression method - val flags = readByte() - skip(6) // skip MTIME, XFL and OS fields - if ((flags & 4) > 0) skip(readShort()) // skip optional extra fields - if ((flags & 8) > 0) while (readByte() != 0) {} // skip optional file name - if ((flags & 16) > 0) while (readByte() != 0) {} // skip optional file comment - if ((flags & 2) > 0 && crc16(data.take(currentOffset)) != readShort()) fail("Corrupt GZIP header") + /** Reading the header bytes */ + case object ReadHeaders extends ByteReadingState { + def read(reader: ByteReader, ctx: Context[ByteString]): Directive = { + import reader._ - ContinueWith(deflate(new CRC32), remainingData) - } catch { - case ByteReader.NeedMoreData ⇒ SuspendAndRetryWithMoreData - } + if (readByte() != 0x1F || readByte() != 0x8B) fail("Not in GZIP format") // check magic header + if (readByte() != 8) fail("Unsupported GZIP compression method") // check compression method + val flags = readByte() + skip(6) // skip MTIME, XFL and OS fields + if ((flags & 4) > 0) skip(readShort()) // skip optional extra fields + if ((flags & 8) > 0) skipZeroTerminatedString() // skip optional file name + if ((flags & 16) > 0) skipZeroTerminatedString() // skip optional file comment + if ((flags & 2) > 0 && crc16(fromStartToHere) != readShort()) fail("Corrupt GZIP header") - private def deflate(checkSum: CRC32)(data: ByteString): Action = { - assert(inflater.needsInput()) - inflater.setInput(data.toArray) - val output = drain(new Array[Byte](data.length * 2)) - checkSum.update(output.toArray) - if (inflater.finished()) EmitAndContinueWith(output, readTrailer(checkSum), data.takeRight(inflater.getRemaining)) - else EmitAndSuspend(output) - } - - private def readTrailer(checkSum: CRC32)(data: ByteString): Action = - try { - val reader = new ByteReader(data) - import reader._ - - if (readInt() != checkSum.getValue.toInt) fail("Corrupt data (CRC32 checksum error)") - if (readInt() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE") - - inflater.reset() - checkSum.reset() - ContinueWith(finished, remainingData) // start over to support multiple concatenated gzip streams - } catch { - case ByteReader.NeedMoreData ⇒ SuspendAndRetryWithMoreData - } - - lazy val finished: ByteString ⇒ Action = - data ⇒ if (data.nonEmpty) ContinueWith(readHeaders, data) else SuspendAndRetryWithMoreData - - private def crc16(data: ByteString) = { - val crc = new CRC32 - crc.update(data.toArray) - crc.getValue.toInt & 0xFFFF + inflater.reset() + crc32.reset() + becomeWithRemaining(StartInflate, remainingData, ctx) } } - private def fail(msg: String) = throw new ZipException(msg) + protected def afterBytesRead(buffer: Array[Byte], offset: Int, length: Int): Unit = + crc32.update(buffer, offset, length) + /** Reading the trailer */ + case object ReadTrailer extends ByteReadingState { + def read(reader: ByteReader, ctx: Context[ByteString]): Directive = { + import reader._ + + if (readInt() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)") + if (readInt() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE") + + becomeWithRemaining(Initial, remainingData, ctx) + } + } + + override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination() + + private def crc16(data: ByteString) = { + val crc = new CRC32 + crc.update(data.toArray) + crc.getValue.toInt & 0xFFFF + } + + override protected def onTruncation(ctx: Context[ByteString]): Directive = ctx.fail(new ZipException("Truncated GZIP stream")) + + private def fail(msg: String) = throw new ZipException(msg) } /** INTERNAL API */ private[http] object GzipDecompressor { // RFC 1952: http://tools.ietf.org/html/rfc1952 section 2.2 val Header = ByteString( - 31, // ID1 - -117, // ID2 + 0x1F, // ID1 + 0x8B, // ID2 8, // CM = Deflate 0, // FLG 0, // MTIME 1 @@ -146,76 +142,4 @@ private[http] object GzipDecompressor { 0, // XFL 0 // OS ) - - class ByteReader(input: ByteString) { - import ByteReader.NeedMoreData - - private[this] var off = 0 - - def readByte(): Int = - if (off < input.length) { - val x = input(off) - off += 1 - x.toInt & 0xFF - } else throw NeedMoreData - def readShort(): Int = readByte() | (readByte() << 8) - def readInt(): Int = readShort() | (readShort() << 16) - def skip(numBytes: Int): Unit = - if (off + numBytes <= input.length) off += numBytes - else throw NeedMoreData - def currentOffset: Int = off - def remainingData: ByteString = input.drop(off) - } - object ByteReader { - val NeedMoreData = new Exception with NoStackTrace - } - - /** A simple state machine implementation for suspendable parsing */ - trait StateMachine { - sealed trait Action - /** Cache the current input and suspend to wait for more data */ - case object SuspendAndRetryWithMoreData extends Action - /** Emit some output and suspend in the current state and wait for more data */ - case class EmitAndSuspend(output: ByteString) extends Action - /** Proceed to the nextState and immediately run it with the remainingInput */ - case class ContinueWith(nextState: State, remainingInput: ByteString) extends Action - /** Emit some output and then proceed to the nextState and immediately run it with the remainingInput */ - case class EmitAndContinueWith(output: ByteString, nextState: State, remainingInput: ByteString) extends Action - - type State = ByteString ⇒ Action - def initialState: State - - private[this] var state: State = initialState - def currentState: State = state - - /** Run the state machine with the current input */ - final def run(input: ByteString): ByteString = { - @tailrec def rec(input: ByteString, result: ByteString = ByteString.empty): ByteString = - state(input) match { - case SuspendAndRetryWithMoreData ⇒ - val oldState = state - state = { newData ⇒ - state = oldState - oldState(input ++ newData) - } - result - case EmitAndSuspend(output) ⇒ result ++ output - case ContinueWith(next, remainingInput) ⇒ - state = next - if (remainingInput.nonEmpty) rec(remainingInput, result) - else result - case EmitAndContinueWith(output, next, remainingInput) ⇒ - state = next - rec(remainingInput, result ++ output) - } - try rec(input) - catch { - case NonFatal(e) ⇒ - state = failState - throw e - } - } - - private def failState: State = _ ⇒ throw new IllegalStateException("Trying to reuse failed decompressor.") - } -} +} \ No newline at end of file diff --git a/akka-http/src/main/scala/akka/http/coding/NoCoding.scala b/akka-http/src/main/scala/akka/http/coding/NoCoding.scala index 9893ffcbae..8fd9e22765 100644 --- a/akka-http/src/main/scala/akka/http/coding/NoCoding.scala +++ b/akka-http/src/main/scala/akka/http/coding/NoCoding.scala @@ -5,13 +5,15 @@ package akka.http.coding import akka.http.model._ +import akka.http.util.StreamUtils +import akka.stream.stage.Stage import akka.util.ByteString import headers.HttpEncodings /** * An encoder and decoder for the HTTP 'identity' encoding. */ -object NoCoding extends Coder { +object NoCoding extends Coder with StreamDecoder { val encoding = HttpEncodings.identity override def encode[T <: HttpMessage](message: T)(implicit mapper: DataMapper[T]): T#Self = message.self @@ -22,7 +24,9 @@ object NoCoding extends Coder { val messageFilter: HttpMessage ⇒ Boolean = _ ⇒ false def newCompressor = NoCodingCompressor - def newDecompressor = NoCodingDecompressor + + def newDecompressorStage(maxBytesPerChunk: Int): () ⇒ Stage[ByteString, ByteString] = + () ⇒ StreamUtils.limitByteChunksStage(maxBytesPerChunk) } object NoCodingCompressor extends Compressor { @@ -33,7 +37,3 @@ object NoCodingCompressor extends Compressor { def compressAndFlush(input: ByteString): ByteString = input def compressAndFinish(input: ByteString): ByteString = input } -object NoCodingDecompressor extends Decompressor { - def decompress(input: ByteString): ByteString = input - def finish(): ByteString = ByteString.empty -} \ No newline at end of file