=htc #16887 implement high-level server-side Websocket API

This commit is contained in:
Johannes Rudolph 2015-04-21 15:21:22 +02:00
parent 880733eb3d
commit 6dafa445de
6 changed files with 1585 additions and 0 deletions

View file

@ -0,0 +1,200 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.stream.scaladsl.Flow
import akka.stream.stage.{ TerminationDirective, SyncDirective, Context, StatefulStage }
import akka.util.ByteString
import Protocol.Opcode
import scala.util.control.NonFatal
/**
* The frame handler validates frames, multiplexes data to the user handler or to the bypass and
* UTF-8 decodes text frames.
*
* INTERNAL API
*/
private[http] object FrameHandler {
def create(server: Boolean): Flow[FrameEvent, Either[BypassEvent, MessagePart], Unit] =
Flow[FrameEvent].transform(() new HandlerStage(server))
class HandlerStage(server: Boolean) extends StatefulStage[FrameEvent, Either[BypassEvent, MessagePart]] {
type Ctx = Context[Either[BypassEvent, MessagePart]]
def initial: State = Idle
object Idle extends StateWithControlFrameHandling {
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
(start.header.opcode, start.isFullMessage) match {
case (Opcode.Binary, true) publishMessagePart(BinaryMessagePart(start.data, last = true))
case (Opcode.Binary, false) becomeAndHandleWith(new CollectingBinaryMessage, start)
case (Opcode.Text, _) becomeAndHandleWith(new CollectingTextMessage, start)
case x protocolError()
}
}
class CollectingBinaryMessage extends CollectingMessageFrame(Opcode.Binary) {
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = BinaryMessagePart(data, last)
}
class CollectingTextMessage extends CollectingMessageFrame(Opcode.Text) {
val decoder = Utf8Decoder.create()
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart =
TextMessagePart(decoder.decode(data, endOfInput = last).get, last)
}
abstract class CollectingMessageFrame(expectedOpcode: Opcode) extends StateWithControlFrameHandling {
var expectFirstHeader = true
var finSeen = false
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = {
if ((expectFirstHeader && start.header.opcode == expectedOpcode) // first opcode must be the expected
|| start.header.opcode == Opcode.Continuation) { // further ones continuations
expectFirstHeader = false
if (start.header.fin) finSeen = true
publish(start)
} else protocolError()
}
override def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = publish(data)
private def publish(part: FrameEvent)(implicit ctx: Ctx): SyncDirective =
try publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart))
catch {
case NonFatal(e) closeWithCode(Protocol.CloseCodes.InconsistentData)
}
}
class CollectingControlFrame(opcode: Opcode, _data: ByteString, nextState: State) extends InFrameState {
var data = _data
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = {
this.data ++= data.data
if (data.lastPart) handleControlFrame(opcode, this.data, nextState)
else ctx.pull()
}
}
object Closed extends State {
def onPush(elem: FrameEvent, ctx: Ctx): SyncDirective =
ctx.pull() // ignore
}
def becomeAndHandleWith(newState: State, part: FrameEvent)(implicit ctx: Ctx): SyncDirective = {
become(newState)
current.onPush(part, ctx)
}
/** Returns a SyncDirective if it handled the message */
def validateHeader(header: FrameHeader)(implicit ctx: Ctx): Option[SyncDirective] = header match {
case h: FrameHeader if h.mask.isDefined && !server Some(protocolError())
case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3 Some(protocolError())
case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) Some(protocolError())
case _ None
}
def handleControlFrame(opcode: Opcode, data: ByteString, nextState: State)(implicit ctx: Ctx): SyncDirective = {
become(nextState)
opcode match {
case Opcode.Ping publishDirectResponse(FrameEvent.fullFrame(Opcode.Pong, None, data, fin = true))
case Opcode.Pong
// ignore unsolicited Pong frame
ctx.pull()
case Opcode.Close
val closeCode = FrameEventParser.parseCloseCode(data)
emit(Iterator(Left(PeerClosed(closeCode)), Right(PeerClosed(closeCode))), ctx, WaitForPeerTcpClose)
case Opcode.Other(o) closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode")
}
}
private def collectControlFrame(start: FrameStart, nextState: State)(implicit ctx: Ctx): SyncDirective = {
assert(!start.isFullMessage)
become(new CollectingControlFrame(start.header.opcode, start.data, nextState))
ctx.pull()
}
private def publishMessagePart(part: MessageDataPart)(implicit ctx: Ctx): SyncDirective =
if (part.last) emit(Iterator(Right(part), Right(MessageEnd)), ctx, Idle)
else ctx.push(Right(part))
private def publishDirectResponse(frame: FrameStart)(implicit ctx: Ctx): SyncDirective =
ctx.push(Left(DirectAnswer(frame)))
private def protocolError(reason: String = "")(implicit ctx: Ctx): SyncDirective =
closeWithCode(Protocol.CloseCodes.ProtocolError, reason)
private def closeWithCode(closeCode: Int, reason: String = "", cause: Throwable = null)(implicit ctx: Ctx): SyncDirective =
emit(
Iterator(
Left(ActivelyCloseWithCode(Some(closeCode), reason)),
Right(ActivelyCloseWithCode(Some(closeCode), reason))), ctx, CloseAfterPeerClosed)
object CloseAfterPeerClosed extends State {
def onPush(elem: FrameEvent, ctx: Context[Either[BypassEvent, MessagePart]]): SyncDirective =
elem match {
case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data)
become(WaitForPeerTcpClose)
ctx.push(Left(PeerClosed(FrameEventParser.parseCloseCode(data))))
case _ ctx.pull() // ignore all other data
}
}
object WaitForPeerTcpClose extends State {
def onPush(elem: FrameEvent, ctx: Context[Either[BypassEvent, MessagePart]]): SyncDirective =
ctx.pull() // ignore
}
abstract class StateWithControlFrameHandling extends BetweenFrameState {
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
validateHeader(start.header).getOrElse {
if (start.header.opcode.isControl)
if (start.isFullMessage) handleControlFrame(start.header.opcode, start.data, this)
else collectControlFrame(start, this)
else handleRegularFrameStart(start)
}
}
abstract class BetweenFrameState extends ImplicitContextState {
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective =
throw new IllegalStateException("Expected FrameStart")
}
abstract class InFrameState extends ImplicitContextState {
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
throw new IllegalStateException("Expected FrameData")
}
abstract class ImplicitContextState extends State {
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective
def onPush(part: FrameEvent, ctx: Ctx): SyncDirective =
part match {
case data: FrameData handleFrameData(data)(ctx)
case start: FrameStart handleFrameStart(start)(ctx)
}
}
}
sealed trait MessagePart {
def isMessageEnd: Boolean
}
sealed trait MessageDataPart extends MessagePart {
def isMessageEnd = false
def last: Boolean
}
final case class TextMessagePart(data: String, last: Boolean) extends MessageDataPart
final case class BinaryMessagePart(data: ByteString, last: Boolean) extends MessageDataPart
case object MessageEnd extends MessagePart {
def isMessageEnd: Boolean = true
}
final case class PeerClosed(code: Option[Int], reason: String = "") extends MessagePart with BypassEvent {
def isMessageEnd: Boolean = true
}
sealed trait BypassEvent
final case class DirectAnswer(frame: FrameStart) extends BypassEvent
final case class ActivelyCloseWithCode(code: Option[Int], reason: String = "") extends MessagePart with BypassEvent {
def isMessageEnd: Boolean = true
}
case object UserHandlerCompleted extends BypassEvent
case class UserHandlerErredOut(cause: Throwable) extends BypassEvent
}

