parent
1756541dcb
commit
6d2d4d5d25
1 changed files with 149 additions and 142 deletions
|
|
@ -6,9 +6,11 @@ package akka.http.impl.engine.ws
|
||||||
|
|
||||||
import akka.NotUsed
|
import akka.NotUsed
|
||||||
import akka.stream.scaladsl.Flow
|
import akka.stream.scaladsl.Flow
|
||||||
import akka.stream.stage.{ SyncDirective, Context, StatefulStage }
|
|
||||||
import akka.util.ByteString
|
import akka.util.ByteString
|
||||||
import Protocol.Opcode
|
import Protocol.Opcode
|
||||||
|
import akka.event.Logging
|
||||||
|
import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }
|
||||||
|
import akka.stream.{ Attributes, FlowShape, Inlet, Outlet }
|
||||||
|
|
||||||
import scala.util.control.NonFatal
|
import scala.util.control.NonFatal
|
||||||
|
|
||||||
|
|
@ -21,156 +23,161 @@ import scala.util.control.NonFatal
|
||||||
private[http] object FrameHandler {
|
private[http] object FrameHandler {
|
||||||
|
|
||||||
def create(server: Boolean): Flow[FrameEventOrError, Output, NotUsed] =
|
def create(server: Boolean): Flow[FrameEventOrError, Output, NotUsed] =
|
||||||
Flow[FrameEventOrError].transform(() ⇒ new HandlerStage(server))
|
Flow[FrameEventOrError].via(new HandlerStage(server))
|
||||||
|
|
||||||
private class HandlerStage(server: Boolean) extends StatefulStage[FrameEventOrError, Output] {
|
private class HandlerStage(server: Boolean) extends GraphStage[FlowShape[FrameEventOrError, Output]] {
|
||||||
type Ctx = Context[Output]
|
val in = Inlet[FrameEventOrError](Logging.simpleName(this) + ".in")
|
||||||
def initial: State = Idle
|
val out = Outlet[Output](Logging.simpleName(this) + ".out")
|
||||||
|
override val shape = FlowShape(in, out)
|
||||||
|
|
||||||
override def toString: String = s"HandlerStage(server=$server)"
|
override def toString: String = s"HandlerStage(server=$server)"
|
||||||
|
|
||||||
private object Idle extends StateWithControlFrameHandling {
|
override def createLogic(attributes: Attributes): GraphStageLogic =
|
||||||
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
|
new GraphStageLogic(shape) with OutHandler {
|
||||||
|
setHandler(out, this)
|
||||||
|
setHandler(in, IdleHandler)
|
||||||
|
|
||||||
|
override def onPull(): Unit = pull(in)
|
||||||
|
|
||||||
|
private object IdleHandler extends ControlFrameStartHandler {
|
||||||
|
def setAndHandleFrameStartWith(newHandler: ControlFrameStartHandler, start: FrameStart): Unit = {
|
||||||
|
setHandler(in, newHandler)
|
||||||
|
newHandler.handleFrameStart(start)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def handleRegularFrameStart(start: FrameStart): Unit =
|
||||||
(start.header.opcode, start.isFullMessage) match {
|
(start.header.opcode, start.isFullMessage) match {
|
||||||
case (Opcode.Binary, true) ⇒ publishMessagePart(BinaryMessagePart(start.data, last = true))
|
case (Opcode.Binary, true) ⇒ publishMessagePart(BinaryMessagePart(start.data, last = true))
|
||||||
case (Opcode.Binary, false) ⇒ becomeAndHandleWith(new CollectingBinaryMessage, start)
|
case (Opcode.Binary, false) ⇒ setAndHandleFrameStartWith(new BinaryMessagehandler, start)
|
||||||
case (Opcode.Text, _) ⇒ becomeAndHandleWith(new CollectingTextMessage, start)
|
case (Opcode.Text, _) ⇒ setAndHandleFrameStartWith(new TextMessageHandler, start)
|
||||||
case x ⇒ protocolError()
|
case x ⇒ pushProtocolError()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private class CollectingBinaryMessage extends CollectingMessageFrame(Opcode.Binary) {
|
private class BinaryMessagehandler extends MessageHandler(Opcode.Binary) {
|
||||||
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = BinaryMessagePart(data, last)
|
override def createMessagePart(data: ByteString, last: Boolean): MessageDataPart =
|
||||||
|
BinaryMessagePart(data, last)
|
||||||
}
|
}
|
||||||
private class CollectingTextMessage extends CollectingMessageFrame(Opcode.Text) {
|
|
||||||
|
private class TextMessageHandler extends MessageHandler(Opcode.Text) {
|
||||||
val decoder = Utf8Decoder.create()
|
val decoder = Utf8Decoder.create()
|
||||||
|
|
||||||
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart =
|
override def createMessagePart(data: ByteString, last: Boolean): MessageDataPart =
|
||||||
TextMessagePart(decoder.decode(data, endOfInput = last).get, last)
|
TextMessagePart(decoder.decode(data, endOfInput = last).get, last)
|
||||||
}
|
}
|
||||||
|
|
||||||
private abstract class CollectingMessageFrame(expectedOpcode: Opcode) extends StateWithControlFrameHandling {
|
private abstract class MessageHandler(expectedOpcode: Opcode) extends ControlFrameStartHandler {
|
||||||
var expectFirstHeader = true
|
var expectFirstHeader = true
|
||||||
var finSeen = false
|
var finSeen = false
|
||||||
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart
|
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart
|
||||||
|
|
||||||
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = {
|
override def handleRegularFrameStart(start: FrameStart): Unit = {
|
||||||
if ((expectFirstHeader && start.header.opcode == expectedOpcode) // first opcode must be the expected
|
if ((expectFirstHeader && start.header.opcode == expectedOpcode) // first opcode must be the expected
|
||||||
|| start.header.opcode == Opcode.Continuation) { // further ones continuations
|
|| start.header.opcode == Opcode.Continuation) { // further ones continuations
|
||||||
expectFirstHeader = false
|
expectFirstHeader = false
|
||||||
|
|
||||||
if (start.header.fin) finSeen = true
|
if (start.header.fin) finSeen = true
|
||||||
publish(start)
|
publish(start)
|
||||||
} else protocolError()
|
} else pushProtocolError()
|
||||||
}
|
}
|
||||||
override def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = publish(data)
|
|
||||||
|
|
||||||
private def publish(part: FrameEvent)(implicit ctx: Ctx): SyncDirective =
|
override def handleFrameData(data: FrameData): Unit = publish(data)
|
||||||
try publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart))
|
|
||||||
catch {
|
def publish(part: FrameEvent): Unit = try {
|
||||||
|
publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart))
|
||||||
|
} catch {
|
||||||
case NonFatal(e) ⇒ closeWithCode(Protocol.CloseCodes.InconsistentData)
|
case NonFatal(e) ⇒ closeWithCode(Protocol.CloseCodes.InconsistentData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private class CollectingControlFrame(opcode: Opcode, _data: ByteString, nextState: State) extends InFrameState {
|
private trait ControlFrameStartHandler extends FrameHandler {
|
||||||
|
def handleRegularFrameStart(start: FrameStart): Unit
|
||||||
|
|
||||||
|
override def handleFrameStart(start: FrameStart): Unit = start.header match {
|
||||||
|
case h: FrameHeader if h.mask.isDefined && !server ⇒ pushProtocolError()
|
||||||
|
case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3 ⇒ pushProtocolError()
|
||||||
|
case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) ⇒ pushProtocolError()
|
||||||
|
case h: FrameHeader if h.opcode.isControl ⇒
|
||||||
|
if (start.isFullMessage) handleControlFrame(h.opcode, start.data, this)
|
||||||
|
else collectControlFrame(start, this)
|
||||||
|
case _ ⇒ handleRegularFrameStart(start)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def handleFrameData(data: FrameData): Unit =
|
||||||
|
throw new IllegalStateException("Expected FrameStart")
|
||||||
|
}
|
||||||
|
|
||||||
|
private class ControlFrameDataHandler(opcode: Opcode, _data: ByteString, nextHandler: InHandler) extends FrameHandler {
|
||||||
var data = _data
|
var data = _data
|
||||||
|
|
||||||
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = {
|
override def handleFrameData(data: FrameData): Unit = {
|
||||||
this.data ++= data.data
|
this.data ++= data.data
|
||||||
if (data.lastPart) handleControlFrame(opcode, this.data, nextState)
|
if (data.lastPart) handleControlFrame(opcode, this.data, nextHandler)
|
||||||
else ctx.pull()
|
else pull(in)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private def becomeAndHandleWith(newState: State, part: FrameEvent)(implicit ctx: Ctx): SyncDirective = {
|
override def handleFrameStart(start: FrameStart): Unit =
|
||||||
become(newState)
|
throw new IllegalStateException("Expected FrameData")
|
||||||
current.onPush(part, ctx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns a SyncDirective if it handled the message */
|
private trait FrameHandler extends InHandler {
|
||||||
private def validateHeader(header: FrameHeader)(implicit ctx: Ctx): Option[SyncDirective] = header match {
|
def handleFrameData(data: FrameData): Unit
|
||||||
case h: FrameHeader if h.mask.isDefined && !server ⇒ Some(protocolError())
|
def handleFrameStart(start: FrameStart): Unit
|
||||||
case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3 ⇒ Some(protocolError())
|
|
||||||
case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) ⇒ Some(protocolError())
|
|
||||||
case _ ⇒ None
|
|
||||||
}
|
|
||||||
|
|
||||||
private def handleControlFrame(opcode: Opcode, data: ByteString, nextState: State)(implicit ctx: Ctx): SyncDirective = {
|
def handleControlFrame(opcode: Opcode, data: ByteString, nextHandler: InHandler): Unit = {
|
||||||
become(nextState)
|
setHandler(in, nextHandler)
|
||||||
opcode match {
|
opcode match {
|
||||||
case Opcode.Ping ⇒ publishDirectResponse(FrameEvent.fullFrame(Opcode.Pong, None, data, fin = true))
|
case Opcode.Ping ⇒ publishDirectResponse(FrameEvent.fullFrame(Opcode.Pong, None, data, fin = true))
|
||||||
case Opcode.Pong ⇒
|
case Opcode.Pong ⇒
|
||||||
// ignore unsolicited Pong frame
|
// ignore unsolicited Pong frame
|
||||||
ctx.pull()
|
pull(in)
|
||||||
case Opcode.Close ⇒
|
case Opcode.Close ⇒
|
||||||
become(WaitForPeerTcpClose)
|
setHandler(in, WaitForPeerTcpClose)
|
||||||
ctx.push(PeerClosed.parse(data))
|
push(out, PeerClosed.parse(data))
|
||||||
case Opcode.Other(o) ⇒ closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode")
|
case Opcode.Other(o) ⇒ closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode")
|
||||||
case other ⇒ ctx.fail(new IllegalStateException(s"unexpected message of type [${other.getClass.getName}] when expecting ControlFrame"))
|
case other ⇒ failStage(
|
||||||
|
new IllegalStateException(s"unexpected message of type [${other.getClass.getName}] when expecting ControlFrame")
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
private def collectControlFrame(start: FrameStart, nextState: State)(implicit ctx: Ctx): SyncDirective = {
|
|
||||||
|
def pushProtocolError(): Unit = closeWithCode(Protocol.CloseCodes.ProtocolError)
|
||||||
|
|
||||||
|
def closeWithCode(closeCode: Int, reason: String = ""): Unit = {
|
||||||
|
setHandler(in, CloseAfterPeerClosed)
|
||||||
|
push(out, ActivelyCloseWithCode(Some(closeCode), reason))
|
||||||
|
}
|
||||||
|
|
||||||
|
def collectControlFrame(start: FrameStart, nextHandler: InHandler): Unit = {
|
||||||
require(!start.isFullMessage)
|
require(!start.isFullMessage)
|
||||||
become(new CollectingControlFrame(start.header.opcode, start.data, nextState))
|
setHandler(in, new ControlFrameDataHandler(start.header.opcode, start.data, nextHandler))
|
||||||
ctx.pull()
|
pull(in)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def publishMessagePart(part: MessageDataPart)(implicit ctx: Ctx): SyncDirective =
|
def publishMessagePart(part: MessageDataPart): Unit =
|
||||||
if (part.last) emit(Iterator(part, MessageEnd), ctx, Idle)
|
if (part.last) emitMultiple(out, Iterator(part, MessageEnd), () ⇒ setHandler(in, IdleHandler))
|
||||||
else ctx.push(part)
|
else push(out, part)
|
||||||
private def publishDirectResponse(frame: FrameStart)(implicit ctx: Ctx): SyncDirective =
|
|
||||||
ctx.push(DirectAnswer(frame))
|
|
||||||
|
|
||||||
private def protocolError(reason: String = "")(implicit ctx: Ctx): SyncDirective =
|
def publishDirectResponse(frame: FrameStart): Unit = push(out, DirectAnswer(frame))
|
||||||
closeWithCode(Protocol.CloseCodes.ProtocolError, reason)
|
|
||||||
|
|
||||||
private def closeWithCode(closeCode: Int, reason: String = "", cause: Throwable = null)(implicit ctx: Ctx): SyncDirective = {
|
override def onPush(): Unit = grab(in) match {
|
||||||
become(CloseAfterPeerClosed)
|
case data: FrameData ⇒ handleFrameData(data)
|
||||||
ctx.push(ActivelyCloseWithCode(Some(closeCode), reason))
|
case start: FrameStart ⇒ handleFrameStart(start)
|
||||||
|
case FrameError(ex) ⇒ failStage(ex)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private object CloseAfterPeerClosed extends State {
|
private object CloseAfterPeerClosed extends InHandler {
|
||||||
def onPush(elem: FrameEventOrError, ctx: Context[Output]): SyncDirective =
|
override def onPush(): Unit = grab(in) match {
|
||||||
elem match {
|
|
||||||
case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data) ⇒
|
case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data) ⇒
|
||||||
become(WaitForPeerTcpClose)
|
setHandler(in, WaitForPeerTcpClose)
|
||||||
ctx.push(PeerClosed.parse(data))
|
push(out, PeerClosed.parse(data))
|
||||||
case _ ⇒ ctx.pull() // ignore all other data
|
case _ ⇒ pull(in) // ignore all other data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
private object WaitForPeerTcpClose extends State {
|
|
||||||
def onPush(elem: FrameEventOrError, ctx: Context[Output]): SyncDirective =
|
|
||||||
ctx.pull() // ignore
|
|
||||||
}
|
|
||||||
|
|
||||||
private abstract class StateWithControlFrameHandling extends BetweenFrameState {
|
private object WaitForPeerTcpClose extends InHandler {
|
||||||
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective
|
override def onPush(): Unit = pull(in) // ignore
|
||||||
|
|
||||||
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
|
|
||||||
validateHeader(start.header).getOrElse {
|
|
||||||
if (start.header.opcode.isControl)
|
|
||||||
if (start.isFullMessage) handleControlFrame(start.header.opcode, start.data, this)
|
|
||||||
else collectControlFrame(start, this)
|
|
||||||
else handleRegularFrameStart(start)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
private abstract class BetweenFrameState extends ImplicitContextState {
|
|
||||||
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective =
|
|
||||||
throw new IllegalStateException("Expected FrameStart")
|
|
||||||
}
|
|
||||||
private abstract class InFrameState extends ImplicitContextState {
|
|
||||||
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
|
|
||||||
throw new IllegalStateException("Expected FrameData")
|
|
||||||
}
|
|
||||||
private abstract class ImplicitContextState extends State {
|
|
||||||
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective
|
|
||||||
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective
|
|
||||||
|
|
||||||
def onPush(part: FrameEventOrError, ctx: Ctx): SyncDirective =
|
|
||||||
part match {
|
|
||||||
case data: FrameData ⇒ handleFrameData(data)(ctx)
|
|
||||||
case start: FrameStart ⇒ handleFrameStart(start)(ctx)
|
|
||||||
case FrameError(ex) ⇒ ctx.fail(ex)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue