HTTP: Optimise response rendering and response parsing (#21046)

* +htp,ben HttpBlueprint benchmark

* +act bring OptionVal from artery-dev
+htc raw benchmarks

* =htc request parsing benchmark

* +htc,ben add benchmark for longer (raw) json response

* =htc optimise renderHeaders, less Option allocs

* -htc remove FastClock, not quite worth it
This commit is contained in:
Konrad Malawski 2016-07-28 22:32:40 +02:00 committed by GitHub
parent 9372087464
commit bb701d1725
15 changed files with 526 additions and 135 deletions

View file

@ -382,7 +382,7 @@ object ByteString {
} else throw new IndexOutOfBoundsException(idx.toString) } else throw new IndexOutOfBoundsException(idx.toString)
} }
// Avoid `iterator` in performance sensitive code, call ops directly on ByteString instead /** Avoid `iterator` in performance sensitive code, call ops directly on ByteString instead */
override def iterator: ByteIterator.MultiByteArrayIterator = override def iterator: ByteIterator.MultiByteArrayIterator =
ByteIterator.MultiByteArrayIterator(bytestrings.toStream map { _.iterator }) ByteIterator.MultiByteArrayIterator(bytestrings.toStream map { _.iterator })

View file

@ -0,0 +1,69 @@
/**
* Copyright (C) 2016 Lightbend Inc. <http://www.lightbend.com>
*/
package akka.util
/**
* INTERNAL API
*/
private[akka] object OptionVal {
def apply[A >: Null](x: A): OptionVal[A] = new OptionVal(x)
object Some {
def apply[A >: Null](x: A): OptionVal[A] = new OptionVal(x)
def unapply[A >: Null](x: OptionVal[A]): OptionVal[A] = x
}
/**
* Represents non-existent values, `null` values.
*/
val None = new OptionVal[Null](null)
}
/**
* INTERNAL API
* Represents optional values similar to `scala.Option`, but
* as a value class to avoid allocations.
*
* Note that it can be used in pattern matching without allocations
* because it has name based extractor using methods `isEmpty` and `get`.
* See http://hseeberger.github.io/blog/2013/10/04/name-based-extractors-in-scala-2-dot-11/
*/
private[akka] final class OptionVal[+A >: Null](val x: A) extends AnyVal {
/**
* Returns true if the option is `OptionVal.None`, false otherwise.
*/
def isEmpty: Boolean =
x == null
/**
* Returns true if the option is `OptionVal.None`, false otherwise.
*/
def isDefined: Boolean = !isEmpty
/**
* Returns the option's value if the option is nonempty, otherwise
* return `default`.
*/
def getOrElse[B >: A](default: B): B =
if (x == null) default else x
/**
* Returns the option's value if it is nonempty, or `null` if it is empty.
*/
def orNull[A1 >: A](implicit ev: Null <:< A1): A1 = this getOrElse ev(null)
/**
* Returns the option's value.
* @note The option must be nonEmpty.
* @throws java.util.NoSuchElementException if the option is empty.
*/
def get: A =
if (x == null) throw new NoSuchElementException("OptionVal.None.get")
else x
override def toString: String =
if (x == null) "None" else s"Some($x)"
}

View file

@ -110,30 +110,13 @@ class HttpBlueprintBenchmark {
Flow.fromSinkAndSource(Sink.cancelled, Source.empty) Flow.fromSinkAndSource(Sink.cancelled, Source.empty)
@Benchmark @Benchmark
@OperationsPerInvocation(100 * 1000) @OperationsPerInvocation(100000)
def run_10000_reqs(blackhole: Blackhole) = { def run_10000_reqs() = {
val n = 100 * 1000 val n = 100000
val latch = new CountDownLatch(n) val latch = new CountDownLatch(n)
val replyCountdown = reply map { x => val replyCountdown = reply map { x =>
latch.countDown() latch.countDown()
blackhole.consume(x)
x
}
server(n).joinMat(replyCountdown)(Keep.right).run()(materializer)
latch.await()
}
@Benchmark
@OperationsPerInvocation(10 * 1000)
def run_1000_reqs(blackhole: Blackhole) = {
val n = 10 * 1000
val latch = new CountDownLatch(n)
val replyCountdown = reply map { x =>
latch.countDown()
blackhole.consume(x)
x x
} }
server(n).joinMat(replyCountdown)(Keep.right).run()(materializer) server(n).joinMat(replyCountdown)(Keep.right).run()(materializer)
@ -142,3 +125,4 @@ class HttpBlueprintBenchmark {
} }
} }

View file

@ -3,25 +3,24 @@
*/ */
package akka.http package akka.http
import java.util.concurrent.{ CountDownLatch, TimeUnit } import java.util.concurrent.TimeUnit
import javax.net.ssl.SSLContext import javax.net.ssl.SSLContext
import akka.Done import akka.Done
import akka.actor.ActorSystem import akka.actor.ActorSystem
import akka.http.impl.engine.parsing.ParserOutput.RequestOutput import akka.event.NoLogging
import akka.http.impl.engine.parsing.{ HttpHeaderParser, HttpMessageParser, HttpRequestParser } import akka.http.impl.engine.parsing.{ HttpHeaderParser, HttpRequestParser }
import akka.http.scaladsl.settings.ParserSettings import akka.http.scaladsl.settings.ParserSettings
import akka.event.NoLogging import akka.event.NoLogging
import akka.stream.ActorMaterializer
import akka.stream.TLSProtocol.SessionBytes import akka.stream.TLSProtocol.SessionBytes
import akka.stream.scaladsl.RunnableGraph import akka.stream.scaladsl._
import akka.stream.{ ActorMaterializer, Attributes }
import akka.stream.scaladsl.{ Flow, Keep, Sink, Source }
import akka.util.ByteString import akka.util.ByteString
import org.openjdk.jmh.annotations.{ OperationsPerInvocation, _ } import org.openjdk.jmh.annotations.{ OperationsPerInvocation, _ }
import org.openjdk.jmh.infra.Blackhole import org.openjdk.jmh.infra.Blackhole
import scala.concurrent.{ Await, Future }
import scala.concurrent.duration._ import scala.concurrent.duration._
import scala.concurrent.{ Await, Future }
@State(Scope.Benchmark) @State(Scope.Benchmark)
@OutputTimeUnit(TimeUnit.SECONDS) @OutputTimeUnit(TimeUnit.SECONDS)
@ -29,48 +28,70 @@ import scala.concurrent.duration._
class HttpRequestParsingBenchmark { class HttpRequestParsingBenchmark {
implicit val system: ActorSystem = ActorSystem("HttpRequestParsingBenchmark") implicit val system: ActorSystem = ActorSystem("HttpRequestParsingBenchmark")
implicit val materializer = ActorMaterializer() implicit val materializer = ActorMaterializer()(system)
val parserSettings = ParserSettings(system) val parserSettings = ParserSettings(system)
val parser = new HttpRequestParser(parserSettings, false, HttpHeaderParser(parserSettings, NoLogging)()) val parser = new HttpRequestParser(parserSettings, false, HttpHeaderParser(parserSettings, NoLogging)())
val dummySession = SSLContext.getDefault.createSSLEngine.getSession val dummySession = SSLContext.getDefault.createSSLEngine.getSession
val requestBytes = SessionBytes(
@Param(Array("small", "large"))
var req: String = ""
def request = req match {
case "small" => requestBytesSmall
case "large" => requestBytesLarge
}
val requestBytesSmall: SessionBytes = SessionBytes(
dummySession, dummySession,
ByteString( ByteString(
"GET / HTTP/1.1\r\n" + """|GET / HTTP/1.1
"Accept: */*\r\n" + |Accept: */*
"Accept-Encoding: gzip, deflate\r\n" + |Accept-Encoding: gzip, deflate
"Connection: keep-alive\r\n" + |Connection: keep-alive
"Host: example.com\r\n" + |Host: example.com
"User-Agent: HTTPie/0.9.3\r\n" + |User-Agent: HTTPie/0.9.3
"\r\n" |
|""".stripMargin.replaceAll("\n", "\r\n")
) )
) )
val requestBytesLarge: SessionBytes = SessionBytes(
dummySession,
ByteString(
"""|GET /json HTTP/1.1
|Host: server
|User-Agent: Mozilla/5.0 (X11; Linux x86_64) Gecko/20130501 Firefox/30.0 AppleWebKit/600.00 Chrome/30.0.0000.0 Trident/10.0 Safari/600.00
|Cookie: uid=12345678901234567890; __utma=1.1234567890.1234567890.1234567890.1234567890.12; wd=2560x1600
|Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
|Accept-Language: en-US,en;q=0.5
|Connection: keep-alive
|
|""".stripMargin.replaceAll("\n", "\r\n")
)
)
/*
// before:
[info] Benchmark (req) Mode Cnt Score Error Units
[info] HttpRequestParsingBenchmark.parse_10000_requests small thrpt 20 358 982.157 ± 93745.863 ops/s
[info] HttpRequestParsingBenchmark.parse_10000_requests large thrpt 20 388 335.666 ± 16990.715 ops/s
// after:
[info] HttpRequestParsingBenchmark.parse_10000_requests_val small thrpt 20 623 975.879 ± 6191.897 ops/s
[info] HttpRequestParsingBenchmark.parse_10000_requests_val large thrpt 20 507 460.283 ± 4735.843 ops/s
*/
val httpMessageParser = Flow.fromGraph(parser) val httpMessageParser = Flow.fromGraph(parser)
def flow(n: Int): RunnableGraph[Future[Done]] = def flow(bytes: SessionBytes, n: Int): RunnableGraph[Future[Done]] =
Source.repeat(requestBytes).take(n) Source.repeat(request).take(n)
.via(httpMessageParser) .via(httpMessageParser)
.toMat(Sink.ignore)(Keep.right) .toMat(Sink.ignore)(Keep.right)
@Benchmark @Benchmark
@OperationsPerInvocation(10000) @OperationsPerInvocation(10000)
def parse_10000_single_requests(blackhole: Blackhole): Unit = { def parse_10000_requests_val(blackhole: Blackhole): Unit = {
val done = flow(10000).run() val done = flow(requestBytesSmall, 10000).run()
Await.ready(done, 32.days)
}
@Benchmark
@OperationsPerInvocation(1000)
def parse_1000_single_requests(blackhole: Blackhole): Unit = {
val done = flow(1000).run()
Await.ready(done, 32.days)
}
@Benchmark
@OperationsPerInvocation(100)
def parse_100_single_requests(blackhole: Blackhole): Unit = {
val done = flow(100).run()
Await.ready(done, 32.days) Await.ready(done, 32.days)
} }

View file

@ -0,0 +1,250 @@
/**
* Copyright (C) 2015-2016 Lightbend Inc. <http://www.lightbend.com>
*/
package akka.http
import java.util.concurrent.{ CountDownLatch, TimeUnit }
import akka.NotUsed
import akka.actor.ActorSystem
import akka.event.NoLogging
import akka.http.impl.engine.rendering.ResponseRenderingOutput.HttpData
import akka.http.impl.engine.rendering.{ HttpResponseRendererFactory, ResponseRenderingContext, ResponseRenderingOutput }
import akka.http.scaladsl.Http
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.Server
import akka.http.scaladsl.unmarshalling.Unmarshal
import akka.stream._
import akka.stream.scaladsl._
import akka.stream.stage.{ GraphStageLogic, GraphStageWithMaterializedValue, InHandler }
import akka.util.ByteString
import com.typesafe.config.ConfigFactory
import org.openjdk.jmh.annotations._
import org.openjdk.jmh.infra.Blackhole
import scala.concurrent.duration._
import scala.concurrent.{ Await, Future }
import scala.util.Try
@State(Scope.Benchmark)
@OutputTimeUnit(TimeUnit.SECONDS)
@BenchmarkMode(Array(Mode.Throughput))
class HttpResponseRenderingBenchmark extends HttpResponseRendererFactory(
serverHeader = Some(Server("Akka HTTP 2.4.x")),
responseHeaderSizeHint = 64,
log = NoLogging
) {
val config = ConfigFactory.parseString(
"""
akka {
loglevel = "ERROR"
}""".stripMargin
).withFallback(ConfigFactory.load())
implicit val system = ActorSystem("HttpResponseRenderingBenchmark", config)
implicit val materializer = ActorMaterializer()
import system.dispatcher
val requestRendered = ByteString(
"GET / HTTP/1.1\r\n" +
"Accept: */*\r\n" +
"Accept-Encoding: gzip, deflate\r\n" +
"Connection: keep-alive\r\n" +
"Host: example.com\r\n" +
"User-Agent: HTTPie/0.9.3\r\n" +
"\r\n"
)
def TCPPlacebo(requests: Int): Flow[ByteString, ByteString, NotUsed] =
Flow.fromSinkAndSource(
Flow[ByteString].takeWhile(it => !(it.utf8String contains "Connection: close")) to Sink.ignore,
Source.repeat(requestRendered).take(requests)
)
def TlsPlacebo = TLSPlacebo()
val requestRendering: Flow[HttpRequest, String, NotUsed] =
Http()
.clientLayer(headers.Host("blah.com"))
.atop(TlsPlacebo)
.join {
Flow[ByteString].map { x
val response = s"HTTP/1.1 200 OK\r\nContent-Length: ${x.size}\r\n\r\n"
ByteString(response) ++ x
}
}
.mapAsync(1)(response => Unmarshal(response).to[String])
def renderResponse: Future[String] = Source.single(HttpRequest(uri = "/foo"))
.via(requestRendering)
.runWith(Sink.head)
var request: HttpRequest = _
var pool: Flow[(HttpRequest, Int), (Try[HttpResponse], Int), _] = _
@TearDown
def shutdown(): Unit = {
Await.ready(Http().shutdownAllConnectionPools(), 1.second)
Await.result(system.terminate(), 5.seconds)
}
/*
[info] Benchmark Mode Cnt Score Error Units
[info] HttpResponseRenderingBenchmark.header_date_val thrpt 20 2 704 169 260 029.906 ± 234456086114.237 ops/s
// def, normal time
[info] HttpResponseRenderingBenchmark.header_date_def thrpt 20 178 297 625 609.638 ± 7429280865.659 ops/s
[info] HttpResponseRenderingBenchmark.response_ok_simple_val thrpt 20 1 258 119.673 ± 58399.454 ops/s
[info] HttpResponseRenderingBenchmark.response_ok_simple_def thrpt 20 687 576.928 ± 94813.618 ops/s
// clock nanos
[info] HttpResponseRenderingBenchmark.response_ok_simple_clock thrpt 20 1 676 438.649 ± 33976.590 ops/s
[info] HttpResponseRenderingBenchmark.response_ok_simple_clock thrpt 40 1 199 462.263 ± 222226.304 ops/s
// ------
// before optimisig collectFirst
[info] HttpResponseRenderingBenchmark.json_response thrpt 20 1 782 572.845 ± 16572.625 ops/s
[info] HttpResponseRenderingBenchmark.simple_response thrpt 20 1 611 802.216 ± 19557.151 ops/s
// after removing collectFirst and Option from renderHeaders
// not much of a difference, but hey, less Option allocs
[info] HttpResponseRenderingBenchmark.json_response thrpt 20 1 785 152.896 ± 15210.299 ops/s
[info] HttpResponseRenderingBenchmark.simple_response thrpt 20 1 783 800.184 ± 14938.415 ops/s
// -----
// baseline for this optimisation is the above results (after collectFirst).
// after introducing pre-rendered ContentType headers:
normal clock
[info] HttpResponseRenderingBenchmark.json_long_raw_response thrpt 20 1738558.895 ± 159612.661 ops/s
[info] HttpResponseRenderingBenchmark.json_response thrpt 20 1714176.824 ± 100011.642 ops/s
"fast clock"
[info] HttpResponseRenderingBenchmark.json_long_raw_response thrpt 20 1 528 632.480 ± 44934.827 ops/s
[info] HttpResponseRenderingBenchmark.json_response thrpt 20 1 517 383.792 ± 28256.716 ops/s
*/
/**
* HTTP/1.1 200 OK
* Server: Akka HTTP 2.4.x
* Date: Tue, 26 Jul 2016 15:26:53 GMT
* Content-Type: text/plain; charset=UTF-8
* Content-Length: 6
*
* ENTITY
*/
val simpleResponse =
ResponseRenderingContext(
response = HttpResponse(
200,
headers = Nil,
entity = HttpEntity("ENTITY")
),
requestMethod = HttpMethods.GET
)
/**
* HTTP/1.1 200 OK
* Server: Akka HTTP 2.4.x
* Date: Tue, 26 Jul 2016 15:26:53 GMT
* Content-Type: application/json
* Content-Length: 27
*
* {"message":"Hello, World!"}
*/
val jsonResponse =
ResponseRenderingContext(
response = HttpResponse(
200,
headers = Nil,
entity = HttpEntity(ContentTypes.`application/json`, """{"message":"Hello, World!"}""")
),
requestMethod = HttpMethods.GET
)
/**
* HTTP/1.1 200 OK
* Server: Akka HTTP 2.4.x
* Date: Tue, 26 Jul 2016 15:26:53 GMT
* Content-Type: application/json
* Content-Length: 315
*
* [{"id":4174,"randomNumber":331},{"id":51,"randomNumber":6544},{"id":4462,"randomNumber":952},{"id":2221,"randomNumber":532},{"id":9276,"randomNumber":3097},{"id":3056,"randomNumber":7293},{"id":6964,"randomNumber":620},{"id":675,"randomNumber":6601},{"id":8414,"randomNumber":6569},{"id":2753,"randomNumber":4065}]
*/
val jsonLongRawResponse =
ResponseRenderingContext(
response = HttpResponse(
200,
headers = Nil,
entity = HttpEntity(ContentTypes.`application/json`, """[{"id":4174,"randomNumber":331},{"id":51,"randomNumber":6544},{"id":4462,"randomNumber":952},{"id":2221,"randomNumber":532},{"id":9276,"randomNumber":3097},{"id":3056,"randomNumber":7293},{"id":6964,"randomNumber":620},{"id":675,"randomNumber":6601},{"id":8414,"randomNumber":6569},{"id":2753,"randomNumber":4065}]""")
),
requestMethod = HttpMethods.GET
)
@Benchmark
@Threads(8)
@OperationsPerInvocation(100 * 1000)
def simple_response(blackhole: Blackhole): Unit =
renderToImpl(simpleResponse, blackhole, n = 100 * 1000).await()
@Benchmark
@OperationsPerInvocation(100 * 1000)
def json_response(blackhole: Blackhole): Unit =
renderToImpl(jsonResponse, blackhole, n = 100 * 1000).await()
/*
Difference between 27 and 315 bytes long JSON is:
[info] Benchmark Mode Cnt Score Error Units
[info] HttpResponseRenderingBenchmark.json_long_raw_response thrpt 20 1 932 331.049 ± 64125.621 ops/s
[info] HttpResponseRenderingBenchmark.json_response thrpt 20 1 973 232.941 ± 18568.314 ops/s
*/
@Benchmark
@OperationsPerInvocation(100 * 1000)
def json_long_raw_response(blackhole: Blackhole): Unit =
renderToImpl(jsonLongRawResponse, blackhole, n = 100 * 1000).await()
class JitSafeLatch[A](blackhole: Blackhole, n: Int) extends GraphStageWithMaterializedValue[SinkShape[A], CountDownLatch] {
val in = Inlet[A]("JitSafeLatch.in")
override val shape = SinkShape(in)
override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, CountDownLatch) = {
val latch = new CountDownLatch(n)
val logic = new GraphStageLogic(shape) with InHandler {
override def preStart(): Unit = pull(in)
override def onPush(): Unit = {
if (blackhole ne null) blackhole.consume(grab(in))
latch.countDown()
pull(in)
}
setHandler(in, this)
}
(logic, latch)
}
}
def renderToImpl(ctx: ResponseRenderingContext, blackhole: Blackhole, n: Int)(implicit mat: Materializer): CountDownLatch = {
val latch =
(Source.repeat(ctx).take(n) ++ Source.maybe[ResponseRenderingContext]) // never send upstream completion
.via(renderer.named("renderer"))
.runWith(new JitSafeLatch[ResponseRenderingOutput](blackhole, n))
latch
}
// TODO benchmark with stable override
override def currentTimeMillis(): Long = System.currentTimeMillis()
// override def currentTimeMillis(): Long = System.currentTimeMillis() // DateTime(2011, 8, 25, 9, 10, 29).clicks // provide a stable date for testing
}

View file

@ -8,7 +8,7 @@ import java.lang.{ StringBuilder ⇒ JStringBuilder }
import scala.annotation.{ switch, tailrec } import scala.annotation.{ switch, tailrec }
import akka.http.scaladsl.settings.ParserSettings import akka.http.scaladsl.settings.ParserSettings
import akka.util.ByteString import akka.util.{ ByteString, OptionVal }
import akka.http.impl.engine.ws.Handshake import akka.http.impl.engine.ws.Handshake
import akka.http.impl.model.parser.CharacterClasses import akka.http.impl.model.parser.CharacterClasses
import akka.http.scaladsl.model._ import akka.http.scaladsl.model._
@ -160,8 +160,8 @@ private[http] final class HttpRequestParser(
val allHeaders = val allHeaders =
if (method == HttpMethods.GET) { if (method == HttpMethods.GET) {
Handshake.Server.websocketUpgrade(headers, hostHeaderPresent) match { Handshake.Server.websocketUpgrade(headers, hostHeaderPresent) match {
case Some(upgrade) upgrade :: allHeaders0 case OptionVal.Some(upgrade) upgrade :: allHeaders0
case None allHeaders0 case OptionVal.None allHeaders0
} }
} else allHeaders0 } else allHeaders0

View file

@ -7,11 +7,11 @@ package akka.http.impl.engine.rendering
import akka.NotUsed import akka.NotUsed
import akka.http.impl.engine.ws.{ FrameEvent, UpgradeToWebSocketResponseHeader } import akka.http.impl.engine.ws.{ FrameEvent, UpgradeToWebSocketResponseHeader }
import akka.http.scaladsl.model.ws.Message import akka.http.scaladsl.model.ws.Message
import akka.stream.{ Outlet, Inlet, Attributes, FlowShape, Graph } import akka.stream.{ Attributes, FlowShape, Graph, Inlet, Outlet }
import scala.annotation.tailrec import scala.annotation.tailrec
import akka.event.LoggingAdapter import akka.event.LoggingAdapter
import akka.util.ByteString import akka.util.{ ByteString, OptionVal }
import akka.stream.scaladsl.{ Flow, Source } import akka.stream.scaladsl.{ Flow, Source }
import akka.stream.stage._ import akka.stream.stage._
import akka.http.scaladsl.model._ import akka.http.scaladsl.model._
@ -20,6 +20,8 @@ import RenderSupport._
import HttpProtocols._ import HttpProtocols._
import headers._ import headers._
import scala.concurrent.duration._
/** /**
* INTERNAL API * INTERNAL API
*/ */
@ -129,9 +131,17 @@ private[http] class HttpResponseRendererFactory(
@tailrec def renderHeaders(remaining: List[HttpHeader], alwaysClose: Boolean = false, @tailrec def renderHeaders(remaining: List[HttpHeader], alwaysClose: Boolean = false,
connHeader: Connection = null, serverSeen: Boolean = false, connHeader: Connection = null, serverSeen: Boolean = false,
transferEncodingSeen: Boolean = false, dateSeen: Boolean = false): Unit = transferEncodingSeen: Boolean = false, dateSeen: Boolean = false): Unit = {
remaining match { remaining match {
case head :: tail head match { case head :: tail head match {
case x: Server
render(x)
renderHeaders(tail, alwaysClose, connHeader, serverSeen = true, transferEncodingSeen, dateSeen)
case x: Date
render(x)
renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen = true)
case x: `Content-Length` case x: `Content-Length`
suppressionWarning(log, x, "explicit `Content-Length` header is not allowed. Use the appropriate HttpEntity subtype.") suppressionWarning(log, x, "explicit `Content-Length` header is not allowed. Use the appropriate HttpEntity subtype.")
renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen)
@ -140,10 +150,6 @@ private[http] class HttpResponseRendererFactory(
suppressionWarning(log, x, "explicit `Content-Type` header is not allowed. Set `HttpResponse.entity.contentType` instead.") suppressionWarning(log, x, "explicit `Content-Type` header is not allowed. Set `HttpResponse.entity.contentType` instead.")
renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen)
case x: Date
render(x)
renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen = true)
case x: `Transfer-Encoding` case x: `Transfer-Encoding`
x.withChunkedPeeled match { x.withChunkedPeeled match {
case None case None
@ -159,10 +165,6 @@ private[http] class HttpResponseRendererFactory(
val connectionHeader = if (connHeader eq null) x else Connection(x.tokens ++ connHeader.tokens) val connectionHeader = if (connHeader eq null) x else Connection(x.tokens ++ connHeader.tokens)
renderHeaders(tail, alwaysClose, connectionHeader, serverSeen, transferEncodingSeen, dateSeen) renderHeaders(tail, alwaysClose, connectionHeader, serverSeen, transferEncodingSeen, dateSeen)
case x: Server
render(x)
renderHeaders(tail, alwaysClose, connHeader, serverSeen = true, transferEncodingSeen, dateSeen)
case x: CustomHeader case x: CustomHeader
if (x.renderInResponses) render(x) if (x.renderInResponses) render(x)
renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen)
@ -205,13 +207,15 @@ private[http] class HttpResponseRendererFactory(
r ~~ Connection ~~ (if (close) CloseBytes else KeepAliveBytes) ~~ CrLf r ~~ Connection ~~ (if (close) CloseBytes else KeepAliveBytes) ~~ CrLf
else if (connHeader != null && connHeader.hasUpgrade) { else if (connHeader != null && connHeader.hasUpgrade) {
r ~~ connHeader ~~ CrLf r ~~ connHeader ~~ CrLf
headers HttpHeader.fastFind(classOf[UpgradeToWebSocketResponseHeader], headers) match {
.collectFirst { case u: UpgradeToWebSocketResponseHeader u } case OptionVal.Some(header) closeMode = SwitchToWebSocket(header.handler)
.foreach { header closeMode = SwitchToWebSocket(header.handler) } case _ // nothing to do here...
}
} }
if (mustRenderTransferEncodingChunkedHeader && !transferEncodingSeen) if (mustRenderTransferEncodingChunkedHeader && !transferEncodingSeen)
r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf
} }
}
def renderContentLengthHeader(contentLength: Long) = def renderContentLengthHeader(contentLength: Long) =
if (status.allowsEntity) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r if (status.allowsEntity) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r
@ -219,7 +223,7 @@ private[http] class HttpResponseRendererFactory(
def byteStrings(entityBytes: Source[ByteString, Any]): Source[ResponseRenderingOutput, Any] = def byteStrings(entityBytes: Source[ByteString, Any]): Source[ResponseRenderingOutput, Any] =
renderByteStrings(r, entityBytes, skipEntity = noEntity).map(ResponseRenderingOutput.HttpData(_)) renderByteStrings(r, entityBytes, skipEntity = noEntity).map(ResponseRenderingOutput.HttpData(_))
def completeResponseRendering(entity: ResponseEntity): StrictOrStreamed = @tailrec def completeResponseRendering(entity: ResponseEntity): StrictOrStreamed =
entity match { entity match {
case HttpEntity.Strict(_, data) case HttpEntity.Strict(_, data)
renderHeaders(headers.toList) renderHeaders(headers.toList)

View file

@ -28,6 +28,20 @@ private object RenderSupport {
val KeepAliveBytes = "Keep-Alive".asciiBytes val KeepAliveBytes = "Keep-Alive".asciiBytes
val CloseBytes = "close".asciiBytes val CloseBytes = "close".asciiBytes
private[this] final val PreRenderedContentTypes = {
val m = new java.util.HashMap[ContentType, Array[Byte]](16)
def preRenderContentType(ct: ContentType) =
m.put(ct, (new ByteArrayRendering(32) ~~ headers.`Content-Type` ~~ ct ~~ CrLf).get)
import ContentTypes._
preRenderContentType(`application/json`)
preRenderContentType(`text/plain(UTF-8)`)
preRenderContentType(`text/xml(UTF-8)`)
preRenderContentType(`text/html(UTF-8)`)
preRenderContentType(`text/csv(UTF-8)`)
m
}
def CrLf = Rendering.CrLf def CrLf = Rendering.CrLf
implicit val trailerRenderer = Renderer.genericSeqRenderer[Renderable, HttpHeader](CrLf, Rendering.Empty) implicit val trailerRenderer = Renderer.genericSeqRenderer[Renderable, HttpHeader](CrLf, Rendering.Empty)
@ -42,9 +56,14 @@ private object RenderSupport {
}) })
} }
def renderEntityContentType(r: Rendering, entity: HttpEntity) = def renderEntityContentType(r: Rendering, entity: HttpEntity) = {
if (entity.contentType != ContentTypes.NoContentType) r ~~ headers.`Content-Type` ~~ entity.contentType ~~ CrLf val ct = entity.contentType
else r if (ct != ContentTypes.NoContentType) {
val preRendered = PreRenderedContentTypes.get(ct)
if (preRendered ne null) r ~~ preRendered // re-use pre-rendered
else r ~~ headers.`Content-Type` ~~ ct ~~ CrLf // render ad-hoc
} else r // don't render
}
def renderByteStrings(r: ByteStringRendering, entityBytes: Source[ByteString, Any], def renderByteStrings(r: ByteStringRendering, entityBytes: Source[ByteString, Any],
skipEntity: Boolean = false): Source[ByteString, Any] = { skipEntity: Boolean = false): Source[ByteString, Any] = {

View file

@ -5,6 +5,7 @@
package akka.http.impl.engine.ws package akka.http.impl.engine.ws
import java.util.Random import java.util.Random
import scala.collection.immutable import scala.collection.immutable
import scala.collection.immutable.Seq import scala.collection.immutable.Seq
import scala.reflect.ClassTag import scala.reflect.ClassTag
@ -12,7 +13,8 @@ import akka.http.impl.util._
import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.model.ws.{ Message, UpgradeToWebSocket } import akka.http.scaladsl.model.ws.{ Message, UpgradeToWebSocket }
import akka.http.scaladsl.model._ import akka.http.scaladsl.model._
import akka.stream.{ Graph, FlowShape } import akka.stream.{ FlowShape, Graph }
import akka.util.OptionVal
/** /**
* Server-side implementation of the WebSocket handshake * Server-side implementation of the WebSocket handshake
@ -62,50 +64,63 @@ private[http] object Handshake {
* to speak. The interpretation of this header field is discussed * to speak. The interpretation of this header field is discussed
* in Section 9.1. * in Section 9.1.
*/ */
def websocketUpgrade(headers: List[HttpHeader], hostHeaderPresent: Boolean): Option[UpgradeToWebSocket] = { def websocketUpgrade(headers: List[HttpHeader], hostHeaderPresent: Boolean): OptionVal[UpgradeToWebSocket] = {
def find[T <: HttpHeader: ClassTag]: Option[T] =
headers.collectFirst { // notes on Headers that re REQUIRE to be present here:
case t: T t // - Host header is validated in general HTTP logic
// - Origin header is optional and, if required, should be validated
// on higher levels (routing, application logic)
//
// TODO See #18709 Extension support is optional in WS and currently unsupported.
//
// these are not needed directly, we verify their presence and correctness only:
// - Upgrade
// - Connection
// - `Sec-WebSocket-Version`
def hasAllRequiredWebsocketUpgradeHeaders: Boolean = {
// single-pass through the headers list while collecting all needed requirements
// this way we avoid scanning the requirements list 3 times (as we would with collect/find)
val it = headers.iterator
var requirementsMet = 0
val targetRequirements = 3
while (it.hasNext && (requirementsMet != targetRequirements)) it.next() match {
case u: Upgrade if (u.hasWebSocket) requirementsMet += 1
case c: Connection if (c.hasUpgrade) requirementsMet += 1
case v: `Sec-WebSocket-Version` if (v.hasVersion(CurrentWebSocketVersion)) requirementsMet += 1
case _ // continue...
} }
requirementsMet == targetRequirements
}
// Host header is validated in general HTTP logic if (hasAllRequiredWebsocketUpgradeHeaders) {
// val host = find[Host] val key = HttpHeader.fastFind(classOf[`Sec-WebSocket-Key`], headers)
val upgrade = find[Upgrade] if (key.isDefined && key.get.isValid) {
val connection = find[Connection] val protocol = HttpHeader.fastFind(classOf[`Sec-WebSocket-Protocol`], headers)
val key = find[`Sec-WebSocket-Key`]
val version = find[`Sec-WebSocket-Version`]
// Origin header is optional and, if required, should be validated
// on higher levels (routing, application logic)
// val origin = find[Origin]
val protocol = find[`Sec-WebSocket-Protocol`]
val clientSupportedSubprotocols = protocol.toList.flatMap(_.protocols)
// Extension support is optional in WS and currently unsupported.
// TODO See #18709
// val extensions = find[`Sec-WebSocket-Extensions`]
if (upgrade.exists(_.hasWebSocket) && val clientSupportedSubprotocols = protocol match {
connection.exists(_.hasUpgrade) && case OptionVal.Some(p) p.protocols
version.exists(_.hasVersion(CurrentWebSocketVersion)) && case _ Nil
key.exists(k k.isValid)) {
val header = new UpgradeToWebSocketLowLevel {
def requestedProtocols: Seq[String] = clientSupportedSubprotocols
def handle(handler: Either[Graph[FlowShape[FrameEvent, FrameEvent], Any], Graph[FlowShape[Message, Message], Any]], subprotocol: Option[String]): HttpResponse = {
require(
subprotocol.forall(chosen clientSupportedSubprotocols.contains(chosen)),
s"Tried to choose invalid subprotocol '$subprotocol' which wasn't offered by the client: [${requestedProtocols.mkString(", ")}]")
buildResponse(key.get, handler, subprotocol)
} }
def handleFrames(handlerFlow: Graph[FlowShape[FrameEvent, FrameEvent], Any], subprotocol: Option[String]): HttpResponse = val header = new UpgradeToWebSocketLowLevel {
handle(Left(handlerFlow), subprotocol) def requestedProtocols: Seq[String] = clientSupportedSubprotocols
override def handleMessages(handlerFlow: Graph[FlowShape[Message, Message], Any], subprotocol: Option[String] = None): HttpResponse = def handle(handler: Either[Graph[FlowShape[FrameEvent, FrameEvent], Any], Graph[FlowShape[Message, Message], Any]], subprotocol: Option[String]): HttpResponse = {
handle(Right(handlerFlow), subprotocol) require(
} subprotocol.forall(chosen clientSupportedSubprotocols.contains(chosen)),
Some(header) s"Tried to choose invalid subprotocol '$subprotocol' which wasn't offered by the client: [${requestedProtocols.mkString(", ")}]")
} else None buildResponse(key.get, handler, subprotocol)
}
def handleFrames(handlerFlow: Graph[FlowShape[FrameEvent, FrameEvent], Any], subprotocol: Option[String]): HttpResponse =
handle(Left(handlerFlow), subprotocol)
override def handleMessages(handlerFlow: Graph[FlowShape[Message, Message], Any], subprotocol: Option[String] = None): HttpResponse =
handle(Right(handlerFlow), subprotocol)
}
OptionVal.Some(header)
} else OptionVal.None
} else OptionVal.None
} }
/* /*

View file

@ -4,14 +4,19 @@
package akka.http.scaladsl.model package akka.http.scaladsl.model
import akka.http.scaladsl.settings.ParserSettings import java.nio.charset.StandardCharsets
import scala.util.{ Success, Failure } import scala.util.{ Failure, Success }
import akka.parboiled2.ParseError import akka.parboiled2.{ ParseError, ParserInput }
import akka.http.impl.util.ToStringRenderable import akka.http.impl.util.ToStringRenderable
import akka.http.impl.model.parser.{ CharacterClasses, HeaderParser } import akka.http.impl.model.parser.{ CharacterClasses, HeaderParser }
import akka.http.javadsl.{ model jm } import akka.http.javadsl.{ model jm }
import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.model.headers._
import akka.parboiled2.ParserInput.DefaultParserInput
import akka.util.{ ByteString, OptionVal }
import scala.annotation.tailrec
import scala.collection.immutable
/** /**
* The model of an HTTP header. In its most basic form headers are simple name-value pairs. Header names * The model of an HTTP header. In its most basic form headers are simple name-value pairs. Header names
@ -80,6 +85,16 @@ object HttpHeader {
} }
} else ParsingResult.Error(ErrorInfo(s"Illegal HTTP header name", name)) } else ParsingResult.Error(ErrorInfo(s"Illegal HTTP header name", name))
/** INTERNAL API */
private[akka] def fastFind[T >: Null <: jm.HttpHeader](clazz: Class[T], headers: immutable.Seq[HttpHeader]): OptionVal[T] = {
val it = headers.iterator
while (it.hasNext) it.next() match {
case h if clazz.isInstance(h) return OptionVal.Some[T](h.asInstanceOf[T])
case _ // continue ...
}
OptionVal.None.asInstanceOf[OptionVal[T]]
}
sealed trait ParsingResult { sealed trait ParsingResult {
def errors: List[ErrorInfo] def errors: List[ErrorInfo]
} }

View file

@ -19,7 +19,7 @@ import scala.reflect.{ ClassTag, classTag }
import akka.Done import akka.Done
import akka.parboiled2.CharUtils import akka.parboiled2.CharUtils
import akka.stream.Materializer import akka.stream.Materializer
import akka.util.{ ByteString, HashCode } import akka.util.{ ByteString, HashCode, OptionVal }
import akka.http.impl.util._ import akka.http.impl.util._
import akka.http.javadsl.{ model jm } import akka.http.javadsl.{ model jm }
import akka.http.scaladsl.util.FastFuture._ import akka.http.scaladsl.util.FastFuture._
@ -102,14 +102,13 @@ sealed trait HttpMessage extends jm.HttpMessage {
} }
/** Returns the first header of the given type if there is one */ /** Returns the first header of the given type if there is one */
def header[T <: jm.HttpHeader: ClassTag]: Option[T] = { def header[T >: Null <: jm.HttpHeader: ClassTag]: Option[T] = {
val erasure = classTag[T].runtimeClass val clazz = classTag[T].runtimeClass.asInstanceOf[Class[T]]
headers.find(erasure.isInstance).asInstanceOf[Option[T]] match { HttpHeader.fastFind[T](clazz, headers) match {
case header: Some[T] header case OptionVal.Some(h) Some(h)
case _ if erasure == classOf[`Content-Type`] Some(entity.contentType).asInstanceOf[Option[T]] case _ if clazz == classOf[`Content-Type`] Some(entity.contentType).asInstanceOf[Option[T]]
case _ None case _ None
} }
} }
/** /**
@ -145,7 +144,11 @@ sealed trait HttpMessage extends jm.HttpMessage {
/** Java API */ /** Java API */
def getHeaders: JIterable[jm.HttpHeader] = (headers: immutable.Seq[jm.HttpHeader]).asJava def getHeaders: JIterable[jm.HttpHeader] = (headers: immutable.Seq[jm.HttpHeader]).asJava
/** Java API */ /** Java API */
def getHeader[T <: jm.HttpHeader](headerClass: Class[T]): Optional[T] = header(ClassTag(headerClass)).asJava def getHeader[T <: jm.HttpHeader](headerClass: Class[T]): Optional[T] =
HttpHeader.fastFind[jm.HttpHeader](headerClass.asInstanceOf[Class[jm.HttpHeader]], headers) match {
case OptionVal.Some(h) Optional.of(h.asInstanceOf[T])
case _ Optional.empty()
}
/** Java API */ /** Java API */
def getHeader(headerName: String): Optional[jm.HttpHeader] = { def getHeader(headerName: String): Optional[jm.HttpHeader] = {
val lowerCased = headerName.toRootLowerCase val lowerCased = headerName.toRootLowerCase
@ -322,14 +325,22 @@ object HttpRequest {
* include a valid [[akka.http.scaladsl.model.headers.Host]] header or if URI authority and [[akka.http.scaladsl.model.headers.Host]] header don't match. * include a valid [[akka.http.scaladsl.model.headers.Host]] header or if URI authority and [[akka.http.scaladsl.model.headers.Host]] header don't match.
*/ */
def effectiveUri(uri: Uri, headers: immutable.Seq[HttpHeader], securedConnection: Boolean, defaultHostHeader: Host): Uri = { def effectiveUri(uri: Uri, headers: immutable.Seq[HttpHeader], securedConnection: Boolean, defaultHostHeader: Host): Uri = {
val hostHeader = headers.collectFirst { case x: Host x } def findHost(headers: immutable.Seq[HttpHeader]): OptionVal[Host] = {
val it = headers.iterator
while (it.hasNext) it.next() match {
case h: Host return OptionVal.Some(h)
case _ // continue ...
}
OptionVal.None
}
val hostHeader: OptionVal[Host] = findHost(headers)
if (uri.isRelative) { if (uri.isRelative) {
def fail(detail: String) = def fail(detail: String) =
throw IllegalUriException(s"Cannot establish effective URI of request to `$uri`, request has a relative URI and $detail") throw IllegalUriException(s"Cannot establish effective URI of request to `$uri`, request has a relative URI and $detail")
val Host(host, port) = hostHeader match { val Host(host, port) = hostHeader match {
case None if (defaultHostHeader.isEmpty) fail("is missing a `Host` header") else defaultHostHeader case OptionVal.None if (defaultHostHeader.isEmpty) fail("is missing a `Host` header") else defaultHostHeader
case Some(x) if x.isEmpty if (defaultHostHeader.isEmpty) fail("an empty `Host` header") else defaultHostHeader case OptionVal.Some(x) if x.isEmpty if (defaultHostHeader.isEmpty) fail("an empty `Host` header") else defaultHostHeader
case Some(x) x case OptionVal.Some(x) x
} }
uri.toEffectiveHttpRequestUri(host, port, securedConnection) uri.toEffectiveHttpRequestUri(host, port, securedConnection)
} else // http://tools.ietf.org/html/rfc7230#section-5.4 } else // http://tools.ietf.org/html/rfc7230#section-5.4

View file

@ -3,7 +3,7 @@
*/ */
package akka.http.impl.engine.ws package akka.http.impl.engine.ws
import scala.concurrent.{ Await, Promise } import scala.concurrent.{ Await, Future, Promise }
import scala.concurrent.duration.DurationInt import scala.concurrent.duration.DurationInt
import org.scalactic.ConversionCheckedTripleEquals import org.scalactic.ConversionCheckedTripleEquals
import org.scalatest.concurrent.ScalaFutures import org.scalatest.concurrent.ScalaFutures

View file

@ -103,7 +103,7 @@ abstract class TestRouteResult(_result: RouteResult, awaitAtMost: FiniteDuration
/** /**
* Returns the first header of the response which is of the given class. * Returns the first header of the response which is of the given class.
*/ */
def header[T <: HttpHeader](clazz: Class[T]): T = def header[T >: Null <: HttpHeader](clazz: Class[T]): T =
response.header(ClassTag(clazz)) response.header(ClassTag(clazz))
.getOrElse(doFail(s"Expected header of type ${clazz.getSimpleName} but wasn't found.")) .getOrElse(doFail(s"Expected header of type ${clazz.getSimpleName} but wasn't found."))

View file

@ -75,7 +75,7 @@ trait RouteTest extends RequestBuilding with WSTestRequestBuilding with RouteTes
def charsetOption: Option[HttpCharset] = contentType.charsetOption def charsetOption: Option[HttpCharset] = contentType.charsetOption
def charset: HttpCharset = charsetOption getOrElse sys.error("Binary entity does not have charset") def charset: HttpCharset = charsetOption getOrElse sys.error("Binary entity does not have charset")
def headers: immutable.Seq[HttpHeader] = response.headers def headers: immutable.Seq[HttpHeader] = response.headers
def header[T <: HttpHeader: ClassTag]: Option[T] = response.header[T] def header[T >: Null <: HttpHeader: ClassTag]: Option[T] = response.header[T](implicitly[ClassTag[T]])
def header(name: String): Option[HttpHeader] = response.headers.find(_.is(name.toLowerCase)) def header(name: String): Option[HttpHeader] = response.headers.find(_.is(name.toLowerCase))
def status: StatusCode = response.status def status: StatusCode = response.status

View file

@ -18,6 +18,9 @@ package akka.parboiled2
import scala.annotation.tailrec import scala.annotation.tailrec
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import akka.parboiled2.ParserInput.DefaultParserInput
trait ParserInput { trait ParserInput {
/** /**
@ -109,4 +112,4 @@ object ParserInput {
def sliceString(start: Int, end: Int) = new String(chars, start, end - start) def sliceString(start: Int, end: Int) = new String(chars, start, end - start)
def sliceCharArray(start: Int, end: Int) = java.util.Arrays.copyOfRange(chars, start, end) def sliceCharArray(start: Int, end: Int) = java.util.Arrays.copyOfRange(chars, start, end)
} }
} }