=htc Convert HandlerStage to GraphStage, #19361. (#20855)

This commit is contained in:
Daniel Moran 2016-07-08 07:44:54 -04:00 committed by Konrad Malawski
parent 1756541dcb
commit 6d2d4d5d25

View file

@ -6,9 +6,11 @@ package akka.http.impl.engine.ws
import akka.NotUsed import akka.NotUsed
import akka.stream.scaladsl.Flow import akka.stream.scaladsl.Flow
import akka.stream.stage.{ SyncDirective, Context, StatefulStage }
import akka.util.ByteString import akka.util.ByteString
import Protocol.Opcode import Protocol.Opcode
import akka.event.Logging
import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }
import akka.stream.{ Attributes, FlowShape, Inlet, Outlet }
import scala.util.control.NonFatal import scala.util.control.NonFatal
@ -21,156 +23,161 @@ import scala.util.control.NonFatal
private[http] object FrameHandler { private[http] object FrameHandler {
def create(server: Boolean): Flow[FrameEventOrError, Output, NotUsed] = def create(server: Boolean): Flow[FrameEventOrError, Output, NotUsed] =
Flow[FrameEventOrError].transform(() new HandlerStage(server)) Flow[FrameEventOrError].via(new HandlerStage(server))
private class HandlerStage(server: Boolean) extends StatefulStage[FrameEventOrError, Output] { private class HandlerStage(server: Boolean) extends GraphStage[FlowShape[FrameEventOrError, Output]] {
type Ctx = Context[Output] val in = Inlet[FrameEventOrError](Logging.simpleName(this) + ".in")
def initial: State = Idle val out = Outlet[Output](Logging.simpleName(this) + ".out")
override val shape = FlowShape(in, out)
override def toString: String = s"HandlerStage(server=$server)" override def toString: String = s"HandlerStage(server=$server)"
private object Idle extends StateWithControlFrameHandling { override def createLogic(attributes: Attributes): GraphStageLogic =
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = new GraphStageLogic(shape) with OutHandler {
setHandler(out, this)
setHandler(in, IdleHandler)
override def onPull(): Unit = pull(in)
private object IdleHandler extends ControlFrameStartHandler {
def setAndHandleFrameStartWith(newHandler: ControlFrameStartHandler, start: FrameStart): Unit = {
setHandler(in, newHandler)
newHandler.handleFrameStart(start)
}
override def handleRegularFrameStart(start: FrameStart): Unit =
(start.header.opcode, start.isFullMessage) match { (start.header.opcode, start.isFullMessage) match {
case (Opcode.Binary, true) publishMessagePart(BinaryMessagePart(start.data, last = true)) case (Opcode.Binary, true) publishMessagePart(BinaryMessagePart(start.data, last = true))
case (Opcode.Binary, false) becomeAndHandleWith(new CollectingBinaryMessage, start) case (Opcode.Binary, false) setAndHandleFrameStartWith(new BinaryMessagehandler, start)
case (Opcode.Text, _) becomeAndHandleWith(new CollectingTextMessage, start) case (Opcode.Text, _) setAndHandleFrameStartWith(new TextMessageHandler, start)
case x protocolError() case x pushProtocolError()
} }
} }
private class CollectingBinaryMessage extends CollectingMessageFrame(Opcode.Binary) { private class BinaryMessagehandler extends MessageHandler(Opcode.Binary) {
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = BinaryMessagePart(data, last) override def createMessagePart(data: ByteString, last: Boolean): MessageDataPart =
BinaryMessagePart(data, last)
} }
private class CollectingTextMessage extends CollectingMessageFrame(Opcode.Text) {
private class TextMessageHandler extends MessageHandler(Opcode.Text) {
val decoder = Utf8Decoder.create() val decoder = Utf8Decoder.create()
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = override def createMessagePart(data: ByteString, last: Boolean): MessageDataPart =
TextMessagePart(decoder.decode(data, endOfInput = last).get, last) TextMessagePart(decoder.decode(data, endOfInput = last).get, last)
} }
private abstract class CollectingMessageFrame(expectedOpcode: Opcode) extends StateWithControlFrameHandling { private abstract class MessageHandler(expectedOpcode: Opcode) extends ControlFrameStartHandler {
var expectFirstHeader = true var expectFirstHeader = true
var finSeen = false var finSeen = false
def createMessagePart(data: ByteString, last: Boolean): MessageDataPart def createMessagePart(data: ByteString, last: Boolean): MessageDataPart
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = { override def handleRegularFrameStart(start: FrameStart): Unit = {
if ((expectFirstHeader && start.header.opcode == expectedOpcode) // first opcode must be the expected if ((expectFirstHeader && start.header.opcode == expectedOpcode) // first opcode must be the expected
|| start.header.opcode == Opcode.Continuation) { // further ones continuations || start.header.opcode == Opcode.Continuation) { // further ones continuations
expectFirstHeader = false expectFirstHeader = false
if (start.header.fin) finSeen = true if (start.header.fin) finSeen = true
publish(start) publish(start)
} else protocolError() } else pushProtocolError()
} }
override def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = publish(data)
private def publish(part: FrameEvent)(implicit ctx: Ctx): SyncDirective = override def handleFrameData(data: FrameData): Unit = publish(data)
try publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart))
catch { def publish(part: FrameEvent): Unit = try {
publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart))
} catch {
case NonFatal(e) closeWithCode(Protocol.CloseCodes.InconsistentData) case NonFatal(e) closeWithCode(Protocol.CloseCodes.InconsistentData)
} }
} }
private class CollectingControlFrame(opcode: Opcode, _data: ByteString, nextState: State) extends InFrameState { private trait ControlFrameStartHandler extends FrameHandler {
def handleRegularFrameStart(start: FrameStart): Unit
override def handleFrameStart(start: FrameStart): Unit = start.header match {
case h: FrameHeader if h.mask.isDefined && !server pushProtocolError()
case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3 pushProtocolError()
case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) pushProtocolError()
case h: FrameHeader if h.opcode.isControl
if (start.isFullMessage) handleControlFrame(h.opcode, start.data, this)
else collectControlFrame(start, this)
case _ handleRegularFrameStart(start)
}
override def handleFrameData(data: FrameData): Unit =
throw new IllegalStateException("Expected FrameStart")
}
private class ControlFrameDataHandler(opcode: Opcode, _data: ByteString, nextHandler: InHandler) extends FrameHandler {
var data = _data var data = _data
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = { override def handleFrameData(data: FrameData): Unit = {
this.data ++= data.data this.data ++= data.data
if (data.lastPart) handleControlFrame(opcode, this.data, nextState) if (data.lastPart) handleControlFrame(opcode, this.data, nextHandler)
else ctx.pull() else pull(in)
}
} }
private def becomeAndHandleWith(newState: State, part: FrameEvent)(implicit ctx: Ctx): SyncDirective = { override def handleFrameStart(start: FrameStart): Unit =
become(newState) throw new IllegalStateException("Expected FrameData")
current.onPush(part, ctx)
} }
/** Returns a SyncDirective if it handled the message */ private trait FrameHandler extends InHandler {
private def validateHeader(header: FrameHeader)(implicit ctx: Ctx): Option[SyncDirective] = header match { def handleFrameData(data: FrameData): Unit
case h: FrameHeader if h.mask.isDefined && !server Some(protocolError()) def handleFrameStart(start: FrameStart): Unit
case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3 Some(protocolError())
case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) Some(protocolError())
case _ None
}
private def handleControlFrame(opcode: Opcode, data: ByteString, nextState: State)(implicit ctx: Ctx): SyncDirective = { def handleControlFrame(opcode: Opcode, data: ByteString, nextHandler: InHandler): Unit = {
become(nextState) setHandler(in, nextHandler)
opcode match { opcode match {
case Opcode.Ping publishDirectResponse(FrameEvent.fullFrame(Opcode.Pong, None, data, fin = true)) case Opcode.Ping publishDirectResponse(FrameEvent.fullFrame(Opcode.Pong, None, data, fin = true))
case Opcode.Pong case Opcode.Pong
// ignore unsolicited Pong frame // ignore unsolicited Pong frame
ctx.pull() pull(in)
case Opcode.Close case Opcode.Close
become(WaitForPeerTcpClose) setHandler(in, WaitForPeerTcpClose)
ctx.push(PeerClosed.parse(data)) push(out, PeerClosed.parse(data))
case Opcode.Other(o) closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode") case Opcode.Other(o) closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode")
case other ctx.fail(new IllegalStateException(s"unexpected message of type [${other.getClass.getName}] when expecting ControlFrame")) case other failStage(
new IllegalStateException(s"unexpected message of type [${other.getClass.getName}] when expecting ControlFrame")
)
} }
} }
private def collectControlFrame(start: FrameStart, nextState: State)(implicit ctx: Ctx): SyncDirective = {
def pushProtocolError(): Unit = closeWithCode(Protocol.CloseCodes.ProtocolError)
def closeWithCode(closeCode: Int, reason: String = ""): Unit = {
setHandler(in, CloseAfterPeerClosed)
push(out, ActivelyCloseWithCode(Some(closeCode), reason))
}
def collectControlFrame(start: FrameStart, nextHandler: InHandler): Unit = {
require(!start.isFullMessage) require(!start.isFullMessage)
become(new CollectingControlFrame(start.header.opcode, start.data, nextState)) setHandler(in, new ControlFrameDataHandler(start.header.opcode, start.data, nextHandler))
ctx.pull() pull(in)
} }
private def publishMessagePart(part: MessageDataPart)(implicit ctx: Ctx): SyncDirective = def publishMessagePart(part: MessageDataPart): Unit =
if (part.last) emit(Iterator(part, MessageEnd), ctx, Idle) if (part.last) emitMultiple(out, Iterator(part, MessageEnd), () setHandler(in, IdleHandler))
else ctx.push(part) else push(out, part)
private def publishDirectResponse(frame: FrameStart)(implicit ctx: Ctx): SyncDirective =
ctx.push(DirectAnswer(frame))
private def protocolError(reason: String = "")(implicit ctx: Ctx): SyncDirective = def publishDirectResponse(frame: FrameStart): Unit = push(out, DirectAnswer(frame))
closeWithCode(Protocol.CloseCodes.ProtocolError, reason)
private def closeWithCode(closeCode: Int, reason: String = "", cause: Throwable = null)(implicit ctx: Ctx): SyncDirective = { override def onPush(): Unit = grab(in) match {
become(CloseAfterPeerClosed) case data: FrameData handleFrameData(data)
ctx.push(ActivelyCloseWithCode(Some(closeCode), reason)) case start: FrameStart handleFrameStart(start)
case FrameError(ex) failStage(ex)
}
} }
private object CloseAfterPeerClosed extends State { private object CloseAfterPeerClosed extends InHandler {
def onPush(elem: FrameEventOrError, ctx: Context[Output]): SyncDirective = override def onPush(): Unit = grab(in) match {
elem match {
case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data) case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data)
become(WaitForPeerTcpClose) setHandler(in, WaitForPeerTcpClose)
ctx.push(PeerClosed.parse(data)) push(out, PeerClosed.parse(data))
case _ ctx.pull() // ignore all other data case _ pull(in) // ignore all other data
} }
} }
private object WaitForPeerTcpClose extends State {
def onPush(elem: FrameEventOrError, ctx: Context[Output]): SyncDirective =
ctx.pull() // ignore
}
private abstract class StateWithControlFrameHandling extends BetweenFrameState { private object WaitForPeerTcpClose extends InHandler {
def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective override def onPush(): Unit = pull(in) // ignore
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
validateHeader(start.header).getOrElse {
if (start.header.opcode.isControl)
if (start.isFullMessage) handleControlFrame(start.header.opcode, start.data, this)
else collectControlFrame(start, this)
else handleRegularFrameStart(start)
}
}
private abstract class BetweenFrameState extends ImplicitContextState {
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective =
throw new IllegalStateException("Expected FrameStart")
}
private abstract class InFrameState extends ImplicitContextState {
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective =
throw new IllegalStateException("Expected FrameData")
}
private abstract class ImplicitContextState extends State {
def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective
def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective
def onPush(part: FrameEventOrError, ctx: Ctx): SyncDirective =
part match {
case data: FrameData handleFrameData(data)(ctx)
case start: FrameStart handleFrameStart(start)(ctx)
case FrameError(ex) ctx.fail(ex)
} }
} }
} }