View file

@ -0,0 +1,132 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import scala.concurrent.duration.FiniteDuration
import akka.stream.stage._
import akka.http.util.Timestamp
import FrameHandler.{ UserHandlerCompleted, ActivelyCloseWithCode, PeerClosed, DirectAnswer }
import Websocket.Tick
/**
* Implements the transport connection close handling at the end of the pipeline.
*
* INTERNAL API
*/
private[http] class FrameOutHandler(serverSide: Boolean, _closeTimeout: FiniteDuration) extends StatefulStage[AnyRef, FrameStart] {
def initial: StageState[AnyRef, FrameStart] = Idle
def closeTimeout: Timestamp = Timestamp.now + _closeTimeout
object Idle extends CompletionHandlingState {
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match {
case start: FrameStart ctx.push(start)
case DirectAnswer(frame) ctx.push(frame)
case PeerClosed(code, reason) if !code.exists(Protocol.CloseCodes.isError)
// let user complete it, FIXME: maybe make configurable? immediately, or timeout
become(new WaitingForUserHandlerClosed(FrameEvent.closeFrame(code.getOrElse(Protocol.CloseCodes.Regular), reason)))
ctx.pull()
case PeerClosed(code, reason)
val closeFrame = FrameEvent.closeFrame(code.getOrElse(Protocol.CloseCodes.Regular), reason)
if (serverSide) ctx.pushAndFinish(closeFrame)
else {
become(new WaitingForTransportClose)
ctx.push(closeFrame)
}
case ActivelyCloseWithCode(code, reason)
val closeFrame = FrameEvent.closeFrame(code.getOrElse(Protocol.CloseCodes.Regular), reason)
become(new WaitingForPeerCloseFrame())
ctx.push(closeFrame)
case UserHandlerCompleted
become(new WaitingForPeerCloseFrame())
ctx.push(FrameEvent.closeFrame(Protocol.CloseCodes.Regular))
case Tick ctx.pull() // ignore
}
def onComplete(ctx: Context[FrameStart]): TerminationDirective = {
become(new SendOutCloseFrameAndComplete(FrameEvent.closeFrame(Protocol.CloseCodes.Regular)))
ctx.absorbTermination()
}
}
/**
* peer has closed, we want to wait for user handler to close as well
*/
class WaitingForUserHandlerClosed(closeFrame: FrameStart) extends CompletionHandlingState {
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match {
case UserHandlerCompleted
if (serverSide) ctx.pushAndFinish(closeFrame)
else {
become(new WaitingForTransportClose())
ctx.push(closeFrame)
}
case start: FrameStart ctx.push(start)
case _ ctx.pull() // ignore
}
def onComplete(ctx: Context[FrameStart]): TerminationDirective =
ctx.fail(new IllegalStateException("Mustn't complete before user has completed"))
}
/**
* we have sent out close frame and wait for peer to sent its close frame
*/
class WaitingForPeerCloseFrame(timeout: Timestamp = closeTimeout) extends CompletionHandlingState {
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match {
case Tick
if (timeout.isPast) ctx.finish()
else ctx.pull()
case PeerClosed(code, reason)
if (serverSide) ctx.finish()
else {
become(new WaitingForTransportClose())
ctx.pull()
}
case _ ctx.pull() // ignore
}
def onComplete(ctx: Context[FrameStart]): TerminationDirective = ctx.finish()
}
/**
* Both side have sent their close frames, server should close the connection first
*/
class WaitingForTransportClose(timeout: Timestamp = closeTimeout) extends CompletionHandlingState {
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match {
case Tick
if (timeout.isPast) ctx.finish()
else ctx.pull()
case _ ctx.pull() // ignore
}
def onComplete(ctx: Context[FrameStart]): TerminationDirective = ctx.finish()
}
/** If upstream has already failed we just wait to be able to deliver our close frame and complete */
class SendOutCloseFrameAndComplete(closeFrame: FrameStart) extends CompletionHandlingState {
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective =
ctx.fail(new IllegalStateException("Didn't expect push after completion"))
override def onPull(ctx: Context[FrameStart]): SyncDirective =
ctx.pushAndFinish(closeFrame)
def onComplete(ctx: Context[FrameStart]): TerminationDirective =
ctx.absorbTermination()
}
trait CompletionHandlingState extends State {
def onComplete(ctx: Context[FrameStart]): TerminationDirective
}
override def onUpstreamFinish(ctx: Context[FrameStart]): TerminationDirective =
current.asInstanceOf[CompletionHandlingState].onComplete(ctx)
override def onUpstreamFailure(cause: scala.Throwable, ctx: Context[FrameStart]): TerminationDirective = cause match {
case p: ProtocolException
become(new SendOutCloseFrameAndComplete(FrameEvent.closeFrame(Protocol.CloseCodes.ProtocolError)))
ctx.absorbTermination()
case _ super.onUpstreamFailure(cause, ctx)
}
}

View file

@ -0,0 +1,71 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.stream.scaladsl.Flow
import akka.stream.stage.{ SyncDirective, Context, StageState, StatefulStage }
import scala.util.Random
/**
* Implements Websocket Frame masking.
*
* INTERNAL API
*/
private[http] object Masking {
def maskIf(condition: Boolean, maskRandom: () Random): Flow[FrameEvent, FrameEvent, Unit] =
if (condition) Flow[FrameEvent].transform(() new Masking(maskRandom())) // new random per materialization
else Flow[FrameEvent]
def unmaskIf(condition: Boolean): Flow[FrameEvent, FrameEvent, Unit] =
if (condition) Flow[FrameEvent].transform(() new Unmasking())
else Flow[FrameEvent]
class Masking(random: Random) extends Masker {
def extractMask(header: FrameHeader): Int = random.nextInt()
def setNewMask(header: FrameHeader, mask: Int): FrameHeader = {
if (header.mask.isDefined) throw new ProtocolException("Frame mustn't already be masked")
header.copy(mask = Some(mask))
}
}
class Unmasking extends Masker {
def extractMask(header: FrameHeader): Int = header.mask match {
case Some(mask) mask
case None throw new ProtocolException("Frame wasn't masked")
}
def setNewMask(header: FrameHeader, mask: Int): FrameHeader = header.copy(mask = None)
}
/** Implements both masking and unmasking which is mostly symmetric (because of XOR) */
abstract class Masker extends StatefulStage[FrameEvent, FrameEvent] {
def extractMask(header: FrameHeader): Int
def setNewMask(header: FrameHeader, mask: Int): FrameHeader
def initial: State = Idle
object Idle extends State {
def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective =
part match {
case start @ FrameStart(header, data)
if (header.length == 0) ctx.push(part)
else {
val mask = extractMask(header)
become(new Running(mask))
current.onPush(start.copy(header = setNewMask(header, mask)), ctx)
}
}
}
class Running(initialMask: Int) extends State {
var mask = initialMask
def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective = {
if (part.lastPart) become(Idle)
val (masked, newMask) = FrameEventParser.mask(part.data, mask)
mask = newMask
ctx.push(part.withData(data = masked))
}
}
}
}

View file

@ -0,0 +1,37 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import akka.util.ByteString
import akka.stream.scaladsl.{ FlattenStrategy, Source, Flow }
import Protocol.Opcode
import akka.http.model.ws._
/**
* Renders messages to full frames.
*
* INTERNAL API
*/
private[http] object MessageToFrameRenderer {
def create(serverSide: Boolean): Flow[Message, FrameStart, Unit] = {
def strictFrames(opcode: Opcode, data: ByteString): Source[FrameStart, _] =
// FIXME: fragment?
Source.single(FrameEvent.fullFrame(opcode, None, data, fin = true))
def streamedFrames(opcode: Opcode, data: Source[ByteString, _]): Source[FrameStart, _] =
Source.single(FrameEvent.empty(opcode, fin = false)) ++
data.map(FrameEvent.fullFrame(Opcode.Continuation, None, _, fin = false)) ++
Source.single(FrameEvent.emptyLastContinuationFrame)
Flow[Message]
.map {
case BinaryMessage.Strict(data) strictFrames(Opcode.Binary, data)
case BinaryMessage.Streamed(data) streamedFrames(Opcode.Binary, data)
case TextMessage.Strict(text) strictFrames(Opcode.Text, ByteString(text, "UTF-8"))
case TextMessage.Streamed(text) streamedFrames(Opcode.Text, text.transform(() new Utf8Encoder))
}.flatten(FlattenStrategy.Concat())
}
}

View file

@ -0,0 +1,179 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import java.security.SecureRandom
import scala.concurrent.duration._
import akka.stream.{ OperationAttributes, FanOutShape2, FanInShape3, Inlet }
import akka.stream.scaladsl._
import akka.stream.stage._
import FlexiRoute.{ DemandFrom, DemandFromAny, RouteLogic }
import FlexiMerge.MergeLogic
import akka.http.util._
import akka.http.model.ws._
/**
* INTERNAL API
*/
private[http] object Websocket {
import FrameHandler._
def handleMessages[T](messageHandler: Flow[Message, Message, T],
serverSide: Boolean = true,
closeTimeout: FiniteDuration = 3.seconds): Flow[FrameEvent, FrameEvent, Unit] = {
/** Completes this branch of the flow if no more messages are expected and converts close codes into errors */
class PrepareForUserHandler extends PushStage[MessagePart, MessagePart] {
def onPush(elem: MessagePart, ctx: Context[MessagePart]): SyncDirective = elem match {
case PeerClosed(code, reason)
if (code.exists(Protocol.CloseCodes.isError)) ctx.fail(new ProtocolException(s"Peer closed connection with code $code"))
else ctx.finish()
case ActivelyCloseWithCode(code, reason)
if (code.exists(Protocol.CloseCodes.isError)) ctx.fail(new ProtocolException(s"Closing connection with error code $code"))
else ctx.finish()
case x ctx.push(x)
}
}
/** Collects user-level API messages from MessageDataParts */
val collectMessage: Flow[Source[MessageDataPart, Unit], Message, Unit] =
Flow[Source[MessageDataPart, Unit]]
.headAndTail
.map {
case (TextMessagePart(text, true), remaining)
TextMessage.Strict(text)
case (first @ TextMessagePart(text, false), remaining)
TextMessage.Streamed(
(Source.single(first) ++ remaining)
.collect {
case t: TextMessagePart if t.data.nonEmpty t.data
})
case (BinaryMessagePart(data, true), remaining)
BinaryMessage.Strict(data)
case (first @ BinaryMessagePart(data, false), remaining)
BinaryMessage.Streamed(
(Source.single(first) ++ remaining)
.collect {
case t: BinaryMessagePart if t.data.nonEmpty t.data
})
}
/** Lifts onComplete and onError into events to be processed in the FlexiMerge */
class LiftCompletions extends StatefulStage[FrameStart, AnyRef] {
def initial: StageState[FrameStart, AnyRef] = SteadyState
object SteadyState extends State {
def onPush(elem: FrameStart, ctx: Context[AnyRef]): SyncDirective = ctx.push(elem)
}
class CompleteWith(last: AnyRef) extends State {
def onPush(elem: FrameStart, ctx: Context[AnyRef]): SyncDirective =
ctx.fail(new IllegalStateException("No push expected"))
override def onPull(ctx: Context[AnyRef]): SyncDirective = ctx.pushAndFinish(last)
}
override def onUpstreamFinish(ctx: Context[AnyRef]): TerminationDirective = {
become(new CompleteWith(UserHandlerCompleted))
ctx.absorbTermination()
}
override def onUpstreamFailure(cause: Throwable, ctx: Context[AnyRef]): TerminationDirective = {
become(new CompleteWith(UserHandlerErredOut(cause)))
ctx.absorbTermination()
}
}
lazy val userFlow =
Flow[MessagePart]
.transform(() new PrepareForUserHandler)
.splitWhen(_.isMessageEnd) // FIXME using splitAfter from #16885 would simplify protocol a lot
.map(_.collect {
case m: MessageDataPart m
})
.via(collectMessage)
.via(messageHandler)
.via(MessageToFrameRenderer.create(serverSide))
.transform(() new LiftCompletions)
/**
* Distributes output from the FrameHandler into bypass and userFlow.
*/
object BypassRouter
extends FlexiRoute[Either[BypassEvent, MessagePart], FanOutShape2[Either[BypassEvent, MessagePart], BypassEvent, MessagePart]](new FanOutShape2("bypassRouter"), OperationAttributes.name("bypassRouter")) {
def createRouteLogic(s: FanOutShape2[Either[BypassEvent, MessagePart], BypassEvent, MessagePart]): RouteLogic[Either[BypassEvent, MessagePart]] =
new RouteLogic[Either[BypassEvent, MessagePart]] {
def initialState: State[_] = State(DemandFromAny(s)) { (ctx, out, ev)
ev match {
case Left(_)
State(DemandFrom(s.out0)) { (ctx, _, ev) // FIXME: #17004
ctx.emit(s.out0)(ev.left.get)
initialState
}
case Right(_)
State(DemandFrom(s.out1)) { (ctx, _, ev)
ctx.emit(s.out1)(ev.right.get)
initialState
}
}
}
override def initialCompletionHandling: CompletionHandling = super.initialCompletionHandling.copy(
onDownstreamFinish = { (ctx, out)
if (out == s.out0) ctx.finish()
SameState
})
}
}
/**
* Merges bypass, user flow and tick source for consumption in the FrameOutHandler.
*/
object BypassMerge extends FlexiMerge[AnyRef, FanInShape3[BypassEvent, AnyRef, Tick.type, AnyRef]](new FanInShape3("bypassMerge"), OperationAttributes.name("bypassMerge")) {
def createMergeLogic(s: FanInShape3[BypassEvent, AnyRef, Tick.type, AnyRef]): MergeLogic[AnyRef] =
new MergeLogic[AnyRef] {
def initialState: State[_] = Idle
lazy val Idle = State[AnyRef](FlexiMerge.ReadAny(s.in0.asInstanceOf[Inlet[AnyRef]], s.in1.asInstanceOf[Inlet[AnyRef]], s.in2.asInstanceOf[Inlet[AnyRef]])) { (ctx, in, elem)
ctx.emit(elem)
SameState
}
override def initialCompletionHandling: CompletionHandling =
CompletionHandling(
onUpstreamFinish = { (ctx, in)
if (in == s.in0) ctx.finish()
SameState
},
onUpstreamFailure = { (ctx, in, cause)
if (in == s.in0) ctx.fail(cause)
SameState
})
}
}
lazy val bypassAndUserHandler: Flow[Either[BypassEvent, MessagePart], AnyRef, Unit] =
Flow(BypassRouter, Source(closeTimeout, closeTimeout, Tick), BypassMerge)((_, _, _) ()) { implicit b
(split, tick, merge)
import FlowGraph.Implicits._
split.out0 ~> merge.in0
split.out1 ~> userFlow ~> merge.in1
tick.outlet ~> merge.in2
(split.in, merge.out)
}
Flow[FrameEvent]
.via(Masking.unmaskIf(serverSide))
.via(FrameHandler.create(server = serverSide))
.mapConcat(x x :: x :: Nil) // FIXME: #17004
.via(bypassAndUserHandler)
.transform(() new FrameOutHandler(serverSide, closeTimeout))
.via(Masking.maskIf(!serverSide, () new SecureRandom()))
}
object Tick
case object SwitchToWebsocketToken
}

View file

@ -0,0 +1,966 @@
/*
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.engine.ws
import scala.annotation.tailrec
import scala.concurrent.duration._
import akka.stream.FlowShape
import akka.stream.scaladsl._
import akka.stream.testkit.StreamTestKit
import akka.util.ByteString
import org.scalatest.{ Matchers, FreeSpec }
import scala.util.Random
import akka.http.model.ws._
import Protocol.Opcode
class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
"The Websocket implementation should" - {
"collect messages from frames" - {
"for binary messages" - {
"for an empty message" in new ClientTestSetup {
val input = frameHeader(Opcode.Binary, 0, fin = true)
pushInput(input)
expectMessage(BinaryMessage.Strict(ByteString.empty))
}
"for one complete, strict, single frame message" in new ClientTestSetup {
val data = ByteString("abcdef", "ASCII")
val input = frameHeader(Opcode.Binary, 6, fin = true) ++ data
pushInput(input)
expectMessage(BinaryMessage.Strict(data))
}
"for a partial frame" in new ClientTestSetup {
val data1 = ByteString("abc", "ASCII")
val header = frameHeader(Opcode.Binary, 6, fin = true)
pushInput(header ++ data1)
val BinaryMessage.Streamed(dataSource) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(2)
sub.expectNext(data1)
}
"for a frame split up into parts" in new ClientTestSetup {
val data1 = ByteString("abc", "ASCII")
val header = frameHeader(Opcode.Binary, 6, fin = true)
pushInput(header)
val BinaryMessage.Streamed(dataSource) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(2)
pushInput(data1)
sub.expectNext(data1)
val data2 = ByteString("def", "ASCII")
pushInput(data2)
sub.expectNext(data2)
sub.expectComplete()
}
"for a message split into several frames" in new ClientTestSetup {
val data1 = ByteString("abc", "ASCII")
val header1 = frameHeader(Opcode.Binary, 3, fin = false)
pushInput(header1 ++ data1)
val BinaryMessage.Streamed(dataSource) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(2)
sub.expectNext(data1)
val header2 = frameHeader(Opcode.Continuation, 4, fin = true)
val data2 = ByteString("defg", "ASCII")
pushInput(header2 ++ data2)
sub.expectNext(data2)
sub.expectComplete()
}
"for several messages" in new ClientTestSetup {
val data1 = ByteString("abc", "ASCII")
val header1 = frameHeader(Opcode.Binary, 3, fin = false)
pushInput(header1 ++ data1)
val BinaryMessage.Streamed(dataSource) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(2)
sub.expectNext(data1)
val header2 = frameHeader(Opcode.Continuation, 4, fin = true)
val header3 = frameHeader(Opcode.Binary, 2, fin = true)
val data2 = ByteString("defg", "ASCII")
val data3 = ByteString("h")
pushInput(header2 ++ data2 ++ header3 ++ data3)
sub.expectNext(data2)
sub.expectComplete()
val BinaryMessage.Streamed(dataSource2) = expectMessage()
val sub2 = StreamTestKit.SubscriberProbe[ByteString]
dataSource2.runWith(Sink(sub2))
val s2 = sub2.expectSubscription()
s2.request(2)
sub2.expectNext(data3)
val data4 = ByteString("i")
pushInput(data4)
sub2.expectNext(data4)
sub2.expectComplete()
}
"unmask masked input on the server side" in new ServerTestSetup {
val mask = Random.nextInt()
val (data, _) = maskedASCII("abcdef", mask)
val data1 = data.take(3)
val data2 = data.drop(3)
val header = frameHeader(Opcode.Binary, 6, fin = true, mask = Some(mask))
pushInput(header ++ data1)
val BinaryMessage.Streamed(dataSource) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(2)
sub.expectNext(ByteString("abc", "ASCII"))
pushInput(data2)
sub.expectNext(ByteString("def", "ASCII"))
sub.expectComplete()
}
}
"for text messages" - {
"empty message" in new ClientTestSetup {
val input = frameHeader(Opcode.Text, 0, fin = true)
pushInput(input)
expectMessage(TextMessage.Strict(""))
}
"decode complete, strict frame from utf8" in new ClientTestSetup {
val msg = "äbcdef€\uffff"
val data = ByteString(msg, "UTF-8")
val input = frameHeader(Opcode.Text, data.size, fin = true) ++ data
pushInput(input)
expectMessage(TextMessage.Strict(msg))
}
"decode utf8 as far as possible for partial frame" in new ClientTestSetup {
val msg = "bäcdef€"
val data = ByteString(msg, "UTF-8")
val data0 = data.slice(0, 2)
val data1 = data.slice(2, 5)
val data2 = data.slice(5, data.size)
val input = frameHeader(Opcode.Text, data.size, fin = true) ++ data0
pushInput(input)
val TextMessage.Streamed(parts) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[String]
parts.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(4)
sub.expectNext("b")
pushInput(data1)
sub.expectNext("äcd")
}
"decode utf8 with code point split across frames" in new ClientTestSetup {
val msg = "äbcdef€"
val data = ByteString(msg, "UTF-8")
val data0 = data.slice(0, 1)
val data1 = data.slice(1, data.size)
val header0 = frameHeader(Opcode.Text, data0.size, fin = false)
pushInput(header0 ++ data0)
val TextMessage.Streamed(parts) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[String]
parts.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(4)
sub.expectNoMsg(100.millis)
val header1 = frameHeader(Opcode.Continuation, data1.size, fin = true)
pushInput(header1 ++ data1)
sub.expectNext("äbcdef€")
}
"unmask masked input on the server side" in new ServerTestSetup {
val mask = Random.nextInt()
val (data, _) = maskedUTF8("äbcdef€", mask)
val data1 = data.take(3)
val data2 = data.drop(3)
val header = frameHeader(Opcode.Binary, data.size, fin = true, mask = Some(mask))
pushInput(header ++ data1)
val BinaryMessage.Streamed(dataSource) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(2)
sub.expectNext(ByteString("äb", "UTF-8"))
pushInput(data2)
sub.expectNext(ByteString("cdef€", "UTF-8"))
sub.expectComplete()
}
}
}
"render frames from messages" - {
"for binary messages" - {
"for a short strict message" in new ServerTestSetup {
val data = ByteString("abcdef", "ASCII")
val msg = BinaryMessage.Strict(data)
netOutSub.request(5)
pushMessage(msg)
expectFrameOnNetwork(Opcode.Binary, data, fin = true)
}
"for a strict message larger than configured maximum frame size" in pending
"for a streamed message" in new ServerTestSetup {
val data = ByteString("abcdefg", "ASCII")
val pub = StreamTestKit.PublisherProbe[ByteString]
val msg = BinaryMessage.Streamed(Source(pub))
netOutSub.request(6)
pushMessage(msg)
val sub = pub.expectSubscription()
expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false)
val data1 = data.take(3)
val data2 = data.drop(3)
sub.sendNext(data1)
expectFrameOnNetwork(Opcode.Continuation, data1, fin = false)
sub.sendNext(data2)
expectFrameOnNetwork(Opcode.Continuation, data2, fin = false)
sub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
}
"for a streamed message with a chunk being larger than configured maximum frame size" in pending
"and mask input on the client side" in new ClientTestSetup {
val data = ByteString("abcdefg", "ASCII")
val pub = StreamTestKit.PublisherProbe[ByteString]
val msg = BinaryMessage.Streamed(Source(pub))
netOutSub.request(7)
pushMessage(msg)
val sub = pub.expectSubscription()
expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false)
val data1 = data.take(3)
val data2 = data.drop(3)
sub.sendNext(data1)
expectMaskedFrameOnNetwork(Opcode.Continuation, data1, fin = false)
sub.sendNext(data2)
expectMaskedFrameOnNetwork(Opcode.Continuation, data2, fin = false)
sub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
}
}
"for text messages" - {
"for a short strict message" in new ServerTestSetup {
val text = "äbcdef"
val msg = TextMessage.Strict(text)
netOutSub.request(5)
pushMessage(msg)
expectFrameOnNetwork(Opcode.Text, ByteString(text, "UTF-8"), fin = true)
}
"for a strict message larger than configured maximum frame size" in pending
"for a streamed message" in new ServerTestSetup {
val text = "äbcd€fg"
val pub = StreamTestKit.PublisherProbe[String]
val msg = TextMessage.Streamed(Source(pub))
netOutSub.request(6)
pushMessage(msg)
val sub = pub.expectSubscription()
expectFrameHeaderOnNetwork(Opcode.Text, 0, fin = false)
val text1 = text.take(3)
val text1Bytes = ByteString(text1, "UTF-8")
val text2 = text.drop(3)
val text2Bytes = ByteString(text2, "UTF-8")
sub.sendNext(text1)
expectFrameOnNetwork(Opcode.Continuation, text1Bytes, fin = false)
sub.sendNext(text2)
expectFrameOnNetwork(Opcode.Continuation, text2Bytes, fin = false)
sub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
}
"for a streamed message don't convert half surrogate pairs naively" in new ServerTestSetup {
val gclef = "𝄞"
gclef.size shouldEqual 2
// split up the code point
val half1 = gclef.take(1)
val half2 = gclef.drop(1)
println(half1(0).toInt.toHexString)
println(half2(0).toInt.toHexString)
val pub = StreamTestKit.PublisherProbe[String]
val msg = TextMessage.Streamed(Source(pub))
netOutSub.request(6)
pushMessage(msg)
val sub = pub.expectSubscription()
expectFrameHeaderOnNetwork(Opcode.Text, 0, fin = false)
sub.sendNext(half1)
expectNoNetworkData()
sub.sendNext(half2)
expectFrameOnNetwork(Opcode.Continuation, ByteString(gclef, "utf8"), fin = false)
}
"for a streamed message with a chunk being larger than configured maximum frame size" in pending
"and mask input on the client side" in new ClientTestSetup {
val text = "abcdefg"
val pub = StreamTestKit.PublisherProbe[String]
val msg = TextMessage.Streamed(Source(pub))
netOutSub.request(5)
pushMessage(msg)
val sub = pub.expectSubscription()
expectFrameOnNetwork(Opcode.Text, ByteString.empty, fin = false)
val text1 = text.take(3)
val text1Bytes = ByteString(text1, "UTF-8")
val text2 = text.drop(3)
val text2Bytes = ByteString(text2, "UTF-8")
sub.sendNext(text1)
expectMaskedFrameOnNetwork(Opcode.Continuation, text1Bytes, fin = false)
sub.sendNext(text2)
expectMaskedFrameOnNetwork(Opcode.Continuation, text2Bytes, fin = false)
sub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
}
}
}
"supply automatic low-level websocket behavior" - {
"respond to ping frames unmasking them on the server side" in new ServerTestSetup {
val mask = Random.nextInt()
val input = frameHeader(Opcode.Ping, 6, fin = true, mask = Some(mask)) ++ maskedASCII("abcdef", mask)._1
pushInput(input)
netOutSub.request(5)
expectFrameOnNetwork(Opcode.Pong, ByteString("abcdef"), fin = true)
}
"respond to ping frames masking them on the client side" in new ClientTestSetup {
val input = frameHeader(Opcode.Ping, 6, fin = true) ++ ByteString("abcdef")
pushInput(input)
netOutSub.request(5)
expectMaskedFrameOnNetwork(Opcode.Pong, ByteString("abcdef"), fin = true)
}
"respond to ping frames interleaved with data frames (without mixing frame data)" in new ServerTestSetup {
// receive multi-frame message
// receive and handle interleaved ping frame
// concurrently send out messages from handler
val mask1 = Random.nextInt()
val input1 = frameHeader(Opcode.Binary, 3, fin = false, mask = Some(mask1)) ++ maskedASCII("123", mask1)._1
pushInput(input1)
val BinaryMessage.Streamed(dataSource) = expectMessage()
val sub = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(sub))
val s = sub.expectSubscription()
s.request(2)
sub.expectNext(ByteString("123", "ASCII"))
val outPub = StreamTestKit.PublisherProbe[ByteString]
val msg = BinaryMessage.Streamed(Source(outPub))
netOutSub.request(10)
pushMessage(msg)
expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false)
val outSub = outPub.expectSubscription()
val outData1 = ByteString("abc", "ASCII")
outSub.sendNext(outData1)
expectFrameOnNetwork(Opcode.Continuation, outData1, fin = false)
val pingMask = Random.nextInt()
val pingData = maskedASCII("pling", pingMask)._1
val pingData0 = pingData.take(3)
val pingData1 = pingData.drop(3)
pushInput(frameHeader(Opcode.Ping, 5, fin = true, mask = Some(pingMask)) ++ pingData0)
expectNoNetworkData()
pushInput(pingData1)
expectFrameOnNetwork(Opcode.Pong, ByteString("pling", "ASCII"), fin = true)
val outData2 = ByteString("def", "ASCII")
outSub.sendNext(outData2)
expectFrameOnNetwork(Opcode.Continuation, outData2, fin = false)
outSub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
val mask2 = Random.nextInt()
val input2 = frameHeader(Opcode.Continuation, 3, fin = true, mask = Some(mask2)) ++ maskedASCII("456", mask2)._1
pushInput(input2)
sub.expectNext(ByteString("456", "ASCII"))
sub.expectComplete()
}
"don't respond to unsolicited pong frames" in new ClientTestSetup {
val data = frameHeader(Opcode.Pong, 6, fin = true) ++ ByteString("abcdef")
pushInput(data)
netOutSub.request(5)
expectNoNetworkData()
}
}
"provide close behavior" - {
"after receiving regular close frame when idle (user closes immediately)" in new ServerTestSetup {
netInSub.expectRequest()
netOutSub.request(20)
messageOutSub.request(20)
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
netIn.expectNoMsg(1.second) // especially the cancellation not yet
expectNoNetworkData()
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
netInSub.expectCancellation()
}
"after receiving close frame without close code" in new ServerTestSetup {
netInSub.expectRequest()
pushInput(frameHeader(Opcode.Close, 0, fin = true))
messageIn.expectComplete()
messageOutSub.sendComplete()
// especially mustn't be Procotol.CloseCodes.NoCodePresent
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
netInSub.expectCancellation()
}
"after receiving regular close frame when idle (user still sends some data)" in new ServerTestSetup {
netOutSub.request(20)
messageOutSub.request(20)
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
// sending another message is allowed before closing (inherently racy)
val pub = StreamTestKit.PublisherProbe[ByteString]
val msg = BinaryMessage.Streamed(Source(pub))
pushMessage(msg)
expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false)
val data = ByteString("abc", "ASCII")
val dataSub = pub.expectSubscription()
dataSub.sendNext(data)
expectFrameOnNetwork(Opcode.Continuation, data, fin = false)
dataSub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
}
"after receiving regular close frame when fragmented message is still open" in pendingUntilFixed {
new ServerTestSetup {
netOutSub.request(10)
messageInSub.request(10)
pushInput(frameHeader(Protocol.Opcode.Binary, 0, fin = false))
val BinaryMessage.Streamed(dataSource) = messageIn.expectNext()
val inSubscriber = StreamTestKit.SubscriberProbe[ByteString]
dataSource.runWith(Sink(inSubscriber))
val inSub = inSubscriber.expectSubscription()
val outData = ByteString("def", "ASCII")
val mask = Random.nextInt()
pushInput(frameHeader(Protocol.Opcode.Continuation, 3, fin = false, mask = Some(mask)) ++ maskedBytes(outData, mask)._1)
inSub.request(5)
inSubscriber.expectNext(outData)
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
inSubscriber.expectError()
// truncation of open message
// sending another message is allowed before closing (inherently racy)
val pub = StreamTestKit.PublisherProbe[ByteString]
val msg = BinaryMessage.Streamed(Source(pub))
pushMessage(msg)
expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false)
val data = ByteString("abc", "ASCII")
val dataSub = pub.expectSubscription()
dataSub.sendNext(data)
expectFrameOnNetwork(Opcode.Continuation, data, fin = false)
dataSub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
}
}
"after receiving error close frame" in pending
"after peer closes connection without sending a close frame" in new ServerTestSetup {
netInSub.expectRequest()
netInSub.sendComplete()
messageIn.expectComplete()
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
}
"when user handler closes (simple)" in new ServerTestSetup {
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectNoMsg(1.second) // wait for peer to close regularly
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
netOut.expectComplete()
netInSub.expectCancellation()
}
"when user handler closes main stream and substream only afterwards" in new ServerTestSetup {
netOutSub.request(10)
messageInSub.request(10)
// send half a message
val pub = StreamTestKit.PublisherProbe[ByteString]
val msg = BinaryMessage.Streamed(Source(pub))
pushMessage(msg)
expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false)
val data = ByteString("abc", "ASCII")
val dataSub = pub.expectSubscription()
dataSub.sendNext(data)
expectFrameOnNetwork(Opcode.Continuation, data, fin = false)
messageOutSub.sendComplete()
expectNoNetworkData() // need to wait for substream to close
dataSub.sendComplete()
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectNoMsg(1.second) // wait for peer to close regularly
val mask = Random.nextInt()
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
netOut.expectComplete()
netInSub.expectCancellation()
}
"if user handler fails" in pending
"if peer closes with invalid close frame" - {
"close code outside of the valid range" in new ServerTestSetup {
netInSub.expectRequest()
pushInput(frameHeader(Opcode.Close, 1, mask = Some(Random.nextInt()), fin = true) ++ ByteString("x"))
val error = messageIn.expectError()
expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError)
netOut.expectComplete()
netInSub.expectCancellation()
}
"close data of size 1" in new ServerTestSetup {
netInSub.expectRequest()
pushInput(frameHeader(Opcode.Close, 1, mask = Some(Random.nextInt()), fin = true) ++ ByteString("x"))
val error = messageIn.expectError()
expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError)
netOut.expectComplete()
netInSub.expectCancellation()
}
"reason is no valid utf8 data" in pending
}
"timeout if user handler closes and peer doesn't send a close frame" in new ServerTestSetup {
netInSub.expectRequest()
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
netInSub.expectCancellation()
}
"timeout after we close after error and peer doesn't send a close frame" in new ServerTestSetup {
netInSub.expectRequest()
pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv1 = true))
expectProtocolErrorOnNetwork()
messageOutSub.sendComplete()
netOut.expectComplete()
netInSub.expectCancellation()
}
"ignore frames peer sends after close frame" in new ServerTestSetup {
netInSub.expectRequest()
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
messageIn.expectComplete()
pushInput(frameHeader(Opcode.Binary, 0, fin = true))
messageOutSub.sendComplete()
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
netOut.expectComplete()
netInSub.expectCancellation()
}
}
"reject unexpected frames" - {
"reserved bits set" - {
"rsv1" in new ServerTestSetup {
pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv1 = true))
expectProtocolErrorOnNetwork()
}
"rsv2" in new ServerTestSetup {
pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv2 = true))
expectProtocolErrorOnNetwork()
}
"rsv3" in new ServerTestSetup {
pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv3 = true))
expectProtocolErrorOnNetwork()
}
}
"highest bit of 64-bit length is set" in new ServerTestSetup {
import BitBuilder._
val header =
b"""0000 # flags
xxxx=1 # opcode
1 # mask?
xxxxxxx=7f # length
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx=ffffffff
xxxxxxxx
xxxxxxxx
xxxxxxxx
xxxxxxxx=ffffffff # length64
00000000
00000000
00000000
00000000 # empty mask
"""
pushInput(header)
expectProtocolErrorOnNetwork()
}
"control frame bigger than 125 bytes" in new ServerTestSetup {
pushInput(frameHeader(Opcode.Ping, 126, fin = true, mask = Some(0)))
expectProtocolErrorOnNetwork()
}
"fragmented control frame" in new ServerTestSetup {
pushInput(frameHeader(Opcode.Ping, 0, fin = false, mask = Some(0)))
expectProtocolErrorOnNetwork()
}
"unexpected continuation frame" in new ServerTestSetup {
pushInput(frameHeader(Opcode.Continuation, 0, fin = false, mask = Some(0)))
expectProtocolErrorOnNetwork()
}
"unexpected data frame when waiting for continuation" in new ServerTestSetup {
pushInput(frameHeader(Opcode.Binary, 0, fin = false) ++
frameHeader(Opcode.Binary, 0, fin = false))
expectProtocolErrorOnNetwork()
}
"invalid utf8 encoding for single frame message" in new ClientTestSetup {
val data = ByteString(
(128 + 64).toByte, // start two byte sequence
0 // but don't finish it
)
pushInput(frameHeader(Opcode.Text, 2, fin = true) ++ data)
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
}
"invalid utf8 encoding for streamed frame" in new ClientTestSetup {
val data = ByteString(
(128 + 64).toByte, // start two byte sequence
0 // but don't finish it
)
pushInput(frameHeader(Opcode.Text, 0, fin = false) ++
frameHeader(Opcode.Continuation, 2, fin = true) ++
data)
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
}
"truncated utf8 encoding for single frame message" in new ClientTestSetup {
val data = ByteString("€", "UTF-8").take(1) // half a euro
pushInput(frameHeader(Opcode.Text, 1, fin = true) ++ data)
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
}
"truncated utf8 encoding for streamed frame" in new ClientTestSetup {
val data = ByteString("€", "UTF-8").take(1) // half a euro
pushInput(frameHeader(Opcode.Text, 0, fin = false) ++
frameHeader(Opcode.Continuation, 1, fin = true) ++
data)
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
}
"half a surrogate pair in utf8 encoding for a strict frame" in new ClientTestSetup {
val data = ByteString(0xed, 0xa0, 0x80) // not strictly supported by utf-8
pushInput(frameHeader(Opcode.Text, 3, fin = true) ++ data)
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
}
"half a surrogate pair in utf8 encoding for a streamed frame" in new ClientTestSetup {
val data = ByteString(0xed, 0xa0, 0x80) // not strictly supported by utf-8
pushInput(frameHeader(Opcode.Text, 0, fin = false))
pushInput(frameHeader(Opcode.Continuation, 3, fin = true) ++ data)
messageIn.expectError()
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
}
"unmasked input on the server side" in new ServerTestSetup {
val data = ByteString("abcdef", "ASCII")
val input = frameHeader(Opcode.Binary, 6, fin = true) ++ data
pushInput(input)
expectProtocolErrorOnNetwork()
}
"masked input on the client side" in new ClientTestSetup {
val mask = Random.nextInt()
val input = frameHeader(Opcode.Binary, 6, fin = true, mask = Some(mask)) ++ maskedASCII("abcdef", mask)._1
pushInput(input)
expectProtocolErrorOnNetwork()
}
}
"support per-message-compression extension" in pending
}
class ServerTestSetup extends TestSetup {
protected def serverSide: Boolean = true
}
class ClientTestSetup extends TestSetup {
protected def serverSide: Boolean = false
}
abstract class TestSetup {
protected def serverSide: Boolean
protected def closeTimeout: FiniteDuration = 1.second
val netIn = StreamTestKit.PublisherProbe[ByteString]
val netOut = StreamTestKit.SubscriberProbe[ByteString]
val messageIn = StreamTestKit.SubscriberProbe[Message]
val messageOut = StreamTestKit.PublisherProbe[Message]
val messageHandler: Flow[Message, Message, Unit] =
Flow.wrap {
FlowGraph.partial() { implicit b
val in = b.add(Sink(messageIn))
val out = b.add(Source(messageOut))
FlowShape[Message, Message](in, out)
}
}
Source(netIn)
.via(printEvent("netIn"))
.transform(() new FrameEventParser)
.via(Websocket.handleMessages(messageHandler, serverSide, closeTimeout = closeTimeout))
.via(printEvent("frameRendererIn"))
.transform(() new FrameEventRenderer)
.via(printEvent("frameRendererOut"))
.to(Sink(netOut))
.run()
val netInSub = netIn.expectSubscription()
val netOutSub = netOut.expectSubscription()
val messageOutSub = messageOut.expectSubscription()
val messageInSub = messageIn.expectSubscription()
def pushInput(data: ByteString): Unit = {
// TODO: expect/handle request?
netInSub.sendNext(data)
}
def pushMessage(msg: Message): Unit = {
messageOutSub.sendNext(msg)
}
def expectMessage(message: Message): Unit = {
messageInSub.request(1)
messageIn.expectNext(message)
}
def expectMessage(): Message = {
messageInSub.request(1)
messageIn.expectNext()
}
var inBuffer = ByteString.empty
@tailrec final def expectNetworkData(bytes: Int): ByteString =
if (inBuffer.size >= bytes) {
val res = inBuffer.take(bytes)
inBuffer = inBuffer.drop(bytes)
res
} else {
netOutSub.request(1)
inBuffer ++= netOut.expectNext()
expectNetworkData(bytes)
}
def expectNetworkData(data: ByteString): Unit =
expectNetworkData(data.size) shouldEqual data
def expectFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = {
expectFrameHeaderOnNetwork(opcode, data.size, fin)
expectNetworkData(data)
}
def expectMaskedFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = {
val Some(mask) = expectFrameHeaderOnNetwork(opcode, data.size, fin)
val masked = maskedBytes(data, mask)._1
expectNetworkData(masked)
}
/** Returns the mask if any is available */
def expectFrameHeaderOnNetwork(opcode: Opcode, length: Long, fin: Boolean): Option[Int] = {
val (op, l, f, m) = expectFrameHeaderOnNetwork()
op shouldEqual opcode
l shouldEqual length
f shouldEqual fin
m
}
def expectFrameHeaderOnNetwork(): (Opcode, Long, Boolean, Option[Int]) = {
val header = expectNetworkData(2)
val fin = (header(0) & Protocol.FIN_MASK) != 0
val op = header(0) & Protocol.OP_MASK
val hasMask = (header(1) & Protocol.MASK_MASK) != 0
val length7 = header(1) & Protocol.LENGTH_MASK
val length = length7 match {
case 126
val length16Bytes = expectNetworkData(2)
(length16Bytes(0) & 0xff) << 8 | (length16Bytes(1) & 0xff) << 0
case 127
val length64Bytes = expectNetworkData(8)
(length64Bytes(0) & 0xff).toLong << 56 |
(length64Bytes(1) & 0xff).toLong << 48 |
(length64Bytes(2) & 0xff).toLong << 40 |
(length64Bytes(3) & 0xff).toLong << 32 |
(length64Bytes(4) & 0xff).toLong << 24 |
(length64Bytes(5) & 0xff).toLong << 16 |
(length64Bytes(6) & 0xff).toLong << 8 |
(length64Bytes(7) & 0xff).toLong << 0
case x x
}
val mask =
if (hasMask) {
val maskBytes = expectNetworkData(4)
val mask =
(maskBytes(0) & 0xff) << 24 |
(maskBytes(1) & 0xff) << 16 |
(maskBytes(2) & 0xff) << 8 |
(maskBytes(3) & 0xff) << 0
Some(mask)
} else None
(Opcode.forCode(op.toByte), length, fin, mask)
}
def expectProtocolErrorOnNetwork(): Unit = expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError)
def expectCloseCodeOnNetwork(expectedCode: Int): Unit = {
val (opcode, length, true, mask) = expectFrameHeaderOnNetwork()
opcode shouldEqual Opcode.Close
length should be >= 2.toLong
val rawData = expectNetworkData(length.toInt)
val data = mask match {
case Some(m) FrameEventParser.mask(rawData, m)._1
case None rawData
}
val code = ((data(0) & 0xff) << 8) | ((data(1) & 0xff) << 0)
code shouldEqual expectedCode
}
def expectNoNetworkData(): Unit =
netOut.expectNoMsg(100.millis)
}
def frameHeader(
opcode: Opcode,
length: Long,
fin: Boolean,
mask: Option[Int] = None,
rsv1: Boolean = false,
rsv2: Boolean = false,
rsv3: Boolean = false): ByteString = {
def set(should: Boolean, mask: Int): Int =
if (should) mask else 0
val flags =
set(fin, Protocol.FIN_MASK) |
set(rsv1, Protocol.RSV1_MASK) |
set(rsv2, Protocol.RSV2_MASK) |
set(rsv3, Protocol.RSV3_MASK)
val opcodeByte = opcode.code | flags
require(length >= 0)
val (lengthByteComponent, lengthBytes) =
if (length < 126) (length.toByte, ByteString.empty)
else if (length < 65536) (126.toByte, shortBE(length.toInt))
else throw new IllegalArgumentException("Only lengths < 65536 allowed in test")
val maskMask = if (mask.isDefined) Protocol.MASK_MASK else 0
val maskBytes = mask match {
case Some(mask) intBE(mask)
case None ByteString.empty
}
val lengthByte = lengthByteComponent | maskMask
ByteString(opcodeByte.toByte, lengthByte.toByte) ++ lengthBytes ++ maskBytes
}
def closeFrame(closeCode: Int, mask: Boolean): ByteString =
if (mask) {
val mask = Random.nextInt()
frameHeader(Opcode.Close, 2, fin = true, mask = Some(mask)) ++
maskedBytes(shortBE(closeCode), mask)._1
} else
frameHeader(Opcode.Close, 2, fin = true) ++
shortBE(closeCode)
def maskedASCII(str: String, mask: Int): (ByteString, Int) =
FrameEventParser.mask(ByteString(str, "ASCII"), mask)
def maskedUTF8(str: String, mask: Int): (ByteString, Int) =
FrameEventParser.mask(ByteString(str, "UTF-8"), mask)
def maskedBytes(bytes: ByteString, mask: Int): (ByteString, Int) =
FrameEventParser.mask(bytes, mask)
def shortBE(value: Int): ByteString = {
require(value >= 0 && value < 65536, s"Value wasn't in short range: $value")
ByteString(
((value >> 8) & 0xff).toByte,
((value >> 0) & 0xff).toByte)
}
def intBE(value: Int): ByteString =
ByteString(
((value >> 24) & 0xff).toByte,
((value >> 16) & 0xff).toByte,
((value >> 8) & 0xff).toByte,
((value >> 0) & 0xff).toByte)
val trace = false // set to `true` for debugging purposes
def printEvent[T](marker: String): Flow[T, T, Unit] =
if (trace) akka.http.util.printEvent(marker)
else Flow[T]
}