Merge pull request #18693 from spray/w/17275-squashed

#17275 Websocket Client implementation
This commit is contained in:
Konrad Malawski 2015-10-19 13:26:09 +02:00
commit b8ea349c3a
36 changed files with 2024 additions and 323 deletions

View file

@ -4,6 +4,9 @@
package akka.http
import java.util.Random
import akka.http.impl.engine.ws.Randoms
import akka.io.Inet.SocketOption
import scala.concurrent.duration.{ Duration, FiniteDuration }
@ -21,6 +24,7 @@ final case class ClientConnectionSettings(
connectingTimeout: FiniteDuration,
idleTimeout: Duration,
requestHeaderSizeHint: Int,
websocketRandomFactory: () Random,
socketOptions: immutable.Traversable[SocketOption],
parserSettings: ParserSettings) {
@ -36,6 +40,7 @@ object ClientConnectionSettings extends SettingsCompanion[ClientConnectionSettin
c getFiniteDuration "connecting-timeout",
c getPotentiallyInfiniteDuration "idle-timeout",
c getIntBytes "request-header-size-hint",
Randoms.SecureRandomInstances, // can currently only be overridden from code
SocketOptionSettings.fromSubConfig(root, c.getConfig("socket-options")),
ParserSettings.fromSubConfig(root, c.getConfig("parsing")))
}

View file

@ -4,6 +4,9 @@
package akka.http
import java.util.Random
import akka.http.impl.engine.ws.Randoms
import com.typesafe.config.Config
import scala.language.implicitConversions
@ -31,6 +34,7 @@ final case class ServerSettings(
backlog: Int,
socketOptions: immutable.Traversable[SocketOption],
defaultHostHeader: Host,
websocketRandomFactory: () Random,
parserSettings: ParserSettings) {
require(0 < maxConnections, "max-connections must be > 0")
@ -65,6 +69,7 @@ object ServerSettings extends SettingsCompanion[ServerSettings]("akka.http.serve
val info = result.errors.head.withSummary("Configured `default-host-header` is illegal")
throw new ConfigurationException(info.formatPretty)
},
Randoms.SecureRandomInstances, // can currently only be overridden from code
ParserSettings.fromSubConfig(root, c.getConfig("parsing")))
def apply(optionalSettings: Option[ServerSettings])(implicit actorRefFactory: ActorRefFactory): ServerSettings =

View file

@ -11,7 +11,7 @@ import scala.annotation.tailrec
import akka.actor.ActorRef
import akka.stream.stage.{ Context, PushPullStage }
import akka.stream.scaladsl.Flow
import akka.stream.scaladsl.{ Keep, Source }
import akka.stream.scaladsl.Source
import akka.util.ByteString
import akka.http.impl.engine.ws.Handshake
import akka.http.impl.model.parser.CharacterClasses
@ -129,7 +129,7 @@ private[http] class HttpRequestParser(_settings: ParserSettings,
val allHeaders =
if (method == HttpMethods.GET) {
Handshake.isWebsocketUpgrade(headers, hostHeaderPresent) match {
Handshake.Server.websocketUpgrade(headers, hostHeaderPresent) match {
case Some(upgrade) upgrade :: allHeaders0
case None allHeaders0
}

View file

@ -56,4 +56,6 @@ private[http] object ParserOutput {
case object NeedMoreData extends MessageOutput
case object NeedNextRequestMethod extends ResponseOutput
final case class RemainingBytes(bytes: ByteString) extends ResponseOutput
}

View file

@ -5,6 +5,7 @@
package akka.http.impl.engine.rendering
import akka.http.impl.engine.ws.{ FrameEvent, UpgradeToWebsocketResponseHeader }
import akka.http.scaladsl.model.ws.Message
import scala.annotation.tailrec
import akka.event.LoggingAdapter
@ -159,7 +160,7 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
r ~~ connHeader ~~ CrLf
headers
.collectFirst { case u: UpgradeToWebsocketResponseHeader u }
.foreach { header closeMode = SwitchToWebsocket(header.handlerFlow) }
.foreach { header closeMode = SwitchToWebsocket(header.handler) }
}
if (mustRenderTransferEncodingChunkedHeader && !transferEncodingSeen)
r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf
@ -220,7 +221,7 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
sealed trait CloseMode
case object DontClose extends CloseMode
case object CloseConnection extends CloseMode
case class SwitchToWebsocket(handlerFlow: Flow[FrameEvent, FrameEvent, Any]) extends CloseMode
case class SwitchToWebsocket(handler: Either[Flow[FrameEvent, FrameEvent, Any], Flow[Message, Message, Any]]) extends CloseMode
}
/**
@ -237,5 +238,5 @@ private[http] sealed trait ResponseRenderingOutput
/** INTERNAL API */
private[http] object ResponseRenderingOutput {
private[http] case class HttpData(bytes: ByteString) extends ResponseRenderingOutput
private[http] case class SwitchToWebsocket(httpResponseBytes: ByteString, handlerFlow: Flow[FrameEvent, FrameEvent, Any]) extends ResponseRenderingOutput
private[http] case class SwitchToWebsocket(httpResponseBytes: ByteString, handler: Either[Flow[FrameEvent, FrameEvent, Any], Flow[Message, Message, Any]]) extends ResponseRenderingOutput
}

View file

@ -5,8 +5,10 @@
package akka.http.impl.engine.server
import java.net.InetSocketAddress
import java.util.Random
import akka.http.ServerSettings
import akka.http.scaladsl.model.ws.Message
import akka.stream.io._
import org.reactivestreams.{ Subscriber, Publisher }
import scala.util.control.NonFatal
@ -145,7 +147,7 @@ private[http] object HttpServerBluePrint {
// protocol routing
val protocolRouter = b.add(new WebsocketSwitchRouter())
val protocolMerge = b.add(new WebsocketMerge(ws.installHandler))
val protocolMerge = b.add(new WebsocketMerge(ws.installHandler, settings.websocketRandomFactory))
protocolRouter.out0 ~> http ~> protocolMerge.in0
protocolRouter.out1 ~> websocket ~> protocolMerge.in1
@ -355,7 +357,7 @@ private[http] object HttpServerBluePrint {
}
}
}
class WebsocketMerge(installHandler: Flow[FrameEvent, FrameEvent, Any] Unit) extends FlexiMerge[ByteString, FanInShape2[ResponseRenderingOutput, ByteString, ByteString]](new FanInShape2("websocketMerge"), Attributes.name("websocketMerge")) {
class WebsocketMerge(installHandler: Flow[FrameEvent, FrameEvent, Any] Unit, websocketRandomFactory: () Random) extends FlexiMerge[ByteString, FanInShape2[ResponseRenderingOutput, ByteString, ByteString]](new FanInShape2("websocketMerge"), Attributes.name("websocketMerge")) {
def createMergeLogic(s: FanInShape2[ResponseRenderingOutput, ByteString, ByteString]): MergeLogic[ByteString] =
new MergeLogic[ByteString] {
var websocketHandlerWasInstalled: Boolean = false
@ -370,7 +372,13 @@ private[http] object HttpServerBluePrint {
ctx.emit(bytes); SameState
case ResponseRenderingOutput.SwitchToWebsocket(responseBytes, handlerFlow)
ctx.emit(responseBytes)
installHandler(handlerFlow)
val frameHandler = handlerFlow match {
case Left(frameHandler) frameHandler
case Right(messageHandler)
Websocket.stack(serverSide = true, maskingRandomFactory = websocketRandomFactory).join(messageHandler)
}
installHandler(frameHandler)
ctx.changeCompletionHandling(defaultCompletionHandling)
websocketHandlerWasInstalled = true
websocket

View file

@ -10,7 +10,7 @@ import scala.concurrent.duration.FiniteDuration
import akka.stream.stage._
import akka.http.impl.util.Timestamp
import FrameHandler.{ UserHandlerCompleted, ActivelyCloseWithCode, PeerClosed, DirectAnswer }
import akka.http.impl.engine.ws.FrameHandler._
import Websocket.Tick
/**

View file

@ -4,16 +4,20 @@
package akka.http.impl.engine.ws
import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.model.ws.{ Message, UpgradeToWebsocket }
import akka.http.scaladsl.model.{ StatusCodes, HttpResponse, HttpProtocol, HttpHeader }
import akka.parboiled2.util.Base64
import akka.stream.Materializer
import akka.stream.scaladsl.Flow
import java.util.Random
import scala.collection.immutable
import scala.collection.immutable.Seq
import scala.reflect.ClassTag
import akka.stream.scaladsl.Flow
import akka.http.impl.util._
import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.model.ws.{ Message, UpgradeToWebsocket }
import akka.http.scaladsl.model._
/**
* Server-side implementation of the Websocket handshake
*
@ -22,104 +26,242 @@ import scala.reflect.ClassTag
private[http] object Handshake {
val CurrentWebsocketVersion = 13
/*
From: http://tools.ietf.org/html/rfc6455#section-4.2.1
1. An HTTP/1.1 or higher GET request, including a "Request-URI"
[RFC2616] that should be interpreted as a /resource name/
defined in Section 3 (or an absolute HTTP/HTTPS URI containing
the /resource name/).
2. A |Host| header field containing the server's authority.
3. An |Upgrade| header field containing the value "websocket",
treated as an ASCII case-insensitive value.
4. A |Connection| header field that includes the token "Upgrade",
treated as an ASCII case-insensitive value.
5. A |Sec-WebSocket-Key| header field with a base64-encoded (see
Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
length.
6. A |Sec-WebSocket-Version| header field, with a value of 13.
7. Optionally, an |Origin| header field. This header field is sent
by all browser clients. A connection attempt lacking this
header field SHOULD NOT be interpreted as coming from a browser
client.
8. Optionally, a |Sec-WebSocket-Protocol| header field, with a list
of values indicating which protocols the client would like to
speak, ordered by preference.
9. Optionally, a |Sec-WebSocket-Extensions| header field, with a
list of values indicating which extensions the client would like
to speak. The interpretation of this header field is discussed
in Section 9.1.
*/
def isWebsocketUpgrade(headers: List[HttpHeader], hostHeaderPresent: Boolean): Option[UpgradeToWebsocket] = {
def find[T <: HttpHeader: ClassTag]: Option[T] =
headers.collectFirst {
case t: T t
}
val host = find[Host]
val upgrade = find[Upgrade]
val connection = find[Connection]
val key = find[`Sec-WebSocket-Key`]
val version = find[`Sec-WebSocket-Version`]
val origin = find[Origin]
val protocol = find[`Sec-WebSocket-Protocol`]
val supportedProtocols = protocol.toList.flatMap(_.protocols)
val extensions = find[`Sec-WebSocket-Extensions`]
def isValidKey(key: String): Boolean = Base64.rfc2045().decode(key).length == 16
if (upgrade.exists(_.hasWebsocket) &&
connection.exists(_.hasUpgrade) &&
version.exists(_.hasVersion(CurrentWebsocketVersion)) &&
key.exists(k isValidKey(k.key))) {
val header = new UpgradeToWebsocketLowLevel {
def requestedProtocols: Seq[String] = supportedProtocols
def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String]): HttpResponse = {
require(subprotocol.forall(chosen supportedProtocols.contains(chosen)),
s"Tried to choose invalid subprotocol '$subprotocol' which wasn't offered by the client: [${requestedProtocols.mkString(", ")}]")
buildResponse(key.get, handlerFlow, subprotocol)
object Server {
/**
* Validates a client Websocket handshake. Returns either `Right(UpgradeToWebsocket)` or
* `Left(MessageStartError)`.
*
* From: http://tools.ietf.org/html/rfc6455#section-4.2.1
*
* 1. An HTTP/1.1 or higher GET request, including a "Request-URI"
* [RFC2616] that should be interpreted as a /resource name/
* defined in Section 3 (or an absolute HTTP/HTTPS URI containing
* the /resource name/).
*
* 2. A |Host| header field containing the server's authority.
*
* 3. An |Upgrade| header field containing the value "websocket",
* treated as an ASCII case-insensitive value.
*
* 4. A |Connection| header field that includes the token "Upgrade",
* treated as an ASCII case-insensitive value.
*
* 5. A |Sec-WebSocket-Key| header field with a base64-encoded (see
* Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
* length.
*
* 6. A |Sec-WebSocket-Version| header field, with a value of 13.
*
* 7. Optionally, an |Origin| header field. This header field is sent
* by all browser clients. A connection attempt lacking this
* header field SHOULD NOT be interpreted as coming from a browser
* client.
*
* 8. Optionally, a |Sec-WebSocket-Protocol| header field, with a list
* of values indicating which protocols the client would like to
* speak, ordered by preference.
*
* 9. Optionally, a |Sec-WebSocket-Extensions| header field, with a
* list of values indicating which extensions the client would like
* to speak. The interpretation of this header field is discussed
* in Section 9.1.
*/
def websocketUpgrade(headers: List[HttpHeader], hostHeaderPresent: Boolean): Option[UpgradeToWebsocket] = {
def find[T <: HttpHeader: ClassTag]: Option[T] =
headers.collectFirst {
case t: T t
}
}
Some(header)
} else None
// Host header is validated in general HTTP logic
// val host = find[Host]
val upgrade = find[Upgrade]
val connection = find[Connection]
val key = find[`Sec-WebSocket-Key`]
val version = find[`Sec-WebSocket-Version`]
// Origin header is optional and, if required, should be validated
// on higher levels (routing, application logic)
// val origin = find[Origin]
val protocol = find[`Sec-WebSocket-Protocol`]
val clientSupportedSubprotocols = protocol.toList.flatMap(_.protocols)
// Extension support is optional in WS and currently unsupported.
// FIXME See #18709
// val extensions = find[`Sec-WebSocket-Extensions`]
if (upgrade.exists(_.hasWebsocket) &&
connection.exists(_.hasUpgrade) &&
version.exists(_.hasVersion(CurrentWebsocketVersion)) &&
key.exists(k k.isValid)) {
val header = new UpgradeToWebsocketLowLevel {
def requestedProtocols: Seq[String] = clientSupportedSubprotocols
def handle(handler: Either[Flow[FrameEvent, FrameEvent, Any], Flow[Message, Message, Any]], subprotocol: Option[String]): HttpResponse = {
require(subprotocol.forall(chosen clientSupportedSubprotocols.contains(chosen)),
s"Tried to choose invalid subprotocol '$subprotocol' which wasn't offered by the client: [${requestedProtocols.mkString(", ")}]")
buildResponse(key.get, handler, subprotocol)
}
def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String]): HttpResponse =
handle(Left(handlerFlow), subprotocol)
override def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String] = None): HttpResponse =
handle(Right(handlerFlow), subprotocol)
}
Some(header)
} else None
}
/*
From: http://tools.ietf.org/html/rfc6455#section-4.2.2
1. A Status-Line with a 101 response code as per RFC 2616
[RFC2616]. Such a response could look like "HTTP/1.1 101
Switching Protocols".
2. An |Upgrade| header field with value "websocket" as per RFC
2616 [RFC2616].
3. A |Connection| header field with value "Upgrade".
4. A |Sec-WebSocket-Accept| header field. The value of this
header field is constructed by concatenating /key/, defined
above in step 4 in Section 4.2.2, with the string "258EAFA5-
E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this
concatenated value to obtain a 20-byte value and base64-
encoding (see Section 4 of [RFC4648]) this 20-byte hash.
*/
def buildResponse(key: `Sec-WebSocket-Key`, handler: Either[Flow[FrameEvent, FrameEvent, Any], Flow[Message, Message, Any]], subprotocol: Option[String]): HttpResponse =
HttpResponse(
StatusCodes.SwitchingProtocols,
subprotocol.map(p `Sec-WebSocket-Protocol`(Seq(p))).toList :::
List(
UpgradeHeader,
ConnectionUpgradeHeader,
`Sec-WebSocket-Accept`.forKey(key),
UpgradeToWebsocketResponseHeader(handler)))
}
/*
From: http://tools.ietf.org/html/rfc6455#section-4.2.2
object Client {
case class NegotiatedWebsocketSettings(subprotocol: Option[String])
1. A Status-Line with a 101 response code as per RFC 2616
[RFC2616]. Such a response could look like "HTTP/1.1 101
Switching Protocols".
/**
* Builds a WebSocket handshake request.
*/
def buildRequest(uri: Uri, extraHeaders: immutable.Seq[HttpHeader], subprotocols: Seq[String], random: Random): (HttpRequest, `Sec-WebSocket-Key`) = {
val keyBytes = new Array[Byte](16)
random.nextBytes(keyBytes)
val key = `Sec-WebSocket-Key`(keyBytes)
val protocol =
if (subprotocols.nonEmpty) `Sec-WebSocket-Protocol`(subprotocols) :: Nil
else Nil
//version, protocol, extensions, origin
2. An |Upgrade| header field with value "websocket" as per RFC
2616 [RFC2616].
val headers = Seq(
UpgradeHeader,
ConnectionUpgradeHeader,
key,
SecWebsocketVersionHeader) ++ protocol ++ extraHeaders
3. A |Connection| header field with value "Upgrade".
(HttpRequest(HttpMethods.GET, uri.toRelative, headers), key)
}
4. A |Sec-WebSocket-Accept| header field. The value of this
header field is constructed by concatenating /key/, defined
above in step 4 in Section 4.2.2, with the string "258EAFA5-
E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this
concatenated value to obtain a 20-byte value and base64-
encoding (see Section 4 of [RFC4648]) this 20-byte hash.
*/
def buildResponse(key: `Sec-WebSocket-Key`, handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String]): HttpResponse =
HttpResponse(
StatusCodes.SwitchingProtocols,
subprotocol.map(p `Sec-WebSocket-Protocol`(Seq(p))).toList :::
List(
Upgrade(List(UpgradeProtocol("websocket"))),
Connection(List("upgrade")),
`Sec-WebSocket-Accept`.forKey(key),
UpgradeToWebsocketResponseHeader(handlerFlow)))
/**
* Tries to validate the HTTP response. Returns either Right(settings) or an error message if
* the response cannot be validated.
*/
def validateResponse(response: HttpResponse, subprotocols: Seq[String], key: `Sec-WebSocket-Key`): Either[String, NegotiatedWebsocketSettings] = {
/*
From http://tools.ietf.org/html/rfc6455#section-4.1
1. If the status code received from the server is not 101, the
client handles the response per HTTP [RFC2616] procedures. In
particular, the client might perform authentication if it
receives a 401 status code; the server might redirect the client
using a 3xx status code (but clients are not required to follow
them), etc. Otherwise, proceed as follows.
2. If the response lacks an |Upgrade| header field or the |Upgrade|
header field contains a value that is not an ASCII case-
insensitive match for the value "websocket", the client MUST
_Fail the WebSocket Connection_.
3. If the response lacks a |Connection| header field or the
|Connection| header field doesn't contain a token that is an
ASCII case-insensitive match for the value "Upgrade", the client
MUST _Fail the WebSocket Connection_.
4. If the response lacks a |Sec-WebSocket-Accept| header field or
the |Sec-WebSocket-Accept| contains a value other than the
base64-encoded SHA-1 of the concatenation of the |Sec-WebSocket-
Key| (as a string, not base64-decoded) with the string "258EAFA5-
E914-47DA-95CA-C5AB0DC85B11" but ignoring any leading and
trailing whitespace, the client MUST _Fail the WebSocket
Connection_.
5. If the response includes a |Sec-WebSocket-Extensions| header
field and this header field indicates the use of an extension
that was not present in the client's handshake (the server has
indicated an extension not requested by the client), the client
MUST _Fail the WebSocket Connection_. (The parsing of this
header field to determine which extensions are requested is
discussed in Section 9.1.)
6. If the response includes a |Sec-WebSocket-Protocol| header field
and this header field indicates the use of a subprotocol that was
not present in the client's handshake (the server has indicated a
subprotocol not requested by the client), the client MUST _Fail
the WebSocket Connection_.
*/
trait Expectation extends (HttpResponse Option[String]) { outer
def &&(other: HttpResponse Option[String]): Expectation =
new Expectation {
def apply(v1: HttpResponse): Option[String] =
outer(v1).orElse(other(v1))
}
}
def check[T](value: HttpResponse T)(condition: T Boolean, msg: T String): Expectation =
new Expectation {
def apply(resp: HttpResponse): Option[String] = {
val v = value(resp)
if (condition(v)) None
else Some(msg(v))
}
}
def compare(candidate: HttpHeader, caseInsensitive: Boolean): Option[HttpHeader] Boolean = {
case Some(`candidate`) if !caseInsensitive true
case Some(header) if caseInsensitive && candidate.value.toRootLowerCase == header.value.toRootLowerCase true
case _ false
}
def headerExists(candidate: HttpHeader, showExactOther: Boolean = true, caseInsensitive: Boolean = false): Expectation =
check(_.headers.find(_.name == candidate.name))(compare(candidate, caseInsensitive), {
case Some(other) if showExactOther s"response that was missing required `$candidate` header. Found `$other` with the wrong value."
case Some(_) s"response with invalid `${candidate.name}` header."
case None s"response that was missing required `${candidate.name}` header."
})
val expectations: Expectation =
check(_.status)(_ == StatusCodes.SwitchingProtocols, "unexpected status code: " + _) &&
headerExists(UpgradeHeader, caseInsensitive = true) &&
headerExists(ConnectionUpgradeHeader, caseInsensitive = true) &&
headerExists(`Sec-WebSocket-Accept`.forKey(key), showExactOther = false)
expectations(response) match {
case None
val subs = response.header[`Sec-WebSocket-Protocol`].flatMap(_.protocols.headOption)
if (subprotocols.isEmpty && subs.isEmpty) Right(NegotiatedWebsocketSettings(None)) // no specific one selected
else if (subs.nonEmpty && subprotocols.contains(subs.get)) Right(NegotiatedWebsocketSettings(Some(subs.get)))
else Left(s"response that indicated that the given subprotocol was not supported. (client supported: ${subprotocols.mkString(", ")}, server supported: $subs)")
case Some(problem) Left(problem)
}
}
}
val UpgradeHeader = Upgrade(List(UpgradeProtocol("websocket")))
val ConnectionUpgradeHeader = Connection(List("upgrade"))
val SecWebsocketVersionHeader = `Sec-WebSocket-Version`(Seq(CurrentWebsocketVersion))
}

