=htc Convert HandlerStage to GraphStage, #19361. (#20855)

This commit is contained in:
Daniel Moran 2016-07-08 07:44:54 -04:00 committed by Konrad Malawski
parent 1756541dcb
commit 6d2d4d5d25

View file

@ -6,9 +6,11 @@ package akka.http.impl.engine.ws
import akka.NotUsed import akka.NotUsed
import akka.stream.scaladsl.Flow import akka.stream.scaladsl.Flow
import akka.stream.stage.{ SyncDirective, Context, StatefulStage }
import akka.util.ByteString import akka.util.ByteString
import Protocol.Opcode import Protocol.Opcode
import akka.event.Logging
import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }
import akka.stream.{ Attributes, FlowShape, Inlet, Outlet }
import scala.util.control.NonFatal import scala.util.control.NonFatal
@ -21,158 +23,163 @@ import scala.util.control.NonFatal
private[http] object FrameHandler { private[http] object FrameHandler {
def create(server: Boolean): Flow[FrameEventOrError, Output, NotUsed] = def create(server: Boolean): Flow[FrameEventOrError, Output, NotUsed] =
Flow[FrameEventOrError].transform(() new HandlerStage(server)) Flow[FrameEventOrError].via(new HandlerStage(server))
private class HandlerStage(server: Boolean) extends StatefulStage[FrameEventOrError, Output] { private class HandlerStage(server: Boolean) extends GraphStage[FlowShape[FrameEventOrError, Output]] {
type Ctx = Context[Output] val in = Inlet[FrameEventOrError](Logging.simpleName(this) + ".in")
def initial: State = Idle val out = Outlet[Output](Logging.simpleName(this) + ".out")
override val shape = FlowShape(in, out)
override def toString: String = s"HandlerStage(server=$server)" override def toString: String = s"HandlerStage(server=$server)"
private object Idle extends StateWithControlFrameHandling { override def createLogic(attributes: Attributes): GraphStageLogic =
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = new GraphStageLogic(shape) with OutHandler {
(start.header.opcode, start.isFullMessage) match { setHandler(out, this)
case (Opcode.Binary, true) publishMessagePart(BinaryMessagePart(start.data, last = true)) setHandler(in, IdleHandler)
case (Opcode.Binary, false) becomeAndHandleWith(new CollectingBinaryMessage, start)
case (Opcode.Text, _) becomeAndHandleWith(new CollectingTextMessage, start) override def onPull(): Unit = pull(in)
case x protocolError()
private object IdleHandler extends ControlFrameStartHandler {
def setAndHandleFrameStartWith(newHandler: ControlFrameStartHandler, start: FrameStart): Unit = {
setHandler(in, newHandler)
newHandler.handleFrameStart(start)
}
override def handleRegularFrameStart(start: FrameStart): Unit =
(start.header.opcode, start.isFullMessage) match {
case (Opcode.Binary, true) publishMessagePart(BinaryMessagePart(start.data, last = true))
case (Opcode.Binary, false) setAndHandleFrameStartWith(new BinaryMessagehandler, start)
case (Opcode.Text, _) setAndHandleFrameStartWith(new TextMessageHandler, start)
case x pushProtocolError()
}
} }
}
private class CollectingBinaryMessage extends CollectingMessageFrame(Opcode.Binary) { private class BinaryMessagehandler extends MessageHandler(Opcode.Binary) {
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = BinaryMessagePart(data, last) override def createMessagePart(data: ByteString, last: Boolean): MessageDataPart =
} BinaryMessagePart(data, last)
private class CollectingTextMessage extends CollectingMessageFrame(Opcode.Text) { }
val decoder = Utf8Decoder.create()
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = private class TextMessageHandler extends MessageHandler(Opcode.Text) {
TextMessagePart(decoder.decode(data, endOfInput = last).get, last) val decoder = Utf8Decoder.create()
}
private abstract class CollectingMessageFrame(expectedOpcode: Opcode) extends StateWithControlFrameHandling { override def createMessagePart(data: ByteString, last: Boolean): MessageDataPart =
var expectFirstHeader = true TextMessagePart(decoder.decode(data, endOfInput = last).get, last)
var finSeen = false }
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = { private abstract class MessageHandler(expectedOpcode: Opcode) extends ControlFrameStartHandler {
if ((expectFirstHeader && start.header.opcode == expectedOpcode) // first opcode must be the expected var expectFirstHeader = true
|| start.header.opcode == Opcode.Continuation) { // further ones continuations var finSeen = false
expectFirstHeader = false def createMessagePart(data: ByteString, last: Boolean): MessageDataPart
if (start.header.fin) finSeen = true override def handleRegularFrameStart(start: FrameStart): Unit = {
publish(start) if ((expectFirstHeader && start.header.opcode == expectedOpcode) // first opcode must be the expected
} else protocolError() || start.header.opcode == Opcode.Continuation) { // further ones continuations
expectFirstHeader = false
if (start.header.fin) finSeen = true
publish(start)
} else pushProtocolError()
}
override def handleFrameData(data: FrameData): Unit = publish(data)
def publish(part: FrameEvent): Unit = try {
publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart))
} catch {
case NonFatal(e) closeWithCode(Protocol.CloseCodes.InconsistentData)
}
}
private trait ControlFrameStartHandler extends FrameHandler {
def handleRegularFrameStart(start: FrameStart): Unit
override def handleFrameStart(start: FrameStart): Unit = start.header match {
case h: FrameHeader if h.mask.isDefined && !server pushProtocolError()
case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3 pushProtocolError()
case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) pushProtocolError()
case h: FrameHeader if h.opcode.isControl
if (start.isFullMessage) handleControlFrame(h.opcode, start.data, this)
else collectControlFrame(start, this)
case _ handleRegularFrameStart(start)
}
override def handleFrameData(data: FrameData): Unit =
throw new IllegalStateException("Expected FrameStart")
}
private class ControlFrameDataHandler(opcode: Opcode, _data: ByteString, nextHandler: InHandler) extends FrameHandler {
var data = _data
override def handleFrameData(data: FrameData): Unit = {
this.data ++= data.data
if (data.lastPart) handleControlFrame(opcode, this.data, nextHandler)
else pull(in)
}
override def handleFrameStart(start: FrameStart): Unit =
throw new IllegalStateException("Expected FrameData")
}
private trait FrameHandler extends InHandler {
def handleFrameData(data: FrameData): Unit
def handleFrameStart(start: FrameStart): Unit
def handleControlFrame(opcode: Opcode, data: ByteString, nextHandler: InHandler): Unit = {
setHandler(in, nextHandler)
opcode match {
case Opcode.Ping publishDirectResponse(FrameEvent.fullFrame(Opcode.Pong, None, data, fin = true))
case Opcode.Pong
// ignore unsolicited Pong frame
pull(in)
case Opcode.Close
setHandler(in, WaitForPeerTcpClose)
push(out, PeerClosed.parse(data))
case Opcode.Other(o) closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode")
case other failStage(
new IllegalStateException(s"unexpected message of type [${other.getClass.getName}] when expecting ControlFrame")
)
}
}
def pushProtocolError(): Unit = closeWithCode(Protocol.CloseCodes.ProtocolError)
def closeWithCode(closeCode: Int, reason: String = ""): Unit = {
setHandler(in, CloseAfterPeerClosed)
push(out, ActivelyCloseWithCode(Some(closeCode), reason))
}
def collectControlFrame(start: FrameStart, nextHandler: InHandler): Unit = {
require(!start.isFullMessage)
setHandler(in, new ControlFrameDataHandler(start.header.opcode, start.data, nextHandler))
pull(in)
}
def publishMessagePart(part: MessageDataPart): Unit =
if (part.last) emitMultiple(out, Iterator(part, MessageEnd), () setHandler(in, IdleHandler))
else push(out, part)
def publishDirectResponse(frame: FrameStart): Unit = push(out, DirectAnswer(frame))
override def onPush(): Unit = grab(in) match {
case data: FrameData handleFrameData(data)
case start: FrameStart handleFrameStart(start)
case FrameError(ex) failStage(ex)
}
}
private object CloseAfterPeerClosed extends InHandler {
override def onPush(): Unit = grab(in) match {
case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data)
setHandler(in, WaitForPeerTcpClose)
push(out, PeerClosed.parse(data))
case _ pull(in) // ignore all other data
}
}
private object WaitForPeerTcpClose extends InHandler {
override def onPush(): Unit = pull(in) // ignore
}
} }
override def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = publish(data)
private def publish(part: FrameEvent)(implicit ctx: Ctx): SyncDirective =
try publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart))
catch {
case NonFatal(e) closeWithCode(Protocol.CloseCodes.InconsistentData)
}
}
private class CollectingControlFrame(opcode: Opcode, _data: ByteString, nextState: State) extends InFrameState {
var data = _data
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = {
this.data ++= data.data
if (data.lastPart) handleControlFrame(opcode, this.data, nextState)
else ctx.pull()
}
}
private def becomeAndHandleWith(newState: State, part: FrameEvent)(implicit ctx: Ctx): SyncDirective = {
become(newState)
current.onPush(part, ctx)
}
/** Returns a SyncDirective if it handled the message */
private def validateHeader(header: FrameHeader)(implicit ctx: Ctx): Option[SyncDirective] = header match {
case h: FrameHeader if h.mask.isDefined && !server Some(protocolError())
case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3 Some(protocolError())
case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) Some(protocolError())
case _ None
}
private def handleControlFrame(opcode: Opcode, data: ByteString, nextState: State)(implicit ctx: Ctx): SyncDirective = {
become(nextState)
opcode match {
case Opcode.Ping publishDirectResponse(FrameEvent.fullFrame(Opcode.Pong, None, data, fin = true))
case Opcode.Pong
// ignore unsolicited Pong frame
ctx.pull()
case Opcode.Close
become(WaitForPeerTcpClose)
ctx.push(PeerClosed.parse(data))
case Opcode.Other(o) closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode")
case other ctx.fail(new IllegalStateException(s"unexpected message of type [${other.getClass.getName}] when expecting ControlFrame"))
}
}
private def collectControlFrame(start: FrameStart, nextState: State)(implicit ctx: Ctx): SyncDirective = {
require(!start.isFullMessage)
become(new CollectingControlFrame(start.header.opcode, start.data, nextState))
ctx.pull()
}
private def publishMessagePart(part: MessageDataPart)(implicit ctx: Ctx): SyncDirective =
if (part.last) emit(Iterator(part, MessageEnd), ctx, Idle)
else ctx.push(part)
private def publishDirectResponse(frame: FrameStart)(implicit ctx: Ctx): SyncDirective =
ctx.push(DirectAnswer(frame))
private def protocolError(reason: String = "")(implicit ctx: Ctx): SyncDirective =
closeWithCode(Protocol.CloseCodes.ProtocolError, reason)
private def closeWithCode(closeCode: Int, reason: String = "", cause: Throwable = null)(implicit ctx: Ctx): SyncDirective = {
become(CloseAfterPeerClosed)
ctx.push(ActivelyCloseWithCode(Some(closeCode), reason))
}
private object CloseAfterPeerClosed extends State {
def onPush(elem: FrameEventOrError, ctx: Context[Output]): SyncDirective =
elem match {
case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data)
become(WaitForPeerTcpClose)
ctx.push(PeerClosed.parse(data))
case _ ctx.pull() // ignore all other data
}
}
private object WaitForPeerTcpClose extends State {
def onPush(elem: FrameEventOrError, ctx: Context[Output]): SyncDirective =
ctx.pull() // ignore
}
private abstract class StateWithControlFrameHandling extends BetweenFrameState {
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
validateHeader(start.header).getOrElse {
if (start.header.opcode.isControl)
if (start.isFullMessage) handleControlFrame(start.header.opcode, start.data, this)
else collectControlFrame(start, this)
else handleRegularFrameStart(start)
}
}
private abstract class BetweenFrameState extends ImplicitContextState {
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective =
throw new IllegalStateException("Expected FrameStart")
}
private abstract class InFrameState extends ImplicitContextState {
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
throw new IllegalStateException("Expected FrameData")
}
private abstract class ImplicitContextState extends State {
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective
def onPush(part: FrameEventOrError, ctx: Ctx): SyncDirective =
part match {
case data: FrameData handleFrameData(data)(ctx)
case start: FrameStart handleFrameStart(start)(ctx)
case FrameError(ex) ctx.fail(ex)
}
}
} }
sealed trait Output sealed trait Output