From 060ea707c9124fe5bc27453a2d08dcd2c93ca298 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Thu, 20 Aug 2015 16:20:39 +0200 Subject: [PATCH] +htp #17726 add support for WS subprotocols to WebsocketDirectives Also, update WebsocketDirectiveSpecs to the new WS testing infrastructure. --- .../directives/WebsocketDirectivesSpec.scala | 102 ++++++++++++++---- .../akka/http/scaladsl/server/Rejection.scala | 6 ++ .../scaladsl/server/RejectionHandler.scala | 6 ++ .../directives/WebsocketDirectives.scala | 49 ++++++++- 4 files changed, 137 insertions(+), 26 deletions(-) diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/WebsocketDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/WebsocketDirectivesSpec.scala index b5dda7b12a..daa6f4b712 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/WebsocketDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/WebsocketDirectivesSpec.scala @@ -4,42 +4,100 @@ package akka.http.scaladsl.server.directives -import scala.collection.immutable.Seq -import akka.http.impl.engine.ws.InternalCustomHeader -import akka.http.scaladsl.model.headers.{ UpgradeProtocol, Upgrade } -import akka.http.scaladsl.model.{ HttpRequest, StatusCodes, HttpResponse } -import akka.http.scaladsl.model.ws.{ Message, UpgradeToWebsocket } -import akka.http.scaladsl.server.{ Route, RoutingSpec } -import akka.stream.scaladsl.Flow +import akka.util.ByteString + +import akka.stream.OverflowStrategy +import akka.stream.scaladsl.{ Source, Sink, Flow } + +import akka.http.scaladsl.testkit.WSProbe + +import akka.http.scaladsl.model.headers.`Sec-WebSocket-Protocol` +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.model.ws._ +import akka.http.scaladsl.server.{ UnsupportedWebsocketSubprotocolRejection, ExpectedWebsocketRequestRejection, Route, RoutingSpec } class WebsocketDirectivesSpec extends RoutingSpec { "the handleWebsocketMessages directive" should { "handle websocket requests" in { - Get("http://localhost/") ~> Upgrade(List(UpgradeProtocol("websocket"))) ~> - emulateHttpCore ~> Route.seal(handleWebsocketMessages(Flow[Message])) ~> + val wsClient = WSProbe() + + WS("http://localhost/", wsClient.flow) ~> websocketRoute ~> check { - status shouldEqual StatusCodes.SwitchingProtocols + isWebsocketUpgrade shouldEqual true + wsClient.sendMessage("Peter") + wsClient.expectMessage("Hello Peter!") + + wsClient.sendMessage(BinaryMessage(ByteString("abcdef"))) + // wsClient.expectNoMessage() // will be checked implicitly by next expectation + + wsClient.sendMessage("John") + wsClient.expectMessage("Hello John!") + + wsClient.sendCompletion() + wsClient.expectCompletion() } } + "choose subprotocol from offered ones" in { + val wsClient = WSProbe() + + WS("http://localhost/", wsClient.flow, List("other", "echo", "greeter")) ~> websocketMultipleProtocolRoute ~> + check { + expectWebsocketUpgradeWithProtocol { protocol ⇒ + protocol shouldEqual "echo" + + wsClient.sendMessage("Peter") + wsClient.expectMessage("Peter") + + wsClient.sendMessage(BinaryMessage(ByteString("abcdef"))) + wsClient.expectMessage(ByteString("abcdef")) + + wsClient.sendMessage("John") + wsClient.expectMessage("John") + + wsClient.sendCompletion() + wsClient.expectCompletion() + } + } + } + "reject websocket requests if no subprotocol matches" in { + WS("http://localhost/", Flow[Message], List("other")) ~> websocketMultipleProtocolRoute ~> check { + rejections.collect { + case UnsupportedWebsocketSubprotocolRejection(p) ⇒ p + }.toSet shouldEqual Set("greeter", "echo") + } + + WS("http://localhost/", Flow[Message], List("other")) ~> Route.seal(websocketMultipleProtocolRoute) ~> check { + status shouldEqual StatusCodes.BadRequest + responseAs[String] shouldEqual "None of the websocket subprotocols offered in the request are supported. Supported are 'echo','greeter'." + header[`Sec-WebSocket-Protocol`].get.protocols.toSet shouldEqual Set("greeter", "echo") + } + } "reject non-websocket requests" in { - Get("http://localhost/") ~> emulateHttpCore ~> Route.seal(handleWebsocketMessages(Flow[Message])) ~> check { + Get("http://localhost/") ~> websocketRoute ~> check { + rejection shouldEqual ExpectedWebsocketRequestRejection + } + + Get("http://localhost/") ~> Route.seal(websocketRoute) ~> check { status shouldEqual StatusCodes.BadRequest responseAs[String] shouldEqual "Expected Websocket Upgrade request" } } } - /** Only checks for upgrade header and then adds UpgradeToWebsocket mock header */ - def emulateHttpCore(req: HttpRequest): HttpRequest = - req.header[Upgrade] match { - case Some(upgrade) if upgrade.hasWebsocket ⇒ req.copy(headers = req.headers :+ upgradeToWebsocketHeaderMock) - case _ ⇒ req - } - def upgradeToWebsocketHeaderMock: UpgradeToWebsocket = - new InternalCustomHeader("UpgradeToWebsocketMock") with UpgradeToWebsocket { - def requestedProtocols: Seq[String] = Nil + def websocketRoute = handleWebsocketMessages(greeter) + def websocketMultipleProtocolRoute = + handleWebsocketMessagesForProtocol(echo, "echo") ~ + handleWebsocketMessagesForProtocol(greeter, "greeter") - def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String]): HttpResponse = - HttpResponse(StatusCodes.SwitchingProtocols) + def greeter: Flow[Message, Message, Any] = + Flow[Message].mapConcat { + case tm: TextMessage ⇒ TextMessage(Source.single("Hello ") ++ tm.textStream ++ Source.single("!")) :: Nil + case bm: BinaryMessage ⇒ // ignore binary messages + bm.dataStream.runWith(Sink.ignore) + Nil } + + def echo: Flow[Message, Message, Any] = + Flow[Message] + .buffer(1, OverflowStrategy.backpressure) // needed because a noop flow hasn't any buffer that would start processing } diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/Rejection.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/Rejection.scala index f3e8f46cc3..95db843ece 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/server/Rejection.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/Rejection.scala @@ -168,6 +168,12 @@ case class MissingCookieRejection(cookieName: String) extends Rejection */ case object ExpectedWebsocketRequestRejection extends Rejection +/** + * Rejection created when a websocket request was not handled because none of the given subprotocols + * was supported. + */ +case class UnsupportedWebsocketSubprotocolRejection(supportedProtocol: String) extends Rejection + /** * Rejection created by the `validation` directive as well as for `IllegalArgumentExceptions` * thrown by domain model constructors (e.g. via `require`). diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/RejectionHandler.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/RejectionHandler.scala index 2019949b7b..dc3d88f871 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/server/RejectionHandler.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/RejectionHandler.scala @@ -205,6 +205,12 @@ object RejectionHandler { complete(BadRequest, "The request's Content-Encoding is not supported. Expected:\n" + supported) } .handle { case ExpectedWebsocketRequestRejection ⇒ complete(BadRequest, "Expected Websocket Upgrade request") } + .handleAll[UnsupportedWebsocketSubprotocolRejection] { rejections ⇒ + val supported = rejections.map(_.supportedProtocol) + complete(HttpResponse(BadRequest, + entity = s"None of the websocket subprotocols offered in the request are supported. Supported are ${supported.map("'" + _ + "'").mkString(",")}.", + headers = `Sec-WebSocket-Protocol`(supported) :: Nil)) + } .handle { case ValidationRejection(msg, _) ⇒ complete(BadRequest, msg) } .handle { case x ⇒ sys.error("Unhandled rejection: " + x) } .handleNotFound { complete(NotFound, "The requested resource could not be found.") } diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/WebsocketDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/WebsocketDirectives.scala index 6117947b44..bf527852d2 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/WebsocketDirectives.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/WebsocketDirectives.scala @@ -5,20 +5,61 @@ package akka.http.scaladsl.server package directives +import scala.collection.immutable + import akka.http.scaladsl.model.ws.{ UpgradeToWebsocket, Message } import akka.stream.scaladsl.Flow trait WebsocketDirectives { import RouteDirectives._ import HeaderDirectives._ + import BasicDirectives._ /** - * Handles websocket requests with the given handler and rejects other requests with a + * Extract the [[UpgradeToWebsocket]] header if existent. Rejects with an [[ExpectedWebsocketRequestRejection]], otherwise. + */ + def extractUpgradeToWebsocket: Directive1[UpgradeToWebsocket] = + optionalHeaderValueByType[UpgradeToWebsocket]().flatMap { + case Some(upgrade) ⇒ provide(upgrade) + case None ⇒ reject(ExpectedWebsocketRequestRejection) + } + + /** + * Extract the list of Websocket subprotocols as offered by the client in the [[Sec-Websocket-Protocol]] header if + * this is a Websocket request. Rejects with an [[ExpectedWebsocketRequestRejection]], otherwise. + */ + def extractOfferedWsProtocols: Directive1[immutable.Seq[String]] = extractUpgradeToWebsocket.map(_.requestedProtocols) + + /** + * Handles Websocket requests with the given handler and rejects other requests with a * [[ExpectedWebsocketRequestRejection]]. */ def handleWebsocketMessages(handler: Flow[Message, Message, Any]): Route = - optionalHeaderValueByType[UpgradeToWebsocket]() { - case Some(upgrade) ⇒ complete(upgrade.handleMessages(handler)) - case None ⇒ reject(ExpectedWebsocketRequestRejection) + handleWebsocketMessagesForOptionalProtocol(handler, None) + + /** + * Handles Websocket requests with the given handler if the given subprotocol is offered in the request and + * rejects other requests with a [[ExpectedWebsocketRequestRejection]] or a [[UnsupportedWebsocketSubprotocolRejection]]. + */ + def handleWebsocketMessagesForProtocol(handler: Flow[Message, Message, Any], subprotocol: String): Route = + handleWebsocketMessagesForOptionalProtocol(handler, Some(subprotocol)) + + /** + * Handles Websocket requests with the given handler and rejects other requests with a + * [[ExpectedWebsocketRequestRejection]]. + * + * If the `subprotocol` parameter is None any Websocket request is accepted. If the `subprotocol` parameter is + * `Some(protocol)` a Websocket request is only accepted if the list of subprotocols supported by the client (as + * announced in the Websocket request) contains `protocol`. If the client did not offer the protocol in question + * the request is rejected with a [[UnsupportedWebsocketSubprotocolRejection]] rejection. + * + * To support several subprotocols you may chain several `handleWebsocketMessage` Routes. + */ + def handleWebsocketMessagesForOptionalProtocol(handler: Flow[Message, Message, Any], subprotocol: Option[String]): Route = + extractUpgradeToWebsocket { upgrade ⇒ + if (subprotocol.forall(sub ⇒ upgrade.requestedProtocols.exists(_ equalsIgnoreCase sub))) + complete(upgrade.handleMessages(handler, subprotocol)) + else + reject(UnsupportedWebsocketSubprotocolRejection(subprotocol.get)) // None.forall == true } }