=htc #17322 fix race condition between websocket http response and first websocket message

The fix was to make the switch atomic instead of relying on side-channels
which introduced an ordering dependency in the protocol merge.

Also, the switch token itself is now fed back from the response renderer
output to the protocol switch leading to a simplified design.
This commit is contained in:
Johannes Rudolph 2015-04-29 12:35:04 +02:00
parent 444a0deef5
commit c283dd45ae
10 changed files with 354 additions and 171 deletions

View file

@ -4,13 +4,12 @@
package akka.http.impl.engine.rendering
import akka.http.impl.engine.ws.{ WebsocketSwitch, UpgradeToWebsocketResponseHeader, Handshake }
import akka.http.impl.engine.ws.{ FrameEvent, UpgradeToWebsocketResponseHeader }
import scala.annotation.tailrec
import akka.event.LoggingAdapter
import akka.util.ByteString
import akka.stream.OperationAttributes._
import akka.stream.scaladsl.Source
import akka.stream.scaladsl.{ Flow, Source }
import akka.stream.stage._
import akka.http.scaladsl.model._
import akka.http.impl.util._
@ -23,8 +22,7 @@ import headers._
*/
private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Server],
responseHeaderSizeHint: Int,
log: LoggingAdapter,
websocketSwitch: Option[WebsocketSwitch] = None) {
log: LoggingAdapter) {
private val renderDefaultServerHeader: Rendering Unit =
serverHeader match {
@ -54,14 +52,17 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
def newRenderer: HttpResponseRenderer = new HttpResponseRenderer
final class HttpResponseRenderer extends PushStage[ResponseRenderingContext, Source[ByteString, Any]] {
final class HttpResponseRenderer extends PushStage[ResponseRenderingContext, Source[ResponseRenderingOutput, Any]] {
private[this] var close = false // signals whether the connection is to be closed after the current response
private[this] var closeMode: CloseMode = DontClose // signals what to do after the current response
private[this] def close: Boolean = closeMode != DontClose
private[this] def closeIf(cond: Boolean): Unit =
if (cond) closeMode = CloseConnection
// need this for testing
private[http] def isComplete = close
override def onPush(ctx: ResponseRenderingContext, opCtx: Context[Source[ByteString, Any]]): SyncDirective = {
override def onPush(ctx: ResponseRenderingContext, opCtx: Context[Source[ResponseRenderingOutput, Any]]): SyncDirective = {
val r = new ByteStringRendering(responseHeaderSizeHint)
import ctx.response._
@ -133,7 +134,7 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
if (!dateSeen) r ~~ dateHeader
// Do we close the connection after this response?
close =
closeIf {
// if we are prohibited to keep-alive by the spec
alwaysClose ||
// if the client wants to close and we don't override
@ -143,6 +144,7 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
case `HTTP/1.1` (connHeader ne null) && connHeader.hasClose
case `HTTP/1.0` if (connHeader eq null) ctx.requestProtocol == `HTTP/1.1` else !connHeader.hasKeepAlive
})
}
// Do we render an explicit Connection header?
val renderConnectionHeader =
@ -152,10 +154,11 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
if (renderConnectionHeader)
r ~~ Connection ~~ (if (close) CloseBytes else KeepAliveBytes) ~~ CrLf
else if (connHeader != null && connHeader.hasUpgrade && websocketSwitch.isDefined) {
else if (connHeader != null && connHeader.hasUpgrade) {
r ~~ connHeader ~~ CrLf
val websocketHeader = headers.collectFirst { case u: UpgradeToWebsocketResponseHeader u }
websocketHeader.foreach(header websocketSwitch.get.switchToWebsocket(header.handlerFlow)(header.mat))
headers
.collectFirst { case u: UpgradeToWebsocketResponseHeader u }
.foreach { header closeMode = SwitchToWebsocket(header.handlerFlow) }
}
if (mustRenderTransferEncodingChunkedHeader && !transferEncodingSeen)
r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf
@ -164,17 +167,24 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
def renderContentLengthHeader(contentLength: Long) =
if (status.allowsEntity) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r
def byteStrings(entityBytes: Source[ByteString, Any]): Source[ByteString, Any] =
renderByteStrings(r, entityBytes, skipEntity = noEntity)
def byteStrings(entityBytes: Source[ByteString, Any]): Source[ResponseRenderingOutput, Any] =
renderByteStrings(r, entityBytes, skipEntity = noEntity).map(ResponseRenderingOutput.HttpData(_))
def completeResponseRendering(entity: ResponseEntity): Source[ByteString, Any] =
def completeResponseRendering(entity: ResponseEntity): Source[ResponseRenderingOutput, Any] =
entity match {
case HttpEntity.Strict(_, data)
renderHeaders(headers.toList)
renderEntityContentType(r, entity)
renderContentLengthHeader(data.length) ~~ CrLf
val entityBytes = if (noEntity) ByteString.empty else data
Source.single(r.get ++ entityBytes)
if (!noEntity) r ~~ data
Source.single {
closeMode match {
case SwitchToWebsocket(handler) ResponseRenderingOutput.SwitchToWebsocket(r.get, handler)
case _ ResponseRenderingOutput.HttpData(r.get)
}
}
case HttpEntity.Default(_, contentLength, data)
renderHeaders(headers.toList)
@ -205,6 +215,11 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
opCtx.push(result)
}
}
sealed trait CloseMode
case object DontClose extends CloseMode
case object CloseConnection extends CloseMode
case class SwitchToWebsocket(handlerFlow: Flow[FrameEvent, FrameEvent, Any]) extends CloseMode
}
/**
@ -215,3 +230,11 @@ private[http] final case class ResponseRenderingContext(
requestMethod: HttpMethod = HttpMethods.GET,
requestProtocol: HttpProtocol = HttpProtocols.`HTTP/1.1`,
closeRequested: Boolean = false)
/** INTERNAL API */
private[http] sealed trait ResponseRenderingOutput
/** INTERNAL API */
private[http] object ResponseRenderingOutput {
private[http] case class HttpData(bytes: ByteString) extends ResponseRenderingOutput
private[http] case class SwitchToWebsocket(httpResponseBytes: ByteString, handlerFlow: Flow[FrameEvent, FrameEvent, Any]) extends ResponseRenderingOutput
}

View file

@ -13,10 +13,10 @@ import akka.actor.{ ActorRef, Props }
import akka.stream._
import akka.stream.scaladsl._
import akka.stream.stage.PushPullStage
import akka.stream.scaladsl.FlexiMerge.{ ReadAny, MergeLogic }
import akka.stream.scaladsl.FlexiMerge.{ Read, ReadAny, MergeLogic }
import akka.stream.scaladsl.FlexiRoute.{ DemandFrom, RouteLogic }
import akka.http.impl.engine.parsing._
import akka.http.impl.engine.rendering.{ ResponseRenderingContext, HttpResponseRendererFactory }
import akka.http.impl.engine.rendering.{ ResponseRenderingOutput, ResponseRenderingContext, HttpResponseRendererFactory }
import akka.http.impl.engine.TokenSourceActor
import akka.http.scaladsl.model._
import akka.http.impl.util._
@ -42,8 +42,8 @@ private[http] object HttpServerBluePrint {
logParsingError(info withSummaryPrepended "Illegal request header", log, parserSettings.errorLoggingVerbosity)
})
val ws = websocketPipeline
val responseRendererFactory = new HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint, log, Some(ws))
val ws = websocketSetup
val responseRendererFactory = new HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint, log)
@volatile var oneHundredContinueRef: Option[ActorRef] = None // FIXME: unnecessary after fixing #16168
val oneHundredContinueSource = StreamUtils.oneTimeSource(Source.actorPublisher[OneHundredContinue.type] {
@ -88,7 +88,7 @@ private[http] object HttpServerBluePrint {
.via(Flow[ResponseRenderingContext].transform(() new ErrorsTo500ResponseRecovery(log)).named("recover")) // FIXME: simplify after #16394 is closed
.via(Flow[ResponseRenderingContext].transform(() responseRendererFactory.newRenderer).named("renderer"))
.flatten(FlattenStrategy.concat)
.via(Flow[ByteString].transform(() errorLogger(log, "Outgoing response stream error")).named("errorLogger"))
.via(Flow[ResponseRenderingOutput].transform(() errorLogger(log, "Outgoing response stream error")).named("errorLogger"))
FlowGraph.partial(requestParsingFlow, rendererPipeline, oneHundredContinueSource)((_, _, _) ()) { implicit b
(requestParsing, renderer, oneHundreds)
@ -107,36 +107,41 @@ private[http] object HttpServerBluePrint {
bypassFanout.out(1) ~> bypass ~> bypassInput
oneHundreds ~> bypassOneHundredContinueInput
val http = FlowShape(requestParsing.inlet, renderer.outlet)
val switchTokenBroadcast = b.add(Broadcast[ResponseRenderingOutput](2))
renderer.outlet ~> switchTokenBroadcast
val switchSource: Outlet[SwitchToWebsocketToken.type] =
(switchTokenBroadcast ~>
Flow[ResponseRenderingOutput]
.collect {
case _: ResponseRenderingOutput.SwitchToWebsocket SwitchToWebsocketToken
}).outlet
val http = FlowShape(requestParsing.inlet, switchTokenBroadcast.outlet)
// Websocket pipeline
val websocket = b.add(ws.flow)
val websocket = b.add(ws.websocketFlow)
// protocol routing
val protocolRouter = b.add(new WebsocketSwitchRouter())
val protocolMerge = b.add(new WebsocketMerge)
val protocolMerge = b.add(new WebsocketMerge(ws.installHandler))
protocolRouter.out0 ~> http ~> protocolMerge.in0
protocolRouter.out1 ~> websocket ~> protocolMerge.in1
// protocol switching
val wsSwitchTokenMerge = b.add(new StreamUtils.EagerCloseMerge2[AnyRef]("protocolSwitchWsTokenMerge"))
val switchTokenBroadcast = b.add(Broadcast[SwitchToWebsocketToken.type](2))
ws.switchSource ~> switchTokenBroadcast.in
switchTokenBroadcast.out(0) ~> wsSwitchTokenMerge.in1
wsSwitchTokenMerge.out /*~> printEvent[AnyRef]("netIn")*/ ~> protocolRouter.in
switchTokenBroadcast.out(1) ~> protocolMerge.in2
val netIn = wsSwitchTokenMerge.in0
val wsSwitchTokenMerge = b.add(new CloseIfFirstClosesMerge2[AnyRef]("protocolSwitchWsTokenMerge"))
// feed back switch signal to the protocol router
switchSource ~> wsSwitchTokenMerge.in1
wsSwitchTokenMerge.out ~> protocolRouter.in
val netOutPrint = b.add( /*printEvent[ByteString]("netOut")*/ Flow[ByteString])
protocolMerge.out ~> netOutPrint.inlet
val netOut = netOutPrint.outlet
val netIn = wsSwitchTokenMerge.in0
val netOut = protocolMerge.out
BidiShape[HttpResponse, ByteString, ByteString, HttpRequest](
bypassApplicationInput,
netOut,
netIn,
requestsIn)
}
}
@ -273,29 +278,11 @@ private[http] object HttpServerBluePrint {
}
}
case class WebsocketSetup(
flow: Flow[ByteString, ByteString, Any],
publisherKey: StreamUtils.ReadableCell[Publisher[FrameEvent]],
subscriberKey: StreamUtils.ReadableCell[Subscriber[FrameEvent]],
switchSource: Source[SwitchToWebsocketToken.type, Any]) extends WebsocketSwitch {
@volatile var switchToWebsocketRef: Option[ActorRef] = None
def switchToWebsocket(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit = {
// 1. fill processing hole in the websocket pipeline with user-provided handler
Source(publisherKey.value)
.via(handlerFlow)
.to(Sink(subscriberKey.value))
.run()
// 1. and 2. could be racy in which case incoming data could arrive because of 2. before
// the pipeline in 1. has been established. The `PublisherSink`, should then, however, backpressure
// until the subscriber has connected (i.e. 1. has run).
// 2. flip the switch
switchToWebsocketRef.get ! TokenSourceActor.Trigger
}
trait WebsocketSetup {
def websocketFlow: Flow[ByteString, ByteString, Any]
def installHandler(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit
}
def websocketPipeline: WebsocketSetup = {
def websocketSetup: WebsocketSetup = {
val sinkCell = new StreamUtils.OneTimeWriteCell[Publisher[FrameEvent]]
val sourceCell = new StreamUtils.OneTimeWriteCell[Subscriber[FrameEvent]]
@ -308,16 +295,15 @@ private[http] object HttpServerBluePrint {
.via(Flow.wrap(sink, source)((_, _) ()))
.transform(() new FrameEventRenderer)
lazy val setup = WebsocketSetup(flow, sinkCell, sourceCell, switchToWebsocketSource)
lazy val switchToWebsocketSource: Source[SwitchToWebsocketToken.type, ActorRef] =
Source.actorPublisher[SwitchToWebsocketToken.type] {
Props {
val actor = new TokenSourceActor(SwitchToWebsocketToken)
setup.switchToWebsocketRef = Some(actor.context.self)
actor
}
}
setup
new WebsocketSetup {
def websocketFlow: Flow[ByteString, ByteString, Any] = flow
def installHandler(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit =
Source(sinkCell.value)
.via(handlerFlow)
.to(Sink(sourceCell.value))
.run()
}
}
class WebsocketSwitchRouter
extends FlexiRoute[AnyRef, FanOutShape2[AnyRef, ByteString, ByteString]](new FanOutShape2("websocketSplit"), OperationAttributes.name("websocketSplit")) {
@ -345,36 +331,44 @@ private[http] object HttpServerBluePrint {
}
}
}
class WebsocketMerge extends FlexiMerge[ByteString, FanInShape3[ByteString, ByteString, SwitchToWebsocketToken.type, ByteString]](new FanInShape3("websocketMerge"), OperationAttributes.name("websocketMerge")) {
def createMergeLogic(s: FanInShape3[ByteString, ByteString, SwitchToWebsocketToken.type, ByteString]): MergeLogic[ByteString] =
class WebsocketMerge(installHandler: Flow[FrameEvent, FrameEvent, Any] Unit) extends FlexiMerge[ByteString, FanInShape2[ResponseRenderingOutput, ByteString, ByteString]](new FanInShape2("websocketMerge"), OperationAttributes.name("websocketMerge")) {
def createMergeLogic(s: FanInShape2[ResponseRenderingOutput, ByteString, ByteString]): MergeLogic[ByteString] =
new MergeLogic[ByteString] {
def httpIn = s.in0
def wsIn = s.in1
def tokenIn = s.in2
def initialState: State[_] = http
def http: State[_] = State[AnyRef](ReadAny(httpIn.asInstanceOf[Inlet[AnyRef]], tokenIn.asInstanceOf[Inlet[AnyRef]])) { (ctx, in, element)
def http: State[_] = State[ResponseRenderingOutput](Read(httpIn)) { (ctx, in, element)
element match {
case b: ByteString
ctx.emit(b); SameState
case SwitchToWebsocketToken
ctx.changeCompletionHandling(closeWhenInCloses(wsIn))
websockets
case ResponseRenderingOutput.HttpData(bytes)
ctx.emit(bytes); SameState
case ResponseRenderingOutput.SwitchToWebsocket(responseBytes, handlerFlow)
ctx.emit(responseBytes)
installHandler(handlerFlow)
websocket
}
}
def websockets: State[_] = State[ByteString](ReadAny(httpIn /* otherwise we won't read the websocket upgrade response */ , wsIn)) { (ctx, _, element)
ctx.emit(element)
def websocket: State[_] = State[ByteString](Read(wsIn)) { (ctx, in, bytes)
ctx.emit(bytes)
SameState
}
}
}
/** A merge for two streams that just forwards all elements and closes the connection when the first input closes. */
class CloseIfFirstClosesMerge2[T](name: String) extends FlexiMerge[T, FanInShape2[T, T, T]](new FanInShape2(name), OperationAttributes.name(name)) {
def createMergeLogic(s: FanInShape2[T, T, T]): MergeLogic[T] =
new MergeLogic[T] {
def initialState: State[T] = State[T](ReadAny(s.in0, s.in1)) {
case (ctx, port, in) ctx.emit(in); SameState
}
def closeWhenInCloses(in: Inlet[_]): CompletionHandling =
defaultCompletionHandling.copy(onUpstreamFinish = { (ctx, closingIn)
if (closingIn == in) ctx.finish()
SameState
})
override def initialCompletionHandling: CompletionHandling = closeWhenInCloses(httpIn)
override def initialCompletionHandling: CompletionHandling =
defaultCompletionHandling.copy(
onUpstreamFinish = { (ctx, in)
if (in == s.in0) ctx.finish()
SameState
})
}
}
}

View file

@ -1,13 +0,0 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import akka.stream.FlowMaterializer
import akka.stream.scaladsl.Flow
/** Internal interface between the handshake and the stream setup to evoke the switch to the websocket protocol */
private[http] trait WebsocketSwitch {
def switchToWebsocket(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit
}

View file

@ -98,10 +98,10 @@ package object util {
}
}
private[http] def errorLogger(log: LoggingAdapter, msg: String): PushStage[ByteString, ByteString] =
new PushStage[ByteString, ByteString] {
override def onPush(element: ByteString, ctx: Context[ByteString]): SyncDirective = ctx.push(element)
override def onUpstreamFailure(cause: Throwable, ctx: Context[ByteString]): TerminationDirective = {
private[http] def errorLogger[T](log: LoggingAdapter, msg: String): PushStage[T, T] =
new PushStage[T, T] {
override def onPush(element: T, ctx: Context[T]): SyncDirective = ctx.push(element)
override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = {
log.error(cause, msg)
super.onUpstreamFailure(cause, ctx)
}

View file

@ -539,7 +539,7 @@ class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll
override def afterAll() = system.shutdown()
class TestSetup(val serverHeader: Option[Server] = Some(Server("akka-http/1.0.0")))
extends HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint = 64, NoLogging, None) {
extends HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint = 64, NoLogging) {
def renderTo(expected: String): Matcher[HttpResponse] =
renderTo(expected, close = false) compose (ResponseRenderingContext(_))
@ -547,10 +547,15 @@ class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll
def renderTo(expected: String, close: Boolean): Matcher[ResponseRenderingContext] =
equal(expected.stripMarginWithNewline("\r\n") -> close).matcher[(String, Boolean)] compose { ctx
val renderer = newRenderer
val byteStringSource = Await.result(Source.single(ctx)
val rendererOutputSource = Await.result(Source.single(ctx)
.transform(() renderer).named("renderer")
.runWith(Sink.head), 1.second)
val future = byteStringSource.grouped(1000).runWith(Sink.head).map(_.reduceLeft(_ ++ _).utf8String)
val future =
rendererOutputSource.grouped(1000).map(
_.map {
case ResponseRenderingOutput.HttpData(bytes) bytes
case _: ResponseRenderingOutput.SwitchToWebsocket throw new IllegalStateException("Didn't expect websocket response")
}).runWith(Sink.head).map(_.reduceLeft(_ ++ _).utf8String)
Await.result(future, 250.millis) -> renderer.isComplete
}

View file

@ -11,7 +11,6 @@ import scala.util.Random
import scala.annotation.tailrec
import scala.concurrent.duration._
import org.scalatest.Inside
import akka.event.NoLogging
import akka.util.ByteString
import akka.stream.scaladsl._
import akka.stream.{ FlowMaterializer, ActorFlowMaterializer }
@ -662,7 +661,7 @@ class HttpServerSpec extends AkkaSpec("akka.loggers = []\n akka.loglevel = OFF")
}
}
class TestSetup extends HttpServerTestSetupBase {
implicit def system: ActorSystem = spec.system
implicit def materializer: FlowMaterializer = spec.materializer
implicit def system = spec.system
implicit def materializer = spec.materializer
}
}

View file

@ -61,6 +61,7 @@ abstract class HttpServerTestSetupBase {
requests.expectNext()
}
def expectNoRequest(max: FiniteDuration): Unit = requests.expectNoMsg(max)
def expectNetworkClose(): Unit = netOut.expectComplete()
def send(data: ByteString): Unit = netInSub.sendNext(data)
def send(data: String): Unit = send(ByteString(data, "UTF8"))

View file

@ -16,6 +16,8 @@ import akka.http.scaladsl.model.ws._
import Protocol.Opcode
class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
import WSTestUtils._
"The Websocket implementation should" - {
"collect messages from frames" - {
"for binary messages" - {
@ -902,68 +904,6 @@ class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
netOut.expectNoMsg(100.millis)
}
def frameHeader(
opcode: Opcode,
length: Long,
fin: Boolean,
mask: Option[Int] = None,
rsv1: Boolean = false,
rsv2: Boolean = false,
rsv3: Boolean = false): ByteString = {
def set(should: Boolean, mask: Int): Int =
if (should) mask else 0
val flags =
set(fin, Protocol.FIN_MASK) |
set(rsv1, Protocol.RSV1_MASK) |
set(rsv2, Protocol.RSV2_MASK) |
set(rsv3, Protocol.RSV3_MASK)
val opcodeByte = opcode.code | flags
require(length >= 0)
val (lengthByteComponent, lengthBytes) =
if (length < 126) (length.toByte, ByteString.empty)
else if (length < 65536) (126.toByte, shortBE(length.toInt))
else throw new IllegalArgumentException("Only lengths < 65536 allowed in test")
val maskMask = if (mask.isDefined) Protocol.MASK_MASK else 0
val maskBytes = mask match {
case Some(mask) intBE(mask)
case None ByteString.empty
}
val lengthByte = lengthByteComponent | maskMask
ByteString(opcodeByte.toByte, lengthByte.toByte) ++ lengthBytes ++ maskBytes
}
def closeFrame(closeCode: Int, mask: Boolean): ByteString =
if (mask) {
val mask = Random.nextInt()
frameHeader(Opcode.Close, 2, fin = true, mask = Some(mask)) ++
maskedBytes(shortBE(closeCode), mask)._1
} else
frameHeader(Opcode.Close, 2, fin = true) ++
shortBE(closeCode)
def maskedASCII(str: String, mask: Int): (ByteString, Int) =
FrameEventParser.mask(ByteString(str, "ASCII"), mask)
def maskedUTF8(str: String, mask: Int): (ByteString, Int) =
FrameEventParser.mask(ByteString(str, "UTF-8"), mask)
def maskedBytes(bytes: ByteString, mask: Int): (ByteString, Int) =
FrameEventParser.mask(bytes, mask)
def shortBE(value: Int): ByteString = {
require(value >= 0 && value < 65536, s"Value wasn't in short range: $value")
ByteString(
((value >> 8) & 0xff).toByte,
((value >> 0) & 0xff).toByte)
}
def intBE(value: Int): ByteString =
ByteString(
((value >> 24) & 0xff).toByte,
((value >> 16) & 0xff).toByte,
((value >> 8) & 0xff).toByte,
((value >> 0) & 0xff).toByte)
val trace = false // set to `true` for debugging purposes
def printEvent[T](marker: String): Flow[T, T, Unit] =
if (trace) akka.http.impl.util.printEvent(marker)

View file

@ -0,0 +1,74 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import akka.http.impl.engine.ws.Protocol.Opcode
import akka.util.ByteString
import scala.util.Random
object WSTestUtils {
def frameHeader(
opcode: Opcode,
length: Long,
fin: Boolean,
mask: Option[Int] = None,
rsv1: Boolean = false,
rsv2: Boolean = false,
rsv3: Boolean = false): ByteString = {
def set(should: Boolean, mask: Int): Int =
if (should) mask else 0
val flags =
set(fin, Protocol.FIN_MASK) |
set(rsv1, Protocol.RSV1_MASK) |
set(rsv2, Protocol.RSV2_MASK) |
set(rsv3, Protocol.RSV3_MASK)
val opcodeByte = opcode.code | flags
require(length >= 0)
val (lengthByteComponent, lengthBytes) =
if (length < 126) (length.toByte, ByteString.empty)
else if (length < 65536) (126.toByte, shortBE(length.toInt))
else throw new IllegalArgumentException("Only lengths < 65536 allowed in test")
val maskMask = if (mask.isDefined) Protocol.MASK_MASK else 0
val maskBytes = mask match {
case Some(mask) intBE(mask)
case None ByteString.empty
}
val lengthByte = lengthByteComponent | maskMask
ByteString(opcodeByte.toByte, lengthByte.toByte) ++ lengthBytes ++ maskBytes
}
def closeFrame(closeCode: Int, mask: Boolean): ByteString =
if (mask) {
val mask = Random.nextInt()
frameHeader(Opcode.Close, 2, fin = true, mask = Some(mask)) ++
maskedBytes(shortBE(closeCode), mask)._1
} else
frameHeader(Opcode.Close, 2, fin = true) ++
shortBE(closeCode)
def maskedASCII(str: String, mask: Int): (ByteString, Int) =
FrameEventParser.mask(ByteString(str, "ASCII"), mask)
def maskedUTF8(str: String, mask: Int): (ByteString, Int) =
FrameEventParser.mask(ByteString(str, "UTF-8"), mask)
def maskedBytes(bytes: ByteString, mask: Int): (ByteString, Int) =
FrameEventParser.mask(bytes, mask)
def shortBE(value: Int): ByteString = {
require(value >= 0 && value < 65536, s"Value wasn't in short range: $value")
ByteString(
((value >> 8) & 0xff).toByte,
((value >> 0) & 0xff).toByte)
}
def intBE(value: Int): ByteString =
ByteString(
((value >> 24) & 0xff).toByte,
((value >> 16) & 0xff).toByte,
((value >> 8) & 0xff).toByte,
((value >> 0) & 0xff).toByte)
}

View file

@ -0,0 +1,160 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import akka.http.impl.engine.ws.Protocol.Opcode
import akka.http.scaladsl.model.ws._
import akka.stream.scaladsl.{ Sink, Flow, Source }
import akka.util.ByteString
import org.scalatest.{ Matchers, FreeSpec }
import akka.http.impl.util._
import akka.http.impl.engine.server.HttpServerTestSetupBase
import scala.util.Random
class WebsocketServerSpec extends FreeSpec with Matchers with WithMaterializerSpec { spec
import WSTestUtils._
"The server-side Websocket integration should" - {
"establish a websocket connection when the user requests it" - {
"when user handler instantly tries to send messages" in new TestSetup {
send(
"""GET /chat HTTP/1.1
|Host: server.example.com
|Upgrade: websocket
|Connection: Upgrade
|Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
|Origin: http://example.com
|Sec-WebSocket-Version: 13
|
|""".
stripMarginWithNewline("\r\n"))
val request = expectRequest
val upgrade = request.header[UpgradeToWebsocket]
upgrade.isDefined shouldBe true
val source =
Source(List(1, 2, 3, 4, 5)).map(num TextMessage.Strict(s"Message $num"))
val handler = Flow.wrap(Sink.ignore, source)((_, _) ())
val response = upgrade.get.handleMessages(handler)
responsesSub.sendNext(response)
wipeDate(expectNextChunk().utf8String) shouldEqual
"""HTTP/1.1 101 Switching Protocols
|Upgrade: websocket
|Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
|Server: akka-http/test
|Date: XXXX
|Connection: upgrade
|
|""".
stripMarginWithNewline("\r\n")
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true)
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 2"), fin = true)
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 3"), fin = true)
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 4"), fin = true)
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 5"), fin = true)
expectWSCloseFrame(Protocol.CloseCodes.Regular)
}
"for echoing user handler" in new TestSetup {
send(
"""GET /echo HTTP/1.1
|Host: server.example.com
|Upgrade: websocket
|Connection: Upgrade
|Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
|Origin: http://example.com
|Sec-WebSocket-Version: 13
|
|""".
stripMarginWithNewline("\r\n"))
val request = expectRequest
val upgrade = request.header[UpgradeToWebsocket]
upgrade.isDefined shouldBe true
val response = upgrade.get.handleMessages(Flow[Message]) // simple echoing
responsesSub.sendNext(response)
wipeDate(expectNextChunk().utf8String) shouldEqual
"""HTTP/1.1 101 Switching Protocols
|Upgrade: websocket
|Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
|Server: akka-http/test
|Date: XXXX
|Connection: upgrade
|
|""".
stripMarginWithNewline("\r\n")
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true, mask = true)
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 1"), fin = true)
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 2"), fin = true, mask = true)
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 2"), fin = true)
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 3"), fin = true, mask = true)
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 3"), fin = true)
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 4"), fin = true, mask = true)
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 4"), fin = true)
sendWSFrame(Protocol.Opcode.Text, ByteString("Message 5"), fin = true, mask = true)
expectWSFrame(Protocol.Opcode.Text, ByteString("Message 5"), fin = true)
sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true)
expectWSCloseFrame(Protocol.CloseCodes.Regular)
closeNetworkInput()
expectNetworkClose()
}
}
"prevent the selection of an unavailable subprotocol" in pending
"reject invalid Websocket handshakes" - {
"missing `Connection: upgrade` header" in pending
"missing `Sec-WebSocket-Key header" in pending
"`Sec-WebSocket-Key` with wrong amount of base64 encoded data" in pending
"missing `Sec-WebSocket-Version` header" in pending
"unsupported `Sec-WebSocket-Version`" in pending
}
}
class TestSetup extends HttpServerTestSetupBase {
implicit def system = spec.system
implicit def materializer = spec.materializer
def sendWSFrame(opcode: Opcode,
data: ByteString,
fin: Boolean,
mask: Boolean = false,
rsv1: Boolean = false,
rsv2: Boolean = false,
rsv3: Boolean = false): Unit = {
val (theMask, theData) =
if (mask) {
val m = Random.nextInt()
(Some(m), maskedBytes(data, m)._1)
} else (None, data)
send(frameHeader(opcode, data.length, fin, theMask, rsv1, rsv2, rsv3) ++ theData)
}
def sendWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit =
send(closeFrame(closeCode, mask))
def expectNextChunk(): ByteString = {
netOutSub.request(1)
netOut.expectNext()
}
def expectWSFrame(opcode: Opcode,
data: ByteString,
fin: Boolean,
mask: Option[Int] = None,
rsv1: Boolean = false,
rsv2: Boolean = false,
rsv3: Boolean = false): Unit =
expectNextChunk() shouldEqual frameHeader(opcode, data.length, fin, mask, rsv1, rsv2, rsv3) ++ data
def expectWSCloseFrame(closeCode: Int, mask: Boolean = false): Unit =
expectNextChunk() shouldEqual closeFrame(closeCode, mask)
}
}