parent
1756541dcb
commit
6d2d4d5d25
1 changed files with 149 additions and 142 deletions
|
|
@ -6,9 +6,11 @@ package akka.http.impl.engine.ws
|
|||
|
||||
import akka.NotUsed
|
||||
import akka.stream.scaladsl.Flow
|
||||
import akka.stream.stage.{ SyncDirective, Context, StatefulStage }
|
||||
import akka.util.ByteString
|
||||
import Protocol.Opcode
|
||||
import akka.event.Logging
|
||||
import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }
|
||||
import akka.stream.{ Attributes, FlowShape, Inlet, Outlet }
|
||||
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
|
|
@ -21,156 +23,161 @@ import scala.util.control.NonFatal
|
|||
private[http] object FrameHandler {
|
||||
|
||||
def create(server: Boolean): Flow[FrameEventOrError, Output, NotUsed] =
|
||||
Flow[FrameEventOrError].transform(() ⇒ new HandlerStage(server))
|
||||
Flow[FrameEventOrError].via(new HandlerStage(server))
|
||||
|
||||
private class HandlerStage(server: Boolean) extends StatefulStage[FrameEventOrError, Output] {
|
||||
type Ctx = Context[Output]
|
||||
def initial: State = Idle
|
||||
private class HandlerStage(server: Boolean) extends GraphStage[FlowShape[FrameEventOrError, Output]] {
|
||||
val in = Inlet[FrameEventOrError](Logging.simpleName(this) + ".in")
|
||||
val out = Outlet[Output](Logging.simpleName(this) + ".out")
|
||||
override val shape = FlowShape(in, out)
|
||||
|
||||
override def toString: String = s"HandlerStage(server=$server)"
|
||||
|
||||
private object Idle extends StateWithControlFrameHandling {
|
||||
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
|
||||
override def createLogic(attributes: Attributes): GraphStageLogic =
|
||||
new GraphStageLogic(shape) with OutHandler {
|
||||
setHandler(out, this)
|
||||
setHandler(in, IdleHandler)
|
||||
|
||||
override def onPull(): Unit = pull(in)
|
||||
|
||||
private object IdleHandler extends ControlFrameStartHandler {
|
||||
def setAndHandleFrameStartWith(newHandler: ControlFrameStartHandler, start: FrameStart): Unit = {
|
||||
setHandler(in, newHandler)
|
||||
newHandler.handleFrameStart(start)
|
||||
}
|
||||
|
||||
override def handleRegularFrameStart(start: FrameStart): Unit =
|
||||
(start.header.opcode, start.isFullMessage) match {
|
||||
case (Opcode.Binary, true) ⇒ publishMessagePart(BinaryMessagePart(start.data, last = true))
|
||||
case (Opcode.Binary, false) ⇒ becomeAndHandleWith(new CollectingBinaryMessage, start)
|
||||
case (Opcode.Text, _) ⇒ becomeAndHandleWith(new CollectingTextMessage, start)
|
||||
case x ⇒ protocolError()
|
||||
case (Opcode.Binary, false) ⇒ setAndHandleFrameStartWith(new BinaryMessagehandler, start)
|
||||
case (Opcode.Text, _) ⇒ setAndHandleFrameStartWith(new TextMessageHandler, start)
|
||||
case x ⇒ pushProtocolError()
|
||||
}
|
||||
}
|
||||
|
||||
private class CollectingBinaryMessage extends CollectingMessageFrame(Opcode.Binary) {
|
||||
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = BinaryMessagePart(data, last)
|
||||
private class BinaryMessagehandler extends MessageHandler(Opcode.Binary) {
|
||||
override def createMessagePart(data: ByteString, last: Boolean): MessageDataPart =
|
||||
BinaryMessagePart(data, last)
|
||||
}
|
||||
private class CollectingTextMessage extends CollectingMessageFrame(Opcode.Text) {
|
||||
|
||||
private class TextMessageHandler extends MessageHandler(Opcode.Text) {
|
||||
val decoder = Utf8Decoder.create()
|
||||
|
||||
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart =
|
||||
override def createMessagePart(data: ByteString, last: Boolean): MessageDataPart =
|
||||
TextMessagePart(decoder.decode(data, endOfInput = last).get, last)
|
||||
}
|
||||
|
||||
private abstract class CollectingMessageFrame(expectedOpcode: Opcode) extends StateWithControlFrameHandling {
|
||||
private abstract class MessageHandler(expectedOpcode: Opcode) extends ControlFrameStartHandler {
|
||||
var expectFirstHeader = true
|
||||
var finSeen = false
|
||||
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart
|
||||
|
||||
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = {
|
||||
override def handleRegularFrameStart(start: FrameStart): Unit = {
|
||||
if ((expectFirstHeader && start.header.opcode == expectedOpcode) // first opcode must be the expected
|
||||
|| start.header.opcode == Opcode.Continuation) { // further ones continuations
|
||||
expectFirstHeader = false
|
||||
|
||||
if (start.header.fin) finSeen = true
|
||||
publish(start)
|
||||
} else protocolError()
|
||||
} else pushProtocolError()
|
||||
}
|
||||
override def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = publish(data)
|
||||
|
||||
private def publish(part: FrameEvent)(implicit ctx: Ctx): SyncDirective =
|
||||
try publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart))
|
||||
catch {
|
||||
override def handleFrameData(data: FrameData): Unit = publish(data)
|
||||
|
||||
def publish(part: FrameEvent): Unit = try {
|
||||
publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart))
|
||||
} catch {
|
||||
case NonFatal(e) ⇒ closeWithCode(Protocol.CloseCodes.InconsistentData)
|
||||
}
|
||||
}
|
||||
|
||||
private class CollectingControlFrame(opcode: Opcode, _data: ByteString, nextState: State) extends InFrameState {
|
||||
private trait ControlFrameStartHandler extends FrameHandler {
|
||||
def handleRegularFrameStart(start: FrameStart): Unit
|
||||
|
||||
override def handleFrameStart(start: FrameStart): Unit = start.header match {
|
||||
case h: FrameHeader if h.mask.isDefined && !server ⇒ pushProtocolError()
|
||||
case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3 ⇒ pushProtocolError()
|
||||
case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) ⇒ pushProtocolError()
|
||||
case h: FrameHeader if h.opcode.isControl ⇒
|
||||
if (start.isFullMessage) handleControlFrame(h.opcode, start.data, this)
|
||||
else collectControlFrame(start, this)
|
||||
case _ ⇒ handleRegularFrameStart(start)
|
||||
}
|
||||
|
||||
override def handleFrameData(data: FrameData): Unit =
|
||||
throw new IllegalStateException("Expected FrameStart")
|
||||
}
|
||||
|
||||
private class ControlFrameDataHandler(opcode: Opcode, _data: ByteString, nextHandler: InHandler) extends FrameHandler {
|
||||
var data = _data
|
||||
|
||||
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = {
|
||||
override def handleFrameData(data: FrameData): Unit = {
|
||||
this.data ++= data.data
|
||||
if (data.lastPart) handleControlFrame(opcode, this.data, nextState)
|
||||
else ctx.pull()
|
||||
}
|
||||
if (data.lastPart) handleControlFrame(opcode, this.data, nextHandler)
|
||||
else pull(in)
|
||||
}
|
||||
|
||||
private def becomeAndHandleWith(newState: State, part: FrameEvent)(implicit ctx: Ctx): SyncDirective = {
|
||||
become(newState)
|
||||
current.onPush(part, ctx)
|
||||
override def handleFrameStart(start: FrameStart): Unit =
|
||||
throw new IllegalStateException("Expected FrameData")
|
||||
}
|
||||
|
||||
/** Returns a SyncDirective if it handled the message */
|
||||
private def validateHeader(header: FrameHeader)(implicit ctx: Ctx): Option[SyncDirective] = header match {
|
||||
case h: FrameHeader if h.mask.isDefined && !server ⇒ Some(protocolError())
|
||||
case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3 ⇒ Some(protocolError())
|
||||
case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) ⇒ Some(protocolError())
|
||||
case _ ⇒ None
|
||||
}
|
||||
private trait FrameHandler extends InHandler {
|
||||
def handleFrameData(data: FrameData): Unit
|
||||
def handleFrameStart(start: FrameStart): Unit
|
||||
|
||||
private def handleControlFrame(opcode: Opcode, data: ByteString, nextState: State)(implicit ctx: Ctx): SyncDirective = {
|
||||
become(nextState)
|
||||
def handleControlFrame(opcode: Opcode, data: ByteString, nextHandler: InHandler): Unit = {
|
||||
setHandler(in, nextHandler)
|
||||
opcode match {
|
||||
case Opcode.Ping ⇒ publishDirectResponse(FrameEvent.fullFrame(Opcode.Pong, None, data, fin = true))
|
||||
case Opcode.Pong ⇒
|
||||
// ignore unsolicited Pong frame
|
||||
ctx.pull()
|
||||
pull(in)
|
||||
case Opcode.Close ⇒
|
||||
become(WaitForPeerTcpClose)
|
||||
ctx.push(PeerClosed.parse(data))
|
||||
setHandler(in, WaitForPeerTcpClose)
|
||||
push(out, PeerClosed.parse(data))
|
||||
case Opcode.Other(o) ⇒ closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode")
|
||||
case other ⇒ ctx.fail(new IllegalStateException(s"unexpected message of type [${other.getClass.getName}] when expecting ControlFrame"))
|
||||
case other ⇒ failStage(
|
||||
new IllegalStateException(s"unexpected message of type [${other.getClass.getName}] when expecting ControlFrame")
|
||||
)
|
||||
}
|
||||
}
|
||||
private def collectControlFrame(start: FrameStart, nextState: State)(implicit ctx: Ctx): SyncDirective = {
|
||||
|
||||
def pushProtocolError(): Unit = closeWithCode(Protocol.CloseCodes.ProtocolError)
|
||||
|
||||
def closeWithCode(closeCode: Int, reason: String = ""): Unit = {
|
||||
setHandler(in, CloseAfterPeerClosed)
|
||||
push(out, ActivelyCloseWithCode(Some(closeCode), reason))
|
||||
}
|
||||
|
||||
def collectControlFrame(start: FrameStart, nextHandler: InHandler): Unit = {
|
||||
require(!start.isFullMessage)
|
||||
become(new CollectingControlFrame(start.header.opcode, start.data, nextState))
|
||||
ctx.pull()
|
||||
setHandler(in, new ControlFrameDataHandler(start.header.opcode, start.data, nextHandler))
|
||||
pull(in)
|
||||
}
|
||||
|
||||
private def publishMessagePart(part: MessageDataPart)(implicit ctx: Ctx): SyncDirective =
|
||||
if (part.last) emit(Iterator(part, MessageEnd), ctx, Idle)
|
||||
else ctx.push(part)
|
||||
private def publishDirectResponse(frame: FrameStart)(implicit ctx: Ctx): SyncDirective =
|
||||
ctx.push(DirectAnswer(frame))
|
||||
def publishMessagePart(part: MessageDataPart): Unit =
|
||||
if (part.last) emitMultiple(out, Iterator(part, MessageEnd), () ⇒ setHandler(in, IdleHandler))
|
||||
else push(out, part)
|
||||
|
||||
private def protocolError(reason: String = "")(implicit ctx: Ctx): SyncDirective =
|
||||
closeWithCode(Protocol.CloseCodes.ProtocolError, reason)
|
||||
def publishDirectResponse(frame: FrameStart): Unit = push(out, DirectAnswer(frame))
|
||||
|
||||
private def closeWithCode(closeCode: Int, reason: String = "", cause: Throwable = null)(implicit ctx: Ctx): SyncDirective = {
|
||||
become(CloseAfterPeerClosed)
|
||||
ctx.push(ActivelyCloseWithCode(Some(closeCode), reason))
|
||||
override def onPush(): Unit = grab(in) match {
|
||||
case data: FrameData ⇒ handleFrameData(data)
|
||||
case start: FrameStart ⇒ handleFrameStart(start)
|
||||
case FrameError(ex) ⇒ failStage(ex)
|
||||
}
|
||||
}
|
||||
|
||||
private object CloseAfterPeerClosed extends State {
|
||||
def onPush(elem: FrameEventOrError, ctx: Context[Output]): SyncDirective =
|
||||
elem match {
|
||||
private object CloseAfterPeerClosed extends InHandler {
|
||||
override def onPush(): Unit = grab(in) match {
|
||||
case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data) ⇒
|
||||
become(WaitForPeerTcpClose)
|
||||
ctx.push(PeerClosed.parse(data))
|
||||
case _ ⇒ ctx.pull() // ignore all other data
|
||||
setHandler(in, WaitForPeerTcpClose)
|
||||
push(out, PeerClosed.parse(data))
|
||||
case _ ⇒ pull(in) // ignore all other data
|
||||
}
|
||||
}
|
||||
private object WaitForPeerTcpClose extends State {
|
||||
def onPush(elem: FrameEventOrError, ctx: Context[Output]): SyncDirective =
|
||||
ctx.pull() // ignore
|
||||
}
|
||||
|
||||
private abstract class StateWithControlFrameHandling extends BetweenFrameState {
|
||||
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective
|
||||
|
||||
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
|
||||
validateHeader(start.header).getOrElse {
|
||||
if (start.header.opcode.isControl)
|
||||
if (start.isFullMessage) handleControlFrame(start.header.opcode, start.data, this)
|
||||
else collectControlFrame(start, this)
|
||||
else handleRegularFrameStart(start)
|
||||
}
|
||||
}
|
||||
private abstract class BetweenFrameState extends ImplicitContextState {
|
||||
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective =
|
||||
throw new IllegalStateException("Expected FrameStart")
|
||||
}
|
||||
private abstract class InFrameState extends ImplicitContextState {
|
||||
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
|
||||
throw new IllegalStateException("Expected FrameData")
|
||||
}
|
||||
private abstract class ImplicitContextState extends State {
|
||||
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective
|
||||
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective
|
||||
|
||||
def onPush(part: FrameEventOrError, ctx: Ctx): SyncDirective =
|
||||
part match {
|
||||
case data: FrameData ⇒ handleFrameData(data)(ctx)
|
||||
case start: FrameStart ⇒ handleFrameStart(start)(ctx)
|
||||
case FrameError(ex) ⇒ ctx.fail(ex)
|
||||
private object WaitForPeerTcpClose extends InHandler {
|
||||
override def onPush(): Unit = pull(in) // ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue