diff --git a/akka-actor/src/main/scala/akka/util/ByteString.scala b/akka-actor/src/main/scala/akka/util/ByteString.scala index 1e782f3c09..cb1dd3aef2 100644 --- a/akka-actor/src/main/scala/akka/util/ByteString.scala +++ b/akka-actor/src/main/scala/akka/util/ByteString.scala @@ -382,7 +382,7 @@ object ByteString { } else throw new IndexOutOfBoundsException(idx.toString) } - // Avoid `iterator` in performance sensitive code, call ops directly on ByteString instead + /** Avoid `iterator` in performance sensitive code, call ops directly on ByteString instead */ override def iterator: ByteIterator.MultiByteArrayIterator = ByteIterator.MultiByteArrayIterator(bytestrings.toStream map { _.iterator }) diff --git a/akka-actor/src/main/scala/akka/util/OptionVal.scala b/akka-actor/src/main/scala/akka/util/OptionVal.scala new file mode 100644 index 0000000000..7238c54233 --- /dev/null +++ b/akka-actor/src/main/scala/akka/util/OptionVal.scala @@ -0,0 +1,69 @@ +/** + * Copyright (C) 2016 Lightbend Inc. + */ +package akka.util + +/** + * INTERNAL API + */ +private[akka] object OptionVal { + + def apply[A >: Null](x: A): OptionVal[A] = new OptionVal(x) + + object Some { + def apply[A >: Null](x: A): OptionVal[A] = new OptionVal(x) + def unapply[A >: Null](x: OptionVal[A]): OptionVal[A] = x + } + + /** + * Represents non-existent values, `null` values. + */ + val None = new OptionVal[Null](null) +} + +/** + * INTERNAL API + * Represents optional values similar to `scala.Option`, but + * as a value class to avoid allocations. + * + * Note that it can be used in pattern matching without allocations + * because it has name based extractor using methods `isEmpty` and `get`. + * See http://hseeberger.github.io/blog/2013/10/04/name-based-extractors-in-scala-2-dot-11/ + */ +private[akka] final class OptionVal[+A >: Null](val x: A) extends AnyVal { + + /** + * Returns true if the option is `OptionVal.None`, false otherwise. + */ + def isEmpty: Boolean = + x == null + + /** + * Returns true if the option is `OptionVal.None`, false otherwise. + */ + def isDefined: Boolean = !isEmpty + + /** + * Returns the option's value if the option is nonempty, otherwise + * return `default`. + */ + def getOrElse[B >: A](default: B): B = + if (x == null) default else x + + /** + * Returns the option's value if it is nonempty, or `null` if it is empty. + */ + def orNull[A1 >: A](implicit ev: Null <:< A1): A1 = this getOrElse ev(null) + + /** + * Returns the option's value. + * @note The option must be nonEmpty. + * @throws java.util.NoSuchElementException if the option is empty. + */ + def get: A = + if (x == null) throw new NoSuchElementException("OptionVal.None.get") + else x + + override def toString: String = + if (x == null) "None" else s"Some($x)" +} diff --git a/akka-bench-jmh/src/main/scala/akka/http/HttpBlueprintBenchmark.scala b/akka-bench-jmh/src/main/scala/akka/http/HttpBlueprintBenchmark.scala index c254ea5692..d9925b8a82 100644 --- a/akka-bench-jmh/src/main/scala/akka/http/HttpBlueprintBenchmark.scala +++ b/akka-bench-jmh/src/main/scala/akka/http/HttpBlueprintBenchmark.scala @@ -110,30 +110,13 @@ class HttpBlueprintBenchmark { Flow.fromSinkAndSource(Sink.cancelled, Source.empty) @Benchmark - @OperationsPerInvocation(100 * 1000) - def run_10000_reqs(blackhole: Blackhole) = { - val n = 100 * 1000 + @OperationsPerInvocation(100000) + def run_10000_reqs() = { + val n = 100000 val latch = new CountDownLatch(n) val replyCountdown = reply map { x => latch.countDown() - blackhole.consume(x) - x - } - server(n).joinMat(replyCountdown)(Keep.right).run()(materializer) - - latch.await() - } - - @Benchmark - @OperationsPerInvocation(10 * 1000) - def run_1000_reqs(blackhole: Blackhole) = { - val n = 10 * 1000 - val latch = new CountDownLatch(n) - - val replyCountdown = reply map { x => - latch.countDown() - blackhole.consume(x) x } server(n).joinMat(replyCountdown)(Keep.right).run()(materializer) @@ -142,3 +125,4 @@ class HttpBlueprintBenchmark { } } + diff --git a/akka-bench-jmh/src/main/scala/akka/http/HttpRequestParsingBenchmark.scala b/akka-bench-jmh/src/main/scala/akka/http/HttpRequestParsingBenchmark.scala index d6f6eb15f8..4233a5115b 100644 --- a/akka-bench-jmh/src/main/scala/akka/http/HttpRequestParsingBenchmark.scala +++ b/akka-bench-jmh/src/main/scala/akka/http/HttpRequestParsingBenchmark.scala @@ -3,25 +3,24 @@ */ package akka.http -import java.util.concurrent.{ CountDownLatch, TimeUnit } +import java.util.concurrent.TimeUnit import javax.net.ssl.SSLContext import akka.Done import akka.actor.ActorSystem -import akka.http.impl.engine.parsing.ParserOutput.RequestOutput -import akka.http.impl.engine.parsing.{ HttpHeaderParser, HttpMessageParser, HttpRequestParser } +import akka.event.NoLogging +import akka.http.impl.engine.parsing.{ HttpHeaderParser, HttpRequestParser } import akka.http.scaladsl.settings.ParserSettings import akka.event.NoLogging +import akka.stream.ActorMaterializer import akka.stream.TLSProtocol.SessionBytes -import akka.stream.scaladsl.RunnableGraph -import akka.stream.{ ActorMaterializer, Attributes } -import akka.stream.scaladsl.{ Flow, Keep, Sink, Source } +import akka.stream.scaladsl._ import akka.util.ByteString import org.openjdk.jmh.annotations.{ OperationsPerInvocation, _ } import org.openjdk.jmh.infra.Blackhole -import scala.concurrent.{ Await, Future } import scala.concurrent.duration._ +import scala.concurrent.{ Await, Future } @State(Scope.Benchmark) @OutputTimeUnit(TimeUnit.SECONDS) @@ -29,48 +28,70 @@ import scala.concurrent.duration._ class HttpRequestParsingBenchmark { implicit val system: ActorSystem = ActorSystem("HttpRequestParsingBenchmark") - implicit val materializer = ActorMaterializer() + implicit val materializer = ActorMaterializer()(system) val parserSettings = ParserSettings(system) val parser = new HttpRequestParser(parserSettings, false, HttpHeaderParser(parserSettings, NoLogging)()) val dummySession = SSLContext.getDefault.createSSLEngine.getSession - val requestBytes = SessionBytes( + + @Param(Array("small", "large")) + var req: String = "" + + def request = req match { + case "small" => requestBytesSmall + case "large" => requestBytesLarge + } + + val requestBytesSmall: SessionBytes = SessionBytes( dummySession, ByteString( - "GET / HTTP/1.1\r\n" + - "Accept: */*\r\n" + - "Accept-Encoding: gzip, deflate\r\n" + - "Connection: keep-alive\r\n" + - "Host: example.com\r\n" + - "User-Agent: HTTPie/0.9.3\r\n" + - "\r\n" + """|GET / HTTP/1.1 + |Accept: */* + |Accept-Encoding: gzip, deflate + |Connection: keep-alive + |Host: example.com + |User-Agent: HTTPie/0.9.3 + | + |""".stripMargin.replaceAll("\n", "\r\n") ) ) + val requestBytesLarge: SessionBytes = SessionBytes( + dummySession, + ByteString( + """|GET /json HTTP/1.1 + |Host: server + |User-Agent: Mozilla/5.0 (X11; Linux x86_64) Gecko/20130501 Firefox/30.0 AppleWebKit/600.00 Chrome/30.0.0000.0 Trident/10.0 Safari/600.00 + |Cookie: uid=12345678901234567890; __utma=1.1234567890.1234567890.1234567890.1234567890.12; wd=2560x1600 + |Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 + |Accept-Language: en-US,en;q=0.5 + |Connection: keep-alive + | + |""".stripMargin.replaceAll("\n", "\r\n") + ) + ) + + /* + // before: + [info] Benchmark (req) Mode Cnt Score Error Units + [info] HttpRequestParsingBenchmark.parse_10000_requests small thrpt 20 358 982.157 ± 93745.863 ops/s + [info] HttpRequestParsingBenchmark.parse_10000_requests large thrpt 20 388 335.666 ± 16990.715 ops/s + + // after: + [info] HttpRequestParsingBenchmark.parse_10000_requests_val small thrpt 20 623 975.879 ± 6191.897 ops/s + [info] HttpRequestParsingBenchmark.parse_10000_requests_val large thrpt 20 507 460.283 ± 4735.843 ops/s + */ + val httpMessageParser = Flow.fromGraph(parser) - def flow(n: Int): RunnableGraph[Future[Done]] = - Source.repeat(requestBytes).take(n) + def flow(bytes: SessionBytes, n: Int): RunnableGraph[Future[Done]] = + Source.repeat(request).take(n) .via(httpMessageParser) .toMat(Sink.ignore)(Keep.right) @Benchmark @OperationsPerInvocation(10000) - def parse_10000_single_requests(blackhole: Blackhole): Unit = { - val done = flow(10000).run() - Await.ready(done, 32.days) - } - - @Benchmark - @OperationsPerInvocation(1000) - def parse_1000_single_requests(blackhole: Blackhole): Unit = { - val done = flow(1000).run() - Await.ready(done, 32.days) - } - - @Benchmark - @OperationsPerInvocation(100) - def parse_100_single_requests(blackhole: Blackhole): Unit = { - val done = flow(100).run() + def parse_10000_requests_val(blackhole: Blackhole): Unit = { + val done = flow(requestBytesSmall, 10000).run() Await.ready(done, 32.days) } diff --git a/akka-bench-jmh/src/main/scala/akka/http/HttpResponseRenderingBenchmark.scala b/akka-bench-jmh/src/main/scala/akka/http/HttpResponseRenderingBenchmark.scala new file mode 100644 index 0000000000..8e5aa5c10a --- /dev/null +++ b/akka-bench-jmh/src/main/scala/akka/http/HttpResponseRenderingBenchmark.scala @@ -0,0 +1,250 @@ +/** + * Copyright (C) 2015-2016 Lightbend Inc. + */ + +package akka.http + +import java.util.concurrent.{ CountDownLatch, TimeUnit } + +import akka.NotUsed +import akka.actor.ActorSystem +import akka.event.NoLogging +import akka.http.impl.engine.rendering.ResponseRenderingOutput.HttpData +import akka.http.impl.engine.rendering.{ HttpResponseRendererFactory, ResponseRenderingContext, ResponseRenderingOutput } +import akka.http.scaladsl.Http +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.Server +import akka.http.scaladsl.unmarshalling.Unmarshal +import akka.stream._ +import akka.stream.scaladsl._ +import akka.stream.stage.{ GraphStageLogic, GraphStageWithMaterializedValue, InHandler } +import akka.util.ByteString +import com.typesafe.config.ConfigFactory +import org.openjdk.jmh.annotations._ +import org.openjdk.jmh.infra.Blackhole + +import scala.concurrent.duration._ +import scala.concurrent.{ Await, Future } +import scala.util.Try + +@State(Scope.Benchmark) +@OutputTimeUnit(TimeUnit.SECONDS) +@BenchmarkMode(Array(Mode.Throughput)) +class HttpResponseRenderingBenchmark extends HttpResponseRendererFactory( + serverHeader = Some(Server("Akka HTTP 2.4.x")), + responseHeaderSizeHint = 64, + log = NoLogging +) { + + val config = ConfigFactory.parseString( + """ + akka { + loglevel = "ERROR" + }""".stripMargin + ).withFallback(ConfigFactory.load()) + + implicit val system = ActorSystem("HttpResponseRenderingBenchmark", config) + implicit val materializer = ActorMaterializer() + + import system.dispatcher + + val requestRendered = ByteString( + "GET / HTTP/1.1\r\n" + + "Accept: */*\r\n" + + "Accept-Encoding: gzip, deflate\r\n" + + "Connection: keep-alive\r\n" + + "Host: example.com\r\n" + + "User-Agent: HTTPie/0.9.3\r\n" + + "\r\n" + ) + + def TCPPlacebo(requests: Int): Flow[ByteString, ByteString, NotUsed] = + Flow.fromSinkAndSource( + Flow[ByteString].takeWhile(it => !(it.utf8String contains "Connection: close")) to Sink.ignore, + Source.repeat(requestRendered).take(requests) + ) + + def TlsPlacebo = TLSPlacebo() + + val requestRendering: Flow[HttpRequest, String, NotUsed] = + Http() + .clientLayer(headers.Host("blah.com")) + .atop(TlsPlacebo) + .join { + Flow[ByteString].map { x ⇒ + val response = s"HTTP/1.1 200 OK\r\nContent-Length: ${x.size}\r\n\r\n" + ByteString(response) ++ x + } + } + .mapAsync(1)(response => Unmarshal(response).to[String]) + + def renderResponse: Future[String] = Source.single(HttpRequest(uri = "/foo")) + .via(requestRendering) + .runWith(Sink.head) + + var request: HttpRequest = _ + var pool: Flow[(HttpRequest, Int), (Try[HttpResponse], Int), _] = _ + + @TearDown + def shutdown(): Unit = { + Await.ready(Http().shutdownAllConnectionPools(), 1.second) + Await.result(system.terminate(), 5.seconds) + } + + /* + [info] Benchmark Mode Cnt Score Error Units + [info] HttpResponseRenderingBenchmark.header_date_val thrpt 20 2 704 169 260 029.906 ± 234456086114.237 ops/s + + // def, normal time + [info] HttpResponseRenderingBenchmark.header_date_def thrpt 20 178 297 625 609.638 ± 7429280865.659 ops/s + [info] HttpResponseRenderingBenchmark.response_ok_simple_val thrpt 20 1 258 119.673 ± 58399.454 ops/s + [info] HttpResponseRenderingBenchmark.response_ok_simple_def thrpt 20 687 576.928 ± 94813.618 ops/s + + // clock nanos + [info] HttpResponseRenderingBenchmark.response_ok_simple_clock thrpt 20 1 676 438.649 ± 33976.590 ops/s + [info] HttpResponseRenderingBenchmark.response_ok_simple_clock thrpt 40 1 199 462.263 ± 222226.304 ops/s + + // ------ + + // before optimisig collectFirst + [info] HttpResponseRenderingBenchmark.json_response thrpt 20 1 782 572.845 ± 16572.625 ops/s + [info] HttpResponseRenderingBenchmark.simple_response thrpt 20 1 611 802.216 ± 19557.151 ops/s + + // after removing collectFirst and Option from renderHeaders + // not much of a difference, but hey, less Option allocs + [info] HttpResponseRenderingBenchmark.json_response thrpt 20 1 785 152.896 ± 15210.299 ops/s + [info] HttpResponseRenderingBenchmark.simple_response thrpt 20 1 783 800.184 ± 14938.415 ops/s + + // ----- + + // baseline for this optimisation is the above results (after collectFirst). + + // after introducing pre-rendered ContentType headers: + + normal clock + [info] HttpResponseRenderingBenchmark.json_long_raw_response thrpt 20 1738558.895 ± 159612.661 ops/s + [info] HttpResponseRenderingBenchmark.json_response thrpt 20 1714176.824 ± 100011.642 ops/s + + "fast clock" + [info] HttpResponseRenderingBenchmark.json_long_raw_response thrpt 20 1 528 632.480 ± 44934.827 ops/s + [info] HttpResponseRenderingBenchmark.json_response thrpt 20 1 517 383.792 ± 28256.716 ops/s + + */ + + /** + * HTTP/1.1 200 OK + * Server: Akka HTTP 2.4.x + * Date: Tue, 26 Jul 2016 15:26:53 GMT + * Content-Type: text/plain; charset=UTF-8 + * Content-Length: 6 + * + * ENTITY + */ + val simpleResponse = + ResponseRenderingContext( + response = HttpResponse( + 200, + headers = Nil, + entity = HttpEntity("ENTITY") + ), + requestMethod = HttpMethods.GET + ) + + /** + * HTTP/1.1 200 OK + * Server: Akka HTTP 2.4.x + * Date: Tue, 26 Jul 2016 15:26:53 GMT + * Content-Type: application/json + * Content-Length: 27 + * + * {"message":"Hello, World!"} + */ + val jsonResponse = + ResponseRenderingContext( + response = HttpResponse( + 200, + headers = Nil, + entity = HttpEntity(ContentTypes.`application/json`, """{"message":"Hello, World!"}""") + ), + requestMethod = HttpMethods.GET + ) + + /** + * HTTP/1.1 200 OK + * Server: Akka HTTP 2.4.x + * Date: Tue, 26 Jul 2016 15:26:53 GMT + * Content-Type: application/json + * Content-Length: 315 + * + * [{"id":4174,"randomNumber":331},{"id":51,"randomNumber":6544},{"id":4462,"randomNumber":952},{"id":2221,"randomNumber":532},{"id":9276,"randomNumber":3097},{"id":3056,"randomNumber":7293},{"id":6964,"randomNumber":620},{"id":675,"randomNumber":6601},{"id":8414,"randomNumber":6569},{"id":2753,"randomNumber":4065}] + */ + val jsonLongRawResponse = + ResponseRenderingContext( + response = HttpResponse( + 200, + headers = Nil, + entity = HttpEntity(ContentTypes.`application/json`, """[{"id":4174,"randomNumber":331},{"id":51,"randomNumber":6544},{"id":4462,"randomNumber":952},{"id":2221,"randomNumber":532},{"id":9276,"randomNumber":3097},{"id":3056,"randomNumber":7293},{"id":6964,"randomNumber":620},{"id":675,"randomNumber":6601},{"id":8414,"randomNumber":6569},{"id":2753,"randomNumber":4065}]""") + ), + requestMethod = HttpMethods.GET + ) + + @Benchmark + @Threads(8) + @OperationsPerInvocation(100 * 1000) + def simple_response(blackhole: Blackhole): Unit = + renderToImpl(simpleResponse, blackhole, n = 100 * 1000).await() + + @Benchmark + @OperationsPerInvocation(100 * 1000) + def json_response(blackhole: Blackhole): Unit = + renderToImpl(jsonResponse, blackhole, n = 100 * 1000).await() + + /* + Difference between 27 and 315 bytes long JSON is: + + [info] Benchmark Mode Cnt Score Error Units + [info] HttpResponseRenderingBenchmark.json_long_raw_response thrpt 20 1 932 331.049 ± 64125.621 ops/s + [info] HttpResponseRenderingBenchmark.json_response thrpt 20 1 973 232.941 ± 18568.314 ops/s + */ + @Benchmark + @OperationsPerInvocation(100 * 1000) + def json_long_raw_response(blackhole: Blackhole): Unit = + renderToImpl(jsonLongRawResponse, blackhole, n = 100 * 1000).await() + + class JitSafeLatch[A](blackhole: Blackhole, n: Int) extends GraphStageWithMaterializedValue[SinkShape[A], CountDownLatch] { + val in = Inlet[A]("JitSafeLatch.in") + override val shape = SinkShape(in) + + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, CountDownLatch) = { + val latch = new CountDownLatch(n) + val logic = new GraphStageLogic(shape) with InHandler { + + override def preStart(): Unit = pull(in) + override def onPush(): Unit = { + if (blackhole ne null) blackhole.consume(grab(in)) + latch.countDown() + pull(in) + } + + setHandler(in, this) + } + + (logic, latch) + } + } + + def renderToImpl(ctx: ResponseRenderingContext, blackhole: Blackhole, n: Int)(implicit mat: Materializer): CountDownLatch = { + val latch = + (Source.repeat(ctx).take(n) ++ Source.maybe[ResponseRenderingContext]) // never send upstream completion + .via(renderer.named("renderer")) + .runWith(new JitSafeLatch[ResponseRenderingOutput](blackhole, n)) + + latch + } + + // TODO benchmark with stable override + override def currentTimeMillis(): Long = System.currentTimeMillis() + // override def currentTimeMillis(): Long = System.currentTimeMillis() // DateTime(2011, 8, 25, 9, 10, 29).clicks // provide a stable date for testing + +} + diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala index 182bea75f1..c41fe5db1e 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala @@ -8,7 +8,7 @@ import java.lang.{ StringBuilder ⇒ JStringBuilder } import scala.annotation.{ switch, tailrec } import akka.http.scaladsl.settings.ParserSettings -import akka.util.ByteString +import akka.util.{ ByteString, OptionVal } import akka.http.impl.engine.ws.Handshake import akka.http.impl.model.parser.CharacterClasses import akka.http.scaladsl.model._ @@ -160,8 +160,8 @@ private[http] final class HttpRequestParser( val allHeaders = if (method == HttpMethods.GET) { Handshake.Server.websocketUpgrade(headers, hostHeaderPresent) match { - case Some(upgrade) ⇒ upgrade :: allHeaders0 - case None ⇒ allHeaders0 + case OptionVal.Some(upgrade) ⇒ upgrade :: allHeaders0 + case OptionVal.None ⇒ allHeaders0 } } else allHeaders0 diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala index 8a7ca48e72..3967ac5597 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala @@ -7,11 +7,11 @@ package akka.http.impl.engine.rendering import akka.NotUsed import akka.http.impl.engine.ws.{ FrameEvent, UpgradeToWebSocketResponseHeader } import akka.http.scaladsl.model.ws.Message -import akka.stream.{ Outlet, Inlet, Attributes, FlowShape, Graph } +import akka.stream.{ Attributes, FlowShape, Graph, Inlet, Outlet } import scala.annotation.tailrec import akka.event.LoggingAdapter -import akka.util.ByteString +import akka.util.{ ByteString, OptionVal } import akka.stream.scaladsl.{ Flow, Source } import akka.stream.stage._ import akka.http.scaladsl.model._ @@ -20,6 +20,8 @@ import RenderSupport._ import HttpProtocols._ import headers._ +import scala.concurrent.duration._ + /** * INTERNAL API */ @@ -129,9 +131,17 @@ private[http] class HttpResponseRendererFactory( @tailrec def renderHeaders(remaining: List[HttpHeader], alwaysClose: Boolean = false, connHeader: Connection = null, serverSeen: Boolean = false, - transferEncodingSeen: Boolean = false, dateSeen: Boolean = false): Unit = + transferEncodingSeen: Boolean = false, dateSeen: Boolean = false): Unit = { remaining match { case head :: tail ⇒ head match { + case x: Server ⇒ + render(x) + renderHeaders(tail, alwaysClose, connHeader, serverSeen = true, transferEncodingSeen, dateSeen) + + case x: Date ⇒ + render(x) + renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen = true) + case x: `Content-Length` ⇒ suppressionWarning(log, x, "explicit `Content-Length` header is not allowed. Use the appropriate HttpEntity subtype.") renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) @@ -140,10 +150,6 @@ private[http] class HttpResponseRendererFactory( suppressionWarning(log, x, "explicit `Content-Type` header is not allowed. Set `HttpResponse.entity.contentType` instead.") renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) - case x: Date ⇒ - render(x) - renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen = true) - case x: `Transfer-Encoding` ⇒ x.withChunkedPeeled match { case None ⇒ @@ -159,10 +165,6 @@ private[http] class HttpResponseRendererFactory( val connectionHeader = if (connHeader eq null) x else Connection(x.tokens ++ connHeader.tokens) renderHeaders(tail, alwaysClose, connectionHeader, serverSeen, transferEncodingSeen, dateSeen) - case x: Server ⇒ - render(x) - renderHeaders(tail, alwaysClose, connHeader, serverSeen = true, transferEncodingSeen, dateSeen) - case x: CustomHeader ⇒ if (x.renderInResponses) render(x) renderHeaders(tail, alwaysClose, connHeader, serverSeen, transferEncodingSeen, dateSeen) @@ -205,13 +207,15 @@ private[http] class HttpResponseRendererFactory( r ~~ Connection ~~ (if (close) CloseBytes else KeepAliveBytes) ~~ CrLf else if (connHeader != null && connHeader.hasUpgrade) { r ~~ connHeader ~~ CrLf - headers - .collectFirst { case u: UpgradeToWebSocketResponseHeader ⇒ u } - .foreach { header ⇒ closeMode = SwitchToWebSocket(header.handler) } + HttpHeader.fastFind(classOf[UpgradeToWebSocketResponseHeader], headers) match { + case OptionVal.Some(header) ⇒ closeMode = SwitchToWebSocket(header.handler) + case _ ⇒ // nothing to do here... + } } if (mustRenderTransferEncodingChunkedHeader && !transferEncodingSeen) r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf } + } def renderContentLengthHeader(contentLength: Long) = if (status.allowsEntity) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r @@ -219,7 +223,7 @@ private[http] class HttpResponseRendererFactory( def byteStrings(entityBytes: ⇒ Source[ByteString, Any]): Source[ResponseRenderingOutput, Any] = renderByteStrings(r, entityBytes, skipEntity = noEntity).map(ResponseRenderingOutput.HttpData(_)) - def completeResponseRendering(entity: ResponseEntity): StrictOrStreamed = + @tailrec def completeResponseRendering(entity: ResponseEntity): StrictOrStreamed = entity match { case HttpEntity.Strict(_, data) ⇒ renderHeaders(headers.toList) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/RenderSupport.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/RenderSupport.scala index ca4a3ef0a2..0c2ddc5643 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/RenderSupport.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/RenderSupport.scala @@ -28,6 +28,20 @@ private object RenderSupport { val KeepAliveBytes = "Keep-Alive".asciiBytes val CloseBytes = "close".asciiBytes + private[this] final val PreRenderedContentTypes = { + val m = new java.util.HashMap[ContentType, Array[Byte]](16) + def preRenderContentType(ct: ContentType) = + m.put(ct, (new ByteArrayRendering(32) ~~ headers.`Content-Type` ~~ ct ~~ CrLf).get) + + import ContentTypes._ + preRenderContentType(`application/json`) + preRenderContentType(`text/plain(UTF-8)`) + preRenderContentType(`text/xml(UTF-8)`) + preRenderContentType(`text/html(UTF-8)`) + preRenderContentType(`text/csv(UTF-8)`) + m + } + def CrLf = Rendering.CrLf implicit val trailerRenderer = Renderer.genericSeqRenderer[Renderable, HttpHeader](CrLf, Rendering.Empty) @@ -42,9 +56,14 @@ private object RenderSupport { }) } - def renderEntityContentType(r: Rendering, entity: HttpEntity) = - if (entity.contentType != ContentTypes.NoContentType) r ~~ headers.`Content-Type` ~~ entity.contentType ~~ CrLf - else r + def renderEntityContentType(r: Rendering, entity: HttpEntity) = { + val ct = entity.contentType + if (ct != ContentTypes.NoContentType) { + val preRendered = PreRenderedContentTypes.get(ct) + if (preRendered ne null) r ~~ preRendered // re-use pre-rendered + else r ~~ headers.`Content-Type` ~~ ct ~~ CrLf // render ad-hoc + } else r // don't render + } def renderByteStrings(r: ByteStringRendering, entityBytes: ⇒ Source[ByteString, Any], skipEntity: Boolean = false): Source[ByteString, Any] = { diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala index c207b7c27e..dd704f7d2c 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala @@ -5,6 +5,7 @@ package akka.http.impl.engine.ws import java.util.Random + import scala.collection.immutable import scala.collection.immutable.Seq import scala.reflect.ClassTag @@ -12,7 +13,8 @@ import akka.http.impl.util._ import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.model.ws.{ Message, UpgradeToWebSocket } import akka.http.scaladsl.model._ -import akka.stream.{ Graph, FlowShape } +import akka.stream.{ FlowShape, Graph } +import akka.util.OptionVal /** * Server-side implementation of the WebSocket handshake @@ -62,50 +64,63 @@ private[http] object Handshake { * to speak. The interpretation of this header field is discussed * in Section 9.1. */ - def websocketUpgrade(headers: List[HttpHeader], hostHeaderPresent: Boolean): Option[UpgradeToWebSocket] = { - def find[T <: HttpHeader: ClassTag]: Option[T] = - headers.collectFirst { - case t: T ⇒ t + def websocketUpgrade(headers: List[HttpHeader], hostHeaderPresent: Boolean): OptionVal[UpgradeToWebSocket] = { + + // notes on Headers that re REQUIRE to be present here: + // - Host header is validated in general HTTP logic + // - Origin header is optional and, if required, should be validated + // on higher levels (routing, application logic) + // + // TODO See #18709 Extension support is optional in WS and currently unsupported. + // + // these are not needed directly, we verify their presence and correctness only: + // - Upgrade + // - Connection + // - `Sec-WebSocket-Version` + def hasAllRequiredWebsocketUpgradeHeaders: Boolean = { + // single-pass through the headers list while collecting all needed requirements + // this way we avoid scanning the requirements list 3 times (as we would with collect/find) + val it = headers.iterator + var requirementsMet = 0 + val targetRequirements = 3 + while (it.hasNext && (requirementsMet != targetRequirements)) it.next() match { + case u: Upgrade ⇒ if (u.hasWebSocket) requirementsMet += 1 + case c: Connection ⇒ if (c.hasUpgrade) requirementsMet += 1 + case v: `Sec-WebSocket-Version` ⇒ if (v.hasVersion(CurrentWebSocketVersion)) requirementsMet += 1 + case _ ⇒ // continue... } + requirementsMet == targetRequirements + } - // Host header is validated in general HTTP logic - // val host = find[Host] - val upgrade = find[Upgrade] - val connection = find[Connection] - val key = find[`Sec-WebSocket-Key`] - val version = find[`Sec-WebSocket-Version`] - // Origin header is optional and, if required, should be validated - // on higher levels (routing, application logic) - // val origin = find[Origin] - val protocol = find[`Sec-WebSocket-Protocol`] - val clientSupportedSubprotocols = protocol.toList.flatMap(_.protocols) - // Extension support is optional in WS and currently unsupported. - // TODO See #18709 - // val extensions = find[`Sec-WebSocket-Extensions`] + if (hasAllRequiredWebsocketUpgradeHeaders) { + val key = HttpHeader.fastFind(classOf[`Sec-WebSocket-Key`], headers) + if (key.isDefined && key.get.isValid) { + val protocol = HttpHeader.fastFind(classOf[`Sec-WebSocket-Protocol`], headers) - if (upgrade.exists(_.hasWebSocket) && - connection.exists(_.hasUpgrade) && - version.exists(_.hasVersion(CurrentWebSocketVersion)) && - key.exists(k ⇒ k.isValid)) { - - val header = new UpgradeToWebSocketLowLevel { - def requestedProtocols: Seq[String] = clientSupportedSubprotocols - - def handle(handler: Either[Graph[FlowShape[FrameEvent, FrameEvent], Any], Graph[FlowShape[Message, Message], Any]], subprotocol: Option[String]): HttpResponse = { - require( - subprotocol.forall(chosen ⇒ clientSupportedSubprotocols.contains(chosen)), - s"Tried to choose invalid subprotocol '$subprotocol' which wasn't offered by the client: [${requestedProtocols.mkString(", ")}]") - buildResponse(key.get, handler, subprotocol) + val clientSupportedSubprotocols = protocol match { + case OptionVal.Some(p) ⇒ p.protocols + case _ ⇒ Nil } - def handleFrames(handlerFlow: Graph[FlowShape[FrameEvent, FrameEvent], Any], subprotocol: Option[String]): HttpResponse = - handle(Left(handlerFlow), subprotocol) + val header = new UpgradeToWebSocketLowLevel { + def requestedProtocols: Seq[String] = clientSupportedSubprotocols - override def handleMessages(handlerFlow: Graph[FlowShape[Message, Message], Any], subprotocol: Option[String] = None): HttpResponse = - handle(Right(handlerFlow), subprotocol) - } - Some(header) - } else None + def handle(handler: Either[Graph[FlowShape[FrameEvent, FrameEvent], Any], Graph[FlowShape[Message, Message], Any]], subprotocol: Option[String]): HttpResponse = { + require( + subprotocol.forall(chosen ⇒ clientSupportedSubprotocols.contains(chosen)), + s"Tried to choose invalid subprotocol '$subprotocol' which wasn't offered by the client: [${requestedProtocols.mkString(", ")}]") + buildResponse(key.get, handler, subprotocol) + } + + def handleFrames(handlerFlow: Graph[FlowShape[FrameEvent, FrameEvent], Any], subprotocol: Option[String]): HttpResponse = + handle(Left(handlerFlow), subprotocol) + + override def handleMessages(handlerFlow: Graph[FlowShape[Message, Message], Any], subprotocol: Option[String] = None): HttpResponse = + handle(Right(handlerFlow), subprotocol) + } + OptionVal.Some(header) + } else OptionVal.None + } else OptionVal.None } /* diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpHeader.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpHeader.scala index 0896fec5c2..ecd3c5341a 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpHeader.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpHeader.scala @@ -4,14 +4,19 @@ package akka.http.scaladsl.model -import akka.http.scaladsl.settings.ParserSettings +import java.nio.charset.StandardCharsets -import scala.util.{ Success, Failure } -import akka.parboiled2.ParseError +import scala.util.{ Failure, Success } +import akka.parboiled2.{ ParseError, ParserInput } import akka.http.impl.util.ToStringRenderable import akka.http.impl.model.parser.{ CharacterClasses, HeaderParser } import akka.http.javadsl.{ model ⇒ jm } import akka.http.scaladsl.model.headers._ +import akka.parboiled2.ParserInput.DefaultParserInput +import akka.util.{ ByteString, OptionVal } + +import scala.annotation.tailrec +import scala.collection.immutable /** * The model of an HTTP header. In its most basic form headers are simple name-value pairs. Header names @@ -80,6 +85,16 @@ object HttpHeader { } } else ParsingResult.Error(ErrorInfo(s"Illegal HTTP header name", name)) + /** INTERNAL API */ + private[akka] def fastFind[T >: Null <: jm.HttpHeader](clazz: Class[T], headers: immutable.Seq[HttpHeader]): OptionVal[T] = { + val it = headers.iterator + while (it.hasNext) it.next() match { + case h if clazz.isInstance(h) ⇒ return OptionVal.Some[T](h.asInstanceOf[T]) + case _ ⇒ // continue ... + } + OptionVal.None.asInstanceOf[OptionVal[T]] + } + sealed trait ParsingResult { def errors: List[ErrorInfo] } diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpMessage.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpMessage.scala index ddc732305a..24601d1487 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpMessage.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpMessage.scala @@ -19,7 +19,7 @@ import scala.reflect.{ ClassTag, classTag } import akka.Done import akka.parboiled2.CharUtils import akka.stream.Materializer -import akka.util.{ ByteString, HashCode } +import akka.util.{ ByteString, HashCode, OptionVal } import akka.http.impl.util._ import akka.http.javadsl.{ model ⇒ jm } import akka.http.scaladsl.util.FastFuture._ @@ -102,14 +102,13 @@ sealed trait HttpMessage extends jm.HttpMessage { } /** Returns the first header of the given type if there is one */ - def header[T <: jm.HttpHeader: ClassTag]: Option[T] = { - val erasure = classTag[T].runtimeClass - headers.find(erasure.isInstance).asInstanceOf[Option[T]] match { - case header: Some[T] ⇒ header - case _ if erasure == classOf[`Content-Type`] ⇒ Some(entity.contentType).asInstanceOf[Option[T]] - case _ ⇒ None + def header[T >: Null <: jm.HttpHeader: ClassTag]: Option[T] = { + val clazz = classTag[T].runtimeClass.asInstanceOf[Class[T]] + HttpHeader.fastFind[T](clazz, headers) match { + case OptionVal.Some(h) ⇒ Some(h) + case _ if clazz == classOf[`Content-Type`] ⇒ Some(entity.contentType).asInstanceOf[Option[T]] + case _ ⇒ None } - } /** @@ -145,7 +144,11 @@ sealed trait HttpMessage extends jm.HttpMessage { /** Java API */ def getHeaders: JIterable[jm.HttpHeader] = (headers: immutable.Seq[jm.HttpHeader]).asJava /** Java API */ - def getHeader[T <: jm.HttpHeader](headerClass: Class[T]): Optional[T] = header(ClassTag(headerClass)).asJava + def getHeader[T <: jm.HttpHeader](headerClass: Class[T]): Optional[T] = + HttpHeader.fastFind[jm.HttpHeader](headerClass.asInstanceOf[Class[jm.HttpHeader]], headers) match { + case OptionVal.Some(h) ⇒ Optional.of(h.asInstanceOf[T]) + case _ ⇒ Optional.empty() + } /** Java API */ def getHeader(headerName: String): Optional[jm.HttpHeader] = { val lowerCased = headerName.toRootLowerCase @@ -322,14 +325,22 @@ object HttpRequest { * include a valid [[akka.http.scaladsl.model.headers.Host]] header or if URI authority and [[akka.http.scaladsl.model.headers.Host]] header don't match. */ def effectiveUri(uri: Uri, headers: immutable.Seq[HttpHeader], securedConnection: Boolean, defaultHostHeader: Host): Uri = { - val hostHeader = headers.collectFirst { case x: Host ⇒ x } + def findHost(headers: immutable.Seq[HttpHeader]): OptionVal[Host] = { + val it = headers.iterator + while (it.hasNext) it.next() match { + case h: Host ⇒ return OptionVal.Some(h) + case _ ⇒ // continue ... + } + OptionVal.None + } + val hostHeader: OptionVal[Host] = findHost(headers) if (uri.isRelative) { def fail(detail: String) = throw IllegalUriException(s"Cannot establish effective URI of request to `$uri`, request has a relative URI and $detail") val Host(host, port) = hostHeader match { - case None ⇒ if (defaultHostHeader.isEmpty) fail("is missing a `Host` header") else defaultHostHeader - case Some(x) if x.isEmpty ⇒ if (defaultHostHeader.isEmpty) fail("an empty `Host` header") else defaultHostHeader - case Some(x) ⇒ x + case OptionVal.None ⇒ if (defaultHostHeader.isEmpty) fail("is missing a `Host` header") else defaultHostHeader + case OptionVal.Some(x) if x.isEmpty ⇒ if (defaultHostHeader.isEmpty) fail("an empty `Host` header") else defaultHostHeader + case OptionVal.Some(x) ⇒ x } uri.toEffectiveHttpRequestUri(host, port, securedConnection) } else // http://tools.ietf.org/html/rfc7230#section-5.4 diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebSocketIntegrationSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebSocketIntegrationSpec.scala index 3c7c646b32..10241c6a11 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebSocketIntegrationSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebSocketIntegrationSpec.scala @@ -3,7 +3,7 @@ */ package akka.http.impl.engine.ws -import scala.concurrent.{ Await, Promise } +import scala.concurrent.{ Await, Future, Promise } import scala.concurrent.duration.DurationInt import org.scalactic.ConversionCheckedTripleEquals import org.scalatest.concurrent.ScalaFutures diff --git a/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/TestRouteResult.scala b/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/TestRouteResult.scala index 91edde58ac..1f27424f61 100644 --- a/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/TestRouteResult.scala +++ b/akka-http-testkit/src/main/scala/akka/http/javadsl/testkit/TestRouteResult.scala @@ -103,7 +103,7 @@ abstract class TestRouteResult(_result: RouteResult, awaitAtMost: FiniteDuration /** * Returns the first header of the response which is of the given class. */ - def header[T <: HttpHeader](clazz: Class[T]): T = + def header[T >: Null <: HttpHeader](clazz: Class[T]): T = response.header(ClassTag(clazz)) .getOrElse(doFail(s"Expected header of type ${clazz.getSimpleName} but wasn't found.")) diff --git a/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTest.scala b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTest.scala index 2c71fc3538..b7e1c15a36 100644 --- a/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTest.scala +++ b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTest.scala @@ -75,7 +75,7 @@ trait RouteTest extends RequestBuilding with WSTestRequestBuilding with RouteTes def charsetOption: Option[HttpCharset] = contentType.charsetOption def charset: HttpCharset = charsetOption getOrElse sys.error("Binary entity does not have charset") def headers: immutable.Seq[HttpHeader] = response.headers - def header[T <: HttpHeader: ClassTag]: Option[T] = response.header[T] + def header[T >: Null <: HttpHeader: ClassTag]: Option[T] = response.header[T](implicitly[ClassTag[T]]) def header(name: String): Option[HttpHeader] = response.headers.find(_.is(name.toLowerCase)) def status: StatusCode = response.status diff --git a/akka-parsing/src/main/scala/akka/parboiled2/ParserInput.scala b/akka-parsing/src/main/scala/akka/parboiled2/ParserInput.scala index 82f2b73250..445f5c0f67 100644 --- a/akka-parsing/src/main/scala/akka/parboiled2/ParserInput.scala +++ b/akka-parsing/src/main/scala/akka/parboiled2/ParserInput.scala @@ -18,6 +18,9 @@ package akka.parboiled2 import scala.annotation.tailrec import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets + +import akka.parboiled2.ParserInput.DefaultParserInput trait ParserInput { /** @@ -109,4 +112,4 @@ object ParserInput { def sliceString(start: Int, end: Int) = new String(chars, start, end - start) def sliceCharArray(start: Int, end: Int) = java.util.Arrays.copyOfRange(chars, start, end) } -} \ No newline at end of file +}