View file

@ -4,10 +4,10 @@
package akka.http.impl.engine.ws
import akka.stream.scaladsl.{ Keep, BidiFlow, Flow }
import akka.stream.stage.{ SyncDirective, Context, StageState, StatefulStage }
import java.util.Random
import scala.util.Random
import akka.stream.scaladsl.{ Keep, BidiFlow, Flow }
import akka.stream.stage.{ SyncDirective, Context, StatefulStage }
/**
* Implements Websocket Frame masking.
@ -51,12 +51,9 @@ private[http] object Masking {
def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective =
part match {
case start @ FrameStart(header, data)
if (header.length == 0) ctx.push(part)
else {
val mask = extractMask(header)
become(new Running(mask))
current.onPush(start.copy(header = setNewMask(header, mask)), ctx)
}
val mask = extractMask(header)
become(new Running(mask))
current.onPush(start.copy(header = setNewMask(header, mask)), ctx)
case _: FrameData
ctx.fail(new IllegalStateException("unexpected FrameData (need FrameStart first)"))
}

View file

@ -0,0 +1,15 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import java.security.SecureRandom
import java.util.Random
object Randoms {
/** A factory that creates SecureRandom instances */
private[http] case object SecureRandomInstances extends (() Random) {
override def apply(): Random = new SecureRandom()
}
}

View file

@ -26,7 +26,4 @@ private[http] abstract class UpgradeToWebsocketLowLevel extends InternalCustomHe
* INTERNAL API (for now)
*/
private[http] def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String] = None): HttpResponse
override def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String] = None): HttpResponse =
handleFrames(Websocket.stack(serverSide = true).join(handlerFlow), subprotocol)
}

View file

