diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameHandler.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameHandler.scala index 549d3be663..7803b21545 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameHandler.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/FrameHandler.scala @@ -6,9 +6,11 @@ package akka.http.impl.engine.ws import akka.NotUsed import akka.stream.scaladsl.Flow -import akka.stream.stage.{ SyncDirective, Context, StatefulStage } import akka.util.ByteString import Protocol.Opcode +import akka.event.Logging +import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler } +import akka.stream.{ Attributes, FlowShape, Inlet, Outlet } import scala.util.control.NonFatal @@ -21,158 +23,163 @@ import scala.util.control.NonFatal private[http] object FrameHandler { def create(server: Boolean): Flow[FrameEventOrError, Output, NotUsed] = - Flow[FrameEventOrError].transform(() ⇒ new HandlerStage(server)) + Flow[FrameEventOrError].via(new HandlerStage(server)) - private class HandlerStage(server: Boolean) extends StatefulStage[FrameEventOrError, Output] { - type Ctx = Context[Output] - def initial: State = Idle + private class HandlerStage(server: Boolean) extends GraphStage[FlowShape[FrameEventOrError, Output]] { + val in = Inlet[FrameEventOrError](Logging.simpleName(this) + ".in") + val out = Outlet[Output](Logging.simpleName(this) + ".out") + override val shape = FlowShape(in, out) override def toString: String = s"HandlerStage(server=$server)" - private object Idle extends StateWithControlFrameHandling { - def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = - (start.header.opcode, start.isFullMessage) match { - case (Opcode.Binary, true) ⇒ publishMessagePart(BinaryMessagePart(start.data, last = true)) - case (Opcode.Binary, false) ⇒ becomeAndHandleWith(new CollectingBinaryMessage, start) - case (Opcode.Text, _) ⇒ becomeAndHandleWith(new CollectingTextMessage, start) - case x ⇒ protocolError() + override def createLogic(attributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with OutHandler { + setHandler(out, this) + setHandler(in, IdleHandler) + + override def onPull(): Unit = pull(in) + + private object IdleHandler extends ControlFrameStartHandler { + def setAndHandleFrameStartWith(newHandler: ControlFrameStartHandler, start: FrameStart): Unit = { + setHandler(in, newHandler) + newHandler.handleFrameStart(start) + } + + override def handleRegularFrameStart(start: FrameStart): Unit = + (start.header.opcode, start.isFullMessage) match { + case (Opcode.Binary, true) ⇒ publishMessagePart(BinaryMessagePart(start.data, last = true)) + case (Opcode.Binary, false) ⇒ setAndHandleFrameStartWith(new BinaryMessagehandler, start) + case (Opcode.Text, _) ⇒ setAndHandleFrameStartWith(new TextMessageHandler, start) + case x ⇒ pushProtocolError() + } } - } - private class CollectingBinaryMessage extends CollectingMessageFrame(Opcode.Binary) { - def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = BinaryMessagePart(data, last) - } - private class CollectingTextMessage extends CollectingMessageFrame(Opcode.Text) { - val decoder = Utf8Decoder.create() + private class BinaryMessagehandler extends MessageHandler(Opcode.Binary) { + override def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = + BinaryMessagePart(data, last) + } - def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = - TextMessagePart(decoder.decode(data, endOfInput = last).get, last) - } + private class TextMessageHandler extends MessageHandler(Opcode.Text) { + val decoder = Utf8Decoder.create() - private abstract class CollectingMessageFrame(expectedOpcode: Opcode) extends StateWithControlFrameHandling { - var expectFirstHeader = true - var finSeen = false - def createMessagePart(data: ByteString, last: Boolean): MessageDataPart + override def createMessagePart(data: ByteString, last: Boolean): MessageDataPart = + TextMessagePart(decoder.decode(data, endOfInput = last).get, last) + } - def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = { - if ((expectFirstHeader && start.header.opcode == expectedOpcode) // first opcode must be the expected - || start.header.opcode == Opcode.Continuation) { // further ones continuations - expectFirstHeader = false + private abstract class MessageHandler(expectedOpcode: Opcode) extends ControlFrameStartHandler { + var expectFirstHeader = true + var finSeen = false + def createMessagePart(data: ByteString, last: Boolean): MessageDataPart - if (start.header.fin) finSeen = true - publish(start) - } else protocolError() + override def handleRegularFrameStart(start: FrameStart): Unit = { + if ((expectFirstHeader && start.header.opcode == expectedOpcode) // first opcode must be the expected + || start.header.opcode == Opcode.Continuation) { // further ones continuations + expectFirstHeader = false + + if (start.header.fin) finSeen = true + publish(start) + } else pushProtocolError() + } + + override def handleFrameData(data: FrameData): Unit = publish(data) + + def publish(part: FrameEvent): Unit = try { + publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart)) + } catch { + case NonFatal(e) ⇒ closeWithCode(Protocol.CloseCodes.InconsistentData) + } + } + + private trait ControlFrameStartHandler extends FrameHandler { + def handleRegularFrameStart(start: FrameStart): Unit + + override def handleFrameStart(start: FrameStart): Unit = start.header match { + case h: FrameHeader if h.mask.isDefined && !server ⇒ pushProtocolError() + case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3 ⇒ pushProtocolError() + case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) ⇒ pushProtocolError() + case h: FrameHeader if h.opcode.isControl ⇒ + if (start.isFullMessage) handleControlFrame(h.opcode, start.data, this) + else collectControlFrame(start, this) + case _ ⇒ handleRegularFrameStart(start) + } + + override def handleFrameData(data: FrameData): Unit = + throw new IllegalStateException("Expected FrameStart") + } + + private class ControlFrameDataHandler(opcode: Opcode, _data: ByteString, nextHandler: InHandler) extends FrameHandler { + var data = _data + + override def handleFrameData(data: FrameData): Unit = { + this.data ++= data.data + if (data.lastPart) handleControlFrame(opcode, this.data, nextHandler) + else pull(in) + } + + override def handleFrameStart(start: FrameStart): Unit = + throw new IllegalStateException("Expected FrameData") + } + + private trait FrameHandler extends InHandler { + def handleFrameData(data: FrameData): Unit + def handleFrameStart(start: FrameStart): Unit + + def handleControlFrame(opcode: Opcode, data: ByteString, nextHandler: InHandler): Unit = { + setHandler(in, nextHandler) + opcode match { + case Opcode.Ping ⇒ publishDirectResponse(FrameEvent.fullFrame(Opcode.Pong, None, data, fin = true)) + case Opcode.Pong ⇒ + // ignore unsolicited Pong frame + pull(in) + case Opcode.Close ⇒ + setHandler(in, WaitForPeerTcpClose) + push(out, PeerClosed.parse(data)) + case Opcode.Other(o) ⇒ closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode") + case other ⇒ failStage( + new IllegalStateException(s"unexpected message of type [${other.getClass.getName}] when expecting ControlFrame") + ) + } + } + + def pushProtocolError(): Unit = closeWithCode(Protocol.CloseCodes.ProtocolError) + + def closeWithCode(closeCode: Int, reason: String = ""): Unit = { + setHandler(in, CloseAfterPeerClosed) + push(out, ActivelyCloseWithCode(Some(closeCode), reason)) + } + + def collectControlFrame(start: FrameStart, nextHandler: InHandler): Unit = { + require(!start.isFullMessage) + setHandler(in, new ControlFrameDataHandler(start.header.opcode, start.data, nextHandler)) + pull(in) + } + + def publishMessagePart(part: MessageDataPart): Unit = + if (part.last) emitMultiple(out, Iterator(part, MessageEnd), () ⇒ setHandler(in, IdleHandler)) + else push(out, part) + + def publishDirectResponse(frame: FrameStart): Unit = push(out, DirectAnswer(frame)) + + override def onPush(): Unit = grab(in) match { + case data: FrameData ⇒ handleFrameData(data) + case start: FrameStart ⇒ handleFrameStart(start) + case FrameError(ex) ⇒ failStage(ex) + } + } + + private object CloseAfterPeerClosed extends InHandler { + override def onPush(): Unit = grab(in) match { + case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data) ⇒ + setHandler(in, WaitForPeerTcpClose) + push(out, PeerClosed.parse(data)) + case _ ⇒ pull(in) // ignore all other data + } + } + + private object WaitForPeerTcpClose extends InHandler { + override def onPush(): Unit = pull(in) // ignore + } } - override def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = publish(data) - - private def publish(part: FrameEvent)(implicit ctx: Ctx): SyncDirective = - try publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart)) - catch { - case NonFatal(e) ⇒ closeWithCode(Protocol.CloseCodes.InconsistentData) - } - } - - private class CollectingControlFrame(opcode: Opcode, _data: ByteString, nextState: State) extends InFrameState { - var data = _data - - def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = { - this.data ++= data.data - if (data.lastPart) handleControlFrame(opcode, this.data, nextState) - else ctx.pull() - } - } - - private def becomeAndHandleWith(newState: State, part: FrameEvent)(implicit ctx: Ctx): SyncDirective = { - become(newState) - current.onPush(part, ctx) - } - - /** Returns a SyncDirective if it handled the message */ - private def validateHeader(header: FrameHeader)(implicit ctx: Ctx): Option[SyncDirective] = header match { - case h: FrameHeader if h.mask.isDefined && !server ⇒ Some(protocolError()) - case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3 ⇒ Some(protocolError()) - case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) ⇒ Some(protocolError()) - case _ ⇒ None - } - - private def handleControlFrame(opcode: Opcode, data: ByteString, nextState: State)(implicit ctx: Ctx): SyncDirective = { - become(nextState) - opcode match { - case Opcode.Ping ⇒ publishDirectResponse(FrameEvent.fullFrame(Opcode.Pong, None, data, fin = true)) - case Opcode.Pong ⇒ - // ignore unsolicited Pong frame - ctx.pull() - case Opcode.Close ⇒ - become(WaitForPeerTcpClose) - ctx.push(PeerClosed.parse(data)) - case Opcode.Other(o) ⇒ closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode") - case other ⇒ ctx.fail(new IllegalStateException(s"unexpected message of type [${other.getClass.getName}] when expecting ControlFrame")) - } - } - private def collectControlFrame(start: FrameStart, nextState: State)(implicit ctx: Ctx): SyncDirective = { - require(!start.isFullMessage) - become(new CollectingControlFrame(start.header.opcode, start.data, nextState)) - ctx.pull() - } - - private def publishMessagePart(part: MessageDataPart)(implicit ctx: Ctx): SyncDirective = - if (part.last) emit(Iterator(part, MessageEnd), ctx, Idle) - else ctx.push(part) - private def publishDirectResponse(frame: FrameStart)(implicit ctx: Ctx): SyncDirective = - ctx.push(DirectAnswer(frame)) - - private def protocolError(reason: String = "")(implicit ctx: Ctx): SyncDirective = - closeWithCode(Protocol.CloseCodes.ProtocolError, reason) - - private def closeWithCode(closeCode: Int, reason: String = "", cause: Throwable = null)(implicit ctx: Ctx): SyncDirective = { - become(CloseAfterPeerClosed) - ctx.push(ActivelyCloseWithCode(Some(closeCode), reason)) - } - - private object CloseAfterPeerClosed extends State { - def onPush(elem: FrameEventOrError, ctx: Context[Output]): SyncDirective = - elem match { - case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data) ⇒ - become(WaitForPeerTcpClose) - ctx.push(PeerClosed.parse(data)) - case _ ⇒ ctx.pull() // ignore all other data - } - } - private object WaitForPeerTcpClose extends State { - def onPush(elem: FrameEventOrError, ctx: Context[Output]): SyncDirective = - ctx.pull() // ignore - } - - private abstract class StateWithControlFrameHandling extends BetweenFrameState { - def handleRegularFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective - - def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = - validateHeader(start.header).getOrElse { - if (start.header.opcode.isControl) - if (start.isFullMessage) handleControlFrame(start.header.opcode, start.data, this) - else collectControlFrame(start, this) - else handleRegularFrameStart(start) - } - } - private abstract class BetweenFrameState extends ImplicitContextState { - def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective = - throw new IllegalStateException("Expected FrameStart") - } - private abstract class InFrameState extends ImplicitContextState { - def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective = - throw new IllegalStateException("Expected FrameData") - } - private abstract class ImplicitContextState extends State { - def handleFrameData(data: FrameData)(implicit ctx: Ctx): SyncDirective - def handleFrameStart(start: FrameStart)(implicit ctx: Ctx): SyncDirective - - def onPush(part: FrameEventOrError, ctx: Ctx): SyncDirective = - part match { - case data: FrameData ⇒ handleFrameData(data)(ctx) - case start: FrameStart ⇒ handleFrameStart(start)(ctx) - case FrameError(ex) ⇒ ctx.fail(ex) - } - } } sealed trait Output