From 435a3387bfbd997e67147cc17031348836396fce Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Tue, 6 Oct 2015 14:01:37 +0200 Subject: [PATCH 01/10] =htc cleanup of WS infrastructure --- .../http/impl/engine/ws/FrameOutHandler.scala | 2 +- .../http/impl/engine/ws/WSTestSetupBase.scala | 121 ++++++++++++++++++ .../http/impl/engine/ws/WSTestUtils.scala | 15 ++- .../impl/engine/ws/WebsocketServerSpec.scala | 41 +----- 4 files changed, 136 insertions(+), 43 deletions(-) create mode 100644 akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestSetupBase.scala diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameOutHandler.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameOutHandler.scala index ed29f99736..e1951d8323 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameOutHandler.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameOutHandler.scala @@ -10,7 +10,7 @@ import scala.concurrent.duration.FiniteDuration import akka.stream.stage._ import akka.http.impl.util.Timestamp -import FrameHandler.{ UserHandlerCompleted, ActivelyCloseWithCode, PeerClosed, DirectAnswer } +import akka.http.impl.engine.ws.FrameHandler._ import Websocket.Tick /** diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestSetupBase.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestSetupBase.scala new file mode 100644 index 0000000000..65c5548f13 --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestSetupBase.scala @@ -0,0 +1,121 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.impl.engine.ws + +import akka.http.impl.engine.ws.Protocol.Opcode +import akka.http.impl.engine.ws.WSTestUtils._ +import akka.util.ByteString +import org.scalatest.Matchers + +import scala.annotation.tailrec +import scala.util.Random + +trait WSTestSetupBase extends Matchers { + def send(bytes: ByteString): Unit + def expectNextChunk(): ByteString + + def sendWSFrame(opcode: Opcode, + data: ByteString, + fin: Boolean, + mask: Boolean = false, + rsv1: Boolean = false, + rsv2: Boolean = false, + rsv3: Boolean = false): Unit = { + val (theMask, theData) = + if (mask) { + val m = Random.nextInt() + (Some(m), maskedBytes(data, m)._1) + } else (None, data) + send(frameHeader(opcode, data.length, fin, theMask, rsv1, rsv2, rsv3) ++ theData) + } + + def sendWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit = + send(closeFrame(closeCode, mask)) + + def expectWSFrame(opcode: Opcode, + data: ByteString, + fin: Boolean, + mask: Option[Int] = None, + rsv1: Boolean = false, + rsv2: Boolean = false, + rsv3: Boolean = false): Unit = + expectNextChunk() shouldEqual frameHeader(opcode, data.length, fin, mask, rsv1, rsv2, rsv3) ++ data + + def expectWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit = + expectNextChunk() shouldEqual closeFrame(closeCode, mask) + + var inBuffer = ByteString.empty + @tailrec final def expectNetworkData(bytes: Int): ByteString = + if (inBuffer.size >= bytes) { + val res = inBuffer.take(bytes) + inBuffer = inBuffer.drop(bytes) + res + } else { + inBuffer ++= expectNextChunk() + expectNetworkData(bytes) + } + + def expectNetworkData(data: ByteString): Unit = + expectNetworkData(data.size) shouldEqual data + + def expectFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = { + expectFrameHeaderOnNetwork(opcode, data.size, fin) + expectNetworkData(data) + } + def expectMaskedFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = { + val Some(mask) = expectFrameHeaderOnNetwork(opcode, data.size, fin) + val masked = maskedBytes(data, mask)._1 + expectNetworkData(masked) + } + + def expectMaskedCloseFrame(closeCode: Int): Unit = + expectMaskedFrameOnNetwork(Protocol.Opcode.Close, closeFrameData(closeCode), fin = true) + + /** Returns the mask if any is available */ + def expectFrameHeaderOnNetwork(opcode: Opcode, length: Long, fin: Boolean): Option[Int] = { + val (op, l, f, m) = expectFrameHeaderOnNetwork() + op shouldEqual opcode + l shouldEqual length + f shouldEqual fin + m + } + def expectFrameHeaderOnNetwork(): (Opcode, Long, Boolean, Option[Int]) = { + val header = expectNetworkData(2) + + val fin = (header(0) & Protocol.FIN_MASK) != 0 + val op = header(0) & Protocol.OP_MASK + + val hasMask = (header(1) & Protocol.MASK_MASK) != 0 + val length7 = header(1) & Protocol.LENGTH_MASK + val length = length7 match { + case 126 ⇒ + val length16Bytes = expectNetworkData(2) + (length16Bytes(0) & 0xff) << 8 | (length16Bytes(1) & 0xff) << 0 + case 127 ⇒ + val length64Bytes = expectNetworkData(8) + (length64Bytes(0) & 0xff).toLong << 56 | + (length64Bytes(1) & 0xff).toLong << 48 | + (length64Bytes(2) & 0xff).toLong << 40 | + (length64Bytes(3) & 0xff).toLong << 32 | + (length64Bytes(4) & 0xff).toLong << 24 | + (length64Bytes(5) & 0xff).toLong << 16 | + (length64Bytes(6) & 0xff).toLong << 8 | + (length64Bytes(7) & 0xff).toLong << 0 + case x ⇒ x + } + val mask = + if (hasMask) { + val maskBytes = expectNetworkData(4) + val mask = + (maskBytes(0) & 0xff) << 24 | + (maskBytes(1) & 0xff) << 16 | + (maskBytes(2) & 0xff) << 8 | + (maskBytes(3) & 0xff) << 0 + Some(mask) + } else None + + (Opcode.forCode(op.toByte), length, fin, mask) + } +} diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestUtils.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestUtils.scala index 47f9556f2a..1821beae2b 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestUtils.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestUtils.scala @@ -43,14 +43,19 @@ object WSTestUtils { val lengthByte = lengthByteComponent | maskMask ByteString(opcodeByte.toByte, lengthByte.toByte) ++ lengthBytes ++ maskBytes } - def closeFrame(closeCode: Int, mask: Boolean): ByteString = + def frame(opcode: Opcode, data: ByteString, fin: Boolean, mask: Boolean): ByteString = if (mask) { val mask = Random.nextInt() - frameHeader(Opcode.Close, 2, fin = true, mask = Some(mask)) ++ - maskedBytes(shortBE(closeCode), mask)._1 + frameHeader(opcode, data.size, fin, mask = Some(mask)) ++ + maskedBytes(data, mask)._1 } else - frameHeader(Opcode.Close, 2, fin = true) ++ - shortBE(closeCode) + frameHeader(opcode, data.size, fin, mask = None) ++ data + + def closeFrame(closeCode: Int, mask: Boolean): ByteString = + frame(Opcode.Close, closeFrameData(closeCode), fin = true, mask) + + def closeFrameData(closeCode: Int): ByteString = + shortBE(closeCode) def maskedASCII(str: String, mask: Int): (ByteString, Int) = FrameEventParser.mask(ByteString(str, "ASCII"), mask) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala index 5a7a507099..fd5e297970 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala @@ -55,13 +55,9 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp | |""".stripMarginWithNewline("\r\n") - expectWSFrame(Protocol.Opcode.Text, - - ByteString("Message 1"), fin = true) - expectWSFrame(Protocol. - Opcode.Text, ByteString("Message 2"), fin = true) - expectWSFrame( - Protocol.Opcode.Text, ByteString("Message 3"), fin = true) + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true) + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 2"), fin = true) + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 3"), fin = true) expectWSFrame(Protocol.Opcode.Text, ByteString("Message 4"), fin = true) expectWSFrame(Protocol.Opcode.Text, ByteString("Message 5"), fin = true) expectWSCloseFrame(Protocol.CloseCodes.Regular) @@ -131,42 +127,13 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp } } - class TestSetup extends HttpServerTestSetupBase { + class TestSetup extends HttpServerTestSetupBase with WSTestSetupBase { implicit def system = spec.system implicit def materializer = spec.materializer - def sendWSFrame(opcode: Opcode, - data: ByteString, - fin: Boolean, - mask: Boolean = false, - rsv1: Boolean = false, - rsv2: Boolean = false, - rsv3: Boolean = false): Unit = { - val (theMask, theData) = - if (mask) { - val m = Random.nextInt() - (Some(m), maskedBytes(data, m)._1) - } else (None, data) - send(frameHeader(opcode, data.length, fin, theMask, rsv1, rsv2, rsv3) ++ theData) - } - - def sendWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit = - send(closeFrame(closeCode, mask)) - def expectNextChunk(): ByteString = { netOutSub.request(1) netOut.expectNext() } - - def expectWSFrame(opcode: Opcode, - data: ByteString, - fin: Boolean, - mask: Option[Int] = None, - rsv1: Boolean = false, - rsv2: Boolean = false, - rsv3: Boolean = false): Unit = - expectNextChunk() shouldEqual frameHeader(opcode, data.length, fin, mask, rsv1, rsv2, rsv3) ++ data - def expectWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit = - expectNextChunk() shouldEqual closeFrame(closeCode, mask) } } From 5e0caf8fe150ac7209f473f2153c1afd80d1b9cf Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Fri, 9 Oct 2015 17:26:55 +0200 Subject: [PATCH 02/10] +htc allow configuration of random source in Websockets + various cleanups --- .../main/scala/akka/http/ServerSettings.scala | 5 + .../engine/parsing/HttpRequestParser.scala | 4 +- .../impl/engine/parsing/ParserOutput.scala | 2 + .../HttpResponseRendererFactory.scala | 7 +- .../engine/server/HttpServerBluePrint.scala | 14 +- .../akka/http/impl/engine/ws/Handshake.scala | 157 ++++++++++-------- .../akka/http/impl/engine/ws/Masking.scala | 6 +- .../akka/http/impl/engine/ws/Randoms.scala | 15 ++ .../ws/UpgradeToWebsocketLowLevel.scala | 3 - .../UpgradeToWebsocketsResponseHeader.scala | 6 +- .../akka/http/impl/engine/ws/Websocket.scala | 22 ++- .../main/scala/akka/http/scaladsl/Http.scala | 18 +- .../akka/http/scaladsl/model/ws/Message.scala | 4 +- .../http/impl/engine/ws/MessageSpec.scala | 2 +- 14 files changed, 158 insertions(+), 107 deletions(-) create mode 100644 akka-http-core/src/main/scala/akka/http/impl/engine/ws/Randoms.scala diff --git a/akka-http-core/src/main/scala/akka/http/ServerSettings.scala b/akka-http-core/src/main/scala/akka/http/ServerSettings.scala index e1d4a3197a..7d73c46f68 100644 --- a/akka-http-core/src/main/scala/akka/http/ServerSettings.scala +++ b/akka-http-core/src/main/scala/akka/http/ServerSettings.scala @@ -4,6 +4,9 @@ package akka.http +import java.util.Random + +import akka.http.impl.engine.ws.Randoms import com.typesafe.config.Config import scala.language.implicitConversions @@ -31,6 +34,7 @@ final case class ServerSettings( backlog: Int, socketOptions: immutable.Traversable[SocketOption], defaultHostHeader: Host, + websocketRandomFactory: () ⇒ Random, parserSettings: ParserSettings) { require(0 < maxConnections, "max-connections must be > 0") @@ -65,6 +69,7 @@ object ServerSettings extends SettingsCompanion[ServerSettings]("akka.http.serve val info = result.errors.head.withSummary("Configured `default-host-header` is illegal") throw new ConfigurationException(info.formatPretty) }, + Randoms.SecureRandomInstances, // can currently only be overridden from code ParserSettings.fromSubConfig(root, c.getConfig("parsing"))) def apply(optionalSettings: Option[ServerSettings])(implicit actorRefFactory: ActorRefFactory): ServerSettings = diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala index 17511784c9..60fb35d7c0 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala @@ -11,7 +11,7 @@ import scala.annotation.tailrec import akka.actor.ActorRef import akka.stream.stage.{ Context, PushPullStage } import akka.stream.scaladsl.Flow -import akka.stream.scaladsl.{ Keep, Source } +import akka.stream.scaladsl.Source import akka.util.ByteString import akka.http.impl.engine.ws.Handshake import akka.http.impl.model.parser.CharacterClasses @@ -129,7 +129,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings, val allHeaders = if (method == HttpMethods.GET) { - Handshake.isWebsocketUpgrade(headers, hostHeaderPresent) match { + Handshake.Server.isWebsocketUpgrade(headers, hostHeaderPresent) match { case Some(upgrade) ⇒ upgrade :: allHeaders0 case None ⇒ allHeaders0 } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/ParserOutput.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/ParserOutput.scala index 035cca3e19..bafc7e9ffe 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/ParserOutput.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/ParserOutput.scala @@ -56,4 +56,6 @@ private[http] object ParserOutput { case object NeedMoreData extends MessageOutput case object NeedNextRequestMethod extends ResponseOutput + + final case class RemainingBytes(bytes: ByteString) extends ResponseOutput } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala index 64d676d17a..d6fd66c3d1 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala @@ -5,6 +5,7 @@ package akka.http.impl.engine.rendering import akka.http.impl.engine.ws.{ FrameEvent, UpgradeToWebsocketResponseHeader } +import akka.http.scaladsl.model.ws.Message import scala.annotation.tailrec import akka.event.LoggingAdapter @@ -159,7 +160,7 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser r ~~ connHeader ~~ CrLf headers .collectFirst { case u: UpgradeToWebsocketResponseHeader ⇒ u } - .foreach { header ⇒ closeMode = SwitchToWebsocket(header.handlerFlow) } + .foreach { header ⇒ closeMode = SwitchToWebsocket(header.handler) } } if (mustRenderTransferEncodingChunkedHeader && !transferEncodingSeen) r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf @@ -220,7 +221,7 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser sealed trait CloseMode case object DontClose extends CloseMode case object CloseConnection extends CloseMode - case class SwitchToWebsocket(handlerFlow: Flow[FrameEvent, FrameEvent, Any]) extends CloseMode + case class SwitchToWebsocket(handler: Either[Flow[FrameEvent, FrameEvent, Any], Flow[Message, Message, Any]]) extends CloseMode } /** @@ -237,5 +238,5 @@ private[http] sealed trait ResponseRenderingOutput /** INTERNAL API */ private[http] object ResponseRenderingOutput { private[http] case class HttpData(bytes: ByteString) extends ResponseRenderingOutput - private[http] case class SwitchToWebsocket(httpResponseBytes: ByteString, handlerFlow: Flow[FrameEvent, FrameEvent, Any]) extends ResponseRenderingOutput + private[http] case class SwitchToWebsocket(httpResponseBytes: ByteString, handler: Either[Flow[FrameEvent, FrameEvent, Any], Flow[Message, Message, Any]]) extends ResponseRenderingOutput } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index bacd72e1c4..fe8505bd99 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -5,8 +5,10 @@ package akka.http.impl.engine.server import java.net.InetSocketAddress +import java.util.Random import akka.http.ServerSettings +import akka.http.scaladsl.model.ws.Message import akka.stream.io._ import org.reactivestreams.{ Subscriber, Publisher } import scala.util.control.NonFatal @@ -145,7 +147,7 @@ private[http] object HttpServerBluePrint { // protocol routing val protocolRouter = b.add(new WebsocketSwitchRouter()) - val protocolMerge = b.add(new WebsocketMerge(ws.installHandler)) + val protocolMerge = b.add(new WebsocketMerge(ws.installHandler, settings.websocketRandomFactory)) protocolRouter.out0 ~> http ~> protocolMerge.in0 protocolRouter.out1 ~> websocket ~> protocolMerge.in1 @@ -355,7 +357,7 @@ private[http] object HttpServerBluePrint { } } } - class WebsocketMerge(installHandler: Flow[FrameEvent, FrameEvent, Any] ⇒ Unit) extends FlexiMerge[ByteString, FanInShape2[ResponseRenderingOutput, ByteString, ByteString]](new FanInShape2("websocketMerge"), Attributes.name("websocketMerge")) { + class WebsocketMerge(installHandler: Flow[FrameEvent, FrameEvent, Any] ⇒ Unit, websocketRandomFactory: () ⇒ Random) extends FlexiMerge[ByteString, FanInShape2[ResponseRenderingOutput, ByteString, ByteString]](new FanInShape2("websocketMerge"), Attributes.name("websocketMerge")) { def createMergeLogic(s: FanInShape2[ResponseRenderingOutput, ByteString, ByteString]): MergeLogic[ByteString] = new MergeLogic[ByteString] { var websocketHandlerWasInstalled: Boolean = false @@ -370,7 +372,13 @@ private[http] object HttpServerBluePrint { ctx.emit(bytes); SameState case ResponseRenderingOutput.SwitchToWebsocket(responseBytes, handlerFlow) ⇒ ctx.emit(responseBytes) - installHandler(handlerFlow) + + val frameHandler = handlerFlow match { + case Left(frameHandler) ⇒ frameHandler + case Right(messageHandler) ⇒ + Websocket.stack(serverSide = true, maskingRandomFactory = websocketRandomFactory).join(messageHandler) + } + installHandler(frameHandler) ctx.changeCompletionHandling(defaultCompletionHandling) websocketHandlerWasInstalled = true websocket diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala index abad53d2e7..cc8a5a7f2f 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala @@ -4,16 +4,22 @@ package akka.http.impl.engine.ws -import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.model.ws.{ Message, UpgradeToWebsocket } -import akka.http.scaladsl.model.{ StatusCodes, HttpResponse, HttpProtocol, HttpHeader } -import akka.parboiled2.util.Base64 -import akka.stream.Materializer -import akka.stream.scaladsl.Flow +import java.util.Random +import scala.collection.immutable import scala.collection.immutable.Seq import scala.reflect.ClassTag +import akka.parboiled2.util.Base64 + +import akka.stream.scaladsl.Flow + +import akka.http.impl.util._ + +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model.ws.{ Message, UpgradeToWebsocket } +import akka.http.scaladsl.model._ + /** * Server-side implementation of the Websocket handshake * @@ -22,7 +28,8 @@ import scala.reflect.ClassTag private[http] object Handshake { val CurrentWebsocketVersion = 13 - /* + object Server { + /* From: http://tools.ietf.org/html/rfc6455#section-4.2.1 1. An HTTP/1.1 or higher GET request, including a "Request-URI" @@ -57,69 +64,81 @@ private[http] object Handshake { list of values indicating which extensions the client would like to speak. The interpretation of this header field is discussed in Section 9.1. - */ - def isWebsocketUpgrade(headers: List[HttpHeader], hostHeaderPresent: Boolean): Option[UpgradeToWebsocket] = { - def find[T <: HttpHeader: ClassTag]: Option[T] = - headers.collectFirst { - case t: T ⇒ t - } - - val host = find[Host] - val upgrade = find[Upgrade] - val connection = find[Connection] - val key = find[`Sec-WebSocket-Key`] - val version = find[`Sec-WebSocket-Version`] - val origin = find[Origin] - val protocol = find[`Sec-WebSocket-Protocol`] - val supportedProtocols = protocol.toList.flatMap(_.protocols) - val extensions = find[`Sec-WebSocket-Extensions`] - - def isValidKey(key: String): Boolean = Base64.rfc2045().decode(key).length == 16 - - if (upgrade.exists(_.hasWebsocket) && - connection.exists(_.hasUpgrade) && - version.exists(_.hasVersion(CurrentWebsocketVersion)) && - key.exists(k ⇒ isValidKey(k.key))) { - - val header = new UpgradeToWebsocketLowLevel { - def requestedProtocols: Seq[String] = supportedProtocols - - def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String]): HttpResponse = { - require(subprotocol.forall(chosen ⇒ supportedProtocols.contains(chosen)), - s"Tried to choose invalid subprotocol '$subprotocol' which wasn't offered by the client: [${requestedProtocols.mkString(", ")}]") - buildResponse(key.get, handlerFlow, subprotocol) + */ + def isWebsocketUpgrade(headers: List[HttpHeader], hostHeaderPresent: Boolean): Option[UpgradeToWebsocket] = { + def find[T <: HttpHeader: ClassTag]: Option[T] = + headers.collectFirst { + case t: T ⇒ t } - } - Some(header) - } else None + + // val host = find[Host] + val upgrade = find[Upgrade] + val connection = find[Connection] + val key = find[`Sec-WebSocket-Key`] + val version = find[`Sec-WebSocket-Version`] + // val origin = find[Origin] + val protocol = find[`Sec-WebSocket-Protocol`] + val supportedProtocols = protocol.toList.flatMap(_.protocols) + // FIXME: support extensions + // val extensions = find[`Sec-WebSocket-Extensions`] + + def isValidKey(key: String): Boolean = Base64.rfc2045().decode(key).length == 16 + + if (upgrade.exists(_.hasWebsocket) && + connection.exists(_.hasUpgrade) && + version.exists(_.hasVersion(CurrentWebsocketVersion)) && + key.exists(k ⇒ isValidKey(k.key))) { + + val header = new UpgradeToWebsocketLowLevel { + def requestedProtocols: Seq[String] = supportedProtocols + + def handle(handler: Either[Flow[FrameEvent, FrameEvent, Any], Flow[Message, Message, Any]], subprotocol: Option[String]): HttpResponse = { + require(subprotocol.forall(chosen ⇒ supportedProtocols.contains(chosen)), + s"Tried to choose invalid subprotocol '$subprotocol' which wasn't offered by the client: [${requestedProtocols.mkString(", ")}]") + buildResponse(key.get, handler, subprotocol) + } + + def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String]): HttpResponse = + handle(Left(handlerFlow), subprotocol) + + override def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String] = None): HttpResponse = + handle(Right(handlerFlow), subprotocol) + } + Some(header) + } else None + } + + /* + From: http://tools.ietf.org/html/rfc6455#section-4.2.2 + + 1. A Status-Line with a 101 response code as per RFC 2616 + [RFC2616]. Such a response could look like "HTTP/1.1 101 + Switching Protocols". + + 2. An |Upgrade| header field with value "websocket" as per RFC + 2616 [RFC2616]. + + 3. A |Connection| header field with value "Upgrade". + + 4. A |Sec-WebSocket-Accept| header field. The value of this + header field is constructed by concatenating /key/, defined + above in step 4 in Section 4.2.2, with the string "258EAFA5- + E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this + concatenated value to obtain a 20-byte value and base64- + encoding (see Section 4 of [RFC4648]) this 20-byte hash. + */ + def buildResponse(key: `Sec-WebSocket-Key`, handler: Either[Flow[FrameEvent, FrameEvent, Any], Flow[Message, Message, Any]], subprotocol: Option[String]): HttpResponse = + HttpResponse( + StatusCodes.SwitchingProtocols, + subprotocol.map(p ⇒ `Sec-WebSocket-Protocol`(Seq(p))).toList ::: + List( + UpgradeHeader, + ConnectionUpgradeHeader, + `Sec-WebSocket-Accept`.forKey(key), + UpgradeToWebsocketResponseHeader(handler))) } - /* - From: http://tools.ietf.org/html/rfc6455#section-4.2.2 - - 1. A Status-Line with a 101 response code as per RFC 2616 - [RFC2616]. Such a response could look like "HTTP/1.1 101 - Switching Protocols". - - 2. An |Upgrade| header field with value "websocket" as per RFC - 2616 [RFC2616]. - - 3. A |Connection| header field with value "Upgrade". - - 4. A |Sec-WebSocket-Accept| header field. The value of this - header field is constructed by concatenating /key/, defined - above in step 4 in Section 4.2.2, with the string "258EAFA5- - E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this - concatenated value to obtain a 20-byte value and base64- - encoding (see Section 4 of [RFC4648]) this 20-byte hash. - */ - def buildResponse(key: `Sec-WebSocket-Key`, handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String]): HttpResponse = - HttpResponse( - StatusCodes.SwitchingProtocols, - subprotocol.map(p ⇒ `Sec-WebSocket-Protocol`(Seq(p))).toList ::: - List( - Upgrade(List(UpgradeProtocol("websocket"))), - Connection(List("upgrade")), - `Sec-WebSocket-Accept`.forKey(key), - UpgradeToWebsocketResponseHeader(handlerFlow))) + val UpgradeHeader = Upgrade(List(UpgradeProtocol("websocket"))) + val ConnectionUpgradeHeader = Connection(List("upgrade")) + val SecWebsocketVersionHeader = `Sec-WebSocket-Version`(Seq(CurrentWebsocketVersion)) } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala index 730f782a42..c4be386eb2 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala @@ -4,10 +4,10 @@ package akka.http.impl.engine.ws -import akka.stream.scaladsl.{ Keep, BidiFlow, Flow } -import akka.stream.stage.{ SyncDirective, Context, StageState, StatefulStage } +import java.util.Random -import scala.util.Random +import akka.stream.scaladsl.{ Keep, BidiFlow, Flow } +import akka.stream.stage.{ SyncDirective, Context, StatefulStage } /** * Implements Websocket Frame masking. diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Randoms.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Randoms.scala new file mode 100644 index 0000000000..735a2e0c72 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Randoms.scala @@ -0,0 +1,15 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.impl.engine.ws + +import java.security.SecureRandom +import java.util.Random + +object Randoms { + /** A factory that creates SecureRandom instances */ + private[http] case object SecureRandomInstances extends (() ⇒ Random) { + override def apply(): Random = new SecureRandom() + } +} diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/UpgradeToWebsocketLowLevel.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/UpgradeToWebsocketLowLevel.scala index 3f9179f859..771171a9af 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/UpgradeToWebsocketLowLevel.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/UpgradeToWebsocketLowLevel.scala @@ -26,7 +26,4 @@ private[http] abstract class UpgradeToWebsocketLowLevel extends InternalCustomHe * INTERNAL API (for now) */ private[http] def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String] = None): HttpResponse - - override def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String] = None): HttpResponse = - handleFrames(Websocket.stack(serverSide = true).join(handlerFlow), subprotocol) } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/UpgradeToWebsocketsResponseHeader.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/UpgradeToWebsocketsResponseHeader.scala index 210863d305..7b37733021 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/UpgradeToWebsocketsResponseHeader.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/UpgradeToWebsocketsResponseHeader.scala @@ -5,12 +5,12 @@ package akka.http.impl.engine.ws import akka.http.scaladsl.model.headers.CustomHeader +import akka.http.scaladsl.model.ws.Message import akka.stream.Materializer import akka.stream.scaladsl.Flow -private[http] case class UpgradeToWebsocketResponseHeader(handlerFlow: Flow[FrameEvent, FrameEvent, Any]) - extends InternalCustomHeader("UpgradeToWebsocketResponseHeader") { -} +private[http] final case class UpgradeToWebsocketResponseHeader(handler: Either[Flow[FrameEvent, FrameEvent, Any], Flow[Message, Message, Any]]) + extends InternalCustomHeader("UpgradeToWebsocketResponseHeader") private[http] abstract class InternalCustomHeader(val name: String) extends CustomHeader { override def suppressRendering: Boolean = true diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala index 6efa2adae0..de1cda24c6 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala @@ -4,7 +4,7 @@ package akka.http.impl.engine.ws -import java.security.SecureRandom +import java.util.Random import akka.util.ByteString @@ -27,6 +27,16 @@ import akka.http.scaladsl.model.ws._ private[http] object Websocket { import FrameHandler._ + /** + * A stack of all the higher WS layers between raw frames and the user API. + */ + def stack(serverSide: Boolean, + maskingRandomFactory: () ⇒ Random, + closeTimeout: FiniteDuration = 3.seconds): BidiFlow[FrameEvent, Message, Message, FrameEvent, Unit] = + masking(serverSide, maskingRandomFactory) atop + frameHandling(serverSide, closeTimeout) atop + messageAPI(serverSide, closeTimeout) + /** The lowest layer that implements the binary protocol */ def framing: BidiFlow[ByteString, FrameEvent, FrameEvent, ByteString, Unit] = BidiFlow.wrap( @@ -35,8 +45,8 @@ private[http] object Websocket { .named("ws-framing") /** The layer that handles masking using the rules defined in the specification */ - def masking(serverSide: Boolean): BidiFlow[FrameEvent, FrameEvent, FrameEvent, FrameEvent, Unit] = - Masking(serverSide, () ⇒ new SecureRandom()) + def masking(serverSide: Boolean, maskingRandomFactory: () ⇒ Random): BidiFlow[FrameEvent, FrameEvent, FrameEvent, FrameEvent, Unit] = + Masking(serverSide, maskingRandomFactory) .named("ws-masking") /** @@ -219,12 +229,6 @@ private[http] object Websocket { }.named("ws-message-api") } - def stack(serverSide: Boolean = true, - closeTimeout: FiniteDuration = 3.seconds): BidiFlow[FrameEvent, Message, Message, FrameEvent, Unit] = - masking(serverSide) atop - frameHandling(serverSide, closeTimeout) atop - messageAPI(serverSide, closeTimeout) - object Tick case object SwitchToWebsocketToken } diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala index e231f5013a..081a068034 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala @@ -568,7 +568,7 @@ object Http extends ExtensionId[HttpExt] with ExtensionIdProvider { * @param localAddress The local address of the endpoint bound by the materialization of the `connections` [[Source]] * */ - case class ServerBinding(localAddress: InetSocketAddress)(private val unbindAction: () ⇒ Future[Unit]) { + final case class ServerBinding(localAddress: InetSocketAddress)(private val unbindAction: () ⇒ Future[Unit]) { /** * Asynchronously triggers the unbinding of the port that was bound by the materialization of the `connections` @@ -582,7 +582,7 @@ object Http extends ExtensionId[HttpExt] with ExtensionIdProvider { /** * Represents one accepted incoming HTTP connection. */ - case class IncomingConnection( + final case class IncomingConnection( localAddress: InetSocketAddress, remoteAddress: InetSocketAddress, flow: Flow[HttpResponse, HttpRequest, Unit]) { @@ -612,12 +612,12 @@ object Http extends ExtensionId[HttpExt] with ExtensionIdProvider { /** * Represents a prospective outgoing HTTP connection. */ - case class OutgoingConnection(localAddress: InetSocketAddress, remoteAddress: InetSocketAddress) + final case class OutgoingConnection(localAddress: InetSocketAddress, remoteAddress: InetSocketAddress) /** * Represents a connection pool to a specific target host and pool configuration. */ - case class HostConnectionPool(setup: HostConnectionPoolSetup)( + final case class HostConnectionPool(setup: HostConnectionPoolSetup)( private[http] val gatewayFuture: Future[PoolGateway]) extends javadsl.HostConnectionPool { // enable test access /** @@ -641,11 +641,11 @@ object Http extends ExtensionId[HttpExt] with ExtensionIdProvider { import scala.collection.JavaConverters._ //# https-context-impl -case class HttpsContext(sslContext: SSLContext, - enabledCipherSuites: Option[immutable.Seq[String]] = None, - enabledProtocols: Option[immutable.Seq[String]] = None, - clientAuth: Option[ClientAuth] = None, - sslParameters: Option[SSLParameters] = None) +final case class HttpsContext(sslContext: SSLContext, + enabledCipherSuites: Option[immutable.Seq[String]] = None, + enabledProtocols: Option[immutable.Seq[String]] = None, + clientAuth: Option[ClientAuth] = None, + sslParameters: Option[SSLParameters] = None) //# extends akka.http.javadsl.HttpsContext { def firstSession = NegotiateNewSession(enabledCipherSuites, enabledProtocols, clientAuth, sslParameters) diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/ws/Message.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/ws/Message.scala index f44b294b30..4a38aa64bd 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/ws/Message.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/ws/Message.scala @@ -16,7 +16,7 @@ sealed trait Message /** * A binary */ -trait TextMessage extends Message { +sealed trait TextMessage extends Message { /** * The contents of this message as a stream. */ @@ -38,7 +38,7 @@ object TextMessage { final private case class Streamed(textStream: Source[String, _]) extends TextMessage } //#message-model -trait BinaryMessage extends Message { +sealed trait BinaryMessage extends Message { /** * The contents of this message as a stream. */ diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala index 7423c6b575..cfeecee154 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala @@ -782,7 +782,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { Source(netIn) .via(printEvent("netIn")) .transform(() ⇒ new FrameEventParser) - .via(Websocket.stack(serverSide, closeTimeout = closeTimeout).join(messageHandler)) + .via(Websocket.stack(serverSide, maskingRandomFactory = Randoms.SecureRandomInstances, closeTimeout = closeTimeout).join(messageHandler)) .via(printEvent("frameRendererIn")) .transform(() ⇒ new FrameEventRenderer) .via(printEvent("frameRendererOut")) From 4cbbb7dbad3a01e4a73dec4dc74af173c5018953 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Tue, 6 Oct 2015 14:37:06 +0200 Subject: [PATCH 03/10] =htc fix WS masking for empty frames on client side --- .../akka/http/impl/engine/ws/Masking.scala | 9 ++-- .../http/impl/engine/ws/MessageSpec.scala | 45 ++++++++++++++++++- 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala index c4be386eb2..a4f805cee9 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Masking.scala @@ -51,12 +51,9 @@ private[http] object Masking { def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective = part match { case start @ FrameStart(header, data) ⇒ - if (header.length == 0) ctx.push(part) - else { - val mask = extractMask(header) - become(new Running(mask)) - current.onPush(start.copy(header = setNewMask(header, mask)), ctx) - } + val mask = extractMask(header) + become(new Running(mask)) + current.onPush(start.copy(header = setNewMask(header, mask)), ctx) case _: FrameData ⇒ ctx.fail(new IllegalStateException("unexpected FrameData (need FrameStart first)")) } diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala index cfeecee154..3914fdf6fb 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala @@ -134,6 +134,13 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { sub.expectNext(ByteString("def", "ASCII")) sub.expectComplete() } + "unmask masked input on the server side for empty frame" in new ServerTestSetup { + val mask = Random.nextInt() + val header = frameHeader(Opcode.Binary, 0, fin = true, mask = Some(mask)) + + pushInput(header) + expectBinaryMessage(BinaryMessage.Strict(ByteString.empty)) + } } "for text messages" - { "empty message" in new ClientTestSetup { @@ -207,6 +214,13 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { sub.expectNext(ByteString("cdef€", "UTF-8")) sub.expectComplete() } + "unmask masked input on the server side for empty frame" in new ServerTestSetup { + val mask = Random.nextInt() + val header = frameHeader(Opcode.Text, 0, fin = true, mask = Some(mask)) + + pushInput(header) + expectTextMessage(TextMessage.Strict("")) + } } } "render frames from messages" - { @@ -265,6 +279,10 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { sub.sendComplete() expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) } + "and mask input on the client side for empty frame" in new ClientTestSetup { + pushMessage(BinaryMessage(ByteString.empty)) + expectMaskedFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = true) + } } "for text messages" - { "for a short strict message" in new ServerTestSetup { @@ -347,6 +365,10 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { sub.sendComplete() expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) } + "and mask input on the client side for empty frame" in new ClientTestSetup { + pushMessage(TextMessage("")) + expectMaskedFrameOnNetwork(Opcode.Text, ByteString.empty, fin = true) + } } } "supply automatic low-level websocket behavior" - { @@ -440,7 +462,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { } "after receiving close frame without close code" in new ServerTestSetup { netInSub.expectRequest() - pushInput(frameHeader(Opcode.Close, 0, fin = true)) + pushInput(frameHeader(Opcode.Close, 0, fin = true, mask = Some(Random.nextInt()))) messageIn.expectComplete() messageOutSub.sendComplete() @@ -479,7 +501,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { netOutSub.request(10) messageInSub.request(10) - pushInput(frameHeader(Protocol.Opcode.Binary, 0, fin = false)) + pushInput(frameHeader(Protocol.Opcode.Binary, 0, fin = false, mask = Some(Random.nextInt()))) val dataSource = expectBinaryMessage().dataStream val inSubscriber = TestSubscriber.manualProbe[ByteString]() dataSource.runWith(Sink(inSubscriber)) @@ -742,10 +764,23 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { pushInput(input) expectProtocolErrorOnNetwork() } + "unmasked input on the server side for empty frame" in new ServerTestSetup { + val input = frameHeader(Opcode.Binary, 0, fin = true) + + pushInput(input) + expectProtocolErrorOnNetwork() + } "masked input on the client side" in new ClientTestSetup { val mask = Random.nextInt() val input = frameHeader(Opcode.Binary, 6, fin = true, mask = Some(mask)) ++ maskedASCII("abcdef", mask)._1 + pushInput(input) + expectProtocolErrorOnNetwork() + } + "masked input on the client side for empty frame" in new ClientTestSetup { + val mask = Random.nextInt() + val input = frameHeader(Opcode.Binary, 0, fin = true, mask = Some(mask)) + pushInput(input) expectProtocolErrorOnNetwork() } @@ -813,9 +848,15 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { def expectBinaryMessage(): BinaryMessage = expectMessage().asInstanceOf[BinaryMessage] + def expectBinaryMessage(message: BinaryMessage): Unit = + expectBinaryMessage() shouldEqual message + def expectTextMessage(): TextMessage = expectMessage().asInstanceOf[TextMessage] + def expectTextMessage(message: TextMessage): Unit = + expectTextMessage() shouldEqual message + var inBuffer = ByteString.empty @tailrec final def expectNetworkData(bytes: Int): ByteString = if (inBuffer.size >= bytes) { From 11e593a1fa847285cd131e7230d709b39c3dad71 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Fri, 9 Oct 2015 17:37:08 +0200 Subject: [PATCH 04/10] =htc introduce ByteStringSinkProbe and update tests to new infrastructure --- .../impl/engine/server/HttpServerSpec.scala | 175 ++++++++---------- .../server/HttpServerTestSetupBase.scala | 35 ++-- .../impl/engine/ws/ByteStringSinkProbe.scala | 67 +++++++ .../http/impl/engine/ws/WSTestSetupBase.scala | 22 +-- .../impl/engine/ws/WebsocketServerSpec.scala | 25 +-- 5 files changed, 181 insertions(+), 143 deletions(-) create mode 100644 akka-http-core/src/test/scala/akka/http/impl/engine/ws/ByteStringSinkProbe.scala diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala index 51a08a6323..b1b3db4393 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala @@ -32,7 +32,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") send("""GET / HTTP/1.1 |Host: example.com | - |""".stripMarginWithNewline("\r\n")) + |""") expectRequest shouldEqual HttpRequest(uri = "http://example.com/", headers = List(Host("example.com"))) } @@ -42,7 +42,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com |Content-Length: 12 | - |""".stripMarginWithNewline("\r\n")) + |""") inside(expectRequest) { case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ @@ -65,10 +65,9 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") send("""GET / HTTP/1.2 |Host: example.com | - |""".stripMarginWithNewline("\r\n")) + |""") - netOutSub.request(1) - wipeDate(netOut.expectNext().utf8String) shouldEqual + expectResponseWithWipedDate( """HTTP/1.1 505 HTTP Version Not Supported |Server: akka-http/test |Date: XXXX @@ -76,7 +75,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Content-Type: text/plain; charset=UTF-8 |Content-Length: 74 | - |The server does not support the HTTP protocol version used in the request.""".stripMarginWithNewline("\r\n") + |The server does not support the HTTP protocol version used in the request.""") } "report an invalid Chunked stream" in new TestSetup { @@ -86,7 +85,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |6 |abcdef - |""".stripMarginWithNewline("\r\n")) + |""") inside(expectRequest) { case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ @@ -102,11 +101,10 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") error.getMessage shouldEqual "Illegal character 'g' in chunk start" requests.expectComplete() - netOutSub.request(1) - responsesSub.expectRequest() - responsesSub.sendError(error.asInstanceOf[Exception]) + responses.expectRequest() + responses.sendError(error.asInstanceOf[Exception]) - wipeDate(netOut.expectNext().utf8String) shouldEqual + expectResponseWithWipedDate( """HTTP/1.1 400 Bad Request |Server: akka-http/test |Date: XXXX @@ -114,7 +112,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Content-Type: text/plain; charset=UTF-8 |Content-Length: 36 | - |Illegal character 'g' in chunk start""".stripMarginWithNewline("\r\n") + |Illegal character 'g' in chunk start""") } } @@ -123,7 +121,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com |Content-Length: 12 | - |abcdefghijkl""".stripMarginWithNewline("\r\n")) + |abcdefghijkl""") expectRequest shouldEqual HttpRequest( @@ -138,7 +136,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com |Content-Length: 12 | - |abcdef""".stripMarginWithNewline("\r\n")) + |abcdef""") inside(expectRequest) { case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ @@ -161,7 +159,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |6 |abcdef - |""".stripMarginWithNewline("\r\n")) + |""") inside(expectRequest) { case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ @@ -182,7 +180,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com |Content-Length: 12 | - |abcdefghijkl""".stripMarginWithNewline("\r\n")) + |abcdefghijkl""") expectRequest shouldEqual HttpRequest( @@ -195,7 +193,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com |Content-Length: 12 | - |mnopqrstuvwx""".stripMarginWithNewline("\r\n")) + |mnopqrstuvwx""") expectRequest shouldEqual HttpRequest( @@ -210,7 +208,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com |Content-Length: 12 | - |abcdef""".stripMarginWithNewline("\r\n")) + |abcdef""") inside(expectRequest) { case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ @@ -232,7 +230,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com |Content-Length: 5 | - |abcde""".stripMarginWithNewline("\r\n")) + |abcde""") inside(expectRequest) { case HttpRequest(POST, _, _, HttpEntity.Strict(_, data), _) ⇒ @@ -247,7 +245,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |6 |abcdef - |""".stripMarginWithNewline("\r\n")) + |""") inside(expectRequest) { case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ @@ -270,7 +268,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com |Content-Length: 5 | - |abcde""".stripMarginWithNewline("\r\n")) + |abcde""") inside(expectRequest) { case HttpRequest(POST, _, _, HttpEntity.Strict(_, data), _) ⇒ @@ -283,7 +281,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com |Content-Length: 12 | - |abcdef""".stripMarginWithNewline("\r\n")) + |abcdef""") inside(expectRequest) { case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ @@ -306,7 +304,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |6 |abcdef - |""".stripMarginWithNewline("\r\n")) + |""") inside(expectRequest) { case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ @@ -328,7 +326,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com |Content-Length: 12 | - |abcdef""".stripMarginWithNewline("\r\n")) + |abcdef""") inside(expectRequest) { case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ByteString] @@ -349,7 +347,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") | |6 |abcdef - |""".stripMarginWithNewline("\r\n")) + |""") inside(expectRequest) { case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ChunkStreamPart] @@ -368,7 +366,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") send("""HEAD / HTTP/1.1 |Host: example.com | - |""".stripMarginWithNewline("\r\n")) + |""") expectRequest shouldEqual HttpRequest(GET, uri = "http://example.com/", headers = List(Host("example.com"))) } @@ -377,7 +375,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") send("""HEAD / HTTP/1.1 |Host: example.com | - |""".stripMarginWithNewline("\r\n")) + |""") expectRequest shouldEqual HttpRequest(HEAD, uri = "http://example.com/", headers = List(Host("example.com"))) } @@ -385,19 +383,18 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") send("""HEAD / HTTP/1.1 |Host: example.com | - |""".stripMarginWithNewline("\r\n")) + |""") inside(expectRequest) { case HttpRequest(GET, _, _, _, _) ⇒ - responsesSub.sendNext(HttpResponse(entity = HttpEntity.Strict(ContentTypes.`text/plain`, ByteString("abcd")))) - netOutSub.request(1) - wipeDate(netOut.expectNext().utf8String) shouldEqual + responses.sendNext(HttpResponse(entity = HttpEntity.Strict(ContentTypes.`text/plain`, ByteString("abcd")))) + expectResponseWithWipedDate( """|HTTP/1.1 200 OK |Server: akka-http/test |Date: XXXX |Content-Type: text/plain |Content-Length: 4 | - |""".stripMarginWithNewline("\r\n") + |""") } } @@ -405,22 +402,21 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") send("""HEAD / HTTP/1.1 |Host: example.com | - |""".stripMarginWithNewline("\r\n")) + |""") val data = TestPublisher.manualProbe[ByteString]() inside(expectRequest) { case HttpRequest(GET, _, _, _, _) ⇒ - responsesSub.sendNext(HttpResponse(entity = HttpEntity.Default(ContentTypes.`text/plain`, 4, Source(data)))) - netOutSub.request(1) + responses.sendNext(HttpResponse(entity = HttpEntity.Default(ContentTypes.`text/plain`, 4, Source(data)))) val dataSub = data.expectSubscription() dataSub.expectCancellation() - wipeDate(netOut.expectNext().utf8String) shouldEqual + expectResponseWithWipedDate( """|HTTP/1.1 200 OK |Server: akka-http/test |Date: XXXX |Content-Type: text/plain |Content-Length: 4 | - |""".stripMarginWithNewline("\r\n") + |""") } } @@ -428,46 +424,44 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") send("""HEAD / HTTP/1.1 |Host: example.com | - |""".stripMarginWithNewline("\r\n")) + |""") val data = TestPublisher.manualProbe[ByteString]() inside(expectRequest) { case HttpRequest(GET, _, _, _, _) ⇒ - responsesSub.sendNext(HttpResponse(entity = HttpEntity.CloseDelimited(ContentTypes.`text/plain`, Source(data)))) - netOutSub.request(1) + responses.sendNext(HttpResponse(entity = HttpEntity.CloseDelimited(ContentTypes.`text/plain`, Source(data)))) val dataSub = data.expectSubscription() dataSub.expectCancellation() - wipeDate(netOut.expectNext().utf8String) shouldEqual + expectResponseWithWipedDate( """|HTTP/1.1 200 OK |Server: akka-http/test |Date: XXXX |Content-Type: text/plain | - |""".stripMarginWithNewline("\r\n") + |""") } // No close should happen here since this was a HEAD request - netOut.expectNoMsg(50.millis) + netOut.expectNoBytes(50.millis) } "not emit entities when responding to HEAD requests if transparent-head-requests is enabled (with Chunked)" in new TestSetup { send("""HEAD / HTTP/1.1 |Host: example.com | - |""".stripMarginWithNewline("\r\n")) + |""") val data = TestPublisher.manualProbe[ChunkStreamPart]() inside(expectRequest) { case HttpRequest(GET, _, _, _, _) ⇒ - responsesSub.sendNext(HttpResponse(entity = HttpEntity.Chunked(ContentTypes.`text/plain`, Source(data)))) - netOutSub.request(1) + responses.sendNext(HttpResponse(entity = HttpEntity.Chunked(ContentTypes.`text/plain`, Source(data)))) val dataSub = data.expectSubscription() dataSub.expectCancellation() - wipeDate(netOut.expectNext().utf8String) shouldEqual + expectResponseWithWipedDate( """|HTTP/1.1 200 OK |Server: akka-http/test |Date: XXXX |Transfer-Encoding: chunked |Content-Type: text/plain | - |""".stripMarginWithNewline("\r\n") + |""") } } @@ -476,15 +470,14 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com |Connection: close | - |""".stripMarginWithNewline("\r\n")) + |""") val data = TestPublisher.manualProbe[ByteString]() inside(expectRequest) { case HttpRequest(GET, _, _, _, _) ⇒ - responsesSub.sendNext(HttpResponse(entity = CloseDelimited(ContentTypes.`text/plain`, Source(data)))) - netOutSub.request(1) + responses.sendNext(HttpResponse(entity = CloseDelimited(ContentTypes.`text/plain`, Source(data)))) val dataSub = data.expectSubscription() dataSub.expectCancellation() - netOut.expectNext() + netOut.expectBytes(1) } netOut.expectComplete() } @@ -495,34 +488,33 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Expect: 100-continue |Content-Length: 16 | - |""".stripMarginWithNewline("\r\n")) + |""") inside(expectRequest) { case HttpRequest(POST, _, _, Default(ContentType(`application/octet-stream`, None), 16, data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ByteString] data.to(Sink(dataProbe)).run() val dataSub = dataProbe.expectSubscription() - netOutSub.request(2) - netOut.expectNoMsg(50.millis) + netOut.expectNoBytes(50.millis) dataSub.request(1) // triggers `100 Continue` response - wipeDate(netOut.expectNext().utf8String) shouldEqual + expectResponseWithWipedDate( """HTTP/1.1 100 Continue |Server: akka-http/test |Date: XXXX | - |""".stripMarginWithNewline("\r\n") + |""") dataProbe.expectNoMsg(50.millis) send("0123456789ABCDEF") dataProbe.expectNext(ByteString("0123456789ABCDEF")) dataProbe.expectComplete() - responsesSub.sendNext(HttpResponse(entity = "Yeah")) - wipeDate(netOut.expectNext().utf8String) shouldEqual + responses.sendNext(HttpResponse(entity = "Yeah")) + expectResponseWithWipedDate( """HTTP/1.1 200 OK |Server: akka-http/test |Date: XXXX |Content-Type: text/plain; charset=UTF-8 |Content-Length: 4 | - |Yeah""".stripMarginWithNewline("\r\n") + |Yeah""") } } @@ -532,39 +524,38 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Expect: 100-continue |Transfer-Encoding: chunked | - |""".stripMarginWithNewline("\r\n")) + |""") inside(expectRequest) { case HttpRequest(POST, _, _, Chunked(ContentType(`application/octet-stream`, None), data), _) ⇒ val dataProbe = TestSubscriber.manualProbe[ChunkStreamPart] data.to(Sink(dataProbe)).run() val dataSub = dataProbe.expectSubscription() - netOutSub.request(2) - netOut.expectNoMsg(50.millis) + netOut.expectNoBytes(50.millis) dataSub.request(2) // triggers `100 Continue` response - wipeDate(netOut.expectNext().utf8String) shouldEqual + expectResponseWithWipedDate( """HTTP/1.1 100 Continue |Server: akka-http/test |Date: XXXX | - |""".stripMarginWithNewline("\r\n") + |""") dataProbe.expectNoMsg(50.millis) send("""10 |0123456789ABCDEF |0 | - |""".stripMarginWithNewline("\r\n")) + |""") dataProbe.expectNext(Chunk(ByteString("0123456789ABCDEF"))) dataProbe.expectNext(LastChunk) dataProbe.expectComplete() - responsesSub.sendNext(HttpResponse(entity = "Yeah")) - wipeDate(netOut.expectNext().utf8String) shouldEqual + responses.sendNext(HttpResponse(entity = "Yeah")) + expectResponseWithWipedDate( """HTTP/1.1 200 OK |Server: akka-http/test |Date: XXXX |Content-Type: text/plain; charset=UTF-8 |Content-Length: 4 | - |Yeah""".stripMarginWithNewline("\r\n") + |Yeah""") } } @@ -574,12 +565,11 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Expect: 100-continue |Content-Length: 16 | - |""".stripMarginWithNewline("\r\n")) + |""") inside(expectRequest) { case HttpRequest(POST, _, _, Default(ContentType(`application/octet-stream`, None), 16, data), _) ⇒ - netOutSub.request(1) - responsesSub.sendNext(HttpResponse(entity = "Yeah")) - wipeDate(netOut.expectNext().utf8String) shouldEqual + responses.sendNext(HttpResponse(entity = "Yeah")) + expectResponseWithWipedDate( """HTTP/1.1 200 OK |Server: akka-http/test |Date: XXXX @@ -587,7 +577,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Content-Type: text/plain; charset=UTF-8 |Content-Length: 4 | - |Yeah""".stripMarginWithNewline("\r\n") + |Yeah""") } } @@ -599,18 +589,17 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") expectRequest shouldEqual HttpRequest(uri = "http://example.com/", headers = List(Host("example.com"))) - netOutSub.request(1) - responsesSub.expectRequest() - responsesSub.sendError(new RuntimeException("CRASH BOOM BANG")) + responses.expectRequest() + responses.sendError(new RuntimeException("CRASH BOOM BANG")) - wipeDate(netOut.expectNext().utf8String) shouldEqual + expectResponseWithWipedDate( """HTTP/1.1 500 Internal Server Error |Server: akka-http/test |Date: XXXX |Connection: close |Content-Length: 0 | - |""".stripMarginWithNewline("\r\n") + |""") } "correctly consume and render large requests and responses" in new TestSetup { @@ -618,22 +607,20 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Host: example.com |Content-Length: 100000 | - |""".stripMarginWithNewline("\r\n")) + |""") val HttpRequest(POST, _, _, entity, _) = expectRequest - responsesSub.expectRequest() - responsesSub.sendNext(HttpResponse(entity = entity)) - responsesSub.sendComplete() + responses.sendNext(HttpResponse(entity = entity)) + responses.sendComplete() - netOutSub.request(1) - wipeDate(netOut.expectNext().utf8String) shouldEqual + expectResponseWithWipedDate( """HTTP/1.1 200 OK |Server: akka-http/test |Date: XXXX |Content-Type: application/octet-stream |Content-Length: 100000 | - |""".stripMarginWithNewline("\r\n") + |""") val random = new Random() @tailrec def rec(bytesLeft: Int): Unit = @@ -641,13 +628,12 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") val count = math.min(random.nextInt(1000) + 1, bytesLeft) val data = random.alphanumeric.take(count).mkString send(data) - netOutSub.request(1) - netOut.expectNext().utf8String shouldEqual data + netOut.expectUtf8EncodedString(data) rec(bytesLeft - count) } rec(100000) - netInSub.sendComplete() + netIn.sendComplete() requests.expectComplete() netOut.expectComplete() } @@ -656,7 +642,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") send("""GET //foo HTTP/1.1 |Host: example.com | - |""".stripMarginWithNewline("\r\n")) + |""") expectRequest shouldEqual HttpRequest(uri = "http://example.com//foo", headers = List(Host("example.com"))) } @@ -664,7 +650,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") "use default-host-header for HTTP/1.0 requests" in new TestSetup { send("""GET /abc HTTP/1.0 | - |""".stripMarginWithNewline("\r\n")) + |""") expectRequest shouldEqual HttpRequest(uri = "http://example.com/abc", protocol = HttpProtocols.`HTTP/1.0`) @@ -673,10 +659,9 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") "fail an HTTP/1.0 request with 400 if no default-host-header is set" in new TestSetup { send("""GET /abc HTTP/1.0 | - |""".stripMarginWithNewline("\r\n")) + |""") - netOutSub.request(1) - wipeDate(netOut.expectNext().utf8String) shouldEqual + expectResponseWithWipedDate( """|HTTP/1.1 400 Bad Request |Server: akka-http/test |Date: XXXX @@ -684,7 +669,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") |Content-Type: text/plain; charset=UTF-8 |Content-Length: 41 | - |Request is missing required `Host` header""".stripMarginWithNewline("\r\n") + |Request is missing required `Host` header""") } "support remote-address-header" in new TestSetup { diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala index b59598a253..55e646d646 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala @@ -6,6 +6,7 @@ package akka.http.impl.engine.server import java.net.InetSocketAddress +import akka.http.impl.engine.ws.ByteStringSinkProbe import akka.stream.io.{ SendBytes, SslTlsOutbound, SessionBytes } import scala.concurrent.duration.FiniteDuration @@ -28,21 +29,21 @@ abstract class HttpServerTestSetupBase { implicit def system: ActorSystem implicit def materializer: Materializer - val requests = TestSubscriber.manualProbe[HttpRequest] - val responses = TestPublisher.manualProbe[HttpResponse]() + val requests = TestSubscriber.probe[HttpRequest] + val responses = TestPublisher.probe[HttpResponse]() def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test"))))) def remoteAddress: Option[InetSocketAddress] = None val (netIn, netOut) = { - val netIn = TestPublisher.manualProbe[ByteString]() - val netOut = TestSubscriber.manualProbe[ByteString] + val netIn = TestPublisher.probe[ByteString]() + val netOut = ByteStringSinkProbe() FlowGraph.closed(HttpServerBluePrint(settings, remoteAddress = remoteAddress, log = NoLogging)) { implicit b ⇒ server ⇒ import FlowGraph.Implicits._ Source(netIn) ~> Flow[ByteString].map(SessionBytes(null, _)) ~> server.in2 - server.out1 ~> Flow[SslTlsOutbound].collect { case SendBytes(x) ⇒ x } ~> Sink(netOut) + server.out1 ~> Flow[SslTlsOutbound].collect { case SendBytes(x) ⇒ x } ~> netOut.sink server.out2 ~> Sink(requests) Source(responses) ~> server.in1 }.run() @@ -50,26 +51,26 @@ abstract class HttpServerTestSetupBase { netIn -> netOut } + def expectResponseWithWipedDate(expected: String): Unit = { + val trimmed = expected.stripMarginWithNewline("\r\n") + // XXXX = 4 bytes, ISO Date Time String = 29 bytes => need to request 15 bytes more than expected string + val expectedSize = ByteString(trimmed, "utf8").length + 25 + val received = wipeDate(netOut.expectBytes(expectedSize).utf8String) + assert(received == trimmed, s"Expected request '$trimmed' but got '$received'") + } + def wipeDate(string: String) = string.fastSplit('\n').map { case s if s.startsWith("Date:") ⇒ "Date: XXXX\r" case s ⇒ s }.mkString("\n") - val netInSub = netIn.expectSubscription() - val netOutSub = netOut.expectSubscription() - val requestsSub = requests.expectSubscription() - val responsesSub = responses.expectSubscription() - - def expectRequest: HttpRequest = { - requestsSub.request(1) - requests.expectNext() - } + def expectRequest: HttpRequest = requests.requestNext() def expectNoRequest(max: FiniteDuration): Unit = requests.expectNoMsg(max) def expectNetworkClose(): Unit = netOut.expectComplete() - def send(data: ByteString): Unit = netInSub.sendNext(data) - def send(data: String): Unit = send(ByteString(data, "UTF8")) + def send(data: ByteString): Unit = netIn.sendNext(data) + def send(string: String): Unit = send(ByteString(string.stripMarginWithNewline("\r\n"), "UTF8")) - def closeNetworkInput(): Unit = netInSub.sendComplete() + def closeNetworkInput(): Unit = netIn.sendComplete() } \ No newline at end of file diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/ByteStringSinkProbe.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/ByteStringSinkProbe.scala new file mode 100644 index 0000000000..77906bafc9 --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/ByteStringSinkProbe.scala @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.impl.engine.ws + +import akka.actor.ActorSystem +import akka.stream.scaladsl.{ Source, Sink } +import akka.stream.testkit.TestSubscriber +import akka.util.ByteString + +import scala.annotation.tailrec +import scala.concurrent.duration.FiniteDuration + +trait ByteStringSinkProbe { + def sink: Sink[ByteString, Unit] + + def expectBytes(length: Int): ByteString + def expectBytes(expected: ByteString): Unit + + def expectUtf8EncodedString(string: String): Unit + + def expectNoBytes(): Unit + def expectNoBytes(timeout: FiniteDuration): Unit + + def expectComplete(): Unit + def expectError(): Throwable + def expectError(cause: Throwable): Unit +} + +object ByteStringSinkProbe { + def apply()(implicit system: ActorSystem): ByteStringSinkProbe = + new ByteStringSinkProbe { + val probe = TestSubscriber.probe[ByteString]() + val sink: Sink[ByteString, Unit] = Sink(probe) + + def expectNoBytes(): Unit = { + probe.ensureSubscription() + probe.expectNoMsg() + } + def expectNoBytes(timeout: FiniteDuration): Unit = { + probe.ensureSubscription() + probe.expectNoMsg(timeout) + } + + var inBuffer = ByteString.empty + @tailrec def expectBytes(length: Int): ByteString = + if (inBuffer.size >= length) { + val res = inBuffer.take(length) + inBuffer = inBuffer.drop(length) + res + } else { + inBuffer ++= probe.requestNext() + expectBytes(length) + } + + def expectBytes(expected: ByteString): Unit = + assert(expectBytes(expected.length) == expected, "expected ") + + def expectUtf8EncodedString(string: String): Unit = + expectBytes(ByteString(string, "utf8")) + + def expectComplete(): Unit = probe.expectComplete() + def expectError(): Throwable = probe.expectError() + def expectError(cause: Throwable): Unit = probe.expectError(cause) + } +} diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestSetupBase.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestSetupBase.scala index 65c5548f13..814999eef1 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestSetupBase.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestSetupBase.scala @@ -14,7 +14,8 @@ import scala.util.Random trait WSTestSetupBase extends Matchers { def send(bytes: ByteString): Unit - def expectNextChunk(): ByteString + def expectBytes(length: Int): ByteString + def expectBytes(bytes: ByteString): Unit def sendWSFrame(opcode: Opcode, data: ByteString, @@ -41,24 +42,13 @@ trait WSTestSetupBase extends Matchers { rsv1: Boolean = false, rsv2: Boolean = false, rsv3: Boolean = false): Unit = - expectNextChunk() shouldEqual frameHeader(opcode, data.length, fin, mask, rsv1, rsv2, rsv3) ++ data + expectBytes(frameHeader(opcode, data.length, fin, mask, rsv1, rsv2, rsv3) ++ data) def expectWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit = - expectNextChunk() shouldEqual closeFrame(closeCode, mask) + expectBytes(closeFrame(closeCode, mask)) - var inBuffer = ByteString.empty - @tailrec final def expectNetworkData(bytes: Int): ByteString = - if (inBuffer.size >= bytes) { - val res = inBuffer.take(bytes) - inBuffer = inBuffer.drop(bytes) - res - } else { - inBuffer ++= expectNextChunk() - expectNetworkData(bytes) - } - - def expectNetworkData(data: ByteString): Unit = - expectNetworkData(data.size) shouldEqual data + def expectNetworkData(length: Int): ByteString = expectBytes(length) + def expectNetworkData(data: ByteString): Unit = expectBytes(data) def expectFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = { expectFrameHeaderOnNetwork(opcode, data.size, fin) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala index fd5e297970..7b25f9002b 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala @@ -4,7 +4,6 @@ package akka.http.impl.engine.ws -import akka.http.impl.engine.ws.Protocol.Opcode import akka.http.scaladsl.model.ws._ import akka.stream.scaladsl.{ Keep, Sink, Flow, Source } import akka.stream.testkit.Utils @@ -15,8 +14,6 @@ import akka.http.impl.util._ import akka.http.impl.engine.server.HttpServerTestSetupBase -import scala.util.Random - class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSpec { spec ⇒ import WSTestUtils._ @@ -33,7 +30,7 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp |Origin: http://example.com |Sec-WebSocket-Version: 13 | - |""".stripMarginWithNewline("\r\n")) + |""") val request = expectRequest val upgrade = request.header[UpgradeToWebsocket] @@ -43,9 +40,9 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp Source(List(1, 2, 3, 4, 5)).map(num ⇒ TextMessage.Strict(s"Message $num")) val handler = Flow.wrap(Sink.ignore, source)(Keep.none) val response = upgrade.get.handleMessages(handler) - responsesSub.sendNext(response) + responses.sendNext(response) - wipeDate(expectNextChunk().utf8String) shouldEqual + expectResponseWithWipedDate( """HTTP/1.1 101 Switching Protocols |Upgrade: websocket |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= @@ -53,7 +50,7 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp |Date: XXXX |Connection: upgrade | - |""".stripMarginWithNewline("\r\n") + |""") expectWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true) expectWSFrame(Protocol.Opcode.Text, ByteString("Message 2"), fin = true) @@ -79,16 +76,16 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp |Origin: http://example.com |Sec-WebSocket-Version: 13 | - |""".stripMarginWithNewline("\r\n")) + |""") val request = expectRequest val upgrade = request.header[UpgradeToWebsocket] upgrade.isDefined shouldBe true val response = upgrade.get.handleMessages(Flow[Message]) // simple echoing - responsesSub.sendNext(response) + responses.sendNext(response) - wipeDate(expectNextChunk().utf8String) shouldEqual + expectResponseWithWipedDate( """HTTP/1.1 101 Switching Protocols |Upgrade: websocket |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= @@ -96,7 +93,7 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp |Date: XXXX |Connection: upgrade | - |""".stripMarginWithNewline("\r\n") + |""") sendWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true, mask = true) expectWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true) @@ -131,9 +128,7 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp implicit def system = spec.system implicit def materializer = spec.materializer - def expectNextChunk(): ByteString = { - netOutSub.request(1) - netOut.expectNext() - } + def expectBytes(length: Int): ByteString = netOut.expectBytes(length) + def expectBytes(bytes: ByteString): Unit = netOut.expectBytes(bytes) } } From 870ff2bbdcff75a2e95814932703633791e50c13 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Fri, 9 Oct 2015 17:38:59 +0200 Subject: [PATCH 05/10] +htc #17275 Websocket client implementation --- .../akka/http/ClientConnectionSettings.scala | 5 + .../akka/http/impl/engine/ws/Handshake.scala | 121 ++++++ .../engine/ws/WebsocketClientBlueprint.scala | 139 +++++++ .../akka/http/impl/util/StreamUtils.scala | 17 + .../main/scala/akka/http/scaladsl/Http.scala | 95 ++++- .../scala/akka/http/scaladsl/model/Uri.scala | 4 +- .../http/scaladsl/model/headers/headers.scala | 1 + .../impl/engine/ws/WebsocketClientSpec.scala | 368 ++++++++++++++++++ 8 files changed, 746 insertions(+), 4 deletions(-) create mode 100644 akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebsocketClientBlueprint.scala create mode 100644 akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketClientSpec.scala diff --git a/akka-http-core/src/main/scala/akka/http/ClientConnectionSettings.scala b/akka-http-core/src/main/scala/akka/http/ClientConnectionSettings.scala index 1718e7a41f..9562d23168 100644 --- a/akka-http-core/src/main/scala/akka/http/ClientConnectionSettings.scala +++ b/akka-http-core/src/main/scala/akka/http/ClientConnectionSettings.scala @@ -4,6 +4,9 @@ package akka.http +import java.util.Random + +import akka.http.impl.engine.ws.Randoms import akka.io.Inet.SocketOption import scala.concurrent.duration.{ Duration, FiniteDuration } @@ -21,6 +24,7 @@ final case class ClientConnectionSettings( connectingTimeout: FiniteDuration, idleTimeout: Duration, requestHeaderSizeHint: Int, + websocketRandomFactory: () ⇒ Random, socketOptions: immutable.Traversable[SocketOption], parserSettings: ParserSettings) { @@ -36,6 +40,7 @@ object ClientConnectionSettings extends SettingsCompanion[ClientConnectionSettin c getFiniteDuration "connecting-timeout", c getPotentiallyInfiniteDuration "idle-timeout", c getIntBytes "request-header-size-hint", + Randoms.SecureRandomInstances, // can currently only be overridden from code SocketOptionSettings.fromSubConfig(root, c.getConfig("socket-options")), ParserSettings.fromSubConfig(root, c.getConfig("parsing"))) } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala index cc8a5a7f2f..af09a2cf46 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala @@ -138,6 +138,127 @@ private[http] object Handshake { UpgradeToWebsocketResponseHeader(handler))) } + object Client { + case class NegotiatedWebsocketSettings(subprotocol: Option[String]) + + /** + * Builds a WebSocket handshake request. + */ + def buildRequest(uri: Uri, extraHeaders: immutable.Seq[HttpHeader], subprotocols: Seq[String], random: Random): (HttpRequest, `Sec-WebSocket-Key`) = { + val keyBytes = new Array[Byte](16) + random.nextBytes(keyBytes) + val key = `Sec-WebSocket-Key`(Base64.rfc2045().encodeToString(keyBytes, false)) + val protocol = + if (subprotocols.nonEmpty) `Sec-WebSocket-Protocol`(subprotocols) :: Nil + else Nil + //version, protocol, extensions, origin + + val headers = Seq( + UpgradeHeader, + ConnectionUpgradeHeader, + key, + SecWebsocketVersionHeader) ++ protocol ++ extraHeaders + + (HttpRequest(HttpMethods.GET, uri.toRelative, headers), key) + } + + /** + * Tries to validate the HTTP response. Returns either Right(settings) or an error message if + * the response cannot be validated. + */ + def validateResponse(response: HttpResponse, subprotocols: Seq[String], key: `Sec-WebSocket-Key`): Either[String, NegotiatedWebsocketSettings] = { + /* + From http://tools.ietf.org/html/rfc6455#section-4.1 + + 1. If the status code received from the server is not 101, the + client handles the response per HTTP [RFC2616] procedures. In + particular, the client might perform authentication if it + receives a 401 status code; the server might redirect the client + using a 3xx status code (but clients are not required to follow + them), etc. Otherwise, proceed as follows. + + 2. If the response lacks an |Upgrade| header field or the |Upgrade| + header field contains a value that is not an ASCII case- + insensitive match for the value "websocket", the client MUST + _Fail the WebSocket Connection_. + + 3. If the response lacks a |Connection| header field or the + |Connection| header field doesn't contain a token that is an + ASCII case-insensitive match for the value "Upgrade", the client + MUST _Fail the WebSocket Connection_. + + 4. If the response lacks a |Sec-WebSocket-Accept| header field or + the |Sec-WebSocket-Accept| contains a value other than the + base64-encoded SHA-1 of the concatenation of the |Sec-WebSocket- + Key| (as a string, not base64-decoded) with the string "258EAFA5- + E914-47DA-95CA-C5AB0DC85B11" but ignoring any leading and + trailing whitespace, the client MUST _Fail the WebSocket + Connection_. + + 5. If the response includes a |Sec-WebSocket-Extensions| header + field and this header field indicates the use of an extension + that was not present in the client's handshake (the server has + indicated an extension not requested by the client), the client + MUST _Fail the WebSocket Connection_. (The parsing of this + header field to determine which extensions are requested is + discussed in Section 9.1.) + + 6. If the response includes a |Sec-WebSocket-Protocol| header field + and this header field indicates the use of a subprotocol that was + not present in the client's handshake (the server has indicated a + subprotocol not requested by the client), the client MUST _Fail + the WebSocket Connection_. + */ + + trait Expectation extends (HttpResponse ⇒ Option[String]) { outer ⇒ + def &&(other: HttpResponse ⇒ Option[String]): Expectation = + new Expectation { + def apply(v1: HttpResponse): Option[String] = + outer(v1).orElse(other(v1)) + } + } + + def check[T](value: HttpResponse ⇒ T)(condition: T ⇒ Boolean, msg: T ⇒ String): Expectation = + new Expectation { + def apply(resp: HttpResponse): Option[String] = { + val v = value(resp) + if (condition(v)) None + else Some(msg(v)) + } + } + + def compare(candidate: HttpHeader, caseInsensitive: Boolean): Option[HttpHeader] ⇒ Boolean = { + + case Some(`candidate`) if !caseInsensitive ⇒ true + case Some(header) if caseInsensitive && candidate.value.toRootLowerCase == header.value.toRootLowerCase ⇒ true + case _ ⇒ false + } + + def headerExists(candidate: HttpHeader, showExactOther: Boolean = true, caseInsensitive: Boolean = false): Expectation = + check(_.headers.find(_.name == candidate.name))(compare(candidate, caseInsensitive), { + case Some(other) if showExactOther ⇒ s"response that was missing required `$candidate` header. Found `$other` with the wrong value." + case Some(_) ⇒ s"response with invalid `${candidate.name}` header." + case None ⇒ s"response that was missing required `${candidate.name}` header." + }) + + val expectations: Expectation = + check(_.status)(_ == StatusCodes.SwitchingProtocols, "unexpected status code: " + _) && + headerExists(UpgradeHeader, caseInsensitive = true) && + headerExists(ConnectionUpgradeHeader, caseInsensitive = true) && + headerExists(`Sec-WebSocket-Accept`.forKey(key), showExactOther = false) + + expectations(response) match { + case None ⇒ + val subs = response.header[`Sec-WebSocket-Protocol`].flatMap(_.protocols.headOption) + + if (subprotocols.isEmpty && subs.isEmpty) Right(NegotiatedWebsocketSettings(None)) // no specific one selected + else if (subs.nonEmpty && subprotocols.contains(subs.get)) Right(NegotiatedWebsocketSettings(Some(subs.get))) + else Left(s"response that indicated that the given subprotocol was not supported. (client supported: ${subprotocols.mkString(", ")}, server supported: $subs)") + case Some(problem) ⇒ Left(problem) + } + } + } + val UpgradeHeader = Upgrade(List(UpgradeProtocol("websocket"))) val ConnectionUpgradeHeader = Connection(List("upgrade")) val SecWebsocketVersionHeader = `Sec-WebSocket-Version`(Seq(CurrentWebsocketVersion)) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebsocketClientBlueprint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebsocketClientBlueprint.scala new file mode 100644 index 0000000000..b7c2bf36ed --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebsocketClientBlueprint.scala @@ -0,0 +1,139 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.impl.engine.ws + +import scala.collection.immutable +import scala.concurrent.{ Future, Promise } + +import akka.util.ByteString +import akka.event.LoggingAdapter + +import akka.stream.stage._ +import akka.stream.BidiShape +import akka.stream.io.{ SessionBytes, SendBytes, SslTlsInbound } +import akka.stream.scaladsl._ + +import akka.http.ClientConnectionSettings +import akka.http.scaladsl.Http +import akka.http.scaladsl.Http.{ InvalidUpgradeResponse, ValidUpgrade, WebsocketUpgradeResponse } +import akka.http.scaladsl.model.{ HttpHeader, HttpResponse, HttpMethods, Uri } +import akka.http.scaladsl.model.headers.Host + +import akka.http.impl.engine.parsing.HttpMessageParser.StateResult +import akka.http.impl.engine.parsing.ParserOutput.{ RemainingBytes, ResponseStart, NeedMoreData } +import akka.http.impl.engine.parsing.{ ParserOutput, HttpHeaderParser, HttpResponseParser } +import akka.http.impl.engine.rendering.{ HttpRequestRendererFactory, RequestRenderingContext } +import akka.http.impl.engine.ws.Handshake.Client.NegotiatedWebsocketSettings +import akka.http.impl.util.StreamUtils + +object WebsocketClientBlueprint { + /** + * Returns a WebsocketClientLayer that can be materialized once. + */ + def apply(uri: Uri, + extraHeaders: immutable.Seq[HttpHeader], + subProtocol: Option[String], + settings: ClientConnectionSettings, + log: LoggingAdapter): Http.WebsocketClientLayer = + (simpleTls.atopMat(handshake(uri, extraHeaders, subProtocol, settings, log))(Keep.right) atop + Websocket.framing atop + Websocket.stack(serverSide = false, maskingRandomFactory = settings.websocketRandomFactory)).reversed + + /** + * A bidi flow that injects and inspects the WS handshake and then goes out of the way. This BidiFlow + * can only be materialized once. + */ + def handshake(uri: Uri, + extraHeaders: immutable.Seq[HttpHeader], + subProtocol: Option[String], + settings: ClientConnectionSettings, + log: LoggingAdapter): BidiFlow[ByteString, ByteString, ByteString, ByteString, Future[WebsocketUpgradeResponse]] = { + val result = Promise[WebsocketUpgradeResponse]() + + val valve = StreamUtils.OneTimeValve() + + val (initialRequest, key) = Handshake.Client.buildRequest(uri, extraHeaders, subProtocol.toList, settings.websocketRandomFactory()) + val hostHeader = Host(uri.authority) + val renderedInitialRequest = + HttpRequestRendererFactory.renderStrict(RequestRenderingContext(initialRequest, hostHeader), settings, log) + + class UpgradeStage extends StatefulStage[ByteString, ByteString] { + type State = StageState[ByteString, ByteString] + + def initial: State = parsingResponse + + def parsingResponse: State = new State { + // a special version of the parser which only parses one message and then reports the remaining data + // if some is available + val parser = new HttpResponseParser(settings.parserSettings, HttpHeaderParser(settings.parserSettings)()) { + var first = true + override protected def parseMessage(input: ByteString, offset: Int): StateResult = { + if (first) { + first = false + super.parseMessage(input, offset) + } else { + emit(RemainingBytes(input.drop(offset))) + terminate() + } + } + } + parser.setRequestMethodForNextResponse(HttpMethods.GET) + + def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = { + parser.onPush(elem) match { + case NeedMoreData ⇒ ctx.pull() + case ResponseStart(status, protocol, headers, entity, close) ⇒ + val response = HttpResponse(status, headers, protocol = protocol) + Handshake.Client.validateResponse(response, subProtocol.toList, key) match { + case Right(NegotiatedWebsocketSettings(protocol)) ⇒ + result.success(ValidUpgrade(response, protocol)) + + become(transparent) + valve.open() + + val parseResult = parser.onPull() + require(parseResult == ParserOutput.MessageEnd, s"parseResult should be MessageEnd but was $parseResult") + parser.onPull() match { + case NeedMoreData ⇒ ctx.pull() + case RemainingBytes(bytes) ⇒ ctx.push(bytes) + } + case Left(problem) ⇒ + result.success(InvalidUpgradeResponse(response, s"Websocket server at $uri returned $problem")) + ctx.fail(throw new IllegalArgumentException(s"Websocket upgrade did not finish because of '$problem'")) + } + } + } + } + + def transparent: State = new State { + def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = ctx.push(elem) + } + } + + BidiFlow() { implicit b ⇒ + import FlowGraph.Implicits._ + + val networkIn = b.add(Flow[ByteString].transform(() ⇒ new UpgradeStage)) + val wsIn = b.add(Flow[ByteString]) + + val handshakeRequestSource = b.add(Source.single(renderedInitialRequest) ++ valve.source) + val httpRequestBytesAndThenWSBytes = b.add(Concat[ByteString]()) + + handshakeRequestSource ~> httpRequestBytesAndThenWSBytes + wsIn.outlet ~> httpRequestBytesAndThenWSBytes + + BidiShape( + networkIn.inlet, + networkIn.outlet, + wsIn.inlet, + httpRequestBytesAndThenWSBytes.out) + } mapMaterializedValue (_ ⇒ result.future) + } + + def simpleTls: BidiFlow[SslTlsInbound, ByteString, ByteString, SendBytes, Unit] = + BidiFlow.wrap( + Flow[SslTlsInbound].collect { case SessionBytes(_, bytes) ⇒ bytes }, + Flow[ByteString].map(SendBytes))(Keep.none) +} diff --git a/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala b/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala index 793d48c11b..4f057bfc8a 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala @@ -316,6 +316,23 @@ private[http] object StreamUtils { } Flow[T].transformMaterializing(newForeachStage) } + + /** + * Similar to Source.lazyEmpty but doesn't rely on materialization. Can only be used once. + */ + trait OneTimeValve { + def source[T]: Source[T, Unit] + def open(): Unit + } + object OneTimeValve { + def apply(): OneTimeValve = new OneTimeValve { + val promise = Promise[Unit]() + val _source = Source(promise.future).drop(1) // we are only interested in the completion event + + def source[T]: Source[T, Unit] = _source.asInstanceOf[Source[T, Unit]] // safe, because source won't generate any elements + def open(): Unit = promise.success(()) + } + } } /** diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala index 081a068034..13430b812d 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala @@ -5,8 +5,9 @@ package akka.http.scaladsl import java.net.InetSocketAddress +import java.security.SecureRandom import java.util.concurrent.ConcurrentHashMap -import java.util.{ Collection ⇒ JCollection } +import java.util.{ Collection ⇒ JCollection, Random } import javax.net.ssl.{ SSLContext, SSLParameters } import akka.actor._ @@ -15,13 +16,16 @@ import akka.http._ import akka.http.impl.engine.client._ import akka.http.impl.engine.server._ import akka.http.impl.util.{ ReadTheDocumentationException, Java6Compat, StreamUtils } +import akka.http.impl.engine.ws.WebsocketClientBlueprint import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers.Host +import akka.http.scaladsl.model.ws.Message import akka.http.scaladsl.util.FastFuture import akka.japi import akka.stream.Materializer import akka.stream.io._ import akka.stream.scaladsl._ +import akka.util.ByteString import com.typesafe.config.Config import scala.collection.immutable @@ -207,11 +211,17 @@ class HttpExt(config: Config)(implicit system: ActorSystem) extends akka.actor.E log: LoggingAdapter): Flow[HttpRequest, HttpResponse, Future[OutgoingConnection]] = { val hostHeader = if (port == (if (httpsContext.isEmpty) 80 else 443)) Host(host) else Host(host, port) val layer = clientLayer(hostHeader, settings, log) + layer.joinMat(_outgoingTlsConnectionLayer(host, port, localAddress, settings, httpsContext, log))(Keep.right) + } + + private def _outgoingTlsConnectionLayer(host: String, port: Int, localAddress: Option[InetSocketAddress], + settings: ClientConnectionSettings, httpsContext: Option[HttpsContext], + log: LoggingAdapter): Flow[SslTlsOutbound, SslTlsInbound, Future[OutgoingConnection]] = { val tlsStage = sslTlsStage(httpsContext, Client, Some(host -> port)) val transportFlow = Tcp().outgoingConnection(new InetSocketAddress(host, port), localAddress, settings.socketOptions, halfClose = true, settings.connectingTimeout, settings.idleTimeout) - layer.atop(tlsStage).joinMat(transportFlow) { (_, tcpConnFuture) ⇒ + tlsStage.joinMat(transportFlow) { (_, tcpConnFuture) ⇒ import system.dispatcher tcpConnFuture map { tcpConn ⇒ OutgoingConnection(tcpConn.localAddress, tcpConn.remoteAddress) } } @@ -406,6 +416,62 @@ class HttpExt(config: Config)(implicit system: ActorSystem) extends akka.actor.E case e: IllegalUriException ⇒ FastFuture.failed(e) } + /** + * Constructs a [[WebsocketClientLayer]] stage using the configured default [[ClientConnectionSettings]], + * configured using the `akka.http.client` config section. + * + * The layer is not reusable and must only be materialized once. + */ + def websocketClientLayer(uri: Uri, + extraHeaders: immutable.Seq[HttpHeader] = Nil, + subprotocol: Option[String] = None, + settings: ClientConnectionSettings = ClientConnectionSettings(system), + log: LoggingAdapter = system.log): Http.WebsocketClientLayer = + WebsocketClientBlueprint(uri, extraHeaders, subprotocol, settings, log) + + /** + * Constructs a flow that once materialized establishes a Websocket connection to the given Uri. + * + * The layer is not reusable and must only be materialized once. + */ + def websocketClientFlow(uri: Uri, + extraHeaders: immutable.Seq[HttpHeader] = Nil, + subprotocol: Option[String] = None, + localAddress: Option[InetSocketAddress] = None, + settings: ClientConnectionSettings = ClientConnectionSettings(system), + httpsContext: Option[HttpsContext] = None, + log: LoggingAdapter = system.log): Flow[Message, Message, Future[WebsocketUpgradeResponse]] = { + require(uri.isAbsolute, s"Websocket request URI must be absolute but was '$uri'") + + val ctx = uri.scheme match { + case "ws" ⇒ None + case "wss" ⇒ effectiveHttpsContext(httpsContext) + case scheme @ _ ⇒ + throw new IllegalArgumentException(s"Illegal URI scheme '$scheme' in '$uri' for Websocket request. " + + s"Websocket requests must use either 'ws' or 'wss'") + } + val host = uri.authority.host.address + val port = uri.effectivePort + + websocketClientLayer(uri, extraHeaders, subprotocol, settings, log) + .joinMat(_outgoingTlsConnectionLayer(host, port, localAddress, settings, ctx, log))(Keep.left) + } + + /** + * Runs a single Websocket conversation given a Uri and a flow that represents the client side of the + * Websocket conversation. + */ + def singleWebsocketRequest[T](uri: Uri, + clientFlow: Flow[Message, Message, T], + extraHeaders: immutable.Seq[HttpHeader] = Nil, + subprotocol: Option[String] = None, + localAddress: Option[InetSocketAddress] = None, + settings: ClientConnectionSettings = ClientConnectionSettings(system), + httpsContext: Option[HttpsContext] = None, + log: LoggingAdapter = system.log)(implicit mat: Materializer): (Future[WebsocketUpgradeResponse], T) = + websocketClientFlow(uri, extraHeaders, subprotocol, localAddress, settings, httpsContext, log) + .joinMat(clientFlow)(Keep.both).run() + /** * Triggers an orderly shutdown of all host connections pools currently maintained by the [[ActorSystem]]. * The returned future is completed when all pools that were live at the time of this method call @@ -562,6 +628,20 @@ object Http extends ExtensionId[HttpExt] with ExtensionIdProvider { type ClientLayer = BidiFlow[HttpRequest, SslTlsOutbound, SslTlsInbound, HttpResponse, Unit] //# + /** + * The type of the client-side Websocket layer as a stand-alone BidiFlow + * that can be put atop the TCP layer to form an HTTP client. + * + * {{{ + * +------+ + * ws.Message ~>| |~> SslTlsOutbound + * | bidi | + * ws.Message <~| |<~ SslTlsInbound + * +------+ + * }}} + */ + type WebsocketClientLayer = BidiFlow[Message, SslTlsOutbound, SslTlsInbound, Message, Future[WebsocketUpgradeResponse]] + /** * Represents a prospective HTTP server binding. * @@ -614,6 +694,17 @@ object Http extends ExtensionId[HttpExt] with ExtensionIdProvider { */ final case class OutgoingConnection(localAddress: InetSocketAddress, remoteAddress: InetSocketAddress) + /** + * Represents the response to a websocket upgrade request. + */ + sealed trait WebsocketUpgradeResponse { + def response: HttpResponse + } + final case class InvalidUpgradeResponse(response: HttpResponse, cause: String) extends WebsocketUpgradeResponse + final case class ValidUpgrade( + response: HttpResponse, + chosenSubprotocol: Option[String]) extends WebsocketUpgradeResponse + /** * Represents a connection pool to a specific target host and pool configuration. */ diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/Uri.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/Uri.scala index 2f655125f5..3186783306 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/Uri.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/Uri.scala @@ -599,8 +599,8 @@ object Uri { } val defaultPorts: Map[String, Int] = - Map("ftp" -> 21, "ssh" -> 22, "telnet" -> 23, "smtp" -> 25, "domain" -> 53, "tftp" -> 69, "http" -> 80, - "pop3" -> 110, "nntp" -> 119, "imap" -> 143, "snmp" -> 161, "ldap" -> 389, "https" -> 443, "imaps" -> 993, + Map("ftp" -> 21, "ssh" -> 22, "telnet" -> 23, "smtp" -> 25, "domain" -> 53, "tftp" -> 69, "http" -> 80, "ws" -> 80, + "pop3" -> 110, "nntp" -> 119, "imap" -> 143, "snmp" -> 161, "ldap" -> 389, "https" -> 443, "wss" -> 443, "imaps" -> 993, "nfs" -> 2049).withDefaultValue(-1) sealed trait ParsingMode diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala index ecbfaf0961..41092abab4 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala @@ -102,6 +102,7 @@ sealed abstract case class Expect private () extends ModeledHeader { // http://tools.ietf.org/html/rfc7230#section-5.4 object Host extends ModeledCompanion[Host] { + def apply(authority: Uri.Authority): Host = apply(authority.host, authority.port) def apply(address: InetSocketAddress): Host = apply(address.getHostStringJava6Compatible, address.getPort) def apply(host: String): Host = apply(host, 0) def apply(host: String, port: Int): Host = apply(Uri.Host(host), port) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketClientSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketClientSpec.scala new file mode 100644 index 0000000000..c7a8629e43 --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketClientSpec.scala @@ -0,0 +1,368 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.impl.engine.ws + +import java.util.Random + +import akka.http.scaladsl.Http.{ InvalidUpgradeResponse, WebsocketUpgradeResponse } + +import scala.concurrent.duration._ + +import akka.http.ClientConnectionSettings +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.headers.{ ProductVersion, `User-Agent` } +import akka.http.scaladsl.model.ws._ +import akka.http.scaladsl.model.{ HttpResponse, Uri } +import akka.stream.io._ +import akka.stream.scaladsl._ +import akka.stream.testkit.{ TestSubscriber, TestPublisher } +import akka.util.ByteString +import org.scalatest.{ Matchers, FreeSpec } + +import akka.http.impl.util._ + +class WebsocketClientSpec extends FreeSpec with Matchers with WithMaterializerSpec { + "The client-side Websocket implementation should" - { + "establish a websocket connection when the user requests it" in new EstablishedConnectionSetup with ClientEchoes + "establish connection with case insensitive header values" in new TestSetup with ClientEchoes { + expectWireData(UpgradeRequestBytes) + sendWireData("""HTTP/1.1 101 Switching Protocols + |Upgrade: wEbSOckET + |Sec-WebSocket-Accept: ujmZX4KXZqjwy6vi1aQFH5p4Ygk= + |Server: akka-http/test + |Sec-WebSocket-Version: 13 + |Connection: upgrade + | + |""") + + sendWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true) + expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 1"), fin = true) + } + "reject invalid handshakes" - { + "other status code" in new TestSetup with ClientEchoes { + expectWireData(UpgradeRequestBytes) + + sendWireData( + """HTTP/1.1 404 Not Found + |Server: akka-http/test + |Content-Length: 0 + | + |""") + + expectNetworkAbort() + expectInvalidUpgradeResponseCause("Websocket server at ws://example.org/ws returned unexpected status code: 404 Not Found") + } + "missing Sec-WebSocket-Accept hash" in new TestSetup with ClientEchoes { + expectWireData(UpgradeRequestBytes) + + sendWireData( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Version: 13 + |Server: akka-http/test + |Connection: upgrade + | + |""") + + expectNetworkAbort() + expectInvalidUpgradeResponseCause("Websocket server at ws://example.org/ws returned response that was missing required `Sec-WebSocket-Accept` header.") + } + "wrong Sec-WebSocket-Accept hash" in new TestSetup with ClientEchoes { + expectWireData(UpgradeRequestBytes) + + sendWireData( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxhZRbK+xOo= + |Sec-WebSocket-Version: 13 + |Server: akka-http/test + |Connection: upgrade + | + |""") + + expectNetworkAbort() + expectInvalidUpgradeResponseCause("Websocket server at ws://example.org/ws returned response with invalid `Sec-WebSocket-Accept` header.") + } + "missing `Upgrade` header" in new TestSetup with ClientEchoes { + expectWireData(UpgradeRequestBytes) + + sendWireData( + """HTTP/1.1 101 Switching Protocols + |Sec-WebSocket-Accept: ujmZX4KXZqjwy6vi1aQFH5p4Ygk= + |Sec-WebSocket-Version: 13 + |Server: akka-http/test + |Connection: upgrade + | + |""") + + expectNetworkAbort() + expectInvalidUpgradeResponseCause("Websocket server at ws://example.org/ws returned response that was missing required `Upgrade` header.") + } + "missing `Connection: upgrade` header" in new TestSetup with ClientEchoes { + expectWireData(UpgradeRequestBytes) + + sendWireData( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: ujmZX4KXZqjwy6vi1aQFH5p4Ygk= + |Sec-WebSocket-Version: 13 + |Server: akka-http/test + | + |""") + + expectNetworkAbort() + expectInvalidUpgradeResponseCause("Websocket server at ws://example.org/ws returned response that was missing required `Connection` header.") + } + } + + "don't send out frames before handshake was finished successfully" in new TestSetup { + def clientImplementation: Flow[Message, Message, Unit] = + Flow.wrap(Sink.ignore, Source.single(TextMessage("fast message")))(Keep.none) + + expectWireData(UpgradeRequestBytes) + expectNoWireData() + + sendWireData(UpgradeResponseBytes) + expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("fast message"), fin = true) + + expectMaskedCloseFrame(Protocol.CloseCodes.Regular) + sendWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + "receive first frame in same chunk as HTTP upgrade response" in new TestSetup with ClientProbes { + expectWireData(UpgradeRequestBytes) + + val firstFrame = WSTestUtils.frame(Protocol.Opcode.Text, ByteString("fast"), fin = true, mask = false) + sendWireData(UpgradeResponseBytes ++ firstFrame) + + messagesIn.requestNext(TextMessage("fast")) + } + + "manual scenario client sends first" in new EstablishedConnectionSetup with ClientProbes { + messagesOut.sendNext(TextMessage("Message 1")) + + expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 1"), fin = true) + + sendWSFrame(Protocol.Opcode.Binary, ByteString("Response"), fin = true, mask = false) + + messagesIn.requestNext(BinaryMessage(ByteString("Response"))) + } + "client echoes scenario" in new EstablishedConnectionSetup with ClientEchoes { + sendWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true) + expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 1"), fin = true) + sendWSFrame(Protocol.Opcode.Text, ByteString("Message 2"), fin = true) + expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 2"), fin = true) + sendWSFrame(Protocol.Opcode.Text, ByteString("Message 3"), fin = true) + expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 3"), fin = true) + sendWSFrame(Protocol.Opcode.Text, ByteString("Message 4"), fin = true) + expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 4"), fin = true) + sendWSFrame(Protocol.Opcode.Text, ByteString("Message 5"), fin = true) + expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 5"), fin = true) + + sendWSCloseFrame(Protocol.CloseCodes.Regular) + expectMaskedCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + "support subprotocols" - { + "accept if server supports subprotocol" in new TestSetup with ClientEchoes { + override protected def requestedSubProtocol: Option[String] = Some("v2") + + expectWireData( + """GET /ws HTTP/1.1 + |Upgrade: websocket + |Connection: upgrade + |Sec-WebSocket-Key: YLQguzhR2dR6y5M9vnA5mw== + |Sec-WebSocket-Version: 13 + |Sec-WebSocket-Protocol: v2 + |Host: example.org + |User-Agent: akka-http/test + | + |""") + sendWireData( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: ujmZX4KXZqjwy6vi1aQFH5p4Ygk= + |Sec-WebSocket-Version: 13 + |Server: akka-http/test + |Connection: upgrade + |Sec-WebSocket-Protocol: v2 + | + |""") + + sendWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true) + expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 1"), fin = true) + } + "send error on user flow if server doesn't support subprotocol" - { + "if no protocol was selected" in new TestSetup with ClientProbes { + override protected def requestedSubProtocol: Option[String] = Some("v2") + + expectWireData( + """GET /ws HTTP/1.1 + |Upgrade: websocket + |Connection: upgrade + |Sec-WebSocket-Key: YLQguzhR2dR6y5M9vnA5mw== + |Sec-WebSocket-Version: 13 + |Sec-WebSocket-Protocol: v2 + |Host: example.org + |User-Agent: akka-http/test + | + |""") + sendWireData( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: ujmZX4KXZqjwy6vi1aQFH5p4Ygk= + |Sec-WebSocket-Version: 13 + |Server: akka-http/test + |Connection: upgrade + | + |""") + + expectNetworkAbort() + expectInvalidUpgradeResponseCause( + "Websocket server at ws://example.org/ws returned response that indicated that the given subprotocol was not supported. (client supported: v2, server supported: None)") + } + "if different protocol was selected" in new TestSetup with ClientProbes { + override protected def requestedSubProtocol: Option[String] = Some("v2") + + expectWireData( + """GET /ws HTTP/1.1 + |Upgrade: websocket + |Connection: upgrade + |Sec-WebSocket-Key: YLQguzhR2dR6y5M9vnA5mw== + |Sec-WebSocket-Version: 13 + |Sec-WebSocket-Protocol: v2 + |Host: example.org + |User-Agent: akka-http/test + | + |""") + sendWireData( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: ujmZX4KXZqjwy6vi1aQFH5p4Ygk= + |Sec-WebSocket-Protocol: v3 + |Sec-WebSocket-Version: 13 + |Server: akka-http/test + |Connection: upgrade + | + |""") + + expectNetworkAbort() + expectInvalidUpgradeResponseCause( + "Websocket server at ws://example.org/ws returned response that indicated that the given subprotocol was not supported. (client supported: v2, server supported: Some(v3))") + } + } + } + } + + def UpgradeRequestBytes = ByteString { + """GET /ws HTTP/1.1 + |Upgrade: websocket + |Connection: upgrade + |Sec-WebSocket-Key: YLQguzhR2dR6y5M9vnA5mw== + |Sec-WebSocket-Version: 13 + |Host: example.org + |User-Agent: akka-http/test + | + |""".stripMarginWithNewline("\r\n") + } + + def UpgradeResponseBytes = ByteString { + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: ujmZX4KXZqjwy6vi1aQFH5p4Ygk= + |Server: akka-http/test + |Sec-WebSocket-Version: 13 + |Connection: upgrade + | + |""".stripMarginWithNewline("\r\n") + } + + abstract class EstablishedConnectionSetup extends TestSetup { + expectWireData(UpgradeRequestBytes) + sendWireData(UpgradeResponseBytes) + } + + abstract class TestSetup extends WSTestSetupBase { + protected def noMsgTimeout: FiniteDuration = 100.millis + protected def clientImplementation: Flow[Message, Message, Unit] + protected def requestedSubProtocol: Option[String] = None + + val random = new Random(0) + def settings = ClientConnectionSettings(system) + .copy( + userAgentHeader = Some(`User-Agent`(List(ProductVersion("akka-http", "test")))), + websocketRandomFactory = () ⇒ random) + + def targetUri: Uri = "ws://example.org/ws" + + def clientLayer: Http.WebsocketClientLayer = + Http(system).websocketClientLayer(targetUri, subprotocol = requestedSubProtocol, settings = settings) + + val (netOut, netIn, response) = { + val netOut = ByteStringSinkProbe() + val netIn = TestPublisher.probe[ByteString]() + + val graph = + FlowGraph.closed(clientLayer) { implicit b ⇒ + client ⇒ + import FlowGraph.Implicits._ + Source(netIn) ~> Flow[ByteString].map(SessionBytes(null, _)) ~> client.in2 + client.out1 ~> Flow[SslTlsOutbound].collect { case SendBytes(x) ⇒ x } ~> netOut.sink + client.out2 ~> clientImplementation ~> client.in1 + } + + val response = graph.run() + + (netOut, netIn, response) + } + def expectBytes(length: Int): ByteString = netOut.expectBytes(length) + def expectBytes(bytes: ByteString): Unit = netOut.expectBytes(bytes) + + def wipeDate(string: String) = + string.fastSplit('\n').map { + case s if s.startsWith("Date:") ⇒ "Date: XXXX\r" + case s ⇒ s + }.mkString("\n") + + def sendWireData(data: String): Unit = sendWireData(ByteString(data.stripMarginWithNewline("\r\n"), "ASCII")) + def sendWireData(data: ByteString): Unit = netIn.sendNext(data) + + def send(bytes: ByteString): Unit = sendWireData(bytes) + + def expectWireData(s: String) = + netOut.expectUtf8EncodedString(s.stripMarginWithNewline("\r\n")) + def expectWireData(bs: ByteString) = netOut.expectBytes(bs) + def expectNoWireData() = netOut.expectNoBytes(noMsgTimeout) + + def expectNetworkClose(): Unit = netOut.expectComplete() + def expectNetworkAbort(): Unit = netOut.expectError() + def closeNetworkInput(): Unit = netIn.sendComplete() + + def expectResponse(response: WebsocketUpgradeResponse): Unit = + expectInvalidUpgradeResponse() shouldEqual response + def expectInvalidUpgradeResponseCause(expected: String): Unit = + expectInvalidUpgradeResponse().cause shouldEqual expected + + import akka.http.impl.util._ + def expectInvalidUpgradeResponse(): InvalidUpgradeResponse = + response.awaitResult(1.second).asInstanceOf[InvalidUpgradeResponse] + } + + trait ClientEchoes extends TestSetup { + override def clientImplementation: Flow[Message, Message, Unit] = echoServer + def echoServer: Flow[Message, Message, Unit] = Flow[Message] + } + trait ClientProbes extends TestSetup { + lazy val messagesOut = TestPublisher.probe[Message]() + lazy val messagesIn = TestSubscriber.probe[Message]() + + override def clientImplementation: Flow[Message, Message, Unit] = + Flow.wrap(Sink(messagesIn), Source(messagesOut))(Keep.none) + } +} From 08aa9034082b927ba242ceb9bb0208010a727bb8 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Fri, 9 Oct 2015 17:39:22 +0200 Subject: [PATCH 06/10] =htc Websocket Autobahn Suite test runners and documentation --- .../RunWebsocketAutobahnTestSuite.md | 51 ++++ .../impl/engine/ws/EchoTestClientApp.scala | 68 ++++++ .../impl/engine/ws/WSClientAutobahnTest.scala | 227 ++++++++++++++++++ .../impl/engine/ws/WSServerAutobahnTest.scala | 44 ++++ 4 files changed, 390 insertions(+) create mode 100644 akka-http-core/RunWebsocketAutobahnTestSuite.md create mode 100644 akka-http-core/src/test/scala/akka/http/impl/engine/ws/EchoTestClientApp.scala create mode 100644 akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSClientAutobahnTest.scala create mode 100644 akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSServerAutobahnTest.scala diff --git a/akka-http-core/RunWebsocketAutobahnTestSuite.md b/akka-http-core/RunWebsocketAutobahnTestSuite.md new file mode 100644 index 0000000000..9591870e75 --- /dev/null +++ b/akka-http-core/RunWebsocketAutobahnTestSuite.md @@ -0,0 +1,51 @@ +# Test the client side + +Start up the testsuite with docker: + +``` +docker run -ti --rm=true -p 8080:8080 -p 9001:9001 jrudolph/autobahn-testsuite +``` + +Then in sbt, to run all tests, use + +``` +akka-http-core-experimental/test:run-main akka.http.impl.engine.ws.WSClientAutobahnTest +``` + +or, to run a single test, use + +``` +akka-http-core-experimental/test:run-main akka.http.impl.engine.ws.WSClientAutobahnTest 1.1.1 +``` + +After a run, you can access the results of the run at http://localhost:8080/cwd/reports/clients/index.html. + +You can supply a configuration file for autobahn by mounting a version of `fuzzingserver.json` to `/tmp/fuzzingserver.json` +of the container, e.g. using this docker option: + +``` +-v /fullpath-on-host/my-fuzzingserver-config.json:/tmp/fuzzingserver.json +``` + +# Test the server side + +Start up the test server in sbt: + +``` +akka-http-core-experimental/test:run-main akka.http.impl.engine.ws.WSServerAutobahnTest +``` + +Then, run the test suite with docker: + +``` +docker run -ti --rm=true -v `pwd`/reports:/tmp/server-report jrudolph/autobahn-testsuite-client +``` + +This will put the result report into a `reports` directory in the current working directory on the host. + +You can supply a configuration file for autobahn by mounting a version of `fuzzingclient.json` to `/tmp/fuzzingclient.json` +of the container, e.g. using this docker option: + +``` +-v /fullpath-on-host/my-fuzzingclient-config.json:/tmp/fuzzingclient.json +``` diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/EchoTestClientApp.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/EchoTestClientApp.scala new file mode 100644 index 0000000000..900b25b0bd --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/EchoTestClientApp.scala @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.impl.engine.ws + +import scala.concurrent.duration._ + +import akka.actor.ActorSystem +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.ws.{ TextMessage, BinaryMessage, Message } +import akka.stream.ActorMaterializer +import akka.stream.scaladsl._ +import akka.util.ByteString + +import scala.concurrent.Future +import scala.util.{ Failure, Success } + +/** + * An example App that runs a quick test against the websocket server at wss://echo.websocket.org + */ +object EchoTestClientApp extends App { + implicit val system = ActorSystem() + import system.dispatcher + implicit val mat = ActorMaterializer() + + def delayedCompletion(delay: FiniteDuration): Source[Nothing, Unit] = + Source.single(1) + .mapAsync(1)(_ ⇒ akka.pattern.after(delay, system.scheduler)(Future(1))) + .drop(1).asInstanceOf[Source[Nothing, Unit]] + + def messages: List[Message] = + List( + TextMessage("Test 1"), + BinaryMessage(ByteString("abc")), + TextMessage("Test 2"), + BinaryMessage(ByteString("def"))) + + def source: Source[Message, Unit] = + Source(messages) ++ delayedCompletion(1.second) // otherwise, we may start closing too soon + + def sink: Sink[Message, Future[Seq[String]]] = + Flow[Message] + .mapAsync(1) { + case tm: TextMessage ⇒ + tm.textStream.runWith(Sink.fold("")(_ + _)).map(str ⇒ s"TextMessage: '$str'") + case bm: BinaryMessage ⇒ + bm.dataStream.runWith(Sink.fold(ByteString.empty)(_ ++ _)).map(bs ⇒ s"BinaryMessage: '${bs.utf8String}'") + } + .grouped(10000) + .toMat(Sink.head)(Keep.right) + + def echoClient = Flow.wrap(sink, source)(Keep.left) + + val (upgrade, res) = Http().singleWebsocketRequest("wss://echo.websocket.org", echoClient) + res onComplete { + case Success(res) ⇒ + println("Run successful. Got these elements:") + res.foreach(println) + system.shutdown() + case Failure(e) ⇒ + println("Run failed.") + e.printStackTrace() + system.shutdown() + } + + system.scheduler.scheduleOnce(10.seconds)(system.shutdown()) +} diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSClientAutobahnTest.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSClientAutobahnTest.scala new file mode 100644 index 0000000000..4c05b8f46b --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSClientAutobahnTest.scala @@ -0,0 +1,227 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.impl.engine.ws + +import scala.concurrent.{ Promise, Future } +import scala.util.{ Try, Failure, Success } + +import spray.json._ + +import akka.actor.ActorSystem + +import akka.stream.ActorMaterializer +import akka.stream.io.SslTlsPlacebo +import akka.stream.stage.{ TerminationDirective, Context, SyncDirective, PushStage } +import akka.stream.scaladsl._ + +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.Uri +import akka.http.scaladsl.model.ws._ + +object WSClientAutobahnTest extends App { + implicit val system = ActorSystem() + import system.dispatcher + implicit val mat = ActorMaterializer() + + val Agent = "akka-http" + val Parallelism = 4 + + val getCaseCountUri: Uri = + s"ws://localhost:9001/getCaseCount" + + def runCaseUri(caseIndex: Int, agent: String): Uri = + s"ws://localhost:9001/runCase?case=$caseIndex&agent=$agent" + + def getCaseStatusUri(caseIndex: Int, agent: String): Uri = + s"ws://localhost:9001/getCaseStatus?case=$caseIndex&agent=$agent" + + def getCaseInfoUri(caseIndex: Int): Uri = + s"ws://localhost:9001/getCaseInfo?case=$caseIndex" + + def updateReportsUri(agent: String): Uri = + s"ws://localhost:9001/updateReports?agent=$agent" + + def runCase(caseIndex: Int, agent: String = Agent): Future[CaseStatus] = + runWs(runCaseUri(caseIndex, agent), echo).recover { case _ ⇒ () }.flatMap { _ ⇒ + getCaseStatus(caseIndex, agent) + } + + def richRunCase(caseIndex: Int, agent: String = Agent): Future[CaseResult] = { + val info = getCaseInfo(caseIndex) + val startMillis = System.currentTimeMillis() + val status = runCase(caseIndex, agent).map { res ⇒ + val lastedMillis = System.currentTimeMillis() - startMillis + (res, lastedMillis) + } + import Console._ + info.flatMap { i ⇒ + val prefix = f"$YELLOW${i.caseInfo.id}%-7s$RESET - $WHITE${i.caseInfo.description}$RESET ... " + //println(prefix) + + status.onComplete { + case Success((CaseStatus(status), millis)) ⇒ + val color = if (status == "OK") GREEN else RED + println(f"${color}$status%-15s$RESET$millis%5d ms $prefix") + case Failure(e) ⇒ + println(s"$prefix${RED}failed with '${e.getMessage}'$RESET") + } + + status.map(s ⇒ CaseResult(i.caseInfo, s._1)) + } + } + + def getCaseCount(): Future[Int] = + runToSingleText(getCaseCountUri).map(_.toInt) + + def getCaseInfo(caseId: Int): Future[IndexedCaseInfo] = + runToSingleJsonValue[CaseInfo](getCaseInfoUri(caseId)).map(IndexedCaseInfo(caseId, _)) + + def getCaseStatus(caseId: Int, agent: String = Agent): Future[CaseStatus] = + runToSingleJsonValue[CaseStatus](getCaseStatusUri(caseId, agent)) + + def updateReports(agent: String = Agent): Future[Unit] = + runToSingleText(updateReportsUri(agent)).map(_ ⇒ ()) + + /** + * Map from textual case ID (like 1.1.1) to IndexedCaseInfo + * @return + */ + def getCaseMap(): Future[Map[String, IndexedCaseInfo]] = { + val res = + getCaseCount().flatMap { count ⇒ + println(s"Retrieving case info for $count cases...") + Future.traverse(1 to count)(getCaseInfo).map(_.map(e ⇒ e.caseInfo.id -> e).toMap) + } + res.foreach { res ⇒ + println(s"Received info for ${res.size} cases") + } + res + } + + def echo = Flow[Message].viaMat(completionSignal)(Keep.right) + + if (args.size >= 1) { + // run one + val testId = args(0) + println(s"Trying to run test $testId") + getCaseMap().flatMap { map ⇒ + val info = map(testId) + richRunCase(info.index) + }.onComplete { + case Success(res) ⇒ + println(s"Run successfully finished!") + updateReportsAndShutdown() + case Failure(e) ⇒ + println("Run failed with this exception") + e.printStackTrace() + updateReportsAndShutdown() + } + } else { + println("Running complete test suite") + getCaseCount().flatMap { count ⇒ + println(s"Found $count tests.") + Source(1 to count).mapAsyncUnordered(Parallelism)(richRunCase(_)).grouped(count).runWith(Sink.head) + }.map { results ⇒ + val grouped = + results.groupBy(_.status.behavior) + + import Console._ + println(s"${results.size} tests run.") + println() + println(s"${GREEN}OK$RESET: ${grouped.getOrElse("OK", Nil).size}") + val notOk = grouped.filterNot(_._1 == "OK") + notOk.toSeq.sortBy(_._2.size).foreach { + case (status, cases) ⇒ println(s"$RED$status$RESET: ${cases.size}") + } + println() + println("Not OK tests") + println() + results.filterNot(_.status.behavior == "OK").foreach { r ⇒ + println(f"$RED${r.status.behavior}%-20s$RESET $YELLOW${r.info.id}%-7s$RESET - $WHITE${r.info.description}$RESET") + } + + () + } + .onComplete(completion) + } + + def completion[T]: Try[T] ⇒ Unit = { + case Success(res) ⇒ + println(s"Run successfully finished!") + updateReportsAndShutdown() + case Failure(e) ⇒ + println("Run failed with this exception") + e.printStackTrace() + updateReportsAndShutdown() + } + def updateReportsAndShutdown(): Unit = + updateReports().onComplete { res ⇒ + println("Reports should now be accessible at http://localhost:8080/cwd/reports/clients/index.html") + system.shutdown() + } + + import scala.concurrent.duration._ + import system.dispatcher + system.scheduler.scheduleOnce(60.seconds)(system.shutdown()) + + def runWs[T](uri: Uri, clientFlow: Flow[Message, Message, T]): T = + Http().singleWebsocketRequest(uri, clientFlow)._2 + + def completionSignal[T]: Flow[T, T, Future[Unit]] = + Flow[T].transformMaterializing { () ⇒ + val p = Promise[Unit]() + val stage = + new PushStage[T, T] { + def onPush(elem: T, ctx: Context[T]): SyncDirective = ctx.push(elem) + override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = { + p.success(()) + super.onUpstreamFinish(ctx) + } + override def onDownstreamFinish(ctx: Context[T]): TerminationDirective = { + p.success(()) // should this be failure as well? + super.onDownstreamFinish(ctx) + } + override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = { + p.failure(cause) + super.onUpstreamFailure(cause, ctx) + } + } + + (stage, p.future) + } + + /** + * The autobahn tests define a weird API where every request must be a Websocket request and + * they will send a single websocket message with the result. Websocket everywhere? Strange, + * but somewhat consistent. + */ + def runToSingleText(uri: Uri): Future[String] = { + val sink = Sink.head[Message] + runWs(uri, Flow.wrap(sink, Source.lazyEmpty[Message])(Keep.left)).flatMap { + case tm: TextMessage ⇒ tm.textStream.runWith(Sink.fold("")(_ + _)) + } + } + def runToSingleJsonValue[T: JsonReader](uri: Uri): Future[T] = + runToSingleText(uri).map(_.parseJson.convertTo[T]) + + case class IndexedCaseInfo(index: Int, caseInfo: CaseInfo) + case class CaseResult(info: CaseInfo, status: CaseStatus) + + // {"behavior": "OK"} + case class CaseStatus(behavior: String) { + def isSuccessful: Boolean = behavior == "OK" + } + object CaseStatus { + import DefaultJsonProtocol._ + implicit def caseStatusFormat: JsonFormat[CaseStatus] = jsonFormat1(CaseStatus.apply) + } + + // {"id": "1.1.1", "description": "Send text message with payload 0."} + case class CaseInfo(id: String, description: String) + object CaseInfo { + import DefaultJsonProtocol._ + implicit def caseInfoFormat: JsonFormat[CaseInfo] = jsonFormat2(CaseInfo.apply) + } +} \ No newline at end of file diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSServerAutobahnTest.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSServerAutobahnTest.scala new file mode 100644 index 0000000000..85a3c26dee --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSServerAutobahnTest.scala @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.impl.engine.ws + +import scala.concurrent.Await +import scala.concurrent.duration._ + +import akka.actor.ActorSystem +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.HttpMethods._ +import akka.http.scaladsl.model.ws.{ Message, UpgradeToWebsocket } +import akka.http.scaladsl.model._ +import akka.stream.ActorMaterializer +import akka.stream.scaladsl.Flow + +object WSServerAutobahnTest extends App { + implicit val system = ActorSystem("WSServerTest") + implicit val fm = ActorMaterializer() + + try { + val binding = Http().bindAndHandleSync({ + case req @ HttpRequest(GET, Uri.Path("/"), _, _, _) if req.header[UpgradeToWebsocket].isDefined ⇒ + req.header[UpgradeToWebsocket] match { + case Some(upgrade) ⇒ upgrade.handleMessages(echoWebsocketService) // needed for running the autobahn test suite + case None ⇒ HttpResponse(400, entity = "Not a valid websocket request!") + } + case _: HttpRequest ⇒ HttpResponse(404, entity = "Unknown resource!") + }, + interface = "172.17.42.1", // adapt to your docker host IP address if necessary + port = 9001) + + Await.result(binding, 1.second) // throws if binding fails + println("Server online at http://172.17.42.1:9001") + println("Press RETURN to stop...") + Console.readLine() + } finally { + system.shutdown() + } + + def echoWebsocketService: Flow[Message, Message, Unit] = + Flow[Message] // just let message flow directly to the output +} From 00b4eefab5f09a24de386c344537d60674c3e8dc Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Mon, 12 Oct 2015 11:34:24 +0200 Subject: [PATCH 07/10] !htc #17275 encapsulate Websocket request arguments in new `WebsocketRequest` class --- .../engine/ws/WebsocketClientBlueprint.scala | 17 ++++++------ .../main/scala/akka/http/scaladsl/Http.scala | 27 +++++++------------ .../scaladsl/model/ws/WebsocketRequest.scala | 26 ++++++++++++++++++ .../impl/engine/ws/WebsocketClientSpec.scala | 4 ++- 4 files changed, 47 insertions(+), 27 deletions(-) create mode 100644 akka-http-core/src/main/scala/akka/http/scaladsl/model/ws/WebsocketRequest.scala diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebsocketClientBlueprint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebsocketClientBlueprint.scala index b7c2bf36ed..79e2dd026a 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebsocketClientBlueprint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebsocketClientBlueprint.scala @@ -4,6 +4,8 @@ package akka.http.impl.engine.ws +import akka.http.scaladsl.model.ws.WebsocketRequest + import scala.collection.immutable import scala.concurrent.{ Future, Promise } @@ -32,12 +34,10 @@ object WebsocketClientBlueprint { /** * Returns a WebsocketClientLayer that can be materialized once. */ - def apply(uri: Uri, - extraHeaders: immutable.Seq[HttpHeader], - subProtocol: Option[String], + def apply(request: WebsocketRequest, settings: ClientConnectionSettings, log: LoggingAdapter): Http.WebsocketClientLayer = - (simpleTls.atopMat(handshake(uri, extraHeaders, subProtocol, settings, log))(Keep.right) atop + (simpleTls.atopMat(handshake(request, settings, log))(Keep.right) atop Websocket.framing atop Websocket.stack(serverSide = false, maskingRandomFactory = settings.websocketRandomFactory)).reversed @@ -45,16 +45,15 @@ object WebsocketClientBlueprint { * A bidi flow that injects and inspects the WS handshake and then goes out of the way. This BidiFlow * can only be materialized once. */ - def handshake(uri: Uri, - extraHeaders: immutable.Seq[HttpHeader], - subProtocol: Option[String], + def handshake(request: WebsocketRequest, settings: ClientConnectionSettings, log: LoggingAdapter): BidiFlow[ByteString, ByteString, ByteString, ByteString, Future[WebsocketUpgradeResponse]] = { + import request._ val result = Promise[WebsocketUpgradeResponse]() val valve = StreamUtils.OneTimeValve() - val (initialRequest, key) = Handshake.Client.buildRequest(uri, extraHeaders, subProtocol.toList, settings.websocketRandomFactory()) + val (initialRequest, key) = Handshake.Client.buildRequest(uri, extraHeaders, subprotocol.toList, settings.websocketRandomFactory()) val hostHeader = Host(uri.authority) val renderedInitialRequest = HttpRequestRendererFactory.renderStrict(RequestRenderingContext(initialRequest, hostHeader), settings, log) @@ -86,7 +85,7 @@ object WebsocketClientBlueprint { case NeedMoreData ⇒ ctx.pull() case ResponseStart(status, protocol, headers, entity, close) ⇒ val response = HttpResponse(status, headers, protocol = protocol) - Handshake.Client.validateResponse(response, subProtocol.toList, key) match { + Handshake.Client.validateResponse(response, subprotocol.toList, key) match { case Right(NegotiatedWebsocketSettings(protocol)) ⇒ result.success(ValidUpgrade(response, protocol)) diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala index 13430b812d..70117668bb 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/Http.scala @@ -19,7 +19,7 @@ import akka.http.impl.util.{ ReadTheDocumentationException, Java6Compat, StreamU import akka.http.impl.engine.ws.WebsocketClientBlueprint import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers.Host -import akka.http.scaladsl.model.ws.Message +import akka.http.scaladsl.model.ws.{ WebsocketRequest, Message } import akka.http.scaladsl.util.FastFuture import akka.japi import akka.stream.Materializer @@ -422,25 +422,22 @@ class HttpExt(config: Config)(implicit system: ActorSystem) extends akka.actor.E * * The layer is not reusable and must only be materialized once. */ - def websocketClientLayer(uri: Uri, - extraHeaders: immutable.Seq[HttpHeader] = Nil, - subprotocol: Option[String] = None, + def websocketClientLayer(request: WebsocketRequest, settings: ClientConnectionSettings = ClientConnectionSettings(system), log: LoggingAdapter = system.log): Http.WebsocketClientLayer = - WebsocketClientBlueprint(uri, extraHeaders, subprotocol, settings, log) + WebsocketClientBlueprint(request, settings, log) /** * Constructs a flow that once materialized establishes a Websocket connection to the given Uri. * * The layer is not reusable and must only be materialized once. */ - def websocketClientFlow(uri: Uri, - extraHeaders: immutable.Seq[HttpHeader] = Nil, - subprotocol: Option[String] = None, + def websocketClientFlow(request: WebsocketRequest, localAddress: Option[InetSocketAddress] = None, settings: ClientConnectionSettings = ClientConnectionSettings(system), httpsContext: Option[HttpsContext] = None, log: LoggingAdapter = system.log): Flow[Message, Message, Future[WebsocketUpgradeResponse]] = { + import request.uri require(uri.isAbsolute, s"Websocket request URI must be absolute but was '$uri'") val ctx = uri.scheme match { @@ -453,7 +450,7 @@ class HttpExt(config: Config)(implicit system: ActorSystem) extends akka.actor.E val host = uri.authority.host.address val port = uri.effectivePort - websocketClientLayer(uri, extraHeaders, subprotocol, settings, log) + websocketClientLayer(request, settings, log) .joinMat(_outgoingTlsConnectionLayer(host, port, localAddress, settings, ctx, log))(Keep.left) } @@ -461,15 +458,13 @@ class HttpExt(config: Config)(implicit system: ActorSystem) extends akka.actor.E * Runs a single Websocket conversation given a Uri and a flow that represents the client side of the * Websocket conversation. */ - def singleWebsocketRequest[T](uri: Uri, + def singleWebsocketRequest[T](request: WebsocketRequest, clientFlow: Flow[Message, Message, T], - extraHeaders: immutable.Seq[HttpHeader] = Nil, - subprotocol: Option[String] = None, localAddress: Option[InetSocketAddress] = None, settings: ClientConnectionSettings = ClientConnectionSettings(system), httpsContext: Option[HttpsContext] = None, log: LoggingAdapter = system.log)(implicit mat: Materializer): (Future[WebsocketUpgradeResponse], T) = - websocketClientFlow(uri, extraHeaders, subprotocol, localAddress, settings, httpsContext, log) + websocketClientFlow(request, localAddress, settings, httpsContext, log) .joinMat(clientFlow)(Keep.both).run() /** @@ -695,15 +690,13 @@ object Http extends ExtensionId[HttpExt] with ExtensionIdProvider { final case class OutgoingConnection(localAddress: InetSocketAddress, remoteAddress: InetSocketAddress) /** - * Represents the response to a websocket upgrade request. + * Represents the response to a websocket upgrade request. Can either be [[ValidUpgrade]] or [[InvalidUpgradeResponse]]. */ sealed trait WebsocketUpgradeResponse { def response: HttpResponse } + final case class ValidUpgrade(response: HttpResponse, chosenSubprotocol: Option[String]) extends WebsocketUpgradeResponse final case class InvalidUpgradeResponse(response: HttpResponse, cause: String) extends WebsocketUpgradeResponse - final case class ValidUpgrade( - response: HttpResponse, - chosenSubprotocol: Option[String]) extends WebsocketUpgradeResponse /** * Represents a connection pool to a specific target host and pool configuration. diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/ws/WebsocketRequest.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/ws/WebsocketRequest.scala new file mode 100644 index 0000000000..e4dd04a02b --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/ws/WebsocketRequest.scala @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.scaladsl.model.ws + +import scala.language.implicitConversions + +import scala.collection.immutable + +import akka.http.scaladsl.model.{ HttpHeader, Uri } + +/** + * Represents a Websocket request. + * @param uri The target URI to connect to. + * @param extraHeaders Extra headers to add to the Websocket request. + * @param subprotocol A Websocket subprotocol if required. + */ +final case class WebsocketRequest( + uri: Uri, + extraHeaders: immutable.Seq[HttpHeader] = Nil, + subprotocol: Option[String] = None) +object WebsocketRequest { + implicit def fromTargetUri(uri: Uri): WebsocketRequest = WebsocketRequest(uri) + implicit def fromTargetUriString(uriString: String): WebsocketRequest = WebsocketRequest(uriString) +} \ No newline at end of file diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketClientSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketClientSpec.scala index c7a8629e43..b4937c9c60 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketClientSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketClientSpec.scala @@ -302,7 +302,9 @@ class WebsocketClientSpec extends FreeSpec with Matchers with WithMaterializerSp def targetUri: Uri = "ws://example.org/ws" def clientLayer: Http.WebsocketClientLayer = - Http(system).websocketClientLayer(targetUri, subprotocol = requestedSubProtocol, settings = settings) + Http(system).websocketClientLayer( + WebsocketRequest(targetUri, subprotocol = requestedSubProtocol), + settings = settings) val (netOut, netIn, response) = { val netOut = ByteStringSinkProbe() From d4b5f29c57e9dbb1a1a6bb45af8d061d2cd67aca Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Mon, 12 Oct 2015 14:21:07 +0200 Subject: [PATCH 08/10] +htc #17275 Java-side of Websocket client API --- .../main/scala/akka/http/javadsl/Http.scala | 112 +++++++++++++++++- .../javadsl/model/ws/WebsocketRequest.scala | 61 ++++++++++ .../model/ws/WebsocketUpgradeResponse.scala | 58 +++++++++ .../http/javadsl/WSEchoTestClientApp.java | 83 +++++++++++++ 4 files changed, 313 insertions(+), 1 deletion(-) create mode 100644 akka-http-core/src/main/scala/akka/http/javadsl/model/ws/WebsocketRequest.scala create mode 100644 akka-http-core/src/main/scala/akka/http/javadsl/model/ws/WebsocketUpgradeResponse.scala create mode 100644 akka-http-core/src/test/java/akka/http/javadsl/WSEchoTestClientApp.java diff --git a/akka-http-core/src/main/scala/akka/http/javadsl/Http.scala b/akka-http-core/src/main/scala/akka/http/javadsl/Http.scala index 632c1fddca..869bd6a9b6 100644 --- a/akka-http-core/src/main/scala/akka/http/javadsl/Http.scala +++ b/akka-http-core/src/main/scala/akka/http/javadsl/Http.scala @@ -7,6 +7,8 @@ package akka.http.javadsl import java.lang.{ Iterable ⇒ JIterable } import java.net.InetSocketAddress import akka.http.impl.util.JavaMapping +import akka.http.javadsl.model.ws._ +import akka.stream import akka.stream.io.{ SslTlsInbound, SslTlsOutbound } import scala.language.implicitConversions @@ -496,6 +498,93 @@ class Http(system: ExtendedActorSystem) extends akka.actor.Extension { log: LoggingAdapter, materializer: Materializer): Future[HttpResponse] = delegate.singleRequest(request.asScala, settings, httpsContext, log)(materializer) + /** + * Constructs a Websocket [[BidiFlow]]. + * + * The layer is not reusable and must only be materialized once. + */ + def websocketClientLayer(request: WebsocketRequest): BidiFlow[Message, SslTlsOutbound, SslTlsInbound, Message, Future[WebsocketUpgradeResponse]] = + adaptWsBidiFlow(delegate.websocketClientLayer(request.asScala)) + + /** + * Constructs a Websocket [[BidiFlow]] using the configured default [[ClientConnectionSettings]], + * configured using the `akka.http.client` config section. + * + * The layer is not reusable and must only be materialized once. + */ + def websocketClientLayer(request: WebsocketRequest, + settings: ClientConnectionSettings): BidiFlow[Message, SslTlsOutbound, SslTlsInbound, Message, Future[WebsocketUpgradeResponse]] = + adaptWsBidiFlow(delegate.websocketClientLayer(request.asScala, settings)) + + /** + * Constructs a Websocket [[BidiFlow]] using the configured default [[ClientConnectionSettings]], + * configured using the `akka.http.client` config section. + * + * The layer is not reusable and must only be materialized once. + */ + def websocketClientLayer(request: WebsocketRequest, + settings: ClientConnectionSettings, + log: LoggingAdapter): BidiFlow[Message, SslTlsOutbound, SslTlsInbound, Message, Future[WebsocketUpgradeResponse]] = + adaptWsBidiFlow(delegate.websocketClientLayer(request.asScala, settings, log)) + + /** + * Constructs a flow that once materialized establishes a Websocket connection to the given Uri. + * + * The layer is not reusable and must only be materialized once. + */ + def websocketClientFlow(request: WebsocketRequest): Flow[Message, Message, Future[WebsocketUpgradeResponse]] = + adaptWsFlow { + delegate.websocketClientFlow(request.asScala) + } + + /** + * Constructs a flow that once materialized establishes a Websocket connection to the given Uri. + * + * The layer is not reusable and must only be materialized once. + */ + def websocketClientFlow(request: WebsocketRequest, + localAddress: Option[InetSocketAddress], + settings: ClientConnectionSettings, + httpsContext: Option[HttpsContext], + log: LoggingAdapter): Flow[Message, Message, Future[WebsocketUpgradeResponse]] = + adaptWsFlow { + delegate.websocketClientFlow(request.asScala, localAddress, settings, httpsContext, log) + } + + /** + * Runs a single Websocket conversation given a Uri and a flow that represents the client side of the + * Websocket conversation. + */ + def singleWebsocketRequest[T](request: WebsocketRequest, + clientFlow: Flow[Message, Message, T], + materializer: Materializer): Pair[Future[WebsocketUpgradeResponse], T] = + adaptWsResultTuple { + delegate.singleWebsocketRequest( + request.asScala, + adaptWsFlow[T](clientFlow))(materializer) + } + + /** + * Runs a single Websocket conversation given a Uri and a flow that represents the client side of the + * Websocket conversation. + */ + def singleWebsocketRequest[T](request: WebsocketRequest, + clientFlow: Flow[Message, Message, T], + localAddress: Option[InetSocketAddress], + settings: ClientConnectionSettings, + httpsContext: Option[HttpsContext], + log: LoggingAdapter, + materializer: Materializer): Pair[Future[WebsocketUpgradeResponse], T] = + adaptWsResultTuple { + delegate.singleWebsocketRequest( + request.asScala, + adaptWsFlow[T](clientFlow), + localAddress, + settings, + httpsContext, + log)(materializer) + } + /** * Triggers an orderly shutdown of all host connections pools currently maintained by the [[ActorSystem]]. * The returned future is completed when all pools that were live at the time of this method call @@ -517,7 +606,7 @@ class Http(system: ExtendedActorSystem) extends akka.actor.Extension { def setDefaultClientHttpsContext(context: HttpsContext): Unit = delegate.setDefaultClientHttpsContext(context.asInstanceOf[akka.http.scaladsl.HttpsContext]) - private def adaptTupleFlow[T, Mat](scalaFlow: akka.stream.scaladsl.Flow[(scaladsl.model.HttpRequest, T), (Try[scaladsl.model.HttpResponse], T), Mat]): Flow[Pair[HttpRequest, T], Pair[Try[HttpResponse], T], Mat] = { + private def adaptTupleFlow[T, Mat](scalaFlow: stream.scaladsl.Flow[(scaladsl.model.HttpRequest, T), (Try[scaladsl.model.HttpResponse], T), Mat]): Flow[Pair[HttpRequest, T], Pair[Try[HttpResponse], T], Mat] = { implicit val _ = JavaMapping.identity[T] JavaMapping.toJava(scalaFlow)(JavaMapping.flowMapping[Pair[HttpRequest, T], (scaladsl.model.HttpRequest, T), Pair[Try[HttpResponse], T], (Try[scaladsl.model.HttpResponse], T), Mat]) } @@ -531,4 +620,25 @@ class Http(system: ExtendedActorSystem) extends akka.actor.Extension { new BidiFlow( JavaMapping.adapterBidiFlow[HttpRequest, sm.HttpRequest, sm.HttpResponse, HttpResponse] .atop(clientLayer)) + + private def adaptWsBidiFlow(wsLayer: scaladsl.Http.WebsocketClientLayer): BidiFlow[Message, SslTlsOutbound, SslTlsInbound, Message, Future[WebsocketUpgradeResponse]] = + new BidiFlow( + JavaMapping.adapterBidiFlow[Message, sm.ws.Message, sm.ws.Message, Message] + .atopMat(wsLayer)((_, s) ⇒ adaptWsUpgradeResponse(s))) + + private def adaptWsFlow(wsLayer: stream.scaladsl.Flow[sm.ws.Message, sm.ws.Message, Future[scaladsl.Http.WebsocketUpgradeResponse]]): Flow[Message, Message, Future[WebsocketUpgradeResponse]] = + Flow.adapt(JavaMapping.adapterBidiFlow[Message, sm.ws.Message, sm.ws.Message, Message].joinMat(wsLayer)(Keep.right).mapMaterializedValue(adaptWsUpgradeResponse _)) + + private def adaptWsFlow[Mat](javaFlow: Flow[Message, Message, Mat]): stream.scaladsl.Flow[scaladsl.model.ws.Message, scaladsl.model.ws.Message, Mat] = + stream.scaladsl.Flow[scaladsl.model.ws.Message] + .map(Message.adapt) + .viaMat(javaFlow.asScala)(Keep.right) + .map(_.asScala) + + private def adaptWsResultTuple[T](result: (Future[scaladsl.Http.WebsocketUpgradeResponse], T)): Pair[Future[WebsocketUpgradeResponse], T] = + result match { + case (fut, tMat) ⇒ Pair(adaptWsUpgradeResponse(fut), tMat) + } + private def adaptWsUpgradeResponse(responseFuture: Future[scaladsl.Http.WebsocketUpgradeResponse]): Future[WebsocketUpgradeResponse] = + responseFuture.map(WebsocketUpgradeResponse.adapt)(system.dispatcher) } diff --git a/akka-http-core/src/main/scala/akka/http/javadsl/model/ws/WebsocketRequest.scala b/akka-http-core/src/main/scala/akka/http/javadsl/model/ws/WebsocketRequest.scala new file mode 100644 index 0000000000..d895cb6e56 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/javadsl/model/ws/WebsocketRequest.scala @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.javadsl.model.ws + +import akka.http.javadsl.model.{ Uri, HttpHeader } +import akka.http.scaladsl.model.ws.{ WebsocketRequest ⇒ ScalaWebsocketRequest } + +/** + * Represents a Websocket request. Use `WebsocketRequest.create` to create a request + * for a target URI and then use `addHeader` or `requestSubprotocol` to set optional + * details. + */ +abstract class WebsocketRequest { + /** + * Return a copy of this request that contains the given additional header. + */ + def addHeader(header: HttpHeader): WebsocketRequest + + /** + * Return a copy of this request that will require that the server uses the + * given Websocket subprotocol. + */ + def requestSubprotocol(subprotocol: String): WebsocketRequest + + def asScala: ScalaWebsocketRequest +} +object WebsocketRequest { + import akka.http.impl.util.JavaMapping.Implicits._ + + /** + * Creates a WebsocketRequest to a target URI. Use the methods on `WebsocketRequest` + * to specify further details. + */ + def create(uri: Uri): WebsocketRequest = + wrap(ScalaWebsocketRequest(uri.asScala)) + + /** + * Creates a WebsocketRequest to a target URI. Use the methods on `WebsocketRequest` + * to specify further details. + */ + def create(uriString: String): WebsocketRequest = + create(Uri.create(uriString)) + + /** + * Wraps a Scala version of WebsocketRequest. + */ + def wrap(scalaRequest: ScalaWebsocketRequest): WebsocketRequest = + new WebsocketRequest { + def addHeader(header: HttpHeader): WebsocketRequest = + transform(s ⇒ s.copy(extraHeaders = s.extraHeaders :+ header.asScala)) + def requestSubprotocol(subprotocol: String): WebsocketRequest = + transform(_.copy(subprotocol = Some(subprotocol))) + + def asScala: ScalaWebsocketRequest = scalaRequest + + def transform(f: ScalaWebsocketRequest ⇒ ScalaWebsocketRequest): WebsocketRequest = + wrap(f(asScala)) + } +} diff --git a/akka-http-core/src/main/scala/akka/http/javadsl/model/ws/WebsocketUpgradeResponse.scala b/akka-http-core/src/main/scala/akka/http/javadsl/model/ws/WebsocketUpgradeResponse.scala new file mode 100644 index 0000000000..f6490d21a0 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/javadsl/model/ws/WebsocketUpgradeResponse.scala @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.javadsl.model.ws + +import akka.http.javadsl.model.HttpResponse +import akka.http.scaladsl +import akka.http.scaladsl.Http.{ InvalidUpgradeResponse, ValidUpgrade } +import akka.japi.Option + +/** + * Represents an upgrade response for a Websocket upgrade request. Can either be valid, in which + * case the `chosenSubprotocol` method is valid, or if invalid, the `invalidationReason` method + * can be used to find out why the upgrade failed. + */ +trait WebsocketUpgradeResponse { + def isValid: Boolean + + /** + * Returns the response object as received from the server for further inspection. + */ + def response: HttpResponse + + /** + * If valid, returns `Some(subprotocol)` (if any was requested), or `None` if none was + * chosen or offered. + */ + def chosenSubprotocol: Option[String] + + /** + * If invalid, the reason why the server's upgrade response could not be accepted. + */ + def invalidationReason: String +} + +object WebsocketUpgradeResponse { + import akka.http.impl.util.JavaMapping.Implicits._ + def adapt(scalaResponse: scaladsl.Http.WebsocketUpgradeResponse): WebsocketUpgradeResponse = + scalaResponse match { + case ValidUpgrade(response, chosen) ⇒ + new WebsocketUpgradeResponse { + def isValid: Boolean = true + def response: HttpResponse = response + def chosenSubprotocol: Option[String] = chosen.asJava + def invalidationReason: String = + throw new UnsupportedOperationException("invalidationReason must not be called for valid response") + } + case InvalidUpgradeResponse(response, cause) ⇒ + new WebsocketUpgradeResponse { + def isValid: Boolean = false + def response: HttpResponse = response + def chosenSubprotocol: Option[String] = throw new UnsupportedOperationException("chosenSubprotocol must not be called for valid response") + def invalidationReason: String = cause + } + } + +} \ No newline at end of file diff --git a/akka-http-core/src/test/java/akka/http/javadsl/WSEchoTestClientApp.java b/akka-http-core/src/test/java/akka/http/javadsl/WSEchoTestClientApp.java new file mode 100644 index 0000000000..78d0e2cd30 --- /dev/null +++ b/akka-http-core/src/test/java/akka/http/javadsl/WSEchoTestClientApp.java @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.javadsl; + +import akka.actor.ActorSystem; +import akka.dispatch.Futures; +import akka.http.javadsl.model.ws.Message; +import akka.http.javadsl.model.ws.TextMessage; +import akka.http.javadsl.model.ws.WebsocketRequest; +import akka.japi.function.Function; +import akka.stream.ActorMaterializer; +import akka.stream.Materializer; +import akka.stream.javadsl.Flow; +import akka.stream.javadsl.Keep; +import akka.stream.javadsl.Sink; +import akka.stream.javadsl.Source; +import scala.concurrent.Await; +import scala.concurrent.Future; +import scala.concurrent.duration.FiniteDuration; +import scala.runtime.BoxedUnit; + +import java.util.Arrays; +import java.util.List; + +public class WSEchoTestClientApp { + private static final Function messageStringifier = new Function() { + @Override + public String apply(Message msg) throws Exception { + if (msg.isText() && msg.asTextMessage().isStrict()) + return msg.asTextMessage().getStrictText(); + else + throw new IllegalArgumentException("Unexpected message "+msg); + } + }; + + public static void main(String[] args) throws Exception { + ActorSystem system = ActorSystem.create(); + + try { + final Materializer materializer = ActorMaterializer.create(system); + + final Future ignoredMessage = Futures.successful((Message) TextMessage.create("blub")); + final Future delayedCompletion = + akka.pattern.Patterns.after( + FiniteDuration.apply(1, "second"), + system.scheduler(), + system.dispatcher(), + ignoredMessage); + + Source echoSource = + Source.from(Arrays.asList( + TextMessage.create("abc"), + TextMessage.create("def"), + TextMessage.create("ghi") + )).concat(Source.from(delayedCompletion).drop(1)); + + Sink>> echoSink = + Flow.of(Message.class) + .map(messageStringifier) + .grouped(1000) + .toMat(Sink.>head(), Keep.>>right()); + + Flow>> echoClient = + Flow.wrap(echoSink, echoSource, Keep.>, BoxedUnit>left()); + + Future> result = + Http.get(system).singleWebsocketRequest( + WebsocketRequest.create("ws://echo.websocket.org"), + echoClient, + materializer + ).second(); + + List messages = Await.result(result, FiniteDuration.apply(10, "second")); + System.out.println("Collected " + messages.size() + " messages:"); + for (String msg: messages) + System.out.println(msg); + } finally { + system.shutdown(); + } + } +} From f6732f33691d64c14a4da120ffe22ae1fd41c9dd Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Mon, 19 Oct 2015 09:08:36 +0200 Subject: [PATCH 09/10] =htc server side handshake cleanup and clarifications --- .../engine/parsing/HttpRequestParser.scala | 2 +- .../akka/http/impl/engine/ws/Handshake.scala | 92 ++++++++++--------- .../impl/engine/ws/WebsocketServerSpec.scala | 1 + 3 files changed, 52 insertions(+), 43 deletions(-) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala index 60fb35d7c0..f69aebea9e 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpRequestParser.scala @@ -129,7 +129,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings, val allHeaders = if (method == HttpMethods.GET) { - Handshake.Server.isWebsocketUpgrade(headers, hostHeaderPresent) match { + Handshake.Server.websocketUpgrade(headers, hostHeaderPresent) match { case Some(upgrade) ⇒ upgrade :: allHeaders0 case None ⇒ allHeaders0 } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala index af09a2cf46..014fca5a24 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala @@ -6,6 +6,8 @@ package akka.http.impl.engine.ws import java.util.Random +import akka.http.impl.engine.parsing.ParserOutput.MessageStartError + import scala.collection.immutable import scala.collection.immutable.Seq import scala.reflect.ClassTag @@ -29,57 +31,64 @@ private[http] object Handshake { val CurrentWebsocketVersion = 13 object Server { - /* - From: http://tools.ietf.org/html/rfc6455#section-4.2.1 - - 1. An HTTP/1.1 or higher GET request, including a "Request-URI" - [RFC2616] that should be interpreted as a /resource name/ - defined in Section 3 (or an absolute HTTP/HTTPS URI containing - the /resource name/). - - 2. A |Host| header field containing the server's authority. - - 3. An |Upgrade| header field containing the value "websocket", - treated as an ASCII case-insensitive value. - - 4. A |Connection| header field that includes the token "Upgrade", - treated as an ASCII case-insensitive value. - - 5. A |Sec-WebSocket-Key| header field with a base64-encoded (see - Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in - length. - - 6. A |Sec-WebSocket-Version| header field, with a value of 13. - - 7. Optionally, an |Origin| header field. This header field is sent - by all browser clients. A connection attempt lacking this - header field SHOULD NOT be interpreted as coming from a browser - client. - - 8. Optionally, a |Sec-WebSocket-Protocol| header field, with a list - of values indicating which protocols the client would like to - speak, ordered by preference. - - 9. Optionally, a |Sec-WebSocket-Extensions| header field, with a - list of values indicating which extensions the client would like - to speak. The interpretation of this header field is discussed - in Section 9.1. - */ - def isWebsocketUpgrade(headers: List[HttpHeader], hostHeaderPresent: Boolean): Option[UpgradeToWebsocket] = { + /** + * Validates a client Websocket handshake. Returns either `Right(UpgradeToWebsocket)` or + * `Left(MessageStartError)`. + * + * From: http://tools.ietf.org/html/rfc6455#section-4.2.1 + * + * 1. An HTTP/1.1 or higher GET request, including a "Request-URI" + * [RFC2616] that should be interpreted as a /resource name/ + * defined in Section 3 (or an absolute HTTP/HTTPS URI containing + * the /resource name/). + * + * 2. A |Host| header field containing the server's authority. + * + * 3. An |Upgrade| header field containing the value "websocket", + * treated as an ASCII case-insensitive value. + * + * 4. A |Connection| header field that includes the token "Upgrade", + * treated as an ASCII case-insensitive value. + * + * 5. A |Sec-WebSocket-Key| header field with a base64-encoded (see + * Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in + * length. + * + * 6. A |Sec-WebSocket-Version| header field, with a value of 13. + * + * 7. Optionally, an |Origin| header field. This header field is sent + * by all browser clients. A connection attempt lacking this + * header field SHOULD NOT be interpreted as coming from a browser + * client. + * + * 8. Optionally, a |Sec-WebSocket-Protocol| header field, with a list + * of values indicating which protocols the client would like to + * speak, ordered by preference. + * + * 9. Optionally, a |Sec-WebSocket-Extensions| header field, with a + * list of values indicating which extensions the client would like + * to speak. The interpretation of this header field is discussed + * in Section 9.1. + */ + def websocketUpgrade(headers: List[HttpHeader], hostHeaderPresent: Boolean): Option[UpgradeToWebsocket] = { def find[T <: HttpHeader: ClassTag]: Option[T] = headers.collectFirst { case t: T ⇒ t } + // Host header is validated in general HTTP logic // val host = find[Host] val upgrade = find[Upgrade] val connection = find[Connection] val key = find[`Sec-WebSocket-Key`] val version = find[`Sec-WebSocket-Version`] + // Origin header is optional and, if required, should be validated + // on higher levels (routing, application logic) // val origin = find[Origin] val protocol = find[`Sec-WebSocket-Protocol`] - val supportedProtocols = protocol.toList.flatMap(_.protocols) - // FIXME: support extensions + val clientSupportedSubprotocols = protocol.toList.flatMap(_.protocols) + // Extension support is optional in WS and currently unsupported. + // FIXME See #18709 // val extensions = find[`Sec-WebSocket-Extensions`] def isValidKey(key: String): Boolean = Base64.rfc2045().decode(key).length == 16 @@ -90,10 +99,10 @@ private[http] object Handshake { key.exists(k ⇒ isValidKey(k.key))) { val header = new UpgradeToWebsocketLowLevel { - def requestedProtocols: Seq[String] = supportedProtocols + def requestedProtocols: Seq[String] = clientSupportedSubprotocols def handle(handler: Either[Flow[FrameEvent, FrameEvent, Any], Flow[Message, Message, Any]], subprotocol: Option[String]): HttpResponse = { - require(subprotocol.forall(chosen ⇒ supportedProtocols.contains(chosen)), + require(subprotocol.forall(chosen ⇒ clientSupportedSubprotocols.contains(chosen)), s"Tried to choose invalid subprotocol '$subprotocol' which wasn't offered by the client: [${requestedProtocols.mkString(", ")}]") buildResponse(key.get, handler, subprotocol) } @@ -228,7 +237,6 @@ private[http] object Handshake { } def compare(candidate: HttpHeader, caseInsensitive: Boolean): Option[HttpHeader] ⇒ Boolean = { - case Some(`candidate`) if !caseInsensitive ⇒ true case Some(header) if caseInsensitive && candidate.value.toRootLowerCase == header.value.toRootLowerCase ⇒ true case _ ⇒ false diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala index 7b25f9002b..e8b8755d72 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala @@ -116,6 +116,7 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp } "prevent the selection of an unavailable subprotocol" in pending "reject invalid Websocket handshakes" - { + "missing `Upgrade: websocket` header" in pending "missing `Connection: upgrade` header" in pending "missing `Sec-WebSocket-Key header" in pending "`Sec-WebSocket-Key` with wrong amount of base64 encoded data" in pending From ddc8cd804bc9a58e12a39b9ca793c716b0821b56 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Mon, 19 Oct 2015 09:17:06 +0200 Subject: [PATCH 10/10] =htc move `Sec-WebSocket-Key` creation/validation to header model --- .../akka/http/impl/engine/ws/Handshake.scala | 10 ++-------- .../http/scaladsl/model/headers/headers.scala | 20 ++++++++++++++++--- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala index 014fca5a24..e0909b08f8 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Handshake.scala @@ -6,14 +6,10 @@ package akka.http.impl.engine.ws import java.util.Random -import akka.http.impl.engine.parsing.ParserOutput.MessageStartError - import scala.collection.immutable import scala.collection.immutable.Seq import scala.reflect.ClassTag -import akka.parboiled2.util.Base64 - import akka.stream.scaladsl.Flow import akka.http.impl.util._ @@ -91,12 +87,10 @@ private[http] object Handshake { // FIXME See #18709 // val extensions = find[`Sec-WebSocket-Extensions`] - def isValidKey(key: String): Boolean = Base64.rfc2045().decode(key).length == 16 - if (upgrade.exists(_.hasWebsocket) && connection.exists(_.hasUpgrade) && version.exists(_.hasVersion(CurrentWebsocketVersion)) && - key.exists(k ⇒ isValidKey(k.key))) { + key.exists(k ⇒ k.isValid)) { val header = new UpgradeToWebsocketLowLevel { def requestedProtocols: Seq[String] = clientSupportedSubprotocols @@ -156,7 +150,7 @@ private[http] object Handshake { def buildRequest(uri: Uri, extraHeaders: immutable.Seq[HttpHeader], subprotocols: Seq[String], random: Random): (HttpRequest, `Sec-WebSocket-Key`) = { val keyBytes = new Array[Byte](16) random.nextBytes(keyBytes) - val key = `Sec-WebSocket-Key`(Base64.rfc2045().encodeToString(keyBytes, false)) + val key = `Sec-WebSocket-Key`(keyBytes) val protocol = if (subprotocols.nonEmpty) `Sec-WebSocket-Protocol`(subprotocols) :: Nil else Nil diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala index 41092abab4..46fbb83522 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala @@ -8,15 +8,18 @@ import java.lang.Iterable import java.net.InetSocketAddress import java.security.MessageDigest import java.util + +import scala.reflect.ClassTag +import scala.util.Try import scala.annotation.tailrec import scala.collection.immutable + import akka.parboiled2.util.Base64 + import akka.http.impl.util._ import akka.http.javadsl.{ model ⇒ jm } import akka.http.scaladsl.model._ -import scala.reflect.ClassTag - sealed abstract class ModeledCompanion[T: ClassTag] extends Renderable { val name = getClass.getSimpleName.replace("$minus", "-").dropRight(1) // trailing $ val lowercaseName = name.toRootLowerCase @@ -627,7 +630,12 @@ private[http] final case class `Sec-WebSocket-Extensions`(extensions: immutable. /** * INTERNAL API */ -private[http] object `Sec-WebSocket-Key` extends ModeledCompanion[`Sec-WebSocket-Key`] +private[http] object `Sec-WebSocket-Key` extends ModeledCompanion[`Sec-WebSocket-Key`] { + def apply(keyBytes: Array[Byte]): `Sec-WebSocket-Key` = { + require(keyBytes.length == 16, s"Sec-WebSocket-Key keyBytes must have length 16 but had ${keyBytes.length}") + `Sec-WebSocket-Key`(Base64.rfc2045().encodeToString(keyBytes, false)) + } +} /** * INTERNAL API */ @@ -635,6 +643,12 @@ private[http] final case class `Sec-WebSocket-Key`(key: String) extends ModeledH protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ key protected def companion = `Sec-WebSocket-Key` + + /** + * Checks if the key value is valid according to the Websocket specification, i.e. + * if the String is a Base64 representation of 16 bytes. + */ + def isValid: Boolean = Try(Base64.rfc2045().decode(key)).toOption.exists(_.length == 16) } // http://tools.ietf.org/html/rfc6455#section-4.3