=htc #17322 fix race condition between websocket http response and first websocket message
The fix was to make the switch atomic instead of relying on side-channels which introduced an ordering dependency in the protocol merge. Also, the switch token itself is now fed back from the response renderer output to the protocol switch leading to a simplified design.
This commit is contained in:
parent
444a0deef5
commit
c283dd45ae
10 changed files with 354 additions and 171 deletions
|
|
@ -4,13 +4,12 @@
|
|||
|
||||
package akka.http.impl.engine.rendering
|
||||
|
||||
import akka.http.impl.engine.ws.{ WebsocketSwitch, UpgradeToWebsocketResponseHeader, Handshake }
|
||||
import akka.http.impl.engine.ws.{ FrameEvent, UpgradeToWebsocketResponseHeader }
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import akka.event.LoggingAdapter
|
||||
import akka.util.ByteString
|
||||
import akka.stream.OperationAttributes._
|
||||
import akka.stream.scaladsl.Source
|
||||
import akka.stream.scaladsl.{ Flow, Source }
|
||||
import akka.stream.stage._
|
||||
import akka.http.scaladsl.model._
|
||||
import akka.http.impl.util._
|
||||
|
|
@ -23,8 +22,7 @@ import headers._
|
|||
*/
|
||||
private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Server],
|
||||
responseHeaderSizeHint: Int,
|
||||
log: LoggingAdapter,
|
||||
websocketSwitch: Option[WebsocketSwitch] = None) {
|
||||
log: LoggingAdapter) {
|
||||
|
||||
private val renderDefaultServerHeader: Rendering ⇒ Unit =
|
||||
serverHeader match {
|
||||
|
|
@ -54,14 +52,17 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
|
|||
|
||||
def newRenderer: HttpResponseRenderer = new HttpResponseRenderer
|
||||
|
||||
final class HttpResponseRenderer extends PushStage[ResponseRenderingContext, Source[ByteString, Any]] {
|
||||
final class HttpResponseRenderer extends PushStage[ResponseRenderingContext, Source[ResponseRenderingOutput, Any]] {
|
||||
|
||||
private[this] var close = false // signals whether the connection is to be closed after the current response
|
||||
private[this] var closeMode: CloseMode = DontClose // signals what to do after the current response
|
||||
private[this] def close: Boolean = closeMode != DontClose
|
||||
private[this] def closeIf(cond: Boolean): Unit =
|
||||
if (cond) closeMode = CloseConnection
|
||||
|
||||
// need this for testing
|
||||
private[http] def isComplete = close
|
||||
|
||||
override def onPush(ctx: ResponseRenderingContext, opCtx: Context[Source[ByteString, Any]]): SyncDirective = {
|
||||
override def onPush(ctx: ResponseRenderingContext, opCtx: Context[Source[ResponseRenderingOutput, Any]]): SyncDirective = {
|
||||
val r = new ByteStringRendering(responseHeaderSizeHint)
|
||||
|
||||
import ctx.response._
|
||||
|
|
@ -133,7 +134,7 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
|
|||
if (!dateSeen) r ~~ dateHeader
|
||||
|
||||
// Do we close the connection after this response?
|
||||
close =
|
||||
closeIf {
|
||||
// if we are prohibited to keep-alive by the spec
|
||||
alwaysClose ||
|
||||
// if the client wants to close and we don't override
|
||||
|
|
@ -143,6 +144,7 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
|
|||
case `HTTP/1.1` ⇒ (connHeader ne null) && connHeader.hasClose
|
||||
case `HTTP/1.0` ⇒ if (connHeader eq null) ctx.requestProtocol == `HTTP/1.1` else !connHeader.hasKeepAlive
|
||||
})
|
||||
}
|
||||
|
||||
// Do we render an explicit Connection header?
|
||||
val renderConnectionHeader =
|
||||
|
|
@ -152,10 +154,11 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
|
|||
|
||||
if (renderConnectionHeader)
|
||||
r ~~ Connection ~~ (if (close) CloseBytes else KeepAliveBytes) ~~ CrLf
|
||||
else if (connHeader != null && connHeader.hasUpgrade && websocketSwitch.isDefined) {
|
||||
else if (connHeader != null && connHeader.hasUpgrade) {
|
||||
r ~~ connHeader ~~ CrLf
|
||||
val websocketHeader = headers.collectFirst { case u: UpgradeToWebsocketResponseHeader ⇒ u }
|
||||
websocketHeader.foreach(header ⇒ websocketSwitch.get.switchToWebsocket(header.handlerFlow)(header.mat))
|
||||
headers
|
||||
.collectFirst { case u: UpgradeToWebsocketResponseHeader ⇒ u }
|
||||
.foreach { header ⇒ closeMode = SwitchToWebsocket(header.handlerFlow) }
|
||||
}
|
||||
if (mustRenderTransferEncodingChunkedHeader && !transferEncodingSeen)
|
||||
r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf
|
||||
|
|
@ -164,17 +167,24 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
|
|||
def renderContentLengthHeader(contentLength: Long) =
|
||||
if (status.allowsEntity) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r
|
||||
|
||||
def byteStrings(entityBytes: ⇒ Source[ByteString, Any]): Source[ByteString, Any] =
|
||||
renderByteStrings(r, entityBytes, skipEntity = noEntity)
|
||||
def byteStrings(entityBytes: ⇒ Source[ByteString, Any]): Source[ResponseRenderingOutput, Any] =
|
||||
renderByteStrings(r, entityBytes, skipEntity = noEntity).map(ResponseRenderingOutput.HttpData(_))
|
||||
|
||||
def completeResponseRendering(entity: ResponseEntity): Source[ByteString, Any] =
|
||||
def completeResponseRendering(entity: ResponseEntity): Source[ResponseRenderingOutput, Any] =
|
||||
entity match {
|
||||
case HttpEntity.Strict(_, data) ⇒
|
||||
renderHeaders(headers.toList)
|
||||
renderEntityContentType(r, entity)
|
||||
renderContentLengthHeader(data.length) ~~ CrLf
|
||||
val entityBytes = if (noEntity) ByteString.empty else data
|
||||
Source.single(r.get ++ entityBytes)
|
||||
|
||||
if (!noEntity) r ~~ data
|
||||
|
||||
Source.single {
|
||||
closeMode match {
|
||||
case SwitchToWebsocket(handler) ⇒ ResponseRenderingOutput.SwitchToWebsocket(r.get, handler)
|
||||
case _ ⇒ ResponseRenderingOutput.HttpData(r.get)
|
||||
}
|
||||
}
|
||||
|
||||
case HttpEntity.Default(_, contentLength, data) ⇒
|
||||
renderHeaders(headers.toList)
|
||||
|
|
@ -205,6 +215,11 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
|
|||
opCtx.push(result)
|
||||
}
|
||||
}
|
||||
|
||||
sealed trait CloseMode
|
||||
case object DontClose extends CloseMode
|
||||
case object CloseConnection extends CloseMode
|
||||
case class SwitchToWebsocket(handlerFlow: Flow[FrameEvent, FrameEvent, Any]) extends CloseMode
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -215,3 +230,11 @@ private[http] final case class ResponseRenderingContext(
|
|||
requestMethod: HttpMethod = HttpMethods.GET,
|
||||
requestProtocol: HttpProtocol = HttpProtocols.`HTTP/1.1`,
|
||||
closeRequested: Boolean = false)
|
||||
|
||||
/** INTERNAL API */
|
||||
private[http] sealed trait ResponseRenderingOutput
|
||||
/** INTERNAL API */
|
||||
private[http] object ResponseRenderingOutput {
|
||||
private[http] case class HttpData(bytes: ByteString) extends ResponseRenderingOutput
|
||||
private[http] case class SwitchToWebsocket(httpResponseBytes: ByteString, handlerFlow: Flow[FrameEvent, FrameEvent, Any]) extends ResponseRenderingOutput
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,10 +13,10 @@ import akka.actor.{ ActorRef, Props }
|
|||
import akka.stream._
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream.stage.PushPullStage
|
||||
import akka.stream.scaladsl.FlexiMerge.{ ReadAny, MergeLogic }
|
||||
import akka.stream.scaladsl.FlexiMerge.{ Read, ReadAny, MergeLogic }
|
||||
import akka.stream.scaladsl.FlexiRoute.{ DemandFrom, RouteLogic }
|
||||
import akka.http.impl.engine.parsing._
|
||||
import akka.http.impl.engine.rendering.{ ResponseRenderingContext, HttpResponseRendererFactory }
|
||||
import akka.http.impl.engine.rendering.{ ResponseRenderingOutput, ResponseRenderingContext, HttpResponseRendererFactory }
|
||||
import akka.http.impl.engine.TokenSourceActor
|
||||
import akka.http.scaladsl.model._
|
||||
import akka.http.impl.util._
|
||||
|
|
@ -42,8 +42,8 @@ private[http] object HttpServerBluePrint {
|
|||
logParsingError(info withSummaryPrepended "Illegal request header", log, parserSettings.errorLoggingVerbosity)
|
||||
})
|
||||
|
||||
val ws = websocketPipeline
|
||||
val responseRendererFactory = new HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint, log, Some(ws))
|
||||
val ws = websocketSetup
|
||||
val responseRendererFactory = new HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint, log)
|
||||
|
||||
@volatile var oneHundredContinueRef: Option[ActorRef] = None // FIXME: unnecessary after fixing #16168
|
||||
val oneHundredContinueSource = StreamUtils.oneTimeSource(Source.actorPublisher[OneHundredContinue.type] {
|
||||
|
|
@ -88,7 +88,7 @@ private[http] object HttpServerBluePrint {
|
|||
.via(Flow[ResponseRenderingContext].transform(() ⇒ new ErrorsTo500ResponseRecovery(log)).named("recover")) // FIXME: simplify after #16394 is closed
|
||||
.via(Flow[ResponseRenderingContext].transform(() ⇒ responseRendererFactory.newRenderer).named("renderer"))
|
||||
.flatten(FlattenStrategy.concat)
|
||||
.via(Flow[ByteString].transform(() ⇒ errorLogger(log, "Outgoing response stream error")).named("errorLogger"))
|
||||
.via(Flow[ResponseRenderingOutput].transform(() ⇒ errorLogger(log, "Outgoing response stream error")).named("errorLogger"))
|
||||
|
||||
FlowGraph.partial(requestParsingFlow, rendererPipeline, oneHundredContinueSource)((_, _, _) ⇒ ()) { implicit b ⇒
|
||||
(requestParsing, renderer, oneHundreds) ⇒
|
||||
|
|
@ -107,36 +107,41 @@ private[http] object HttpServerBluePrint {
|
|||
|
||||
bypassFanout.out(1) ~> bypass ~> bypassInput
|
||||
oneHundreds ~> bypassOneHundredContinueInput
|
||||
val http = FlowShape(requestParsing.inlet, renderer.outlet)
|
||||
|
||||
val switchTokenBroadcast = b.add(Broadcast[ResponseRenderingOutput](2))
|
||||
renderer.outlet ~> switchTokenBroadcast
|
||||
val switchSource: Outlet[SwitchToWebsocketToken.type] =
|
||||
(switchTokenBroadcast ~>
|
||||
Flow[ResponseRenderingOutput]
|
||||
.collect {
|
||||
case _: ResponseRenderingOutput.SwitchToWebsocket ⇒ SwitchToWebsocketToken
|
||||
}).outlet
|
||||
|
||||
val http = FlowShape(requestParsing.inlet, switchTokenBroadcast.outlet)
|
||||
|
||||
// Websocket pipeline
|
||||
val websocket = b.add(ws.flow)
|
||||
val websocket = b.add(ws.websocketFlow)
|
||||
|
||||
// protocol routing
|
||||
val protocolRouter = b.add(new WebsocketSwitchRouter())
|
||||
val protocolMerge = b.add(new WebsocketMerge)
|
||||
val protocolMerge = b.add(new WebsocketMerge(ws.installHandler))
|
||||
|
||||
protocolRouter.out0 ~> http ~> protocolMerge.in0
|
||||
protocolRouter.out1 ~> websocket ~> protocolMerge.in1
|
||||
|
||||
// protocol switching
|
||||
val wsSwitchTokenMerge = b.add(new StreamUtils.EagerCloseMerge2[AnyRef]("protocolSwitchWsTokenMerge"))
|
||||
val switchTokenBroadcast = b.add(Broadcast[SwitchToWebsocketToken.type](2))
|
||||
ws.switchSource ~> switchTokenBroadcast.in
|
||||
switchTokenBroadcast.out(0) ~> wsSwitchTokenMerge.in1
|
||||
wsSwitchTokenMerge.out /*~> printEvent[AnyRef]("netIn")*/ ~> protocolRouter.in
|
||||
switchTokenBroadcast.out(1) ~> protocolMerge.in2
|
||||
val netIn = wsSwitchTokenMerge.in0
|
||||
val wsSwitchTokenMerge = b.add(new CloseIfFirstClosesMerge2[AnyRef]("protocolSwitchWsTokenMerge"))
|
||||
// feed back switch signal to the protocol router
|
||||
switchSource ~> wsSwitchTokenMerge.in1
|
||||
wsSwitchTokenMerge.out ~> protocolRouter.in
|
||||
|
||||
val netOutPrint = b.add( /*printEvent[ByteString]("netOut")*/ Flow[ByteString])
|
||||
protocolMerge.out ~> netOutPrint.inlet
|
||||
val netOut = netOutPrint.outlet
|
||||
val netIn = wsSwitchTokenMerge.in0
|
||||
val netOut = protocolMerge.out
|
||||
|
||||
BidiShape[HttpResponse, ByteString, ByteString, HttpRequest](
|
||||
bypassApplicationInput,
|
||||
netOut,
|
||||
netIn,
|
||||
|
||||
requestsIn)
|
||||
}
|
||||
}
|
||||
|
|
@ -273,29 +278,11 @@ private[http] object HttpServerBluePrint {
|
|||
}
|
||||
}
|
||||
|
||||
case class WebsocketSetup(
|
||||
flow: Flow[ByteString, ByteString, Any],
|
||||
publisherKey: StreamUtils.ReadableCell[Publisher[FrameEvent]],
|
||||
subscriberKey: StreamUtils.ReadableCell[Subscriber[FrameEvent]],
|
||||
switchSource: Source[SwitchToWebsocketToken.type, Any]) extends WebsocketSwitch {
|
||||
@volatile var switchToWebsocketRef: Option[ActorRef] = None
|
||||
|
||||
def switchToWebsocket(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit = {
|
||||
// 1. fill processing hole in the websocket pipeline with user-provided handler
|
||||
Source(publisherKey.value)
|
||||
.via(handlerFlow)
|
||||
.to(Sink(subscriberKey.value))
|
||||
.run()
|
||||
|
||||
// 1. and 2. could be racy in which case incoming data could arrive because of 2. before
|
||||
// the pipeline in 1. has been established. The `PublisherSink`, should then, however, backpressure
|
||||
// until the subscriber has connected (i.e. 1. has run).
|
||||
|
||||
// 2. flip the switch
|
||||
switchToWebsocketRef.get ! TokenSourceActor.Trigger
|
||||
}
|
||||
trait WebsocketSetup {
|
||||
def websocketFlow: Flow[ByteString, ByteString, Any]
|
||||
def installHandler(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit
|
||||
}
|
||||
def websocketPipeline: WebsocketSetup = {
|
||||
def websocketSetup: WebsocketSetup = {
|
||||
val sinkCell = new StreamUtils.OneTimeWriteCell[Publisher[FrameEvent]]
|
||||
val sourceCell = new StreamUtils.OneTimeWriteCell[Subscriber[FrameEvent]]
|
||||
|
||||
|
|
@ -308,16 +295,15 @@ private[http] object HttpServerBluePrint {
|
|||
.via(Flow.wrap(sink, source)((_, _) ⇒ ()))
|
||||
.transform(() ⇒ new FrameEventRenderer)
|
||||
|
||||
lazy val setup = WebsocketSetup(flow, sinkCell, sourceCell, switchToWebsocketSource)
|
||||
lazy val switchToWebsocketSource: Source[SwitchToWebsocketToken.type, ActorRef] =
|
||||
Source.actorPublisher[SwitchToWebsocketToken.type] {
|
||||
Props {
|
||||
val actor = new TokenSourceActor(SwitchToWebsocketToken)
|
||||
setup.switchToWebsocketRef = Some(actor.context.self)
|
||||
actor
|
||||
}
|
||||
}
|
||||
setup
|
||||
new WebsocketSetup {
|
||||
def websocketFlow: Flow[ByteString, ByteString, Any] = flow
|
||||
|
||||
def installHandler(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit =
|
||||
Source(sinkCell.value)
|
||||
.via(handlerFlow)
|
||||
.to(Sink(sourceCell.value))
|
||||
.run()
|
||||
}
|
||||
}
|
||||
class WebsocketSwitchRouter
|
||||
extends FlexiRoute[AnyRef, FanOutShape2[AnyRef, ByteString, ByteString]](new FanOutShape2("websocketSplit"), OperationAttributes.name("websocketSplit")) {
|
||||
|
|
@ -345,36 +331,44 @@ private[http] object HttpServerBluePrint {
|
|||
}
|
||||
}
|
||||
}
|
||||
class WebsocketMerge extends FlexiMerge[ByteString, FanInShape3[ByteString, ByteString, SwitchToWebsocketToken.type, ByteString]](new FanInShape3("websocketMerge"), OperationAttributes.name("websocketMerge")) {
|
||||
def createMergeLogic(s: FanInShape3[ByteString, ByteString, SwitchToWebsocketToken.type, ByteString]): MergeLogic[ByteString] =
|
||||
class WebsocketMerge(installHandler: Flow[FrameEvent, FrameEvent, Any] ⇒ Unit) extends FlexiMerge[ByteString, FanInShape2[ResponseRenderingOutput, ByteString, ByteString]](new FanInShape2("websocketMerge"), OperationAttributes.name("websocketMerge")) {
|
||||
def createMergeLogic(s: FanInShape2[ResponseRenderingOutput, ByteString, ByteString]): MergeLogic[ByteString] =
|
||||
new MergeLogic[ByteString] {
|
||||
def httpIn = s.in0
|
||||
def wsIn = s.in1
|
||||
def tokenIn = s.in2
|
||||
|
||||
def initialState: State[_] = http
|
||||
|
||||
def http: State[_] = State[AnyRef](ReadAny(httpIn.asInstanceOf[Inlet[AnyRef]], tokenIn.asInstanceOf[Inlet[AnyRef]])) { (ctx, in, element) ⇒
|
||||
def http: State[_] = State[ResponseRenderingOutput](Read(httpIn)) { (ctx, in, element) ⇒
|
||||
element match {
|
||||
case b: ByteString ⇒
|
||||
ctx.emit(b); SameState
|
||||
case SwitchToWebsocketToken ⇒
|
||||
ctx.changeCompletionHandling(closeWhenInCloses(wsIn))
|
||||
websockets
|
||||
case ResponseRenderingOutput.HttpData(bytes) ⇒
|
||||
ctx.emit(bytes); SameState
|
||||
case ResponseRenderingOutput.SwitchToWebsocket(responseBytes, handlerFlow) ⇒
|
||||
ctx.emit(responseBytes)
|
||||
installHandler(handlerFlow)
|
||||
websocket
|
||||
}
|
||||
}
|
||||
def websockets: State[_] = State[ByteString](ReadAny(httpIn /* otherwise we won't read the websocket upgrade response */ , wsIn)) { (ctx, _, element) ⇒
|
||||
ctx.emit(element)
|
||||
def websocket: State[_] = State[ByteString](Read(wsIn)) { (ctx, in, bytes) ⇒
|
||||
ctx.emit(bytes)
|
||||
SameState
|
||||
}
|
||||
}
|
||||
}
|
||||
/** A merge for two streams that just forwards all elements and closes the connection when the first input closes. */
|
||||
class CloseIfFirstClosesMerge2[T](name: String) extends FlexiMerge[T, FanInShape2[T, T, T]](new FanInShape2(name), OperationAttributes.name(name)) {
|
||||
def createMergeLogic(s: FanInShape2[T, T, T]): MergeLogic[T] =
|
||||
new MergeLogic[T] {
|
||||
def initialState: State[T] = State[T](ReadAny(s.in0, s.in1)) {
|
||||
case (ctx, port, in) ⇒ ctx.emit(in); SameState
|
||||
}
|
||||
|
||||
def closeWhenInCloses(in: Inlet[_]): CompletionHandling =
|
||||
defaultCompletionHandling.copy(onUpstreamFinish = { (ctx, closingIn) ⇒
|
||||
if (closingIn == in) ctx.finish()
|
||||
SameState
|
||||
})
|
||||
|
||||
override def initialCompletionHandling: CompletionHandling = closeWhenInCloses(httpIn)
|
||||
override def initialCompletionHandling: CompletionHandling =
|
||||
defaultCompletionHandling.copy(
|
||||
onUpstreamFinish = { (ctx, in) ⇒
|
||||
if (in == s.in0) ctx.finish()
|
||||
SameState
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +0,0 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.impl.engine.ws
|
||||
|
||||
import akka.stream.FlowMaterializer
|
||||
import akka.stream.scaladsl.Flow
|
||||
|
||||
/** Internal interface between the handshake and the stream setup to evoke the switch to the websocket protocol */
|
||||
private[http] trait WebsocketSwitch {
|
||||
def switchToWebsocket(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit
|
||||
}
|
||||
|
|
@ -98,10 +98,10 @@ package object util {
|
|||
}
|
||||
}
|
||||
|
||||
private[http] def errorLogger(log: LoggingAdapter, msg: String): PushStage[ByteString, ByteString] =
|
||||
new PushStage[ByteString, ByteString] {
|
||||
override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective = ctx.push(element)
|
||||
override def onUpstreamFailure(cause: Throwable, ctx: Context[ByteString]): TerminationDirective = {
|
||||
private[http] def errorLogger[T](log: LoggingAdapter, msg: String): PushStage[T, T] =
|
||||
new PushStage[T, T] {
|
||||
override def onPush(element: T, ctx: Context[T]): SyncDirective = ctx.push(element)
|
||||
override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = {
|
||||
log.error(cause, msg)
|
||||
super.onUpstreamFailure(cause, ctx)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -539,7 +539,7 @@ class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll
|
|||
override def afterAll() = system.shutdown()
|
||||
|
||||
class TestSetup(val serverHeader: Option[Server] = Some(Server("akka-http/1.0.0")))
|
||||
extends HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint = 64, NoLogging, None) {
|
||||
extends HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint = 64, NoLogging) {
|
||||
|
||||
def renderTo(expected: String): Matcher[HttpResponse] =
|
||||
renderTo(expected, close = false) compose (ResponseRenderingContext(_))
|
||||
|
|
@ -547,10 +547,15 @@ class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll
|
|||
def renderTo(expected: String, close: Boolean): Matcher[ResponseRenderingContext] =
|
||||
equal(expected.stripMarginWithNewline("\r\n") -> close).matcher[(String, Boolean)] compose { ctx ⇒
|
||||
val renderer = newRenderer
|
||||
val byteStringSource = Await.result(Source.single(ctx)
|
||||
val rendererOutputSource = Await.result(Source.single(ctx)
|
||||
.transform(() ⇒ renderer).named("renderer")
|
||||
.runWith(Sink.head), 1.second)
|
||||
val future = byteStringSource.grouped(1000).runWith(Sink.head).map(_.reduceLeft(_ ++ _).utf8String)
|
||||
val future =
|
||||
rendererOutputSource.grouped(1000).map(
|
||||
_.map {
|
||||
case ResponseRenderingOutput.HttpData(bytes) ⇒ bytes
|
||||
case _: ResponseRenderingOutput.SwitchToWebsocket ⇒ throw new IllegalStateException("Didn't expect websocket response")
|
||||
}).runWith(Sink.head).map(_.reduceLeft(_ ++ _).utf8String)
|
||||
Await.result(future, 250.millis) -> renderer.isComplete
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import scala.util.Random
|
|||
import scala.annotation.tailrec
|
||||
import scala.concurrent.duration._
|
||||
import org.scalatest.Inside
|
||||
import akka.event.NoLogging
|
||||
import akka.util.ByteString
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream.{ FlowMaterializer, ActorFlowMaterializer }
|
||||
|
|
@ -662,7 +661,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
|
|||
}
|
||||
}
|
||||
class TestSetup extends HttpServerTestSetupBase {
|
||||
implicit def system: ActorSystem = spec.system
|
||||
implicit def materializer: FlowMaterializer = spec.materializer
|
||||
implicit def system = spec.system
|
||||
implicit def materializer = spec.materializer
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ abstract class HttpServerTestSetupBase {
|
|||
requests.expectNext()
|
||||
}
|
||||
def expectNoRequest(max: FiniteDuration): Unit = requests.expectNoMsg(max)
|
||||
def expectNetworkClose(): Unit = netOut.expectComplete()
|
||||
|
||||
def send(data: ByteString): Unit = netInSub.sendNext(data)
|
||||
def send(data: String): Unit = send(ByteString(data, "UTF8"))
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ import akka.http.scaladsl.model.ws._
|
|||
import Protocol.Opcode
|
||||
|
||||
class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
|
||||
import WSTestUtils._
|
||||
|
||||
"The Websocket implementation should" - {
|
||||
"collect messages from frames" - {
|
||||
"for binary messages" - {
|
||||
|
|
@ -902,68 +904,6 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
|
|||
netOut.expectNoMsg(100.millis)
|
||||
}
|
||||
|
||||
def frameHeader(
|
||||
opcode: Opcode,
|
||||
length: Long,
|
||||
fin: Boolean,
|
||||
mask: Option[Int] = None,
|
||||
rsv1: Boolean = false,
|
||||
rsv2: Boolean = false,
|
||||
rsv3: Boolean = false): ByteString = {
|
||||
def set(should: Boolean, mask: Int): Int =
|
||||
if (should) mask else 0
|
||||
|
||||
val flags =
|
||||
set(fin, Protocol.FIN_MASK) |
|
||||
set(rsv1, Protocol.RSV1_MASK) |
|
||||
set(rsv2, Protocol.RSV2_MASK) |
|
||||
set(rsv3, Protocol.RSV3_MASK)
|
||||
|
||||
val opcodeByte = opcode.code | flags
|
||||
|
||||
require(length >= 0)
|
||||
val (lengthByteComponent, lengthBytes) =
|
||||
if (length < 126) (length.toByte, ByteString.empty)
|
||||
else if (length < 65536) (126.toByte, shortBE(length.toInt))
|
||||
else throw new IllegalArgumentException("Only lengths < 65536 allowed in test")
|
||||
|
||||
val maskMask = if (mask.isDefined) Protocol.MASK_MASK else 0
|
||||
val maskBytes = mask match {
|
||||
case Some(mask) ⇒ intBE(mask)
|
||||
case None ⇒ ByteString.empty
|
||||
}
|
||||
val lengthByte = lengthByteComponent | maskMask
|
||||
ByteString(opcodeByte.toByte, lengthByte.toByte) ++ lengthBytes ++ maskBytes
|
||||
}
|
||||
def closeFrame(closeCode: Int, mask: Boolean): ByteString =
|
||||
if (mask) {
|
||||
val mask = Random.nextInt()
|
||||
frameHeader(Opcode.Close, 2, fin = true, mask = Some(mask)) ++
|
||||
maskedBytes(shortBE(closeCode), mask)._1
|
||||
} else
|
||||
frameHeader(Opcode.Close, 2, fin = true) ++
|
||||
shortBE(closeCode)
|
||||
|
||||
def maskedASCII(str: String, mask: Int): (ByteString, Int) =
|
||||
FrameEventParser.mask(ByteString(str, "ASCII"), mask)
|
||||
def maskedUTF8(str: String, mask: Int): (ByteString, Int) =
|
||||
FrameEventParser.mask(ByteString(str, "UTF-8"), mask)
|
||||
def maskedBytes(bytes: ByteString, mask: Int): (ByteString, Int) =
|
||||
FrameEventParser.mask(bytes, mask)
|
||||
|
||||
def shortBE(value: Int): ByteString = {
|
||||
require(value >= 0 && value < 65536, s"Value wasn't in short range: $value")
|
||||
ByteString(
|
||||
((value >> 8) & 0xff).toByte,
|
||||
((value >> 0) & 0xff).toByte)
|
||||
}
|
||||
def intBE(value: Int): ByteString =
|
||||
ByteString(
|
||||
((value >> 24) & 0xff).toByte,
|
||||
((value >> 16) & 0xff).toByte,
|
||||
((value >> 8) & 0xff).toByte,
|
||||
((value >> 0) & 0xff).toByte)
|
||||
|
||||
val trace = false // set to `true` for debugging purposes
|
||||
def printEvent[T](marker: String): Flow[T, T, Unit] =
|
||||
if (trace) akka.http.impl.util.printEvent(marker)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.impl.engine.ws
|
||||
|
||||
import akka.http.impl.engine.ws.Protocol.Opcode
|
||||
import akka.util.ByteString
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
object WSTestUtils {
|
||||
def frameHeader(
|
||||
opcode: Opcode,
|
||||
length: Long,
|
||||
fin: Boolean,
|
||||
mask: Option[Int] = None,
|
||||
rsv1: Boolean = false,
|
||||
rsv2: Boolean = false,
|
||||
rsv3: Boolean = false): ByteString = {
|
||||
def set(should: Boolean, mask: Int): Int =
|
||||
if (should) mask else 0
|
||||
|
||||
val flags =
|
||||
set(fin, Protocol.FIN_MASK) |
|
||||
set(rsv1, Protocol.RSV1_MASK) |
|
||||
set(rsv2, Protocol.RSV2_MASK) |
|
||||
set(rsv3, Protocol.RSV3_MASK)
|
||||
|
||||
val opcodeByte = opcode.code | flags
|
||||
|
||||
require(length >= 0)
|
||||
val (lengthByteComponent, lengthBytes) =
|
||||
if (length < 126) (length.toByte, ByteString.empty)
|
||||
else if (length < 65536) (126.toByte, shortBE(length.toInt))
|
||||
else throw new IllegalArgumentException("Only lengths < 65536 allowed in test")
|
||||
|
||||
val maskMask = if (mask.isDefined) Protocol.MASK_MASK else 0
|
||||
val maskBytes = mask match {
|
||||
case Some(mask) ⇒ intBE(mask)
|
||||
case None ⇒ ByteString.empty
|
||||
}
|
||||
val lengthByte = lengthByteComponent | maskMask
|
||||
ByteString(opcodeByte.toByte, lengthByte.toByte) ++ lengthBytes ++ maskBytes
|
||||
}
|
||||
def closeFrame(closeCode: Int, mask: Boolean): ByteString =
|
||||
if (mask) {
|
||||
val mask = Random.nextInt()
|
||||
frameHeader(Opcode.Close, 2, fin = true, mask = Some(mask)) ++
|
||||
maskedBytes(shortBE(closeCode), mask)._1
|
||||
} else
|
||||
frameHeader(Opcode.Close, 2, fin = true) ++
|
||||
shortBE(closeCode)
|
||||
|
||||
def maskedASCII(str: String, mask: Int): (ByteString, Int) =
|
||||
FrameEventParser.mask(ByteString(str, "ASCII"), mask)
|
||||
def maskedUTF8(str: String, mask: Int): (ByteString, Int) =
|
||||
FrameEventParser.mask(ByteString(str, "UTF-8"), mask)
|
||||
def maskedBytes(bytes: ByteString, mask: Int): (ByteString, Int) =
|
||||
FrameEventParser.mask(bytes, mask)
|
||||
|
||||
def shortBE(value: Int): ByteString = {
|
||||
require(value >= 0 && value < 65536, s"Value wasn't in short range: $value")
|
||||
ByteString(
|
||||
((value >> 8) & 0xff).toByte,
|
||||
((value >> 0) & 0xff).toByte)
|
||||
}
|
||||
def intBE(value: Int): ByteString =
|
||||
ByteString(
|
||||
((value >> 24) & 0xff).toByte,
|
||||
((value >> 16) & 0xff).toByte,
|
||||
((value >> 8) & 0xff).toByte,
|
||||
((value >> 0) & 0xff).toByte)
|
||||
}
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.impl.engine.ws
|
||||
|
||||
import akka.http.impl.engine.ws.Protocol.Opcode
|
||||
import akka.http.scaladsl.model.ws._
|
||||
import akka.stream.scaladsl.{ Sink, Flow, Source }
|
||||
import akka.util.ByteString
|
||||
import org.scalatest.{ Matchers, FreeSpec }
|
||||
|
||||
import akka.http.impl.util._
|
||||
|
||||
import akka.http.impl.engine.server.HttpServerTestSetupBase
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSpec { spec ⇒
|
||||
import WSTestUtils._
|
||||
|
||||
"The server-side Websocket integration should" - {
|
||||
"establish a websocket connection when the user requests it" - {
|
||||
"when user handler instantly tries to send messages" in new TestSetup {
|
||||
send(
|
||||
"""GET /chat HTTP/1.1
|
||||
|Host: server.example.com
|
||||
|Upgrade: websocket
|
||||
|Connection: Upgrade
|
||||
|Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
|
||||
|Origin: http://example.com
|
||||
|Sec-WebSocket-Version: 13
|
||||
|
|
||||
|""".
|
||||
stripMarginWithNewline("\r\n"))
|
||||
val request = expectRequest
|
||||
val upgrade = request.header[UpgradeToWebsocket]
|
||||
upgrade.isDefined shouldBe true
|
||||
|
||||
val source =
|
||||
Source(List(1, 2, 3, 4, 5)).map(num ⇒ TextMessage.Strict(s"Message $num"))
|
||||
val handler = Flow.wrap(Sink.ignore, source)((_, _) ⇒ ())
|
||||
val response = upgrade.get.handleMessages(handler)
|
||||
responsesSub.sendNext(response)
|
||||
|
||||
wipeDate(expectNextChunk().utf8String) shouldEqual
|
||||
"""HTTP/1.1 101 Switching Protocols
|
||||
|Upgrade: websocket
|
||||
|Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
|
||||
|Server: akka-http/test
|
||||
|Date: XXXX
|
||||
|Connection: upgrade
|
||||
|
|
||||
|""".
|
||||
stripMarginWithNewline("\r\n")
|
||||
|
||||
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true)
|
||||
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 2"), fin = true)
|
||||
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 3"), fin = true)
|
||||
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 4"), fin = true)
|
||||
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 5"), fin = true)
|
||||
expectWSCloseFrame(Protocol.CloseCodes.Regular)
|
||||
}
|
||||
"for echoing user handler" in new TestSetup {
|
||||
send(
|
||||
"""GET /echo HTTP/1.1
|
||||
|Host: server.example.com
|
||||
|Upgrade: websocket
|
||||
|Connection: Upgrade
|
||||
|Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
|
||||
|Origin: http://example.com
|
||||
|Sec-WebSocket-Version: 13
|
||||
|
|
||||
|""".
|
||||
stripMarginWithNewline("\r\n"))
|
||||
val request = expectRequest
|
||||
val upgrade = request.header[UpgradeToWebsocket]
|
||||
upgrade.isDefined shouldBe true
|
||||
|
||||
val response = upgrade.get.handleMessages(Flow[Message]) // simple echoing
|
||||
responsesSub.sendNext(response)
|
||||
|
||||
wipeDate(expectNextChunk().utf8String) shouldEqual
|
||||
"""HTTP/1.1 101 Switching Protocols
|
||||
|Upgrade: websocket
|
||||
|Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
|
||||
|Server: akka-http/test
|
||||
|Date: XXXX
|
||||
|Connection: upgrade
|
||||
|
|
||||
|""".
|
||||
stripMarginWithNewline("\r\n")
|
||||
|
||||
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true, mask = true)
|
||||
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true)
|
||||
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 2"), fin = true, mask = true)
|
||||
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 2"), fin = true)
|
||||
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 3"), fin = true, mask = true)
|
||||
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 3"), fin = true)
|
||||
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 4"), fin = true, mask = true)
|
||||
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 4"), fin = true)
|
||||
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 5"), fin = true, mask = true)
|
||||
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 5"), fin = true)
|
||||
|
||||
sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true)
|
||||
expectWSCloseFrame(Protocol.CloseCodes.Regular)
|
||||
|
||||
closeNetworkInput()
|
||||
expectNetworkClose()
|
||||
}
|
||||
}
|
||||
"prevent the selection of an unavailable subprotocol" in pending
|
||||
"reject invalid Websocket handshakes" - {
|
||||
"missing `Connection: upgrade` header" in pending
|
||||
"missing `Sec-WebSocket-Key header" in pending
|
||||
"`Sec-WebSocket-Key` with wrong amount of base64 encoded data" in pending
|
||||
"missing `Sec-WebSocket-Version` header" in pending
|
||||
"unsupported `Sec-WebSocket-Version`" in pending
|
||||
}
|
||||
}
|
||||
|
||||
class TestSetup extends HttpServerTestSetupBase {
|
||||
implicit def system = spec.system
|
||||
implicit def materializer = spec.materializer
|
||||
|
||||
def sendWSFrame(opcode: Opcode,
|
||||
data: ByteString,
|
||||
fin: Boolean,
|
||||
mask: Boolean = false,
|
||||
rsv1: Boolean = false,
|
||||
rsv2: Boolean = false,
|
||||
rsv3: Boolean = false): Unit = {
|
||||
val (theMask, theData) =
|
||||
if (mask) {
|
||||
val m = Random.nextInt()
|
||||
(Some(m), maskedBytes(data, m)._1)
|
||||
} else (None, data)
|
||||
send(frameHeader(opcode, data.length, fin, theMask, rsv1, rsv2, rsv3) ++ theData)
|
||||
}
|
||||
|
||||
def sendWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit =
|
||||
send(closeFrame(closeCode, mask))
|
||||
|
||||
def expectNextChunk(): ByteString = {
|
||||
netOutSub.request(1)
|
||||
netOut.expectNext()
|
||||
}
|
||||
|
||||
def expectWSFrame(opcode: Opcode,
|
||||
data: ByteString,
|
||||
fin: Boolean,
|
||||
mask: Option[Int] = None,
|
||||
rsv1: Boolean = false,
|
||||
rsv2: Boolean = false,
|
||||
rsv3: Boolean = false): Unit =
|
||||
expectNextChunk() shouldEqual frameHeader(opcode, data.length, fin, mask, rsv1, rsv2, rsv3) ++ data
|
||||
def expectWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit =
|
||||
expectNextChunk() shouldEqual closeFrame(closeCode, mask)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue