diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/Handshake.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/Handshake.scala index e57d07ffde..b6113d7ef5 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/ws/Handshake.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/Handshake.scala @@ -11,6 +11,7 @@ import akka.parboiled2.util.Base64 import akka.stream.FlowMaterializer import akka.stream.scaladsl.Flow +import scala.collection.immutable.Seq import scala.reflect.ClassTag /** @@ -70,6 +71,7 @@ private[http] object Handshake { val version = find[`Sec-WebSocket-Version`] val origin = find[Origin] val protocol = find[`Sec-WebSocket-Protocol`] + val supportedProtocols = protocol.toList.flatMap(_.protocols) val extensions = find[`Sec-WebSocket-Extensions`] def isValidKey(key: String): Boolean = Base64.rfc2045().decode(key).length == 16 @@ -80,8 +82,13 @@ private[http] object Handshake { key.exists(k ⇒ isValidKey(k.key))) { val header = new UpgradeToWebsocketLowLevel { - def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): HttpResponse = - buildResponse(key.get, handlerFlow) + def requestedProtocols: Seq[String] = supportedProtocols + + def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String])(implicit mat: FlowMaterializer): HttpResponse = { + require(subprotocol.forall(chosen ⇒ supportedProtocols.contains(chosen)), + s"Tried to choose invalid subprotocol '$subprotocol' which wasn't offered by the client: [${requestedProtocols.mkString(", ")}]") + buildResponse(key.get, handlerFlow, subprotocol) + } } Some(header) } else None @@ -106,12 +113,13 @@ private[http] object Handshake { concatenated value to obtain a 20-byte value and base64- encoding (see Section 4 of [RFC4648]) this 20-byte hash. */ - def buildResponse(key: `Sec-WebSocket-Key`, handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): HttpResponse = + def buildResponse(key: `Sec-WebSocket-Key`, handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String])(implicit mat: FlowMaterializer): HttpResponse = HttpResponse( StatusCodes.SwitchingProtocols, - List( - Upgrade(List(UpgradeProtocol("websocket"))), - Connection(List("upgrade")), - `Sec-WebSocket-Accept`.forKey(key), - UpgradeToWebsocketResponseHeader(handlerFlow))) + subprotocol.map(p ⇒ `Sec-WebSocket-Protocol`(Seq(p))).toList ::: + List( + Upgrade(List(UpgradeProtocol("websocket"))), + Connection(List("upgrade")), + `Sec-WebSocket-Accept`.forKey(key), + UpgradeToWebsocketResponseHeader(handlerFlow))) } diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketLowLevel.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketLowLevel.scala index 1c40e5d489..1408d41be6 100644 --- a/akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketLowLevel.scala +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/UpgradeToWebsocketLowLevel.scala @@ -25,8 +25,8 @@ private[http] abstract class UpgradeToWebsocketLowLevel extends InternalCustomHe * * INTERNAL API (for now) */ - private[http] def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): HttpResponse + private[http] def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String] = None)(implicit mat: FlowMaterializer): HttpResponse - override def handleMessages(handlerFlow: Flow[Message, Message, Any])(implicit mat: FlowMaterializer): HttpResponse = - handleFrames(Websocket.handleMessages(handlerFlow)) + override def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String] = None)(implicit mat: FlowMaterializer): HttpResponse = + handleFrames(Websocket.handleMessages(handlerFlow), subprotocol) } diff --git a/akka-http-core/src/main/scala/akka/http/model/ws/UpgradeToWebsocket.scala b/akka-http-core/src/main/scala/akka/http/model/ws/UpgradeToWebsocket.scala index c63b4ea2a4..7325d46835 100644 --- a/akka-http-core/src/main/scala/akka/http/model/ws/UpgradeToWebsocket.scala +++ b/akka-http-core/src/main/scala/akka/http/model/ws/UpgradeToWebsocket.scala @@ -4,6 +4,7 @@ package akka.http.model.ws +import scala.collection.immutable import akka.stream.FlowMaterializer import akka.stream.scaladsl.Flow @@ -13,16 +14,22 @@ import akka.http.model.{ HttpHeader, HttpResponse } * A custom header that will be added to an Websocket upgrade HttpRequest that * enables a request handler to upgrade this connection to a Websocket connection and * registers a Websocket handler. - * - * FIXME: needs to be able to choose subprotocols as possibly agreed on in the websocket handshake */ trait UpgradeToWebsocket extends HttpHeader { + /** + * A sequence of protocols the client accepts. + * + * See http://tools.ietf.org/html/rfc6455#section-1.9 + */ + def requestedProtocols: immutable.Seq[String] + /** * The high-level interface to create a Websocket server based on "messages". * * Returns a response to return in a request handler that will signal the * low-level HTTP implementation to upgrade the connection to Websocket and - * use the supplied handler to handle incoming Websocket messages. + * use the supplied handler to handle incoming Websocket messages. Optionally, + * a subprotocol out of the ones requested by the client can be chosen. */ - def handleMessages(handlerFlow: Flow[Message, Message, Any])(implicit mat: FlowMaterializer): HttpResponse + def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String] = None)(implicit mat: FlowMaterializer): HttpResponse } diff --git a/akka-http-tests/src/test/scala/akka/http/server/directives/WebsocketDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/server/directives/WebsocketDirectivesSpec.scala index 2de7c9ca95..e05c6ab17e 100644 --- a/akka-http-tests/src/test/scala/akka/http/server/directives/WebsocketDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/server/directives/WebsocketDirectivesSpec.scala @@ -14,6 +14,8 @@ import akka.http.util.Rendering import akka.stream.FlowMaterializer import akka.stream.scaladsl.Flow +import scala.collection.immutable.Seq + class WebsocketDirectivesSpec extends RoutingSpec { "the handleWebsocketMessages directive" should { "handle websocket requests" in { @@ -39,7 +41,9 @@ class WebsocketDirectivesSpec extends RoutingSpec { } def upgradeToWebsocketHeaderMock: UpgradeToWebsocket = new InternalCustomHeader("UpgradeToWebsocketMock") with UpgradeToWebsocket { - def handleMessages(handlerFlow: Flow[Message, Message, Any])(implicit mat: FlowMaterializer): HttpResponse = + def requestedProtocols: Seq[String] = Nil + + def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String])(implicit mat: FlowMaterializer): HttpResponse = HttpResponse(StatusCodes.SwitchingProtocols) } }