+htp #16810 use flushing deflate call on Java > 6
This commit is contained in:
parent
8527e0347e
commit
e540882049
4 changed files with 92 additions and 30 deletions
|
|
@ -28,15 +28,22 @@ private[http] object StreamUtils {
|
|||
/**
|
||||
* Creates a transformer that will call `f` for each incoming ByteString and output its result. After the complete
|
||||
* input has been read it will call `finish` once to determine the final ByteString to post to the output.
|
||||
* Empty ByteStrings are discarded.
|
||||
*/
|
||||
def byteStringTransformer(f: ByteString ⇒ ByteString, finish: () ⇒ ByteString): Stage[ByteString, ByteString] = {
|
||||
new PushPullStage[ByteString, ByteString] {
|
||||
override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective =
|
||||
ctx.push(f(element))
|
||||
override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective = {
|
||||
val data = f(element)
|
||||
if (data.nonEmpty) ctx.push(data)
|
||||
else ctx.pull()
|
||||
}
|
||||
|
||||
override def onPull(ctx: Context[ByteString]): SyncDirective =
|
||||
if (ctx.isFinishing) ctx.pushAndFinish(finish())
|
||||
else ctx.pull()
|
||||
if (ctx.isFinishing) {
|
||||
val data = finish()
|
||||
if (data.nonEmpty) ctx.pushAndFinish(data)
|
||||
else ctx.finish()
|
||||
} else ctx.pull()
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = ctx.absorbTermination()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ import scala.annotation.tailrec
|
|||
import scala.concurrent.duration._
|
||||
import scala.concurrent.Await
|
||||
import scala.concurrent.ExecutionContext.Implicits.global
|
||||
import scala.concurrent.forkjoin.ThreadLocalRandom
|
||||
import scala.util.Random
|
||||
import scala.util.control.NoStackTrace
|
||||
import org.scalatest.{ Inspectors, WordSpec }
|
||||
import akka.util.ByteString
|
||||
|
|
@ -96,8 +98,6 @@ abstract class CoderSpec extends WordSpec with CodecSpecSupport with Inspectors
|
|||
ourDecode(compressed) should equal(inputBytes)
|
||||
}
|
||||
|
||||
extraTests()
|
||||
|
||||
"shouldn't produce huge ByteStrings for some input" in {
|
||||
val array = new Array[Byte](10) // FIXME
|
||||
util.Arrays.fill(array, 1.toByte)
|
||||
|
|
@ -114,6 +114,27 @@ abstract class CoderSpec extends WordSpec with CodecSpecSupport with Inspectors
|
|||
bs.forall(_ == 1) should equal(true)
|
||||
}
|
||||
}
|
||||
|
||||
"be able to decode chunk-by-chunk (depending on input chunks)" in {
|
||||
val minLength = 100
|
||||
val maxLength = 1000
|
||||
val numElements = 1000
|
||||
|
||||
val random = ThreadLocalRandom.current()
|
||||
val sizes = Seq.fill(numElements)(random.nextInt(minLength, maxLength))
|
||||
def createByteString(size: Int): ByteString =
|
||||
ByteString(Array.fill(size)(1.toByte))
|
||||
|
||||
val sizesAfterRoundtrip =
|
||||
Source(() ⇒ sizes.toIterator.map(createByteString))
|
||||
.via(Coder.encoderFlow)
|
||||
.via(Coder.decoderFlow)
|
||||
.runFold(Seq.empty[Int])(_ :+ _.size)
|
||||
|
||||
sizes shouldEqual sizesAfterRoundtrip.awaitResult(1.second)
|
||||
}
|
||||
|
||||
extraTests()
|
||||
}
|
||||
|
||||
def encode(s: String) = ourEncode(ByteString(s, "UTF8"))
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
package akka.http.scaladsl.coding
|
||||
|
||||
import java.lang.reflect.{ Method, InvocationTargetException }
|
||||
import java.util.zip.{ Inflater, Deflater }
|
||||
import akka.stream.stage._
|
||||
import akka.util.{ ByteStringBuilder, ByteString }
|
||||
|
|
@ -13,6 +14,8 @@ import akka.http.impl.util._
|
|||
import akka.http.scaladsl.model._
|
||||
import akka.http.scaladsl.model.headers.HttpEncodings
|
||||
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
class Deflate(val messageFilter: HttpMessage ⇒ Boolean) extends Coder with StreamDecoder {
|
||||
val encoding = HttpEncodings.deflate
|
||||
def newCompressor = new DeflateCompressor
|
||||
|
|
@ -21,6 +24,8 @@ class Deflate(val messageFilter: HttpMessage ⇒ Boolean) extends Coder with Str
|
|||
object Deflate extends Deflate(Encoder.DefaultFilter)
|
||||
|
||||
class DeflateCompressor extends Compressor {
|
||||
import DeflateCompressor._
|
||||
|
||||
protected lazy val deflater = new Deflater(Deflater.BEST_COMPRESSION, false)
|
||||
|
||||
override final def compressAndFlush(input: ByteString): ByteString = {
|
||||
|
|
@ -40,37 +45,16 @@ class DeflateCompressor extends Compressor {
|
|||
protected def compressWithBuffer(input: ByteString, buffer: Array[Byte]): ByteString = {
|
||||
assert(deflater.needsInput())
|
||||
deflater.setInput(input.toArray)
|
||||
drain(buffer)
|
||||
}
|
||||
protected def flushWithBuffer(buffer: Array[Byte]): ByteString = {
|
||||
// trick the deflater into flushing: switch compression level
|
||||
// FIXME: use proper APIs and SYNC_FLUSH when Java 6 support is dropped
|
||||
deflater.deflate(EmptyByteArray, 0, 0)
|
||||
deflater.setLevel(Deflater.NO_COMPRESSION)
|
||||
val res1 = drain(buffer)
|
||||
deflater.setLevel(Deflater.BEST_COMPRESSION)
|
||||
val res2 = drain(buffer)
|
||||
res1 ++ res2
|
||||
drainDeflater(deflater, buffer)
|
||||
}
|
||||
protected def flushWithBuffer(buffer: Array[Byte]): ByteString = DeflateCompressor.flush(deflater, buffer)
|
||||
protected def finishWithBuffer(buffer: Array[Byte]): ByteString = {
|
||||
deflater.finish()
|
||||
val res = drain(buffer)
|
||||
val res = drainDeflater(deflater, buffer)
|
||||
deflater.end()
|
||||
res
|
||||
}
|
||||
|
||||
@tailrec
|
||||
protected final def drain(buffer: Array[Byte], result: ByteStringBuilder = new ByteStringBuilder()): ByteString = {
|
||||
val len = deflater.deflate(buffer)
|
||||
if (len > 0) {
|
||||
result ++= ByteString.fromArray(buffer, 0, len)
|
||||
drain(buffer, result)
|
||||
} else {
|
||||
assert(deflater.needsInput())
|
||||
result.result()
|
||||
}
|
||||
}
|
||||
|
||||
private def newTempBuffer(size: Int = 65536): Array[Byte] =
|
||||
// The default size is somewhat arbitrary, we'd like to guess a better value but Deflater/zlib
|
||||
// is buffering in an unpredictable manner.
|
||||
|
|
@ -82,6 +66,54 @@ class DeflateCompressor extends Compressor {
|
|||
new Array[Byte](size)
|
||||
}
|
||||
|
||||
private[http] object DeflateCompressor {
|
||||
// TODO: remove reflective call once Java 6 support is dropped
|
||||
/**
|
||||
* Compatibility mode: reflectively call deflate(..., flushMode) if available or use a hack otherwise
|
||||
*/
|
||||
private[this] val flushImplementation: (Deflater, Array[Byte]) ⇒ ByteString = {
|
||||
def flushHack(deflater: Deflater, buffer: Array[Byte]): ByteString = {
|
||||
// hack: change compression mode to provoke flushing
|
||||
deflater.deflate(EmptyByteArray, 0, 0)
|
||||
deflater.setLevel(Deflater.NO_COMPRESSION)
|
||||
val res1 = drainDeflater(deflater, buffer)
|
||||
deflater.setLevel(Deflater.BEST_COMPRESSION)
|
||||
val res2 = drainDeflater(deflater, buffer)
|
||||
res1 ++ res2
|
||||
}
|
||||
def reflectiveDeflateWithSyncMode(method: Method, syncFlushConstant: Int)(deflater: Deflater, buffer: Array[Byte]): ByteString =
|
||||
try {
|
||||
val written = method.invoke(deflater, buffer, 0: java.lang.Integer, buffer.length: java.lang.Integer, syncFlushConstant: java.lang.Integer).asInstanceOf[Int]
|
||||
ByteString.fromArray(buffer, 0, written)
|
||||
} catch {
|
||||
case t: InvocationTargetException ⇒ throw t.getTargetException
|
||||
}
|
||||
|
||||
try {
|
||||
val deflateWithFlush = classOf[Deflater].getMethod("deflate", classOf[Array[Byte]], classOf[Int], classOf[Int], classOf[Int])
|
||||
require(deflateWithFlush.getReturnType == classOf[Int])
|
||||
val flushModeSync = classOf[Deflater].getField("SYNC_FLUSH").get(null).asInstanceOf[Int]
|
||||
reflectiveDeflateWithSyncMode(deflateWithFlush, flushModeSync)
|
||||
} catch {
|
||||
case NonFatal(e) ⇒ flushHack
|
||||
}
|
||||
}
|
||||
|
||||
def flush(deflater: Deflater, buffer: Array[Byte]): ByteString = flushImplementation(deflater, buffer)
|
||||
|
||||
@tailrec
|
||||
def drainDeflater(deflater: Deflater, buffer: Array[Byte], result: ByteStringBuilder = new ByteStringBuilder()): ByteString = {
|
||||
val len = deflater.deflate(buffer)
|
||||
if (len > 0) {
|
||||
result ++= ByteString.fromArray(buffer, 0, len)
|
||||
drainDeflater(deflater, buffer, result)
|
||||
} else {
|
||||
assert(deflater.needsInput())
|
||||
result.result()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class DeflateDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault) extends DeflateDecompressorBase(maxBytesPerChunk) {
|
||||
protected def createInflater() = new Inflater()
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ trait Encoder {
|
|||
|
||||
def encode(input: ByteString): ByteString = newCompressor.compressAndFinish(input)
|
||||
|
||||
def encoderFlow: Flow[ByteString, ByteString, Unit] = Flow[ByteString].transform(newEncodeTransformer)
|
||||
|
||||
def newCompressor: Compressor
|
||||
|
||||
def newEncodeTransformer(): Stage[ByteString, ByteString] = {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue