Merge pull request #17248 from spray/w/websockets

Server-side Websocket support
This commit is contained in:
Roland Kuhn 2015-04-23 14:26:25 +02:00
commit e7816b2979
46 changed files with 3651 additions and 30 deletions

View file

@ -40,10 +40,10 @@ class HttpExt(config: Config)(implicit system: ActorSystem) extends akka.actor.E
val connections: Source[StreamTcp.IncomingConnection, Future[StreamTcp.ServerBinding]] =
StreamTcp().bind(endpoint, backlog, options, effectiveSettings.timeouts.idleTimeout)
val layer = serverLayer(effectiveSettings, log)
connections.map {
case StreamTcp.IncomingConnection(localAddress, remoteAddress, flow)
val layer = serverLayer(effectiveSettings, log)
IncomingConnection(localAddress, remoteAddress, layer join flow)
}.mapMaterialized { tcpBindingFuture
import system.dispatcher

View file

@ -5,6 +5,8 @@
package akka.http.engine.parsing
import java.lang.{ StringBuilder JStringBuilder }
import akka.http.engine.ws.Handshake
import scala.annotation.tailrec
import akka.actor.ActorRef
import akka.stream.OperationAttributes._
@ -121,9 +123,18 @@ private[http] class HttpRequestParser(_settings: ParserSettings,
if (hostHeaderPresent || protocol == HttpProtocols.`HTTP/1.0`) {
def emitRequestStart(createEntity: Source[RequestOutput, Unit] RequestEntity,
headers: List[HttpHeader] = headers) = {
val allHeaders =
val allHeaders0 =
if (rawRequestUriHeader) `Raw-Request-URI`(new String(uriBytes, HttpCharsets.`US-ASCII`.nioCharset)) :: headers
else headers
val allHeaders =
if (method == HttpMethods.GET) {
Handshake.isWebsocketUpgrade(headers, hostHeaderPresent) match {
case Some(upgrade) upgrade :: allHeaders0
case None allHeaders0
}
} else allHeaders0
emit(RequestStart(method, uri, protocol, allHeaders, createEntity, expect100continue, closeAfterResponseCompletion))
}

View file

@ -4,6 +4,8 @@
package akka.http.engine.rendering
import akka.http.engine.ws.{ WebsocketSwitch, UpgradeToWebsocketResponseHeader, Handshake }
import scala.annotation.tailrec
import akka.event.LoggingAdapter
import akka.util.ByteString
@ -21,7 +23,8 @@ import headers._
*/
private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Server],
responseHeaderSizeHint: Int,
log: LoggingAdapter) {
log: LoggingAdapter,
websocketSwitch: Option[WebsocketSwitch] = None) {
private val renderDefaultServerHeader: Rendering Unit =
serverHeader match {
@ -149,6 +152,11 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
if (renderConnectionHeader)
r ~~ Connection ~~ (if (close) CloseBytes else KeepAliveBytes) ~~ CrLf
else if (connHeader != null && connHeader.hasUpgrade && websocketSwitch.isDefined) {
r ~~ connHeader ~~ CrLf
val websocketHeader = headers.collectFirst { case u: UpgradeToWebsocketResponseHeader u }
websocketHeader.foreach(header websocketSwitch.get.switchToWebsocket(header.handlerFlow)(header.mat))
}
if (mustRenderTransferEncodingChunkedHeader && !transferEncodingSeen)
r ~~ `Transfer-Encoding` ~~ ChunkedBytes ~~ CrLf
}

View file

@ -4,6 +4,13 @@
package akka.http.engine.server
import akka.stream.scaladsl.FlexiMerge.{ ReadAny, MergeLogic }
import akka.stream.scaladsl._
import akka.http.engine.ws._
import akka.stream.scaladsl.FlexiRoute.{ DemandFrom, RouteLogic }
import org.reactivestreams.{ Subscriber, Publisher }
import scala.util.control.NonFatal
import akka.util.ByteString
import akka.event.LoggingAdapter
@ -18,6 +25,7 @@ import akka.http.engine.TokenSourceActor
import akka.http.model._
import akka.http.util._
import ParserOutput._
import akka.http.engine.ws.Websocket.{ SwitchToWebsocketToken }
/**
* INTERNAL API
@ -37,7 +45,8 @@ private[http] object HttpServerBluePrint {
logParsingError(info withSummaryPrepended "Illegal request header", log, parserSettings.errorLoggingVerbosity)
})
val responseRendererFactory = new HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint, log)
val ws = websocketPipeline
val responseRendererFactory = new HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint, log, Some(ws))
@volatile var oneHundredContinueRef: Option[ActorRef] = None // FIXME: unnecessary after fixing #16168
val oneHundredContinueSource = Source.actorPublisher[OneHundredContinue.type] {
@ -94,17 +103,43 @@ private[http] object HttpServerBluePrint {
val bypassOneHundredContinueInput = bypassMerge.in1
val bypassApplicationInput = bypassMerge.in2
// HTTP pipeline
requestParsing.outlet ~> bypassFanout.in
bypassMerge.out ~> renderer.inlet
val requestsIn = (bypassFanout.out(0) ~> requestPreparation).outlet
bypassFanout.out(1) ~> bypass ~> bypassInput
oneHundredContinueSource ~> bypassOneHundredContinueInput
val http = FlowShape(requestParsing.inlet, renderer.outlet)
// Websocket pipeline
val websocket = b.add(ws.flow)
// protocol routing
val protocolRouter = b.add(new WebsocketSwitchRouter())
val protocolMerge = b.add(new WebsocketMerge)
protocolRouter.out0 ~> http ~> protocolMerge.in0
protocolRouter.out1 ~> websocket ~> protocolMerge.in1
// protocol switching
val wsSwitchTokenMerge = b.add(new StreamUtils.EagerCloseMerge2[AnyRef]("protocolSwitchWsTokenMerge"))
val switchTokenBroadcast = b.add(Broadcast[SwitchToWebsocketToken.type](2))
ws.switchSource ~> switchTokenBroadcast.in
switchTokenBroadcast.out(0) ~> wsSwitchTokenMerge.in1
wsSwitchTokenMerge.out /*~> printEvent[AnyRef]("netIn")*/ ~> protocolRouter.in
switchTokenBroadcast.out(1) ~> protocolMerge.in2
val netIn = wsSwitchTokenMerge.in0
val netOutPrint = b.add( /*printEvent[ByteString]("netOut")*/ Flow[ByteString])
protocolMerge.out ~> netOutPrint.inlet
val netOut = netOutPrint.outlet
BidiShape[HttpResponse, ByteString, ByteString, HttpRequest](
bypassApplicationInput,
renderer.outlet,
requestParsing.inlet,
netOut,
netIn,
requestsIn)
}
}
@ -240,4 +275,109 @@ private[http] object HttpServerBluePrint {
case _ ctx.fail(error)
}
}
case class WebsocketSetup(
flow: Flow[ByteString, ByteString, Any],
publisherKey: StreamUtils.ReadableCell[Publisher[FrameEvent]],
subscriberKey: StreamUtils.ReadableCell[Subscriber[FrameEvent]],
switchSource: Source[SwitchToWebsocketToken.type, Any]) extends WebsocketSwitch {
@volatile var switchToWebsocketRef: Option[ActorRef] = None
def switchToWebsocket(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit = {
// 1. fill processing hole in the websocket pipeline with user-provided handler
Source(publisherKey.value)
.via(handlerFlow)
.to(Sink(subscriberKey.value))
.run()
// 1. and 2. could be racy in which case incoming data could arrive because of 2. before
// the pipeline in 1. has been established. The `PublisherSink`, should then, however, backpressure
// until the subscriber has connected (i.e. 1. has run).
// 2. flip the switch
switchToWebsocketRef.get ! TokenSourceActor.Trigger
}
}
def websocketPipeline: WebsocketSetup = {
val sinkCell = new StreamUtils.OneTimeWriteCell[Publisher[FrameEvent]]
val sourceCell = new StreamUtils.OneTimeWriteCell[Subscriber[FrameEvent]]
val sink = StreamUtils.oneTimePublisherSink[FrameEvent](sinkCell, "frameHandler.in")
val source = StreamUtils.oneTimeSubscriberSource[FrameEvent](sourceCell, "frameHandler.out")
val flow =
Flow[ByteString]
.transform[FrameEvent](() new FrameEventParser)
.via(Flow.wrap(sink, source)((_, _) ()))
.transform(() new FrameEventRenderer)
lazy val setup = WebsocketSetup(flow, sinkCell, sourceCell, switchToWebsocketSource)
lazy val switchToWebsocketSource: Source[SwitchToWebsocketToken.type, ActorRef] =
Source.actorPublisher[SwitchToWebsocketToken.type] {
Props {
val actor = new TokenSourceActor(SwitchToWebsocketToken)
setup.switchToWebsocketRef = Some(actor.context.self)
actor
}
}
setup
}
class WebsocketSwitchRouter
extends FlexiRoute[AnyRef, FanOutShape2[AnyRef, ByteString, ByteString]](new FanOutShape2("websocketSplit"), OperationAttributes.name("websocketSplit")) {
override def createRouteLogic(shape: FanOutShape2[AnyRef, ByteString, ByteString]): RouteLogic[AnyRef] =
new RouteLogic[AnyRef] {
def initialState: State[_] = http
def http: State[_] = State[Any](DemandFrom(shape.out0)) { (ctx, _, element)
element match {
case b: ByteString
// route to HTTP processing
ctx.emit(shape.out0)(b)
SameState
case SwitchToWebsocketToken
// switch to websocket protocol
websockets
}
}
def websockets: State[_] = State[Any](DemandFrom(shape.out1)) { (ctx, _, element)
// route to Websocket processing
ctx.emit(shape.out1)(element.asInstanceOf[ByteString])
SameState
}
}
}
class WebsocketMerge extends FlexiMerge[ByteString, FanInShape3[ByteString, ByteString, SwitchToWebsocketToken.type, ByteString]](new FanInShape3("websocketMerge"), OperationAttributes.name("websocketMerge")) {
def createMergeLogic(s: FanInShape3[ByteString, ByteString, SwitchToWebsocketToken.type, ByteString]): MergeLogic[ByteString] =
new MergeLogic[ByteString] {
def httpIn = s.in0
def wsIn = s.in1
def tokenIn = s.in2
def initialState: State[_] = http
def http: State[_] = State[AnyRef](ReadAny(httpIn.asInstanceOf[Inlet[AnyRef]], tokenIn.asInstanceOf[Inlet[AnyRef]])) { (ctx, in, element)
element match {
case b: ByteString
ctx.emit(b); SameState
case SwitchToWebsocketToken
ctx.changeCompletionHandling(closeWhenInCloses(wsIn))
websockets
}
}
def websockets: State[_] = State[ByteString](ReadAny(httpIn /* otherwise we won't read the websocket upgrade response */ , wsIn)) { (ctx, _, element)
ctx.emit(element)
SameState
}
def closeWhenInCloses(in: Inlet[_]): CompletionHandling =
defaultCompletionHandling.copy(onUpstreamFinish = { (ctx, closingIn)
if (closingIn == in) ctx.finish()
SameState
})
override def initialCompletionHandling: CompletionHandling = closeWhenInCloses(httpIn)
}
}
}

View file

@ -0,0 +1,72 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.http.engine.ws.Protocol.Opcode
import akka.util.ByteString
/**
* The low-level Websocket framing model.
*
* INTERNAL API
*/
private[http] sealed trait FrameEvent {
def data: ByteString
def lastPart: Boolean
def withData(data: ByteString): FrameEvent
}
/**
* Starts a frame. Contains the frame's headers. May contain all the data of the frame if `lastPart == true`. Otherwise,
* following events will be `FrameData` events that contain the remaining data of the frame.
*/
private[http] final case class FrameStart(header: FrameHeader, data: ByteString) extends FrameEvent {
def lastPart: Boolean = data.size == header.length
def withData(data: ByteString): FrameStart = copy(data = data)
def isFullMessage: Boolean = header.fin && header.length == data.length
}
/**
* Frame data that was received after the start of the frame..
*/
private[http] final case class FrameData(data: ByteString, lastPart: Boolean) extends FrameEvent {
def withData(data: ByteString): FrameData = copy(data = data)
}
/** Model of the frame header */
private[http] final case class FrameHeader(opcode: Protocol.Opcode,
mask: Option[Int],
length: Long,
fin: Boolean,
rsv1: Boolean = false,
rsv2: Boolean = false,
rsv3: Boolean = false)
private[http] object FrameEvent {
def empty(opcode: Protocol.Opcode,
fin: Boolean,
rsv1: Boolean = false,
rsv2: Boolean = false,
rsv3: Boolean = false): FrameStart =
fullFrame(opcode, None, ByteString.empty, fin, rsv1, rsv2, rsv3)
def fullFrame(opcode: Protocol.Opcode, mask: Option[Int], data: ByteString,
fin: Boolean,
rsv1: Boolean = false,
rsv2: Boolean = false,
rsv3: Boolean = false): FrameStart =
FrameStart(FrameHeader(opcode, mask, data.length, fin, rsv1, rsv2, rsv3), data)
val emptyLastContinuationFrame: FrameStart =
empty(Protocol.Opcode.Continuation, fin = true)
def closeFrame(closeCode: Int, reason: String = "", mask: Option[Int] = None): FrameStart = {
require(closeCode >= 1000, s"Invalid close code: $closeCode")
val body = ByteString(
((closeCode & 0xff00) >> 8).toByte,
(closeCode & 0xff).toByte) ++ ByteString(reason, "UTF8")
fullFrame(Opcode.Close, mask, FrameEventParser.mask(body, mask), fin = true)
}
}

View file

@ -0,0 +1,162 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.http.util.{ ByteReader, ByteStringParserStage }
import akka.stream.stage.{ StageState, SyncDirective, Context }
import akka.util.ByteString
import scala.annotation.tailrec
/**
* Streaming parser for the Websocket framing protocol as defined in RFC6455
*
* http://tools.ietf.org/html/rfc6455
*
* 0 1 2 3
* 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
* +-+-+-+-+-------+-+-------------+-------------------------------+
* |F|R|R|R| opcode|M| Payload len | Extended payload length |
* |I|S|S|S| (4) |A| (7) | (16/64) |
* |N|V|V|V| |S| | (if payload len==126/127) |
* | |1|2|3| |K| | |
* +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
* | Extended payload length continued, if payload len == 127 |
* + - - - - - - - - - - - - - - - +-------------------------------+
* | |Masking-key, if MASK set to 1 |
* +-------------------------------+-------------------------------+
* | Masking-key (continued) | Payload Data |
* +-------------------------------- - - - - - - - - - - - - - - - +
* : Payload Data continued ... :
* + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
* | Payload Data continued ... |
* +---------------------------------------------------------------+
*
* INTERNAL API
*/
private[http] class FrameEventParser extends ByteStringParserStage[FrameEvent] {
protected def onTruncation(ctx: Context[FrameEvent]): SyncDirective =
ctx.fail(new ProtocolException("Data truncated"))
def initial: StageState[ByteString, FrameEvent] = ReadFrameHeader
object ReadFrameHeader extends ByteReadingState {
def read(reader: ByteReader, ctx: Context[FrameEvent]): SyncDirective = {
import Protocol._
val flagsAndOp = reader.readByte()
val maskAndLength = reader.readByte()
val flags = flagsAndOp & FLAGS_MASK
val op = flagsAndOp & OP_MASK
val maskBit = (maskAndLength & MASK_MASK) != 0
val length7 = maskAndLength & LENGTH_MASK
val length =
length7 match {
case 126 reader.readShortBE().toLong
case 127 reader.readLongBE()
case x x.toLong
}
if (length < 0) ctx.fail(new ProtocolException("Highest bit of 64bit length was set"))
val mask =
if (maskBit) Some(reader.readIntBE())
else None
def isFlagSet(mask: Int): Boolean = (flags & mask) != 0
val header =
FrameHeader(Opcode.forCode(op.toByte),
mask,
length,
fin = isFlagSet(FIN_MASK),
rsv1 = isFlagSet(RSV1_MASK),
rsv2 = isFlagSet(RSV2_MASK),
rsv3 = isFlagSet(RSV3_MASK))
val data = reader.remainingData
val takeNow = (header.length min Int.MaxValue).toInt
val thisFrameData = data.take(takeNow)
val remaining = data.drop(takeNow)
val nextState =
if (thisFrameData.length == length) ReadFrameHeader
else readData(length - thisFrameData.length)
pushAndBecomeWithRemaining(FrameStart(header, thisFrameData), nextState, remaining, ctx)
}
}
def readData(_remaining: Long): State =
new State {
var remaining = _remaining
def onPush(elem: ByteString, ctx: Context[FrameEvent]): SyncDirective =
if (elem.size < remaining) {
remaining -= elem.size
ctx.push(FrameData(elem, lastPart = false))
} else {
assert(remaining <= Int.MaxValue) // safe because, remaining <= elem.size <= Int.MaxValue
val frameData = elem.take(remaining.toInt)
val remainingData = elem.drop(remaining.toInt)
pushAndBecomeWithRemaining(FrameData(frameData, lastPart = true), ReadFrameHeader, remainingData, ctx)
}
}
def becomeWithRemaining(nextState: State, remainingData: ByteString, ctx: Context[FrameEvent]): SyncDirective = {
become(nextState)
nextState.onPush(remainingData, ctx)
}
def pushAndBecomeWithRemaining(elem: FrameEvent, nextState: State, remainingData: ByteString, ctx: Context[FrameEvent]): SyncDirective =
if (remainingData.isEmpty) {
become(nextState)
ctx.push(elem)
} else {
become(waitForPull(nextState, remainingData))
ctx.push(elem)
}
def waitForPull(nextState: State, remainingData: ByteString): State =
new State {
def onPush(elem: ByteString, ctx: Context[FrameEvent]): SyncDirective =
throw new IllegalStateException("Mustn't push in this state")
override def onPull(ctx: Context[FrameEvent]): SyncDirective = {
become(nextState)
nextState.onPush(remainingData, ctx)
}
}
}
object FrameEventParser {
def mask(bytes: ByteString, _mask: Option[Int]): ByteString =
_mask match {
case Some(m) mask(bytes, m)._1
case None bytes
}
def mask(bytes: ByteString, mask: Int): (ByteString, Int) = {
@tailrec def rec(bytes: Array[Byte], offset: Int, mask: Int): Int =
if (offset >= bytes.length) mask
else {
val newMask = Integer.rotateLeft(mask, 8) // we cycle through the mask in BE order
bytes(offset) = (bytes(offset) ^ (newMask & 0xff)).toByte
rec(bytes, offset + 1, newMask)
}
val buffer = bytes.toArray[Byte]
val newMask = rec(buffer, 0, mask)
(ByteString(buffer), newMask)
}
def parseCloseCode(data: ByteString): Option[Int] =
if (data.length >= 2) {
val code = ((data(0) & 0xff) << 8) | (data(1) & 0xff)
if (Protocol.CloseCodes.isValid(code)) Some(code)
else Some(Protocol.CloseCodes.ProtocolError)
} else if (data.length == 1) Some(Protocol.CloseCodes.ProtocolError) // must be >= length 2 if not empty
else None
}

View file

@ -0,0 +1,100 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.util.ByteString
import akka.stream.stage.{ TerminationDirective, StatefulStage, SyncDirective, Context }
import scala.annotation.tailrec
/**
* Renders FrameEvents to ByteString.
*
* INTERNAL API
*/
private[http] class FrameEventRenderer extends StatefulStage[FrameEvent, ByteString] {
def initial: State = Idle
object Idle extends State {
def onPush(elem: FrameEvent, ctx: Context[ByteString]): SyncDirective = elem match {
case start @ FrameStart(header, data)
assert(header.length >= data.size)
if (!start.lastPart && header.length > 0) become(renderData(header.length - data.length, this))
ctx.push(renderStart(start))
}
}
def renderData(initialRemaining: Long, nextState: State): State =
new State {
var remaining: Long = initialRemaining
def onPush(elem: FrameEvent, ctx: Context[ByteString]): SyncDirective = elem match {
case FrameData(data, lastPart)
if (data.size > remaining)
throw new IllegalStateException(s"Expected $remaining frame bytes but got ${data.size}")
else if (data.size == remaining) {
if (!lastPart) throw new IllegalStateException(s"Frame data complete but `lastPart` flag not set")
become(nextState)
ctx.push(data)
} else {
remaining -= data.size
ctx.push(data)
}
}
}
def renderStart(start: FrameStart): ByteString = renderHeader(start.header) ++ start.data
def renderHeader(header: FrameHeader): ByteString = {
import Protocol._
val length = header.length
val (lengthBits, extraLengthBytes) = length match {
case x if x < 126 (x.toInt, 0)
case x if x <= 0xFFFF (126, 2)
case _ (127, 8)
}
val maskBytes = if (header.mask.isDefined) 4 else 0
val totalSize = 2 + extraLengthBytes + maskBytes
val data = new Array[Byte](totalSize)
def bool(b: Boolean, mask: Int): Int = if (b) mask else 0
val flags =
bool(header.fin, FIN_MASK) |
bool(header.rsv1, RSV1_MASK) |
bool(header.rsv2, RSV2_MASK) |
bool(header.rsv3, RSV3_MASK)
data(0) = (flags | header.opcode.code).toByte
data(1) = (bool(header.mask.isDefined, MASK_MASK) | lengthBits).toByte
extraLengthBytes match {
case 0
case 2
data(2) = ((length & 0xFF00) >> 8).toByte
data(3) = ((length & 0x00FF) >> 0).toByte
case 8
@tailrec def addLongBytes(l: Long, writtenBytes: Int): Unit =
if (writtenBytes < 8) {
data(2 + writtenBytes) = (l & 0xff).toByte
addLongBytes(java.lang.Long.rotateLeft(l, 8), writtenBytes + 1)
}
addLongBytes(java.lang.Long.rotateLeft(length, 8), 0)
}
val maskOffset = 2 + extraLengthBytes
header.mask.foreach { mask
data(maskOffset + 0) = ((mask & 0xFF000000) >> 24).toByte
data(maskOffset + 1) = ((mask & 0x00FF0000) >> 16).toByte
data(maskOffset + 2) = ((mask & 0x0000FF00) >> 8).toByte
data(maskOffset + 3) = ((mask & 0x000000FF) >> 0).toByte
}
ByteString(data)
}
}

View file

@ -0,0 +1,200 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.stream.scaladsl.Flow
import akka.stream.stage.{ TerminationDirective, SyncDirective, Context, StatefulStage }
import akka.util.ByteString
import Protocol.Opcode
import scala.util.control.NonFatal
/**
* The frame handler validates frames, multiplexes data to the user handler or to the bypass and
* UTF-8 decodes text frames.
*
* INTERNAL API
*/
private[http] object FrameHandler {
def create(server: Boolean): Flow[FrameEvent, Either[BypassEvent, MessagePart], Unit] =
Flow[FrameEvent].transform(() new HandlerStage(server))
class HandlerStage(server: Boolean) extends StatefulStage[FrameEvent, Either[BypassEvent, MessagePart]] {
type Ctx = Context[Either[BypassEvent, MessagePart]]
def initial: State = Idle
object Idle extends StateWithControlFrameHandling {
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
(start.header.opcode, start.isFullMessage) match {
case (Opcode.Binary, true) publishMessagePart(BinaryMessagePart(start.data, last = true))
case (Opcode.Binary, false) becomeAndHandleWith(new CollectingBinaryMessage, start)
case (Opcode.Text, _) becomeAndHandleWith(new CollectingTextMessage, start)
case x protocolError()
}
}
class CollectingBinaryMessage extends CollectingMessageFrame(Opcode.Binary) {
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = BinaryMessagePart(data, last)
}
class CollectingTextMessage extends CollectingMessageFrame(Opcode.Text) {
val decoder = Utf8Decoder.create()
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart =
TextMessagePart(decoder.decode(data, endOfInput = last).get, last)
}
abstract class CollectingMessageFrame(expectedOpcode: Opcode) extends StateWithControlFrameHandling {
var expectFirstHeader = true
var finSeen = false
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = {
if ((expectFirstHeader && start.header.opcode == expectedOpcode) // first opcode must be the expected
|| start.header.opcode == Opcode.Continuation) { // further ones continuations
expectFirstHeader = false
if (start.header.fin) finSeen = true
publish(start)
} else protocolError()
}
override def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = publish(data)
private def publish(part: FrameEvent)(implicit ctx: Ctx): SyncDirective =
try publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart))
catch {
case NonFatal(e) closeWithCode(Protocol.CloseCodes.InconsistentData)
}
}
class CollectingControlFrame(opcode: Opcode, _data: ByteString, nextState: State) extends InFrameState {
var data = _data
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = {
this.data ++= data.data
if (data.lastPart) handleControlFrame(opcode, this.data, nextState)
else ctx.pull()
}
}
object Closed extends State {
def onPush(elem: FrameEvent, ctx: Ctx): SyncDirective =
ctx.pull() // ignore
}
def becomeAndHandleWith(newState: State, part: FrameEvent)(implicit ctx: Ctx): SyncDirective = {
become(newState)
current.onPush(part, ctx)
}
/** Returns a SyncDirective if it handled the message */
def validateHeader(header: FrameHeader)(implicit ctx: Ctx): Option[SyncDirective] = header match {
case h: FrameHeader if h.mask.isDefined && !server Some(protocolError())
case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3 Some(protocolError())
case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) Some(protocolError())
case _ None
}
def handleControlFrame(opcode: Opcode, data: ByteString, nextState: State)(implicit ctx: Ctx): SyncDirective = {
become(nextState)
opcode match {
case Opcode.Ping publishDirectResponse(FrameEvent.fullFrame(Opcode.Pong, None, data, fin = true))
case Opcode.Pong
// ignore unsolicited Pong frame
ctx.pull()
case Opcode.Close
val closeCode = FrameEventParser.parseCloseCode(data)
emit(Iterator(Left(PeerClosed(closeCode)), Right(PeerClosed(closeCode))), ctx, WaitForPeerTcpClose)
case Opcode.Other(o) closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode")
}
}
private def collectControlFrame(start: FrameStart, nextState: State)(implicit ctx: Ctx): SyncDirective = {
assert(!start.isFullMessage)
become(new CollectingControlFrame(start.header.opcode, start.data, nextState))
ctx.pull()
}
private def publishMessagePart(part: MessageDataPart)(implicit ctx: Ctx): SyncDirective =
if (part.last) emit(Iterator(Right(part), Right(MessageEnd)), ctx, Idle)
else ctx.push(Right(part))
private def publishDirectResponse(frame: FrameStart)(implicit ctx: Ctx): SyncDirective =
ctx.push(Left(DirectAnswer(frame)))
private def protocolError(reason: String = "")(implicit ctx: Ctx): SyncDirective =
closeWithCode(Protocol.CloseCodes.ProtocolError, reason)
private def closeWithCode(closeCode: Int, reason: String = "", cause: Throwable = null)(implicit ctx: Ctx): SyncDirective =
emit(
Iterator(
Left(ActivelyCloseWithCode(Some(closeCode), reason)),
Right(ActivelyCloseWithCode(Some(closeCode), reason))), ctx, CloseAfterPeerClosed)
object CloseAfterPeerClosed extends State {
def onPush(elem: FrameEvent, ctx: Context[Either[BypassEvent, MessagePart]]): SyncDirective =
elem match {
case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data)
become(WaitForPeerTcpClose)
ctx.push(Left(PeerClosed(FrameEventParser.parseCloseCode(data))))
case _ ctx.pull() // ignore all other data
}
}
object WaitForPeerTcpClose extends State {
def onPush(elem: FrameEvent, ctx: Context[Either[BypassEvent, MessagePart]]): SyncDirective =
ctx.pull() // ignore
}
abstract class StateWithControlFrameHandling extends BetweenFrameState {
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
validateHeader(start.header).getOrElse {
if (start.header.opcode.isControl)
if (start.isFullMessage) handleControlFrame(start.header.opcode, start.data, this)
else collectControlFrame(start, this)
else handleRegularFrameStart(start)
}
}
abstract class BetweenFrameState extends ImplicitContextState {
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective =
throw new IllegalStateException("Expected FrameStart")
}
abstract class InFrameState extends ImplicitContextState {
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
throw new IllegalStateException("Expected FrameData")
}
abstract class ImplicitContextState extends State {
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective
def onPush(part: FrameEvent, ctx: Ctx): SyncDirective =
part match {
case data: FrameData handleFrameData(data)(ctx)
case start: FrameStart handleFrameStart(start)(ctx)
}
}
}
sealed trait MessagePart {
def isMessageEnd: Boolean
}
sealed trait MessageDataPart extends MessagePart {
def isMessageEnd = false
def last: Boolean
}
final case class TextMessagePart(data: String, last: Boolean) extends MessageDataPart
final case class BinaryMessagePart(data: ByteString, last: Boolean) extends MessageDataPart
case object MessageEnd extends MessagePart {
def isMessageEnd: Boolean = true
}
final case class PeerClosed(code: Option[Int], reason: String = "") extends MessagePart with BypassEvent {
def isMessageEnd: Boolean = true
}
sealed trait BypassEvent
final case class DirectAnswer(frame: FrameStart) extends BypassEvent
final case class ActivelyCloseWithCode(code: Option[Int], reason: String = "") extends MessagePart with BypassEvent {
def isMessageEnd: Boolean = true
}
case object UserHandlerCompleted extends BypassEvent
case class UserHandlerErredOut(cause: Throwable) extends BypassEvent
}

View file

@ -0,0 +1,132 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import scala.concurrent.duration.FiniteDuration
import akka.stream.stage._
import akka.http.util.Timestamp
import FrameHandler.{ UserHandlerCompleted, ActivelyCloseWithCode, PeerClosed, DirectAnswer }
import Websocket.Tick
/**
* Implements the transport connection close handling at the end of the pipeline.
*
* INTERNAL API
*/
private[http] class FrameOutHandler(serverSide: Boolean, _closeTimeout: FiniteDuration) extends StatefulStage[AnyRef, FrameStart] {
def initial: StageState[AnyRef, FrameStart] = Idle
def closeTimeout: Timestamp = Timestamp.now + _closeTimeout
object Idle extends CompletionHandlingState {
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match {
case start: FrameStart ctx.push(start)
case DirectAnswer(frame) ctx.push(frame)
case PeerClosed(code, reason) if !code.exists(Protocol.CloseCodes.isError)
// let user complete it, FIXME: maybe make configurable? immediately, or timeout
become(new WaitingForUserHandlerClosed(FrameEvent.closeFrame(code.getOrElse(Protocol.CloseCodes.Regular), reason)))
ctx.pull()
case PeerClosed(code, reason)
val closeFrame = FrameEvent.closeFrame(code.getOrElse(Protocol.CloseCodes.Regular), reason)
if (serverSide) ctx.pushAndFinish(closeFrame)
else {
become(new WaitingForTransportClose)
ctx.push(closeFrame)
}
case ActivelyCloseWithCode(code, reason)
val closeFrame = FrameEvent.closeFrame(code.getOrElse(Protocol.CloseCodes.Regular), reason)
become(new WaitingForPeerCloseFrame())
ctx.push(closeFrame)
case UserHandlerCompleted
become(new WaitingForPeerCloseFrame())
ctx.push(FrameEvent.closeFrame(Protocol.CloseCodes.Regular))
case Tick ctx.pull() // ignore
}
def onComplete(ctx: Context[FrameStart]): TerminationDirective = {
become(new SendOutCloseFrameAndComplete(FrameEvent.closeFrame(Protocol.CloseCodes.Regular)))
ctx.absorbTermination()
}
}
/**
* peer has closed, we want to wait for user handler to close as well
*/
class WaitingForUserHandlerClosed(closeFrame: FrameStart) extends CompletionHandlingState {
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match {
case UserHandlerCompleted
if (serverSide) ctx.pushAndFinish(closeFrame)
else {
become(new WaitingForTransportClose())
ctx.push(closeFrame)
}
case start: FrameStart ctx.push(start)
case _ ctx.pull() // ignore
}
def onComplete(ctx: Context[FrameStart]): TerminationDirective =
ctx.fail(new IllegalStateException("Mustn't complete before user has completed"))
}
/**
* we have sent out close frame and wait for peer to sent its close frame
*/
class WaitingForPeerCloseFrame(timeout: Timestamp = closeTimeout) extends CompletionHandlingState {
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match {
case Tick
if (timeout.isPast) ctx.finish()
else ctx.pull()
case PeerClosed(code, reason)
if (serverSide) ctx.finish()
else {
become(new WaitingForTransportClose())
ctx.pull()
}
case _ ctx.pull() // ignore
}
def onComplete(ctx: Context[FrameStart]): TerminationDirective = ctx.finish()
}
/**
* Both side have sent their close frames, server should close the connection first
*/
class WaitingForTransportClose(timeout: Timestamp = closeTimeout) extends CompletionHandlingState {
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match {
case Tick
if (timeout.isPast) ctx.finish()
else ctx.pull()
case _ ctx.pull() // ignore
}
def onComplete(ctx: Context[FrameStart]): TerminationDirective = ctx.finish()
}
/** If upstream has already failed we just wait to be able to deliver our close frame and complete */
class SendOutCloseFrameAndComplete(closeFrame: FrameStart) extends CompletionHandlingState {
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective =
ctx.fail(new IllegalStateException("Didn't expect push after completion"))
override def onPull(ctx: Context[FrameStart]): SyncDirective =
ctx.pushAndFinish(closeFrame)
def onComplete(ctx: Context[FrameStart]): TerminationDirective =
ctx.absorbTermination()
}
trait CompletionHandlingState extends State {
def onComplete(ctx: Context[FrameStart]): TerminationDirective
}
override def onUpstreamFinish(ctx: Context[FrameStart]): TerminationDirective =
current.asInstanceOf[CompletionHandlingState].onComplete(ctx)
override def onUpstreamFailure(cause: scala.Throwable, ctx: Context[FrameStart]): TerminationDirective = cause match {
case p: ProtocolException
become(new SendOutCloseFrameAndComplete(FrameEvent.closeFrame(Protocol.CloseCodes.ProtocolError)))
ctx.absorbTermination()
case _ super.onUpstreamFailure(cause, ctx)
}
}

View file

@ -0,0 +1,125 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.http.model.headers._
import akka.http.model.ws.{ Message, UpgradeToWebsocket }
import akka.http.model.{ StatusCodes, HttpResponse, HttpProtocol, HttpHeader }
import akka.parboiled2.util.Base64
import akka.stream.FlowMaterializer
import akka.stream.scaladsl.Flow
import scala.collection.immutable.Seq
import scala.reflect.ClassTag
/**
* Server-side implementation of the Websocket handshake
*
* INTERNAL API
*/
private[http] object Handshake {
val CurrentWebsocketVersion = 13
/*
From: http://tools.ietf.org/html/rfc6455#section-4.2.1
1. An HTTP/1.1 or higher GET request, including a "Request-URI"
[RFC2616] that should be interpreted as a /resource name/
defined in Section 3 (or an absolute HTTP/HTTPS URI containing
the /resource name/).
2. A |Host| header field containing the server's authority.
3. An |Upgrade| header field containing the value "websocket",
treated as an ASCII case-insensitive value.
4. A |Connection| header field that includes the token "Upgrade",
treated as an ASCII case-insensitive value.
5. A |Sec-WebSocket-Key| header field with a base64-encoded (see
Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
length.
6. A |Sec-WebSocket-Version| header field, with a value of 13.
7. Optionally, an |Origin| header field. This header field is sent
by all browser clients. A connection attempt lacking this
header field SHOULD NOT be interpreted as coming from a browser
client.
8. Optionally, a |Sec-WebSocket-Protocol| header field, with a list
of values indicating which protocols the client would like to
speak, ordered by preference.
9. Optionally, a |Sec-WebSocket-Extensions| header field, with a
list of values indicating which extensions the client would like
to speak. The interpretation of this header field is discussed
in Section 9.1.
*/
def isWebsocketUpgrade(headers: List[HttpHeader], hostHeaderPresent: Boolean): Option[UpgradeToWebsocket] = {
def find[T <: HttpHeader: ClassTag]: Option[T] =
headers.collectFirst {
case t: T t
}
val host = find[Host]
val upgrade = find[Upgrade]
val connection = find[Connection]
val key = find[`Sec-WebSocket-Key`]
val version = find[`Sec-WebSocket-Version`]
val origin = find[Origin]
val protocol = find[`Sec-WebSocket-Protocol`]
val supportedProtocols = protocol.toList.flatMap(_.protocols)
val extensions = find[`Sec-WebSocket-Extensions`]
def isValidKey(key: String): Boolean = Base64.rfc2045().decode(key).length == 16
if (upgrade.exists(_.hasWebsocket) &&
connection.exists(_.hasUpgrade) &&
version.exists(_.hasVersion(CurrentWebsocketVersion)) &&
key.exists(k isValidKey(k.key))) {
val header = new UpgradeToWebsocketLowLevel {
def requestedProtocols: Seq[String] = supportedProtocols
def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String])(implicit mat: FlowMaterializer): HttpResponse = {
require(subprotocol.forall(chosen supportedProtocols.contains(chosen)),
s"Tried to choose invalid subprotocol '$subprotocol' which wasn't offered by the client: [${requestedProtocols.mkString(", ")}]")
buildResponse(key.get, handlerFlow, subprotocol)
}
}
Some(header)
} else None
}
/*
From: http://tools.ietf.org/html/rfc6455#section-4.2.2
1. A Status-Line with a 101 response code as per RFC 2616
[RFC2616]. Such a response could look like "HTTP/1.1 101
Switching Protocols".
2. An |Upgrade| header field with value "websocket" as per RFC
2616 [RFC2616].
3. A |Connection| header field with value "Upgrade".
4. A |Sec-WebSocket-Accept| header field. The value of this
header field is constructed by concatenating /key/, defined
above in step 4 in Section 4.2.2, with the string "258EAFA5-
E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this
concatenated value to obtain a 20-byte value and base64-
encoding (see Section 4 of [RFC4648]) this 20-byte hash.
*/
def buildResponse(key: `Sec-WebSocket-Key`, handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String])(implicit mat: FlowMaterializer): HttpResponse =
HttpResponse(
StatusCodes.SwitchingProtocols,
subprotocol.map(p `Sec-WebSocket-Protocol`(Seq(p))).toList :::
List(
Upgrade(List(UpgradeProtocol("websocket"))),
Connection(List("upgrade")),
`Sec-WebSocket-Accept`.forKey(key),
UpgradeToWebsocketResponseHeader(handlerFlow)))
}

View file

@ -0,0 +1,71 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.stream.scaladsl.Flow
import akka.stream.stage.{ SyncDirective, Context, StageState, StatefulStage }
import scala.util.Random
/**
* Implements Websocket Frame masking.
*
* INTERNAL API
*/
private[http] object Masking {
def maskIf(condition: Boolean, maskRandom: () Random): Flow[FrameEvent, FrameEvent, Unit] =
if (condition) Flow[FrameEvent].transform(() new Masking(maskRandom())) // new random per materialization
else Flow[FrameEvent]
def unmaskIf(condition: Boolean): Flow[FrameEvent, FrameEvent, Unit] =
if (condition) Flow[FrameEvent].transform(() new Unmasking())
else Flow[FrameEvent]
class Masking(random: Random) extends Masker {
def extractMask(header: FrameHeader): Int = random.nextInt()
def setNewMask(header: FrameHeader, mask: Int): FrameHeader = {
if (header.mask.isDefined) throw new ProtocolException("Frame mustn't already be masked")
header.copy(mask = Some(mask))
}
}
class Unmasking extends Masker {
def extractMask(header: FrameHeader): Int = header.mask match {
case Some(mask) mask
case None throw new ProtocolException("Frame wasn't masked")
}
def setNewMask(header: FrameHeader, mask: Int): FrameHeader = header.copy(mask = None)
}
/** Implements both masking and unmasking which is mostly symmetric (because of XOR) */
abstract class Masker extends StatefulStage[FrameEvent, FrameEvent] {
def extractMask(header: FrameHeader): Int
def setNewMask(header: FrameHeader, mask: Int): FrameHeader
def initial: State = Idle
object Idle extends State {
def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective =
part match {
case start @ FrameStart(header, data)
if (header.length == 0) ctx.push(part)
else {
val mask = extractMask(header)
become(new Running(mask))
current.onPush(start.copy(header = setNewMask(header, mask)), ctx)
}
}
}
class Running(initialMask: Int) extends State {
var mask = initialMask
def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective = {
if (part.lastPart) become(Idle)
val (masked, newMask) = FrameEventParser.mask(part.data, mask)
mask = newMask
ctx.push(part.withData(data = masked))
}
}
}
}

View file

@ -0,0 +1,37 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.util.ByteString
import akka.stream.scaladsl.{ FlattenStrategy, Source, Flow }
import Protocol.Opcode
import akka.http.model.ws._
/**
* Renders messages to full frames.
*
* INTERNAL API
*/
private[http] object MessageToFrameRenderer {
def create(serverSide: Boolean): Flow[Message, FrameStart, Unit] = {
def strictFrames(opcode: Opcode, data: ByteString): Source[FrameStart, _] =
// FIXME: fragment?
Source.single(FrameEvent.fullFrame(opcode, None, data, fin = true))
def streamedFrames(opcode: Opcode, data: Source[ByteString, _]): Source[FrameStart, _] =
Source.single(FrameEvent.empty(opcode, fin = false)) ++
data.map(FrameEvent.fullFrame(Opcode.Continuation, None, _, fin = false)) ++
Source.single(FrameEvent.emptyLastContinuationFrame)
Flow[Message]
.map {
case BinaryMessage.Strict(data) strictFrames(Opcode.Binary, data)
case BinaryMessage.Streamed(data) streamedFrames(Opcode.Binary, data)
case TextMessage.Strict(text) strictFrames(Opcode.Text, ByteString(text, "UTF-8"))
case TextMessage.Streamed(text) streamedFrames(Opcode.Text, text.transform(() new Utf8Encoder))
}.flatten(FlattenStrategy.Concat())
}
}

View file

@ -0,0 +1,84 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
/**
* Contains Websocket protocol constants
*
* INTERNAL API
*/
private[http] object Protocol {
val FIN_MASK = 0x80
val RSV1_MASK = 0x40
val RSV2_MASK = 0x20
val RSV3_MASK = 0x10
val FLAGS_MASK = 0xF0
val OP_MASK = 0x0F
val MASK_MASK = 0x80
val LENGTH_MASK = 0x7F
sealed trait Opcode {
def code: Byte
def isControl: Boolean
}
object Opcode {
def forCode(code: Byte): Opcode = code match {
case 0x0 Continuation
case 0x1 Text
case 0x2 Binary
case 0x8 Close
case 0x9 Ping
case 0xA Pong
case b if (b & 0xf0) == 0 Other(code)
case _ throw new IllegalArgumentException(f"Opcode must be 4bit long but was 0x$code%02X")
}
sealed abstract class AbstractOpcode private[Opcode] (val code: Byte) extends Opcode {
def isControl: Boolean = (code & 0x8) != 0
}
case object Continuation extends AbstractOpcode(0x0)
case object Text extends AbstractOpcode(0x1)
case object Binary extends AbstractOpcode(0x2)
case object Close extends AbstractOpcode(0x8)
case object Ping extends AbstractOpcode(0x9)
case object Pong extends AbstractOpcode(0xA)
case class Other(override val code: Byte) extends AbstractOpcode(code)
}
/**
* Close status codes as defined at http://tools.ietf.org/html/rfc6455#section-7.4.1
*/
object CloseCodes {
def isError(code: Int): Boolean = !(code == Regular || code == GoingAway)
def isValid(code: Int): Boolean =
((code >= 1000) && (code <= 1003)) ||
(code >= 1007) && (code <= 1011) ||
(code >= 3000) && (code <= 4999)
val Regular = 1000
val GoingAway = 1001
val ProtocolError = 1002
val Unacceptable = 1003
// Reserved = 1004
// NoCodePresent = 1005
val ConnectionAbort = 1006
val InconsistentData = 1007
val PolicyViolated = 1008
val TooBig = 1009
val ClientRejectsExtension = 1010
val UnexpectedCondition = 1011
val TLSHandshakeFailure = 1015
}
}
/** INTERNAL API */
private[http] case class ProtocolException(cause: String) extends RuntimeException(cause)

View file

@ -0,0 +1,32 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.http.model.HttpResponse
import akka.http.model.ws.{ Message, UpgradeToWebsocket }
import akka.stream.FlowMaterializer
import akka.stream.scaladsl.Flow
/**
* Currently internal API to handle FrameEvents directly.
*
* INTERNAL API
*/
private[http] abstract class UpgradeToWebsocketLowLevel extends InternalCustomHeader("UpgradeToWebsocket") with UpgradeToWebsocket {
/**
* The low-level interface to create Websocket server based on "frames".
* The user needs to handle control frames manually in this case.
*
* Returns a response to return in a request handler that will signal the
* low-level HTTP implementation to upgrade the connection to Websocket and
* use the supplied handler to handle incoming Websocket frames.
*
* INTERNAL API (for now)
*/
private[http] def handleFrames(handlerFlow: Flow[FrameEvent, FrameEvent, Any], subprotocol: Option[String] = None)(implicit mat: FlowMaterializer): HttpResponse
override def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String] = None)(implicit mat: FlowMaterializer): HttpResponse =
handleFrames(Websocket.handleMessages(handlerFlow), subprotocol)
}

View file

@ -0,0 +1,19 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.http.model.headers.CustomHeader
import akka.stream.FlowMaterializer
import akka.stream.scaladsl.Flow
private[http] case class UpgradeToWebsocketResponseHeader(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit val mat: FlowMaterializer)
extends InternalCustomHeader("UpgradeToWebsocketResponseHeader") {
}
private[http] abstract class InternalCustomHeader(val name: String) extends CustomHeader {
override def suppressRendering: Boolean = true
def value(): String = ""
}

View file

@ -0,0 +1,121 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.util.ByteString
import scala.annotation.tailrec
import scala.util.Try
/**
* A Utf8 -> Utf16 (= Java char) decoder.
*
* This decoder is based on the one of Bjoern Hoehrmann from
*
* http://bjoern.hoehrmann.de/utf-8/decoder/dfa/
*
* which is licensed under this license:
*
* Copyright (c) 2008-2010 Bjoern Hoehrmann <bjoern@hoehrmann.de>
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* FIXME: reviewers: it may be necessary to distribute this notice in the license file, is it?
*
* INTERNAL API
*/
private[http] object Utf8Decoder extends StreamingCharsetDecoder {
private[this] val Utf8Accept = 0
private[this] val Utf8Reject = 12
val characterClasses =
Array[Byte](
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 11, 6, 6, 6, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8)
val states =
Array[Byte](
0, 12, 24, 36, 60, 96, 84, 12, 12, 12, 48, 72, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
12, 0, 12, 12, 12, 12, 12, 0, 12, 0, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 24, 12, 12,
12, 12, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 12, 12, 24, 12, 12,
12, 12, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12, 12, 36, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12,
12, 36, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12)
def create(): StreamingCharsetDecoderInstance =
new StreamingCharsetDecoderInstance {
var currentCodePoint = 0
var currentState = Utf8Accept
def decode(bytes: ByteString, endOfInput: Boolean): Try[String] = Try {
val result = new StringBuilder(bytes.size)
val length = bytes.size
def step(byte: Int): Unit = {
val chClass = characterClasses(byte)
currentCodePoint =
if (currentState == Utf8Accept) // first byte
(0xff >> chClass) & byte // take as much bits as the characterClass says
else // continuation byte
(0x3f & byte) | (currentCodePoint << 6) // take 6 bits
currentState = states(currentState + chClass)
currentState match {
case Utf8Accept
if (currentCodePoint <= 0xffff)
// fits in single UTF-16 char
result.append(currentCodePoint.toChar)
else {
// create surrogate pair
result.append((0xD7C0 + (currentCodePoint >> 10)).toChar)
result.append((0xDC00 + (currentCodePoint & 0x3FF)).toChar)
}
case Utf8Reject fail("Invalid UTF-8 input")
case _ // valid intermediate state, need more input
}
}
var offset = 0
while (offset < length) {
step(bytes(offset) & 0xff)
offset += 1
}
if (endOfInput && currentState != Utf8Accept) fail("Truncated UTF-8 input")
else
result.toString()
}
def fail(msg: String): Nothing = throw new IllegalArgumentException(msg)
}
}
private[http] trait StreamingCharsetDecoder {
def create(): StreamingCharsetDecoderInstance
def decode(bytes: ByteString): Try[String] = create().decode(bytes, endOfInput = true)
}
private[http] trait StreamingCharsetDecoderInstance {
def decode(bytes: ByteString, endOfInput: Boolean): Try[String]
}

View file

@ -0,0 +1,83 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.stream.stage._
import akka.util.{ ByteStringBuilder, ByteString }
import scala.annotation.tailrec
/**
* A utf16 (= Java char) to utf8 encoder.
*
* INTERNAL API
*/
private[http] class Utf8Encoder extends PushStage[String, ByteString] {
import Utf8Encoder._
var surrogateValue: Int = 0
def inSurrogatePair: Boolean = surrogateValue != 0
def onPush(input: String, ctx: Context[ByteString]): SyncDirective = {
val builder = new ByteStringBuilder
def b(v: Int): Unit = {
assert((v & 0xff) == v)
builder += v.toByte
}
def step(char: Int): Unit =
if (!inSurrogatePair)
if (char <= Utf8OneByteLimit) builder += char.toByte
else if (char <= Utf8TwoByteLimit) {
b(0xc0 | ((char & 0x7c0) >> 6)) // upper 5 bits
b(0x80 | (char & 0x3f)) // lower 6 bits
} else if (char >= SurrogateFirst && char < SurrogateSecond)
surrogateValue = 0x10000 | ((char ^ SurrogateFirst) << 10)
else if (char >= SurrogateSecond && char < 0xdfff)
throw new IllegalArgumentException(f"Unexpected UTF-16 surrogate continuation")
else if (char <= Utf8ThreeByteLimit) {
b(0xe0 | ((char & 0xf000) >> 12)) // upper 4 bits
b(0x80 | ((char & 0x0fc0) >> 6)) // middle 6 bits
b(0x80 | (char & 0x3f)) // lower 6 bits
} else
throw new IllegalStateException("Char cannot be >= 2^16") // char value was converted from 16bit value
else if (char >= SurrogateSecond && char <= 0xdfff) {
surrogateValue |= (char & 0x3ff)
b(0xf0 | ((surrogateValue & 0x1c0000) >> 18)) // upper 3 bits
b(0x80 | ((surrogateValue & 0x3f000) >> 12)) // first middle 6 bits
b(0x80 | ((surrogateValue & 0x0fc0) >> 6)) // second middle 6 bits
b(0x80 | (surrogateValue & 0x3f)) // lower 6 bits
surrogateValue = 0
} else throw new IllegalArgumentException(f"Expected UTF-16 surrogate continuation")
var offset = 0
while (offset < input.length) {
step(input(offset))
offset += 1
}
if (builder.length > 0) ctx.push(builder.result())
else ctx.pull()
}
override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective =
if (inSurrogatePair) ctx.fail(new IllegalArgumentException("Truncated String input (ends in the middle of surrogate pair)"))
else super.onUpstreamFinish(ctx)
}
/**
* INTERNAL API
*/
private[http] object Utf8Encoder {
val SurrogateFirst = 0xd800
val SurrogateSecond = 0xdc00
val Utf8OneByteLimit = lowerNBitsSet(7)
val Utf8TwoByteLimit = lowerNBitsSet(11)
val Utf8ThreeByteLimit = lowerNBitsSet(16)
def lowerNBitsSet(n: Int): Long = (1L << n) - 1
}

View file

@ -0,0 +1,179 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import java.security.SecureRandom
import scala.concurrent.duration._
import akka.stream.{ OperationAttributes, FanOutShape2, FanInShape3, Inlet }
import akka.stream.scaladsl._
import akka.stream.stage._
import FlexiRoute.{ DemandFrom, DemandFromAny, RouteLogic }
import FlexiMerge.MergeLogic
import akka.http.util._
import akka.http.model.ws._
/**
* INTERNAL API
*/
private[http] object Websocket {
import FrameHandler._
def handleMessages[T](messageHandler: Flow[Message, Message, T],
serverSide: Boolean = true,
closeTimeout: FiniteDuration = 3.seconds): Flow[FrameEvent, FrameEvent, Unit] = {
/** Completes this branch of the flow if no more messages are expected and converts close codes into errors */
class PrepareForUserHandler extends PushStage[MessagePart, MessagePart] {
def onPush(elem: MessagePart, ctx: Context[MessagePart]): SyncDirective = elem match {
case PeerClosed(code, reason)
if (code.exists(Protocol.CloseCodes.isError)) ctx.fail(new ProtocolException(s"Peer closed connection with code $code"))
else ctx.finish()
case ActivelyCloseWithCode(code, reason)
if (code.exists(Protocol.CloseCodes.isError)) ctx.fail(new ProtocolException(s"Closing connection with error code $code"))
else ctx.finish()
case x ctx.push(x)
}
}
/** Collects user-level API messages from MessageDataParts */
val collectMessage: Flow[Source[MessageDataPart, Unit], Message, Unit] =
Flow[Source[MessageDataPart, Unit]]
.headAndTail
.map {
case (TextMessagePart(text, true), remaining)
TextMessage.Strict(text)
case (first @ TextMessagePart(text, false), remaining)
TextMessage.Streamed(
(Source.single(first) ++ remaining)
.collect {
case t: TextMessagePart if t.data.nonEmpty t.data
})
case (BinaryMessagePart(data, true), remaining)
BinaryMessage.Strict(data)
case (first @ BinaryMessagePart(data, false), remaining)
BinaryMessage.Streamed(
(Source.single(first) ++ remaining)
.collect {
case t: BinaryMessagePart if t.data.nonEmpty t.data
})
}
/** Lifts onComplete and onError into events to be processed in the FlexiMerge */
class LiftCompletions extends StatefulStage[FrameStart, AnyRef] {
def initial: StageState[FrameStart, AnyRef] = SteadyState
object SteadyState extends State {
def onPush(elem: FrameStart, ctx: Context[AnyRef]): SyncDirective = ctx.push(elem)
}
class CompleteWith(last: AnyRef) extends State {
def onPush(elem: FrameStart, ctx: Context[AnyRef]): SyncDirective =
ctx.fail(new IllegalStateException("No push expected"))
override def onPull(ctx: Context[AnyRef]): SyncDirective = ctx.pushAndFinish(last)
}
override def onUpstreamFinish(ctx: Context[AnyRef]): TerminationDirective = {
become(new CompleteWith(UserHandlerCompleted))
ctx.absorbTermination()
}
override def onUpstreamFailure(cause: Throwable, ctx: Context[AnyRef]): TerminationDirective = {
become(new CompleteWith(UserHandlerErredOut(cause)))
ctx.absorbTermination()
}
}
lazy val userFlow =
Flow[MessagePart]
.transform(() new PrepareForUserHandler)
.splitWhen(_.isMessageEnd) // FIXME using splitAfter from #16885 would simplify protocol a lot
.map(_.collect {
case m: MessageDataPart m
})
.via(collectMessage)
.via(messageHandler)
.via(MessageToFrameRenderer.create(serverSide))
.transform(() new LiftCompletions)
/**
* Distributes output from the FrameHandler into bypass and userFlow.
*/
object BypassRouter
extends FlexiRoute[Either[BypassEvent, MessagePart], FanOutShape2[Either[BypassEvent, MessagePart], BypassEvent, MessagePart]](new FanOutShape2("bypassRouter"), OperationAttributes.name("bypassRouter")) {
def createRouteLogic(s: FanOutShape2[Either[BypassEvent, MessagePart], BypassEvent, MessagePart]): RouteLogic[Either[BypassEvent, MessagePart]] =
new RouteLogic[Either[BypassEvent, MessagePart]] {
def initialState: State[_] = State(DemandFromAny(s)) { (ctx, out, ev)
ev match {
case Left(_)
State(DemandFrom(s.out0)) { (ctx, _, ev) // FIXME: #17004
ctx.emit(s.out0)(ev.left.get)
initialState
}
case Right(_)
State(DemandFrom(s.out1)) { (ctx, _, ev)
ctx.emit(s.out1)(ev.right.get)
initialState
}
}
}
override def initialCompletionHandling: CompletionHandling = super.initialCompletionHandling.copy(
onDownstreamFinish = { (ctx, out)
if (out == s.out0) ctx.finish()
SameState
})
}
}
/**
* Merges bypass, user flow and tick source for consumption in the FrameOutHandler.
*/
object BypassMerge extends FlexiMerge[AnyRef, FanInShape3[BypassEvent, AnyRef, Tick.type, AnyRef]](new FanInShape3("bypassMerge"), OperationAttributes.name("bypassMerge")) {
def createMergeLogic(s: FanInShape3[BypassEvent, AnyRef, Tick.type, AnyRef]): MergeLogic[AnyRef] =
new MergeLogic[AnyRef] {
def initialState: State[_] = Idle
lazy val Idle = State[AnyRef](FlexiMerge.ReadAny(s.in0.asInstanceOf[Inlet[AnyRef]], s.in1.asInstanceOf[Inlet[AnyRef]], s.in2.asInstanceOf[Inlet[AnyRef]])) { (ctx, in, elem)
ctx.emit(elem)
SameState
}
override def initialCompletionHandling: CompletionHandling =
CompletionHandling(
onUpstreamFinish = { (ctx, in)
if (in == s.in0) ctx.finish()
SameState
},
onUpstreamFailure = { (ctx, in, cause)
if (in == s.in0) ctx.fail(cause)
SameState
})
}
}
lazy val bypassAndUserHandler: Flow[Either[BypassEvent, MessagePart], AnyRef, Unit] =
Flow(BypassRouter, Source(closeTimeout, closeTimeout, Tick), BypassMerge)((_, _, _) ()) { implicit b
(split, tick, merge)
import FlowGraph.Implicits._
split.out0 ~> merge.in0
split.out1 ~> userFlow ~> merge.in1
tick.outlet ~> merge.in2
(split.in, merge.out)
}
Flow[FrameEvent]
.via(Masking.unmaskIf(serverSide))
.via(FrameHandler.create(server = serverSide))
.mapConcat(x x :: x :: Nil) // FIXME: #17004
.via(bypassAndUserHandler)
.transform(() new FrameOutHandler(serverSide, closeTimeout))
.via(Masking.maskIf(!serverSide, () new SecureRandom()))
}
object Tick
case object SwitchToWebsocketToken
}

View file

@ -0,0 +1,13 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.stream.FlowMaterializer
import akka.stream.scaladsl.Flow
/** Internal interface between the handshake and the stream setup to evoke the switch to the websocket protocol */
private[http] trait WebsocketSwitch {
def switchToWebsocket(handlerFlow: Flow[FrameEvent, FrameEvent, Any])(implicit mat: FlowMaterializer): Unit
}

View file

@ -0,0 +1,15 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.model.headers
import akka.http.util.{ Rendering, ValueRenderable }
final case class UpgradeProtocol(name: String, version: Option[String] = None) extends ValueRenderable {
def render[R <: Rendering](r: R): r.type = {
r ~~ name
version.foreach(v r ~~ '/' ~~ v)
r
}
}

View file

@ -0,0 +1,24 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.model.headers
import akka.http.util.{ Rendering, ValueRenderable }
import scala.collection.immutable
/**
* A websocket extension as defined in http://tools.ietf.org/html/rfc6455#section-4.3
*/
final case class WebsocketExtension(name: String, params: immutable.Map[String, String] = Map.empty) extends ValueRenderable {
def render[R <: Rendering](r: R): r.type = {
r ~~ name
if (params.nonEmpty)
params.foreach {
case (k, "") r ~~ "; " ~~ k
case (k, v) r ~~ "; " ~~ k ~~ '=' ~~# v
}
r
}
}

View file

@ -7,7 +7,10 @@ package headers
import java.lang.Iterable
import java.net.InetSocketAddress
import java.security.MessageDigest
import java.util
import akka.parboiled2.util.Base64
import scala.annotation.tailrec
import scala.collection.immutable
import akka.http.util._
@ -52,6 +55,7 @@ final case class Connection(tokens: immutable.Seq[String]) extends ModeledHeader
def renderValue[R <: Rendering](r: R): r.type = r ~~ tokens
def hasClose = has("close")
def hasKeepAlive = has("keep-alive")
def hasUpgrade = has("upgrade")
def append(tokens: immutable.Seq[String]) = Connection(this.tokens ++ tokens)
@tailrec private def has(item: String, ix: Int = 0): Boolean =
if (ix < tokens.length)
@ -563,6 +567,103 @@ final case class Referer(uri: Uri) extends japi.headers.Referer with ModeledHead
def getUri: akka.http.model.japi.Uri = uri.asJava
}
/**
* INTERNAL API
*/
// http://tools.ietf.org/html/rfc6455#section-4.3
private[http] object `Sec-WebSocket-Accept` extends ModeledCompanion {
// Defined at http://tools.ietf.org/html/rfc6455#section-4.2.2
val MagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
/** Generates the matching accept header for this key */
def forKey(key: `Sec-WebSocket-Key`): `Sec-WebSocket-Accept` = {
val sha1 = MessageDigest.getInstance("sha1")
val salted = key.key + MagicGuid
val hash = sha1.digest(salted.asciiBytes)
val acceptKey = Base64.rfc2045().encodeToString(hash, false)
`Sec-WebSocket-Accept`(acceptKey)
}
}
/**
* INTERNAL API
*/
private[http] final case class `Sec-WebSocket-Accept`(key: String) extends ModeledHeader {
protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ key
protected def companion = `Sec-WebSocket-Accept`
}
/**
* INTERNAL API
*/
// http://tools.ietf.org/html/rfc6455#section-4.3
private[http] object `Sec-WebSocket-Extensions` extends ModeledCompanion {
implicit val extensionsRenderer = Renderer.defaultSeqRenderer[WebsocketExtension]
}
/**
* INTERNAL API
*/
private[http] final case class `Sec-WebSocket-Extensions`(extensions: immutable.Seq[WebsocketExtension]) extends ModeledHeader {
require(extensions.nonEmpty, "Sec-WebSocket-Extensions.extensions must not be empty")
import `Sec-WebSocket-Extensions`.extensionsRenderer
protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ extensions
protected def companion = `Sec-WebSocket-Extensions`
}
// http://tools.ietf.org/html/rfc6455#section-4.3
/**
* INTERNAL API
*/
private[http] object `Sec-WebSocket-Key` extends ModeledCompanion
/**
* INTERNAL API
*/
private[http] final case class `Sec-WebSocket-Key`(key: String) extends ModeledHeader {
protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ key
protected def companion = `Sec-WebSocket-Key`
}
// http://tools.ietf.org/html/rfc6455#section-4.3
/**
* INTERNAL API
*/
private[http] object `Sec-WebSocket-Protocol` extends ModeledCompanion {
implicit val protocolsRenderer = Renderer.defaultSeqRenderer[String]
}
/**
* INTERNAL API
*/
private[http] final case class `Sec-WebSocket-Protocol`(protocols: immutable.Seq[String]) extends ModeledHeader {
require(protocols.nonEmpty, "Sec-WebSocket-Protocol.protocols must not be empty")
import `Sec-WebSocket-Protocol`.protocolsRenderer
protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ protocols
protected def companion = `Sec-WebSocket-Protocol`
}
// http://tools.ietf.org/html/rfc6455#section-4.3
/**
* INTERNAL API
*/
private[http] object `Sec-WebSocket-Version` extends ModeledCompanion {
implicit val versionsRenderer = Renderer.defaultSeqRenderer[Int]
}
/**
* INTERNAL API
*/
private[http] final case class `Sec-WebSocket-Version`(versions: immutable.Seq[Int]) extends ModeledHeader {
require(versions.nonEmpty, "Sec-WebSocket-Version.versions must not be empty")
require(versions.forall(v v >= 0 && v <= 255), s"Sec-WebSocket-Version.versions must be in the range 0 <= version <= 255 but were $versions")
import `Sec-WebSocket-Version`.versionsRenderer
protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ versions
def hasVersion(versionNumber: Int): Boolean = versions.exists(_ == versionNumber)
protected def companion = `Sec-WebSocket-Version`
}
// http://tools.ietf.org/html/rfc7231#section-7.4.2
object Server extends ModeledCompanion {
def apply(products: String): Server = apply(ProductVersion.parseMultiple(products))
@ -611,6 +712,19 @@ final case class `Transfer-Encoding`(encodings: immutable.Seq[TransferEncoding])
def getEncodings: Iterable[japi.TransferEncoding] = encodings.asJava
}
// http://tools.ietf.org/html/rfc7230#section-6.7
object Upgrade extends ModeledCompanion {
implicit val protocolsRenderer = Renderer.defaultSeqRenderer[UpgradeProtocol]
}
final case class Upgrade(protocols: immutable.Seq[UpgradeProtocol]) extends ModeledHeader {
import Upgrade.protocolsRenderer
protected[http] def renderValue[R <: Rendering](r: R): r.type = r ~~ protocols
protected def companion: ModeledCompanion = Upgrade
def hasWebsocket: Boolean = protocols.exists(_.name equalsIgnoreCase "websocket")
}
// http://tools.ietf.org/html/rfc7231#section-5.5.3
object `User-Agent` extends ModeledCompanion {
def apply(products: String): `User-Agent` = apply(ProductVersion.parseMultiple(products))

View file

@ -26,7 +26,8 @@ private[http] class HeaderParser(val input: ParserInput) extends Parser with Dyn
with IpAddressParsing
with LinkHeader
with SimpleHeaders
with StringBuilding {
with StringBuilding
with WebsocketHeaders {
import CharacterClasses._
// http://www.rfc-editor.org/errata_search.php?rfc=7230 errata id 4189
@ -111,8 +112,14 @@ private[http] object HeaderParser {
"range",
"referer",
"server",
"sec-websocket-accept",
"sec-websocket-extensions",
"sec-websocket-key",
"sec-websocket-protocol",
"sec-websocket-version",
"set-cookie",
"transfer-encoding",
"upgrade",
"user-agent",
"www-authenticate",
"x-forwarded-for")

View file

@ -187,6 +187,15 @@ private[parser] trait SimpleHeaders { this: Parser with CommonRules with CommonA
`cookie-pair` ~ zeroOrMore(ws(';') ~ `cookie-av`) ~ EOI ~> (`Set-Cookie`(_))
}
// http://tools.ietf.org/html/rfc7230#section-6.7
def upgrade = rule {
oneOrMore(protocol).separatedBy(listSep) ~> (Upgrade(_))
}
def protocol = rule {
token ~ optional(ws("/") ~ token) ~> (UpgradeProtocol(_, _))
}
// http://tools.ietf.org/html/rfc7231#section-5.5.3
def `user-agent` = rule { products ~> (`User-Agent`(_)) }

View file

@ -0,0 +1,61 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.model.parser
import akka.http.model.headers._
import akka.parboiled2._
// see grammar at http://tools.ietf.org/html/rfc6455#section-4.3
private[parser] trait WebsocketHeaders { this: Parser with CommonRules with CommonActions
import CharacterClasses._
import Base64Parsing.rfc2045Alphabet
def `sec-websocket-accept` = rule {
`base64-value-non-empty` ~ EOI ~> (`Sec-WebSocket-Accept`(_))
}
def `sec-websocket-extensions` = rule {
oneOrMore(extension).separatedBy(listSep) ~ EOI ~> (`Sec-WebSocket-Extensions`(_))
}
def `sec-websocket-key` = rule {
`base64-value-non-empty` ~ EOI ~> (`Sec-WebSocket-Key`(_))
}
def `sec-websocket-protocol` = rule {
oneOrMore(token).separatedBy(listSep) ~ EOI ~> (`Sec-WebSocket-Protocol`(_))
}
def `sec-websocket-version` = rule {
oneOrMore(version).separatedBy(listSep) ~ EOI ~> (`Sec-WebSocket-Version`(_))
}
private def `base64-value-non-empty` = rule {
capture(oneOrMore(`base64-data`) ~ optional(`base64-padding`) | `base64-padding`)
}
private def `base64-data` = rule { 4.times(`base64-character`) }
private def `base64-padding` = rule {
2.times(`base64-character`) ~ "==" |
3.times(`base64-character`) ~ "="
}
private def `base64-character` = rfc2045Alphabet
private def extension = rule {
`extension-token` ~ zeroOrMore(ws(";") ~ `extension-param`) ~>
((name, params) WebsocketExtension(name, Map(params: _*)))
}
private def `extension-token`: Rule1[String] = token
private def `extension-param`: Rule1[(String, String)] =
rule {
token ~ optional(ws("=") ~ word) ~> ((name: String, value: Option[String]) (name, value.getOrElse("")))
}
private def version = rule {
capture(
NZDIGIT ~ optional(DIGIT ~ optional(DIGIT)) |
DIGIT) ~> (_.toInt)
}
private def NZDIGIT = DIGIT19
}

View file

@ -0,0 +1,29 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.model.ws
import akka.stream.scaladsl.Source
import akka.util.ByteString
/**
* The ADT for Websocket messages. A message can either be binary or a text message. Each of
* those can either be strict or streamed.
*/
sealed trait Message
sealed trait TextMessage extends Message
object TextMessage {
final case class Strict(text: String) extends TextMessage {
override def toString: String = s"TextMessage.Strict($text)"
}
final case class Streamed(textStream: Source[String, _]) extends TextMessage
}
sealed trait BinaryMessage extends Message
object BinaryMessage {
final case class Strict(data: ByteString) extends BinaryMessage {
override def toString: String = s"BinaryMessage.Strict($data)"
}
final case class Streamed(dataStream: Source[ByteString, _]) extends BinaryMessage
}

View file

@ -0,0 +1,35 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.model.ws
import scala.collection.immutable
import akka.stream.FlowMaterializer
import akka.stream.scaladsl.Flow
import akka.http.model.{ HttpHeader, HttpResponse }
/**
* A custom header that will be added to an Websocket upgrade HttpRequest that
* enables a request handler to upgrade this connection to a Websocket connection and
* registers a Websocket handler.
*/
trait UpgradeToWebsocket extends HttpHeader {
/**
* A sequence of protocols the client accepts.
*
* See http://tools.ietf.org/html/rfc6455#section-1.9
*/
def requestedProtocols: immutable.Seq[String]
/**
* The high-level interface to create a Websocket server based on "messages".
*
* Returns a response to return in a request handler that will signal the
* low-level HTTP implementation to upgrade the connection to Websocket and
* use the supplied handler to handle incoming Websocket messages. Optionally,
* a subprotocol out of the ones requested by the client can be chosen.
*/
def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String] = None)(implicit mat: FlowMaterializer): HttpResponse
}

View file

@ -18,6 +18,8 @@ private[akka] class ByteReader(input: ByteString) {
private[this] var off = 0
def hasRemaining: Boolean = off < input.size
def currentOffset: Int = off
def remainingData: ByteString = input.drop(off)
def fromStartToHere: ByteString = input.take(currentOffset)
@ -28,8 +30,14 @@ private[akka] class ByteReader(input: ByteString) {
off += 1
x.toInt & 0xFF
} else throw NeedMoreData
def readShort(): Int = readByte() | (readByte() << 8)
def readInt(): Int = readShort() | (readShort() << 16)
def readShortLE(): Int = readByte() | (readByte() << 8)
def readIntLE(): Int = readShortLE() | (readShortLE() << 16)
def readLongLE(): Long = (readIntBE() & 0xffffffffL) | ((readIntLE() & 0xffffffffL) << 32)
def readShortBE(): Int = (readByte() << 8) | readByte()
def readIntBE(): Int = (readShortBE() << 16) | readShortBE()
def readLongBE(): Long = ((readIntBE() & 0xffffffffL) << 32) | (readIntBE() & 0xffffffffL)
def skip(numBytes: Int): Unit =
if (off + numBytes <= input.length) off += numBytes
else throw NeedMoreData

View file

@ -84,6 +84,9 @@ private[http] object Renderer {
implicit object CharRenderer extends Renderer[Char] {
def render[R <: Rendering](r: R, value: Char): r.type = r ~~ value
}
implicit object IntRenderer extends Renderer[Int] {
def render[R <: Rendering](r: R, value: Int): r.type = r ~~ value
}
implicit object StringRenderer extends Renderer[String] {
def render[R <: Rendering](r: R, value: String): r.type = r ~~ value
}

View file

@ -5,13 +5,17 @@
package akka.http.util
import java.io.InputStream
import java.util.concurrent.atomic.AtomicBoolean
import org.reactivestreams.Publisher
import java.util.concurrent.atomic.{ AtomicReference, AtomicBoolean }
import akka.stream.impl.StreamLayout.Module
import akka.stream.impl.{ SourceModule, SinkModule, ActorFlowMaterializerImpl, PublisherSink }
import akka.stream.scaladsl.FlexiMerge._
import org.reactivestreams.{ Subscription, Processor, Subscriber, Publisher }
import scala.annotation.unchecked.uncheckedVariance
import scala.collection.immutable
import scala.concurrent.{ ExecutionContext, Future }
import akka.util.ByteString
import akka.http.model.RequestEntity
import akka.stream.{ FlowMaterializer, impl, OperationAttributes, ActorOperationAttributes }
import akka.stream._
import akka.stream.scaladsl._
import akka.stream.stage._
@ -193,6 +197,79 @@ private[http] object StreamUtils {
elem
}
}
def oneTimePublisherSink[In](cell: OneTimeWriteCell[Publisher[In]], name: String): Sink[In, Publisher[In]] =
new Sink[In, Publisher[In]](new OneTimePublisherSink(none, SinkShape(new Inlet(name)), cell))
def oneTimeSubscriberSource[Out](cell: OneTimeWriteCell[Subscriber[Out]], name: String): Source[Out, Subscriber[Out]] =
new Source[Out, Subscriber[Out]](new OneTimeSubscriberSource(none, SourceShape(new Outlet(name)), cell))
/** A copy of PublisherSink that allows access to the publisher through the cell but can only materialized once */
private class OneTimePublisherSink[In](attributes: OperationAttributes, shape: SinkShape[In], cell: OneTimeWriteCell[Publisher[In]])
extends PublisherSink[In](attributes, shape) {
override def create(context: MaterializationContext): (Subscriber[In], Publisher[In]) = {
val results = super.create(context)
cell.set(results._2)
results
}
override protected def newInstance(shape: SinkShape[In]): SinkModule[In, Publisher[In]] =
new OneTimePublisherSink[In](attributes, shape, cell)
override def withAttributes(attr: OperationAttributes): Module =
new OneTimePublisherSink[In](attr, amendShape(attr), cell)
}
/** A copy of SubscriberSource that allows access to the subscriber through the cell but can only materialized once */
private class OneTimeSubscriberSource[Out](val attributes: OperationAttributes, shape: SourceShape[Out], cell: OneTimeWriteCell[Subscriber[Out]])
extends SourceModule[Out, Subscriber[Out]](shape) {
override def create(context: MaterializationContext): (Publisher[Out], Subscriber[Out]) = {
val processor = new Processor[Out, Out] {
@volatile private var subscriber: Subscriber[_ >: Out] = null
override def subscribe(s: Subscriber[_ >: Out]): Unit = subscriber = s
override def onError(t: Throwable): Unit = subscriber.onError(t)
override def onSubscribe(s: Subscription): Unit = subscriber.onSubscribe(s)
override def onComplete(): Unit = subscriber.onComplete()
override def onNext(t: Out): Unit = subscriber.onNext(t)
}
cell.setValue(processor)
(processor, processor)
}
override protected def newInstance(shape: SourceShape[Out]): SourceModule[Out, Subscriber[Out]] =
new OneTimeSubscriberSource[Out](attributes, shape, cell)
override def withAttributes(attr: OperationAttributes): Module =
new OneTimeSubscriberSource[Out](attr, amendShape(attr), cell)
}
trait ReadableCell[+T] {
def value: T
}
/** A one time settable cell */
class OneTimeWriteCell[T <: AnyRef] extends AtomicReference[T] with ReadableCell[T] {
def value: T = {
val value = get()
require(value != null, "Value wasn't set yet")
value
}
def setValue(value: T): Unit =
if (!compareAndSet(null.asInstanceOf[T], value))
throw new IllegalStateException("Value can be only set once.")
}
/** A merge for two streams that just forwards all elements and closes the connection eagerly. */
class EagerCloseMerge2[T](name: String) extends FlexiMerge[T, FanInShape2[T, T, T]](new FanInShape2(name), OperationAttributes.name(name)) {
def createMergeLogic(s: FanInShape2[T, T, T]): MergeLogic[T] =
new MergeLogic[T] {
def initialState: State[T] = State[T](ReadAny(s.in0, s.in1)) {
case (ctx, port, in) ctx.emit(in); SameState
}
override def initialCompletionHandling: CompletionHandling = eagerClose
}
}
}
/**

View file

@ -43,13 +43,21 @@ package object util {
private[http] implicit class SourceWithHeadAndTail[T, Mat](val underlying: Source[Source[T, Any], Mat]) extends AnyVal {
def headAndTail: Source[(T, Source[T, Unit]), Mat] =
underlying.map { _.prefixAndTail(1).map { case (prefix, tail) (prefix.head, tail) } }
underlying.map {
_.prefixAndTail(1)
.filter(_._1.nonEmpty)
.map { case (prefix, tail) (prefix.head, tail) }
}
.flatten(FlattenStrategy.concat)
}
private[http] implicit class FlowWithHeadAndTail[In, Out, Mat](val underlying: Flow[In, Source[Out, Any], Mat]) extends AnyVal {
def headAndTail: Flow[In, (Out, Source[Out, Unit]), Mat] =
underlying.map { _.prefixAndTail(1).map { case (prefix, tail) (prefix.head, tail) } }
underlying.map {
_.prefixAndTail(1)
.filter(_._1.nonEmpty)
.map { case (prefix, tail) (prefix.head, tail) }
}
.flatten(FlattenStrategy.concat)
}

View file

@ -4,12 +4,17 @@
package akka.http
import scala.concurrent.duration._
import akka.actor.ActorSystem
import akka.http.model._
import akka.http.model.ws._
import akka.stream.ActorFlowMaterializer
import akka.stream.scaladsl.{ Source, Flow }
import com.typesafe.config.{ ConfigFactory, Config }
import HttpMethods._
import scala.concurrent.Await
object TestServer extends App {
val testConf: Config = ConfigFactory.parseString("""
akka.loglevel = INFO
@ -18,18 +23,36 @@ object TestServer extends App {
implicit val system = ActorSystem("ServerTest", testConf)
implicit val fm = ActorFlowMaterializer()
val binding = Http().bindAndHandleSync({
case HttpRequest(GET, Uri.Path("/"), _, _, _) index
case HttpRequest(GET, Uri.Path("/ping"), _, _, _) HttpResponse(entity = "PONG!")
case HttpRequest(GET, Uri.Path("/crash"), _, _, _) sys.error("BOOM!")
case _: HttpRequest HttpResponse(404, entity = "Unknown resource!")
}, interface = "localhost", port = 8080)
try {
val binding = Http().bindAndHandleSync({
case req @ HttpRequest(GET, Uri.Path("/"), _, _, _) if req.header[UpgradeToWebsocket].isDefined
req.header[UpgradeToWebsocket] match {
case Some(upgrade) upgrade.handleMessages(echoWebsocketService) // needed for running the autobahn test suite
case None HttpResponse(400, entity = "Not a valid websocket request!")
}
case req @ HttpRequest(GET, Uri.Path("/ws-greeter"), _, _, _)
req.header[UpgradeToWebsocket] match {
case Some(upgrade) upgrade.handleMessages(greeterWebsocketService)
case None HttpResponse(400, entity = "Not a valid websocket request!")
}
case HttpRequest(GET, Uri.Path("/"), _, _, _) index
case HttpRequest(GET, Uri.Path("/ping"), _, _, _) HttpResponse(entity = "PONG!")
case HttpRequest(GET, Uri.Path("/crash"), _, _, _) sys.error("BOOM!")
case req @ HttpRequest(GET, Uri.Path("/ws-greeter"), _, _, _)
req.header[UpgradeToWebsocket] match {
case Some(upgrade) upgrade.handleMessages(greeterWebsocketService)
case None HttpResponse(400, entity = "Not a valid websocket request!")
}
case _: HttpRequest HttpResponse(404, entity = "Unknown resource!")
}, interface = "localhost", port = 9001)
println(s"Server online at http://localhost:8080")
println("Press RETURN to stop...")
Console.readLine()
system.shutdown()
Await.result(binding, 1.second) // throws if binding fails
println("Server online at http://localhost:9001")
println("Press RETURN to stop...")
Console.readLine()
} finally {
system.shutdown()
}
////////////// helpers //////////////
@ -45,4 +68,15 @@ object TestServer extends App {
| </ul>
| </body>
|</html>""".stripMargin))
def echoWebsocketService: Flow[Message, Message, Unit] =
Flow[Message] // just let message flow directly to the output
def greeterWebsocketService: Flow[Message, Message, Unit] =
Flow[Message]
.collect {
case TextMessage.Strict(name) TextMessage.Strict(s"Hello '$name'")
case TextMessage.Streamed(nameStream) TextMessage.Streamed(Source.single("Hello ") ++ nameStream mapMaterialized (_ ()))
// ignore binary messages
}
}

View file

@ -540,7 +540,7 @@ class ResponseRendererSpec extends FreeSpec with Matchers with BeforeAndAfterAll
override def afterAll() = system.shutdown()
class TestSetup(val serverHeader: Option[Server] = Some(Server("akka-http/1.0.0")))
extends HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint = 64, NoLogging) {
extends HttpResponseRendererFactory(serverHeader, responseHeaderSizeHint = 64, NoLogging, None) {
def renderTo(expected: String): Matcher[HttpResponse] =
renderTo(expected, close = false) compose (ResponseRenderingContext(_))

View file

@ -0,0 +1,103 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.parboiled2
import parboiled2._
import akka.util.ByteString
import scala.annotation.tailrec
import scala.util.{ Try, Failure, Success }
object BitBuilder {
implicit class BitBuilderContext(val ctx: StringContext) {
def b(args: Any*): ByteString = {
val parser = new BitSpecParser(ctx.parts.mkString)
val bits = parser.parseBits()
//println(bits)
//println(bits.get.toByteString.map(_ formatted "%02x").mkString(" "))
bits.get.toByteString
}
}
}
final case class Bits(elements: Seq[Bits.BitElement]) {
def toByteString: ByteString = {
import Bits._
val bits = elements.map(_.bits).sum
require(bits % 8 == 0)
val data = new Array[Byte](bits / 8)
@tailrec def rec(byteIdx: Int, bitIdx: Int, remaining: Seq[Bits.BitElement]): Unit =
if (bitIdx >= 8) rec(byteIdx + 1, bitIdx - 8, remaining)
else remaining match {
case Zero +: rest
// zero by default
rec(byteIdx, bitIdx + 1, rest)
case One +: rest
data(byteIdx) = (data(byteIdx) | (1 << (7 - bitIdx))).toByte
rec(byteIdx, bitIdx + 1, rest)
case Multibit(bits, value) +: rest
val numBits = math.min(8 - bitIdx, bits)
val remainingBits = bits - numBits
val highestNBits = value >> remainingBits
val lowestNBitMask = (~(0xff << numBits) & 0xff)
data(byteIdx) = (data(byteIdx) | (highestNBits & lowestNBitMask)).toByte
if (remainingBits > 0)
rec(byteIdx + 1, 0, Multibit(remainingBits, value) +: rest)
else
rec(byteIdx, bitIdx + numBits, rest)
case Nil
require(bitIdx == 0 && byteIdx == bits / 8)
}
rec(0, 0, elements)
ByteString(data) // this could be ByteString1C
}
}
object Bits {
sealed trait BitElement {
def bits: Int
}
sealed abstract class SingleBit extends BitElement {
def bits: Int = 1
}
case object Zero extends SingleBit
case object One extends SingleBit
case class Multibit(bits: Int, value: Long) extends BitElement
}
class BitSpecParser(val input: ParserInput) extends parboiled2.Parser {
import Bits._
def parseBits(): Try[Bits] =
bits.run() match {
case s: Success[Bits] s
case Failure(e: ParseError) Failure(new RuntimeException(formatError(e, showTraces = true)))
}
def bits: Rule1[Bits] = rule { zeroOrMore(element) ~ EOI ~> (Bits(_)) }
val WSChar = CharPredicate(' ', '\t', '\n')
def ws = rule { zeroOrMore(wsElement) }
def wsElement = rule { WSChar | comment }
def comment =
rule {
'#' ~ zeroOrMore(!'\n' ~ ANY) ~ '\n'
}
def element: Rule1[BitElement] = rule {
zero | one | multi
}
def zero: Rule1[BitElement] = rule { '0' ~ push(Zero) ~ ws }
def one: Rule1[BitElement] = rule { '1' ~ push(One) ~ ws }
def multi: Rule1[Multibit] = rule {
capture(oneOrMore('x' ~ ws)) ~> (_.count(_ == 'x')) ~ '=' ~ value ~ ws ~> Multibit
}
def value: Rule1[Long] = rule {
capture(oneOrMore(CharPredicate.HexDigit)) ~> ((str: String) java.lang.Long.parseLong(str, 16))
}
}

View file

@ -0,0 +1,326 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import scala.collection.immutable
import scala.concurrent.duration._
import org.scalatest.matchers.Matcher
import org.scalatest.{ FreeSpec, Matchers }
import akka.util.ByteString
import akka.stream.scaladsl.Source
import akka.stream.stage.Stage
import akka.http.util._
import Protocol.Opcode
class FramingSpec extends FreeSpec with Matchers with WithMaterializerSpec {
import BitBuilder._
"The Websocket parser/renderer round-trip should work for" - {
"the frame header" - {
"interpret flags correctly" - {
"FIN" in {
b"""1000 # flags
0000 # opcode
0 # mask?
0000000 # length
""" should parseTo(FrameHeader(Opcode.Continuation, None, 0, fin = true))
}
"RSV1" in {
b"""0100 # flags
0000 # opcode
0 # mask?
0000000 # length
""" should parseTo(FrameHeader(Opcode.Continuation, None, 0, fin = false, rsv1 = true))
}
"RSV2" in {
b"""0010 # flags
0000 # opcode
0 # mask?
0000000 # length
""" should parseTo(FrameHeader(Opcode.Continuation, None, 0, fin = false, rsv2 = true))
}
"RSV3" in {
b"""0001 # flags
0000 # opcode
0 # mask?
0000000 # length
""" should parseTo(FrameHeader(Opcode.Continuation, None, 0, fin = false, rsv3 = true))
}
}
"interpret opcode correctly" - {
"Continuation" in {
b"""0000 # flags
xxxx=0 # opcode
0 # mask?
0000000 # length
""" should parseTo(FrameHeader(Opcode.Continuation, None, 0, fin = false))
}
"Text" in {
b"""0000 # flags
xxxx=1 # opcode
0 # mask?
0000000 # length
""" should parseTo(FrameHeader(Opcode.Text, None, 0, fin = false))
}
"Binary" in {
b"""0000 # flags
xxxx=2 # opcode
0 # mask?
0000000 # length
""" should parseTo(FrameHeader(Opcode.Binary, None, 0, fin = false))
}
"Close" in {
b"""0000 # flags
xxxx=8 # opcode
0 # mask?
0000000 # length
""" should parseTo(FrameHeader(Opcode.Close, None, 0, fin = false))
}
"Ping" in {
b"""0000 # flags
xxxx=9 # opcode
0 # mask?
0000000 # length
""" should parseTo(FrameHeader(Opcode.Ping, None, 0, fin = false))
}
"Pong" in {
b"""0000 # flags
xxxx=a # opcode
0 # mask?
0000000 # length
""" should parseTo(FrameHeader(Opcode.Pong, None, 0, fin = false))
}
"Other" in {
b"""0000 # flags
xxxx=6 # opcode
0 # mask?
0000000 # length
""" should parseTo(FrameHeader(Opcode.Other(6), None, 0, fin = false))
}
}
"read mask correctly" in {
b"""0000 # flags
0000 # opcode
1 # mask?
0000000 # length
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx=a1b2c3d4
""" should parseTo(FrameHeader(Opcode.Continuation, Some(0xa1b2c3d4), 0, fin = false))
}
"read length" - {
"< 126" in {
b"""0000 # flags
xxxx=0 # opcode
0 # mask?
xxxxxxx=5 # length
""" should parseTo(FrameHeader(Opcode.Continuation, None, 5, fin = false))
}
"126" in {
b"""0000 # flags
xxxx=0 # opcode
0 # mask?
xxxxxxx=7e # length
xxxxxxxx
xxxxxxxx=007e # length16
""" should parseTo(FrameHeader(Opcode.Continuation, None, 126, fin = false))
}
"127" in {
b"""0000 # flags
xxxx=0 # opcode
0 # mask?
xxxxxxx=7e # length
xxxxxxxx
xxxxxxxx=007f # length16
""" should parseTo(FrameHeader(Opcode.Continuation, None, 127, fin = false))
}
"127 < length < 65536" in {
b"""0000 # flags
xxxx=0 # opcode
0 # mask?
xxxxxxx=7e # length
xxxxxxxx
xxxxxxxx=d28e # length16
""" should parseTo(FrameHeader(Opcode.Continuation, None, 0xd28e, fin = false))
}
"65535" in {
b"""0000 # flags
xxxx=0 # opcode
0 # mask?
xxxxxxx=7e # length
xxxxxxxx
xxxxxxxx=ffff # length16
""" should parseTo(FrameHeader(Opcode.Continuation, None, 0xffff, fin = false))
}
"65536" in {
b"""0000 # flags
xxxx=0 # opcode
0 # mask?
xxxxxxx=7f # length
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx=0000000000010000 # length64
""" should parseTo(FrameHeader(Opcode.Continuation, None, 0x10000, fin = false))
}
"> 65536" in {
b"""0000 # flags
xxxx=0 # opcode
0 # mask?
xxxxxxx=7f # length
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx=0000000123456789 # length64
""" should parseTo(FrameHeader(Opcode.Continuation, None, 0x123456789L, fin = false))
}
"Long.MaxValue" in {
b"""0000 # flags
xxxx=0 # opcode
0 # mask?
xxxxxxx=7f # length
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx=7fffffffffffffff # length64
""" should parseTo(FrameHeader(Opcode.Continuation, None, Long.MaxValue, fin = false))
}
}
}
"a partial frame" in {
val header =
b"""0000 # flags
xxxx=1 # opcode
0 # mask?
xxxxxxx=5 # length
"""
val data = ByteString("abc")
(header ++ data) should parseTo(
FrameStart(
FrameHeader(Opcode.Text, None, 5, fin = false),
data))
}
"a partial frame of total size > Int.MaxValue" in {
val header =
b"""0000 # flags
xxxx=0 # opcode
0 # mask?
xxxxxxx=7f # length
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx=00000000ffffffff # length64
"""
val data = ByteString("abc", "ASCII")
Seq(header, data) should parseMultipleTo(
FrameStart(FrameHeader(Opcode.Continuation, None, 0xFFFFFFFFL, fin = false), ByteString.empty),
FrameData(data, lastPart = false))
}
"a full frame" in {
val header =
b"""0000 # flags
xxxx=0 # opcode
0 # mask?
xxxxxxx=5 # length
"""
val data = ByteString("abcde")
(header ++ data) should parseTo(
FrameStart(FrameHeader(Opcode.Continuation, None, 5, fin = false), data))
}
"a full frame in chunks" in {
val header =
b"""0000 # flags
xxxx=1 # opcode
0 # mask?
xxxxxxx=5 # length
"""
val data1 = ByteString("abc")
val data2 = ByteString("de")
val expectedHeader = FrameHeader(Opcode.Text, None, 5, fin = false)
Seq(header, data1, data2) should parseMultipleTo(
FrameStart(expectedHeader, ByteString.empty),
FrameData(data1, lastPart = false),
FrameData(data2, lastPart = true))
}
"several frames" in {
val header1 =
b"""0000 # flags
xxxx=0 # opcode
0 # mask?
xxxxxxx=5 # length
"""
val header2 =
b"""0000 # flags
xxxx=0 # opcode
0 # mask?
xxxxxxx=7 # length
"""
val data1 = ByteString("abcde")
val data2 = ByteString("abc")
(header1 ++ data1 ++ header2 ++ data2) should parseTo(
FrameStart(FrameHeader(Opcode.Continuation, None, 5, fin = false), data1),
FrameStart(FrameHeader(Opcode.Continuation, None, 7, fin = false), data2))
}
}
def parseTo(events: FrameEvent*): Matcher[ByteString] =
parseMultipleTo(events: _*).compose(Seq(_))
def parseMultipleTo(events: FrameEvent*): Matcher[Seq[ByteString]] =
equal(events).matcher[Seq[FrameEvent]].compose {
(chunks: Seq[ByteString])
val result = parseToEvents(chunks)
result shouldEqual events
val rendered = renderToByteString(result)
rendered shouldEqual chunks.reduce(_ ++ _)
result
}
def parseToEvents(bytes: Seq[ByteString]): immutable.Seq[FrameEvent] =
Source(bytes.toVector).transform(newParser).runFold(Vector.empty[FrameEvent])(_ :+ _)
.awaitResult(1.second)
def renderToByteString(events: immutable.Seq[FrameEvent]): ByteString =
Source(events).transform(newRenderer).runFold(ByteString.empty)(_ ++ _)
.awaitResult(1.second)
protected def newParser(): Stage[ByteString, FrameEvent] = new FrameEventParser
protected def newRenderer(): Stage[FrameEvent, ByteString] = new FrameEventRenderer
import scala.language.implicitConversions
implicit def headerToEvent(header: FrameHeader): FrameEvent =
FrameStart(header, ByteString.empty)
}

View file

@ -0,0 +1,966 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import scala.annotation.tailrec
import scala.concurrent.duration._
import akka.stream.FlowShape
import akka.stream.scaladsl._
import akka.stream.testkit.StreamTestKit
import akka.util.ByteString
import org.scalatest.{ Matchers, FreeSpec }
import scala.util.Random
import akka.http.model.ws._
import Protocol.Opcode
class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
"The Websocket implementation should" - {
"collect messages from frames" - {
"for binary messages" - {
"for an empty message" in new ClientTestSetup {
val input = frameHeader(Opcode.Binary, 0, fin = true)
pushInput(input)
expectMessage(BinaryMessage.Strict(ByteString.empty))
}
"for one complete, strict, single frame message" in new ClientTestSetup {
val data = ByteString("abcdef", "ASCII")
val input = frameHeader(Opcode.Binary, 6, fin = true) ++ data
pushInput(input)
expectMessage(BinaryMessage.Strict(data))
}
"for a partial frame" in new ClientTestSetup {
val data1 = ByteString("abc", "ASCII")
val header = frameHeader(Opcode.Binary, 6, fin = true)
pushInput(header ++ data1)
val BinaryMessage.Streamed(dataSource) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(2)
sub.expectNext(data1)
}
"for a frame split up into parts" in new ClientTestSetup {
val data1 = ByteString("abc", "ASCII")
val header = frameHeader(Opcode.Binary, 6, fin = true)
pushInput(header)
val BinaryMessage.Streamed(dataSource) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(2)
pushInput(data1)
sub.expectNext(data1)
val data2 = ByteString("def", "ASCII")
pushInput(data2)
sub.expectNext(data2)
sub.expectComplete()
}
"for a message split into several frames" in new ClientTestSetup {
val data1 = ByteString("abc", "ASCII")
val header1 = frameHeader(Opcode.Binary, 3, fin = false)
pushInput(header1 ++ data1)
val BinaryMessage.Streamed(dataSource) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(2)
sub.expectNext(data1)
val header2 = frameHeader(Opcode.Continuation, 4, fin = true)
val data2 = ByteString("defg", "ASCII")
pushInput(header2 ++ data2)
sub.expectNext(data2)
sub.expectComplete()
}
"for several messages" in new ClientTestSetup {
val data1 = ByteString("abc", "ASCII")
val header1 = frameHeader(Opcode.Binary, 3, fin = false)
pushInput(header1 ++ data1)
val BinaryMessage.Streamed(dataSource) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(2)
sub.expectNext(data1)
val header2 = frameHeader(Opcode.Continuation, 4, fin = true)
val header3 = frameHeader(Opcode.Binary, 2, fin = true)
val data2 = ByteString("defg", "ASCII")
val data3 = ByteString("h")
pushInput(header2 ++ data2 ++ header3 ++ data3)
sub.expectNext(data2)
sub.expectComplete()
val BinaryMessage.Streamed(dataSource2) = expectMessage()
val sub2 = StreamTestKit.SubscriberProbe[ByteString]
dataSource2.runWith(Sink(sub2))
val s2 = sub2.expectSubscription()
s2.request(2)
sub2.expectNext(data3)
val data4 = ByteString("i")
pushInput(data4)
sub2.expectNext(data4)
sub2.expectComplete()
}
"unmask masked input on the server side" in new ServerTestSetup {
val mask = Random.nextInt()
val (data, _) = maskedASCII("abcdef", mask)
val data1 = data.take(3)
val data2 = data.drop(3)
val header = frameHeader(Opcode.Binary, 6, fin = true, mask = Some(mask))
pushInput(header ++ data1)
val BinaryMessage.Streamed(dataSource) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(2)
sub.expectNext(ByteString("abc", "ASCII"))
pushInput(data2)
sub.expectNext(ByteString("def", "ASCII"))
sub.expectComplete()
}
}
"for text messages" - {
"empty message" in new ClientTestSetup {
val input = frameHeader(Opcode.Text, 0, fin = true)
pushInput(input)
expectMessage(TextMessage.Strict(""))
}
"decode complete, strict frame from utf8" in new ClientTestSetup {
val msg = "äbcdef€\uffff"
val data = ByteString(msg, "UTF-8")
val input = frameHeader(Opcode.Text, data.size, fin = true) ++ data
pushInput(input)
expectMessage(TextMessage.Strict(msg))
}
"decode utf8 as far as possible for partial frame" in new ClientTestSetup {
val msg = "bäcdef€"
val data = ByteString(msg, "UTF-8")
val data0 = data.slice(0, 2)
val data1 = data.slice(2, 5)
val data2 = data.slice(5, data.size)
val input = frameHeader(Opcode.Text, data.size, fin = true) ++ data0
pushInput(input)
val TextMessage.Streamed(parts) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[String]
parts.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(4)
sub.expectNext("b")
pushInput(data1)
sub.expectNext("äcd")
}
"decode utf8 with code point split across frames" in new ClientTestSetup {
val msg = "äbcdef€"
val data = ByteString(msg, "UTF-8")
val data0 = data.slice(0, 1)
val data1 = data.slice(1, data.size)
val header0 = frameHeader(Opcode.Text, data0.size, fin = false)
pushInput(header0 ++ data0)
val TextMessage.Streamed(parts) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[String]
parts.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(4)
sub.expectNoMsg(100.millis)
val header1 = frameHeader(Opcode.Continuation, data1.size, fin = true)
pushInput(header1 ++ data1)
sub.expectNext("äbcdef€")
}
"unmask masked input on the server side" in new ServerTestSetup {
val mask = Random.nextInt()
val (data, _) = maskedUTF8("äbcdef€", mask)
val data1 = data.take(3)
val data2 = data.drop(3)
val header = frameHeader(Opcode.Binary, data.size, fin = true, mask = Some(mask))
pushInput(header ++ data1)
val BinaryMessage.Streamed(dataSource) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(2)
sub.expectNext(ByteString("äb", "UTF-8"))
pushInput(data2)
sub.expectNext(ByteString("cdef€", "UTF-8"))
sub.expectComplete()
}
}
}
"render frames from messages" - {
"for binary messages" - {
"for a short strict message" in new ServerTestSetup {
val data = ByteString("abcdef", "ASCII")
val msg = BinaryMessage.Strict(data)
netOutSub.request(5)
pushMessage(msg)
expectFrameOnNetwork(Opcode.Binary, data, fin = true)
}
"for a strict message larger than configured maximum frame size" in pending
"for a streamed message" in new ServerTestSetup {
val data = ByteString("abcdefg", "ASCII")
val pub = StreamTestKit.PublisherProbe[ByteString]
val msg = BinaryMessage.Streamed(Source(pub))
netOutSub.request(6)
pushMessage(msg)
val sub = pub.expectSubscription()
expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false)
val data1 = data.take(3)
val data2 = data.drop(3)
sub.sendNext(data1)
expectFrameOnNetwork(Opcode.Continuation, data1, fin = false)
sub.sendNext(data2)
expectFrameOnNetwork(Opcode.Continuation, data2, fin = false)
sub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
}
"for a streamed message with a chunk being larger than configured maximum frame size" in pending
"and mask input on the client side" in new ClientTestSetup {
val data = ByteString("abcdefg", "ASCII")
val pub = StreamTestKit.PublisherProbe[ByteString]
val msg = BinaryMessage.Streamed(Source(pub))
netOutSub.request(7)
pushMessage(msg)
val sub = pub.expectSubscription()
expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false)
val data1 = data.take(3)
val data2 = data.drop(3)
sub.sendNext(data1)
expectMaskedFrameOnNetwork(Opcode.Continuation, data1, fin = false)
sub.sendNext(data2)
expectMaskedFrameOnNetwork(Opcode.Continuation, data2, fin = false)
sub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
}
}
"for text messages" - {
"for a short strict message" in new ServerTestSetup {
val text = "äbcdef"
val msg = TextMessage.Strict(text)
netOutSub.request(5)
pushMessage(msg)
expectFrameOnNetwork(Opcode.Text, ByteString(text, "UTF-8"), fin = true)
}
"for a strict message larger than configured maximum frame size" in pending
"for a streamed message" in new ServerTestSetup {
val text = "äbcd€fg"
val pub = StreamTestKit.PublisherProbe[String]
val msg = TextMessage.Streamed(Source(pub))
netOutSub.request(6)
pushMessage(msg)
val sub = pub.expectSubscription()
expectFrameHeaderOnNetwork(Opcode.Text, 0, fin = false)
val text1 = text.take(3)
val text1Bytes = ByteString(text1, "UTF-8")
val text2 = text.drop(3)
val text2Bytes = ByteString(text2, "UTF-8")
sub.sendNext(text1)
expectFrameOnNetwork(Opcode.Continuation, text1Bytes, fin = false)
sub.sendNext(text2)
expectFrameOnNetwork(Opcode.Continuation, text2Bytes, fin = false)
sub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
}
"for a streamed message don't convert half surrogate pairs naively" in new ServerTestSetup {
val gclef = "𝄞"
gclef.size shouldEqual 2
// split up the code point
val half1 = gclef.take(1)
val half2 = gclef.drop(1)
println(half1(0).toInt.toHexString)
println(half2(0).toInt.toHexString)
val pub = StreamTestKit.PublisherProbe[String]
val msg = TextMessage.Streamed(Source(pub))
netOutSub.request(6)
pushMessage(msg)
val sub = pub.expectSubscription()
expectFrameHeaderOnNetwork(Opcode.Text, 0, fin = false)
sub.sendNext(half1)
expectNoNetworkData()
sub.sendNext(half2)
expectFrameOnNetwork(Opcode.Continuation, ByteString(gclef, "utf8"), fin = false)
}
"for a streamed message with a chunk being larger than configured maximum frame size" in pending
"and mask input on the client side" in new ClientTestSetup {
val text = "abcdefg"
val pub = StreamTestKit.PublisherProbe[String]
val msg = TextMessage.Streamed(Source(pub))
netOutSub.request(5)
pushMessage(msg)
val sub = pub.expectSubscription()
expectFrameOnNetwork(Opcode.Text, ByteString.empty, fin = false)
val text1 = text.take(3)
val text1Bytes = ByteString(text1, "UTF-8")
val text2 = text.drop(3)
val text2Bytes = ByteString(text2, "UTF-8")
sub.sendNext(text1)
expectMaskedFrameOnNetwork(Opcode.Continuation, text1Bytes, fin = false)
sub.sendNext(text2)
expectMaskedFrameOnNetwork(Opcode.Continuation, text2Bytes, fin = false)
sub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
}
}
}
"supply automatic low-level websocket behavior" - {
"respond to ping frames unmasking them on the server side" in new ServerTestSetup {
val mask = Random.nextInt()
val input = frameHeader(Opcode.Ping, 6, fin = true, mask = Some(mask)) ++ maskedASCII("abcdef", mask)._1
pushInput(input)
netOutSub.request(5)
expectFrameOnNetwork(Opcode.Pong, ByteString("abcdef"), fin = true)
}
"respond to ping frames masking them on the client side" in new ClientTestSetup {
val input = frameHeader(Opcode.Ping, 6, fin = true) ++ ByteString("abcdef")
pushInput(input)
netOutSub.request(5)
expectMaskedFrameOnNetwork(Opcode.Pong, ByteString("abcdef"), fin = true)
}
"respond to ping frames interleaved with data frames (without mixing frame data)" in new ServerTestSetup {
// receive multi-frame message
// receive and handle interleaved ping frame
// concurrently send out messages from handler
val mask1 = Random.nextInt()
val input1 = frameHeader(Opcode.Binary, 3, fin = false, mask = Some(mask1)) ++ maskedASCII("123", mask1)._1
pushInput(input1)
val BinaryMessage.Streamed(dataSource) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(2)
sub.expectNext(ByteString("123", "ASCII"))
val outPub = StreamTestKit.PublisherProbe[ByteString]
val msg = BinaryMessage.Streamed(Source(outPub))
netOutSub.request(10)
pushMessage(msg)
expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false)
val outSub = outPub.expectSubscription()
val outData1 = ByteString("abc", "ASCII")
outSub.sendNext(outData1)
expectFrameOnNetwork(Opcode.Continuation, outData1, fin = false)
val pingMask = Random.nextInt()
val pingData = maskedASCII("pling", pingMask)._1
val pingData0 = pingData.take(3)
val pingData1 = pingData.drop(3)
pushInput(frameHeader(Opcode.Ping, 5, fin = true, mask = Some(pingMask)) ++ pingData0)
expectNoNetworkData()
pushInput(pingData1)
expectFrameOnNetwork(Opcode.Pong, ByteString("pling", "ASCII"), fin = true)
val outData2 = ByteString("def", "ASCII")
outSub.sendNext(outData2)
expectFrameOnNetwork(Opcode.Continuation, outData2, fin = false)
outSub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
val mask2 = Random.nextInt()
val input2 = frameHeader(Opcode.Continuation, 3, fin = true, mask = Some(mask2)) ++ maskedASCII("456", mask2)._1
pushInput(input2)
sub.expectNext(ByteString("456", "ASCII"))
sub.expectComplete()
}
"don't respond to unsolicited pong frames" in new ClientTestSetup {
val data = frameHeader(Opcode.Pong, 6, fin = true) ++ ByteString("abcdef")
pushInput(data)
netOutSub.request(5)
expectNoNetworkData()
}
}
"provide close behavior" - {
"after receiving regular close frame when idle (user closes immediately)" in new ServerTestSetup {
netInSub.expectRequest()
netOutSub.request(20)
messageOutSub.request(20)
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
netIn.expectNoMsg(1.second) // especially the cancellation not yet
expectNoNetworkData()
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
netInSub.expectCancellation()
}
"after receiving close frame without close code" in new ServerTestSetup {
netInSub.expectRequest()
pushInput(frameHeader(Opcode.Close, 0, fin = true))
messageIn.expectComplete()
messageOutSub.sendComplete()
// especially mustn't be Procotol.CloseCodes.NoCodePresent
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
netInSub.expectCancellation()
}
"after receiving regular close frame when idle (user still sends some data)" in new ServerTestSetup {
netOutSub.request(20)
messageOutSub.request(20)
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
// sending another message is allowed before closing (inherently racy)
val pub = StreamTestKit.PublisherProbe[ByteString]
val msg = BinaryMessage.Streamed(Source(pub))
pushMessage(msg)
expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false)
val data = ByteString("abc", "ASCII")
val dataSub = pub.expectSubscription()
dataSub.sendNext(data)
expectFrameOnNetwork(Opcode.Continuation, data, fin = false)
dataSub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
}
"after receiving regular close frame when fragmented message is still open" in pendingUntilFixed {
new ServerTestSetup {
netOutSub.request(10)
messageInSub.request(10)
pushInput(frameHeader(Protocol.Opcode.Binary, 0, fin = false))
val BinaryMessage.Streamed(dataSource) = messageIn.expectNext()
val inSubscriber = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(inSubscriber))
val inSub = inSubscriber.expectSubscription()
val outData = ByteString("def", "ASCII")
val mask = Random.nextInt()
pushInput(frameHeader(Protocol.Opcode.Continuation, 3, fin = false, mask = Some(mask)) ++ maskedBytes(outData, mask)._1)
inSub.request(5)
inSubscriber.expectNext(outData)
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
inSubscriber.expectError()
// truncation of open message
// sending another message is allowed before closing (inherently racy)
val pub = StreamTestKit.PublisherProbe[ByteString]
val msg = BinaryMessage.Streamed(Source(pub))
pushMessage(msg)
expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false)
val data = ByteString("abc", "ASCII")
val dataSub = pub.expectSubscription()
dataSub.sendNext(data)
expectFrameOnNetwork(Opcode.Continuation, data, fin = false)
dataSub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
}
}
"after receiving error close frame" in pending
"after peer closes connection without sending a close frame" in new ServerTestSetup {
netInSub.expectRequest()
netInSub.sendComplete()
messageIn.expectComplete()
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
}
"when user handler closes (simple)" in new ServerTestSetup {
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectNoMsg(1.second) // wait for peer to close regularly
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
netOut.expectComplete()
netInSub.expectCancellation()
}
"when user handler closes main stream and substream only afterwards" in new ServerTestSetup {
netOutSub.request(10)
messageInSub.request(10)
// send half a message
val pub = StreamTestKit.PublisherProbe[ByteString]
val msg = BinaryMessage.Streamed(Source(pub))
pushMessage(msg)
expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false)
val data = ByteString("abc", "ASCII")
val dataSub = pub.expectSubscription()
dataSub.sendNext(data)
expectFrameOnNetwork(Opcode.Continuation, data, fin = false)
messageOutSub.sendComplete()
expectNoNetworkData() // need to wait for substream to close
dataSub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectNoMsg(1.second) // wait for peer to close regularly
val mask = Random.nextInt()
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
netOut.expectComplete()
netInSub.expectCancellation()
}
"if user handler fails" in pending
"if peer closes with invalid close frame" - {
"close code outside of the valid range" in new ServerTestSetup {
netInSub.expectRequest()
pushInput(frameHeader(Opcode.Close, 1, mask = Some(Random.nextInt()), fin = true) ++ ByteString("x"))
val error = messageIn.expectError()
expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError)
netOut.expectComplete()
netInSub.expectCancellation()
}
"close data of size 1" in new ServerTestSetup {
netInSub.expectRequest()
pushInput(frameHeader(Opcode.Close, 1, mask = Some(Random.nextInt()), fin = true) ++ ByteString("x"))
val error = messageIn.expectError()
expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError)
netOut.expectComplete()
netInSub.expectCancellation()
}
"reason is no valid utf8 data" in pending
}
"timeout if user handler closes and peer doesn't send a close frame" in new ServerTestSetup {
netInSub.expectRequest()
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
netInSub.expectCancellation()
}
"timeout after we close after error and peer doesn't send a close frame" in new ServerTestSetup {
netInSub.expectRequest()
pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv1 = true))
expectProtocolErrorOnNetwork()
messageOutSub.sendComplete()
netOut.expectComplete()
netInSub.expectCancellation()
}
"ignore frames peer sends after close frame" in new ServerTestSetup {
netInSub.expectRequest()
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
pushInput(frameHeader(Opcode.Binary, 0, fin = true))
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
netInSub.expectCancellation()
}
}
"reject unexpected frames" - {
"reserved bits set" - {
"rsv1" in new ServerTestSetup {
pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv1 = true))
expectProtocolErrorOnNetwork()
}
"rsv2" in new ServerTestSetup {
pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv2 = true))
expectProtocolErrorOnNetwork()
}
"rsv3" in new ServerTestSetup {
pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv3 = true))
expectProtocolErrorOnNetwork()
}
}
"highest bit of 64-bit length is set" in new ServerTestSetup {
import BitBuilder._
val header =
b"""0000 # flags
xxxx=1 # opcode
1 # mask?
xxxxxxx=7f # length
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx=ffffffff
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx=ffffffff # length64
00000000
00000000
00000000
00000000 # empty mask
"""
pushInput(header)
expectProtocolErrorOnNetwork()
}
"control frame bigger than 125 bytes" in new ServerTestSetup {
pushInput(frameHeader(Opcode.Ping, 126, fin = true, mask = Some(0)))
expectProtocolErrorOnNetwork()
}
"fragmented control frame" in new ServerTestSetup {
pushInput(frameHeader(Opcode.Ping, 0, fin = false, mask = Some(0)))
expectProtocolErrorOnNetwork()
}
"unexpected continuation frame" in new ServerTestSetup {
pushInput(frameHeader(Opcode.Continuation, 0, fin = false, mask = Some(0)))
expectProtocolErrorOnNetwork()
}
"unexpected data frame when waiting for continuation" in new ServerTestSetup {
pushInput(frameHeader(Opcode.Binary, 0, fin = false) ++
frameHeader(Opcode.Binary, 0, fin = false))
expectProtocolErrorOnNetwork()
}
"invalid utf8 encoding for single frame message" in new ClientTestSetup {
val data = ByteString(
(128 + 64).toByte, // start two byte sequence
0 // but don't finish it
)
pushInput(frameHeader(Opcode.Text, 2, fin = true) ++ data)
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
}
"invalid utf8 encoding for streamed frame" in new ClientTestSetup {
val data = ByteString(
(128 + 64).toByte, // start two byte sequence
0 // but don't finish it
)
pushInput(frameHeader(Opcode.Text, 0, fin = false) ++
frameHeader(Opcode.Continuation, 2, fin = true) ++
data)
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
}
"truncated utf8 encoding for single frame message" in new ClientTestSetup {
val data = ByteString("€", "UTF-8").take(1) // half a euro
pushInput(frameHeader(Opcode.Text, 1, fin = true) ++ data)
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
}
"truncated utf8 encoding for streamed frame" in new ClientTestSetup {
val data = ByteString("€", "UTF-8").take(1) // half a euro
pushInput(frameHeader(Opcode.Text, 0, fin = false) ++
frameHeader(Opcode.Continuation, 1, fin = true) ++
data)
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
}
"half a surrogate pair in utf8 encoding for a strict frame" in new ClientTestSetup {
val data = ByteString(0xed, 0xa0, 0x80) // not strictly supported by utf-8
pushInput(frameHeader(Opcode.Text, 3, fin = true) ++ data)
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
}
"half a surrogate pair in utf8 encoding for a streamed frame" in new ClientTestSetup {
val data = ByteString(0xed, 0xa0, 0x80) // not strictly supported by utf-8
pushInput(frameHeader(Opcode.Text, 0, fin = false))
pushInput(frameHeader(Opcode.Continuation, 3, fin = true) ++ data)
messageIn.expectError()
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
}
"unmasked input on the server side" in new ServerTestSetup {
val data = ByteString("abcdef", "ASCII")
val input = frameHeader(Opcode.Binary, 6, fin = true) ++ data
pushInput(input)
expectProtocolErrorOnNetwork()
}
"masked input on the client side" in new ClientTestSetup {
val mask = Random.nextInt()
val input = frameHeader(Opcode.Binary, 6, fin = true, mask = Some(mask)) ++ maskedASCII("abcdef", mask)._1
pushInput(input)
expectProtocolErrorOnNetwork()
}
}
"support per-message-compression extension" in pending
}
class ServerTestSetup extends TestSetup {
protected def serverSide: Boolean = true
}
class ClientTestSetup extends TestSetup {
protected def serverSide: Boolean = false
}
abstract class TestSetup {
protected def serverSide: Boolean
protected def closeTimeout: FiniteDuration = 1.second
val netIn = StreamTestKit.PublisherProbe[ByteString]
val netOut = StreamTestKit.SubscriberProbe[ByteString]
val messageIn = StreamTestKit.SubscriberProbe[Message]
val messageOut = StreamTestKit.PublisherProbe[Message]
val messageHandler: Flow[Message, Message, Unit] =
Flow.wrap {
FlowGraph.partial() { implicit b
val in = b.add(Sink(messageIn))
val out = b.add(Source(messageOut))
FlowShape[Message, Message](in, out)
}
}
Source(netIn)
.via(printEvent("netIn"))
.transform(() new FrameEventParser)
.via(Websocket.handleMessages(messageHandler, serverSide, closeTimeout = closeTimeout))
.via(printEvent("frameRendererIn"))
.transform(() new FrameEventRenderer)
.via(printEvent("frameRendererOut"))
.to(Sink(netOut))
.run()
val netInSub = netIn.expectSubscription()
val netOutSub = netOut.expectSubscription()
val messageOutSub = messageOut.expectSubscription()
val messageInSub = messageIn.expectSubscription()
def pushInput(data: ByteString): Unit = {
// TODO: expect/handle request?
netInSub.sendNext(data)
}
def pushMessage(msg: Message): Unit = {
messageOutSub.sendNext(msg)
}
def expectMessage(message: Message): Unit = {
messageInSub.request(1)
messageIn.expectNext(message)
}
def expectMessage(): Message = {
messageInSub.request(1)
messageIn.expectNext()
}
var inBuffer = ByteString.empty
@tailrec final def expectNetworkData(bytes: Int): ByteString =
if (inBuffer.size >= bytes) {
val res = inBuffer.take(bytes)
inBuffer = inBuffer.drop(bytes)
res
} else {
netOutSub.request(1)
inBuffer ++= netOut.expectNext()
expectNetworkData(bytes)
}
def expectNetworkData(data: ByteString): Unit =
expectNetworkData(data.size) shouldEqual data
def expectFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = {
expectFrameHeaderOnNetwork(opcode, data.size, fin)
expectNetworkData(data)
}
def expectMaskedFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = {
val Some(mask) = expectFrameHeaderOnNetwork(opcode, data.size, fin)
val masked = maskedBytes(data, mask)._1
expectNetworkData(masked)
}
/** Returns the mask if any is available */
def expectFrameHeaderOnNetwork(opcode: Opcode, length: Long, fin: Boolean): Option[Int] = {
val (op, l, f, m) = expectFrameHeaderOnNetwork()
op shouldEqual opcode
l shouldEqual length
f shouldEqual fin
m
}
def expectFrameHeaderOnNetwork(): (Opcode, Long, Boolean, Option[Int]) = {
val header = expectNetworkData(2)
val fin = (header(0) & Protocol.FIN_MASK) != 0
val op = header(0) & Protocol.OP_MASK
val hasMask = (header(1) & Protocol.MASK_MASK) != 0
val length7 = header(1) & Protocol.LENGTH_MASK
val length = length7 match {
case 126
val length16Bytes = expectNetworkData(2)
(length16Bytes(0) & 0xff) << 8 | (length16Bytes(1) & 0xff) << 0
case 127
val length64Bytes = expectNetworkData(8)
(length64Bytes(0) & 0xff).toLong << 56 |
(length64Bytes(1) & 0xff).toLong << 48 |
(length64Bytes(2) & 0xff).toLong << 40 |
(length64Bytes(3) & 0xff).toLong << 32 |
(length64Bytes(4) & 0xff).toLong << 24 |
(length64Bytes(5) & 0xff).toLong << 16 |
(length64Bytes(6) & 0xff).toLong << 8 |
(length64Bytes(7) & 0xff).toLong << 0
case x x
}
val mask =
if (hasMask) {
val maskBytes = expectNetworkData(4)
val mask =
(maskBytes(0) & 0xff) << 24 |
(maskBytes(1) & 0xff) << 16 |
(maskBytes(2) & 0xff) << 8 |
(maskBytes(3) & 0xff) << 0
Some(mask)
} else None
(Opcode.forCode(op.toByte), length, fin, mask)
}
def expectProtocolErrorOnNetwork(): Unit = expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError)
def expectCloseCodeOnNetwork(expectedCode: Int): Unit = {
val (opcode, length, true, mask) = expectFrameHeaderOnNetwork()
opcode shouldEqual Opcode.Close
length should be >= 2.toLong
val rawData = expectNetworkData(length.toInt)
val data = mask match {
case Some(m) FrameEventParser.mask(rawData, m)._1
case None rawData
}
val code = ((data(0) & 0xff) << 8) | ((data(1) & 0xff) << 0)
code shouldEqual expectedCode
}
def expectNoNetworkData(): Unit =
netOut.expectNoMsg(100.millis)
}
def frameHeader(
opcode: Opcode,
length: Long,
fin: Boolean,
mask: Option[Int] = None,
rsv1: Boolean = false,
rsv2: Boolean = false,
rsv3: Boolean = false): ByteString = {
def set(should: Boolean, mask: Int): Int =
if (should) mask else 0
val flags =
set(fin, Protocol.FIN_MASK) |
set(rsv1, Protocol.RSV1_MASK) |
set(rsv2, Protocol.RSV2_MASK) |
set(rsv3, Protocol.RSV3_MASK)
val opcodeByte = opcode.code | flags
require(length >= 0)
val (lengthByteComponent, lengthBytes) =
if (length < 126) (length.toByte, ByteString.empty)
else if (length < 65536) (126.toByte, shortBE(length.toInt))
else throw new IllegalArgumentException("Only lengths < 65536 allowed in test")
val maskMask = if (mask.isDefined) Protocol.MASK_MASK else 0
val maskBytes = mask match {
case Some(mask) intBE(mask)
case None ByteString.empty
}
val lengthByte = lengthByteComponent | maskMask
ByteString(opcodeByte.toByte, lengthByte.toByte) ++ lengthBytes ++ maskBytes
}
def closeFrame(closeCode: Int, mask: Boolean): ByteString =
if (mask) {
val mask = Random.nextInt()
frameHeader(Opcode.Close, 2, fin = true, mask = Some(mask)) ++
maskedBytes(shortBE(closeCode), mask)._1
} else
frameHeader(Opcode.Close, 2, fin = true) ++
shortBE(closeCode)
def maskedASCII(str: String, mask: Int): (ByteString, Int) =
FrameEventParser.mask(ByteString(str, "ASCII"), mask)
def maskedUTF8(str: String, mask: Int): (ByteString, Int) =
FrameEventParser.mask(ByteString(str, "UTF-8"), mask)
def maskedBytes(bytes: ByteString, mask: Int): (ByteString, Int) =
FrameEventParser.mask(bytes, mask)
def shortBE(value: Int): ByteString = {
require(value >= 0 && value < 65536, s"Value wasn't in short range: $value")
ByteString(
((value >> 8) & 0xff).toByte,
((value >> 0) & 0xff).toByte)
}
def intBE(value: Int): ByteString =
ByteString(
((value >> 24) & 0xff).toByte,
((value >> 16) & 0xff).toByte,
((value >> 8) & 0xff).toByte,
((value >> 0) & 0xff).toByte)
val trace = false // set to `true` for debugging purposes
def printEvent[T](marker: String): Flow[T, T, Unit] =
if (trace) akka.http.util.printEvent(marker)
else Flow[T]
}

View file

@ -0,0 +1,60 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import org.scalacheck.Gen
import scala.concurrent.duration._
import akka.stream.scaladsl.Source
import akka.util.ByteString
import akka.http.util._
import org.scalatest.prop.PropertyChecks
import org.scalatest.{ FreeSpec, Matchers }
class Utf8CodingSpecs extends FreeSpec with Matchers with PropertyChecks with WithMaterializerSpec {
"Utf8 decoding/encoding" - {
"work for all codepoints" in {
def isSurrogate(cp: Int): Boolean =
cp >= Utf8Encoder.SurrogateFirst && cp <= 0xdfff
val cps =
Gen.choose(0, 0x10ffff)
.filter(!isSurrogate(_))
def codePointAsString(cp: Int): String = {
if (cp < 0x10000) new String(Array(cp.toChar))
else {
val part0 = 0xd7c0 + (cp >> 10) // constant has 0x10000 subtracted already
val part1 = 0xdc00 + (cp & 0x3ff)
new String(Array(part0.toChar, part1.toChar))
}
}
forAll(cps) { (cp: Int)
val utf16 = codePointAsString(cp)
decodeUtf8(encodeUtf8(utf16)) === utf16
}
}
}
def encodeUtf8(str: String): ByteString =
Source(str.map(ch new String(Array(ch)))) // chunk in smallest chunks possible
.transform(() new Utf8Encoder)
.runFold(ByteString.empty)(_ ++ _).awaitResult(1.second)
def decodeUtf8(bytes: ByteString): String = {
val builder = new StringBuilder
val decoder = Utf8Decoder.create()
bytes
.map(b ByteString(b)) // chunk in smallest chunks possible
.foreach { bs
builder append decoder.decode(bs, endOfInput = false).get
}
builder append decoder.decode(ByteString.empty, endOfInput = true).get
builder.toString()
}
}

View file

@ -0,0 +1,20 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.actor.ActorSystem
import akka.stream.ActorFlowMaterializer
import com.typesafe.config.{ ConfigFactory, Config }
import org.scalatest.{ Suite, BeforeAndAfterAll }
trait WithMaterializerSpec extends BeforeAndAfterAll { _: Suite
lazy val testConf: Config = ConfigFactory.parseString("""
akka.event-handlers = ["akka.testkit.TestEventListener"]
akka.loglevel = WARNING""")
implicit lazy val system = ActorSystem(getClass.getSimpleName, testConf)
implicit lazy val materializer = ActorFlowMaterializer()
override def afterAll() = system.shutdown()
}

View file

@ -369,6 +369,44 @@ class HttpHeaderSpec extends FreeSpec with Matchers {
"Range: bytes=0-1, 2-3, -99" =!= Range(ByteRange(0, 1), ByteRange(2, 3), ByteRange.suffix(99))
}
"Sec-WebSocket-Accept" in {
"Sec-WebSocket-Accept: ZGgwOTM0Z2owcmViamRvcGcK" =!= `Sec-WebSocket-Accept`("ZGgwOTM0Z2owcmViamRvcGcK")
}
"Sec-WebSocket-Extensions" in {
"Sec-WebSocket-Extensions: abc" =!=
`Sec-WebSocket-Extensions`(Vector(WebsocketExtension("abc")))
"Sec-WebSocket-Extensions: abc, def" =!=
`Sec-WebSocket-Extensions`(Vector(WebsocketExtension("abc"), WebsocketExtension("def")))
"Sec-WebSocket-Extensions: abc; param=2; use_y, def" =!=
`Sec-WebSocket-Extensions`(Vector(WebsocketExtension("abc", Map("param" -> "2", "use_y" -> "")), WebsocketExtension("def")))
"Sec-WebSocket-Extensions: abc; param=\",xyz\", def" =!=
`Sec-WebSocket-Extensions`(Vector(WebsocketExtension("abc", Map("param" -> ",xyz")), WebsocketExtension("def")))
// real examples from https://tools.ietf.org/html/draft-ietf-hybi-permessage-compression-19
"Sec-WebSocket-Extensions: permessage-deflate" =!=
`Sec-WebSocket-Extensions`(Vector(WebsocketExtension("permessage-deflate")))
"Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits; server_max_window_bits=10" =!=
`Sec-WebSocket-Extensions`(Vector(WebsocketExtension("permessage-deflate", Map("client_max_window_bits" -> "", "server_max_window_bits" -> "10"))))
"Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits; server_max_window_bits=10, permessage-deflate; client_max_window_bits" =!=
`Sec-WebSocket-Extensions`(Vector(
WebsocketExtension("permessage-deflate", Map("client_max_window_bits" -> "", "server_max_window_bits" -> "10")),
WebsocketExtension("permessage-deflate", Map("client_max_window_bits" -> ""))))
}
"Sec-WebSocket-Key" in {
"Sec-WebSocket-Key: c2Zxb3JpbmgyMzA5dGpoMDIzOWdlcm5vZ2luCg==" =!= `Sec-WebSocket-Key`("c2Zxb3JpbmgyMzA5dGpoMDIzOWdlcm5vZ2luCg==")
}
"Sec-WebSocket-Protocol" in {
"Sec-WebSocket-Protocol: chat" =!= `Sec-WebSocket-Protocol`(Vector("chat"))
"Sec-WebSocket-Protocol: chat, superchat" =!= `Sec-WebSocket-Protocol`(Vector("chat", "superchat"))
}
"Sec-WebSocket-Version" in {
"Sec-WebSocket-Version: 25" =!= `Sec-WebSocket-Version`(Vector(25))
"Sec-WebSocket-Version: 13, 8, 7" =!= `Sec-WebSocket-Version`(Vector(13, 8, 7))
"Sec-WebSocket-Version: 255" =!= `Sec-WebSocket-Version`(Vector(255))
"Sec-WebSocket-Version: 0" =!= `Sec-WebSocket-Version`(Vector(0))
}
"Set-Cookie" in {
"Set-Cookie: SID=\"31d4d96e407aad42\"" =!=
`Set-Cookie`(HttpCookie("SID", "31d4d96e407aad42")).renderedTo("SID=31d4d96e407aad42")
@ -445,6 +483,13 @@ class HttpHeaderSpec extends FreeSpec with Matchers {
.renderedTo("PLAY_FLASH=; Expires=Sun, 07 Dec 2014 22:48:47 GMT; Path=/; HttpOnly")
}
"Upgrade" in {
"Upgrade: abc, def" =!= Upgrade(Vector(UpgradeProtocol("abc"), UpgradeProtocol("def")))
"Upgrade: abc, def/38.1" =!= Upgrade(Vector(UpgradeProtocol("abc"), UpgradeProtocol("def", Some("38.1"))))
"Upgrade: websocket" =!= Upgrade(Vector(UpgradeProtocol("websocket")))
}
"User-Agent" in {
"User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_3) AppleWebKit/537.31" =!=
`User-Agent`(ProductVersion("Mozilla", "5.0", "Macintosh; Intel Mac OS X 10_8_3"), ProductVersion("AppleWebKit", "537.31"))

View file

@ -0,0 +1,49 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.server.directives
import akka.http.engine.ws.InternalCustomHeader
import akka.http.model
import akka.http.model.headers.{ Connection, UpgradeProtocol, Upgrade }
import akka.http.model.{ HttpRequest, StatusCodes, HttpResponse }
import akka.http.model.ws.{ Message, UpgradeToWebsocket }
import akka.http.server.{ Route, RoutingSpec }
import akka.http.util.Rendering
import akka.stream.FlowMaterializer
import akka.stream.scaladsl.Flow
import scala.collection.immutable.Seq
class WebsocketDirectivesSpec extends RoutingSpec {
"the handleWebsocketMessages directive" should {
"handle websocket requests" in {
Get("http://localhost/") ~> Upgrade(List(UpgradeProtocol("websocket"))) ~>
emulateHttpCore ~> Route.seal(handleWebsocketMessages(Flow[Message])) ~>
check {
status shouldEqual StatusCodes.SwitchingProtocols
}
}
"reject non-websocket requests" in {
Get("http://localhost/") ~> emulateHttpCore ~> Route.seal(handleWebsocketMessages(Flow[Message])) ~> check {
status shouldEqual StatusCodes.BadRequest
responseAs[String] shouldEqual "Expected Websocket Upgrade request"
}
}
}
/** Only checks for upgrade header and then adds UpgradeToWebsocket mock header */
def emulateHttpCore(req: HttpRequest): HttpRequest =
req.header[Upgrade] match {
case Some(upgrade) if upgrade.hasWebsocket req.copy(headers = req.headers :+ upgradeToWebsocketHeaderMock)
case _ req
}
def upgradeToWebsocketHeaderMock: UpgradeToWebsocket =
new InternalCustomHeader("UpgradeToWebsocketMock") with UpgradeToWebsocket {
def requestedProtocols: Seq[String] = Nil
def handleMessages(handlerFlow: Flow[Message, Message, Any], subprotocol: Option[String])(implicit mat: FlowMaterializer): HttpResponse =
HttpResponse(StatusCodes.SwitchingProtocols)
}
}

View file

@ -88,10 +88,10 @@ class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault)
if (readByte() != 8) fail("Unsupported GZIP compression method") // check compression method
val flags = readByte()
skip(6) // skip MTIME, XFL and OS fields
if ((flags & 4) > 0) skip(readShort()) // skip optional extra fields
if ((flags & 4) > 0) skip(readShortLE()) // skip optional extra fields
if ((flags & 8) > 0) skipZeroTerminatedString() // skip optional file name
if ((flags & 16) > 0) skipZeroTerminatedString() // skip optional file comment
if ((flags & 2) > 0 && crc16(fromStartToHere) != readShort()) fail("Corrupt GZIP header")
if ((flags & 2) > 0 && crc16(fromStartToHere) != readShortLE()) fail("Corrupt GZIP header")
inflater.reset()
crc32.reset()
@ -107,8 +107,8 @@ class GzipDecompressor(maxBytesPerChunk: Int = Decoder.MaxBytesPerChunkDefault)
def read(reader: ByteReader, ctx: Context[ByteString]): SyncDirective = {
import reader._
if (readInt() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)")
if (readInt() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE")
if (readIntLE() != crc32.getValue.toInt) fail("Corrupt data (CRC32 checksum error)")
if (readIntLE() != inflater.getBytesWritten.toInt /* truncated to 32bit */ ) fail("Corrupt GZIP trailer ISIZE")
becomeWithRemaining(Initial, remainingData, ctx)
}

View file

@ -32,5 +32,6 @@ trait Directives extends RouteConcatenation
with RouteDirectives
with SchemeDirectives
with SecurityDirectives
with WebsocketDirectives
object Directives extends Directives

View file

@ -163,6 +163,11 @@ case object AuthorizationFailedRejection extends Rejection
*/
case class MissingCookieRejection(cookieName: String) extends Rejection
/**
* Rejection created when a websocket request was expected but none was found.
*/
case object ExpectedWebsocketRequestRejection extends Rejection
/**
* Rejection created by the `validation` directive as well as for `IllegalArgumentExceptions`
* thrown by domain model constructors (e.g. via `require`).

View file

@ -206,6 +206,7 @@ object RejectionHandler {
val supported = rejections.map(_.supported.value).mkString(" or ")
complete(BadRequest, "The request's Content-Encoding is not supported. Expected:\n" + supported)
}
.handle { case ExpectedWebsocketRequestRejection complete(BadRequest, "Expected Websocket Upgrade request") }
.handle { case ValidationRejection(msg, _) complete(BadRequest, msg) }
.handle { case x sys.error("Unhandled rejection: " + x) }
.handleNotFound { complete(NotFound, "The requested resource could not be found.") }

View file

@ -0,0 +1,27 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.server
package directives
import akka.http.model.ws.{ UpgradeToWebsocket, Message }
import akka.stream.scaladsl.Flow
trait WebsocketDirectives {
import BasicDirectives._
import RouteDirectives._
import HeaderDirectives._
/**
* Handles websocket requests with the given handler and rejects other requests with a
* [[ExpectedWebsocketRequestRejection]].
*/
def handleWebsocketMessages(handler: Flow[Message, Message, Any]): Route =
extractFlowMaterializer { implicit mat
optionalHeaderValueByType[UpgradeToWebsocket]() {
case Some(upgrade) complete(upgrade.handleMessages(handler))
case None reject(ExpectedWebsocketRequestRejection)
}
}
}

View file

@ -300,6 +300,11 @@ object Flow extends FlowApply {
*/
def wrap[I, O, M](g: Graph[FlowShape[I, O], M]): Flow[I, O, M] = new Flow(g.module)
/**
* Helper to create `Flow` from a pair of sink and source.
*/
def wrap[I, O, M1, M2, M](sink: Sink[I, M1], source: Source[O, M2])(f: (M1, M2) M): Flow[I, O, M] =
Flow(sink, source)(f) { implicit b (in, out) (in.inlet, out.outlet) }
}
/**