@ -5,12 +5,12 @@
package akka.http.impl.engine.ws
import akka.http.scaladsl.model.headers.CustomHeader
import akka.http.scaladsl.model.ws.Message
import akka.stream.Materializer
import akka.stream.scaladsl.Flow
private[http] case class UpgradeToWebsocketResponseHeader(handlerFlow: Flow[FrameEvent, FrameEvent, Any])
extends InternalCustomHeader("UpgradeToWebsocketResponseHeader") {
}
private[http] final case class UpgradeToWebsocketResponseHeader(handler: Either[Flow[FrameEvent, FrameEvent, Any], Flow[Message, Message, Any]])
extends InternalCustomHeader("UpgradeToWebsocketResponseHeader")
private[http] abstract class InternalCustomHeader(val name: String) extends CustomHeader {
override def suppressRendering: Boolean = true

View file

@ -4,7 +4,7 @@
package akka.http.impl.engine.ws
import java.security.SecureRandom
import java.util.Random
import akka.util.ByteString
@ -27,6 +27,16 @@ import akka.http.scaladsl.model.ws._
private[http] object Websocket {
import FrameHandler._
/**
* A stack of all the higher WS layers between raw frames and the user API.
*/
def stack(serverSide: Boolean,
maskingRandomFactory: () Random,
closeTimeout: FiniteDuration = 3.seconds): BidiFlow[FrameEvent, Message, Message, FrameEvent, Unit] =
masking(serverSide, maskingRandomFactory) atop
frameHandling(serverSide, closeTimeout) atop
messageAPI(serverSide, closeTimeout)
/** The lowest layer that implements the binary protocol */
def framing: BidiFlow[ByteString, FrameEvent, FrameEvent, ByteString, Unit] =
BidiFlow.wrap(
@ -35,8 +45,8 @@ private[http] object Websocket {
.named("ws-framing")
/** The layer that handles masking using the rules defined in the specification */
def masking(serverSide: Boolean): BidiFlow[FrameEvent, FrameEvent, FrameEvent, FrameEvent, Unit] =
Masking(serverSide, () new SecureRandom())
def masking(serverSide: Boolean, maskingRandomFactory: () Random): BidiFlow[FrameEvent, FrameEvent, FrameEvent, FrameEvent, Unit] =
Masking(serverSide, maskingRandomFactory)
.named("ws-masking")
/**
@ -219,12 +229,6 @@ private[http] object Websocket {
}.named("ws-message-api")
}
def stack(serverSide: Boolean = true,
closeTimeout: FiniteDuration = 3.seconds): BidiFlow[FrameEvent, Message, Message, FrameEvent, Unit] =
masking(serverSide) atop
frameHandling(serverSide, closeTimeout) atop
messageAPI(serverSide, closeTimeout)
object Tick
case object SwitchToWebsocketToken
}

View file

@ -0,0 +1,138 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import akka.http.scaladsl.model.ws.WebsocketRequest
import scala.collection.immutable
import scala.concurrent.{ Future, Promise }
import akka.util.ByteString
import akka.event.LoggingAdapter
import akka.stream.stage._
import akka.stream.BidiShape
import akka.stream.io.{ SessionBytes, SendBytes, SslTlsInbound }
import akka.stream.scaladsl._
import akka.http.ClientConnectionSettings
import akka.http.scaladsl.Http
import akka.http.scaladsl.Http.{ InvalidUpgradeResponse, ValidUpgrade, WebsocketUpgradeResponse }
import akka.http.scaladsl.model.{ HttpHeader, HttpResponse, HttpMethods, Uri }
import akka.http.scaladsl.model.headers.Host
import akka.http.impl.engine.parsing.HttpMessageParser.StateResult
import akka.http.impl.engine.parsing.ParserOutput.{ RemainingBytes, ResponseStart, NeedMoreData }
import akka.http.impl.engine.parsing.{ ParserOutput, HttpHeaderParser, HttpResponseParser }
import akka.http.impl.engine.rendering.{ HttpRequestRendererFactory, RequestRenderingContext }
import akka.http.impl.engine.ws.Handshake.Client.NegotiatedWebsocketSettings
import akka.http.impl.util.StreamUtils
object WebsocketClientBlueprint {
/**
* Returns a WebsocketClientLayer that can be materialized once.
*/
def apply(request: WebsocketRequest,
settings: ClientConnectionSettings,
log: LoggingAdapter): Http.WebsocketClientLayer =
(simpleTls.atopMat(handshake(request, settings, log))(Keep.right) atop
Websocket.framing atop
Websocket.stack(serverSide = false, maskingRandomFactory = settings.websocketRandomFactory)).reversed
/**
* A bidi flow that injects and inspects the WS handshake and then goes out of the way. This BidiFlow
* can only be materialized once.
*/
def handshake(request: WebsocketRequest,
settings: ClientConnectionSettings,
log: LoggingAdapter): BidiFlow[ByteString, ByteString, ByteString, ByteString, Future[WebsocketUpgradeResponse]] = {
import request._
val result = Promise[WebsocketUpgradeResponse]()
val valve = StreamUtils.OneTimeValve()
val (initialRequest, key) = Handshake.Client.buildRequest(uri, extraHeaders, subprotocol.toList, settings.websocketRandomFactory())
val hostHeader = Host(uri.authority)
val renderedInitialRequest =
HttpRequestRendererFactory.renderStrict(RequestRenderingContext(initialRequest, hostHeader), settings, log)
class UpgradeStage extends StatefulStage[ByteString, ByteString] {
type State = StageState[ByteString, ByteString]
def initial: State = parsingResponse
def parsingResponse: State = new State {
// a special version of the parser which only parses one message and then reports the remaining data
// if some is available
val parser = new HttpResponseParser(settings.parserSettings, HttpHeaderParser(settings.parserSettings)()) {
var first = true
override protected def parseMessage(input: ByteString, offset: Int): StateResult = {
if (first) {
first = false
super.parseMessage(input, offset)
} else {
emit(RemainingBytes(input.drop(offset)))
terminate()
}
}
}
parser.setRequestMethodForNextResponse(HttpMethods.GET)
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = {
parser.onPush(elem) match {
case NeedMoreData ctx.pull()
case ResponseStart(status, protocol, headers, entity, close)
val response = HttpResponse(status, headers, protocol = protocol)
Handshake.Client.validateResponse(response, subprotocol.toList, key) match {
case Right(NegotiatedWebsocketSettings(protocol))
result.success(ValidUpgrade(response, protocol))
become(transparent)
valve.open()
val parseResult = parser.onPull()
require(parseResult == ParserOutput.MessageEnd, s"parseResult should be MessageEnd but was $parseResult")
parser.onPull() match {
case NeedMoreData ctx.pull()
case RemainingBytes(bytes) ctx.push(bytes)
}
case Left(problem)
result.success(InvalidUpgradeResponse(response, s"Websocket server at $uri returned $problem"))
ctx.fail(throw new IllegalArgumentException(s"Websocket upgrade did not finish because of '$problem'"))
}
}
}
}
def transparent: State = new State {
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = ctx.push(elem)
}
}
BidiFlow() { implicit b
import FlowGraph.Implicits._
val networkIn = b.add(Flow[ByteString].transform(() new UpgradeStage))
val wsIn = b.add(Flow[ByteString])
val handshakeRequestSource = b.add(Source.single(renderedInitialRequest) ++ valve.source)
val httpRequestBytesAndThenWSBytes = b.add(Concat[ByteString]())
handshakeRequestSource ~> httpRequestBytesAndThenWSBytes
wsIn.outlet ~> httpRequestBytesAndThenWSBytes
BidiShape(
networkIn.inlet,
networkIn.outlet,
wsIn.inlet,
httpRequestBytesAndThenWSBytes.out)
} mapMaterializedValue (_ result.future)
}
def simpleTls: BidiFlow[SslTlsInbound, ByteString, ByteString, SendBytes, Unit] =
BidiFlow.wrap(
Flow[SslTlsInbound].collect { case SessionBytes(_, bytes) bytes },
Flow[ByteString].map(SendBytes))(Keep.none)
}

View file

@ -316,6 +316,23 @@ private[http] object StreamUtils {
}
Flow[T].transformMaterializing(newForeachStage)
}
/**
* Similar to Source.lazyEmpty but doesn't rely on materialization. Can only be used once.
*/
trait OneTimeValve {
def source[T]: Source[T, Unit]
def open(): Unit
}
object OneTimeValve {
def apply(): OneTimeValve = new OneTimeValve {
val promise = Promise[Unit]()
val _source = Source(promise.future).drop(1) // we are only interested in the completion event
def source[T]: Source[T, Unit] = _source.asInstanceOf[Source[T, Unit]] // safe, because source won't generate any elements
def open(): Unit = promise.success(())
}
}
}
/**

View file

@ -6,6 +6,8 @@ package akka.http.javadsl
import java.net.InetSocketAddress
import akka.http.impl.util.JavaMapping
import akka.http.javadsl.model.ws._
import akka.stream
import akka.stream.io.{ SslTlsInbound, SslTlsOutbound }
import scala.language.implicitConversions
@ -498,6 +500,93 @@ class Http(system: ExtendedActorSystem) extends akka.actor.Extension {
log: LoggingAdapter, materializer: Materializer): Future[HttpResponse] =
delegate.singleRequest(request.asScala, settings, httpsContext, log)(materializer)
/**
* Constructs a Websocket [[BidiFlow]].
*
* The layer is not reusable and must only be materialized once.
*/
def websocketClientLayer(request: WebsocketRequest): BidiFlow[Message, SslTlsOutbound, SslTlsInbound, Message, Future[WebsocketUpgradeResponse]] =
adaptWsBidiFlow(delegate.websocketClientLayer(request.asScala))
/**
* Constructs a Websocket [[BidiFlow]] using the configured default [[ClientConnectionSettings]],
* configured using the `akka.http.client` config section.
*
* The layer is not reusable and must only be materialized once.
*/
def websocketClientLayer(request: WebsocketRequest,
settings: ClientConnectionSettings): BidiFlow[Message, SslTlsOutbound, SslTlsInbound, Message, Future[WebsocketUpgradeResponse]] =
adaptWsBidiFlow(delegate.websocketClientLayer(request.asScala, settings))
/**
* Constructs a Websocket [[BidiFlow]] using the configured default [[ClientConnectionSettings]],
* configured using the `akka.http.client` config section.
*
* The layer is not reusable and must only be materialized once.
*/
def websocketClientLayer(request: WebsocketRequest,
settings: ClientConnectionSettings,
log: LoggingAdapter): BidiFlow[Message, SslTlsOutbound, SslTlsInbound, Message, Future[WebsocketUpgradeResponse]] =
adaptWsBidiFlow(delegate.websocketClientLayer(request.asScala, settings, log))
/**
* Constructs a flow that once materialized establishes a Websocket connection to the given Uri.
*
* The layer is not reusable and must only be materialized once.
*/
def websocketClientFlow(request: WebsocketRequest): Flow[Message, Message, Future[WebsocketUpgradeResponse]] =
adaptWsFlow {
delegate.websocketClientFlow(request.asScala)
}
/**
* Constructs a flow that once materialized establishes a Websocket connection to the given Uri.
*
* The layer is not reusable and must only be materialized once.
*/
def websocketClientFlow(request: WebsocketRequest,
localAddress: Option[InetSocketAddress],
settings: ClientConnectionSettings,
httpsContext: Option[HttpsContext],
log: LoggingAdapter): Flow[Message, Message, Future[WebsocketUpgradeResponse]] =
adaptWsFlow {
delegate.websocketClientFlow(request.asScala, localAddress, settings, httpsContext, log)
}
/**
* Runs a single Websocket conversation given a Uri and a flow that represents the client side of the
* Websocket conversation.
*/
def singleWebsocketRequest[T](request: WebsocketRequest,
clientFlow: Flow[Message, Message, T],
materializer: Materializer): Pair[Future[WebsocketUpgradeResponse], T] =
adaptWsResultTuple {
delegate.singleWebsocketRequest(
request.asScala,
adaptWsFlow[T](clientFlow))(materializer)
}
/**
* Runs a single Websocket conversation given a Uri and a flow that represents the client side of the
* Websocket conversation.
*/
def singleWebsocketRequest[T](request: WebsocketRequest,
clientFlow: Flow[Message, Message, T],
localAddress: Option[InetSocketAddress],
settings: ClientConnectionSettings,
httpsContext: Option[HttpsContext],
log: LoggingAdapter,
materializer: Materializer): Pair[Future[WebsocketUpgradeResponse], T] =
adaptWsResultTuple {
delegate.singleWebsocketRequest(
request.asScala,
adaptWsFlow[T](clientFlow),
localAddress,
settings,
httpsContext,
log)(materializer)
}
/**
* Triggers an orderly shutdown of all host connections pools currently maintained by the [[ActorSystem]].
* The returned future is completed when all pools that were live at the time of this method call
@ -519,7 +608,7 @@ class Http(system: ExtendedActorSystem) extends akka.actor.Extension {
def setDefaultClientHttpsContext(context: HttpsContext): Unit =
delegate.setDefaultClientHttpsContext(context.asInstanceOf[akka.http.scaladsl.HttpsContext])
private def adaptTupleFlow[T, Mat](scalaFlow: akka.stream.scaladsl.Flow[(scaladsl.model.HttpRequest, T), (Try[scaladsl.model.HttpResponse], T), Mat]): Flow[Pair[HttpRequest, T], Pair[Try[HttpResponse], T], Mat] = {
private def adaptTupleFlow[T, Mat](scalaFlow: stream.scaladsl.Flow[(scaladsl.model.HttpRequest, T), (Try[scaladsl.model.HttpResponse], T), Mat]): Flow[Pair[HttpRequest, T], Pair[Try[HttpResponse], T], Mat] = {
implicit val _ = JavaMapping.identity[T]
JavaMapping.toJava(scalaFlow)(JavaMapping.flowMapping[Pair[HttpRequest, T], (scaladsl.model.HttpRequest, T), Pair[Try[HttpResponse], T], (Try[scaladsl.model.HttpResponse], T), Mat])
}
@ -533,4 +622,25 @@ class Http(system: ExtendedActorSystem) extends akka.actor.Extension {
new BidiFlow(
JavaMapping.adapterBidiFlow[HttpRequest, sm.HttpRequest, sm.HttpResponse, HttpResponse]
.atop(clientLayer))
private def adaptWsBidiFlow(wsLayer: scaladsl.Http.WebsocketClientLayer): BidiFlow[Message, SslTlsOutbound, SslTlsInbound, Message, Future[WebsocketUpgradeResponse]] =
new BidiFlow(
JavaMapping.adapterBidiFlow[Message, sm.ws.Message, sm.ws.Message, Message]
.atopMat(wsLayer)((_, s) adaptWsUpgradeResponse(s)))
private def adaptWsFlow(wsLayer: stream.scaladsl.Flow[sm.ws.Message, sm.ws.Message, Future[scaladsl.Http.WebsocketUpgradeResponse]]): Flow[Message, Message, Future[WebsocketUpgradeResponse]] =
Flow.adapt(JavaMapping.adapterBidiFlow[Message, sm.ws.Message, sm.ws.Message, Message].joinMat(wsLayer)(Keep.right).mapMaterializedValue(adaptWsUpgradeResponse _))
private def adaptWsFlow[Mat](javaFlow: Flow[Message, Message, Mat]): stream.scaladsl.Flow[scaladsl.model.ws.Message, scaladsl.model.ws.Message, Mat] =
stream.scaladsl.Flow[scaladsl.model.ws.Message]
.map(Message.adapt)
.viaMat(javaFlow.asScala)(Keep.right)
.map(_.asScala)
private def adaptWsResultTuple[T](result: (Future[scaladsl.Http.WebsocketUpgradeResponse], T)): Pair[Future[WebsocketUpgradeResponse], T] =
result match {
case (fut, tMat) Pair(adaptWsUpgradeResponse(fut), tMat)
}
private def adaptWsUpgradeResponse(responseFuture: Future[scaladsl.Http.WebsocketUpgradeResponse]): Future[WebsocketUpgradeResponse] =
responseFuture.map(WebsocketUpgradeResponse.adapt)(system.dispatcher)
}

View file

@ -0,0 +1,61 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.javadsl.model.ws
import akka.http.javadsl.model.{ Uri, HttpHeader }
import akka.http.scaladsl.model.ws.{ WebsocketRequest ScalaWebsocketRequest }
/**
* Represents a Websocket request. Use `WebsocketRequest.create` to create a request
* for a target URI and then use `addHeader` or `requestSubprotocol` to set optional
* details.
*/
abstract class WebsocketRequest {
/**
* Return a copy of this request that contains the given additional header.
*/
def addHeader(header: HttpHeader): WebsocketRequest
/**
* Return a copy of this request that will require that the server uses the
* given Websocket subprotocol.
*/
def requestSubprotocol(subprotocol: String): WebsocketRequest
def asScala: ScalaWebsocketRequest
}
object WebsocketRequest {
import akka.http.impl.util.JavaMapping.Implicits._
/**
* Creates a WebsocketRequest to a target URI. Use the methods on `WebsocketRequest`
* to specify further details.
*/
def create(uri: Uri): WebsocketRequest =
wrap(ScalaWebsocketRequest(uri.asScala))
/**
* Creates a WebsocketRequest to a target URI. Use the methods on `WebsocketRequest`
* to specify further details.
*/
def create(uriString: String): WebsocketRequest =
create(Uri.create(uriString))
/**
* Wraps a Scala version of WebsocketRequest.
*/
def wrap(scalaRequest: ScalaWebsocketRequest): WebsocketRequest =
new WebsocketRequest {
def addHeader(header: HttpHeader): WebsocketRequest =
transform(s s.copy(extraHeaders = s.extraHeaders :+ header.asScala))
def requestSubprotocol(subprotocol: String): WebsocketRequest =
transform(_.copy(subprotocol = Some(subprotocol)))
def asScala: ScalaWebsocketRequest = scalaRequest
def transform(f: ScalaWebsocketRequest ScalaWebsocketRequest): WebsocketRequest =
wrap(f(asScala))
}
}

View file

@ -0,0 +1,58 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.javadsl.model.ws
import akka.http.javadsl.model.HttpResponse
import akka.http.scaladsl
import akka.http.scaladsl.Http.{ InvalidUpgradeResponse, ValidUpgrade }
import akka.japi.Option
/**
* Represents an upgrade response for a Websocket upgrade request. Can either be valid, in which
* case the `chosenSubprotocol` method is valid, or if invalid, the `invalidationReason` method
* can be used to find out why the upgrade failed.
*/
trait WebsocketUpgradeResponse {
def isValid: Boolean
/**
* Returns the response object as received from the server for further inspection.
*/
def response: HttpResponse
/**
* If valid, returns `Some(subprotocol)` (if any was requested), or `None` if none was
* chosen or offered.
*/
def chosenSubprotocol: Option[String]
/**
* If invalid, the reason why the server's upgrade response could not be accepted.
*/
def invalidationReason: String
}
object WebsocketUpgradeResponse {
import akka.http.impl.util.JavaMapping.Implicits._
def adapt(scalaResponse: scaladsl.Http.WebsocketUpgradeResponse): WebsocketUpgradeResponse =
scalaResponse match {
case ValidUpgrade(response, chosen)
new WebsocketUpgradeResponse {
def isValid: Boolean = true
def response: HttpResponse = response
def chosenSubprotocol: Option[String] = chosen.asJava
def invalidationReason: String =
throw new UnsupportedOperationException("invalidationReason must not be called for valid response")
}
case InvalidUpgradeResponse(response, cause)
new WebsocketUpgradeResponse {
def isValid: Boolean = false
def response: HttpResponse = response
def chosenSubprotocol: Option[String] = throw new UnsupportedOperationException("chosenSubprotocol must not be called for valid response")
def invalidationReason: String = cause
}
}
}

View file

@ -5,8 +5,9 @@
package akka.http.scaladsl
import java.net.InetSocketAddress
import java.security.SecureRandom
import java.util.concurrent.ConcurrentHashMap
import java.util.{ Collection JCollection }
import java.util.{ Collection JCollection, Random }
import javax.net.ssl.{ SSLContext, SSLParameters }
import akka.actor._
@ -15,13 +16,16 @@ import akka.http._
import akka.http.impl.engine.client._
import akka.http.impl.engine.server._
import akka.http.impl.util.{ ReadTheDocumentationException, Java6Compat, StreamUtils }
import akka.http.impl.engine.ws.WebsocketClientBlueprint
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.Host
import akka.http.scaladsl.model.ws.{ WebsocketRequest, Message }
import akka.http.scaladsl.util.FastFuture
import akka.japi
import akka.stream.Materializer
import akka.stream.io._
import akka.stream.scaladsl._
import akka.util.ByteString
import com.typesafe.config.Config
import scala.collection.immutable
@ -207,11 +211,17 @@ class HttpExt(config: Config)(implicit system: ActorSystem) extends akka.actor.E
log: LoggingAdapter): Flow[HttpRequest, HttpResponse, Future[OutgoingConnection]] = {
val hostHeader = if (port == (if (httpsContext.isEmpty) 80 else 443)) Host(host) else Host(host, port)
val layer = clientLayer(hostHeader, settings, log)
layer.joinMat(_outgoingTlsConnectionLayer(host, port, localAddress, settings, httpsContext, log))(Keep.right)
}
private def _outgoingTlsConnectionLayer(host: String, port: Int, localAddress: Option[InetSocketAddress],
settings: ClientConnectionSettings, httpsContext: Option[HttpsContext],
log: LoggingAdapter): Flow[SslTlsOutbound, SslTlsInbound, Future[OutgoingConnection]] = {
val tlsStage = sslTlsStage(httpsContext, Client, Some(host -> port))
val transportFlow = Tcp().outgoingConnection(new InetSocketAddress(host, port), localAddress,
settings.socketOptions, halfClose = true, settings.connectingTimeout, settings.idleTimeout)
layer.atop(tlsStage).joinMat(transportFlow) { (_, tcpConnFuture)
tlsStage.joinMat(transportFlow) { (_, tcpConnFuture)
import system.dispatcher
tcpConnFuture map { tcpConn OutgoingConnection(tcpConn.localAddress, tcpConn.remoteAddress) }
}
@ -406,6 +416,57 @@ class HttpExt(config: Config)(implicit system: ActorSystem) extends akka.actor.E
case e: IllegalUriException FastFuture.failed(e)
}
/**
* Constructs a [[WebsocketClientLayer]] stage using the configured default [[ClientConnectionSettings]],
* configured using the `akka.http.client` config section.
*
* The layer is not reusable and must only be materialized once.
*/
def websocketClientLayer(request: WebsocketRequest,
settings: ClientConnectionSettings = ClientConnectionSettings(system),
log: LoggingAdapter = system.log): Http.WebsocketClientLayer =
WebsocketClientBlueprint(request, settings, log)
/**
* Constructs a flow that once materialized establishes a Websocket connection to the given Uri.
*
* The layer is not reusable and must only be materialized once.
*/
def websocketClientFlow(request: WebsocketRequest,
localAddress: Option[InetSocketAddress] = None,
settings: ClientConnectionSettings = ClientConnectionSettings(system),
httpsContext: Option[HttpsContext] = None,
log: LoggingAdapter = system.log): Flow[Message, Message, Future[WebsocketUpgradeResponse]] = {
import request.uri
require(uri.isAbsolute, s"Websocket request URI must be absolute but was '$uri'")
val ctx = uri.scheme match {
case "ws" None
case "wss" effectiveHttpsContext(httpsContext)
case scheme @ _
throw new IllegalArgumentException(s"Illegal URI scheme '$scheme' in '$uri' for Websocket request. " +
s"Websocket requests must use either 'ws' or 'wss'")
}
val host = uri.authority.host.address
val port = uri.effectivePort
websocketClientLayer(request, settings, log)
.joinMat(_outgoingTlsConnectionLayer(host, port, localAddress, settings, ctx, log))(Keep.left)
}
/**
* Runs a single Websocket conversation given a Uri and a flow that represents the client side of the
* Websocket conversation.
*/
def singleWebsocketRequest[T](request: WebsocketRequest,
clientFlow: Flow[Message, Message, T],
localAddress: Option[InetSocketAddress] = None,
settings: ClientConnectionSettings = ClientConnectionSettings(system),
httpsContext: Option[HttpsContext] = None,
log: LoggingAdapter = system.log)(implicit mat: Materializer): (Future[WebsocketUpgradeResponse], T) =
websocketClientFlow(request, localAddress, settings, httpsContext, log)
.joinMat(clientFlow)(Keep.both).run()
/**
* Triggers an orderly shutdown of all host connections pools currently maintained by the [[ActorSystem]].
* The returned future is completed when all pools that were live at the time of this method call
@ -562,13 +623,27 @@ object Http extends ExtensionId[HttpExt] with ExtensionIdProvider {
type ClientLayer = BidiFlow[HttpRequest, SslTlsOutbound, SslTlsInbound, HttpResponse, Unit]
//#
/**
* The type of the client-side Websocket layer as a stand-alone BidiFlow
* that can be put atop the TCP layer to form an HTTP client.
*
* {{{
* +------+
* ws.Message ~>| |~> SslTlsOutbound
* | bidi |
* ws.Message <~| |<~ SslTlsInbound
* +------+
* }}}
*/
type WebsocketClientLayer = BidiFlow[Message, SslTlsOutbound, SslTlsInbound, Message, Future[WebsocketUpgradeResponse]]
/**
* Represents a prospective HTTP server binding.
*
* @param localAddress The local address of the endpoint bound by the materialization of the `connections` [[Source]]
*
*/
case class ServerBinding(localAddress: InetSocketAddress)(private val unbindAction: () Future[Unit]) {
final case class ServerBinding(localAddress: InetSocketAddress)(private val unbindAction: () Future[Unit]) {
/**
* Asynchronously triggers the unbinding of the port that was bound by the materialization of the `connections`
@ -582,7 +657,7 @@ object Http extends ExtensionId[HttpExt] with ExtensionIdProvider {
/**
* Represents one accepted incoming HTTP connection.
*/
case class IncomingConnection(
final case class IncomingConnection(
localAddress: InetSocketAddress,
remoteAddress: InetSocketAddress,
flow: Flow[HttpResponse, HttpRequest, Unit]) {
@ -612,12 +687,21 @@ object Http extends ExtensionId[HttpExt] with ExtensionIdProvider {
/**
* Represents a prospective outgoing HTTP connection.
*/
case class OutgoingConnection(localAddress: InetSocketAddress, remoteAddress: InetSocketAddress)
final case class OutgoingConnection(localAddress: InetSocketAddress, remoteAddress: InetSocketAddress)
/**
* Represents the response to a websocket upgrade request. Can either be [[ValidUpgrade]] or [[InvalidUpgradeResponse]].
*/
sealed trait WebsocketUpgradeResponse {
def response: HttpResponse
}
final case class ValidUpgrade(response: HttpResponse, chosenSubprotocol: Option[String]) extends WebsocketUpgradeResponse
final case class InvalidUpgradeResponse(response: HttpResponse, cause: String) extends WebsocketUpgradeResponse
/**
* Represents a connection pool to a specific target host and pool configuration.
*/
case class HostConnectionPool(setup: HostConnectionPoolSetup)(
final case class HostConnectionPool(setup: HostConnectionPoolSetup)(
private[http] val gatewayFuture: Future[PoolGateway]) extends javadsl.HostConnectionPool { // enable test access
/**
@ -641,11 +725,11 @@ object Http extends ExtensionId[HttpExt] with ExtensionIdProvider {
import scala.collection.JavaConverters._
//# https-context-impl
case class HttpsContext(sslContext: SSLContext,
enabledCipherSuites: Option[immutable.Seq[String]] = None,
enabledProtocols: Option[immutable.Seq[String]] = None,
clientAuth: Option[ClientAuth] = None,
sslParameters: Option[SSLParameters] = None)
final case class HttpsContext(sslContext: SSLContext,
enabledCipherSuites: Option[immutable.Seq[String]] = None,
enabledProtocols: Option[immutable.Seq[String]] = None,
clientAuth: Option[ClientAuth] = None,
sslParameters: Option[SSLParameters] = None)
//#
extends akka.http.javadsl.HttpsContext {
def firstSession = NegotiateNewSession(enabledCipherSuites, enabledProtocols, clientAuth, sslParameters)

View file

@ -599,8 +599,8 @@ object Uri {
}
val defaultPorts: Map[String, Int] =
Map("ftp" -> 21, "ssh" -> 22, "telnet" -> 23, "smtp" -> 25, "domain" -> 53, "tftp" -> 69, "http" -> 80,
"pop3" -> 110, "nntp" -> 119, "imap" -> 143, "snmp" -> 161, "ldap" -> 389, "https" -> 443, "imaps" -> 993,
Map("ftp" -> 21, "ssh" -> 22, "telnet" -> 23, "smtp" -> 25, "domain" -> 53, "tftp" -> 69, "http" -> 80, "ws" -> 80,
"pop3" -> 110, "nntp" -> 119, "imap" -> 143, "snmp" -> 161, "ldap" -> 389, "https" -> 443, "wss" -> 443, "imaps" -> 993,
"nfs" -> 2049).withDefaultValue(-1)
sealed trait ParsingMode

View file

@ -8,15 +8,18 @@ import java.lang.Iterable
import java.net.InetSocketAddress
import java.security.MessageDigest
import java.util
import scala.reflect.ClassTag
import scala.util.Try
import scala.annotation.tailrec
import scala.collection.immutable
import akka.parboiled2.util.Base64
import akka.http.impl.util._
import akka.http.javadsl.{ model jm }
import akka.http.scaladsl.model._
import scala.reflect.ClassTag
sealed abstract class ModeledCompanion[T: ClassTag] extends Renderable {
val name = getClass.getSimpleName.replace("$minus", "-").dropRight(1) // trailing $
val lowercaseName = name.toRootLowerCase
@ -102,6 +105,7 @@ sealed abstract case class Expect private () extends ModeledHeader {
// http://tools.ietf.org/html/rfc7230#section-5.4
object Host extends ModeledCompanion[Host] {
def apply(authority: Uri.Authority): Host = apply(authority.host, authority.port)
def apply(address: InetSocketAddress): Host = apply(address.getHostStringJava6Compatible, address.getPort)
def apply(host: String): Host = apply(host, 0)
def apply(host: String, port: Int): Host = apply(Uri.Host(host), port)
@ -626,7 +630,12 @@ private[http] final case class `Sec-WebSocket-Extensions`(extensions: immutable.
/**
* INTERNAL API
*/
private[http] object `Sec-WebSocket-Key` extends ModeledCompanion[`Sec-WebSocket-Key`]
private[http] object `Sec-WebSocket-Key` extends ModeledCompanion[`Sec-WebSocket-Key`] {
def apply(keyBytes: Array[Byte]): `Sec-WebSocket-Key` = {
require(keyBytes.length == 16, s"Sec-WebSocket-Key keyBytes must have length 16 but had ${keyBytes.length}")
`Sec-WebSocket-Key`(Base64.rfc2045().encodeToString(keyBytes, false))
}
}
/**
* INTERNAL API
*/
@ -634,6 +643,12 @@ private[http] final case class `Sec-WebSocket-Key`(key: String) extends ModeledH
protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ key
protected def companion = `Sec-WebSocket-Key`
/**
* Checks if the key value is valid according to the Websocket specification, i.e.
* if the String is a Base64 representation of 16 bytes.
*/
def isValid: Boolean = Try(Base64.rfc2045().decode(key)).toOption.exists(_.length == 16)
}
// http://tools.ietf.org/html/rfc6455#section-4.3

View file

@ -16,7 +16,7 @@ sealed trait Message
/**
* A binary
*/
trait TextMessage extends Message {
sealed trait TextMessage extends Message {
/**
* The contents of this message as a stream.
*/
@ -38,7 +38,7 @@ object TextMessage {
final private case class Streamed(textStream: Source[String, _]) extends TextMessage
}
//#message-model
trait BinaryMessage extends Message {
sealed trait BinaryMessage extends Message {
/**
* The contents of this message as a stream.
*/

View file

@ -0,0 +1,26 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.scaladsl.model.ws
import scala.language.implicitConversions
import scala.collection.immutable
import akka.http.scaladsl.model.{ HttpHeader, Uri }
/**
* Represents a Websocket request.
* @param uri The target URI to connect to.
* @param extraHeaders Extra headers to add to the Websocket request.
* @param subprotocol A Websocket subprotocol if required.
*/
final case class WebsocketRequest(
uri: Uri,
extraHeaders: immutable.Seq[HttpHeader] = Nil,
subprotocol: Option[String] = None)
object WebsocketRequest {
implicit def fromTargetUri(uri: Uri): WebsocketRequest = WebsocketRequest(uri)
implicit def fromTargetUriString(uriString: String): WebsocketRequest = WebsocketRequest(uriString)
}

View file

@ -0,0 +1,83 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.javadsl;
import akka.actor.ActorSystem;
import akka.dispatch.Futures;
import akka.http.javadsl.model.ws.Message;
import akka.http.javadsl.model.ws.TextMessage;
import akka.http.javadsl.model.ws.WebsocketRequest;
import akka.japi.function.Function;
import akka.stream.ActorMaterializer;
import akka.stream.Materializer;
import akka.stream.javadsl.Flow;
import akka.stream.javadsl.Keep;
import akka.stream.javadsl.Sink;
import akka.stream.javadsl.Source;
import scala.concurrent.Await;
import scala.concurrent.Future;
import scala.concurrent.duration.FiniteDuration;
import scala.runtime.BoxedUnit;
import java.util.Arrays;
import java.util.List;
public class WSEchoTestClientApp {
private static final Function<Message, String> messageStringifier = new Function<Message, String>() {
@Override
public String apply(Message msg) throws Exception {
if (msg.isText() && msg.asTextMessage().isStrict())
return msg.asTextMessage().getStrictText();
else
throw new IllegalArgumentException("Unexpected message "+msg);
}
};
public static void main(String[] args) throws Exception {
ActorSystem system = ActorSystem.create();
try {
final Materializer materializer = ActorMaterializer.create(system);
final Future<Message> ignoredMessage = Futures.successful((Message) TextMessage.create("blub"));
final Future<Message> delayedCompletion =
akka.pattern.Patterns.after(
FiniteDuration.apply(1, "second"),
system.scheduler(),
system.dispatcher(),
ignoredMessage);
Source<Message, BoxedUnit> echoSource =
Source.from(Arrays.<Message>asList(
TextMessage.create("abc"),
TextMessage.create("def"),
TextMessage.create("ghi")
)).concat(Source.from(delayedCompletion).drop(1));
Sink<Message, Future<List<String>>> echoSink =
Flow.of(Message.class)
.map(messageStringifier)
.grouped(1000)
.toMat(Sink.<List<String>>head(), Keep.<BoxedUnit, Future<List<String>>>right());
Flow<Message, Message, Future<List<String>>> echoClient =
Flow.wrap(echoSink, echoSource, Keep.<Future<List<String>>, BoxedUnit>left());
Future<List<String>> result =
Http.get(system).singleWebsocketRequest(
WebsocketRequest.create("ws://echo.websocket.org"),
echoClient,
materializer
).second();
List<String> messages = Await.result(result, FiniteDuration.apply(10, "second"));
System.out.println("Collected " + messages.size() + " messages:");
for (String msg: messages)
System.out.println(msg);
} finally {
system.shutdown();
}
}
}

View file

@ -32,7 +32,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
send("""GET / HTTP/1.1
|Host: example.com
|
|""".stripMarginWithNewline("\r\n"))
|""")
expectRequest shouldEqual HttpRequest(uri = "http://example.com/", headers = List(Host("example.com")))
}
@ -42,7 +42,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Host: example.com
|Content-Length: 12
|
|""".stripMarginWithNewline("\r\n"))
|""")
inside(expectRequest) {
case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _)
@ -65,10 +65,9 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
send("""GET / HTTP/1.2
|Host: example.com
|
|""".stripMarginWithNewline("\r\n"))
|""")
netOutSub.request(1)
wipeDate(netOut.expectNext().utf8String) shouldEqual
expectResponseWithWipedDate(
"""HTTP/1.1 505 HTTP Version Not Supported
|Server: akka-http/test
|Date: XXXX
@ -76,7 +75,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Content-Type: text/plain; charset=UTF-8
|Content-Length: 74
|
|The server does not support the HTTP protocol version used in the request.""".stripMarginWithNewline("\r\n")
|The server does not support the HTTP protocol version used in the request.""")
}
"report an invalid Chunked stream" in new TestSetup {
@ -86,7 +85,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|
|6
|abcdef
|""".stripMarginWithNewline("\r\n"))
|""")
inside(expectRequest) {
case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _)
@ -102,11 +101,10 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
error.getMessage shouldEqual "Illegal character 'g' in chunk start"
requests.expectComplete()
netOutSub.request(1)
responsesSub.expectRequest()
responsesSub.sendError(error.asInstanceOf[Exception])
responses.expectRequest()
responses.sendError(error.asInstanceOf[Exception])
wipeDate(netOut.expectNext().utf8String) shouldEqual
expectResponseWithWipedDate(
"""HTTP/1.1 400 Bad Request
|Server: akka-http/test
|Date: XXXX
@ -114,7 +112,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Content-Type: text/plain; charset=UTF-8
|Content-Length: 36
|
|Illegal character 'g' in chunk start""".stripMarginWithNewline("\r\n")
|Illegal character 'g' in chunk start""")
}
}
@ -123,7 +121,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Host: example.com
|Content-Length: 12
|
|abcdefghijkl""".stripMarginWithNewline("\r\n"))
|abcdefghijkl""")
expectRequest shouldEqual
HttpRequest(
@ -138,7 +136,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Host: example.com
|Content-Length: 12
|
|abcdef""".stripMarginWithNewline("\r\n"))
|abcdef""")
inside(expectRequest) {
case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _)
@ -161,7 +159,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|
|6
|abcdef
|""".stripMarginWithNewline("\r\n"))
|""")
inside(expectRequest) {
case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _)
@ -182,7 +180,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Host: example.com
|Content-Length: 12
|
|abcdefghijkl""".stripMarginWithNewline("\r\n"))
|abcdefghijkl""")
expectRequest shouldEqual
HttpRequest(
@ -195,7 +193,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Host: example.com
|Content-Length: 12
|
|mnopqrstuvwx""".stripMarginWithNewline("\r\n"))
|mnopqrstuvwx""")
expectRequest shouldEqual
HttpRequest(
@ -210,7 +208,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Host: example.com
|Content-Length: 12
|
|abcdef""".stripMarginWithNewline("\r\n"))
|abcdef""")
inside(expectRequest) {
case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _)
@ -232,7 +230,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Host: example.com
|Content-Length: 5
|
|abcde""".stripMarginWithNewline("\r\n"))
|abcde""")
inside(expectRequest) {
case HttpRequest(POST, _, _, HttpEntity.Strict(_, data), _)
@ -247,7 +245,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|
|6
|abcdef
|""".stripMarginWithNewline("\r\n"))
|""")
inside(expectRequest) {
case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _)
@ -270,7 +268,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Host: example.com
|Content-Length: 5
|
|abcde""".stripMarginWithNewline("\r\n"))
|abcde""")
inside(expectRequest) {
case HttpRequest(POST, _, _, HttpEntity.Strict(_, data), _)
@ -283,7 +281,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Host: example.com
|Content-Length: 12
|
|abcdef""".stripMarginWithNewline("\r\n"))
|abcdef""")
inside(expectRequest) {
case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _)
@ -306,7 +304,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|
|6
|abcdef
|""".stripMarginWithNewline("\r\n"))
|""")
inside(expectRequest) {
case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _)
@ -328,7 +326,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Host: example.com
|Content-Length: 12
|
|abcdef""".stripMarginWithNewline("\r\n"))
|abcdef""")
inside(expectRequest) {
case HttpRequest(POST, _, _, HttpEntity.Default(_, 12, data), _)
val dataProbe = TestSubscriber.manualProbe[ByteString]
@ -349,7 +347,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|
|6
|abcdef
|""".stripMarginWithNewline("\r\n"))
|""")
inside(expectRequest) {
case HttpRequest(POST, _, _, HttpEntity.Chunked(_, data), _)
val dataProbe = TestSubscriber.manualProbe[ChunkStreamPart]
@ -368,7 +366,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
send("""HEAD / HTTP/1.1
|Host: example.com
|
|""".stripMarginWithNewline("\r\n"))
|""")
expectRequest shouldEqual HttpRequest(GET, uri = "http://example.com/", headers = List(Host("example.com")))
}
@ -377,7 +375,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
send("""HEAD / HTTP/1.1
|Host: example.com
|
|""".stripMarginWithNewline("\r\n"))
|""")
expectRequest shouldEqual HttpRequest(HEAD, uri = "http://example.com/", headers = List(Host("example.com")))
}
@ -385,19 +383,18 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
send("""HEAD / HTTP/1.1
|Host: example.com
|
|""".stripMarginWithNewline("\r\n"))
|""")
inside(expectRequest) {
case HttpRequest(GET, _, _, _, _)
responsesSub.sendNext(HttpResponse(entity = HttpEntity.Strict(ContentTypes.`text/plain`, ByteString("abcd"))))
netOutSub.request(1)
wipeDate(netOut.expectNext().utf8String) shouldEqual
responses.sendNext(HttpResponse(entity = HttpEntity.Strict(ContentTypes.`text/plain`, ByteString("abcd"))))
expectResponseWithWipedDate(
"""|HTTP/1.1 200 OK
|Server: akka-http/test
|Date: XXXX
|Content-Type: text/plain
|Content-Length: 4
|
|""".stripMarginWithNewline("\r\n")
|""")
}
}
@ -405,22 +402,21 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
send("""HEAD / HTTP/1.1
|Host: example.com
|
|""".stripMarginWithNewline("\r\n"))
|""")
val data = TestPublisher.manualProbe[ByteString]()
inside(expectRequest) {
case HttpRequest(GET, _, _, _, _)
responsesSub.sendNext(HttpResponse(entity = HttpEntity.Default(ContentTypes.`text/plain`, 4, Source(data))))
netOutSub.request(1)
responses.sendNext(HttpResponse(entity = HttpEntity.Default(ContentTypes.`text/plain`, 4, Source(data))))
val dataSub = data.expectSubscription()
dataSub.expectCancellation()
wipeDate(netOut.expectNext().utf8String) shouldEqual
expectResponseWithWipedDate(
"""|HTTP/1.1 200 OK
|Server: akka-http/test
|Date: XXXX
|Content-Type: text/plain
|Content-Length: 4
|
|""".stripMarginWithNewline("\r\n")
|""")
}
}
@ -428,46 +424,44 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
send("""HEAD / HTTP/1.1
|Host: example.com
|
|""".stripMarginWithNewline("\r\n"))
|""")
val data = TestPublisher.manualProbe[ByteString]()
inside(expectRequest) {
case HttpRequest(GET, _, _, _, _)
responsesSub.sendNext(HttpResponse(entity = HttpEntity.CloseDelimited(ContentTypes.`text/plain`, Source(data))))
netOutSub.request(1)
responses.sendNext(HttpResponse(entity = HttpEntity.CloseDelimited(ContentTypes.`text/plain`, Source(data))))
val dataSub = data.expectSubscription()
dataSub.expectCancellation()
wipeDate(netOut.expectNext().utf8String) shouldEqual
expectResponseWithWipedDate(
"""|HTTP/1.1 200 OK
|Server: akka-http/test
|Date: XXXX
|Content-Type: text/plain
|
|""".stripMarginWithNewline("\r\n")
|""")
}
// No close should happen here since this was a HEAD request
netOut.expectNoMsg(50.millis)
netOut.expectNoBytes(50.millis)
}
"not emit entities when responding to HEAD requests if transparent-head-requests is enabled (with Chunked)" in new TestSetup {
send("""HEAD / HTTP/1.1
|Host: example.com
|
|""".stripMarginWithNewline("\r\n"))
|""")
val data = TestPublisher.manualProbe[ChunkStreamPart]()
inside(expectRequest) {
case HttpRequest(GET, _, _, _, _)
responsesSub.sendNext(HttpResponse(entity = HttpEntity.Chunked(ContentTypes.`text/plain`, Source(data))))
netOutSub.request(1)
responses.sendNext(HttpResponse(entity = HttpEntity.Chunked(ContentTypes.`text/plain`, Source(data))))
val dataSub = data.expectSubscription()
dataSub.expectCancellation()
wipeDate(netOut.expectNext().utf8String) shouldEqual
expectResponseWithWipedDate(
"""|HTTP/1.1 200 OK
|Server: akka-http/test
|Date: XXXX
|Transfer-Encoding: chunked
|Content-Type: text/plain
|
|""".stripMarginWithNewline("\r\n")
|""")
}
}
@ -476,15 +470,14 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Host: example.com
|Connection: close
|
|""".stripMarginWithNewline("\r\n"))
|""")
val data = TestPublisher.manualProbe[ByteString]()
inside(expectRequest) {
case HttpRequest(GET, _, _, _, _)
responsesSub.sendNext(HttpResponse(entity = CloseDelimited(ContentTypes.`text/plain`, Source(data))))
netOutSub.request(1)
responses.sendNext(HttpResponse(entity = CloseDelimited(ContentTypes.`text/plain`, Source(data))))
val dataSub = data.expectSubscription()
dataSub.expectCancellation()
netOut.expectNext()
netOut.expectBytes(1)
}
netOut.expectComplete()
}
@ -495,34 +488,33 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Expect: 100-continue
|Content-Length: 16
|
|""".stripMarginWithNewline("\r\n"))
|""")
inside(expectRequest) {
case HttpRequest(POST, _, _, Default(ContentType(`application/octet-stream`, None), 16, data), _)
val dataProbe = TestSubscriber.manualProbe[ByteString]
data.to(Sink(dataProbe)).run()
val dataSub = dataProbe.expectSubscription()
netOutSub.request(2)
netOut.expectNoMsg(50.millis)
netOut.expectNoBytes(50.millis)
dataSub.request(1) // triggers `100 Continue` response
wipeDate(netOut.expectNext().utf8String) shouldEqual
expectResponseWithWipedDate(
"""HTTP/1.1 100 Continue
|Server: akka-http/test
|Date: XXXX
|
|""".stripMarginWithNewline("\r\n")
|""")
dataProbe.expectNoMsg(50.millis)
send("0123456789ABCDEF")
dataProbe.expectNext(ByteString("0123456789ABCDEF"))
dataProbe.expectComplete()
responsesSub.sendNext(HttpResponse(entity = "Yeah"))
wipeDate(netOut.expectNext().utf8String) shouldEqual
responses.sendNext(HttpResponse(entity = "Yeah"))
expectResponseWithWipedDate(
"""HTTP/1.1 200 OK
|Server: akka-http/test
|Date: XXXX
|Content-Type: text/plain; charset=UTF-8
|Content-Length: 4
|
|Yeah""".stripMarginWithNewline("\r\n")
|Yeah""")
}
}
@ -532,39 +524,38 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Expect: 100-continue
|Transfer-Encoding: chunked
|
|""".stripMarginWithNewline("\r\n"))
|""")
inside(expectRequest) {
case HttpRequest(POST, _, _, Chunked(ContentType(`application/octet-stream`, None), data), _)
val dataProbe = TestSubscriber.manualProbe[ChunkStreamPart]
data.to(Sink(dataProbe)).run()
val dataSub = dataProbe.expectSubscription()
netOutSub.request(2)
netOut.expectNoMsg(50.millis)
netOut.expectNoBytes(50.millis)
dataSub.request(2) // triggers `100 Continue` response
wipeDate(netOut.expectNext().utf8String) shouldEqual
expectResponseWithWipedDate(
"""HTTP/1.1 100 Continue
|Server: akka-http/test
|Date: XXXX
|
|""".stripMarginWithNewline("\r\n")
|""")
dataProbe.expectNoMsg(50.millis)
send("""10
|0123456789ABCDEF
|0
|
|""".stripMarginWithNewline("\r\n"))
|""")
dataProbe.expectNext(Chunk(ByteString("0123456789ABCDEF")))
dataProbe.expectNext(LastChunk)
dataProbe.expectComplete()
responsesSub.sendNext(HttpResponse(entity = "Yeah"))
wipeDate(netOut.expectNext().utf8String) shouldEqual
responses.sendNext(HttpResponse(entity = "Yeah"))
expectResponseWithWipedDate(
"""HTTP/1.1 200 OK
|Server: akka-http/test
|Date: XXXX
|Content-Type: text/plain; charset=UTF-8
|Content-Length: 4
|
|Yeah""".stripMarginWithNewline("\r\n")
|Yeah""")
}
}
@ -574,12 +565,11 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Expect: 100-continue
|Content-Length: 16
|
|""".stripMarginWithNewline("\r\n"))
|""")
inside(expectRequest) {
case HttpRequest(POST, _, _, Default(ContentType(`application/octet-stream`, None), 16, data), _)
netOutSub.request(1)
responsesSub.sendNext(HttpResponse(entity = "Yeah"))
wipeDate(netOut.expectNext().utf8String) shouldEqual
responses.sendNext(HttpResponse(entity = "Yeah"))
expectResponseWithWipedDate(
"""HTTP/1.1 200 OK
|Server: akka-http/test
|Date: XXXX
@ -587,7 +577,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Content-Type: text/plain; charset=UTF-8
|Content-Length: 4
|
|Yeah""".stripMarginWithNewline("\r\n")
|Yeah""")
}
}
@ -599,18 +589,17 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
expectRequest shouldEqual HttpRequest(uri = "http://example.com/", headers = List(Host("example.com")))
netOutSub.request(1)
responsesSub.expectRequest()
responsesSub.sendError(new RuntimeException("CRASH BOOM BANG"))
responses.expectRequest()
responses.sendError(new RuntimeException("CRASH BOOM BANG"))
wipeDate(netOut.expectNext().utf8String) shouldEqual
expectResponseWithWipedDate(
"""HTTP/1.1 500 Internal Server Error
|Server: akka-http/test
|Date: XXXX
|Connection: close
|Content-Length: 0
|
|""".stripMarginWithNewline("\r\n")
|""")
}
"correctly consume and render large requests and responses" in new TestSetup {
@ -618,22 +607,20 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Host: example.com
|Content-Length: 100000
|
|""".stripMarginWithNewline("\r\n"))
|""")
val HttpRequest(POST, _, _, entity, _) = expectRequest
responsesSub.expectRequest()
responsesSub.sendNext(HttpResponse(entity = entity))
responsesSub.sendComplete()
responses.sendNext(HttpResponse(entity = entity))
responses.sendComplete()
netOutSub.request(1)
wipeDate(netOut.expectNext().utf8String) shouldEqual
expectResponseWithWipedDate(
"""HTTP/1.1 200 OK
|Server: akka-http/test
|Date: XXXX
|Content-Type: application/octet-stream
|Content-Length: 100000
|
|""".stripMarginWithNewline("\r\n")
|""")
val random = new Random()
@tailrec def rec(bytesLeft: Int): Unit =
@ -641,13 +628,12 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
val count = math.min(random.nextInt(1000) + 1, bytesLeft)
val data = random.alphanumeric.take(count).mkString
send(data)
netOutSub.request(1)
netOut.expectNext().utf8String shouldEqual data
netOut.expectUtf8EncodedString(data)
rec(bytesLeft - count)
}
rec(100000)
netInSub.sendComplete()
netIn.sendComplete()
requests.expectComplete()
netOut.expectComplete()
}
@ -656,7 +642,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
send("""GET //foo HTTP/1.1
|Host: example.com
|
|""".stripMarginWithNewline("\r\n"))
|""")
expectRequest shouldEqual HttpRequest(uri = "http://example.com//foo", headers = List(Host("example.com")))
}
@ -664,7 +650,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
"use default-host-header for HTTP/1.0 requests" in new TestSetup {
send("""GET /abc HTTP/1.0
|
|""".stripMarginWithNewline("\r\n"))
|""")
expectRequest shouldEqual HttpRequest(uri = "http://example.com/abc", protocol = HttpProtocols.`HTTP/1.0`)
@ -673,10 +659,9 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
"fail an HTTP/1.0 request with 400 if no default-host-header is set" in new TestSetup {
send("""GET /abc HTTP/1.0
|
|""".stripMarginWithNewline("\r\n"))
|""")
netOutSub.request(1)
wipeDate(netOut.expectNext().utf8String) shouldEqual
expectResponseWithWipedDate(
"""|HTTP/1.1 400 Bad Request
|Server: akka-http/test
|Date: XXXX
@ -684,7 +669,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|Content-Type: text/plain; charset=UTF-8
|Content-Length: 41
|
|Request is missing required `Host` header""".stripMarginWithNewline("\r\n")
|Request is missing required `Host` header""")
}
"support remote-address-header" in new TestSetup {

View file

@ -6,6 +6,7 @@ package akka.http.impl.engine.server
import java.net.InetSocketAddress
import akka.http.impl.engine.ws.ByteStringSinkProbe
import akka.stream.io.{ SendBytes, SslTlsOutbound, SessionBytes }
import scala.concurrent.duration.FiniteDuration
@ -28,21 +29,21 @@ abstract class HttpServerTestSetupBase {
implicit def system: ActorSystem
implicit def materializer: Materializer
val requests = TestSubscriber.manualProbe[HttpRequest]
val responses = TestPublisher.manualProbe[HttpResponse]()
val requests = TestSubscriber.probe[HttpRequest]
val responses = TestPublisher.probe[HttpResponse]()
def settings = ServerSettings(system).copy(serverHeader = Some(Server(List(ProductVersion("akka-http", "test")))))
def remoteAddress: Option[InetSocketAddress] = None
val (netIn, netOut) = {
val netIn = TestPublisher.manualProbe[ByteString]()
val netOut = TestSubscriber.manualProbe[ByteString]
val netIn = TestPublisher.probe[ByteString]()
val netOut = ByteStringSinkProbe()
FlowGraph.closed(HttpServerBluePrint(settings, remoteAddress = remoteAddress, log = NoLogging)) { implicit b
server
import FlowGraph.Implicits._
Source(netIn) ~> Flow[ByteString].map(SessionBytes(null, _)) ~> server.in2
server.out1 ~> Flow[SslTlsOutbound].collect { case SendBytes(x) x } ~> Sink(netOut)
server.out1 ~> Flow[SslTlsOutbound].collect { case SendBytes(x) x } ~> netOut.sink
server.out2 ~> Sink(requests)
Source(responses) ~> server.in1
}.run()
@ -50,26 +51,26 @@ abstract class HttpServerTestSetupBase {
netIn -> netOut
}
def expectResponseWithWipedDate(expected: String): Unit = {
val trimmed = expected.stripMarginWithNewline("\r\n")
// XXXX = 4 bytes, ISO Date Time String = 29 bytes => need to request 15 bytes more than expected string
val expectedSize = ByteString(trimmed, "utf8").length + 25
val received = wipeDate(netOut.expectBytes(expectedSize).utf8String)
assert(received == trimmed, s"Expected request '$trimmed' but got '$received'")
}
def wipeDate(string: String) =
string.fastSplit('\n').map {
case s if s.startsWith("Date:") "Date: XXXX\r"
case s s
}.mkString("\n")
val netInSub = netIn.expectSubscription()
val netOutSub = netOut.expectSubscription()
val requestsSub = requests.expectSubscription()
val responsesSub = responses.expectSubscription()
def expectRequest: HttpRequest = {
requestsSub.request(1)
requests.expectNext()
}
def expectRequest: HttpRequest = requests.requestNext()
def expectNoRequest(max: FiniteDuration): Unit = requests.expectNoMsg(max)
def expectNetworkClose(): Unit = netOut.expectComplete()
def send(data: ByteString): Unit = netInSub.sendNext(data)
def send(data: String): Unit = send(ByteString(data, "UTF8"))
def send(data: ByteString): Unit = netIn.sendNext(data)
def send(string: String): Unit = send(ByteString(string.stripMarginWithNewline("\r\n"), "UTF8"))
def closeNetworkInput(): Unit = netInSub.sendComplete()
def closeNetworkInput(): Unit = netIn.sendComplete()
}

View file

@ -0,0 +1,67 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import akka.actor.ActorSystem
import akka.stream.scaladsl.{ Source, Sink }
import akka.stream.testkit.TestSubscriber
import akka.util.ByteString
import scala.annotation.tailrec
import scala.concurrent.duration.FiniteDuration
trait ByteStringSinkProbe {
def sink: Sink[ByteString, Unit]
def expectBytes(length: Int): ByteString
def expectBytes(expected: ByteString): Unit
def expectUtf8EncodedString(string: String): Unit
def expectNoBytes(): Unit
def expectNoBytes(timeout: FiniteDuration): Unit
def expectComplete(): Unit
def expectError(): Throwable
def expectError(cause: Throwable): Unit
}
object ByteStringSinkProbe {
def apply()(implicit system: ActorSystem): ByteStringSinkProbe =
new ByteStringSinkProbe {
val probe = TestSubscriber.probe[ByteString]()
val sink: Sink[ByteString, Unit] = Sink(probe)
def expectNoBytes(): Unit = {
probe.ensureSubscription()
probe.expectNoMsg()
}
def expectNoBytes(timeout: FiniteDuration): Unit = {
probe.ensureSubscription()
probe.expectNoMsg(timeout)
}
var inBuffer = ByteString.empty
@tailrec def expectBytes(length: Int): ByteString =
if (inBuffer.size >= length) {
val res = inBuffer.take(length)
inBuffer = inBuffer.drop(length)
res
} else {
inBuffer ++= probe.requestNext()
expectBytes(length)
}
def expectBytes(expected: ByteString): Unit =
assert(expectBytes(expected.length) == expected, "expected ")
def expectUtf8EncodedString(string: String): Unit =
expectBytes(ByteString(string, "utf8"))
def expectComplete(): Unit = probe.expectComplete()
def expectError(): Throwable = probe.expectError()
def expectError(cause: Throwable): Unit = probe.expectError(cause)
}
}

View file

@ -0,0 +1,68 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import scala.concurrent.duration._
import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.ws.{ TextMessage, BinaryMessage, Message }
import akka.stream.ActorMaterializer
import akka.stream.scaladsl._
import akka.util.ByteString
import scala.concurrent.Future
import scala.util.{ Failure, Success }
/**
* An example App that runs a quick test against the websocket server at wss://echo.websocket.org
*/
object EchoTestClientApp extends App {
implicit val system = ActorSystem()
import system.dispatcher
implicit val mat = ActorMaterializer()
def delayedCompletion(delay: FiniteDuration): Source[Nothing, Unit] =
Source.single(1)
.mapAsync(1)(_ akka.pattern.after(delay, system.scheduler)(Future(1)))
.drop(1).asInstanceOf[Source[Nothing, Unit]]
def messages: List[Message] =
List(
TextMessage("Test 1"),
BinaryMessage(ByteString("abc")),
TextMessage("Test 2"),
BinaryMessage(ByteString("def")))
def source: Source[Message, Unit] =
Source(messages) ++ delayedCompletion(1.second) // otherwise, we may start closing too soon
def sink: Sink[Message, Future[Seq[String]]] =
Flow[Message]
.mapAsync(1) {
case tm: TextMessage
tm.textStream.runWith(Sink.fold("")(_ + _)).map(str s"TextMessage: '$str'")
case bm: BinaryMessage
bm.dataStream.runWith(Sink.fold(ByteString.empty)(_ ++ _)).map(bs s"BinaryMessage: '${bs.utf8String}'")
}
.grouped(10000)
.toMat(Sink.head)(Keep.right)
def echoClient = Flow.wrap(sink, source)(Keep.left)
val (upgrade, res) = Http().singleWebsocketRequest("wss://echo.websocket.org", echoClient)
res onComplete {
case Success(res)
println("Run successful. Got these elements:")
res.foreach(println)
system.shutdown()
case Failure(e)
println("Run failed.")
e.printStackTrace()
system.shutdown()
}
system.scheduler.scheduleOnce(10.seconds)(system.shutdown())
}

View file

@ -134,6 +134,13 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
sub.expectNext(ByteString("def", "ASCII"))
sub.expectComplete()
}
"unmask masked input on the server side for empty frame" in new ServerTestSetup {
val mask = Random.nextInt()
val header = frameHeader(Opcode.Binary, 0, fin = true, mask = Some(mask))
pushInput(header)
expectBinaryMessage(BinaryMessage.Strict(ByteString.empty))
}
}
"for text messages" - {
"empty message" in new ClientTestSetup {
@ -207,6 +214,13 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
sub.expectNext(ByteString("cdef€", "UTF-8"))
sub.expectComplete()
}
"unmask masked input on the server side for empty frame" in new ServerTestSetup {
val mask = Random.nextInt()
val header = frameHeader(Opcode.Text, 0, fin = true, mask = Some(mask))
pushInput(header)
expectTextMessage(TextMessage.Strict(""))
}
}
}
"render frames from messages" - {
@ -265,6 +279,10 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
sub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
}
"and mask input on the client side for empty frame" in new ClientTestSetup {
pushMessage(BinaryMessage(ByteString.empty))
expectMaskedFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = true)
}
}
"for text messages" - {
"for a short strict message" in new ServerTestSetup {
@ -347,6 +365,10 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
sub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
}
"and mask input on the client side for empty frame" in new ClientTestSetup {
pushMessage(TextMessage(""))
expectMaskedFrameOnNetwork(Opcode.Text, ByteString.empty, fin = true)
}
}
}
"supply automatic low-level websocket behavior" - {
@ -440,7 +462,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
}
"after receiving close frame without close code" in new ServerTestSetup {
netInSub.expectRequest()
pushInput(frameHeader(Opcode.Close, 0, fin = true))
pushInput(frameHeader(Opcode.Close, 0, fin = true, mask = Some(Random.nextInt())))
messageIn.expectComplete()
messageOutSub.sendComplete()
@ -479,7 +501,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
netOutSub.request(10)
messageInSub.request(10)
pushInput(frameHeader(Protocol.Opcode.Binary, 0, fin = false))
pushInput(frameHeader(Protocol.Opcode.Binary, 0, fin = false, mask = Some(Random.nextInt())))
val dataSource = expectBinaryMessage().dataStream
val inSubscriber = TestSubscriber.manualProbe[ByteString]()
dataSource.runWith(Sink(inSubscriber))
@ -742,10 +764,23 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
pushInput(input)
expectProtocolErrorOnNetwork()
}
"unmasked input on the server side for empty frame" in new ServerTestSetup {
val input = frameHeader(Opcode.Binary, 0, fin = true)
pushInput(input)
expectProtocolErrorOnNetwork()
}
"masked input on the client side" in new ClientTestSetup {
val mask = Random.nextInt()
val input = frameHeader(Opcode.Binary, 6, fin = true, mask = Some(mask)) ++ maskedASCII("abcdef", mask)._1
pushInput(input)
expectProtocolErrorOnNetwork()
}
"masked input on the client side for empty frame" in new ClientTestSetup {
val mask = Random.nextInt()
val input = frameHeader(Opcode.Binary, 0, fin = true, mask = Some(mask))
pushInput(input)
expectProtocolErrorOnNetwork()
}
@ -782,7 +817,7 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
Source(netIn)
.via(printEvent("netIn"))
.transform(() new FrameEventParser)
.via(Websocket.stack(serverSide, closeTimeout = closeTimeout).join(messageHandler))
.via(Websocket.stack(serverSide, maskingRandomFactory = Randoms.SecureRandomInstances, closeTimeout = closeTimeout).join(messageHandler))
.via(printEvent("frameRendererIn"))
.transform(() new FrameEventRenderer)
.via(printEvent("frameRendererOut"))
@ -813,9 +848,15 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
def expectBinaryMessage(): BinaryMessage =
expectMessage().asInstanceOf[BinaryMessage]
def expectBinaryMessage(message: BinaryMessage): Unit =
expectBinaryMessage() shouldEqual message
def expectTextMessage(): TextMessage =
expectMessage().asInstanceOf[TextMessage]
def expectTextMessage(message: TextMessage): Unit =
expectTextMessage() shouldEqual message
var inBuffer = ByteString.empty
@tailrec final def expectNetworkData(bytes: Int): ByteString =
if (inBuffer.size >= bytes) {

View file

@ -0,0 +1,227 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import scala.concurrent.{ Promise, Future }
import scala.util.{ Try, Failure, Success }
import spray.json._
import akka.actor.ActorSystem
import akka.stream.ActorMaterializer
import akka.stream.io.SslTlsPlacebo
import akka.stream.stage.{ TerminationDirective, Context, SyncDirective, PushStage }
import akka.stream.scaladsl._
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.Uri
import akka.http.scaladsl.model.ws._
object WSClientAutobahnTest extends App {
implicit val system = ActorSystem()
import system.dispatcher
implicit val mat = ActorMaterializer()
val Agent = "akka-http"
val Parallelism = 4
val getCaseCountUri: Uri =
s"ws://localhost:9001/getCaseCount"
def runCaseUri(caseIndex: Int, agent: String): Uri =
s"ws://localhost:9001/runCase?case=$caseIndex&agent=$agent"
def getCaseStatusUri(caseIndex: Int, agent: String): Uri =
s"ws://localhost:9001/getCaseStatus?case=$caseIndex&agent=$agent"
def getCaseInfoUri(caseIndex: Int): Uri =
s"ws://localhost:9001/getCaseInfo?case=$caseIndex"
def updateReportsUri(agent: String): Uri =
s"ws://localhost:9001/updateReports?agent=$agent"
def runCase(caseIndex: Int, agent: String = Agent): Future[CaseStatus] =
runWs(runCaseUri(caseIndex, agent), echo).recover { case _ () }.flatMap { _
getCaseStatus(caseIndex, agent)
}
def richRunCase(caseIndex: Int, agent: String = Agent): Future[CaseResult] = {
val info = getCaseInfo(caseIndex)
val startMillis = System.currentTimeMillis()
val status = runCase(caseIndex, agent).map { res
val lastedMillis = System.currentTimeMillis() - startMillis
(res, lastedMillis)
}
import Console._
info.flatMap { i
val prefix = f"$YELLOW${i.caseInfo.id}%-7s$RESET - $WHITE${i.caseInfo.description}$RESET ... "
//println(prefix)
status.onComplete {
case Success((CaseStatus(status), millis))
val color = if (status == "OK") GREEN else RED
println(f"${color}$status%-15s$RESET$millis%5d ms $prefix")
case Failure(e)
println(s"$prefix${RED}failed with '${e.getMessage}'$RESET")
}
status.map(s CaseResult(i.caseInfo, s._1))
}
}
def getCaseCount(): Future[Int] =
runToSingleText(getCaseCountUri).map(_.toInt)
def getCaseInfo(caseId: Int): Future[IndexedCaseInfo] =
runToSingleJsonValue[CaseInfo](getCaseInfoUri(caseId)).map(IndexedCaseInfo(caseId, _))
def getCaseStatus(caseId: Int, agent: String = Agent): Future[CaseStatus] =
runToSingleJsonValue[CaseStatus](getCaseStatusUri(caseId, agent))
def updateReports(agent: String = Agent): Future[Unit] =
runToSingleText(updateReportsUri(agent)).map(_ ())
/**
* Map from textual case ID (like 1.1.1) to IndexedCaseInfo
* @return
*/
def getCaseMap(): Future[Map[String, IndexedCaseInfo]] = {
val res =
getCaseCount().flatMap { count
println(s"Retrieving case info for $count cases...")
Future.traverse(1 to count)(getCaseInfo).map(_.map(e e.caseInfo.id -> e).toMap)
}
res.foreach { res
println(s"Received info for ${res.size} cases")
}
res
}
def echo = Flow[Message].viaMat(completionSignal)(Keep.right)
if (args.size >= 1) {
// run one
val testId = args(0)
println(s"Trying to run test $testId")
getCaseMap().flatMap { map
val info = map(testId)
richRunCase(info.index)
}.onComplete {
case Success(res)
println(s"Run successfully finished!")
updateReportsAndShutdown()
case Failure(e)
println("Run failed with this exception")
e.printStackTrace()
updateReportsAndShutdown()
}
} else {
println("Running complete test suite")
getCaseCount().flatMap { count
println(s"Found $count tests.")
Source(1 to count).mapAsyncUnordered(Parallelism)(richRunCase(_)).grouped(count).runWith(Sink.head)
}.map { results
val grouped =
results.groupBy(_.status.behavior)
import Console._
println(s"${results.size} tests run.")
println()
println(s"${GREEN}OK$RESET: ${grouped.getOrElse("OK", Nil).size}")
val notOk = grouped.filterNot(_._1 == "OK")
notOk.toSeq.sortBy(_._2.size).foreach {
case (status, cases) println(s"$RED$status$RESET: ${cases.size}")
}
println()
println("Not OK tests")
println()
results.filterNot(_.status.behavior == "OK").foreach { r
println(f"$RED${r.status.behavior}%-20s$RESET $YELLOW${r.info.id}%-7s$RESET - $WHITE${r.info.description}$RESET")
}
()
}
.onComplete(completion)
}
def completion[T]: Try[T] Unit = {
case Success(res)
println(s"Run successfully finished!")
updateReportsAndShutdown()
case Failure(e)
println("Run failed with this exception")
e.printStackTrace()
updateReportsAndShutdown()
}
def updateReportsAndShutdown(): Unit =
updateReports().onComplete { res
println("Reports should now be accessible at http://localhost:8080/cwd/reports/clients/index.html")
system.shutdown()
}
import scala.concurrent.duration._
import system.dispatcher
system.scheduler.scheduleOnce(60.seconds)(system.shutdown())
def runWs[T](uri: Uri, clientFlow: Flow[Message, Message, T]): T =
Http().singleWebsocketRequest(uri, clientFlow)._2
def completionSignal[T]: Flow[T, T, Future[Unit]] =
Flow[T].transformMaterializing { ()
val p = Promise[Unit]()
val stage =
new PushStage[T, T] {
def onPush(elem: T, ctx: Context[T]): SyncDirective = ctx.push(elem)
override def onUpstreamFinish(ctx: Context[T]): TerminationDirective = {
p.success(())
super.onUpstreamFinish(ctx)
}
override def onDownstreamFinish(ctx: Context[T]): TerminationDirective = {
p.success(()) // should this be failure as well?
super.onDownstreamFinish(ctx)
}
override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = {
p.failure(cause)
super.onUpstreamFailure(cause, ctx)
}
}
(stage, p.future)
}
/**
* The autobahn tests define a weird API where every request must be a Websocket request and
* they will send a single websocket message with the result. Websocket everywhere? Strange,
* but somewhat consistent.
*/
def runToSingleText(uri: Uri): Future[String] = {
val sink = Sink.head[Message]
runWs(uri, Flow.wrap(sink, Source.lazyEmpty[Message])(Keep.left)).flatMap {
case tm: TextMessage tm.textStream.runWith(Sink.fold("")(_ + _))
}
}
def runToSingleJsonValue[T: JsonReader](uri: Uri): Future[T] =
runToSingleText(uri).map(_.parseJson.convertTo[T])
case class IndexedCaseInfo(index: Int, caseInfo: CaseInfo)
case class CaseResult(info: CaseInfo, status: CaseStatus)
// {"behavior": "OK"}
case class CaseStatus(behavior: String) {
def isSuccessful: Boolean = behavior == "OK"
}
object CaseStatus {
import DefaultJsonProtocol._
implicit def caseStatusFormat: JsonFormat[CaseStatus] = jsonFormat1(CaseStatus.apply)
}
// {"id": "1.1.1", "description": "Send text message with payload 0."}
case class CaseInfo(id: String, description: String)
object CaseInfo {
import DefaultJsonProtocol._
implicit def caseInfoFormat: JsonFormat[CaseInfo] = jsonFormat2(CaseInfo.apply)
}
}

View file

@ -0,0 +1,44 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import scala.concurrent.Await
import scala.concurrent.duration._
import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.HttpMethods._
import akka.http.scaladsl.model.ws.{ Message, UpgradeToWebsocket }
import akka.http.scaladsl.model._
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.Flow
object WSServerAutobahnTest extends App {
implicit val system = ActorSystem("WSServerTest")
implicit val fm = ActorMaterializer()
try {
val binding = Http().bindAndHandleSync({
case req @ HttpRequest(GET, Uri.Path("/"), _, _, _) if req.header[UpgradeToWebsocket].isDefined
req.header[UpgradeToWebsocket] match {
case Some(upgrade) upgrade.handleMessages(echoWebsocketService) // needed for running the autobahn test suite
case None HttpResponse(400, entity = "Not a valid websocket request!")
}
case _: HttpRequest HttpResponse(404, entity = "Unknown resource!")
},
interface = "172.17.42.1", // adapt to your docker host IP address if necessary
port = 9001)
Await.result(binding, 1.second) // throws if binding fails
println("Server online at http://172.17.42.1:9001")
println("Press RETURN to stop...")
Console.readLine()
} finally {
system.shutdown()
}
def echoWebsocketService: Flow[Message, Message, Unit] =
Flow[Message] // just let message flow directly to the output
}

View file

@ -0,0 +1,111 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import akka.http.impl.engine.ws.Protocol.Opcode
import akka.http.impl.engine.ws.WSTestUtils._
import akka.util.ByteString
import org.scalatest.Matchers
import scala.annotation.tailrec
import scala.util.Random
trait WSTestSetupBase extends Matchers {
def send(bytes: ByteString): Unit
def expectBytes(length: Int): ByteString
def expectBytes(bytes: ByteString): Unit
def sendWSFrame(opcode: Opcode,
data: ByteString,
fin: Boolean,
mask: Boolean = false,
rsv1: Boolean = false,
rsv2: Boolean = false,
rsv3: Boolean = false): Unit = {
val (theMask, theData) =
if (mask) {
val m = Random.nextInt()
(Some(m), maskedBytes(data, m)._1)
} else (None, data)
send(frameHeader(opcode, data.length, fin, theMask, rsv1, rsv2, rsv3) ++ theData)
}
def sendWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit =
send(closeFrame(closeCode, mask))
def expectWSFrame(opcode: Opcode,
data: ByteString,
fin: Boolean,
mask: Option[Int] = None,
rsv1: Boolean = false,
rsv2: Boolean = false,
rsv3: Boolean = false): Unit =
expectBytes(frameHeader(opcode, data.length, fin, mask, rsv1, rsv2, rsv3) ++ data)
def expectWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit =
expectBytes(closeFrame(closeCode, mask))
def expectNetworkData(length: Int): ByteString = expectBytes(length)
def expectNetworkData(data: ByteString): Unit = expectBytes(data)
def expectFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = {
expectFrameHeaderOnNetwork(opcode, data.size, fin)
expectNetworkData(data)
}
def expectMaskedFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = {
val Some(mask) = expectFrameHeaderOnNetwork(opcode, data.size, fin)
val masked = maskedBytes(data, mask)._1
expectNetworkData(masked)
}
def expectMaskedCloseFrame(closeCode: Int): Unit =
expectMaskedFrameOnNetwork(Protocol.Opcode.Close, closeFrameData(closeCode), fin = true)
/** Returns the mask if any is available */
def expectFrameHeaderOnNetwork(opcode: Opcode, length: Long, fin: Boolean): Option[Int] = {
val (op, l, f, m) = expectFrameHeaderOnNetwork()
op shouldEqual opcode
l shouldEqual length
f shouldEqual fin
m
}
def expectFrameHeaderOnNetwork(): (Opcode, Long, Boolean, Option[Int]) = {
val header = expectNetworkData(2)
val fin = (header(0) & Protocol.FIN_MASK) != 0
val op = header(0) & Protocol.OP_MASK
val hasMask = (header(1) & Protocol.MASK_MASK) != 0
val length7 = header(1) & Protocol.LENGTH_MASK
val length = length7 match {
case 126
val length16Bytes = expectNetworkData(2)
(length16Bytes(0) & 0xff) << 8 | (length16Bytes(1) & 0xff) << 0
case 127
val length64Bytes = expectNetworkData(8)
(length64Bytes(0) & 0xff).toLong << 56 |
(length64Bytes(1) & 0xff).toLong << 48 |
(length64Bytes(2) & 0xff).toLong << 40 |
(length64Bytes(3) & 0xff).toLong << 32 |
(length64Bytes(4) & 0xff).toLong << 24 |
(length64Bytes(5) & 0xff).toLong << 16 |
(length64Bytes(6) & 0xff).toLong << 8 |
(length64Bytes(7) & 0xff).toLong << 0
case x x
}
val mask =
if (hasMask) {
val maskBytes = expectNetworkData(4)
val mask =
(maskBytes(0) & 0xff) << 24 |
(maskBytes(1) & 0xff) << 16 |
(maskBytes(2) & 0xff) << 8 |
(maskBytes(3) & 0xff) << 0
Some(mask)
} else None
(Opcode.forCode(op.toByte), length, fin, mask)
}
}

View file

@ -43,14 +43,19 @@ object WSTestUtils {
val lengthByte = lengthByteComponent | maskMask
ByteString(opcodeByte.toByte, lengthByte.toByte) ++ lengthBytes ++ maskBytes
}
def closeFrame(closeCode: Int, mask: Boolean): ByteString =
def frame(opcode: Opcode, data: ByteString, fin: Boolean, mask: Boolean): ByteString =
if (mask) {
val mask = Random.nextInt()
frameHeader(Opcode.Close, 2, fin = true, mask = Some(mask)) ++
maskedBytes(shortBE(closeCode), mask)._1
frameHeader(opcode, data.size, fin, mask = Some(mask)) ++
maskedBytes(data, mask)._1
} else
frameHeader(Opcode.Close, 2, fin = true) ++
shortBE(closeCode)
frameHeader(opcode, data.size, fin, mask = None) ++ data
def closeFrame(closeCode: Int, mask: Boolean): ByteString =
frame(Opcode.Close, closeFrameData(closeCode), fin = true, mask)
def closeFrameData(closeCode: Int): ByteString =
shortBE(closeCode)
def maskedASCII(str: String, mask: Int): (ByteString, Int) =
FrameEventParser.mask(ByteString(str, "ASCII"), mask)

View file

@ -0,0 +1,370 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import java.util.Random
import akka.http.scaladsl.Http.{ InvalidUpgradeResponse, WebsocketUpgradeResponse }
import scala.concurrent.duration._
import akka.http.ClientConnectionSettings
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.headers.{ ProductVersion, `User-Agent` }
import akka.http.scaladsl.model.ws._
import akka.http.scaladsl.model.{ HttpResponse, Uri }
import akka.stream.io._
import akka.stream.scaladsl._
import akka.stream.testkit.{ TestSubscriber, TestPublisher }
import akka.util.ByteString
import org.scalatest.{ Matchers, FreeSpec }
import akka.http.impl.util._
class WebsocketClientSpec extends FreeSpec with Matchers with WithMaterializerSpec {
"The client-side Websocket implementation should" - {
"establish a websocket connection when the user requests it" in new EstablishedConnectionSetup with ClientEchoes
"establish connection with case insensitive header values" in new TestSetup with ClientEchoes {
expectWireData(UpgradeRequestBytes)
sendWireData("""HTTP/1.1 101 Switching Protocols
|Upgrade: wEbSOckET
|Sec-WebSocket-Accept: ujmZX4KXZqjwy6vi1aQFH5p4Ygk=
|Server: akka-http/test
|Sec-WebSocket-Version: 13
|Connection: upgrade
|
|""")
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true)
expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 1"), fin = true)
}
"reject invalid handshakes" - {
"other status code" in new TestSetup with ClientEchoes {
expectWireData(UpgradeRequestBytes)
sendWireData(
"""HTTP/1.1 404 Not Found
|Server: akka-http/test
|Content-Length: 0
|
|""")
expectNetworkAbort()
expectInvalidUpgradeResponseCause("Websocket server at ws://example.org/ws returned unexpected status code: 404 Not Found")
}
"missing Sec-WebSocket-Accept hash" in new TestSetup with ClientEchoes {
expectWireData(UpgradeRequestBytes)
sendWireData(
"""HTTP/1.1 101 Switching Protocols
|Upgrade: websocket
|Sec-WebSocket-Version: 13
|Server: akka-http/test
|Connection: upgrade
|
|""")
expectNetworkAbort()
expectInvalidUpgradeResponseCause("Websocket server at ws://example.org/ws returned response that was missing required `Sec-WebSocket-Accept` header.")
}
"wrong Sec-WebSocket-Accept hash" in new TestSetup with ClientEchoes {
expectWireData(UpgradeRequestBytes)
sendWireData(
"""HTTP/1.1 101 Switching Protocols
|Upgrade: websocket
|Sec-WebSocket-Accept: s3pPLMBiTxhZRbK+xOo=
|Sec-WebSocket-Version: 13
|Server: akka-http/test
|Connection: upgrade
|
|""")
expectNetworkAbort()
expectInvalidUpgradeResponseCause("Websocket server at ws://example.org/ws returned response with invalid `Sec-WebSocket-Accept` header.")
}
"missing `Upgrade` header" in new TestSetup with ClientEchoes {
expectWireData(UpgradeRequestBytes)
sendWireData(
"""HTTP/1.1 101 Switching Protocols
|Sec-WebSocket-Accept: ujmZX4KXZqjwy6vi1aQFH5p4Ygk=
|Sec-WebSocket-Version: 13
|Server: akka-http/test
|Connection: upgrade
|
|""")
expectNetworkAbort()
expectInvalidUpgradeResponseCause("Websocket server at ws://example.org/ws returned response that was missing required `Upgrade` header.")
}
"missing `Connection: upgrade` header" in new TestSetup with ClientEchoes {
expectWireData(UpgradeRequestBytes)
sendWireData(
"""HTTP/1.1 101 Switching Protocols
|Upgrade: websocket
|Sec-WebSocket-Accept: ujmZX4KXZqjwy6vi1aQFH5p4Ygk=
|Sec-WebSocket-Version: 13
|Server: akka-http/test
|
|""")
expectNetworkAbort()
expectInvalidUpgradeResponseCause("Websocket server at ws://example.org/ws returned response that was missing required `Connection` header.")
}
}
"don't send out frames before handshake was finished successfully" in new TestSetup {
def clientImplementation: Flow[Message, Message, Unit] =
Flow.wrap(Sink.ignore, Source.single(TextMessage("fast message")))(Keep.none)
expectWireData(UpgradeRequestBytes)
expectNoWireData()
sendWireData(UpgradeResponseBytes)
expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("fast message"), fin = true)
expectMaskedCloseFrame(Protocol.CloseCodes.Regular)
sendWSCloseFrame(Protocol.CloseCodes.Regular)
closeNetworkInput()
expectNetworkClose()
}
"receive first frame in same chunk as HTTP upgrade response" in new TestSetup with ClientProbes {
expectWireData(UpgradeRequestBytes)
val firstFrame = WSTestUtils.frame(Protocol.Opcode.Text, ByteString("fast"), fin = true, mask = false)
sendWireData(UpgradeResponseBytes ++ firstFrame)
messagesIn.requestNext(TextMessage("fast"))
}
"manual scenario client sends first" in new EstablishedConnectionSetup with ClientProbes {
messagesOut.sendNext(TextMessage("Message 1"))
expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 1"), fin = true)
sendWSFrame(Protocol.Opcode.Binary, ByteString("Response"), fin = true, mask = false)
messagesIn.requestNext(BinaryMessage(ByteString("Response")))
}
"client echoes scenario" in new EstablishedConnectionSetup with ClientEchoes {
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true)
expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 1"), fin = true)
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 2"), fin = true)
expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 2"), fin = true)
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 3"), fin = true)
expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 3"), fin = true)
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 4"), fin = true)
expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 4"), fin = true)
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 5"), fin = true)
expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 5"), fin = true)
sendWSCloseFrame(Protocol.CloseCodes.Regular)
expectMaskedCloseFrame(Protocol.CloseCodes.Regular)
closeNetworkInput()
expectNetworkClose()
}
"support subprotocols" - {
"accept if server supports subprotocol" in new TestSetup with ClientEchoes {
override protected def requestedSubProtocol: Option[String] = Some("v2")
expectWireData(
"""GET /ws HTTP/1.1
|Upgrade: websocket
|Connection: upgrade
|Sec-WebSocket-Key: YLQguzhR2dR6y5M9vnA5mw==
|Sec-WebSocket-Version: 13
|Sec-WebSocket-Protocol: v2
|Host: example.org
|User-Agent: akka-http/test
|
|""")
sendWireData(
"""HTTP/1.1 101 Switching Protocols
|Upgrade: websocket
|Sec-WebSocket-Accept: ujmZX4KXZqjwy6vi1aQFH5p4Ygk=
|Sec-WebSocket-Version: 13
|Server: akka-http/test
|Connection: upgrade
|Sec-WebSocket-Protocol: v2
|
|""")
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true)
expectMaskedFrameOnNetwork(Protocol.Opcode.Text, ByteString("Message 1"), fin = true)
}
"send error on user flow if server doesn't support subprotocol" - {
"if no protocol was selected" in new TestSetup with ClientProbes {
override protected def requestedSubProtocol: Option[String] = Some("v2")
expectWireData(
"""GET /ws HTTP/1.1
|Upgrade: websocket
|Connection: upgrade
|Sec-WebSocket-Key: YLQguzhR2dR6y5M9vnA5mw==
|Sec-WebSocket-Version: 13
|Sec-WebSocket-Protocol: v2
|Host: example.org
|User-Agent: akka-http/test
|
|""")
sendWireData(
"""HTTP/1.1 101 Switching Protocols
|Upgrade: websocket
|Sec-WebSocket-Accept: ujmZX4KXZqjwy6vi1aQFH5p4Ygk=
|Sec-WebSocket-Version: 13
|Server: akka-http/test
|Connection: upgrade
|
|""")
expectNetworkAbort()
expectInvalidUpgradeResponseCause(
"Websocket server at ws://example.org/ws returned response that indicated that the given subprotocol was not supported. (client supported: v2, server supported: None)")
}
"if different protocol was selected" in new TestSetup with ClientProbes {
override protected def requestedSubProtocol: Option[String] = Some("v2")
expectWireData(
"""GET /ws HTTP/1.1
|Upgrade: websocket
|Connection: upgrade
|Sec-WebSocket-Key: YLQguzhR2dR6y5M9vnA5mw==
|Sec-WebSocket-Version: 13
|Sec-WebSocket-Protocol: v2
|Host: example.org
|User-Agent: akka-http/test
|
|""")
sendWireData(
"""HTTP/1.1 101 Switching Protocols
|Upgrade: websocket
|Sec-WebSocket-Accept: ujmZX4KXZqjwy6vi1aQFH5p4Ygk=
|Sec-WebSocket-Protocol: v3
|Sec-WebSocket-Version: 13
|Server: akka-http/test
|Connection: upgrade
|
|""")
expectNetworkAbort()
expectInvalidUpgradeResponseCause(
"Websocket server at ws://example.org/ws returned response that indicated that the given subprotocol was not supported. (client supported: v2, server supported: Some(v3))")
}
}
}
}
def UpgradeRequestBytes = ByteString {
"""GET /ws HTTP/1.1
|Upgrade: websocket
|Connection: upgrade
|Sec-WebSocket-Key: YLQguzhR2dR6y5M9vnA5mw==
|Sec-WebSocket-Version: 13
|Host: example.org
|User-Agent: akka-http/test
|
|""".stripMarginWithNewline("\r\n")
}
def UpgradeResponseBytes = ByteString {
"""HTTP/1.1 101 Switching Protocols
|Upgrade: websocket
|Sec-WebSocket-Accept: ujmZX4KXZqjwy6vi1aQFH5p4Ygk=
|Server: akka-http/test
|Sec-WebSocket-Version: 13
|Connection: upgrade
|
|""".stripMarginWithNewline("\r\n")
}
abstract class EstablishedConnectionSetup extends TestSetup {
expectWireData(UpgradeRequestBytes)
sendWireData(UpgradeResponseBytes)
}
abstract class TestSetup extends WSTestSetupBase {
protected def noMsgTimeout: FiniteDuration = 100.millis
protected def clientImplementation: Flow[Message, Message, Unit]
protected def requestedSubProtocol: Option[String] = None
val random = new Random(0)
def settings = ClientConnectionSettings(system)
.copy(
userAgentHeader = Some(`User-Agent`(List(ProductVersion("akka-http", "test")))),
websocketRandomFactory = () random)
def targetUri: Uri = "ws://example.org/ws"
def clientLayer: Http.WebsocketClientLayer =
Http(system).websocketClientLayer(
WebsocketRequest(targetUri, subprotocol = requestedSubProtocol),
settings = settings)
val (netOut, netIn, response) = {
val netOut = ByteStringSinkProbe()
val netIn = TestPublisher.probe[ByteString]()
val graph =
FlowGraph.closed(clientLayer) { implicit b
client
import FlowGraph.Implicits._
Source(netIn) ~> Flow[ByteString].map(SessionBytes(null, _)) ~> client.in2
client.out1 ~> Flow[SslTlsOutbound].collect { case SendBytes(x) x } ~> netOut.sink
client.out2 ~> clientImplementation ~> client.in1
}
val response = graph.run()
(netOut, netIn, response)
}
def expectBytes(length: Int): ByteString = netOut.expectBytes(length)
def expectBytes(bytes: ByteString): Unit = netOut.expectBytes(bytes)
def wipeDate(string: String) =
string.fastSplit('\n').map {
case s if s.startsWith("Date:") "Date: XXXX\r"
case s s
}.mkString("\n")
def sendWireData(data: String): Unit = sendWireData(ByteString(data.stripMarginWithNewline("\r\n"), "ASCII"))
def sendWireData(data: ByteString): Unit = netIn.sendNext(data)
def send(bytes: ByteString): Unit = sendWireData(bytes)
def expectWireData(s: String) =
netOut.expectUtf8EncodedString(s.stripMarginWithNewline("\r\n"))
def expectWireData(bs: ByteString) = netOut.expectBytes(bs)
def expectNoWireData() = netOut.expectNoBytes(noMsgTimeout)
def expectNetworkClose(): Unit = netOut.expectComplete()
def expectNetworkAbort(): Unit = netOut.expectError()
def closeNetworkInput(): Unit = netIn.sendComplete()
def expectResponse(response: WebsocketUpgradeResponse): Unit =
expectInvalidUpgradeResponse() shouldEqual response
def expectInvalidUpgradeResponseCause(expected: String): Unit =
expectInvalidUpgradeResponse().cause shouldEqual expected
import akka.http.impl.util._
def expectInvalidUpgradeResponse(): InvalidUpgradeResponse =
response.awaitResult(1.second).asInstanceOf[InvalidUpgradeResponse]
}
trait ClientEchoes extends TestSetup {
override def clientImplementation: Flow[Message, Message, Unit] = echoServer
def echoServer: Flow[Message, Message, Unit] = Flow[Message]
}
trait ClientProbes extends TestSetup {
lazy val messagesOut = TestPublisher.probe[Message]()
lazy val messagesIn = TestSubscriber.probe[Message]()
override def clientImplementation: Flow[Message, Message, Unit] =
Flow.wrap(Sink(messagesIn), Source(messagesOut))(Keep.none)
}
}

View file

@ -4,7 +4,6 @@
package akka.http.impl.engine.ws
import akka.http.impl.engine.ws.Protocol.Opcode
import akka.http.scaladsl.model.ws._
import akka.stream.scaladsl.{ Keep, Sink, Flow, Source }
import akka.stream.testkit.Utils
@ -15,8 +14,6 @@ import akka.http.impl.util._
import akka.http.impl.engine.server.HttpServerTestSetupBase
import scala.util.Random
class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSpec { spec
import WSTestUtils._
@ -33,7 +30,7 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp
|Origin: http://example.com
|Sec-WebSocket-Version: 13
|
|""".stripMarginWithNewline("\r\n"))
|""")
val request = expectRequest
val upgrade = request.header[UpgradeToWebsocket]
@ -43,9 +40,9 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp
Source(List(1, 2, 3, 4, 5)).map(num TextMessage.Strict(s"Message $num"))
val handler = Flow.wrap(Sink.ignore, source)(Keep.none)
val response = upgrade.get.handleMessages(handler)
responsesSub.sendNext(response)
responses.sendNext(response)
wipeDate(expectNextChunk().utf8String) shouldEqual
expectResponseWithWipedDate(
"""HTTP/1.1 101 Switching Protocols
|Upgrade: websocket
|Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
@ -53,15 +50,11 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp
|Date: XXXX
|Connection: upgrade
|
|""".stripMarginWithNewline("\r\n")
|""")
expectWSFrame(Protocol.Opcode.Text,
ByteString("Message 1"), fin = true)
expectWSFrame(Protocol.
Opcode.Text, ByteString("Message 2"), fin = true)
expectWSFrame(
Protocol.Opcode.Text, ByteString("Message 3"), fin = true)
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true)
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 2"), fin = true)
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 3"), fin = true)
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 4"), fin = true)
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 5"), fin = true)
expectWSCloseFrame(Protocol.CloseCodes.Regular)
@ -83,16 +76,16 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp
|Origin: http://example.com
|Sec-WebSocket-Version: 13
|
|""".stripMarginWithNewline("\r\n"))
|""")
val request = expectRequest
val upgrade = request.header[UpgradeToWebsocket]
upgrade.isDefined shouldBe true
val response = upgrade.get.handleMessages(Flow[Message]) // simple echoing
responsesSub.sendNext(response)
responses.sendNext(response)
wipeDate(expectNextChunk().utf8String) shouldEqual
expectResponseWithWipedDate(
"""HTTP/1.1 101 Switching Protocols
|Upgrade: websocket
|Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
@ -100,7 +93,7 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp
|Date: XXXX
|Connection: upgrade
|
|""".stripMarginWithNewline("\r\n")
|""")
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true, mask = true)
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true)
@ -123,6 +116,7 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp
}
"prevent the selection of an unavailable subprotocol" in pending
"reject invalid Websocket handshakes" - {
"missing `Upgrade: websocket` header" in pending
"missing `Connection: upgrade` header" in pending
"missing `Sec-WebSocket-Key header" in pending
"`Sec-WebSocket-Key` with wrong amount of base64 encoded data" in pending
@ -131,42 +125,11 @@ class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSp
}
}
class TestSetup extends HttpServerTestSetupBase {
class TestSetup extends HttpServerTestSetupBase with WSTestSetupBase {
implicit def system = spec.system
implicit def materializer = spec.materializer
def sendWSFrame(opcode: Opcode,
data: ByteString,
fin: Boolean,
mask: Boolean = false,
rsv1: Boolean = false,
rsv2: Boolean = false,
rsv3: Boolean = false): Unit = {
val (theMask, theData) =
if (mask) {
val m = Random.nextInt()
(Some(m), maskedBytes(data, m)._1)
} else (None, data)
send(frameHeader(opcode, data.length, fin, theMask, rsv1, rsv2, rsv3) ++ theData)
}
def sendWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit =
send(closeFrame(closeCode, mask))
def expectNextChunk(): ByteString = {
netOutSub.request(1)
netOut.expectNext()
}
def expectWSFrame(opcode: Opcode,
data: ByteString,
fin: Boolean,
mask: Option[Int] = None,
rsv1: Boolean = false,
rsv2: Boolean = false,
rsv3: Boolean = false): Unit =
expectNextChunk() shouldEqual frameHeader(opcode, data.length, fin, mask, rsv1, rsv2, rsv3) ++ data
def expectWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit =
expectNextChunk() shouldEqual closeFrame(closeCode, mask)
def expectBytes(length: Int): ByteString = netOut.expectBytes(length)
def expectBytes(bytes: ByteString): Unit = netOut.expectBytes(bytes)
}
}