diff --git a/akka-http-tests/src/test/scala/akka/http/server/directives/WebsocketDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/server/directives/WebsocketDirectivesSpec.scala new file mode 100644 index 0000000000..2de7c9ca95 --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/server/directives/WebsocketDirectivesSpec.scala @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.server.directives + +import akka.http.engine.ws.InternalCustomHeader +import akka.http.model +import akka.http.model.headers.{ Connection, UpgradeProtocol, Upgrade } +import akka.http.model.{ HttpRequest, StatusCodes, HttpResponse } +import akka.http.model.ws.{ Message, UpgradeToWebsocket } +import akka.http.server.{ Route, RoutingSpec } +import akka.http.util.Rendering +import akka.stream.FlowMaterializer +import akka.stream.scaladsl.Flow + +class WebsocketDirectivesSpec extends RoutingSpec { + "the handleWebsocketMessages directive" should { + "handle websocket requests" in { + Get("http://localhost/") ~> Upgrade(List(UpgradeProtocol("websocket"))) ~> + emulateHttpCore ~> Route.seal(handleWebsocketMessages(Flow[Message])) ~> + check { + status shouldEqual StatusCodes.SwitchingProtocols + } + } + "reject non-websocket requests" in { + Get("http://localhost/") ~> emulateHttpCore ~> Route.seal(handleWebsocketMessages(Flow[Message])) ~> check { + status shouldEqual StatusCodes.BadRequest + responseAs[String] shouldEqual "Expected Websocket Upgrade request" + } + } + } + + /** Only checks for upgrade header and then adds UpgradeToWebsocket mock header */ + def emulateHttpCore(req: HttpRequest): HttpRequest = + req.header[Upgrade] match { + case Some(upgrade) if upgrade.hasWebsocket ⇒ req.copy(headers = req.headers :+ upgradeToWebsocketHeaderMock) + case _ ⇒ req + } + def upgradeToWebsocketHeaderMock: UpgradeToWebsocket = + new InternalCustomHeader("UpgradeToWebsocketMock") with UpgradeToWebsocket { + def handleMessages(handlerFlow: Flow[Message, Message, Any])(implicit mat: FlowMaterializer): HttpResponse = + HttpResponse(StatusCodes.SwitchingProtocols) + } +} diff --git a/akka-http/src/main/scala/akka/http/server/Directives.scala b/akka-http/src/main/scala/akka/http/server/Directives.scala index 5b877d6025..24064c6c0c 100644 --- a/akka-http/src/main/scala/akka/http/server/Directives.scala +++ b/akka-http/src/main/scala/akka/http/server/Directives.scala @@ -32,5 +32,6 @@ trait Directives extends RouteConcatenation with RouteDirectives with SchemeDirectives with SecurityDirectives + with WebsocketDirectives object Directives extends Directives diff --git a/akka-http/src/main/scala/akka/http/server/Rejection.scala b/akka-http/src/main/scala/akka/http/server/Rejection.scala index 636d1854e2..ed89828dc7 100644 --- a/akka-http/src/main/scala/akka/http/server/Rejection.scala +++ b/akka-http/src/main/scala/akka/http/server/Rejection.scala @@ -163,6 +163,11 @@ case object AuthorizationFailedRejection extends Rejection */ case class MissingCookieRejection(cookieName: String) extends Rejection +/** + * Rejection created when a websocket request was expected but none was found. + */ +case object ExpectedWebsocketRequestRejection extends Rejection + /** * Rejection created by the `validation` directive as well as for `IllegalArgumentExceptions` * thrown by domain model constructors (e.g. via `require`). diff --git a/akka-http/src/main/scala/akka/http/server/RejectionHandler.scala b/akka-http/src/main/scala/akka/http/server/RejectionHandler.scala index 828df9b73e..a467ce6b72 100644 --- a/akka-http/src/main/scala/akka/http/server/RejectionHandler.scala +++ b/akka-http/src/main/scala/akka/http/server/RejectionHandler.scala @@ -206,6 +206,7 @@ object RejectionHandler { val supported = rejections.map(_.supported.value).mkString(" or ") complete(BadRequest, "The request's Content-Encoding is not supported. Expected:\n" + supported) } + .handle { case ExpectedWebsocketRequestRejection ⇒ complete(BadRequest, "Expected Websocket Upgrade request") } .handle { case ValidationRejection(msg, _) ⇒ complete(BadRequest, msg) } .handle { case x ⇒ sys.error("Unhandled rejection: " + x) } .handleNotFound { complete(NotFound, "The requested resource could not be found.") } diff --git a/akka-http/src/main/scala/akka/http/server/directives/WebsocketDirectives.scala b/akka-http/src/main/scala/akka/http/server/directives/WebsocketDirectives.scala new file mode 100644 index 0000000000..7871a0f39e --- /dev/null +++ b/akka-http/src/main/scala/akka/http/server/directives/WebsocketDirectives.scala @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.server +package directives + +import akka.http.model.ws.{ UpgradeToWebsocket, Message } +import akka.stream.scaladsl.Flow + +trait WebsocketDirectives { + import BasicDirectives._ + import RouteDirectives._ + import HeaderDirectives._ + + /** + * Handles websocket requests with the given handler and rejects other requests with a + * [[ExpectedWebsocketRequestRejection]]. + */ + def handleWebsocketMessages(handler: Flow[Message, Message, Any]): Route = + extractFlowMaterializer { implicit mat ⇒ + optionalHeaderValueByType[UpgradeToWebsocket]() { + case Some(upgrade) ⇒ complete(upgrade.handleMessages(handler)) + case None ⇒ reject(ExpectedWebsocketRequestRejection) + } + } +}