+htp #16810 use flushing deflate call on Java > 6

This commit is contained in:
Johannes Rudolph 2015-06-07 11:51:21 +02:00
parent 8527e0347e
commit e540882049
4 changed files with 92 additions and 30 deletions

View file

@ -28,15 +28,22 @@ private[http] object StreamUtils {
/**
* Creates a transformer that will call `f` for each incoming ByteString and output its result. After the complete
* input has been read it will call `finish` once to determine the final ByteString to post to the output.
* Empty ByteStrings are discarded.
*/
def byteStringTransformer(f: ByteString ByteString, finish: () ByteString): Stage[ByteString, ByteString] = {
new PushPullStage[ByteString, ByteString] {
override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective =
ctx.push(f(element))
override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective = {
val data = f(element)
if (data.nonEmpty) ctx.push(data)
else ctx.pull()
}
override def onPull(ctx: Context[ByteString]): SyncDirective =
if (ctx.isFinishing) ctx.pushAndFinish(finish())
else ctx.pull()
if (ctx.isFinishing) {
val data = finish()
if (data.nonEmpty) ctx.pushAndFinish(data)
else ctx.finish()
} else ctx.pull()
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination()
}

View file

@ -11,6 +11,8 @@ import scala.annotation.tailrec
import scala.concurrent.duration._
import scala.concurrent.Await
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.forkjoin.ThreadLocalRandom
import scala.util.Random
import scala.util.control.NoStackTrace
import org.scalatest.{ Inspectors, WordSpec }
import akka.util.ByteString
@ -96,8 +98,6 @@ abstract class CoderSpec extends WordSpec with CodecSpecSupport with Inspectors
ourDecode(compressed) should equal(inputBytes)
}
extraTests()
"shouldn't produce huge ByteStrings for some input" in {
val array = new Array[Byte](10) // FIXME
util.Arrays.fill(array, 1.toByte)
@ -114,6 +114,27 @@ abstract class CoderSpec extends WordSpec with CodecSpecSupport with Inspectors
bs.forall(_ == 1) should equal(true)
}
}
"be able to decode chunk-by-chunk (depending on input chunks)" in {
val minLength = 100
val maxLength = 1000
val numElements = 1000
val random = ThreadLocalRandom.current()
val sizes = Seq.fill(numElements)(random.nextInt(minLength, maxLength))
def createByteString(size: Int): ByteString =
ByteString(Array.fill(size)(1.toByte))
val sizesAfterRoundtrip =
Source(() sizes.toIterator.map(createByteString))
.via(Coder.encoderFlow)
.via(Coder.decoderFlow)
.runFold(Seq.empty[Int])(_ :+ _.size)
sizes shouldEqual sizesAfterRoundtrip.awaitResult(1.second)
}
extraTests()
}
def encode(s: String) = ourEncode(ByteString(s, "UTF8"))

View file

@ -4,6 +4,7 @@
package akka.http.scaladsl.coding
import java.lang.reflect.{ Method, InvocationTargetException }
import java.util.zip.{ Inflater, Deflater }
import akka.stream.stage._
import akka.util.{ ByteStringBuilder, ByteString }
@ -13,6 +14,8 @@ import akka.http.impl.util._
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.HttpEncodings
import scala.util.control.NonFatal
class Deflate(val messageFilter: HttpMessage Boolean) extends Coder with StreamDecoder {
val encoding = HttpEncodings.deflate
def newCompressor = new DeflateCompressor
@ -21,6 +24,8 @@ class Deflate(val messageFilter: HttpMessage ⇒ Boolean) extends Coder with Str
object Deflate extends Deflate(Encoder.DefaultFilter)
class DeflateCompressor extends Compressor {
import DeflateCompressor._
protected lazy val deflater = new Deflater(Deflater.BEST_COMPRESSION, false)
override final def compressAndFlush(input: ByteString): ByteString = {
@ -40,37 +45,16 @@ class DeflateCompressor extends Compressor {
protected def compressWithBuffer(input: ByteString, buffer: Array[Byte]): ByteString = {
assert(deflater.needsInput())
deflater.setInput(input.toArray)
drain(buffer)
}
protected def flushWithBuffer(buffer: Array[Byte]): ByteString = {
// trick the deflater into flushing: switch compression level
// FIXME: use proper APIs and SYNC_FLUSH when Java 6 support is dropped
deflater.deflate(EmptyByteArray, 0, 0)
deflater.setLevel(Deflater.NO_COMPRESSION)
val res1 = drain(buffer)
deflater.setLevel(Deflater.BEST_COMPRESSION)
val res2 = drain(buffer)
res1 ++ res2
drainDeflater(deflater, buffer)
}
protected def flushWithBuffer(buffer: Array[Byte]): ByteString = DeflateCompressor.flush(deflater, buffer)
protected def finishWithBuffer(buffer: Array[Byte]): ByteString = {
deflater.finish()
val res = drain(buffer)
val res = drainDeflater(deflater, buffer)
deflater.end()
res
}
@tailrec
protected final def drain(buffer: Array[Byte], result: ByteStringBuilder = new ByteStringBuilder()): ByteString = {
val len = deflater.deflate(buffer)
if (len > 0) {
result ++= ByteString.fromArray(buffer, 0, len)
drain(buffer, result)
} else {
assert(deflater.needsInput())
result.result()
}
}
private def newTempBuffer(size: Int = 65536): Array[Byte] =
// The default size is somewhat arbitrary, we'd like to guess a better value but Deflater/zlib
// is buffering in an unpredictable manner.
@ -82,6 +66,54 @@ class DeflateCompressor extends Compressor {
new Array[Byte](size)
}
private[http] object DeflateCompressor {
// TODO: remove reflective call once Java 6 support is dropped
/**
* Compatibility mode: reflectively call deflate(..., flushMode) if available or use a hack otherwise
*/
private[this] val flushImplementation: (Deflater, Array[Byte]) ByteString = {
def flushHack(deflater: Deflater, buffer: Array[Byte]): ByteString = {
// hack: change compression mode to provoke flushing
deflater.deflate(EmptyByteArray, 0, 0)
deflater.setLevel(Deflater.NO_COMPRESSION)
val res1 = drainDeflater(deflater, buffer)
deflater.setLevel(Deflater.BEST_COMPRESSION)
val res2 = drainDeflater(deflater, buffer)
res1 ++ res2
}
def reflectiveDeflateWithSyncMode(method: Method, syncFlushConstant: Int)(deflater: Deflater, buffer: Array[Byte]): ByteString =
try {
val written = method.invoke(deflater, buffer, 0: java.lang.Integer, buffer.length: java.lang.Integer, syncFlushConstant: java.lang.Integer).asInstanceOf[Int]
ByteString.fromArray(buffer, 0, written)
} catch {
case t: InvocationTargetException throw t.getTargetException
}
try {
val deflateWithFlush = classOf[Deflater].getMethod("deflate", classOf[Array[Byte]], classOf[Int], classOf[Int], classOf[Int])
require(deflateWithFlush.getReturnType == classOf[Int])
val flushModeSync = classOf[Deflater].getField("SYNC_FLUSH").get(null).asInstanceOf[Int]
reflectiveDeflateWithSyncMode(deflateWithFlush, flushModeSync)
} catch {
case NonFatal(e) flushHack
}
}
def flush(deflater: Deflater, buffer: Array[Byte]): ByteString = flushImplementation(deflater, buffer)
@tailrec
def drainDeflater(deflater: Deflater, buffer: Array[Byte], result: ByteStringBuilder = new ByteStringBuilder()): ByteString = {
val len = deflater.deflate(buffer)
if (len > 0) {
result ++= ByteString.fromArray(buffer, 0, len)
drainDeflater(deflater, buffer, result)
} else {
assert(deflater.needsInput())
result.result()
}
}
}
class DeflateDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) {
protected def createInflater() = new Inflater()

View file

@ -26,6 +26,8 @@ trait Encoder {
def encode(input: ByteString): ByteString = newCompressor.compressAndFinish(input)
def encoderFlow: Flow[ByteString, ByteString, Unit] = Flow[ByteString].transform(newEncodeTransformer)
def newCompressor: Compressor
def newEncodeTransformer(): Stage[ByteString, ByteString] = {