From 6dafa445deb7d4645cf65c80a1b35bf871f764ed Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Tue, 21 Apr 2015 15:21:22 +0200 Subject: [PATCH] =htc #16887 implement high-level server-side Websocket API --- .../akka/http/engine/ws/FrameHandler.scala | 200 ++++ .../akka/http/engine/ws/FrameOutHandler.scala | 132 +++ .../scala/akka/http/engine/ws/Masking.scala | 71 ++ .../engine/ws/MessageToFrameRenderer.scala | 37 + .../scala/akka/http/engine/ws/Websocket.scala | 179 ++++ .../akka/http/engine/ws/MessageSpec.scala | 966 ++++++++++++++++++ 6 files changed, 1585 insertions(+) create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/FrameHandler.scala create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/FrameOutHandler.scala create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/Masking.scala create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/MessageToFrameRenderer.scala create mode 100644 akka-http-core/src/main/scala/akka/http/engine/ws/Websocket.scala create mode 100644 akka-http-core/src/test/scala/akka/http/engine/ws/MessageSpec.scala diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/FrameHandler.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/FrameHandler.scala new file mode 100644 index 0000000000..bde765f3d7 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/FrameHandler.scala @@ -0,0 +1,200 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.stream.scaladsl.Flow +import akka.stream.stage.{ TerminationDirective, SyncDirective, Context, StatefulStage } +import akka.util.ByteString +import Protocol.Opcode + +import scala.util.control.NonFatal + +/** + * The frame handler validates frames, multiplexes data to the user handler or to the bypass and + * UTF-8 decodes text frames. + * + * INTERNAL API + */ +private[http] object FrameHandler { + def create(server: Boolean): Flow[FrameEvent, Either[BypassEvent, MessagePart], Unit] = + Flow[FrameEvent].transform(() ⇒ new HandlerStage(server)) + + class HandlerStage(server: Boolean) extends StatefulStage[FrameEvent, Either[BypassEvent, MessagePart]] { + type Ctx = Context[Either[BypassEvent, MessagePart]] + def initial: State = Idle + + object Idle extends StateWithControlFrameHandling { + def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = + (start.header.opcode, start.isFullMessage) match { + case (Opcode.Binary, true) ⇒ publishMessagePart(BinaryMessagePart(start.data, last = true)) + case (Opcode.Binary, false) ⇒ becomeAndHandleWith(new CollectingBinaryMessage, start) + case (Opcode.Text, _) ⇒ becomeAndHandleWith(new CollectingTextMessage, start) + case x ⇒ protocolError() + } + } + + class CollectingBinaryMessage extends CollectingMessageFrame(Opcode.Binary) { + def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = BinaryMessagePart(data, last) + } + class CollectingTextMessage extends CollectingMessageFrame(Opcode.Text) { + val decoder = Utf8Decoder.create() + + def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = + TextMessagePart(decoder.decode(data, endOfInput = last).get, last) + } + + abstract class CollectingMessageFrame(expectedOpcode: Opcode) extends StateWithControlFrameHandling { + var expectFirstHeader = true + var finSeen = false + def createMessagePart(data: ByteString, last: Boolean): MessageDataPart + + def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = { + if ((expectFirstHeader && start.header.opcode == expectedOpcode) // first opcode must be the expected + || start.header.opcode == Opcode.Continuation) { // further ones continuations + expectFirstHeader = false + + if (start.header.fin) finSeen = true + publish(start) + } else protocolError() + } + override def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = publish(data) + + private def publish(part: FrameEvent)(implicit ctx: Ctx): SyncDirective = + try publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart)) + catch { + case NonFatal(e) ⇒ closeWithCode(Protocol.CloseCodes.InconsistentData) + } + } + class CollectingControlFrame(opcode: Opcode, _data: ByteString, nextState: State) extends InFrameState { + var data = _data + + def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = { + this.data ++= data.data + if (data.lastPart) handleControlFrame(opcode, this.data, nextState) + else ctx.pull() + } + } + object Closed extends State { + def onPush(elem: FrameEvent, ctx: Ctx): SyncDirective = + ctx.pull() // ignore + } + + def becomeAndHandleWith(newState: State, part: FrameEvent)(implicit ctx: Ctx): SyncDirective = { + become(newState) + current.onPush(part, ctx) + } + + /** Returns a SyncDirective if it handled the message */ + def validateHeader(header: FrameHeader)(implicit ctx: Ctx): Option[SyncDirective] = header match { + case h: FrameHeader if h.mask.isDefined && !server ⇒ Some(protocolError()) + case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3 ⇒ Some(protocolError()) + case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) ⇒ Some(protocolError()) + case _ ⇒ None + } + + def handleControlFrame(opcode: Opcode, data: ByteString, nextState: State)(implicit ctx: Ctx): SyncDirective = { + become(nextState) + opcode match { + case Opcode.Ping ⇒ publishDirectResponse(FrameEvent.fullFrame(Opcode.Pong, None, data, fin = true)) + case Opcode.Pong ⇒ + // ignore unsolicited Pong frame + ctx.pull() + case Opcode.Close ⇒ + val closeCode = FrameEventParser.parseCloseCode(data) + emit(Iterator(Left(PeerClosed(closeCode)), Right(PeerClosed(closeCode))), ctx, WaitForPeerTcpClose) + case Opcode.Other(o) ⇒ closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode") + } + } + private def collectControlFrame(start: FrameStart, nextState: State)(implicit ctx: Ctx): SyncDirective = { + assert(!start.isFullMessage) + become(new CollectingControlFrame(start.header.opcode, start.data, nextState)) + ctx.pull() + } + + private def publishMessagePart(part: MessageDataPart)(implicit ctx: Ctx): SyncDirective = + if (part.last) emit(Iterator(Right(part), Right(MessageEnd)), ctx, Idle) + else ctx.push(Right(part)) + private def publishDirectResponse(frame: FrameStart)(implicit ctx: Ctx): SyncDirective = + ctx.push(Left(DirectAnswer(frame))) + + private def protocolError(reason: String = "")(implicit ctx: Ctx): SyncDirective = + closeWithCode(Protocol.CloseCodes.ProtocolError, reason) + + private def closeWithCode(closeCode: Int, reason: String = "", cause: Throwable = null)(implicit ctx: Ctx): SyncDirective = + emit( + Iterator( + Left(ActivelyCloseWithCode(Some(closeCode), reason)), + Right(ActivelyCloseWithCode(Some(closeCode), reason))), ctx, CloseAfterPeerClosed) + + object CloseAfterPeerClosed extends State { + def onPush(elem: FrameEvent, ctx: Context[Either[BypassEvent, MessagePart]]): SyncDirective = + elem match { + case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data) ⇒ + become(WaitForPeerTcpClose) + ctx.push(Left(PeerClosed(FrameEventParser.parseCloseCode(data)))) + + case _ ⇒ ctx.pull() // ignore all other data + } + } + object WaitForPeerTcpClose extends State { + def onPush(elem: FrameEvent, ctx: Context[Either[BypassEvent, MessagePart]]): SyncDirective = + ctx.pull() // ignore + } + + abstract class StateWithControlFrameHandling extends BetweenFrameState { + def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective + + def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = + validateHeader(start.header).getOrElse { + if (start.header.opcode.isControl) + if (start.isFullMessage) handleControlFrame(start.header.opcode, start.data, this) + else collectControlFrame(start, this) + else handleRegularFrameStart(start) + } + } + abstract class BetweenFrameState extends ImplicitContextState { + def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = + throw new IllegalStateException("Expected FrameStart") + } + abstract class InFrameState extends ImplicitContextState { + def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = + throw new IllegalStateException("Expected FrameData") + } + abstract class ImplicitContextState extends State { + def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective + def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective + + def onPush(part: FrameEvent, ctx: Ctx): SyncDirective = + part match { + case data: FrameData ⇒ handleFrameData(data)(ctx) + case start: FrameStart ⇒ handleFrameStart(start)(ctx) + } + } + } + + sealed trait MessagePart { + def isMessageEnd: Boolean + } + sealed trait MessageDataPart extends MessagePart { + def isMessageEnd = false + def last: Boolean + } + final case class TextMessagePart(data: String, last: Boolean) extends MessageDataPart + final case class BinaryMessagePart(data: ByteString, last: Boolean) extends MessageDataPart + case object MessageEnd extends MessagePart { + def isMessageEnd: Boolean = true + } + final case class PeerClosed(code: Option[Int], reason: String = "") extends MessagePart with BypassEvent { + def isMessageEnd: Boolean = true + } + + sealed trait BypassEvent + final case class DirectAnswer(frame: FrameStart) extends BypassEvent + final case class ActivelyCloseWithCode(code: Option[Int], reason: String = "") extends MessagePart with BypassEvent { + def isMessageEnd: Boolean = true + } + case object UserHandlerCompleted extends BypassEvent + case class UserHandlerErredOut(cause: Throwable) extends BypassEvent +} diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/FrameOutHandler.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/FrameOutHandler.scala new file mode 100644 index 0000000000..4f560cca33 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/FrameOutHandler.scala @@ -0,0 +1,132 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import scala.concurrent.duration.FiniteDuration + +import akka.stream.stage._ +import akka.http.util.Timestamp +import FrameHandler.{ UserHandlerCompleted, ActivelyCloseWithCode, PeerClosed, DirectAnswer } +import Websocket.Tick + +/** + * Implements the transport connection close handling at the end of the pipeline. + * + * INTERNAL API + */ +private[http] class FrameOutHandler(serverSide: Boolean, _closeTimeout: FiniteDuration) extends StatefulStage[AnyRef, FrameStart] { + def initial: StageState[AnyRef, FrameStart] = Idle + def closeTimeout: Timestamp = Timestamp.now + _closeTimeout + + object Idle extends CompletionHandlingState { + def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match { + case start: FrameStart ⇒ ctx.push(start) + case DirectAnswer(frame) ⇒ ctx.push(frame) + case PeerClosed(code, reason) if !code.exists(Protocol.CloseCodes.isError) ⇒ + // let user complete it, FIXME: maybe make configurable? immediately, or timeout + become(new WaitingForUserHandlerClosed(FrameEvent.closeFrame(code.getOrElse(Protocol.CloseCodes.Regular), reason))) + ctx.pull() + case PeerClosed(code, reason) ⇒ + val closeFrame = FrameEvent.closeFrame(code.getOrElse(Protocol.CloseCodes.Regular), reason) + if (serverSide) ctx.pushAndFinish(closeFrame) + else { + become(new WaitingForTransportClose) + ctx.push(closeFrame) + } + case ActivelyCloseWithCode(code, reason) ⇒ + val closeFrame = FrameEvent.closeFrame(code.getOrElse(Protocol.CloseCodes.Regular), reason) + become(new WaitingForPeerCloseFrame()) + ctx.push(closeFrame) + case UserHandlerCompleted ⇒ + become(new WaitingForPeerCloseFrame()) + ctx.push(FrameEvent.closeFrame(Protocol.CloseCodes.Regular)) + case Tick ⇒ ctx.pull() // ignore + } + + def onComplete(ctx: Context[FrameStart]): TerminationDirective = { + become(new SendOutCloseFrameAndComplete(FrameEvent.closeFrame(Protocol.CloseCodes.Regular))) + ctx.absorbTermination() + } + } + + /** + * peer has closed, we want to wait for user handler to close as well + */ + class WaitingForUserHandlerClosed(closeFrame: FrameStart) extends CompletionHandlingState { + def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match { + case UserHandlerCompleted ⇒ + if (serverSide) ctx.pushAndFinish(closeFrame) + else { + become(new WaitingForTransportClose()) + ctx.push(closeFrame) + } + case start: FrameStart ⇒ ctx.push(start) + case _ ⇒ ctx.pull() // ignore + } + + def onComplete(ctx: Context[FrameStart]): TerminationDirective = + ctx.fail(new IllegalStateException("Mustn't complete before user has completed")) + } + + /** + * we have sent out close frame and wait for peer to sent its close frame + */ + class WaitingForPeerCloseFrame(timeout: Timestamp = closeTimeout) extends CompletionHandlingState { + def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match { + case Tick ⇒ + if (timeout.isPast) ctx.finish() + else ctx.pull() + case PeerClosed(code, reason) ⇒ + if (serverSide) ctx.finish() + else { + become(new WaitingForTransportClose()) + ctx.pull() + } + case _ ⇒ ctx.pull() // ignore + } + + def onComplete(ctx: Context[FrameStart]): TerminationDirective = ctx.finish() + } + + /** + * Both side have sent their close frames, server should close the connection first + */ + class WaitingForTransportClose(timeout: Timestamp = closeTimeout) extends CompletionHandlingState { + def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = elem match { + case Tick ⇒ + if (timeout.isPast) ctx.finish() + else ctx.pull() + case _ ⇒ ctx.pull() // ignore + } + + def onComplete(ctx: Context[FrameStart]): TerminationDirective = ctx.finish() + } + + /** If upstream has already failed we just wait to be able to deliver our close frame and complete */ + class SendOutCloseFrameAndComplete(closeFrame: FrameStart) extends CompletionHandlingState { + def onPush(elem: AnyRef, ctx: Context[FrameStart]): SyncDirective = + ctx.fail(new IllegalStateException("Didn't expect push after completion")) + + override def onPull(ctx: Context[FrameStart]): SyncDirective = + ctx.pushAndFinish(closeFrame) + + def onComplete(ctx: Context[FrameStart]): TerminationDirective = + ctx.absorbTermination() + } + + trait CompletionHandlingState extends State { + def onComplete(ctx: Context[FrameStart]): TerminationDirective + } + + override def onUpstreamFinish(ctx: Context[FrameStart]): TerminationDirective = + current.asInstanceOf[CompletionHandlingState].onComplete(ctx) + + override def onUpstreamFailure(cause: scala.Throwable, ctx: Context[FrameStart]): TerminationDirective = cause match { + case p: ProtocolException ⇒ + become(new SendOutCloseFrameAndComplete(FrameEvent.closeFrame(Protocol.CloseCodes.ProtocolError))) + ctx.absorbTermination() + case _ ⇒ super.onUpstreamFailure(cause, ctx) + } +} \ No newline at end of file diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/Masking.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/Masking.scala new file mode 100644 index 0000000000..dbf7d85338 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/Masking.scala @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.stream.scaladsl.Flow +import akka.stream.stage.{ SyncDirective, Context, StageState, StatefulStage } + +import scala.util.Random + +/** + * Implements Websocket Frame masking. + * + * INTERNAL API + */ +private[http] object Masking { + def maskIf(condition: Boolean, maskRandom: () ⇒ Random): Flow[FrameEvent, FrameEvent, Unit] = + if (condition) Flow[FrameEvent].transform(() ⇒ new Masking(maskRandom())) // new random per materialization + else Flow[FrameEvent] + def unmaskIf(condition: Boolean): Flow[FrameEvent, FrameEvent, Unit] = + if (condition) Flow[FrameEvent].transform(() ⇒ new Unmasking()) + else Flow[FrameEvent] + + class Masking(random: Random) extends Masker { + def extractMask(header: FrameHeader): Int = random.nextInt() + def setNewMask(header: FrameHeader, mask: Int): FrameHeader = { + if (header.mask.isDefined) throw new ProtocolException("Frame mustn't already be masked") + header.copy(mask = Some(mask)) + } + } + class Unmasking extends Masker { + def extractMask(header: FrameHeader): Int = header.mask match { + case Some(mask) ⇒ mask + case None ⇒ throw new ProtocolException("Frame wasn't masked") + } + def setNewMask(header: FrameHeader, mask: Int): FrameHeader = header.copy(mask = None) + } + + /** Implements both masking and unmasking which is mostly symmetric (because of XOR) */ + abstract class Masker extends StatefulStage[FrameEvent, FrameEvent] { + def extractMask(header: FrameHeader): Int + def setNewMask(header: FrameHeader, mask: Int): FrameHeader + + def initial: State = Idle + + object Idle extends State { + def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective = + part match { + case start @ FrameStart(header, data) ⇒ + if (header.length == 0) ctx.push(part) + else { + val mask = extractMask(header) + become(new Running(mask)) + current.onPush(start.copy(header = setNewMask(header, mask)), ctx) + } + } + } + class Running(initialMask: Int) extends State { + var mask = initialMask + + def onPush(part: FrameEvent, ctx: Context[FrameEvent]): SyncDirective = { + if (part.lastPart) become(Idle) + + val (masked, newMask) = FrameEventParser.mask(part.data, mask) + mask = newMask + ctx.push(part.withData(data = masked)) + } + } + } +} diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/MessageToFrameRenderer.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/MessageToFrameRenderer.scala new file mode 100644 index 0000000000..3d3c4c8d6b --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/MessageToFrameRenderer.scala @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import akka.util.ByteString +import akka.stream.scaladsl.{ FlattenStrategy, Source, Flow } + +import Protocol.Opcode +import akka.http.model.ws._ + +/** + * Renders messages to full frames. + * + * INTERNAL API + */ +private[http] object MessageToFrameRenderer { + def create(serverSide: Boolean): Flow[Message, FrameStart, Unit] = { + def strictFrames(opcode: Opcode, data: ByteString): Source[FrameStart, _] = + // FIXME: fragment? + Source.single(FrameEvent.fullFrame(opcode, None, data, fin = true)) + + def streamedFrames(opcode: Opcode, data: Source[ByteString, _]): Source[FrameStart, _] = + Source.single(FrameEvent.empty(opcode, fin = false)) ++ + data.map(FrameEvent.fullFrame(Opcode.Continuation, None, _, fin = false)) ++ + Source.single(FrameEvent.emptyLastContinuationFrame) + + Flow[Message] + .map { + case BinaryMessage.Strict(data) ⇒ strictFrames(Opcode.Binary, data) + case BinaryMessage.Streamed(data) ⇒ streamedFrames(Opcode.Binary, data) + case TextMessage.Strict(text) ⇒ strictFrames(Opcode.Text, ByteString(text, "UTF-8")) + case TextMessage.Streamed(text) ⇒ streamedFrames(Opcode.Text, text.transform(() ⇒ new Utf8Encoder)) + }.flatten(FlattenStrategy.Concat()) + } +} diff --git a/akka-http-core/src/main/scala/akka/http/engine/ws/Websocket.scala b/akka-http-core/src/main/scala/akka/http/engine/ws/Websocket.scala new file mode 100644 index 0000000000..e19827d0d1 --- /dev/null +++ b/akka-http-core/src/main/scala/akka/http/engine/ws/Websocket.scala @@ -0,0 +1,179 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import java.security.SecureRandom + +import scala.concurrent.duration._ + +import akka.stream.{ OperationAttributes, FanOutShape2, FanInShape3, Inlet } +import akka.stream.scaladsl._ +import akka.stream.stage._ +import FlexiRoute.{ DemandFrom, DemandFromAny, RouteLogic } +import FlexiMerge.MergeLogic + +import akka.http.util._ +import akka.http.model.ws._ + +/** + * INTERNAL API + */ +private[http] object Websocket { + import FrameHandler._ + + def handleMessages[T](messageHandler: Flow[Message, Message, T], + serverSide: Boolean = true, + closeTimeout: FiniteDuration = 3.seconds): Flow[FrameEvent, FrameEvent, Unit] = { + /** Completes this branch of the flow if no more messages are expected and converts close codes into errors */ + class PrepareForUserHandler extends PushStage[MessagePart, MessagePart] { + def onPush(elem: MessagePart, ctx: Context[MessagePart]): SyncDirective = elem match { + case PeerClosed(code, reason) ⇒ + if (code.exists(Protocol.CloseCodes.isError)) ctx.fail(new ProtocolException(s"Peer closed connection with code $code")) + else ctx.finish() + case ActivelyCloseWithCode(code, reason) ⇒ + if (code.exists(Protocol.CloseCodes.isError)) ctx.fail(new ProtocolException(s"Closing connection with error code $code")) + else ctx.finish() + case x ⇒ ctx.push(x) + } + } + + /** Collects user-level API messages from MessageDataParts */ + val collectMessage: Flow[Source[MessageDataPart, Unit], Message, Unit] = + Flow[Source[MessageDataPart, Unit]] + .headAndTail + .map { + case (TextMessagePart(text, true), remaining) ⇒ + TextMessage.Strict(text) + case (first @ TextMessagePart(text, false), remaining) ⇒ + TextMessage.Streamed( + (Source.single(first) ++ remaining) + .collect { + case t: TextMessagePart if t.data.nonEmpty ⇒ t.data + }) + case (BinaryMessagePart(data, true), remaining) ⇒ + BinaryMessage.Strict(data) + case (first @ BinaryMessagePart(data, false), remaining) ⇒ + BinaryMessage.Streamed( + (Source.single(first) ++ remaining) + .collect { + case t: BinaryMessagePart if t.data.nonEmpty ⇒ t.data + }) + } + + /** Lifts onComplete and onError into events to be processed in the FlexiMerge */ + class LiftCompletions extends StatefulStage[FrameStart, AnyRef] { + def initial: StageState[FrameStart, AnyRef] = SteadyState + + object SteadyState extends State { + def onPush(elem: FrameStart, ctx: Context[AnyRef]): SyncDirective = ctx.push(elem) + } + class CompleteWith(last: AnyRef) extends State { + def onPush(elem: FrameStart, ctx: Context[AnyRef]): SyncDirective = + ctx.fail(new IllegalStateException("No push expected")) + + override def onPull(ctx: Context[AnyRef]): SyncDirective = ctx.pushAndFinish(last) + } + + override def onUpstreamFinish(ctx: Context[AnyRef]): TerminationDirective = { + become(new CompleteWith(UserHandlerCompleted)) + ctx.absorbTermination() + } + override def onUpstreamFailure(cause: Throwable, ctx: Context[AnyRef]): TerminationDirective = { + become(new CompleteWith(UserHandlerErredOut(cause))) + ctx.absorbTermination() + } + } + + lazy val userFlow = + Flow[MessagePart] + .transform(() ⇒ new PrepareForUserHandler) + .splitWhen(_.isMessageEnd) // FIXME using splitAfter from #16885 would simplify protocol a lot + .map(_.collect { + case m: MessageDataPart ⇒ m + }) + .via(collectMessage) + .via(messageHandler) + .via(MessageToFrameRenderer.create(serverSide)) + .transform(() ⇒ new LiftCompletions) + + /** + * Distributes output from the FrameHandler into bypass and userFlow. + */ + object BypassRouter + extends FlexiRoute[Either[BypassEvent, MessagePart], FanOutShape2[Either[BypassEvent, MessagePart], BypassEvent, MessagePart]](new FanOutShape2("bypassRouter"), OperationAttributes.name("bypassRouter")) { + def createRouteLogic(s: FanOutShape2[Either[BypassEvent, MessagePart], BypassEvent, MessagePart]): RouteLogic[Either[BypassEvent, MessagePart]] = + new RouteLogic[Either[BypassEvent, MessagePart]] { + def initialState: State[_] = State(DemandFromAny(s)) { (ctx, out, ev) ⇒ + ev match { + case Left(_) ⇒ + State(DemandFrom(s.out0)) { (ctx, _, ev) ⇒ // FIXME: #17004 + ctx.emit(s.out0)(ev.left.get) + initialState + } + case Right(_) ⇒ + State(DemandFrom(s.out1)) { (ctx, _, ev) ⇒ + ctx.emit(s.out1)(ev.right.get) + initialState + } + } + } + + override def initialCompletionHandling: CompletionHandling = super.initialCompletionHandling.copy( + onDownstreamFinish = { (ctx, out) ⇒ + if (out == s.out0) ctx.finish() + SameState + }) + } + } + /** + * Merges bypass, user flow and tick source for consumption in the FrameOutHandler. + */ + object BypassMerge extends FlexiMerge[AnyRef, FanInShape3[BypassEvent, AnyRef, Tick.type, AnyRef]](new FanInShape3("bypassMerge"), OperationAttributes.name("bypassMerge")) { + def createMergeLogic(s: FanInShape3[BypassEvent, AnyRef, Tick.type, AnyRef]): MergeLogic[AnyRef] = + new MergeLogic[AnyRef] { + def initialState: State[_] = Idle + + lazy val Idle = State[AnyRef](FlexiMerge.ReadAny(s.in0.asInstanceOf[Inlet[AnyRef]], s.in1.asInstanceOf[Inlet[AnyRef]], s.in2.asInstanceOf[Inlet[AnyRef]])) { (ctx, in, elem) ⇒ + ctx.emit(elem) + SameState + } + + override def initialCompletionHandling: CompletionHandling = + CompletionHandling( + onUpstreamFinish = { (ctx, in) ⇒ + if (in == s.in0) ctx.finish() + SameState + }, + onUpstreamFailure = { (ctx, in, cause) ⇒ + if (in == s.in0) ctx.fail(cause) + SameState + }) + } + } + + lazy val bypassAndUserHandler: Flow[Either[BypassEvent, MessagePart], AnyRef, Unit] = + Flow(BypassRouter, Source(closeTimeout, closeTimeout, Tick), BypassMerge)((_, _, _) ⇒ ()) { implicit b ⇒ + (split, tick, merge) ⇒ + import FlowGraph.Implicits._ + + split.out0 ~> merge.in0 + split.out1 ~> userFlow ~> merge.in1 + tick.outlet ~> merge.in2 + + (split.in, merge.out) + } + + Flow[FrameEvent] + .via(Masking.unmaskIf(serverSide)) + .via(FrameHandler.create(server = serverSide)) + .mapConcat(x ⇒ x :: x :: Nil) // FIXME: #17004 + .via(bypassAndUserHandler) + .transform(() ⇒ new FrameOutHandler(serverSide, closeTimeout)) + .via(Masking.maskIf(!serverSide, () ⇒ new SecureRandom())) + } + + object Tick + case object SwitchToWebsocketToken +} diff --git a/akka-http-core/src/test/scala/akka/http/engine/ws/MessageSpec.scala b/akka-http-core/src/test/scala/akka/http/engine/ws/MessageSpec.scala new file mode 100644 index 0000000000..bbe635db35 --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/engine/ws/MessageSpec.scala @@ -0,0 +1,966 @@ +/* + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.http.engine.ws + +import scala.annotation.tailrec +import scala.concurrent.duration._ + +import akka.stream.FlowShape +import akka.stream.scaladsl._ +import akka.stream.testkit.StreamTestKit +import akka.util.ByteString +import org.scalatest.{ Matchers, FreeSpec } + +import scala.util.Random + +import akka.http.model.ws._ +import Protocol.Opcode + +class MessageSpec extends FreeSpec with Matchers with WithMaterializerSpec { + "The Websocket implementation should" - { + "collect messages from frames" - { + "for binary messages" - { + "for an empty message" in new ClientTestSetup { + val input = frameHeader(Opcode.Binary, 0, fin = true) + + pushInput(input) + expectMessage(BinaryMessage.Strict(ByteString.empty)) + } + "for one complete, strict, single frame message" in new ClientTestSetup { + val data = ByteString("abcdef", "ASCII") + val input = frameHeader(Opcode.Binary, 6, fin = true) ++ data + + pushInput(input) + expectMessage(BinaryMessage.Strict(data)) + } + "for a partial frame" in new ClientTestSetup { + val data1 = ByteString("abc", "ASCII") + val header = frameHeader(Opcode.Binary, 6, fin = true) + + pushInput(header ++ data1) + val BinaryMessage.Streamed(dataSource) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(2) + sub.expectNext(data1) + } + "for a frame split up into parts" in new ClientTestSetup { + val data1 = ByteString("abc", "ASCII") + val header = frameHeader(Opcode.Binary, 6, fin = true) + + pushInput(header) + val BinaryMessage.Streamed(dataSource) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(2) + pushInput(data1) + sub.expectNext(data1) + + val data2 = ByteString("def", "ASCII") + pushInput(data2) + sub.expectNext(data2) + sub.expectComplete() + } + + "for a message split into several frames" in new ClientTestSetup { + val data1 = ByteString("abc", "ASCII") + val header1 = frameHeader(Opcode.Binary, 3, fin = false) + + pushInput(header1 ++ data1) + val BinaryMessage.Streamed(dataSource) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(2) + sub.expectNext(data1) + + val header2 = frameHeader(Opcode.Continuation, 4, fin = true) + val data2 = ByteString("defg", "ASCII") + pushInput(header2 ++ data2) + sub.expectNext(data2) + sub.expectComplete() + } + "for several messages" in new ClientTestSetup { + val data1 = ByteString("abc", "ASCII") + val header1 = frameHeader(Opcode.Binary, 3, fin = false) + + pushInput(header1 ++ data1) + val BinaryMessage.Streamed(dataSource) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(2) + sub.expectNext(data1) + + val header2 = frameHeader(Opcode.Continuation, 4, fin = true) + val header3 = frameHeader(Opcode.Binary, 2, fin = true) + val data2 = ByteString("defg", "ASCII") + val data3 = ByteString("h") + pushInput(header2 ++ data2 ++ header3 ++ data3) + sub.expectNext(data2) + sub.expectComplete() + + val BinaryMessage.Streamed(dataSource2) = expectMessage() + val sub2 = StreamTestKit.SubscriberProbe[ByteString] + dataSource2.runWith(Sink(sub2)) + val s2 = sub2.expectSubscription() + s2.request(2) + sub2.expectNext(data3) + + val data4 = ByteString("i") + pushInput(data4) + sub2.expectNext(data4) + sub2.expectComplete() + } + "unmask masked input on the server side" in new ServerTestSetup { + val mask = Random.nextInt() + val (data, _) = maskedASCII("abcdef", mask) + val data1 = data.take(3) + val data2 = data.drop(3) + val header = frameHeader(Opcode.Binary, 6, fin = true, mask = Some(mask)) + + pushInput(header ++ data1) + val BinaryMessage.Streamed(dataSource) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(2) + sub.expectNext(ByteString("abc", "ASCII")) + + pushInput(data2) + sub.expectNext(ByteString("def", "ASCII")) + sub.expectComplete() + } + } + "for text messages" - { + "empty message" in new ClientTestSetup { + val input = frameHeader(Opcode.Text, 0, fin = true) + + pushInput(input) + expectMessage(TextMessage.Strict("")) + } + "decode complete, strict frame from utf8" in new ClientTestSetup { + val msg = "äbcdef€\uffff" + val data = ByteString(msg, "UTF-8") + val input = frameHeader(Opcode.Text, data.size, fin = true) ++ data + + pushInput(input) + expectMessage(TextMessage.Strict(msg)) + } + "decode utf8 as far as possible for partial frame" in new ClientTestSetup { + val msg = "bäcdef€" + val data = ByteString(msg, "UTF-8") + val data0 = data.slice(0, 2) + val data1 = data.slice(2, 5) + val data2 = data.slice(5, data.size) + val input = frameHeader(Opcode.Text, data.size, fin = true) ++ data0 + + pushInput(input) + val TextMessage.Streamed(parts) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[String] + parts.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(4) + sub.expectNext("b") + + pushInput(data1) + sub.expectNext("äcd") + } + "decode utf8 with code point split across frames" in new ClientTestSetup { + val msg = "äbcdef€" + val data = ByteString(msg, "UTF-8") + val data0 = data.slice(0, 1) + val data1 = data.slice(1, data.size) + val header0 = frameHeader(Opcode.Text, data0.size, fin = false) + + pushInput(header0 ++ data0) + val TextMessage.Streamed(parts) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[String] + parts.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(4) + sub.expectNoMsg(100.millis) + + val header1 = frameHeader(Opcode.Continuation, data1.size, fin = true) + pushInput(header1 ++ data1) + sub.expectNext("äbcdef€") + } + "unmask masked input on the server side" in new ServerTestSetup { + val mask = Random.nextInt() + val (data, _) = maskedUTF8("äbcdef€", mask) + val data1 = data.take(3) + val data2 = data.drop(3) + val header = frameHeader(Opcode.Binary, data.size, fin = true, mask = Some(mask)) + + pushInput(header ++ data1) + val BinaryMessage.Streamed(dataSource) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(2) + sub.expectNext(ByteString("äb", "UTF-8")) + + pushInput(data2) + sub.expectNext(ByteString("cdef€", "UTF-8")) + sub.expectComplete() + } + } + } + "render frames from messages" - { + "for binary messages" - { + "for a short strict message" in new ServerTestSetup { + val data = ByteString("abcdef", "ASCII") + val msg = BinaryMessage.Strict(data) + netOutSub.request(5) + pushMessage(msg) + + expectFrameOnNetwork(Opcode.Binary, data, fin = true) + } + "for a strict message larger than configured maximum frame size" in pending + "for a streamed message" in new ServerTestSetup { + val data = ByteString("abcdefg", "ASCII") + val pub = StreamTestKit.PublisherProbe[ByteString] + val msg = BinaryMessage.Streamed(Source(pub)) + netOutSub.request(6) + pushMessage(msg) + val sub = pub.expectSubscription() + + expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false) + + val data1 = data.take(3) + val data2 = data.drop(3) + + sub.sendNext(data1) + expectFrameOnNetwork(Opcode.Continuation, data1, fin = false) + + sub.sendNext(data2) + expectFrameOnNetwork(Opcode.Continuation, data2, fin = false) + + sub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + } + "for a streamed message with a chunk being larger than configured maximum frame size" in pending + "and mask input on the client side" in new ClientTestSetup { + val data = ByteString("abcdefg", "ASCII") + val pub = StreamTestKit.PublisherProbe[ByteString] + val msg = BinaryMessage.Streamed(Source(pub)) + netOutSub.request(7) + pushMessage(msg) + val sub = pub.expectSubscription() + + expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false) + + val data1 = data.take(3) + val data2 = data.drop(3) + + sub.sendNext(data1) + expectMaskedFrameOnNetwork(Opcode.Continuation, data1, fin = false) + + sub.sendNext(data2) + expectMaskedFrameOnNetwork(Opcode.Continuation, data2, fin = false) + + sub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + } + } + "for text messages" - { + "for a short strict message" in new ServerTestSetup { + val text = "äbcdef" + val msg = TextMessage.Strict(text) + netOutSub.request(5) + pushMessage(msg) + + expectFrameOnNetwork(Opcode.Text, ByteString(text, "UTF-8"), fin = true) + } + "for a strict message larger than configured maximum frame size" in pending + "for a streamed message" in new ServerTestSetup { + val text = "äbcd€fg" + val pub = StreamTestKit.PublisherProbe[String] + val msg = TextMessage.Streamed(Source(pub)) + netOutSub.request(6) + pushMessage(msg) + val sub = pub.expectSubscription() + + expectFrameHeaderOnNetwork(Opcode.Text, 0, fin = false) + + val text1 = text.take(3) + val text1Bytes = ByteString(text1, "UTF-8") + val text2 = text.drop(3) + val text2Bytes = ByteString(text2, "UTF-8") + + sub.sendNext(text1) + expectFrameOnNetwork(Opcode.Continuation, text1Bytes, fin = false) + + sub.sendNext(text2) + expectFrameOnNetwork(Opcode.Continuation, text2Bytes, fin = false) + + sub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + } + "for a streamed message don't convert half surrogate pairs naively" in new ServerTestSetup { + val gclef = "𝄞" + gclef.size shouldEqual 2 + + // split up the code point + val half1 = gclef.take(1) + val half2 = gclef.drop(1) + println(half1(0).toInt.toHexString) + println(half2(0).toInt.toHexString) + + val pub = StreamTestKit.PublisherProbe[String] + val msg = TextMessage.Streamed(Source(pub)) + netOutSub.request(6) + + pushMessage(msg) + val sub = pub.expectSubscription() + + expectFrameHeaderOnNetwork(Opcode.Text, 0, fin = false) + sub.sendNext(half1) + + expectNoNetworkData() + sub.sendNext(half2) + expectFrameOnNetwork(Opcode.Continuation, ByteString(gclef, "utf8"), fin = false) + } + "for a streamed message with a chunk being larger than configured maximum frame size" in pending + "and mask input on the client side" in new ClientTestSetup { + val text = "abcdefg" + val pub = StreamTestKit.PublisherProbe[String] + val msg = TextMessage.Streamed(Source(pub)) + netOutSub.request(5) + pushMessage(msg) + val sub = pub.expectSubscription() + + expectFrameOnNetwork(Opcode.Text, ByteString.empty, fin = false) + + val text1 = text.take(3) + val text1Bytes = ByteString(text1, "UTF-8") + val text2 = text.drop(3) + val text2Bytes = ByteString(text2, "UTF-8") + + sub.sendNext(text1) + expectMaskedFrameOnNetwork(Opcode.Continuation, text1Bytes, fin = false) + + sub.sendNext(text2) + expectMaskedFrameOnNetwork(Opcode.Continuation, text2Bytes, fin = false) + + sub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + } + } + } + "supply automatic low-level websocket behavior" - { + "respond to ping frames unmasking them on the server side" in new ServerTestSetup { + val mask = Random.nextInt() + val input = frameHeader(Opcode.Ping, 6, fin = true, mask = Some(mask)) ++ maskedASCII("abcdef", mask)._1 + + pushInput(input) + netOutSub.request(5) + expectFrameOnNetwork(Opcode.Pong, ByteString("abcdef"), fin = true) + } + "respond to ping frames masking them on the client side" in new ClientTestSetup { + val input = frameHeader(Opcode.Ping, 6, fin = true) ++ ByteString("abcdef") + + pushInput(input) + netOutSub.request(5) + expectMaskedFrameOnNetwork(Opcode.Pong, ByteString("abcdef"), fin = true) + } + "respond to ping frames interleaved with data frames (without mixing frame data)" in new ServerTestSetup { + // receive multi-frame message + // receive and handle interleaved ping frame + // concurrently send out messages from handler + val mask1 = Random.nextInt() + val input1 = frameHeader(Opcode.Binary, 3, fin = false, mask = Some(mask1)) ++ maskedASCII("123", mask1)._1 + pushInput(input1) + + val BinaryMessage.Streamed(dataSource) = expectMessage() + val sub = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(sub)) + val s = sub.expectSubscription() + s.request(2) + sub.expectNext(ByteString("123", "ASCII")) + + val outPub = StreamTestKit.PublisherProbe[ByteString] + val msg = BinaryMessage.Streamed(Source(outPub)) + netOutSub.request(10) + pushMessage(msg) + + expectFrameHeaderOnNetwork(Opcode.Binary, 0, fin = false) + + val outSub = outPub.expectSubscription() + val outData1 = ByteString("abc", "ASCII") + outSub.sendNext(outData1) + expectFrameOnNetwork(Opcode.Continuation, outData1, fin = false) + + val pingMask = Random.nextInt() + val pingData = maskedASCII("pling", pingMask)._1 + val pingData0 = pingData.take(3) + val pingData1 = pingData.drop(3) + pushInput(frameHeader(Opcode.Ping, 5, fin = true, mask = Some(pingMask)) ++ pingData0) + expectNoNetworkData() + pushInput(pingData1) + expectFrameOnNetwork(Opcode.Pong, ByteString("pling", "ASCII"), fin = true) + + val outData2 = ByteString("def", "ASCII") + outSub.sendNext(outData2) + expectFrameOnNetwork(Opcode.Continuation, outData2, fin = false) + + outSub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + + val mask2 = Random.nextInt() + val input2 = frameHeader(Opcode.Continuation, 3, fin = true, mask = Some(mask2)) ++ maskedASCII("456", mask2)._1 + pushInput(input2) + sub.expectNext(ByteString("456", "ASCII")) + sub.expectComplete() + } + "don't respond to unsolicited pong frames" in new ClientTestSetup { + val data = frameHeader(Opcode.Pong, 6, fin = true) ++ ByteString("abcdef") + pushInput(data) + netOutSub.request(5) + expectNoNetworkData() + } + } + "provide close behavior" - { + "after receiving regular close frame when idle (user closes immediately)" in new ServerTestSetup { + netInSub.expectRequest() + netOutSub.request(20) + messageOutSub.request(20) + + pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) + messageIn.expectComplete() + + netIn.expectNoMsg(1.second) // especially the cancellation not yet + expectNoNetworkData() + messageOutSub.sendComplete() + + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + netOut.expectComplete() + netInSub.expectCancellation() + } + "after receiving close frame without close code" in new ServerTestSetup { + netInSub.expectRequest() + pushInput(frameHeader(Opcode.Close, 0, fin = true)) + messageIn.expectComplete() + + messageOutSub.sendComplete() + // especially mustn't be Procotol.CloseCodes.NoCodePresent + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + netOut.expectComplete() + netInSub.expectCancellation() + } + "after receiving regular close frame when idle (user still sends some data)" in new ServerTestSetup { + netOutSub.request(20) + messageOutSub.request(20) + + pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) + messageIn.expectComplete() + + // sending another message is allowed before closing (inherently racy) + val pub = StreamTestKit.PublisherProbe[ByteString] + val msg = BinaryMessage.Streamed(Source(pub)) + pushMessage(msg) + expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false) + + val data = ByteString("abc", "ASCII") + val dataSub = pub.expectSubscription() + dataSub.sendNext(data) + expectFrameOnNetwork(Opcode.Continuation, data, fin = false) + + dataSub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + + messageOutSub.sendComplete() + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + netOut.expectComplete() + } + "after receiving regular close frame when fragmented message is still open" in pendingUntilFixed { + new ServerTestSetup { + netOutSub.request(10) + messageInSub.request(10) + + pushInput(frameHeader(Protocol.Opcode.Binary, 0, fin = false)) + val BinaryMessage.Streamed(dataSource) = messageIn.expectNext() + val inSubscriber = StreamTestKit.SubscriberProbe[ByteString] + dataSource.runWith(Sink(inSubscriber)) + val inSub = inSubscriber.expectSubscription() + + val outData = ByteString("def", "ASCII") + val mask = Random.nextInt() + pushInput(frameHeader(Protocol.Opcode.Continuation, 3, fin = false, mask = Some(mask)) ++ maskedBytes(outData, mask)._1) + inSub.request(5) + inSubscriber.expectNext(outData) + + pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) + messageIn.expectComplete() + inSubscriber.expectError() + // truncation of open message + + // sending another message is allowed before closing (inherently racy) + + val pub = StreamTestKit.PublisherProbe[ByteString] + val msg = BinaryMessage.Streamed(Source(pub)) + pushMessage(msg) + expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false) + + val data = ByteString("abc", "ASCII") + val dataSub = pub.expectSubscription() + dataSub.sendNext(data) + expectFrameOnNetwork(Opcode.Continuation, data, fin = false) + + dataSub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + + messageOutSub.sendComplete() + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + netOut.expectComplete() + } + } + "after receiving error close frame" in pending + "after peer closes connection without sending a close frame" in new ServerTestSetup { + netInSub.expectRequest() + netInSub.sendComplete() + + messageIn.expectComplete() + messageOutSub.sendComplete() + + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + netOut.expectComplete() + } + "when user handler closes (simple)" in new ServerTestSetup { + messageOutSub.sendComplete() + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + + netOut.expectNoMsg(1.second) // wait for peer to close regularly + pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) + + messageIn.expectComplete() + netOut.expectComplete() + netInSub.expectCancellation() + } + "when user handler closes main stream and substream only afterwards" in new ServerTestSetup { + netOutSub.request(10) + messageInSub.request(10) + + // send half a message + val pub = StreamTestKit.PublisherProbe[ByteString] + val msg = BinaryMessage.Streamed(Source(pub)) + pushMessage(msg) + expectFrameOnNetwork(Opcode.Binary, ByteString.empty, fin = false) + + val data = ByteString("abc", "ASCII") + val dataSub = pub.expectSubscription() + dataSub.sendNext(data) + expectFrameOnNetwork(Opcode.Continuation, data, fin = false) + + messageOutSub.sendComplete() + expectNoNetworkData() // need to wait for substream to close + + dataSub.sendComplete() + expectFrameOnNetwork(Opcode.Continuation, ByteString.empty, fin = true) + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + netOut.expectNoMsg(1.second) // wait for peer to close regularly + + val mask = Random.nextInt() + pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) + + messageIn.expectComplete() + netOut.expectComplete() + netInSub.expectCancellation() + } + "if user handler fails" in pending + "if peer closes with invalid close frame" - { + "close code outside of the valid range" in new ServerTestSetup { + netInSub.expectRequest() + pushInput(frameHeader(Opcode.Close, 1, mask = Some(Random.nextInt()), fin = true) ++ ByteString("x")) + + val error = messageIn.expectError() + + expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError) + netOut.expectComplete() + netInSub.expectCancellation() + } + "close data of size 1" in new ServerTestSetup { + netInSub.expectRequest() + pushInput(frameHeader(Opcode.Close, 1, mask = Some(Random.nextInt()), fin = true) ++ ByteString("x")) + + val error = messageIn.expectError() + + expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError) + netOut.expectComplete() + netInSub.expectCancellation() + } + "reason is no valid utf8 data" in pending + } + "timeout if user handler closes and peer doesn't send a close frame" in new ServerTestSetup { + netInSub.expectRequest() + messageOutSub.sendComplete() + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + + netOut.expectComplete() + netInSub.expectCancellation() + } + "timeout after we close after error and peer doesn't send a close frame" in new ServerTestSetup { + netInSub.expectRequest() + + pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv1 = true)) + expectProtocolErrorOnNetwork() + messageOutSub.sendComplete() + + netOut.expectComplete() + netInSub.expectCancellation() + } + "ignore frames peer sends after close frame" in new ServerTestSetup { + netInSub.expectRequest() + pushInput(closeFrame(Protocol.CloseCodes.Regular, mask = true)) + + messageIn.expectComplete() + + pushInput(frameHeader(Opcode.Binary, 0, fin = true)) + messageOutSub.sendComplete() + expectCloseCodeOnNetwork(Protocol.CloseCodes.Regular) + + netOut.expectComplete() + netInSub.expectCancellation() + } + } + "reject unexpected frames" - { + "reserved bits set" - { + "rsv1" in new ServerTestSetup { + pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv1 = true)) + expectProtocolErrorOnNetwork() + } + "rsv2" in new ServerTestSetup { + pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv2 = true)) + expectProtocolErrorOnNetwork() + } + "rsv3" in new ServerTestSetup { + pushInput(frameHeader(Opcode.Binary, 0, fin = true, rsv3 = true)) + expectProtocolErrorOnNetwork() + } + } + "highest bit of 64-bit length is set" in new ServerTestSetup { + import BitBuilder._ + + val header = + b"""0000 # flags + xxxx=1 # opcode + 1 # mask? + xxxxxxx=7f # length + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx=ffffffff + xxxxxxxx + xxxxxxxx + xxxxxxxx + xxxxxxxx=ffffffff # length64 + 00000000 + 00000000 + 00000000 + 00000000 # empty mask + """ + + pushInput(header) + expectProtocolErrorOnNetwork() + } + "control frame bigger than 125 bytes" in new ServerTestSetup { + pushInput(frameHeader(Opcode.Ping, 126, fin = true, mask = Some(0))) + expectProtocolErrorOnNetwork() + } + "fragmented control frame" in new ServerTestSetup { + pushInput(frameHeader(Opcode.Ping, 0, fin = false, mask = Some(0))) + expectProtocolErrorOnNetwork() + } + "unexpected continuation frame" in new ServerTestSetup { + pushInput(frameHeader(Opcode.Continuation, 0, fin = false, mask = Some(0))) + expectProtocolErrorOnNetwork() + } + "unexpected data frame when waiting for continuation" in new ServerTestSetup { + pushInput(frameHeader(Opcode.Binary, 0, fin = false) ++ + frameHeader(Opcode.Binary, 0, fin = false)) + expectProtocolErrorOnNetwork() + } + "invalid utf8 encoding for single frame message" in new ClientTestSetup { + val data = ByteString( + (128 + 64).toByte, // start two byte sequence + 0 // but don't finish it + ) + + pushInput(frameHeader(Opcode.Text, 2, fin = true) ++ data) + expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData) + } + "invalid utf8 encoding for streamed frame" in new ClientTestSetup { + val data = ByteString( + (128 + 64).toByte, // start two byte sequence + 0 // but don't finish it + ) + + pushInput(frameHeader(Opcode.Text, 0, fin = false) ++ + frameHeader(Opcode.Continuation, 2, fin = true) ++ + data) + expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData) + } + "truncated utf8 encoding for single frame message" in new ClientTestSetup { + val data = ByteString("€", "UTF-8").take(1) // half a euro + pushInput(frameHeader(Opcode.Text, 1, fin = true) ++ data) + expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData) + } + "truncated utf8 encoding for streamed frame" in new ClientTestSetup { + val data = ByteString("€", "UTF-8").take(1) // half a euro + pushInput(frameHeader(Opcode.Text, 0, fin = false) ++ + frameHeader(Opcode.Continuation, 1, fin = true) ++ + data) + expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData) + } + "half a surrogate pair in utf8 encoding for a strict frame" in new ClientTestSetup { + val data = ByteString(0xed, 0xa0, 0x80) // not strictly supported by utf-8 + pushInput(frameHeader(Opcode.Text, 3, fin = true) ++ data) + expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData) + } + "half a surrogate pair in utf8 encoding for a streamed frame" in new ClientTestSetup { + val data = ByteString(0xed, 0xa0, 0x80) // not strictly supported by utf-8 + pushInput(frameHeader(Opcode.Text, 0, fin = false)) + pushInput(frameHeader(Opcode.Continuation, 3, fin = true) ++ data) + + messageIn.expectError() + + expectCloseCodeOnNetwork(Protocol.CloseCodes.InconsistentData) + } + "unmasked input on the server side" in new ServerTestSetup { + val data = ByteString("abcdef", "ASCII") + val input = frameHeader(Opcode.Binary, 6, fin = true) ++ data + + pushInput(input) + expectProtocolErrorOnNetwork() + } + "masked input on the client side" in new ClientTestSetup { + val mask = Random.nextInt() + val input = frameHeader(Opcode.Binary, 6, fin = true, mask = Some(mask)) ++ maskedASCII("abcdef", mask)._1 + + pushInput(input) + expectProtocolErrorOnNetwork() + } + } + "support per-message-compression extension" in pending + } + + class ServerTestSetup extends TestSetup { + protected def serverSide: Boolean = true + } + class ClientTestSetup extends TestSetup { + protected def serverSide: Boolean = false + } + abstract class TestSetup { + protected def serverSide: Boolean + protected def closeTimeout: FiniteDuration = 1.second + + val netIn = StreamTestKit.PublisherProbe[ByteString] + val netOut = StreamTestKit.SubscriberProbe[ByteString] + + val messageIn = StreamTestKit.SubscriberProbe[Message] + val messageOut = StreamTestKit.PublisherProbe[Message] + + val messageHandler: Flow[Message, Message, Unit] = + Flow.wrap { + FlowGraph.partial() { implicit b ⇒ + val in = b.add(Sink(messageIn)) + val out = b.add(Source(messageOut)) + + FlowShape[Message, Message](in, out) + } + } + + Source(netIn) + .via(printEvent("netIn")) + .transform(() ⇒ new FrameEventParser) + .via(Websocket.handleMessages(messageHandler, serverSide, closeTimeout = closeTimeout)) + .via(printEvent("frameRendererIn")) + .transform(() ⇒ new FrameEventRenderer) + .via(printEvent("frameRendererOut")) + .to(Sink(netOut)) + .run() + + val netInSub = netIn.expectSubscription() + val netOutSub = netOut.expectSubscription() + val messageOutSub = messageOut.expectSubscription() + val messageInSub = messageIn.expectSubscription() + + def pushInput(data: ByteString): Unit = { + // TODO: expect/handle request? + netInSub.sendNext(data) + } + def pushMessage(msg: Message): Unit = { + messageOutSub.sendNext(msg) + } + + def expectMessage(message: Message): Unit = { + messageInSub.request(1) + messageIn.expectNext(message) + } + def expectMessage(): Message = { + messageInSub.request(1) + messageIn.expectNext() + } + + var inBuffer = ByteString.empty + @tailrec final def expectNetworkData(bytes: Int): ByteString = + if (inBuffer.size >= bytes) { + val res = inBuffer.take(bytes) + inBuffer = inBuffer.drop(bytes) + res + } else { + netOutSub.request(1) + inBuffer ++= netOut.expectNext() + expectNetworkData(bytes) + } + + def expectNetworkData(data: ByteString): Unit = + expectNetworkData(data.size) shouldEqual data + + def expectFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = { + expectFrameHeaderOnNetwork(opcode, data.size, fin) + expectNetworkData(data) + } + def expectMaskedFrameOnNetwork(opcode: Opcode, data: ByteString, fin: Boolean): Unit = { + val Some(mask) = expectFrameHeaderOnNetwork(opcode, data.size, fin) + val masked = maskedBytes(data, mask)._1 + expectNetworkData(masked) + } + + /** Returns the mask if any is available */ + def expectFrameHeaderOnNetwork(opcode: Opcode, length: Long, fin: Boolean): Option[Int] = { + val (op, l, f, m) = expectFrameHeaderOnNetwork() + op shouldEqual opcode + l shouldEqual length + f shouldEqual fin + m + } + def expectFrameHeaderOnNetwork(): (Opcode, Long, Boolean, Option[Int]) = { + val header = expectNetworkData(2) + + val fin = (header(0) & Protocol.FIN_MASK) != 0 + val op = header(0) & Protocol.OP_MASK + + val hasMask = (header(1) & Protocol.MASK_MASK) != 0 + val length7 = header(1) & Protocol.LENGTH_MASK + val length = length7 match { + case 126 ⇒ + val length16Bytes = expectNetworkData(2) + (length16Bytes(0) & 0xff) << 8 | (length16Bytes(1) & 0xff) << 0 + case 127 ⇒ + val length64Bytes = expectNetworkData(8) + (length64Bytes(0) & 0xff).toLong << 56 | + (length64Bytes(1) & 0xff).toLong << 48 | + (length64Bytes(2) & 0xff).toLong << 40 | + (length64Bytes(3) & 0xff).toLong << 32 | + (length64Bytes(4) & 0xff).toLong << 24 | + (length64Bytes(5) & 0xff).toLong << 16 | + (length64Bytes(6) & 0xff).toLong << 8 | + (length64Bytes(7) & 0xff).toLong << 0 + case x ⇒ x + } + val mask = + if (hasMask) { + val maskBytes = expectNetworkData(4) + val mask = + (maskBytes(0) & 0xff) << 24 | + (maskBytes(1) & 0xff) << 16 | + (maskBytes(2) & 0xff) << 8 | + (maskBytes(3) & 0xff) << 0 + Some(mask) + } else None + + (Opcode.forCode(op.toByte), length, fin, mask) + } + + def expectProtocolErrorOnNetwork(): Unit = expectCloseCodeOnNetwork(Protocol.CloseCodes.ProtocolError) + def expectCloseCodeOnNetwork(expectedCode: Int): Unit = { + val (opcode, length, true, mask) = expectFrameHeaderOnNetwork() + opcode shouldEqual Opcode.Close + length should be >= 2.toLong + + val rawData = expectNetworkData(length.toInt) + val data = mask match { + case Some(m) ⇒ FrameEventParser.mask(rawData, m)._1 + case None ⇒ rawData + } + + val code = ((data(0) & 0xff) << 8) | ((data(1) & 0xff) << 0) + code shouldEqual expectedCode + } + + def expectNoNetworkData(): Unit = + netOut.expectNoMsg(100.millis) + } + + def frameHeader( + opcode: Opcode, + length: Long, + fin: Boolean, + mask: Option[Int] = None, + rsv1: Boolean = false, + rsv2: Boolean = false, + rsv3: Boolean = false): ByteString = { + def set(should: Boolean, mask: Int): Int = + if (should) mask else 0 + + val flags = + set(fin, Protocol.FIN_MASK) | + set(rsv1, Protocol.RSV1_MASK) | + set(rsv2, Protocol.RSV2_MASK) | + set(rsv3, Protocol.RSV3_MASK) + + val opcodeByte = opcode.code | flags + + require(length >= 0) + val (lengthByteComponent, lengthBytes) = + if (length < 126) (length.toByte, ByteString.empty) + else if (length < 65536) (126.toByte, shortBE(length.toInt)) + else throw new IllegalArgumentException("Only lengths < 65536 allowed in test") + + val maskMask = if (mask.isDefined) Protocol.MASK_MASK else 0 + val maskBytes = mask match { + case Some(mask) ⇒ intBE(mask) + case None ⇒ ByteString.empty + } + val lengthByte = lengthByteComponent | maskMask + ByteString(opcodeByte.toByte, lengthByte.toByte) ++ lengthBytes ++ maskBytes + } + def closeFrame(closeCode: Int, mask: Boolean): ByteString = + if (mask) { + val mask = Random.nextInt() + frameHeader(Opcode.Close, 2, fin = true, mask = Some(mask)) ++ + maskedBytes(shortBE(closeCode), mask)._1 + } else + frameHeader(Opcode.Close, 2, fin = true) ++ + shortBE(closeCode) + + def maskedASCII(str: String, mask: Int): (ByteString, Int) = + FrameEventParser.mask(ByteString(str, "ASCII"), mask) + def maskedUTF8(str: String, mask: Int): (ByteString, Int) = + FrameEventParser.mask(ByteString(str, "UTF-8"), mask) + def maskedBytes(bytes: ByteString, mask: Int): (ByteString, Int) = + FrameEventParser.mask(bytes, mask) + + def shortBE(value: Int): ByteString = { + require(value >= 0 && value < 65536, s"Value wasn't in short range: $value") + ByteString( + ((value >> 8) & 0xff).toByte, + ((value >> 0) & 0xff).toByte) + } + def intBE(value: Int): ByteString = + ByteString( + ((value >> 24) & 0xff).toByte, + ((value >> 16) & 0xff).toByte, + ((value >> 8) & 0xff).toByte, + ((value >> 0) & 0xff).toByte) + + val trace = false // set to `true` for debugging purposes + def printEvent[T](marker: String): Flow[T, T, Unit] = + if (trace) akka.http.util.printEvent(marker) + else Flow[T] +}