+htp #17726 add support for WS subprotocols to WebsocketDirectives

Also, update WebsocketDirectiveSpecs to the new WS testing infrastructure.
This commit is contained in:
Johannes Rudolph 2015-08-20 16:20:39 +02:00
parent 12edec7073
commit 060ea707c9
4 changed files with 137 additions and 26 deletions

View file

@ -4,42 +4,100 @@
package akka.http.scaladsl.server.directives
import scala.collection.immutable.Seq
import akka.http.impl.engine.ws.InternalCustomHeader
import akka.http.scaladsl.model.headers.{ UpgradeProtocol, Upgrade }
import akka.http.scaladsl.model.{ HttpRequest, StatusCodes, HttpResponse }
import akka.http.scaladsl.model.ws.{ Message, UpgradeToWebsocket }
import akka.http.scaladsl.server.{ Route, RoutingSpec }
import akka.stream.scaladsl.Flow
import akka.util.ByteString
import akka.stream.OverflowStrategy
import akka.stream.scaladsl.{ Source, Sink, Flow }
import akka.http.scaladsl.testkit.WSProbe
import akka.http.scaladsl.model.headers.`Sec-WebSocket-Protocol`
import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.model.ws._
import akka.http.scaladsl.server.{ UnsupportedWebsocketSubprotocolRejection, ExpectedWebsocketRequestRejection, Route, RoutingSpec }
class WebsocketDirectivesSpec extends RoutingSpec {
"the handleWebsocketMessages directive" should {
"handle websocket requests" in {
Get("http://localhost/") ~> Upgrade(List(UpgradeProtocol("websocket"))) ~>
emulateHttpCore ~> Route.seal(handleWebsocketMessages(Flow[Message])) ~>
val wsClient = WSProbe()
WS("http://localhost/", wsClient.flow) ~> websocketRoute ~>
check {
status shouldEqual StatusCodes.SwitchingProtocols
isWebsocketUpgrade shouldEqual true
wsClient.sendMessage("Peter")
wsClient.expectMessage("Hello Peter!")
wsClient.sendMessage(BinaryMessage(ByteString("abcdef")))
// wsClient.expectNoMessage() // will be checked implicitly by next expectation
wsClient.sendMessage("John")
wsClient.expectMessage("Hello John!")
wsClient.sendCompletion()
wsClient.expectCompletion()
}
}
"choose subprotocol from offered ones" in {
val wsClient = WSProbe()
WS("http://localhost/", wsClient.flow, List("other", "echo", "greeter")) ~> websocketMultipleProtocolRoute ~>
check {
expectWebsocketUpgradeWithProtocol { protocol
protocol shouldEqual "echo"
wsClient.sendMessage("Peter")
wsClient.expectMessage("Peter")
wsClient.sendMessage(BinaryMessage(ByteString("abcdef")))
wsClient.expectMessage(ByteString("abcdef"))
wsClient.sendMessage("John")
wsClient.expectMessage("John")
wsClient.sendCompletion()
wsClient.expectCompletion()
}
}
}
"reject websocket requests if no subprotocol matches" in {
WS("http://localhost/", Flow[Message], List("other")) ~> websocketMultipleProtocolRoute ~> check {
rejections.collect {
case UnsupportedWebsocketSubprotocolRejection(p) p
}.toSet shouldEqual Set("greeter", "echo")
}
WS("http://localhost/", Flow[Message], List("other")) ~> Route.seal(websocketMultipleProtocolRoute) ~> check {
status shouldEqual StatusCodes.BadRequest
responseAs[String] shouldEqual "None of the websocket subprotocols offered in the request are supported. Supported are 'echo','greeter'."
header[`Sec-WebSocket-Protocol`].get.protocols.toSet shouldEqual Set("greeter", "echo")
}
}
"reject non-websocket requests" in {
Get("http://localhost/") ~> emulateHttpCore ~> Route.seal(handleWebsocketMessages(Flow[Message])) ~> check {
Get("http://localhost/") ~> websocketRoute ~> check {
rejection shouldEqual ExpectedWebsocketRequestRejection
}
Get("http://localhost/") ~> Route.seal(websocketRoute) ~> check {
status shouldEqual StatusCodes.BadRequest
responseAs[String] shouldEqual "Expected Websocket Upgrade request"
}
}
}
/** Only checks for upgrade header and then adds UpgradeToWebsocket mock header */
def emulateHttpCore(req: HttpRequest): HttpRequest =
req.header[Upgrade] match {
case Some(upgrade) if upgrade.hasWebsocket req.copy(headers = req.headers :+ upgradeToWebsocketHeaderMock)
case _ req
}
def upgradeToWebsocketHeaderMock: UpgradeToWebsocket =
new InternalCustomHeader("UpgradeToWebsocketMock") with UpgradeToWebsocket {
def requestedProtocols: Seq[String] = Nil
def websocketRoute = handleWebsocketMessages(greeter)
def websocketMultipleProtocolRoute =
handleWebsocketMessagesForProtocol(echo, "echo") ~
handleWebsocketMessagesForProtocol(greeter, "greeter")
def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String]): HttpResponse =
HttpResponse(StatusCodes.SwitchingProtocols)
def greeter: Flow[Message, Message, Any] =
Flow[Message].mapConcat {
case tm: TextMessage TextMessage(Source.single("Hello ") ++ tm.textStream ++ Source.single("!")) :: Nil
case bm: BinaryMessage // ignore binary messages
bm.dataStream.runWith(Sink.ignore)
Nil
}
def echo: Flow[Message, Message, Any] =
Flow[Message]
.buffer(1, OverflowStrategy.backpressure) // needed because a noop flow hasn't any buffer that would start processing
}

View file

@ -168,6 +168,12 @@ case class MissingCookieRejection(cookieName: String) extends Rejection
*/
case object ExpectedWebsocketRequestRejection extends Rejection
/**
* Rejection created when a websocket request was not handled because none of the given subprotocols
* was supported.
*/
case class UnsupportedWebsocketSubprotocolRejection(supportedProtocol: String) extends Rejection
/**
* Rejection created by the `validation` directive as well as for `IllegalArgumentExceptions`
* thrown by domain model constructors (e.g. via `require`).

View file

@ -205,6 +205,12 @@ object RejectionHandler {
complete(BadRequest, "The request's Content-Encoding is not supported. Expected:\n" + supported)
}
.handle { case ExpectedWebsocketRequestRejection complete(BadRequest, "Expected Websocket Upgrade request") }
.handleAll[UnsupportedWebsocketSubprotocolRejection] { rejections
val supported = rejections.map(_.supportedProtocol)
complete(HttpResponse(BadRequest,
entity = s"None of the websocket subprotocols offered in the request are supported. Supported are ${supported.map("'" + _ + "'").mkString(",")}.",
headers = `Sec-WebSocket-Protocol`(supported) :: Nil))
}
.handle { case ValidationRejection(msg, _) complete(BadRequest, msg) }
.handle { case x sys.error("Unhandled rejection: " + x) }
.handleNotFound { complete(NotFound, "The requested resource could not be found.") }

View file

@ -5,20 +5,61 @@
package akka.http.scaladsl.server
package directives
import scala.collection.immutable
import akka.http.scaladsl.model.ws.{ UpgradeToWebsocket, Message }
import akka.stream.scaladsl.Flow
trait WebsocketDirectives {
import RouteDirectives._
import HeaderDirectives._
import BasicDirectives._
/**
* Handles websocket requests with the given handler and rejects other requests with a
* Extract the [[UpgradeToWebsocket]] header if existent. Rejects with an [[ExpectedWebsocketRequestRejection]], otherwise.
*/
def extractUpgradeToWebsocket: Directive1[UpgradeToWebsocket] =
optionalHeaderValueByType[UpgradeToWebsocket]().flatMap {
case Some(upgrade) provide(upgrade)
case None reject(ExpectedWebsocketRequestRejection)
}
/**
* Extract the list of Websocket subprotocols as offered by the client in the [[Sec-Websocket-Protocol]] header if
* this is a Websocket request. Rejects with an [[ExpectedWebsocketRequestRejection]], otherwise.
*/
def extractOfferedWsProtocols: Directive1[immutable.Seq[String]] = extractUpgradeToWebsocket.map(_.requestedProtocols)
/**
* Handles Websocket requests with the given handler and rejects other requests with a
* [[ExpectedWebsocketRequestRejection]].
*/
def handleWebsocketMessages(handler: Flow[Message, Message, Any]): Route =
optionalHeaderValueByType[UpgradeToWebsocket]() {
case Some(upgrade) complete(upgrade.handleMessages(handler))
case None reject(ExpectedWebsocketRequestRejection)
handleWebsocketMessagesForOptionalProtocol(handler, None)
/**
* Handles Websocket requests with the given handler if the given subprotocol is offered in the request and
* rejects other requests with a [[ExpectedWebsocketRequestRejection]] or a [[UnsupportedWebsocketSubprotocolRejection]].
*/
def handleWebsocketMessagesForProtocol(handler: Flow[Message, Message, Any], subprotocol: String): Route =
handleWebsocketMessagesForOptionalProtocol(handler, Some(subprotocol))
/**
* Handles Websocket requests with the given handler and rejects other requests with a
* [[ExpectedWebsocketRequestRejection]].
*
* If the `subprotocol` parameter is None any Websocket request is accepted. If the `subprotocol` parameter is
* `Some(protocol)` a Websocket request is only accepted if the list of subprotocols supported by the client (as
* announced in the Websocket request) contains `protocol`. If the client did not offer the protocol in question
* the request is rejected with a [[UnsupportedWebsocketSubprotocolRejection]] rejection.
*
* To support several subprotocols you may chain several `handleWebsocketMessage` Routes.
*/
def handleWebsocketMessagesForOptionalProtocol(handler: Flow[Message, Message, Any], subprotocol: Option[String]): Route =
extractUpgradeToWebsocket { upgrade
if (subprotocol.forall(sub upgrade.requestedProtocols.exists(_ equalsIgnoreCase sub)))
complete(upgrade.handleMessages(handler, subprotocol))
else
reject(UnsupportedWebsocketSubprotocolRejection(subprotocol.get)) // None.forall == true
}
}