diff --git a/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTest.scala b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTest.scala index 5ff2a08608..9cadec4b36 100644 --- a/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTest.scala +++ b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/RouteTest.scala @@ -17,10 +17,10 @@ import akka.http.scaladsl.util.FastFuture import akka.http.scaladsl.server._ import akka.http.scaladsl.unmarshalling._ import akka.http.scaladsl.model._ -import headers.Host +import akka.http.scaladsl.model.headers.{ Upgrade, `Sec-WebSocket-Protocol`, Host } import FastFuture._ -trait RouteTest extends RequestBuilding with RouteTestResultComponent with MarshallingTestUtils { +trait RouteTest extends RequestBuilding with WSTestRequestBuilding with RouteTestResultComponent with MarshallingTestUtils { this: TestFrameworkInterface ⇒ /** Override to supply a custom ActorSystem */ @@ -88,6 +88,21 @@ trait RouteTest extends RequestBuilding with RouteTestResultComponent with Marsh if (r.size == 1) r.head else failTest("Expected a single rejection but got %s (%s)".format(r.size, r)) } + def isWebsocketUpgrade: Boolean = + status == StatusCodes.SwitchingProtocols && header[Upgrade].exists(_.hasWebsocket) + + /** + * Asserts that the received response is a Websocket upgrade response and the extracts + * the chosen subprotocol and passes it to the handler. + */ + def expectWebsocketUpgradeWithProtocol(body: String ⇒ Unit): Unit = { + if (!isWebsocketUpgrade) failTest("Response was no Websocket Upgrade response") + header[`Sec-WebSocket-Protocol`] match { + case Some(`Sec-WebSocket-Protocol`(Seq(protocol))) ⇒ body(protocol) + case _ ⇒ failTest("No Websocket protocol found in response.") + } + } + /** * A dummy that can be used as `~> runRoute` to run the route but without blocking for the result. * The result of the pipeline is the result that can later be checked with `check`. See the diff --git a/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/WSProbe.scala b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/WSProbe.scala new file mode 100644 index 0000000000..92ebd379b3 --- /dev/null +++ b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/WSProbe.scala @@ -0,0 +1,142 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.scaladsl.testkit + +import scala.concurrent.duration._ + +import akka.util.ByteString + +import akka.actor.ActorSystem + +import akka.stream.Materializer +import akka.stream.scaladsl.{ Keep, Source, Sink, Flow } +import akka.stream.testkit.{ TestPublisher, TestSubscriber } + +import akka.http.impl.util._ +import akka.http.scaladsl.model.ws.{ BinaryMessage, TextMessage, Message } + +/** + * A WSProbe is a probe that implements a `Flow[Message, Message, Unit]` for testing + * websocket code. + * + * Requesting elements is handled automatically. + */ +trait WSProbe { + def flow: Flow[Message, Message, Unit] + + /** + * Send the given messages out of the flow. + */ + def sendMessage(message: Message): Unit + + /** + * Send a text message containing the given string out of the flow. + */ + def sendMessage(text: String): Unit + + /** + * Send a binary message containing the given bytes out of the flow. + */ + def sendMessage(bytes: ByteString): Unit + + /** + * Complete the output side of the flow. + */ + def sendCompletion(): Unit + + /** + * Expect a message on the input side of the flow. + */ + def expectMessage(): Message + + /** + * Expect a text message on the input side of the flow and compares its payload with the given one. + * If the received message is streamed its contents are collected and then asserted against the given + * String. + */ + def expectMessage(text: String): Unit + + /** + * Expect a binary message on the input side of the flow and compares its payload with the given one. + * If the received message is streamed its contents are collected and then asserted against the given + * ByteString. + */ + def expectMessage(bytes: ByteString): Unit + + /** + * Expect no message on the input side of the flow. + */ + def expectNoMessage(): Unit + + /** + * Expect no message on the input side of the flow for the given maximum duration. + */ + def expectNoMessage(max: FiniteDuration): Unit + + /** + * Expect completion on the input side of the flow. + */ + def expectCompletion(): Unit + + /** + * The underlying probe for the ingoing side of this probe. Can be used if the methods + * on WSProbe don't allow fine enough control over the message flow. + */ + def inProbe: TestSubscriber.Probe[Message] + + /** + * The underlying probe for the ingoing side of this probe. Can be used if the methods + * on WSProbe don't allow fine enough control over the message flow. + */ + def outProbe: TestPublisher.Probe[Message] +} + +object WSProbe { + /** + * Creates a WSProbe to use in tests against websocket handlers. + * @param maxChunks The maximum number of chunks to collect for streamed messages. + * @param maxChunkCollectionMills The maximum time in milliseconds to collect chunks for streamed messages. + */ + def apply(maxChunks: Int = 1000, maxChunkCollectionMills: Long = 5000)(implicit system: ActorSystem, materializer: Materializer): WSProbe = + new WSProbe { + val subscriber = TestSubscriber.probe[Message]() + val publisher = TestPublisher.probe[Message]() + + def flow: Flow[Message, Message, Unit] = Flow.wrap(Sink(subscriber), Source(publisher))(Keep.none) + + def sendMessage(message: Message): Unit = publisher.sendNext(message) + def sendMessage(text: String): Unit = sendMessage(TextMessage(text)) + def sendMessage(bytes: ByteString): Unit = sendMessage(BinaryMessage(bytes)) + def sendCompletion(): Unit = publisher.sendComplete() + + def expectMessage(): Message = subscriber.requestNext() + def expectMessage(text: String): Unit = expectMessage() match { + case t: TextMessage ⇒ + val collectedMessage = collect(t.textStream)(_ + _) + assert(collectedMessage == text, s"""Expected TextMessage("$text") but got TextMessage("$collectedMessage")""") + case _ ⇒ throw new AssertionError(s"""Expected TextMessage("$text") but got BinaryMessage""") + } + def expectMessage(bytes: ByteString): Unit = expectMessage() match { + case t: BinaryMessage ⇒ + val collectedMessage = collect(t.dataStream)(_ ++ _) + assert(collectedMessage == bytes, s"""Expected BinaryMessage("$bytes") but got BinaryMessage("$collectedMessage")""") + case _ ⇒ throw new AssertionError(s"""Expected BinaryMessage("$bytes") but got TextMessage""") + } + + def expectNoMessage(): Unit = subscriber.expectNoMsg() + def expectNoMessage(max: FiniteDuration): Unit = subscriber.expectNoMsg(max) + + def expectCompletion(): Unit = subscriber.expectComplete() + + def inProbe: TestSubscriber.Probe[Message] = subscriber + def outProbe: TestPublisher.Probe[Message] = publisher + + private def collect[T](stream: Source[T, Any])(reduce: (T, T) ⇒ T): T = + stream.grouped(maxChunks) + .runWith(Sink.head) + .awaitResult(maxChunkCollectionMills.millis) + .reduce(reduce) + } +} \ No newline at end of file diff --git a/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/WSTestRequestBuilding.scala b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/WSTestRequestBuilding.scala new file mode 100644 index 0000000000..d210435771 --- /dev/null +++ b/akka-http-testkit/src/main/scala/akka/http/scaladsl/testkit/WSTestRequestBuilding.scala @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.scaladsl.testkit + +import akka.http.impl.engine.ws.InternalCustomHeader +import akka.http.scaladsl.model.headers.{ UpgradeProtocol, Upgrade, `Sec-WebSocket-Protocol` } +import akka.http.scaladsl.model.{ StatusCodes, HttpResponse, HttpRequest, Uri } +import akka.http.scaladsl.model.ws.{ UpgradeToWebsocket, Message } +import akka.stream.scaladsl.Flow + +import scala.collection.immutable + +trait WSTestRequestBuilding { self: RouteTest ⇒ + def WS(uri: Uri, clientSideHandler: Flow[Message, Message, Any], subprotocols: Seq[String] = Nil)(): HttpRequest = + HttpRequest(uri = uri) + .addHeader(new InternalCustomHeader("UpgradeToWebsocketTestHeader") with UpgradeToWebsocket { + def requestedProtocols: immutable.Seq[String] = subprotocols.toList + + def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String]): HttpResponse = { + clientSideHandler.join(handlerFlow).run() + HttpResponse(StatusCodes.SwitchingProtocols, + headers = + Upgrade(UpgradeProtocol("websocket") :: Nil) :: + subprotocol.map(p ⇒ `Sec-WebSocket-Protocol`(p :: Nil)).toList) + } + }) +}