+htp #17726 add support for WS subprotocols to WebsocketDirectives
Also, update WebsocketDirectiveSpecs to the new WS testing infrastructure.
This commit is contained in:
parent
12edec7073
commit
060ea707c9
4 changed files with 137 additions and 26 deletions
|
|
@ -4,42 +4,100 @@
|
|||
|
||||
package akka.http.scaladsl.server.directives
|
||||
|
||||
import scala.collection.immutable.Seq
|
||||
import akka.http.impl.engine.ws.InternalCustomHeader
|
||||
import akka.http.scaladsl.model.headers.{ UpgradeProtocol, Upgrade }
|
||||
import akka.http.scaladsl.model.{ HttpRequest, StatusCodes, HttpResponse }
|
||||
import akka.http.scaladsl.model.ws.{ Message, UpgradeToWebsocket }
|
||||
import akka.http.scaladsl.server.{ Route, RoutingSpec }
|
||||
import akka.stream.scaladsl.Flow
|
||||
import akka.util.ByteString
|
||||
|
||||
import akka.stream.OverflowStrategy
|
||||
import akka.stream.scaladsl.{ Source, Sink, Flow }
|
||||
|
||||
import akka.http.scaladsl.testkit.WSProbe
|
||||
|
||||
import akka.http.scaladsl.model.headers.`Sec-WebSocket-Protocol`
|
||||
import akka.http.scaladsl.model.StatusCodes
|
||||
import akka.http.scaladsl.model.ws._
|
||||
import akka.http.scaladsl.server.{ UnsupportedWebsocketSubprotocolRejection, ExpectedWebsocketRequestRejection, Route, RoutingSpec }
|
||||
|
||||
class WebsocketDirectivesSpec extends RoutingSpec {
|
||||
"the handleWebsocketMessages directive" should {
|
||||
"handle websocket requests" in {
|
||||
Get("http://localhost/") ~> Upgrade(List(UpgradeProtocol("websocket"))) ~>
|
||||
emulateHttpCore ~> Route.seal(handleWebsocketMessages(Flow[Message])) ~>
|
||||
val wsClient = WSProbe()
|
||||
|
||||
WS("http://localhost/", wsClient.flow) ~> websocketRoute ~>
|
||||
check {
|
||||
status shouldEqual StatusCodes.SwitchingProtocols
|
||||
isWebsocketUpgrade shouldEqual true
|
||||
wsClient.sendMessage("Peter")
|
||||
wsClient.expectMessage("Hello Peter!")
|
||||
|
||||
wsClient.sendMessage(BinaryMessage(ByteString("abcdef")))
|
||||
// wsClient.expectNoMessage() // will be checked implicitly by next expectation
|
||||
|
||||
wsClient.sendMessage("John")
|
||||
wsClient.expectMessage("Hello John!")
|
||||
|
||||
wsClient.sendCompletion()
|
||||
wsClient.expectCompletion()
|
||||
}
|
||||
}
|
||||
"choose subprotocol from offered ones" in {
|
||||
val wsClient = WSProbe()
|
||||
|
||||
WS("http://localhost/", wsClient.flow, List("other", "echo", "greeter")) ~> websocketMultipleProtocolRoute ~>
|
||||
check {
|
||||
expectWebsocketUpgradeWithProtocol { protocol ⇒
|
||||
protocol shouldEqual "echo"
|
||||
|
||||
wsClient.sendMessage("Peter")
|
||||
wsClient.expectMessage("Peter")
|
||||
|
||||
wsClient.sendMessage(BinaryMessage(ByteString("abcdef")))
|
||||
wsClient.expectMessage(ByteString("abcdef"))
|
||||
|
||||
wsClient.sendMessage("John")
|
||||
wsClient.expectMessage("John")
|
||||
|
||||
wsClient.sendCompletion()
|
||||
wsClient.expectCompletion()
|
||||
}
|
||||
}
|
||||
}
|
||||
"reject websocket requests if no subprotocol matches" in {
|
||||
WS("http://localhost/", Flow[Message], List("other")) ~> websocketMultipleProtocolRoute ~> check {
|
||||
rejections.collect {
|
||||
case UnsupportedWebsocketSubprotocolRejection(p) ⇒ p
|
||||
}.toSet shouldEqual Set("greeter", "echo")
|
||||
}
|
||||
|
||||
WS("http://localhost/", Flow[Message], List("other")) ~> Route.seal(websocketMultipleProtocolRoute) ~> check {
|
||||
status shouldEqual StatusCodes.BadRequest
|
||||
responseAs[String] shouldEqual "None of the websocket subprotocols offered in the request are supported. Supported are 'echo','greeter'."
|
||||
header[`Sec-WebSocket-Protocol`].get.protocols.toSet shouldEqual Set("greeter", "echo")
|
||||
}
|
||||
}
|
||||
"reject non-websocket requests" in {
|
||||
Get("http://localhost/") ~> emulateHttpCore ~> Route.seal(handleWebsocketMessages(Flow[Message])) ~> check {
|
||||
Get("http://localhost/") ~> websocketRoute ~> check {
|
||||
rejection shouldEqual ExpectedWebsocketRequestRejection
|
||||
}
|
||||
|
||||
Get("http://localhost/") ~> Route.seal(websocketRoute) ~> check {
|
||||
status shouldEqual StatusCodes.BadRequest
|
||||
responseAs[String] shouldEqual "Expected Websocket Upgrade request"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Only checks for upgrade header and then adds UpgradeToWebsocket mock header */
|
||||
def emulateHttpCore(req: HttpRequest): HttpRequest =
|
||||
req.header[Upgrade] match {
|
||||
case Some(upgrade) if upgrade.hasWebsocket ⇒ req.copy(headers = req.headers :+ upgradeToWebsocketHeaderMock)
|
||||
case _ ⇒ req
|
||||
}
|
||||
def upgradeToWebsocketHeaderMock: UpgradeToWebsocket =
|
||||
new InternalCustomHeader("UpgradeToWebsocketMock") with UpgradeToWebsocket {
|
||||
def requestedProtocols: Seq[String] = Nil
|
||||
def websocketRoute = handleWebsocketMessages(greeter)
|
||||
def websocketMultipleProtocolRoute =
|
||||
handleWebsocketMessagesForProtocol(echo, "echo") ~
|
||||
handleWebsocketMessagesForProtocol(greeter, "greeter")
|
||||
|
||||
def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String]): HttpResponse =
|
||||
HttpResponse(StatusCodes.SwitchingProtocols)
|
||||
def greeter: Flow[Message, Message, Any] =
|
||||
Flow[Message].mapConcat {
|
||||
case tm: TextMessage ⇒ TextMessage(Source.single("Hello ") ++ tm.textStream ++ Source.single("!")) :: Nil
|
||||
case bm: BinaryMessage ⇒ // ignore binary messages
|
||||
bm.dataStream.runWith(Sink.ignore)
|
||||
Nil
|
||||
}
|
||||
|
||||
def echo: Flow[Message, Message, Any] =
|
||||
Flow[Message]
|
||||
.buffer(1, OverflowStrategy.backpressure) // needed because a noop flow hasn't any buffer that would start processing
|
||||
}
|
||||
|
|
|
|||
|
|
@ -168,6 +168,12 @@ case class MissingCookieRejection(cookieName: String) extends Rejection
|
|||
*/
|
||||
case object ExpectedWebsocketRequestRejection extends Rejection
|
||||
|
||||
/**
|
||||
* Rejection created when a websocket request was not handled because none of the given subprotocols
|
||||
* was supported.
|
||||
*/
|
||||
case class UnsupportedWebsocketSubprotocolRejection(supportedProtocol: String) extends Rejection
|
||||
|
||||
/**
|
||||
* Rejection created by the `validation` directive as well as for `IllegalArgumentExceptions`
|
||||
* thrown by domain model constructors (e.g. via `require`).
|
||||
|
|
|
|||
|
|
@ -205,6 +205,12 @@ object RejectionHandler {
|
|||
complete(BadRequest, "The request's Content-Encoding is not supported. Expected:\n" + supported)
|
||||
}
|
||||
.handle { case ExpectedWebsocketRequestRejection ⇒ complete(BadRequest, "Expected Websocket Upgrade request") }
|
||||
.handleAll[UnsupportedWebsocketSubprotocolRejection] { rejections ⇒
|
||||
val supported = rejections.map(_.supportedProtocol)
|
||||
complete(HttpResponse(BadRequest,
|
||||
entity = s"None of the websocket subprotocols offered in the request are supported. Supported are ${supported.map("'" + _ + "'").mkString(",")}.",
|
||||
headers = `Sec-WebSocket-Protocol`(supported) :: Nil))
|
||||
}
|
||||
.handle { case ValidationRejection(msg, _) ⇒ complete(BadRequest, msg) }
|
||||
.handle { case x ⇒ sys.error("Unhandled rejection: " + x) }
|
||||
.handleNotFound { complete(NotFound, "The requested resource could not be found.") }
|
||||
|
|
|
|||
|
|
@ -5,20 +5,61 @@
|
|||
package akka.http.scaladsl.server
|
||||
package directives
|
||||
|
||||
import scala.collection.immutable
|
||||
|
||||
import akka.http.scaladsl.model.ws.{ UpgradeToWebsocket, Message }
|
||||
import akka.stream.scaladsl.Flow
|
||||
|
||||
trait WebsocketDirectives {
|
||||
import RouteDirectives._
|
||||
import HeaderDirectives._
|
||||
import BasicDirectives._
|
||||
|
||||
/**
|
||||
* Handles websocket requests with the given handler and rejects other requests with a
|
||||
* Extract the [[UpgradeToWebsocket]] header if existent. Rejects with an [[ExpectedWebsocketRequestRejection]], otherwise.
|
||||
*/
|
||||
def extractUpgradeToWebsocket: Directive1[UpgradeToWebsocket] =
|
||||
optionalHeaderValueByType[UpgradeToWebsocket]().flatMap {
|
||||
case Some(upgrade) ⇒ provide(upgrade)
|
||||
case None ⇒ reject(ExpectedWebsocketRequestRejection)
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the list of Websocket subprotocols as offered by the client in the [[Sec-Websocket-Protocol]] header if
|
||||
* this is a Websocket request. Rejects with an [[ExpectedWebsocketRequestRejection]], otherwise.
|
||||
*/
|
||||
def extractOfferedWsProtocols: Directive1[immutable.Seq[String]] = extractUpgradeToWebsocket.map(_.requestedProtocols)
|
||||
|
||||
/**
|
||||
* Handles Websocket requests with the given handler and rejects other requests with a
|
||||
* [[ExpectedWebsocketRequestRejection]].
|
||||
*/
|
||||
def handleWebsocketMessages(handler: Flow[Message, Message, Any]): Route =
|
||||
optionalHeaderValueByType[UpgradeToWebsocket]() {
|
||||
case Some(upgrade) ⇒ complete(upgrade.handleMessages(handler))
|
||||
case None ⇒ reject(ExpectedWebsocketRequestRejection)
|
||||
handleWebsocketMessagesForOptionalProtocol(handler, None)
|
||||
|
||||
/**
|
||||
* Handles Websocket requests with the given handler if the given subprotocol is offered in the request and
|
||||
* rejects other requests with a [[ExpectedWebsocketRequestRejection]] or a [[UnsupportedWebsocketSubprotocolRejection]].
|
||||
*/
|
||||
def handleWebsocketMessagesForProtocol(handler: Flow[Message, Message, Any], subprotocol: String): Route =
|
||||
handleWebsocketMessagesForOptionalProtocol(handler, Some(subprotocol))
|
||||
|
||||
/**
|
||||
* Handles Websocket requests with the given handler and rejects other requests with a
|
||||
* [[ExpectedWebsocketRequestRejection]].
|
||||
*
|
||||
* If the `subprotocol` parameter is None any Websocket request is accepted. If the `subprotocol` parameter is
|
||||
* `Some(protocol)` a Websocket request is only accepted if the list of subprotocols supported by the client (as
|
||||
* announced in the Websocket request) contains `protocol`. If the client did not offer the protocol in question
|
||||
* the request is rejected with a [[UnsupportedWebsocketSubprotocolRejection]] rejection.
|
||||
*
|
||||
* To support several subprotocols you may chain several `handleWebsocketMessage` Routes.
|
||||
*/
|
||||
def handleWebsocketMessagesForOptionalProtocol(handler: Flow[Message, Message, Any], subprotocol: Option[String]): Route =
|
||||
extractUpgradeToWebsocket { upgrade ⇒
|
||||
if (subprotocol.forall(sub ⇒ upgrade.requestedProtocols.exists(_ equalsIgnoreCase sub)))
|
||||
complete(upgrade.handleMessages(handler, subprotocol))
|
||||
else
|
||||
reject(UnsupportedWebsocketSubprotocolRejection(subprotocol.get)) // None.forall == true
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue