=htc #16887 implement high-level server-side Websocket API
This commit is contained in:
parent
880733eb3d
commit
6dafa445de
6 changed files with 1585 additions and 0 deletions
|
|
@ -0,0 +1,200 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.engine.ws
|
||||
|
||||
import akka.stream.scaladsl.Flow
|
||||
import akka.stream.stage.{ TerminationDirective, SyncDirective, Context, StatefulStage }
|
||||
import akka.util.ByteString
|
||||
import Protocol.Opcode
|
||||
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
/**
|
||||
* The frame handler validates frames, multiplexes data to the user handler or to the bypass and
|
||||
* UTF-8 decodes text frames.
|
||||
*
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[http] object FrameHandler {
|
||||
def create(server: Boolean): Flow[FrameEvent, Either[BypassEvent, MessagePart], Unit] =
|
||||
Flow[FrameEvent].transform(() ⇒ new HandlerStage(server))
|
||||
|
||||
class HandlerStage(server: Boolean) extends StatefulStage[FrameEvent, Either[BypassEvent, MessagePart]] {
|
||||
type Ctx = Context[Either[BypassEvent, MessagePart]]
|
||||
def initial: State = Idle
|
||||
|
||||
object Idle extends StateWithControlFrameHandling {
|
||||
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
|
||||
(start.header.opcode, start.isFullMessage) match {
|
||||
case (Opcode.Binary, true) ⇒ publishMessagePart(BinaryMessagePart(start.data, last = true))
|
||||
case (Opcode.Binary, false) ⇒ becomeAndHandleWith(new CollectingBinaryMessage, start)
|
||||
case (Opcode.Text, _) ⇒ becomeAndHandleWith(new CollectingTextMessage, start)
|
||||
case x ⇒ protocolError()
|
||||
}
|
||||
}
|
||||
|
||||
class CollectingBinaryMessage extends CollectingMessageFrame(Opcode.Binary) {
|
||||
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = BinaryMessagePart(data, last)
|
||||
}
|
||||
class CollectingTextMessage extends CollectingMessageFrame(Opcode.Text) {
|
||||
val decoder = Utf8Decoder.create()
|
||||
|
||||
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart =
|
||||
TextMessagePart(decoder.decode(data, endOfInput = last).get, last)
|
||||
}
|
||||
|
||||
abstract class CollectingMessageFrame(expectedOpcode: Opcode) extends StateWithControlFrameHandling {
|
||||
var expectFirstHeader = true
|
||||
var finSeen = false
|
||||
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart
|
||||
|
||||
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = {
|
||||
if ((expectFirstHeader && start.header.opcode == expectedOpcode) // first opcode must be the expected
|
||||
|| start.header.opcode == Opcode.Continuation) { // further ones continuations
|
||||
expectFirstHeader = false
|
||||
|
||||
if (start.header.fin) finSeen = true
|
||||
publish(start)
|
||||
} else protocolError()
|
||||
}
|
||||
override def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = publish(data)
|
||||
|
||||
private def publish(part: FrameEvent)(implicit ctx: Ctx): SyncDirective =
|
||||
try publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart))
|
||||
catch {
|
||||
case NonFatal(e) ⇒ closeWithCode(Protocol.CloseCodes.InconsistentData)
|
||||
}
|
||||
}
|
||||
class CollectingControlFrame(opcode: Opcode, _data: ByteString, nextState: State) extends InFrameState {
|
||||
var data = _data
|
||||
|
||||
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = {
|
||||
this.data ++= data.data
|
||||
if (data.lastPart) handleControlFrame(opcode, this.data, nextState)
|
||||
else ctx.pull()
|
||||
}
|
||||
}
|
||||
object Closed extends State {
|
||||
def onPush(elem: FrameEvent, ctx: Ctx): SyncDirective =
|
||||
ctx.pull() // ignore
|
||||
}
|
||||
|
||||
def becomeAndHandleWith(newState: State, part: FrameEvent)(implicit ctx: Ctx): SyncDirective = {
|
||||
become(newState)
|
||||
current.onPush(part, ctx)
|
||||
}
|
||||
|
||||
/** Returns a SyncDirective if it handled the message */
|
||||
def validateHeader(header: FrameHeader)(implicit ctx: Ctx): Option[SyncDirective] = header match {
|
||||
case h: FrameHeader if h.mask.isDefined && !server ⇒ Some(protocolError())
|
||||
case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3 ⇒ Some(protocolError())
|
||||
case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) ⇒ Some(protocolError())
|
||||
case _ ⇒ None
|
||||
}
|
||||
|
||||
def handleControlFrame(opcode: Opcode, data: ByteString, nextState: State)(implicit ctx: Ctx): SyncDirective = {
|
||||
become(nextState)
|
||||
opcode match {
|
||||
case Opcode.Ping ⇒ publishDirectResponse(FrameEvent.fullFrame(Opcode.Pong, None, data, fin = true))
|
||||
case Opcode.Pong ⇒
|
||||
// ignore unsolicited Pong frame
|
||||
ctx.pull()
|
||||
case Opcode.Close ⇒
|
||||
val closeCode = FrameEventParser.parseCloseCode(data)
|
||||
emit(Iterator(Left(PeerClosed(closeCode)), Right(PeerClosed(closeCode))), ctx, WaitForPeerTcpClose)
|
||||
case Opcode.Other(o) ⇒ closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode")
|
||||
}
|
||||
}
|
||||
private def collectControlFrame(start: FrameStart, nextState: State)(implicit ctx: Ctx): SyncDirective = {
|
||||
assert(!start.isFullMessage)
|
||||
become(new CollectingControlFrame(start.header.opcode, start.data, nextState))
|
||||
ctx.pull()
|
||||
}
|
||||
|
||||
private def publishMessagePart(part: MessageDataPart)(implicit ctx: Ctx): SyncDirective =
|
||||
if (part.last) emit(Iterator(Right(part), Right(MessageEnd)), ctx, Idle)
|
||||
else ctx.push(Right(part))
|
||||
private def publishDirectResponse(frame: FrameStart)(implicit ctx: Ctx): SyncDirective =
|
||||
ctx.push(Left(DirectAnswer(frame)))
|
||||
|
||||
private def protocolError(reason: String = "")(implicit ctx: Ctx): SyncDirective =
|
||||
closeWithCode(Protocol.CloseCodes.ProtocolError, reason)
|
||||
|
||||
private def closeWithCode(closeCode: Int, reason: String = "", cause: Throwable = null)(implicit ctx: Ctx): SyncDirective =
|
||||
emit(
|
||||
Iterator(
|
||||
Left(ActivelyCloseWithCode(Some(closeCode), reason)),
|
||||
Right(ActivelyCloseWithCode(Some(closeCode), reason))), ctx, CloseAfterPeerClosed)
|
||||
|
||||
object CloseAfterPeerClosed extends State {
|
||||
def onPush(elem: FrameEvent, ctx: Context[Either[BypassEvent, MessagePart]]): SyncDirective =
|
||||
elem match {
|
||||
case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data) ⇒
|
||||
become(WaitForPeerTcpClose)
|
||||
ctx.push(Left(PeerClosed(FrameEventParser.parseCloseCode(data))))
|
||||
|
||||
case _ ⇒ ctx.pull() // ignore all other data
|
||||
}
|
||||
}
|
||||
object WaitForPeerTcpClose extends State {
|
||||
def onPush(elem: FrameEvent, ctx: Context[Either[BypassEvent, MessagePart]]): SyncDirective =
|
||||
ctx.pull() // ignore
|
||||
}
|
||||
|
||||
abstract class StateWithControlFrameHandling extends BetweenFrameState {
|
||||
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective
|
||||
|
||||
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
|
||||
validateHeader(start.header).getOrElse {
|
||||
if (start.header.opcode.isControl)
|
||||
if (start.isFullMessage) handleControlFrame(start.header.opcode, start.data, this)
|
||||
else collectControlFrame(start, this)
|
||||
else handleRegularFrameStart(start)
|
||||
}
|
||||
}
|
||||
abstract class BetweenFrameState extends ImplicitContextState {
|
||||
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective =
|
||||
throw new IllegalStateException("Expected FrameStart")
|
||||
}
|
||||
abstract class InFrameState extends ImplicitContextState {
|
||||
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
|
||||
throw new IllegalStateException("Expected FrameData")
|
||||
}
|
||||
abstract class ImplicitContextState extends State {
|
||||
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective
|
||||
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective
|
||||
|
||||
def onPush(part: FrameEvent, ctx: Ctx): SyncDirective =
|
||||
part match {
|
||||
case data: FrameData ⇒ handleFrameData(data)(ctx)
|
||||
case start: FrameStart ⇒ handleFrameStart(start)(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sealed trait MessagePart {
|
||||
def isMessageEnd: Boolean
|
||||
}
|
||||
sealed trait MessageDataPart extends MessagePart {
|
||||
def isMessageEnd = false
|
||||
def last: Boolean
|
||||
}
|
||||
final case class TextMessagePart(data: String, last: Boolean) extends MessageDataPart
|
||||
final case class BinaryMessagePart(data: ByteString, last: Boolean) extends MessageDataPart
|
||||
case object MessageEnd extends MessagePart {
|
||||
def isMessageEnd: Boolean = true
|
||||
}
|
||||
final case class PeerClosed(code: Option[Int], reason: String = "") extends MessagePart with BypassEvent {
|
||||
def isMessageEnd: Boolean = true
|
||||
}
|
||||
|
||||
sealed trait BypassEvent
|
||||
final case class DirectAnswer(frame: FrameStart) extends BypassEvent
|
||||
final case class ActivelyCloseWithCode(code: Option[Int], reason: String = "") extends MessagePart with BypassEvent {
|
||||
def isMessageEnd: Boolean = true
|
||||
}
|
||||
case object UserHandlerCompleted extends BypassEvent
|
||||
case class UserHandlerErredOut(cause: Throwable) extends BypassEvent
|
||||
}
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.engine.ws
|
||||
|
||||
import scala.concurrent.duration.FiniteDuration
|
||||
|
||||
import akka.stream.stage._
|
||||
import akka.http.util.Timestamp
|
||||
import FrameHandler.{ UserHandlerCompleted, ActivelyCloseWithCode, PeerClosed, DirectAnswer }
|
||||
import Websocket.Tick
|
||||
|
||||
/**
|
||||
* Implements the transport connection close handling at the end of the pipeline.
|
||||
*
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[http] class FrameOutHandler(serverSide: Boolean, _closeTimeout: FiniteDuration) extends StatefulStage[AnyRef, FrameStart] {
|
||||
def initial: StageState[AnyRef, FrameStart] = Idle
|
||||
def closeTimeout: Timestamp = Timestamp.now + _closeTimeout
|
||||
|
||||
object Idle extends CompletionHandlingState {
|
||||
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match {
|
||||
case start: FrameStart ⇒ ctx.push(start)
|
||||
case DirectAnswer(frame) ⇒ ctx.push(frame)
|
||||
case PeerClosed(code, reason) if !code.exists(Protocol.CloseCodes.isError) ⇒
|
||||
// let user complete it, FIXME: maybe make configurable? immediately, or timeout
|
||||
become(new WaitingForUserHandlerClosed(FrameEvent.closeFrame(code.getOrElse(Protocol.CloseCodes.Regular), reason)))
|
||||
ctx.pull()
|
||||
case PeerClosed(code, reason) ⇒
|
||||
val closeFrame = FrameEvent.closeFrame(code.getOrElse(Protocol.CloseCodes.Regular), reason)
|
||||
if (serverSide) ctx.pushAndFinish(closeFrame)
|
||||
else {
|
||||
become(new WaitingForTransportClose)
|
||||
ctx.push(closeFrame)
|
||||
}
|
||||
case ActivelyCloseWithCode(code, reason) ⇒
|
||||
val closeFrame = FrameEvent.closeFrame(code.getOrElse(Protocol.CloseCodes.Regular), reason)
|
||||
become(new WaitingForPeerCloseFrame())
|
||||
ctx.push(closeFrame)
|
||||
case UserHandlerCompleted ⇒
|
||||
become(new WaitingForPeerCloseFrame())
|
||||
ctx.push(FrameEvent.closeFrame(Protocol.CloseCodes.Regular))
|
||||
case Tick ⇒ ctx.pull() // ignore
|
||||
}
|
||||
|
||||
def onComplete(ctx: Context[FrameStart]): TerminationDirective = {
|
||||
become(new SendOutCloseFrameAndComplete(FrameEvent.closeFrame(Protocol.CloseCodes.Regular)))
|
||||
ctx.absorbTermination()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* peer has closed, we want to wait for user handler to close as well
|
||||
*/
|
||||
class WaitingForUserHandlerClosed(closeFrame: FrameStart) extends CompletionHandlingState {
|
||||
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match {
|
||||
case UserHandlerCompleted ⇒
|
||||
if (serverSide) ctx.pushAndFinish(closeFrame)
|
||||
else {
|
||||
become(new WaitingForTransportClose())
|
||||
ctx.push(closeFrame)
|
||||
}
|
||||
case start: FrameStart ⇒ ctx.push(start)
|
||||
case _ ⇒ ctx.pull() // ignore
|
||||
}
|
||||
|
||||
def onComplete(ctx: Context[FrameStart]): TerminationDirective =
|
||||
ctx.fail(new IllegalStateException("Mustn't complete before user has completed"))
|
||||
}
|
||||
|
||||
/**
|
||||
* we have sent out close frame and wait for peer to sent its close frame
|
||||
*/
|
||||
class WaitingForPeerCloseFrame(timeout: Timestamp = closeTimeout) extends CompletionHandlingState {
|
||||
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match {
|
||||
case Tick ⇒
|
||||
if (timeout.isPast) ctx.finish()
|
||||
else ctx.pull()
|
||||
case PeerClosed(code, reason) ⇒
|
||||
if (serverSide) ctx.finish()
|
||||
else {
|
||||
become(new WaitingForTransportClose())
|
||||
ctx.pull()
|
||||
}
|
||||
case _ ⇒ ctx.pull() // ignore
|
||||
}
|
||||
|
||||
def onComplete(ctx: Context[FrameStart]): TerminationDirective = ctx.finish()
|
||||
}
|
||||
|
||||
/**
|
||||
* Both side have sent their close frames, server should close the connection first
|
||||
*/
|
||||
class WaitingForTransportClose(timeout: Timestamp = closeTimeout) extends CompletionHandlingState {
|
||||
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match {
|
||||
case Tick ⇒
|
||||
if (timeout.isPast) ctx.finish()
|
||||
else ctx.pull()
|
||||
case _ ⇒ ctx.pull() // ignore
|
||||
}
|
||||
|
||||
def onComplete(ctx: Context[FrameStart]): TerminationDirective = ctx.finish()
|
||||
}
|
||||
|
||||
/** If upstream has already failed we just wait to be able to deliver our close frame and complete */
|
||||
class SendOutCloseFrameAndComplete(closeFrame: FrameStart) extends CompletionHandlingState {
|
||||
def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective =
|
||||
ctx.fail(new IllegalStateException("Didn't expect push after completion"))
|
||||
|
||||
override def onPull(ctx: Context[FrameStart]): SyncDirective =
|
||||
ctx.pushAndFinish(closeFrame)
|
||||
|
||||
def onComplete(ctx: Context[FrameStart]): TerminationDirective =
|
||||
ctx.absorbTermination()
|
||||
}
|
||||
|
||||
trait CompletionHandlingState extends State {
|
||||
def onComplete(ctx: Context[FrameStart]): TerminationDirective
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[FrameStart]): TerminationDirective =
|
||||
current.asInstanceOf[CompletionHandlingState].onComplete(ctx)
|
||||
|
||||
override def onUpstreamFailure(cause: scala.Throwable, ctx: Context[FrameStart]): TerminationDirective = cause match {
|
||||
case p: ProtocolException ⇒
|
||||
become(new SendOutCloseFrameAndComplete(FrameEvent.closeFrame(Protocol.CloseCodes.ProtocolError)))
|
||||
ctx.absorbTermination()
|
||||
case _ ⇒ super.onUpstreamFailure(cause, ctx)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.engine.ws
|
||||
|
||||
import akka.stream.scaladsl.Flow
|
||||
import akka.stream.stage.{ SyncDirective, Context, StageState, StatefulStage }
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
/**
|
||||
* Implements Websocket Frame masking.
|
||||
*
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[http] object Masking {
|
||||
def maskIf(condition: Boolean, maskRandom: () ⇒ Random): Flow[FrameEvent, FrameEvent, Unit] =
|
||||
if (condition) Flow[FrameEvent].transform(() ⇒ new Masking(maskRandom())) // new random per materialization
|
||||
else Flow[FrameEvent]
|
||||
def unmaskIf(condition: Boolean): Flow[FrameEvent, FrameEvent, Unit] =
|
||||
if (condition) Flow[FrameEvent].transform(() ⇒ new Unmasking())
|
||||
else Flow[FrameEvent]
|
||||
|
||||
class Masking(random: Random) extends Masker {
|
||||
def extractMask(header: FrameHeader): Int = random.nextInt()
|
||||
def setNewMask(header: FrameHeader, mask: Int): FrameHeader = {
|
||||
if (header.mask.isDefined) throw new ProtocolException("Frame mustn't already be masked")
|
||||
header.copy(mask = Some(mask))
|
||||
}
|
||||
}
|
||||
class Unmasking extends Masker {
|
||||
def extractMask(header: FrameHeader): Int = header.mask match {
|
||||
case Some(mask) ⇒ mask
|
||||
case None ⇒ throw new ProtocolException("Frame wasn't masked")
|
||||
}
|
||||
def setNewMask(header: FrameHeader, mask: Int): FrameHeader = header.copy(mask = None)
|
||||
}
|
||||
|
||||
/** Implements both masking and unmasking which is mostly symmetric (because of XOR) */
|
||||
abstract class Masker extends StatefulStage[FrameEvent, FrameEvent] {
|
||||
def extractMask(header: FrameHeader): Int
|
||||
def setNewMask(header: FrameHeader, mask: Int): FrameHeader
|
||||
|
||||
def initial: State = Idle
|
||||
|
||||
object Idle extends State {
|
||||
def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective =
|
||||
part match {
|
||||
case start @ FrameStart(header, data) ⇒
|
||||
if (header.length == 0) ctx.push(part)
|
||||
else {
|
||||
val mask = extractMask(header)
|
||||
become(new Running(mask))
|
||||
current.onPush(start.copy(header = setNewMask(header, mask)), ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
class Running(initialMask: Int) extends State {
|
||||
var mask = initialMask
|
||||
|
||||
def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective = {
|
||||
if (part.lastPart) become(Idle)
|
||||
|
||||
val (masked, newMask) = FrameEventParser.mask(part.data, mask)
|
||||
mask = newMask
|
||||
ctx.push(part.withData(data = masked))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.engine.ws
|
||||
|
||||
import akka.util.ByteString
|
||||
import akka.stream.scaladsl.{ FlattenStrategy, Source, Flow }
|
||||
|
||||
import Protocol.Opcode
|
||||
import akka.http.model.ws._
|
||||
|
||||
/**
|
||||
* Renders messages to full frames.
|
||||
*
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[http] object MessageToFrameRenderer {
|
||||
def create(serverSide: Boolean): Flow[Message, FrameStart, Unit] = {
|
||||
def strictFrames(opcode: Opcode, data: ByteString): Source[FrameStart, _] =
|
||||
// FIXME: fragment?
|
||||
Source.single(FrameEvent.fullFrame(opcode, None, data, fin = true))
|
||||
|
||||
def streamedFrames(opcode: Opcode, data: Source[ByteString, _]): Source[FrameStart, _] =
|
||||
Source.single(FrameEvent.empty(opcode, fin = false)) ++
|
||||
data.map(FrameEvent.fullFrame(Opcode.Continuation, None, _, fin = false)) ++
|
||||
Source.single(FrameEvent.emptyLastContinuationFrame)
|
||||
|
||||
Flow[Message]
|
||||
.map {
|
||||
case BinaryMessage.Strict(data) ⇒ strictFrames(Opcode.Binary, data)
|
||||
case BinaryMessage.Streamed(data) ⇒ streamedFrames(Opcode.Binary, data)
|
||||
case TextMessage.Strict(text) ⇒ strictFrames(Opcode.Text, ByteString(text, "UTF-8"))
|
||||
case TextMessage.Streamed(text) ⇒ streamedFrames(Opcode.Text, text.transform(() ⇒ new Utf8Encoder))
|
||||
}.flatten(FlattenStrategy.Concat())
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.engine.ws
|
||||
|
||||
import java.security.SecureRandom
|
||||
|
||||
import scala.concurrent.duration._
|
||||
|
||||
import akka.stream.{ OperationAttributes, FanOutShape2, FanInShape3, Inlet }
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream.stage._
|
||||
import FlexiRoute.{ DemandFrom, DemandFromAny, RouteLogic }
|
||||
import FlexiMerge.MergeLogic
|
||||
|
||||
import akka.http.util._
|
||||
import akka.http.model.ws._
|
||||
|
||||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
private[http] object Websocket {
|
||||
import FrameHandler._
|
||||
|
||||
def handleMessages[T](messageHandler: Flow[Message, Message, T],
|
||||
serverSide: Boolean = true,
|
||||
closeTimeout: FiniteDuration = 3.seconds): Flow[FrameEvent, FrameEvent, Unit] = {
|
||||
/** Completes this branch of the flow if no more messages are expected and converts close codes into errors */
|
||||
class PrepareForUserHandler extends PushStage[MessagePart, MessagePart] {
|
||||
def onPush(elem: MessagePart, ctx: Context[MessagePart]): SyncDirective = elem match {
|
||||
case PeerClosed(code, reason) ⇒
|
||||
if (code.exists(Protocol.CloseCodes.isError)) ctx.fail(new ProtocolException(s"Peer closed connection with code $code"))
|
||||
else ctx.finish()
|
||||
case ActivelyCloseWithCode(code, reason) ⇒
|
||||
if (code.exists(Protocol.CloseCodes.isError)) ctx.fail(new ProtocolException(s"Closing connection with error code $code"))
|
||||
else ctx.finish()
|
||||
case x ⇒ ctx.push(x)
|
||||
}
|
||||
}
|
||||
|
||||
/** Collects user-level API messages from MessageDataParts */
|
||||
val collectMessage: Flow[Source[MessageDataPart, Unit], Message, Unit] =
|
||||
Flow[Source[MessageDataPart, Unit]]
|
||||
.headAndTail
|
||||
.map {
|
||||
case (TextMessagePart(text, true), remaining) ⇒
|
||||
TextMessage.Strict(text)
|
||||
case (first @ TextMessagePart(text, false), remaining) ⇒
|
||||
TextMessage.Streamed(
|
||||
(Source.single(first) ++ remaining)
|
||||
.collect {
|
||||
case t: TextMessagePart if t.data.nonEmpty ⇒ t.data
|
||||
})
|
||||
case (BinaryMessagePart(data, true), remaining) ⇒
|
||||
BinaryMessage.Strict(data)
|
||||
case (first @ BinaryMessagePart(data, false), remaining) ⇒
|
||||
BinaryMessage.Streamed(
|
||||
(Source.single(first) ++ remaining)
|
||||
.collect {
|
||||
case t: BinaryMessagePart if t.data.nonEmpty ⇒ t.data
|
||||
})
|
||||
}
|
||||
|
||||
/** Lifts onComplete and onError into events to be processed in the FlexiMerge */
|
||||
class LiftCompletions extends StatefulStage[FrameStart, AnyRef] {
|
||||
def initial: StageState[FrameStart, AnyRef] = SteadyState
|
||||
|
||||
object SteadyState extends State {
|
||||
def onPush(elem: FrameStart, ctx: Context[AnyRef]): SyncDirective = ctx.push(elem)
|
||||
}
|
||||
class CompleteWith(last: AnyRef) extends State {
|
||||
def onPush(elem: FrameStart, ctx: Context[AnyRef]): SyncDirective =
|
||||
ctx.fail(new IllegalStateException("No push expected"))
|
||||
|
||||
override def onPull(ctx: Context[AnyRef]): SyncDirective = ctx.pushAndFinish(last)
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[AnyRef]): TerminationDirective = {
|
||||
become(new CompleteWith(UserHandlerCompleted))
|
||||
ctx.absorbTermination()
|
||||
}
|
||||
override def onUpstreamFailure(cause: Throwable, ctx: Context[AnyRef]): TerminationDirective = {
|
||||
become(new CompleteWith(UserHandlerErredOut(cause)))
|
||||
ctx.absorbTermination()
|
||||
}
|
||||
}
|
||||
|
||||
lazy val userFlow =
|
||||
Flow[MessagePart]
|
||||
.transform(() ⇒ new PrepareForUserHandler)
|
||||
.splitWhen(_.isMessageEnd) // FIXME using splitAfter from #16885 would simplify protocol a lot
|
||||
.map(_.collect {
|
||||
case m: MessageDataPart ⇒ m
|
||||
})
|
||||
.via(collectMessage)
|
||||
.via(messageHandler)
|
||||
.via(MessageToFrameRenderer.create(serverSide))
|
||||
.transform(() ⇒ new LiftCompletions)
|
||||
|
||||
/**
|
||||
* Distributes output from the FrameHandler into bypass and userFlow.
|
||||
*/
|
||||
object BypassRouter
|
||||
extends FlexiRoute[Either[BypassEvent, MessagePart], FanOutShape2[Either[BypassEvent, MessagePart], BypassEvent, MessagePart]](new FanOutShape2("bypassRouter"), OperationAttributes.name("bypassRouter")) {
|
||||
def createRouteLogic(s: FanOutShape2[Either[BypassEvent, MessagePart], BypassEvent, MessagePart]): RouteLogic[Either[BypassEvent, MessagePart]] =
|
||||
new RouteLogic[Either[BypassEvent, MessagePart]] {
|
||||
def initialState: State[_] = State(DemandFromAny(s)) { (ctx, out, ev) ⇒
|
||||
ev match {
|
||||
case Left(_) ⇒
|
||||
State(DemandFrom(s.out0)) { (ctx, _, ev) ⇒ // FIXME: #17004
|
||||
ctx.emit(s.out0)(ev.left.get)
|
||||
initialState
|
||||
}
|
||||
case Right(_) ⇒
|
||||
State(DemandFrom(s.out1)) { (ctx, _, ev) ⇒
|
||||
ctx.emit(s.out1)(ev.right.get)
|
||||
initialState
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def initialCompletionHandling: CompletionHandling = super.initialCompletionHandling.copy(
|
||||
onDownstreamFinish = { (ctx, out) ⇒
|
||||
if (out == s.out0) ctx.finish()
|
||||
SameState
|
||||
})
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Merges bypass, user flow and tick source for consumption in the FrameOutHandler.
|
||||
*/
|
||||
object BypassMerge extends FlexiMerge[AnyRef, FanInShape3[BypassEvent, AnyRef, Tick.type, AnyRef]](new FanInShape3("bypassMerge"), OperationAttributes.name("bypassMerge")) {
|
||||
def createMergeLogic(s: FanInShape3[BypassEvent, AnyRef, Tick.type, AnyRef]): MergeLogic[AnyRef] =
|
||||
new MergeLogic[AnyRef] {
|
||||
def initialState: State[_] = Idle
|
||||
|
||||
lazy val Idle = State[AnyRef](FlexiMerge.ReadAny(s.in0.asInstanceOf[Inlet[AnyRef]], s.in1.asInstanceOf[Inlet[AnyRef]], s.in2.asInstanceOf[Inlet[AnyRef]])) { (ctx, in, elem) ⇒
|
||||
ctx.emit(elem)
|
||||
SameState
|
||||
}
|
||||
|
||||
override def initialCompletionHandling: CompletionHandling =
|
||||
CompletionHandling(
|
||||
onUpstreamFinish = { (ctx, in) ⇒
|
||||
if (in == s.in0) ctx.finish()
|
||||
SameState
|
||||
},
|
||||
onUpstreamFailure = { (ctx, in, cause) ⇒
|
||||
if (in == s.in0) ctx.fail(cause)
|
||||
SameState
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
lazy val bypassAndUserHandler: Flow[Either[BypassEvent, MessagePart], AnyRef, Unit] =
|
||||
Flow(BypassRouter, Source(closeTimeout, closeTimeout, Tick), BypassMerge)((_, _, _) ⇒ ()) { implicit b ⇒
|
||||
(split, tick, merge) ⇒
|
||||
import FlowGraph.Implicits._
|
||||
|
||||
split.out0 ~> merge.in0
|
||||
split.out1 ~> userFlow ~> merge.in1
|
||||
tick.outlet ~> merge.in2
|
||||
|
||||
(split.in, merge.out)
|
||||
}
|
||||
|
||||
Flow[FrameEvent]
|
||||
.via(Masking.unmaskIf(serverSide))
|
||||
.via(FrameHandler.create(server = serverSide))
|
||||
.mapConcat(x ⇒ x :: x :: Nil) // FIXME: #17004
|
||||
.via(bypassAndUserHandler)
|
||||
.transform(() ⇒ new FrameOutHandler(serverSide, closeTimeout))
|
||||
.via(Masking.maskIf(!serverSide, () ⇒ new SecureRandom()))
|
||||
}
|
||||
|
||||
object Tick
|
||||
case object SwitchToWebsocketToken
|
||||
}
|
||||
|
|
@ -0,0 +1,966 @@
|
|||
/*
|
||||
* Copyright (C) 2009-2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.http.engine.ws
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import scala.concurrent.duration._
|
||||
|
||||
import akka.stream.FlowShape
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream.testkit.StreamTestKit
|
||||
import akka.util.ByteString
|
||||
import org.scalatest.{ Matchers, FreeSpec }
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import akka.http.model.ws._
|
||||
import Protocol.Opcode
|
||||
|
||||
class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec {
|
||||
"The Websocket implementation should" - {
|
||||
"collect messages from frames" - {
|
||||
"for binary messages" - {
|
||||
"for an empty message" in new ClientTestSetup {
|
||||
val input = frameHeader(Opcode.Binary, 0, fin = true)
|
||||
|
||||
pushInput(input)
|
||||
expectMessage(BinaryMessage.Strict(ByteString.empty))
|
||||
}
|
||||
"for one complete, strict, single frame message" in new ClientTestSetup {
|
||||
val data = ByteString("abcdef", "ASCII")
|
||||
val input = frameHeader(Opcode.Binary, 6, fin = true) ++ data
|
||||
|
||||
pushInput(input)
|
||||
expectMessage(BinaryMessage.Strict(data))
|
||||
}
|
||||
"for a partial frame" in new ClientTestSetup {
|
||||
val data1 = ByteString("abc", "ASCII")
|
||||
val header = frameHeader(Opcode.Binary, 6, fin = true)
|
||||
|
||||
pushInput(header ++ data1)
|
||||
val BinaryMessage.Streamed(dataSource) = expectMessage()
|
||||
val sub = StreamTestKit.SubscriberProbe[ByteString]
|
||||
dataSource.runWith(Sink(sub))
|
||||
val s = sub.expectSubscription()
|
||||
s.request(2)
|
||||
sub.expectNext(data1)
|
||||
}
|
||||
"for a frame split up into parts" in new ClientTestSetup {
|
||||
val data1 = ByteString("abc", "ASCII")
|
||||
val header = frameHeader(Opcode.Binary, 6, fin = true)
|
||||
|
||||
pushInput(header)
|
||||
val BinaryMessage.Streamed(dataSource) = expectMessage()
|
||||
val sub = StreamTestKit.SubscriberProbe[ByteString]
|
||||
dataSource.runWith(Sink(sub))
|
||||
val s = sub.expectSubscription()
|
||||
s.request(2)
|
||||
pushInput(data1)
|
||||
sub.expectNext(data1)
|
||||
|
||||
val data2 = ByteString("def", "ASCII")
|
||||
pushInput(data2)
|
||||
sub.expectNext(data2)
|
||||
sub.expectComplete()
|
||||
}
|
||||
|
||||
"for a message split into several frames" in new ClientTestSetup {
|
||||
val data1 = ByteString("abc", "ASCII")
|
||||
val header1 = frameHeader(Opcode.Binary, 3, fin = false)
|
||||
|
||||
pushInput(header1 ++ data1)
|
||||
val BinaryMessage.Streamed(dataSource) = expectMessage()
|
||||
val sub = StreamTestKit.SubscriberProbe[ByteString]
|
||||
dataSource.runWith(Sink(sub))
|
||||
val s = sub.expectSubscription()
|
||||
s.request(2)
|
||||
sub.expectNext(data1)
|
||||
|
||||
val header2 = frameHeader(Opcode.Continuation, 4, fin = true)
|
||||
val data2 = ByteString("defg", "ASCII")
|
||||
pushInput(header2 ++ data2)
|
||||
sub.expectNext(data2)
|
||||
sub.expectComplete()
|
||||
}
|
||||
"for several messages" in new ClientTestSetup {
|
||||
val data1 = ByteString("abc", "ASCII")
|
||||
val header1 = frameHeader(Opcode.Binary, 3, fin = false)
|
||||
|
||||
pushInput(header1 ++ data1)
|
||||
val BinaryMessage.Streamed(dataSource) = expectMessage()
|
||||
val sub = StreamTestKit.SubscriberProbe[ByteString]
|
||||
dataSource.runWith(Sink(sub))
|
||||
val s = sub.expectSubscription()
|
||||
s.request(2)
|
||||
sub.expectNext(data1)
|
||||
|
||||
val header2 = frameHeader(Opcode.Continuation, 4, fin = true)
|
||||
val header3 = frameHeader(Opcode.Binary, 2, fin = true)
|
||||
val data2 = ByteString("defg", "ASCII")
|
||||
val data3 = ByteString("h")
|
||||
pushInput(header2 ++ data2 ++ header3 ++ data3)
|
||||
sub.expectNext(data2)
|
||||
sub.expectComplete()
|
||||
|
||||
val BinaryMessage.Streamed(dataSource2) = expectMessage()
|
||||
val sub2 = StreamTestKit.SubscriberProbe[ByteString]
|
||||
dataSource2.runWith(Sink(sub2))
|
||||
val s2 = sub2.expectSubscription()
|
||||
s2.request(2)
|
||||
sub2.expectNext(data3)
|
||||
|
||||
val data4 = ByteString("i")
|
||||
pushInput(data4)
|
||||
sub2.expectNext(data4)
|
||||
sub2.expectComplete()
|
||||
}
|
||||
"unmask masked input on the server side" in new ServerTestSetup {
|
||||
val mask = Random.nextInt()
|
||||
val (data, _) = maskedASCII("abcdef", mask)
|
||||
val data1 = data.take(3)
|
||||
val data2 = data.drop(3)
|
||||
val header = frameHeader(Opcode.Binary, 6, fin = true, mask = Some(mask))
|
||||
|
||||
pushInput(header ++ data1)
|
||||
val BinaryMessage.Streamed(dataSource) = expectMessage()
|
||||
val sub = StreamTestKit.SubscriberProbe[ByteString]
|
||||
dataSource.runWith(Sink(sub))
|
||||
val s = sub.expectSubscription()
|
||||
s.request(2)
|
||||
sub.expectNext(ByteString("abc", "ASCII"))
|
||||
|
||||
pushInput(data2)
|
||||
sub.expectNext(ByteString("def", "ASCII"))
|
||||
sub.expectComplete()
|
||||
}
|
||||
}
|
||||
"for text messages" - {
|
||||
"empty message" in new ClientTestSetup {
|
||||
val input = frameHeader(Opcode.Text, 0, fin = true)
|
||||
|
||||
pushInput(input)
|
||||
expectMessage(TextMessage.Strict(""))
|
||||
}
|
||||
"decode complete, strict frame from utf8" in new ClientTestSetup {
|
||||
val msg = "äbcdef€\uffff"
|
||||
val data = ByteString(msg, "UTF-8")
|
||||
val input = frameHeader(Opcode.Text, data.size, fin = true) ++ data
|
||||
|
||||
pushInput(input)
|
||||
expectMessage(TextMessage.Strict(msg))
|
||||
}
|
||||
"decode utf8 as far as possible for partial frame" in new ClientTestSetup {
|
||||
val msg = "bäcdef€"
|
||||
val data = ByteString(msg, "UTF-8")
|
||||
val data0 = data.slice(0, 2)
|
||||
val data1 = data.slice(2, 5)
|
||||
val data2 = data.slice(5, data.size)
|
||||
val input = frameHeader(Opcode.Text, data.size, fin = true) ++ data0
|
||||
|
||||
pushInput(input)
|
||||
val TextMessage.Streamed(parts) = expectMessage()
|
||||
val sub = StreamTestKit.SubscriberProbe[String]
|
||||
parts.runWith(Sink(sub))
|
||||
val s = sub.expectSubscription()
|
||||
s.request(4)
|
||||
sub.expectNext("b")
|
||||
|
||||
pushInput(data1)
|
||||
sub.expectNext("äcd")
|
||||
}
|
||||
"decode utf8 with code point split across frames" in new ClientTestSetup {
|
||||
val msg = "äbcdef€"
|
||||
val data = ByteString(msg, "UTF-8")
|
||||
val data0 = data.slice(0, 1)
|
||||
val data1 = data.slice(1, data.size)
|
||||
val header0 = frameHeader(Opcode.Text, data0.size, fin = false)
|
||||
|
||||
pushInput(header0 ++ data0)
|
||||
val TextMessage.Streamed(parts) = expectMessage()
|
||||
val sub = StreamTestKit.SubscriberProbe[String]
|
||||
parts.runWith(Sink(sub))
|
||||
val s = sub.expectSubscription()
|
||||
s.request(4)
|
||||
sub.expectNoMsg(100.millis)
|
||||
|
||||
val header1 = frameHeader(Opcode.Continuation, data1.size, fin = true)
|
||||
pushInput(header1 ++ data1)
|
||||
sub.expectNext("äbcdef€")
|
||||
}
|
||||
"unmask masked input on the server side" in new ServerTestSetup {
|
||||
val mask = Random.nextInt()
|
||||
val (data, _) = maskedUTF8("äbcdef€", mask)
|
||||
val data1 = data.take(3)
|
||||
val data2 = data.drop(3)
|
||||
val header = frameHeader(Opcode.Binary, data.size, fin = true, mask = Some(mask))
|
||||
|
||||
pushInput(header ++ data1)
|
||||
val BinaryMessage.Streamed(dataSource) = expectMessage()
|
||||
val sub = StreamTestKit.SubscriberProbe[ByteString]
|
||||
dataSource.runWith(Sink(sub))
|
||||
val s = sub.expectSubscription()
|
||||
s.request(2)
|
||||
sub.expectNext(ByteString("äb", "UTF-8"))
|
||||
|
||||
pushInput(data2)
|
||||
sub.expectNext(ByteString("cdef€", "UTF-8"))
|
||||
sub.expectComplete()
|
||||
}
|
||||
}
|
||||
}
|
||||
"render frames from messages" - {
|
||||
"for binary messages" - {
|
||||
"for a short strict message" in new ServerTestSetup {
|
||||
val data = ByteString("abcdef", "ASCII")
|
||||
val msg = BinaryMessage.Strict(data)
|
||||
netOutSub.request(5)
|
||||
pushMessage(msg)
|
||||
|
||||
expectFrameOnNetwork(Opcode.Binary, data, fin = true)
|
||||
}
|
||||
"for a strict message larger than configured maximum frame size" in pending
|
||||
"for a streamed message" in new ServerTestSetup {
|
||||
val data = ByteString("abcdefg", "ASCII")
|
||||
val pub = StreamTestKit.PublisherProbe[ByteString]
|
||||
val msg = BinaryMessage.Streamed(Source(pub))
|
||||
netOutSub.request(6)
|
||||
pushMessage(msg)
|
||||
val sub = pub.expectSubscription()
|
||||
|
||||
expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false)
|
||||
|
||||
val data1 = data.take(3)
|
||||
val data2 = data.drop(3)
|
||||
|
||||
sub.sendNext(data1)
|
||||
expectFrameOnNetwork(Opcode.Continuation, data1, fin = false)
|
||||
|
||||
sub.sendNext(data2)
|
||||
expectFrameOnNetwork(Opcode.Continuation, data2, fin = false)
|
||||
|
||||
sub.sendComplete()
|
||||
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
|
||||
}
|
||||
"for a streamed message with a chunk being larger than configured maximum frame size" in pending
|
||||
"and mask input on the client side" in new ClientTestSetup {
|
||||
val data = ByteString("abcdefg", "ASCII")
|
||||
val pub = StreamTestKit.PublisherProbe[ByteString]
|
||||
val msg = BinaryMessage.Streamed(Source(pub))
|
||||
netOutSub.request(7)
|
||||
pushMessage(msg)
|
||||
val sub = pub.expectSubscription()
|
||||
|
||||
expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false)
|
||||
|
||||
val data1 = data.take(3)
|
||||
val data2 = data.drop(3)
|
||||
|
||||
sub.sendNext(data1)
|
||||
expectMaskedFrameOnNetwork(Opcode.Continuation, data1, fin = false)
|
||||
|
||||
sub.sendNext(data2)
|
||||
expectMaskedFrameOnNetwork(Opcode.Continuation, data2, fin = false)
|
||||
|
||||
sub.sendComplete()
|
||||
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
|
||||
}
|
||||
}
|
||||
"for text messages" - {
|
||||
"for a short strict message" in new ServerTestSetup {
|
||||
val text = "äbcdef"
|
||||
val msg = TextMessage.Strict(text)
|
||||
netOutSub.request(5)
|
||||
pushMessage(msg)
|
||||
|
||||
expectFrameOnNetwork(Opcode.Text, ByteString(text, "UTF-8"), fin = true)
|
||||
}
|
||||
"for a strict message larger than configured maximum frame size" in pending
|
||||
"for a streamed message" in new ServerTestSetup {
|
||||
val text = "äbcd€fg"
|
||||
val pub = StreamTestKit.PublisherProbe[String]
|
||||
val msg = TextMessage.Streamed(Source(pub))
|
||||
netOutSub.request(6)
|
||||
pushMessage(msg)
|
||||
val sub = pub.expectSubscription()
|
||||
|
||||
expectFrameHeaderOnNetwork(Opcode.Text, 0, fin = false)
|
||||
|
||||
val text1 = text.take(3)
|
||||
val text1Bytes = ByteString(text1, "UTF-8")
|
||||
val text2 = text.drop(3)
|
||||
val text2Bytes = ByteString(text2, "UTF-8")
|
||||
|
||||
sub.sendNext(text1)
|
||||
expectFrameOnNetwork(Opcode.Continuation, text1Bytes, fin = false)
|
||||
|
||||
sub.sendNext(text2)
|
||||
expectFrameOnNetwork(Opcode.Continuation, text2Bytes, fin = false)
|
||||
|
||||
sub.sendComplete()
|
||||
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
|
||||
}
|
||||
"for a streamed message don't convert half surrogate pairs naively" in new ServerTestSetup {
|
||||
val gclef = "𝄞"
|
||||
gclef.size shouldEqual 2
|
||||
|
||||
// split up the code point
|
||||
val half1 = gclef.take(1)
|
||||
val half2 = gclef.drop(1)
|
||||
println(half1(0).toInt.toHexString)
|
||||
println(half2(0).toInt.toHexString)
|
||||
|
||||
val pub = StreamTestKit.PublisherProbe[String]
|
||||
val msg = TextMessage.Streamed(Source(pub))
|
||||
netOutSub.request(6)
|
||||
|
||||
pushMessage(msg)
|
||||
val sub = pub.expectSubscription()
|
||||
|
||||
expectFrameHeaderOnNetwork(Opcode.Text, 0, fin = false)
|
||||
sub.sendNext(half1)
|
||||
|
||||
expectNoNetworkData()
|
||||
sub.sendNext(half2)
|
||||
expectFrameOnNetwork(Opcode.Continuation, ByteString(gclef, "utf8"), fin = false)
|
||||
}
|
||||
"for a streamed message with a chunk being larger than configured maximum frame size" in pending
|
||||
"and mask input on the client side" in new ClientTestSetup {
|
||||
val text = "abcdefg"
|
||||
val pub = StreamTestKit.PublisherProbe[String]
|
||||
val msg = TextMessage.Streamed(Source(pub))
|
||||
netOutSub.request(5)
|
||||
pushMessage(msg)
|
||||
val sub = pub.expectSubscription()
|
||||
|
||||
expectFrameOnNetwork(Opcode.Text, ByteString.empty, fin = false)
|
||||
|
||||
val text1 = text.take(3)
|
||||
val text1Bytes = ByteString(text1, "UTF-8")
|
||||
val text2 = text.drop(3)
|
||||
val text2Bytes = ByteString(text2, "UTF-8")
|
||||
|
||||
sub.sendNext(text1)
|
||||
expectMaskedFrameOnNetwork(Opcode.Continuation, text1Bytes, fin = false)
|
||||
|
||||
sub.sendNext(text2)
|
||||
expectMaskedFrameOnNetwork(Opcode.Continuation, text2Bytes, fin = false)
|
||||
|
||||
sub.sendComplete()
|
||||
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
"supply automatic low-level websocket behavior" - {
|
||||
"respond to ping frames unmasking them on the server side" in new ServerTestSetup {
|
||||
val mask = Random.nextInt()
|
||||
val input = frameHeader(Opcode.Ping, 6, fin = true, mask = Some(mask)) ++ maskedASCII("abcdef", mask)._1
|
||||
|
||||
pushInput(input)
|
||||
netOutSub.request(5)
|
||||
expectFrameOnNetwork(Opcode.Pong, ByteString("abcdef"), fin = true)
|
||||
}
|
||||
"respond to ping frames masking them on the client side" in new ClientTestSetup {
|
||||
val input = frameHeader(Opcode.Ping, 6, fin = true) ++ ByteString("abcdef")
|
||||
|
||||
pushInput(input)
|
||||
netOutSub.request(5)
|
||||
expectMaskedFrameOnNetwork(Opcode.Pong, ByteString("abcdef"), fin = true)
|
||||
}
|
||||
"respond to ping frames interleaved with data frames (without mixing frame data)" in new ServerTestSetup {
|
||||
// receive multi-frame message
|
||||
// receive and handle interleaved ping frame
|
||||
// concurrently send out messages from handler
|
||||
val mask1 = Random.nextInt()
|
||||
val input1 = frameHeader(Opcode.Binary, 3, fin = false, mask = Some(mask1)) ++ maskedASCII("123", mask1)._1
|
||||
pushInput(input1)
|
||||
|
||||
val BinaryMessage.Streamed(dataSource) = expectMessage()
|
||||
val sub = StreamTestKit.SubscriberProbe[ByteString]
|
||||
dataSource.runWith(Sink(sub))
|
||||
val s = sub.expectSubscription()
|
||||
s.request(2)
|
||||
sub.expectNext(ByteString("123", "ASCII"))
|
||||
|
||||
val outPub = StreamTestKit.PublisherProbe[ByteString]
|
||||
val msg = BinaryMessage.Streamed(Source(outPub))
|
||||
netOutSub.request(10)
|
||||
pushMessage(msg)
|
||||
|
||||
expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false)
|
||||
|
||||
val outSub = outPub.expectSubscription()
|
||||
val outData1 = ByteString("abc", "ASCII")
|
||||
outSub.sendNext(outData1)
|
||||
expectFrameOnNetwork(Opcode.Continuation, outData1, fin = false)
|
||||
|
||||
val pingMask = Random.nextInt()
|
||||
val pingData = maskedASCII("pling", pingMask)._1
|
||||
val pingData0 = pingData.take(3)
|
||||
val pingData1 = pingData.drop(3)
|
||||
pushInput(frameHeader(Opcode.Ping, 5, fin = true, mask = Some(pingMask)) ++ pingData0)
|
||||
expectNoNetworkData()
|
||||
pushInput(pingData1)
|
||||
expectFrameOnNetwork(Opcode.Pong, ByteString("pling", "ASCII"), fin = true)
|
||||
|
||||
val outData2 = ByteString("def", "ASCII")
|
||||
outSub.sendNext(outData2)
|
||||
expectFrameOnNetwork(Opcode.Continuation, outData2, fin = false)
|
||||
|
||||
outSub.sendComplete()
|
||||
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
|
||||
|
||||
val mask2 = Random.nextInt()
|
||||
val input2 = frameHeader(Opcode.Continuation, 3, fin = true, mask = Some(mask2)) ++ maskedASCII("456", mask2)._1
|
||||
pushInput(input2)
|
||||
sub.expectNext(ByteString("456", "ASCII"))
|
||||
sub.expectComplete()
|
||||
}
|
||||
"don't respond to unsolicited pong frames" in new ClientTestSetup {
|
||||
val data = frameHeader(Opcode.Pong, 6, fin = true) ++ ByteString("abcdef")
|
||||
pushInput(data)
|
||||
netOutSub.request(5)
|
||||
expectNoNetworkData()
|
||||
}
|
||||
}
|
||||
"provide close behavior" - {
|
||||
"after receiving regular close frame when idle (user closes immediately)" in new ServerTestSetup {
|
||||
netInSub.expectRequest()
|
||||
netOutSub.request(20)
|
||||
messageOutSub.request(20)
|
||||
|
||||
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
|
||||
messageIn.expectComplete()
|
||||
|
||||
netIn.expectNoMsg(1.second) // especially the cancellation not yet
|
||||
expectNoNetworkData()
|
||||
messageOutSub.sendComplete()
|
||||
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
|
||||
netOut.expectComplete()
|
||||
netInSub.expectCancellation()
|
||||
}
|
||||
"after receiving close frame without close code" in new ServerTestSetup {
|
||||
netInSub.expectRequest()
|
||||
pushInput(frameHeader(Opcode.Close, 0, fin = true))
|
||||
messageIn.expectComplete()
|
||||
|
||||
messageOutSub.sendComplete()
|
||||
// especially mustn't be Procotol.CloseCodes.NoCodePresent
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
|
||||
netOut.expectComplete()
|
||||
netInSub.expectCancellation()
|
||||
}
|
||||
"after receiving regular close frame when idle (user still sends some data)" in new ServerTestSetup {
|
||||
netOutSub.request(20)
|
||||
messageOutSub.request(20)
|
||||
|
||||
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
|
||||
messageIn.expectComplete()
|
||||
|
||||
// sending another message is allowed before closing (inherently racy)
|
||||
val pub = StreamTestKit.PublisherProbe[ByteString]
|
||||
val msg = BinaryMessage.Streamed(Source(pub))
|
||||
pushMessage(msg)
|
||||
expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false)
|
||||
|
||||
val data = ByteString("abc", "ASCII")
|
||||
val dataSub = pub.expectSubscription()
|
||||
dataSub.sendNext(data)
|
||||
expectFrameOnNetwork(Opcode.Continuation, data, fin = false)
|
||||
|
||||
dataSub.sendComplete()
|
||||
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
|
||||
|
||||
messageOutSub.sendComplete()
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
|
||||
netOut.expectComplete()
|
||||
}
|
||||
"after receiving regular close frame when fragmented message is still open" in pendingUntilFixed {
|
||||
new ServerTestSetup {
|
||||
netOutSub.request(10)
|
||||
messageInSub.request(10)
|
||||
|
||||
pushInput(frameHeader(Protocol.Opcode.Binary, 0, fin = false))
|
||||
val BinaryMessage.Streamed(dataSource) = messageIn.expectNext()
|
||||
val inSubscriber = StreamTestKit.SubscriberProbe[ByteString]
|
||||
dataSource.runWith(Sink(inSubscriber))
|
||||
val inSub = inSubscriber.expectSubscription()
|
||||
|
||||
val outData = ByteString("def", "ASCII")
|
||||
val mask = Random.nextInt()
|
||||
pushInput(frameHeader(Protocol.Opcode.Continuation, 3, fin = false, mask = Some(mask)) ++ maskedBytes(outData, mask)._1)
|
||||
inSub.request(5)
|
||||
inSubscriber.expectNext(outData)
|
||||
|
||||
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
|
||||
messageIn.expectComplete()
|
||||
inSubscriber.expectError()
|
||||
// truncation of open message
|
||||
|
||||
// sending another message is allowed before closing (inherently racy)
|
||||
|
||||
val pub = StreamTestKit.PublisherProbe[ByteString]
|
||||
val msg = BinaryMessage.Streamed(Source(pub))
|
||||
pushMessage(msg)
|
||||
expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false)
|
||||
|
||||
val data = ByteString("abc", "ASCII")
|
||||
val dataSub = pub.expectSubscription()
|
||||
dataSub.sendNext(data)
|
||||
expectFrameOnNetwork(Opcode.Continuation, data, fin = false)
|
||||
|
||||
dataSub.sendComplete()
|
||||
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
|
||||
|
||||
messageOutSub.sendComplete()
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
|
||||
netOut.expectComplete()
|
||||
}
|
||||
}
|
||||
"after receiving error close frame" in pending
|
||||
"after peer closes connection without sending a close frame" in new ServerTestSetup {
|
||||
netInSub.expectRequest()
|
||||
netInSub.sendComplete()
|
||||
|
||||
messageIn.expectComplete()
|
||||
messageOutSub.sendComplete()
|
||||
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
|
||||
netOut.expectComplete()
|
||||
}
|
||||
"when user handler closes (simple)" in new ServerTestSetup {
|
||||
messageOutSub.sendComplete()
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
|
||||
|
||||
netOut.expectNoMsg(1.second) // wait for peer to close regularly
|
||||
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
|
||||
|
||||
messageIn.expectComplete()
|
||||
netOut.expectComplete()
|
||||
netInSub.expectCancellation()
|
||||
}
|
||||
"when user handler closes main stream and substream only afterwards" in new ServerTestSetup {
|
||||
netOutSub.request(10)
|
||||
messageInSub.request(10)
|
||||
|
||||
// send half a message
|
||||
val pub = StreamTestKit.PublisherProbe[ByteString]
|
||||
val msg = BinaryMessage.Streamed(Source(pub))
|
||||
pushMessage(msg)
|
||||
expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false)
|
||||
|
||||
val data = ByteString("abc", "ASCII")
|
||||
val dataSub = pub.expectSubscription()
|
||||
dataSub.sendNext(data)
|
||||
expectFrameOnNetwork(Opcode.Continuation, data, fin = false)
|
||||
|
||||
messageOutSub.sendComplete()
|
||||
expectNoNetworkData() // need to wait for substream to close
|
||||
|
||||
dataSub.sendComplete()
|
||||
expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true)
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
|
||||
netOut.expectNoMsg(1.second) // wait for peer to close regularly
|
||||
|
||||
val mask = Random.nextInt()
|
||||
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
|
||||
|
||||
messageIn.expectComplete()
|
||||
netOut.expectComplete()
|
||||
netInSub.expectCancellation()
|
||||
}
|
||||
"if user handler fails" in pending
|
||||
"if peer closes with invalid close frame" - {
|
||||
"close code outside of the valid range" in new ServerTestSetup {
|
||||
netInSub.expectRequest()
|
||||
pushInput(frameHeader(Opcode.Close, 1, mask = Some(Random.nextInt()), fin = true) ++ ByteString("x"))
|
||||
|
||||
val error = messageIn.expectError()
|
||||
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError)
|
||||
netOut.expectComplete()
|
||||
netInSub.expectCancellation()
|
||||
}
|
||||
"close data of size 1" in new ServerTestSetup {
|
||||
netInSub.expectRequest()
|
||||
pushInput(frameHeader(Opcode.Close, 1, mask = Some(Random.nextInt()), fin = true) ++ ByteString("x"))
|
||||
|
||||
val error = messageIn.expectError()
|
||||
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError)
|
||||
netOut.expectComplete()
|
||||
netInSub.expectCancellation()
|
||||
}
|
||||
"reason is no valid utf8 data" in pending
|
||||
}
|
||||
"timeout if user handler closes and peer doesn't send a close frame" in new ServerTestSetup {
|
||||
netInSub.expectRequest()
|
||||
messageOutSub.sendComplete()
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
|
||||
|
||||
netOut.expectComplete()
|
||||
netInSub.expectCancellation()
|
||||
}
|
||||
"timeout after we close after error and peer doesn't send a close frame" in new ServerTestSetup {
|
||||
netInSub.expectRequest()
|
||||
|
||||
pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv1 = true))
|
||||
expectProtocolErrorOnNetwork()
|
||||
messageOutSub.sendComplete()
|
||||
|
||||
netOut.expectComplete()
|
||||
netInSub.expectCancellation()
|
||||
}
|
||||
"ignore frames peer sends after close frame" in new ServerTestSetup {
|
||||
netInSub.expectRequest()
|
||||
pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true))
|
||||
|
||||
messageIn.expectComplete()
|
||||
|
||||
pushInput(frameHeader(Opcode.Binary, 0, fin = true))
|
||||
messageOutSub.sendComplete()
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular)
|
||||
|
||||
netOut.expectComplete()
|
||||
netInSub.expectCancellation()
|
||||
}
|
||||
}
|
||||
"reject unexpected frames" - {
|
||||
"reserved bits set" - {
|
||||
"rsv1" in new ServerTestSetup {
|
||||
pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv1 = true))
|
||||
expectProtocolErrorOnNetwork()
|
||||
}
|
||||
"rsv2" in new ServerTestSetup {
|
||||
pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv2 = true))
|
||||
expectProtocolErrorOnNetwork()
|
||||
}
|
||||
"rsv3" in new ServerTestSetup {
|
||||
pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv3 = true))
|
||||
expectProtocolErrorOnNetwork()
|
||||
}
|
||||
}
|
||||
"highest bit of 64-bit length is set" in new ServerTestSetup {
|
||||
import BitBuilder._
|
||||
|
||||
val header =
|
||||
b"""0000 # flags
|
||||
xxxx=1 # opcode
|
||||
1 # mask?
|
||||
xxxxxxx=7f # length
|
||||
xxxxxxxx
|
||||
xxxxxxxx
|
||||
xxxxxxxx
|
||||
xxxxxxxx=ffffffff
|
||||
xxxxxxxx
|
||||
xxxxxxxx
|
||||
xxxxxxxx
|
||||
xxxxxxxx=ffffffff # length64
|
||||
00000000
|
||||
00000000
|
||||
00000000
|
||||
00000000 # empty mask
|
||||
"""
|
||||
|
||||
pushInput(header)
|
||||
expectProtocolErrorOnNetwork()
|
||||
}
|
||||
"control frame bigger than 125 bytes" in new ServerTestSetup {
|
||||
pushInput(frameHeader(Opcode.Ping, 126, fin = true, mask = Some(0)))
|
||||
expectProtocolErrorOnNetwork()
|
||||
}
|
||||
"fragmented control frame" in new ServerTestSetup {
|
||||
pushInput(frameHeader(Opcode.Ping, 0, fin = false, mask = Some(0)))
|
||||
expectProtocolErrorOnNetwork()
|
||||
}
|
||||
"unexpected continuation frame" in new ServerTestSetup {
|
||||
pushInput(frameHeader(Opcode.Continuation, 0, fin = false, mask = Some(0)))
|
||||
expectProtocolErrorOnNetwork()
|
||||
}
|
||||
"unexpected data frame when waiting for continuation" in new ServerTestSetup {
|
||||
pushInput(frameHeader(Opcode.Binary, 0, fin = false) ++
|
||||
frameHeader(Opcode.Binary, 0, fin = false))
|
||||
expectProtocolErrorOnNetwork()
|
||||
}
|
||||
"invalid utf8 encoding for single frame message" in new ClientTestSetup {
|
||||
val data = ByteString(
|
||||
(128 + 64).toByte, // start two byte sequence
|
||||
0 // but don't finish it
|
||||
)
|
||||
|
||||
pushInput(frameHeader(Opcode.Text, 2, fin = true) ++ data)
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
|
||||
}
|
||||
"invalid utf8 encoding for streamed frame" in new ClientTestSetup {
|
||||
val data = ByteString(
|
||||
(128 + 64).toByte, // start two byte sequence
|
||||
0 // but don't finish it
|
||||
)
|
||||
|
||||
pushInput(frameHeader(Opcode.Text, 0, fin = false) ++
|
||||
frameHeader(Opcode.Continuation, 2, fin = true) ++
|
||||
data)
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
|
||||
}
|
||||
"truncated utf8 encoding for single frame message" in new ClientTestSetup {
|
||||
val data = ByteString("€", "UTF-8").take(1) // half a euro
|
||||
pushInput(frameHeader(Opcode.Text, 1, fin = true) ++ data)
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
|
||||
}
|
||||
"truncated utf8 encoding for streamed frame" in new ClientTestSetup {
|
||||
val data = ByteString("€", "UTF-8").take(1) // half a euro
|
||||
pushInput(frameHeader(Opcode.Text, 0, fin = false) ++
|
||||
frameHeader(Opcode.Continuation, 1, fin = true) ++
|
||||
data)
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
|
||||
}
|
||||
"half a surrogate pair in utf8 encoding for a strict frame" in new ClientTestSetup {
|
||||
val data = ByteString(0xed, 0xa0, 0x80) // not strictly supported by utf-8
|
||||
pushInput(frameHeader(Opcode.Text, 3, fin = true) ++ data)
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
|
||||
}
|
||||
"half a surrogate pair in utf8 encoding for a streamed frame" in new ClientTestSetup {
|
||||
val data = ByteString(0xed, 0xa0, 0x80) // not strictly supported by utf-8
|
||||
pushInput(frameHeader(Opcode.Text, 0, fin = false))
|
||||
pushInput(frameHeader(Opcode.Continuation, 3, fin = true) ++ data)
|
||||
|
||||
messageIn.expectError()
|
||||
|
||||
expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData)
|
||||
}
|
||||
"unmasked input on the server side" in new ServerTestSetup {
|
||||
val data = ByteString("abcdef", "ASCII")
|
||||
val input = frameHeader(Opcode.Binary, 6, fin = true) ++ data
|
||||
|
||||
pushInput(input)
|
||||
expectProtocolErrorOnNetwork()
|
||||
}
|
||||
"masked input on the client side" in new ClientTestSetup {
|
||||
val mask = Random.nextInt()
|
||||
val input = frameHeader(Opcode.Binary, 6, fin = true, mask = Some(mask)) ++ maskedASCII("abcdef", mask)._1
|
||||
|
||||
pushInput(input)
|
||||
expectProtocolErrorOnNetwork()
|
||||
}
|
||||
}
|
||||
"support per-message-compression extension" in pending
|
||||
}
|
||||
|
||||
class ServerTestSetup extends TestSetup {
|
||||
protected def serverSide: Boolean = true
|
||||
}
|
||||
class ClientTestSetup extends TestSetup {
|
||||
protected def serverSide: Boolean = false
|
||||
}
|
||||
abstract class TestSetup {
|
||||
protected def serverSide: Boolean
|
||||
protected def closeTimeout: FiniteDuration = 1.second
|
||||
|
||||
val netIn = StreamTestKit.PublisherProbe[ByteString]
|
||||
val netOut = StreamTestKit.SubscriberProbe[ByteString]
|
||||
|
||||
val messageIn = StreamTestKit.SubscriberProbe[Message]
|
||||
val messageOut = StreamTestKit.PublisherProbe[Message]
|
||||
|
||||
val messageHandler: Flow[Message, Message, Unit] =
|
||||
Flow.wrap {
|
||||
FlowGraph.partial() { implicit b ⇒
|
||||
val in = b.add(Sink(messageIn))
|
||||
val out = b.add(Source(messageOut))
|
||||
|
||||
FlowShape[Message, Message](in, out)
|
||||
}
|
||||
}
|
||||
|
||||
Source(netIn)
|
||||
.via(printEvent("netIn"))
|
||||
.transform(() ⇒ new FrameEventParser)
|
||||
.via(Websocket.handleMessages(messageHandler, serverSide, closeTimeout = closeTimeout))
|
||||
.via(printEvent("frameRendererIn"))
|
||||
.transform(() ⇒ new FrameEventRenderer)
|
||||
.via(printEvent("frameRendererOut"))
|
||||
.to(Sink(netOut))
|
||||
.run()
|
||||
|
||||
val netInSub = netIn.expectSubscription()
|
||||
val netOutSub = netOut.expectSubscription()
|
||||
val messageOutSub = messageOut.expectSubscription()
|
||||
val messageInSub = messageIn.expectSubscription()
|
||||
|
||||
def pushInput(data: ByteString): Unit = {
|
||||
// TODO: expect/handle request?
|
||||
netInSub.sendNext(data)
|
||||
}
|
||||
def pushMessage(msg: Message): Unit = {
|
||||
messageOutSub.sendNext(msg)
|
||||
}
|
||||
|
||||
def expectMessage(message: Message): Unit = {
|
||||
messageInSub.request(1)
|
||||
messageIn.expectNext(message)
|
||||
}
|
||||
def expectMessage(): Message = {
|
||||
messageInSub.request(1)
|
||||
messageIn.expectNext()
|
||||
}
|
||||
|
||||
var inBuffer = ByteString.empty
|
||||
@tailrec final def expectNetworkData(bytes: Int): ByteString =
|
||||
if (inBuffer.size >= bytes) {
|
||||
val res = inBuffer.take(bytes)
|
||||
inBuffer = inBuffer.drop(bytes)
|
||||
res
|
||||
} else {
|
||||
netOutSub.request(1)
|
||||
inBuffer ++= netOut.expectNext()
|
||||
expectNetworkData(bytes)
|
||||
}
|
||||
|
||||
def expectNetworkData(data: ByteString): Unit =
|
||||
expectNetworkData(data.size) shouldEqual data
|
||||
|
||||
def expectFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = {
|
||||
expectFrameHeaderOnNetwork(opcode, data.size, fin)
|
||||
expectNetworkData(data)
|
||||
}
|
||||
def expectMaskedFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = {
|
||||
val Some(mask) = expectFrameHeaderOnNetwork(opcode, data.size, fin)
|
||||
val masked = maskedBytes(data, mask)._1
|
||||
expectNetworkData(masked)
|
||||
}
|
||||
|
||||
/** Returns the mask if any is available */
|
||||
def expectFrameHeaderOnNetwork(opcode: Opcode, length: Long, fin: Boolean): Option[Int] = {
|
||||
val (op, l, f, m) = expectFrameHeaderOnNetwork()
|
||||
op shouldEqual opcode
|
||||
l shouldEqual length
|
||||
f shouldEqual fin
|
||||
m
|
||||
}
|
||||
def expectFrameHeaderOnNetwork(): (Opcode, Long, Boolean, Option[Int]) = {
|
||||
val header = expectNetworkData(2)
|
||||
|
||||
val fin = (header(0) & Protocol.FIN_MASK) != 0
|
||||
val op = header(0) & Protocol.OP_MASK
|
||||
|
||||
val hasMask = (header(1) & Protocol.MASK_MASK) != 0
|
||||
val length7 = header(1) & Protocol.LENGTH_MASK
|
||||
val length = length7 match {
|
||||
case 126 ⇒
|
||||
val length16Bytes = expectNetworkData(2)
|
||||
(length16Bytes(0) & 0xff) << 8 | (length16Bytes(1) & 0xff) << 0
|
||||
case 127 ⇒
|
||||
val length64Bytes = expectNetworkData(8)
|
||||
(length64Bytes(0) & 0xff).toLong << 56 |
|
||||
(length64Bytes(1) & 0xff).toLong << 48 |
|
||||
(length64Bytes(2) & 0xff).toLong << 40 |
|
||||
(length64Bytes(3) & 0xff).toLong << 32 |
|
||||
(length64Bytes(4) & 0xff).toLong << 24 |
|
||||
(length64Bytes(5) & 0xff).toLong << 16 |
|
||||
(length64Bytes(6) & 0xff).toLong << 8 |
|
||||
(length64Bytes(7) & 0xff).toLong << 0
|
||||
case x ⇒ x
|
||||
}
|
||||
val mask =
|
||||
if (hasMask) {
|
||||
val maskBytes = expectNetworkData(4)
|
||||
val mask =
|
||||
(maskBytes(0) & 0xff) << 24 |
|
||||
(maskBytes(1) & 0xff) << 16 |
|
||||
(maskBytes(2) & 0xff) << 8 |
|
||||
(maskBytes(3) & 0xff) << 0
|
||||
Some(mask)
|
||||
} else None
|
||||
|
||||
(Opcode.forCode(op.toByte), length, fin, mask)
|
||||
}
|
||||
|
||||
def expectProtocolErrorOnNetwork(): Unit = expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError)
|
||||
def expectCloseCodeOnNetwork(expectedCode: Int): Unit = {
|
||||
val (opcode, length, true, mask) = expectFrameHeaderOnNetwork()
|
||||
opcode shouldEqual Opcode.Close
|
||||
length should be >= 2.toLong
|
||||
|
||||
val rawData = expectNetworkData(length.toInt)
|
||||
val data = mask match {
|
||||
case Some(m) ⇒ FrameEventParser.mask(rawData, m)._1
|
||||
case None ⇒ rawData
|
||||
}
|
||||
|
||||
val code = ((data(0) & 0xff) << 8) | ((data(1) & 0xff) << 0)
|
||||
code shouldEqual expectedCode
|
||||
}
|
||||
|
||||
def expectNoNetworkData(): Unit =
|
||||
netOut.expectNoMsg(100.millis)
|
||||
}
|
||||
|
||||
def frameHeader(
|
||||
opcode: Opcode,
|
||||
length: Long,
|
||||
fin: Boolean,
|
||||
mask: Option[Int] = None,
|
||||
rsv1: Boolean = false,
|
||||
rsv2: Boolean = false,
|
||||
rsv3: Boolean = false): ByteString = {
|
||||
def set(should: Boolean, mask: Int): Int =
|
||||
if (should) mask else 0
|
||||
|
||||
val flags =
|
||||
set(fin, Protocol.FIN_MASK) |
|
||||
set(rsv1, Protocol.RSV1_MASK) |
|
||||
set(rsv2, Protocol.RSV2_MASK) |
|
||||
set(rsv3, Protocol.RSV3_MASK)
|
||||
|
||||
val opcodeByte = opcode.code | flags
|
||||
|
||||
require(length >= 0)
|
||||
val (lengthByteComponent, lengthBytes) =
|
||||
if (length < 126) (length.toByte, ByteString.empty)
|
||||
else if (length < 65536) (126.toByte, shortBE(length.toInt))
|
||||
else throw new IllegalArgumentException("Only lengths < 65536 allowed in test")
|
||||
|
||||
val maskMask = if (mask.isDefined) Protocol.MASK_MASK else 0
|
||||
val maskBytes = mask match {
|
||||
case Some(mask) ⇒ intBE(mask)
|
||||
case None ⇒ ByteString.empty
|
||||
}
|
||||
val lengthByte = lengthByteComponent | maskMask
|
||||
ByteString(opcodeByte.toByte, lengthByte.toByte) ++ lengthBytes ++ maskBytes
|
||||
}
|
||||
def closeFrame(closeCode: Int, mask: Boolean): ByteString =
|
||||
if (mask) {
|
||||
val mask = Random.nextInt()
|
||||
frameHeader(Opcode.Close, 2, fin = true, mask = Some(mask)) ++
|
||||
maskedBytes(shortBE(closeCode), mask)._1
|
||||
} else
|
||||
frameHeader(Opcode.Close, 2, fin = true) ++
|
||||
shortBE(closeCode)
|
||||
|
||||
def maskedASCII(str: String, mask: Int): (ByteString, Int) =
|
||||
FrameEventParser.mask(ByteString(str, "ASCII"), mask)
|
||||
def maskedUTF8(str: String, mask: Int): (ByteString, Int) =
|
||||
FrameEventParser.mask(ByteString(str, "UTF-8"), mask)
|
||||
def maskedBytes(bytes: ByteString, mask: Int): (ByteString, Int) =
|
||||
FrameEventParser.mask(bytes, mask)
|
||||
|
||||
def shortBE(value: Int): ByteString = {
|
||||
require(value >= 0 && value < 65536, s"Value wasn't in short range: $value")
|
||||
ByteString(
|
||||
((value >> 8) & 0xff).toByte,
|
||||
((value >> 0) & 0xff).toByte)
|
||||
}
|
||||
def intBE(value: Int): ByteString =
|
||||
ByteString(
|
||||
((value >> 24) & 0xff).toByte,
|
||||
((value >> 16) & 0xff).toByte,
|
||||
((value >> 8) & 0xff).toByte,
|
||||
((value >> 0) & 0xff).toByte)
|
||||
|
||||
val trace = false // set to `true` for debugging purposes
|
||||
def printEvent[T](marker: String): Flow[T, T, Unit] =
|
||||
if (trace) akka.http.util.printEvent(marker)
|
||||
else Flow[T]
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue