diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala index e59f5d9568..4101f87366 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala @@ -4,13 +4,12 @@ package akka.http.impl.engine.rendering -import akka.http.impl.engine.ws.{ WebsocketSwitch, UpgradeToWebsocketResponseHeader, Handshake } +import akka.http.impl.engine.ws.{ FrameEvent, UpgradeToWebsocketResponseHeader } import scala.annotation.tailrec import akka.event.LoggingAdapter import akka.util.ByteString -import akka.stream.OperationAttributes._ -import akka.stream.scaladsl.Source +import akka.stream.scaladsl.{ Flow, Source } import akka.stream.stage._ import akka.http.scaladsl.model._ import akka.http.impl.util._ @@ -23,8 +22,7 @@ import headers._ */ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Server], responseHeaderSizeHint: Int, - log: LoggingAdapter, - websocketSwitch: Option[WebsocketSwitch] = None) { + log: LoggingAdapter) { private val renderDefaultServerHeader: Rendering ⇒ Unit = serverHeader match { @@ -54,14 +52,17 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser def newRenderer: HttpResponseRenderer = new HttpResponseRenderer - final class HttpResponseRenderer extends PushStage[ResponseRenderingContext, Source[ByteString, Any]] { + final class HttpResponseRenderer extends PushStage[ResponseRenderingContext, Source[ResponseRenderingOutput, Any]] { - private[this] var close = false // signals whether the connection is to be closed after the current response + private[this] var closeMode: CloseMode = DontClose // signals what to do after the current response + private[this] def close: Boolean = closeMode != DontClose + private[this] def closeIf(cond: Boolean): Unit = + if (cond) closeMode = CloseConnection // need this for testing private[http] def isComplete = close - override def onPush(ctx: ResponseRenderingContext, opCtx: Context[Source[ByteString, Any]]): SyncDirective = { + override def onPush(ctx: ResponseRenderingContext, opCtx: Context[Source[ResponseRenderingOutput, Any]]): SyncDirective = { val r = new ByteStringRendering(responseHeaderSizeHint) import ctx.response._ @@ -133,7 +134,7 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser if (!dateSeen) r ~~ dateHeader // Do we close the connection after this response? - close = + closeIf { // if we are prohibited to keep-alive by the spec alwaysClose || // if the client wants to close and we don't override @@ -143,6 +144,7 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser case `HTTP/1.1` ⇒ (connHeader ne null) && connHeader.hasClose case `HTTP/1.0` ⇒ if (connHeader eq null) ctx.requestProtocol == `HTTP/1.1` else !connHeader.hasKeepAlive }) + } // Do we render an explicit Connection header? val renderConnectionHeader = @@ -152,10 +154,11 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser if (renderConnectionHeader) r ~~ Connection ~~ (if (close) CloseBytes else KeepAliveBytes) ~~ CrLf - else if (connHeader != null && connHeader.hasUpgrade && websocketSwitch.isDefined) { + else if (connHeader != null && connHeader.hasUpgrade) { r ~~ connHeader ~~ CrLf - val websocketHeader = headers.collectFirst { case u: UpgradeToWebsocketResponseHeader ⇒ u } - websocketHeader.foreach(header ⇒ websocketSwitch.get.switchToWebsocket(header.handlerFlow)(header.mat)) + headers + .collectFirst { case u: UpgradeToWebsocketResponseHeader ⇒ u } + .foreach { header ⇒ closeMode = SwitchToWebsocket(header.handlerFlow) } } if (mustRenderTransferEncodingChunkedHeader && !transferEncodingSeen) r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf @@ -164,17 +167,24 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser def renderContentLengthHeader(contentLength: Long) = if (status.allowsEntity) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r - def byteStrings(entityBytes: ⇒ Source[ByteString, Any]): Source[ByteString, Any] = - renderByteStrings(r, entityBytes, skipEntity = noEntity) + def byteStrings(entityBytes: ⇒ Source[ByteString, Any]): Source[ResponseRenderingOutput, Any] = + renderByteStrings(r, entityBytes, skipEntity = noEntity).map(ResponseRenderingOutput.HttpData(_)) - def completeResponseRendering(entity: ResponseEntity): Source[ByteString, Any] = + def completeResponseRendering(entity: ResponseEntity): Source[ResponseRenderingOutput, Any] = entity match { case HttpEntity.Strict(_, data) ⇒ renderHeaders(headers.toList) renderEntityContentType(r, entity) renderContentLengthHeader(data.length) ~~ CrLf - val entityBytes = if (noEntity) ByteString.empty else data - Source.single(r.get ++ entityBytes) + + if (!noEntity) r ~~ data + + Source.single { + closeMode match { + case SwitchToWebsocket(handler) ⇒ ResponseRenderingOutput.SwitchToWebsocket(r.get, handler) + case _ ⇒ ResponseRenderingOutput.HttpData(r.get) + } + } case HttpEntity.Default(_, contentLength, data) ⇒ renderHeaders(headers.toList) @@ -205,6 +215,11 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser opCtx.push(result) } } + + sealed trait CloseMode + case object DontClose extends CloseMode + case object CloseConnection extends CloseMode + case class SwitchToWebsocket(handlerFlow: Flow[FrameEvent, FrameEvent, Any]) extends CloseMode } /** @@ -215,3 +230,11 @@ private[http] final case class ResponseRenderingContext( requestMethod: HttpMethod = HttpMethods.GET, requestProtocol: HttpProtocol = HttpProtocols.`HTTP/1.1`, closeRequested: Boolean = false) + +/** INTERNAL API */ +private[http] sealed trait ResponseRenderingOutput +/** INTERNAL API */ +private[http] object ResponseRenderingOutput { + private[http] case class HttpData(bytes: ByteString) extends ResponseRenderingOutput + private[http] case class SwitchToWebsocket(httpResponseBytes: ByteString, handlerFlow: Flow[FrameEvent, FrameEvent, Any]) extends ResponseRenderingOutput +} diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index f8a74d7407..07a89583e4 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -13,10 +13,10 @@ import akka.actor.{ ActorRef, Props } import akka.stream._ import akka.stream.scaladsl._ import akka.stream.stage.PushPullStage -import akka.stream.scaladsl.FlexiMerge.{ ReadAny, MergeLogic } +import akka.stream.scaladsl.FlexiMerge.{ Read, ReadAny, MergeLogic } import akka.stream.scaladsl.FlexiRoute.{ DemandFrom, RouteLogic } import akka.http.impl.engine.parsing._ -import akka.http.impl.engine.rendering.{ ResponseRenderingContext, HttpResponseRendererFactory } +import akka.http.impl.engine.rendering.{ ResponseRenderingOutput, ResponseRenderingContext, HttpResponseRendererFactory } import akka.http.impl.engine.TokenSourceActor import akka.http.scaladsl.model._ import akka.http.impl.util._ @@ -42,8 +42,8 @@ private[http] object HttpServerBluePrint { logParsingError(info withSummaryPrepended "Illegal request header", log, parserSettings.errorLoggingVerbosity) }) - val ws = websocketPipeline - val responseRendererFactory = new HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint, log, Some(ws)) + val ws = websocketSetup + val responseRendererFactory = new HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint, log) @volatile var oneHundredContinueRef: Option[ActorRef] = None // FIXME: unnecessary after fixing #16168 val oneHundredContinueSource = StreamUtils.oneTimeSource(Source.actorPublisher[OneHundredContinue.type] { @@ -88,7 +88,7 @@ private[http] object HttpServerBluePrint { .via(Flow[ResponseRenderingContext].transform(() ⇒ new ErrorsTo500ResponseRecovery(log)).named("recover")) // FIXME: simplify after #16394 is closed .via(Flow[ResponseRenderingContext].transform(() ⇒ responseRendererFactory.newRenderer).named("renderer")) .flatten(FlattenStrategy.concat) - .via(Flow[ByteString].transform(() ⇒ errorLogger(log, "Outgoing response stream error")).named("errorLogger")) + .via(Flow[ResponseRenderingOutput].transform(() ⇒ errorLogger(log, "Outgoing response stream error")).named("errorLogger")) FlowGraph.partial(requestParsingFlow, rendererPipeline, oneHundredContinueSource)((_, _, _) ⇒ ()) { implicit b ⇒ (requestParsing, renderer, oneHundreds) ⇒ @@ -107,36 +107,41 @@ private[http] object HttpServerBluePrint { bypassFanout.out(1) ~> bypass ~> bypassInput oneHundreds ~> bypassOneHundredContinueInput - val http = FlowShape(requestParsing.inlet, renderer.outlet) + + val switchTokenBroadcast = b.add(Broadcast[ResponseRenderingOutput](2)) + renderer.outlet ~> switchTokenBroadcast + val switchSource: Outlet[SwitchToWebsocketToken.type] = + (switchTokenBroadcast ~> + Flow[ResponseRenderingOutput] + .collect { + case _: ResponseRenderingOutput.SwitchToWebsocket ⇒ SwitchToWebsocketToken + }).outlet + + val http = FlowShape(requestParsing.inlet, switchTokenBroadcast.outlet) // Websocket pipeline - val websocket = b.add(ws.flow) + val websocket = b.add(ws.websocketFlow) // protocol routing val protocolRouter = b.add(new WebsocketSwitchRouter()) - val protocolMerge = b.add(new WebsocketMerge) + val protocolMerge = b.add(new WebsocketMerge(ws.installHandler)) protocolRouter.out0 ~> http ~> protocolMerge.in0 protocolRouter.out1 ~> websocket ~> protocolMerge.in1 // protocol switching - val wsSwitchTokenMerge = b.add(new StreamUtils.EagerCloseMerge2[AnyRef]("protocolSwitchWsTokenMerge")) - val switchTokenBroadcast = b.add(Broadcast[SwitchToWebsocketToken.type](2)) - ws.switchSource ~> switchTokenBroadcast.in - switchTokenBroadcast.out(0) ~> wsSwitchTokenMerge.in1 - wsSwitchTokenMerge.out /*~> printEvent[AnyRef]("netIn")*/ ~> protocolRouter.in - switchTokenBroadcast.out(1) ~> protocolMerge.in2 - val netIn = wsSwitchTokenMerge.in0 + val wsSwitchTokenMerge = b.add(new CloseIfFirstClosesMerge2[AnyRef]("protocolSwitchWsTokenMerge")) + // feed back switch signal to the protocol router + switchSource ~> wsSwitchTokenMerge.in1 + wsSwitchTokenMerge.out ~> protocolRouter.in - val netOutPrint = b.add( /*printEvent[ByteString]("netOut")*/ Flow[ByteString]) - protocolMerge.out ~> netOutPrint.inlet - val netOut = netOutPrint.outlet + val netIn = wsSwitchTokenMerge.in0 + val netOut = protocolMerge.out BidiShape[HttpResponse, ByteString, ByteString, HttpRequest]( bypassApplicationInput, netOut, netIn, - requestsIn) } } @@ -273,29 +278,11 @@ private[http] object HttpServerBluePrint { } } - case class WebsocketSetup( - flow: Flow[ByteString, ByteString, Any], - publisherKey: StreamUtils.ReadableCell[Publisher[FrameEvent]], - subscriberKey: StreamUtils.ReadableCell[Subscriber[FrameEvent]], - switchSource: Source[SwitchToWebsocketToken.type, Any]) extends WebsocketSwitch { - @volatile var switchToWebsocketRef: Option[ActorRef] = None - - def switchToWebsocket(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit = { - // 1. fill processing hole in the websocket pipeline with user-provided handler - Source(publisherKey.value) - .via(handlerFlow) - .to(Sink(subscriberKey.value)) - .run() - - // 1. and 2. could be racy in which case incoming data could arrive because of 2. before - // the pipeline in 1. has been established. The `PublisherSink`, should then, however, backpressure - // until the subscriber has connected (i.e. 1. has run). - - // 2. flip the switch - switchToWebsocketRef.get ! TokenSourceActor.Trigger - } + trait WebsocketSetup { + def websocketFlow: Flow[ByteString, ByteString, Any] + def installHandler(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit } - def websocketPipeline: WebsocketSetup = { + def websocketSetup: WebsocketSetup = { val sinkCell = new StreamUtils.OneTimeWriteCell[Publisher[FrameEvent]] val sourceCell = new StreamUtils.OneTimeWriteCell[Subscriber[FrameEvent]] @@ -308,16 +295,15 @@ private[http] object HttpServerBluePrint { .via(Flow.wrap(sink, source)((_, _) ⇒ ())) .transform(() ⇒ new FrameEventRenderer) - lazy val setup = WebsocketSetup(flow, sinkCell, sourceCell, switchToWebsocketSource) - lazy val switchToWebsocketSource: Source[SwitchToWebsocketToken.type, ActorRef] = - Source.actorPublisher[SwitchToWebsocketToken.type] { - Props { - val actor = new TokenSourceActor(SwitchToWebsocketToken) - setup.switchToWebsocketRef = Some(actor.context.self) - actor - } - } - setup + new WebsocketSetup { + def websocketFlow: Flow[ByteString, ByteString, Any] = flow + + def installHandler(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit = + Source(sinkCell.value) + .via(handlerFlow) + .to(Sink(sourceCell.value)) + .run() + } } class WebsocketSwitchRouter extends FlexiRoute[AnyRef, FanOutShape2[AnyRef, ByteString, ByteString]](new FanOutShape2("websocketSplit"), OperationAttributes.name("websocketSplit")) { @@ -345,36 +331,44 @@ private[http] object HttpServerBluePrint { } } } - class WebsocketMerge extends FlexiMerge[ByteString, FanInShape3[ByteString, ByteString, SwitchToWebsocketToken.type, ByteString]](new FanInShape3("websocketMerge"), OperationAttributes.name("websocketMerge")) { - def createMergeLogic(s: FanInShape3[ByteString, ByteString, SwitchToWebsocketToken.type, ByteString]): MergeLogic[ByteString] = + class WebsocketMerge(installHandler: Flow[FrameEvent, FrameEvent, Any] ⇒ Unit) extends FlexiMerge[ByteString, FanInShape2[ResponseRenderingOutput, ByteString, ByteString]](new FanInShape2("websocketMerge"), OperationAttributes.name("websocketMerge")) { + def createMergeLogic(s: FanInShape2[ResponseRenderingOutput, ByteString, ByteString]): MergeLogic[ByteString] = new MergeLogic[ByteString] { def httpIn = s.in0 def wsIn = s.in1 - def tokenIn = s.in2 def initialState: State[_] = http - def http: State[_] = State[AnyRef](ReadAny(httpIn.asInstanceOf[Inlet[AnyRef]], tokenIn.asInstanceOf[Inlet[AnyRef]])) { (ctx, in, element) ⇒ + def http: State[_] = State[ResponseRenderingOutput](Read(httpIn)) { (ctx, in, element) ⇒ element match { - case b: ByteString ⇒ - ctx.emit(b); SameState - case SwitchToWebsocketToken ⇒ - ctx.changeCompletionHandling(closeWhenInCloses(wsIn)) - websockets + case ResponseRenderingOutput.HttpData(bytes) ⇒ + ctx.emit(bytes); SameState + case ResponseRenderingOutput.SwitchToWebsocket(responseBytes, handlerFlow) ⇒ + ctx.emit(responseBytes) + installHandler(handlerFlow) + websocket } } - def websockets: State[_] = State[ByteString](ReadAny(httpIn /* otherwise we won't read the websocket upgrade response */ , wsIn)) { (ctx, _, element) ⇒ - ctx.emit(element) + def websocket: State[_] = State[ByteString](Read(wsIn)) { (ctx, in, bytes) ⇒ + ctx.emit(bytes) SameState } + } + } + /** A merge for two streams that just forwards all elements and closes the connection when the first input closes. */ + class CloseIfFirstClosesMerge2[T](name: String) extends FlexiMerge[T, FanInShape2[T, T, T]](new FanInShape2(name), OperationAttributes.name(name)) { + def createMergeLogic(s: FanInShape2[T, T, T]): MergeLogic[T] = + new MergeLogic[T] { + def initialState: State[T] = State[T](ReadAny(s.in0, s.in1)) { + case (ctx, port, in) ⇒ ctx.emit(in); SameState + } - def closeWhenInCloses(in: Inlet[_]): CompletionHandling = - defaultCompletionHandling.copy(onUpstreamFinish = { (ctx, closingIn) ⇒ - if (closingIn == in) ctx.finish() - SameState - }) - - override def initialCompletionHandling: CompletionHandling = closeWhenInCloses(httpIn) + override def initialCompletionHandling: CompletionHandling = + defaultCompletionHandling.copy( + onUpstreamFinish = { (ctx, in) ⇒ + if (in == s.in0) ctx.finish() + SameState + }) } } } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebsocketSwitch.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebsocketSwitch.scala deleted file mode 100644 index ade0243316..0000000000 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebsocketSwitch.scala +++ /dev/null @@ -1,13 +0,0 @@ -/* - * Copyright (C) 2009-2015 Typesafe Inc. - */ - -package akka.http.impl.engine.ws - -import akka.stream.FlowMaterializer -import akka.stream.scaladsl.Flow - -/** Internal interface between the handshake and the stream setup to evoke the switch to the websocket protocol */ -private[http] trait WebsocketSwitch { - def switchToWebsocket(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit -} diff --git a/akka-http-core/src/main/scala/akka/http/impl/util/package.scala b/akka-http-core/src/main/scala/akka/http/impl/util/package.scala index 92141e2c5a..6bb5d4ac6a 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/util/package.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/util/package.scala @@ -98,10 +98,10 @@ package object util { } } - private[http] def errorLogger(log: LoggingAdapter, msg: String): PushStage[ByteString, ByteString] = - new PushStage[ByteString, ByteString] { - override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective = ctx.push(element) - override def onUpstreamFailure(cause: Throwable, ctx: Context[ByteString]): TerminationDirective = { + private[http] def errorLogger[T](log: LoggingAdapter, msg: String): PushStage[T, T] = + new PushStage[T, T] { + override def onPush(element: T, ctx: Context[T]): SyncDirective = ctx.push(element) + override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = { log.error(cause, msg) super.onUpstreamFailure(cause, ctx) } diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/rendering/ResponseRendererSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/rendering/ResponseRendererSpec.scala index af50b7170a..d2f49ada56 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/rendering/ResponseRendererSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/rendering/ResponseRendererSpec.scala @@ -539,7 +539,7 @@ class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll override def afterAll() = system.shutdown() class TestSetup(val serverHeader: Option[Server] = Some(Server("akka-http/1.0.0"))) - extends HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint = 64, NoLogging, None) { + extends HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint = 64, NoLogging) { def renderTo(expected: String): Matcher[HttpResponse] = renderTo(expected, close = false) compose (ResponseRenderingContext(_)) @@ -547,10 +547,15 @@ class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll def renderTo(expected: String, close: Boolean): Matcher[ResponseRenderingContext] = equal(expected.stripMarginWithNewline("\r\n") -> close).matcher[(String, Boolean)] compose { ctx ⇒ val renderer = newRenderer - val byteStringSource = Await.result(Source.single(ctx) + val rendererOutputSource = Await.result(Source.single(ctx) .transform(() ⇒ renderer).named("renderer") .runWith(Sink.head), 1.second) - val future = byteStringSource.grouped(1000).runWith(Sink.head).map(_.reduceLeft(_ ++ _).utf8String) + val future = + rendererOutputSource.grouped(1000).map( + _.map { + case ResponseRenderingOutput.HttpData(bytes) ⇒ bytes + case _: ResponseRenderingOutput.SwitchToWebsocket ⇒ throw new IllegalStateException("Didn't expect websocket response") + }).runWith(Sink.head).map(_.reduceLeft(_ ++ _).utf8String) Await.result(future, 250.millis) -> renderer.isComplete } diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala index bcbc4f8d5d..589d480fa1 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerSpec.scala @@ -11,7 +11,6 @@ import scala.util.Random import scala.annotation.tailrec import scala.concurrent.duration._ import org.scalatest.Inside -import akka.event.NoLogging import akka.util.ByteString import akka.stream.scaladsl._ import akka.stream.{ FlowMaterializer, ActorFlowMaterializer } @@ -662,7 +661,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF") } } class TestSetup extends HttpServerTestSetupBase { - implicit def system: ActorSystem = spec.system - implicit def materializer: FlowMaterializer = spec.materializer + implicit def system = spec.system + implicit def materializer = spec.materializer } } diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala index 91e5f19c1e..5e6b3e70b8 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/server/HttpServerTestSetupBase.scala @@ -61,6 +61,7 @@ abstract class HttpServerTestSetupBase { requests.expectNext() } def expectNoRequest(max: FiniteDuration): Unit = requests.expectNoMsg(max) + def expectNetworkClose(): Unit = netOut.expectComplete() def send(data: ByteString): Unit = netInSub.sendNext(data) def send(data: String): Unit = send(ByteString(data, "UTF8")) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala index 5826516493..6f529fd7a3 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/MessageSpec.scala @@ -16,6 +16,8 @@ import akka.http.scaladsl.model.ws._ import Protocol.Opcode class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { + import WSTestUtils._ + "The Websocket implementation should" - { "collect messages from frames" - { "for binary messages" - { @@ -902,68 +904,6 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { netOut.expectNoMsg(100.millis) } - def frameHeader( - opcode: Opcode, - length: Long, - fin: Boolean, - mask: Option[Int] = None, - rsv1: Boolean = false, - rsv2: Boolean = false, - rsv3: Boolean = false): ByteString = { - def set(should: Boolean, mask: Int): Int = - if (should) mask else 0 - - val flags = - set(fin, Protocol.FIN_MASK) | - set(rsv1, Protocol.RSV1_MASK) | - set(rsv2, Protocol.RSV2_MASK) | - set(rsv3, Protocol.RSV3_MASK) - - val opcodeByte = opcode.code | flags - - require(length >= 0) - val (lengthByteComponent, lengthBytes) = - if (length < 126) (length.toByte, ByteString.empty) - else if (length < 65536) (126.toByte, shortBE(length.toInt)) - else throw new IllegalArgumentException("Only lengths < 65536 allowed in test") - - val maskMask = if (mask.isDefined) Protocol.MASK_MASK else 0 - val maskBytes = mask match { - case Some(mask) ⇒ intBE(mask) - case None ⇒ ByteString.empty - } - val lengthByte = lengthByteComponent | maskMask - ByteString(opcodeByte.toByte, lengthByte.toByte) ++ lengthBytes ++ maskBytes - } - def closeFrame(closeCode: Int, mask: Boolean): ByteString = - if (mask) { - val mask = Random.nextInt() - frameHeader(Opcode.Close, 2, fin = true, mask = Some(mask)) ++ - maskedBytes(shortBE(closeCode), mask)._1 - } else - frameHeader(Opcode.Close, 2, fin = true) ++ - shortBE(closeCode) - - def maskedASCII(str: String, mask: Int): (ByteString, Int) = - FrameEventParser.mask(ByteString(str, "ASCII"), mask) - def maskedUTF8(str: String, mask: Int): (ByteString, Int) = - FrameEventParser.mask(ByteString(str, "UTF-8"), mask) - def maskedBytes(bytes: ByteString, mask: Int): (ByteString, Int) = - FrameEventParser.mask(bytes, mask) - - def shortBE(value: Int): ByteString = { - require(value >= 0 && value < 65536, s"Value wasn't in short range: $value") - ByteString( - ((value >> 8) & 0xff).toByte, - ((value >> 0) & 0xff).toByte) - } - def intBE(value: Int): ByteString = - ByteString( - ((value >> 24) & 0xff).toByte, - ((value >> 16) & 0xff).toByte, - ((value >> 8) & 0xff).toByte, - ((value >> 0) & 0xff).toByte) - val trace = false // set to `true` for debugging purposes def printEvent[T](marker: String): Flow[T, T, Unit] = if (trace) akka.http.impl.util.printEvent(marker) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestUtils.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestUtils.scala new file mode 100644 index 0000000000..47f9556f2a --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WSTestUtils.scala @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.impl.engine.ws + +import akka.http.impl.engine.ws.Protocol.Opcode +import akka.util.ByteString + +import scala.util.Random + +object WSTestUtils { + def frameHeader( + opcode: Opcode, + length: Long, + fin: Boolean, + mask: Option[Int] = None, + rsv1: Boolean = false, + rsv2: Boolean = false, + rsv3: Boolean = false): ByteString = { + def set(should: Boolean, mask: Int): Int = + if (should) mask else 0 + + val flags = + set(fin, Protocol.FIN_MASK) | + set(rsv1, Protocol.RSV1_MASK) | + set(rsv2, Protocol.RSV2_MASK) | + set(rsv3, Protocol.RSV3_MASK) + + val opcodeByte = opcode.code | flags + + require(length >= 0) + val (lengthByteComponent, lengthBytes) = + if (length < 126) (length.toByte, ByteString.empty) + else if (length < 65536) (126.toByte, shortBE(length.toInt)) + else throw new IllegalArgumentException("Only lengths < 65536 allowed in test") + + val maskMask = if (mask.isDefined) Protocol.MASK_MASK else 0 + val maskBytes = mask match { + case Some(mask) ⇒ intBE(mask) + case None ⇒ ByteString.empty + } + val lengthByte = lengthByteComponent | maskMask + ByteString(opcodeByte.toByte, lengthByte.toByte) ++ lengthBytes ++ maskBytes + } + def closeFrame(closeCode: Int, mask: Boolean): ByteString = + if (mask) { + val mask = Random.nextInt() + frameHeader(Opcode.Close, 2, fin = true, mask = Some(mask)) ++ + maskedBytes(shortBE(closeCode), mask)._1 + } else + frameHeader(Opcode.Close, 2, fin = true) ++ + shortBE(closeCode) + + def maskedASCII(str: String, mask: Int): (ByteString, Int) = + FrameEventParser.mask(ByteString(str, "ASCII"), mask) + def maskedUTF8(str: String, mask: Int): (ByteString, Int) = + FrameEventParser.mask(ByteString(str, "UTF-8"), mask) + def maskedBytes(bytes: ByteString, mask: Int): (ByteString, Int) = + FrameEventParser.mask(bytes, mask) + + def shortBE(value: Int): ByteString = { + require(value >= 0 && value < 65536, s"Value wasn't in short range: $value") + ByteString( + ((value >> 8) & 0xff).toByte, + ((value >> 0) & 0xff).toByte) + } + def intBE(value: Int): ByteString = + ByteString( + ((value >> 24) & 0xff).toByte, + ((value >> 16) & 0xff).toByte, + ((value >> 8) & 0xff).toByte, + ((value >> 0) & 0xff).toByte) +} diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala new file mode 100644 index 0000000000..b0ddf8125e --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketServerSpec.scala @@ -0,0 +1,160 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.impl.engine.ws + +import akka.http.impl.engine.ws.Protocol.Opcode +import akka.http.scaladsl.model.ws._ +import akka.stream.scaladsl.{ Sink, Flow, Source } +import akka.util.ByteString +import org.scalatest.{ Matchers, FreeSpec } + +import akka.http.impl.util._ + +import akka.http.impl.engine.server.HttpServerTestSetupBase + +import scala.util.Random + +class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSpec { spec ⇒ + import WSTestUtils._ + + "The server-side Websocket integration should" - { + "establish a websocket connection when the user requests it" - { + "when user handler instantly tries to send messages" in new TestSetup { + send( + """GET /chat HTTP/1.1 + |Host: server.example.com + |Upgrade: websocket + |Connection: Upgrade + |Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== + |Origin: http://example.com + |Sec-WebSocket-Version: 13 + | + |""". + stripMarginWithNewline("\r\n")) + val request = expectRequest + val upgrade = request.header[UpgradeToWebsocket] + upgrade.isDefined shouldBe true + + val source = + Source(List(1, 2, 3, 4, 5)).map(num ⇒ TextMessage.Strict(s"Message $num")) + val handler = Flow.wrap(Sink.ignore, source)((_, _) ⇒ ()) + val response = upgrade.get.handleMessages(handler) + responsesSub.sendNext(response) + + wipeDate(expectNextChunk().utf8String) shouldEqual + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Server: akka-http/test + |Date: XXXX + |Connection: upgrade + | + |""". + stripMarginWithNewline("\r\n") + + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true) + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 2"), fin = true) + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 3"), fin = true) + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 4"), fin = true) + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 5"), fin = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + } + "for echoing user handler" in new TestSetup { + send( + """GET /echo HTTP/1.1 + |Host: server.example.com + |Upgrade: websocket + |Connection: Upgrade + |Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== + |Origin: http://example.com + |Sec-WebSocket-Version: 13 + | + |""". + stripMarginWithNewline("\r\n")) + val request = expectRequest + val upgrade = request.header[UpgradeToWebsocket] + upgrade.isDefined shouldBe true + + val response = upgrade.get.handleMessages(Flow[Message]) // simple echoing + responsesSub.sendNext(response) + + wipeDate(expectNextChunk().utf8String) shouldEqual + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Server: akka-http/test + |Date: XXXX + |Connection: upgrade + | + |""". + stripMarginWithNewline("\r\n") + + sendWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true, mask = true) + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true) + sendWSFrame(Protocol.Opcode.Text, ByteString("Message 2"), fin = true, mask = true) + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 2"), fin = true) + sendWSFrame(Protocol.Opcode.Text, ByteString("Message 3"), fin = true, mask = true) + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 3"), fin = true) + sendWSFrame(Protocol.Opcode.Text, ByteString("Message 4"), fin = true, mask = true) + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 4"), fin = true) + sendWSFrame(Protocol.Opcode.Text, ByteString("Message 5"), fin = true, mask = true) + expectWSFrame(Protocol.Opcode.Text, ByteString("Message 5"), fin = true) + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + "prevent the selection of an unavailable subprotocol" in pending + "reject invalid Websocket handshakes" - { + "missing `Connection: upgrade` header" in pending + "missing `Sec-WebSocket-Key header" in pending + "`Sec-WebSocket-Key` with wrong amount of base64 encoded data" in pending + "missing `Sec-WebSocket-Version` header" in pending + "unsupported `Sec-WebSocket-Version`" in pending + } + } + + class TestSetup extends HttpServerTestSetupBase { + implicit def system = spec.system + implicit def materializer = spec.materializer + + def sendWSFrame(opcode: Opcode, + data: ByteString, + fin: Boolean, + mask: Boolean = false, + rsv1: Boolean = false, + rsv2: Boolean = false, + rsv3: Boolean = false): Unit = { + val (theMask, theData) = + if (mask) { + val m = Random.nextInt() + (Some(m), maskedBytes(data, m)._1) + } else (None, data) + send(frameHeader(opcode, data.length, fin, theMask, rsv1, rsv2, rsv3) ++ theData) + } + + def sendWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit = + send(closeFrame(closeCode, mask)) + + def expectNextChunk(): ByteString = { + netOutSub.request(1) + netOut.expectNext() + } + + def expectWSFrame(opcode: Opcode, + data: ByteString, + fin: Boolean, + mask: Option[Int] = None, + rsv1: Boolean = false, + rsv2: Boolean = false, + rsv3: Boolean = false): Unit = + expectNextChunk() shouldEqual frameHeader(opcode, data.length, fin, mask, rsv1, rsv2, rsv3) ++ data + def expectWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit = + expectNextChunk() shouldEqual closeFrame(closeCode, mask) + } +}