Merge pull request #19913 from vans239/19910-checkContentLength

Replaced PushStage based CheckContentLength with GraphStage #19834
This commit is contained in:
drewhk 2016-03-31 11:55:08 +02:00
commit 982243b49a
2 changed files with 36 additions and 23 deletions

View file

@ -5,9 +5,10 @@
package akka.http.impl.engine.rendering
import akka.parboiled2.CharUtils
import akka.stream.SourceShape
import akka.stream.{Attributes, SourceShape}
import akka.util.ByteString
import akka.event.LoggingAdapter
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
import akka.stream.scaladsl._
import akka.stream.stage._
import akka.http.scaladsl.model._
@ -69,28 +70,45 @@ private object RenderSupport {
}
object CheckContentLengthTransformer {
def flow(contentLength: Long) = Flow[ByteString].transform(()
new CheckContentLengthTransformer(contentLength)).named("checkContentLength")
def flow(contentLength: Long) = Flow[ByteString].via(new CheckContentLengthTransformer(contentLength))
}
class CheckContentLengthTransformer(length: Long) extends PushStage[ByteString, ByteString] {
var sent = 0L
final class CheckContentLengthTransformer(length: Long) extends SimpleLinearGraphStage[ByteString] {
override def initialAttributes: Attributes = Attributes.name("CheckContentLength")
override def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = {
sent += elem.length
if (sent > length)
ctx fail InvalidContentLengthException(s"HTTP message had declared Content-Length $length but entity data stream amounts to more bytes")
ctx.push(elem)
}
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {
override def toString = s"CheckContentLength(sent=$sent)"
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = {
if (sent < length)
ctx fail InvalidContentLengthException(s"HTTP message had declared Content-Length $length but entity data stream amounts to ${length - sent} bytes less")
ctx.finish()
}
private var sent = 0L
override def onPush(): Unit = {
val elem = grab(in)
sent += elem.length
if (sent <= length) {
push(out, elem)
} else {
failStage(InvalidContentLengthException(s"HTTP message had declared Content-Length $length but entity data stream amounts to more bytes"))
}
}
override def onUpstreamFinish(): Unit = {
if (sent < length) {
failStage(InvalidContentLengthException(s"HTTP message had declared Content-Length $length but entity data stream amounts to ${length - sent} bytes less"))
} else {
completeStage()
}
}
override def onPull(): Unit = pull(in)
setHandlers(in, out, this)
}
override def toString = "CheckContentLength"
}
private def renderChunk(chunk: HttpEntity.ChunkStreamPart): ByteString = {
import chunk._
val renderedSize = // buffer space required for rendering (without trailer)