=htc #19361 Convert UpgradeStage to GraphStage (#20835)

This commit is contained in:
Daniel Moran 2016-06-28 02:58:12 -07:00 committed by Konrad Malawski
parent adc931a493
commit e00a86271a

View file

@ -8,26 +8,23 @@ import akka.NotUsed
import akka.http.scaladsl.model.ws._ import akka.http.scaladsl.model.ws._
import scala.concurrent.{ Future, Promise } import scala.concurrent.{ Future, Promise }
import akka.util.ByteString import akka.util.ByteString
import akka.event.LoggingAdapter import akka.event.LoggingAdapter
import akka.stream.stage._ import akka.stream.stage._
import akka.stream._ import akka.stream._
import akka.stream.TLSProtocol._ import akka.stream.TLSProtocol._
import akka.stream.scaladsl._ import akka.stream.scaladsl._
import akka.http.scaladsl.settings.ClientConnectionSettings import akka.http.scaladsl.settings.ClientConnectionSettings
import akka.http.scaladsl.Http import akka.http.scaladsl.Http
import akka.http.scaladsl.model.{ HttpResponse, HttpMethods } import akka.http.scaladsl.model.{ HttpMethods, HttpResponse }
import akka.http.scaladsl.model.headers.Host import akka.http.scaladsl.model.headers.Host
import akka.http.impl.engine.parsing.HttpMessageParser.StateResult import akka.http.impl.engine.parsing.HttpMessageParser.StateResult
import akka.http.impl.engine.parsing.ParserOutput.{ RemainingBytes, ResponseStart, NeedMoreData } import akka.http.impl.engine.parsing.ParserOutput.{ NeedMoreData, RemainingBytes, ResponseStart }
import akka.http.impl.engine.parsing.{ ParserOutput, HttpHeaderParser, HttpResponseParser } import akka.http.impl.engine.parsing.{ HttpHeaderParser, HttpResponseParser, ParserOutput }
import akka.http.impl.engine.rendering.{ HttpRequestRendererFactory, RequestRenderingContext } import akka.http.impl.engine.rendering.{ HttpRequestRendererFactory, RequestRenderingContext }
import akka.http.impl.engine.ws.Handshake.Client.NegotiatedWebSocketSettings import akka.http.impl.engine.ws.Handshake.Client.NegotiatedWebSocketSettings
import akka.http.impl.util.StreamUtils import akka.http.impl.util.StreamUtils
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
object WebSocketClientBlueprint { object WebSocketClientBlueprint {
/** /**
@ -59,68 +56,70 @@ object WebSocketClientBlueprint {
val renderedInitialRequest = val renderedInitialRequest =
HttpRequestRendererFactory.renderStrict(RequestRenderingContext(initialRequest, hostHeader), settings, log) HttpRequestRendererFactory.renderStrict(RequestRenderingContext(initialRequest, hostHeader), settings, log)
class UpgradeStage extends StatefulStage[ByteString, ByteString] { class UpgradeStage extends SimpleLinearGraphStage[ByteString] {
type State = StageState[ByteString, ByteString]
def initial: State = parsingResponse override def createLogic(attributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {
def parsingResponse: State = new State { // a special version of the parser which only parses one message and then reports the remaining data
// a special version of the parser which only parses one message and then reports the remaining data // if some is available
// if some is available val parser = new HttpResponseParser(settings.parserSettings, HttpHeaderParser(settings.parserSettings)()) {
val parser = new HttpResponseParser(settings.parserSettings, HttpHeaderParser(settings.parserSettings)()) { var first = true
var first = true override def handleInformationalResponses = false
override def handleInformationalResponses = false override protected def parseMessage(input: ByteString, offset: Int): StateResult = {
override protected def parseMessage(input: ByteString, offset: Int): StateResult = { if (first) {
if (first) { first = false
first = false super.parseMessage(input, offset)
super.parseMessage(input, offset) } else {
} else { emit(RemainingBytes(input.drop(offset)))
emit(RemainingBytes(input.drop(offset))) terminate()
terminate() }
} }
} }
} parser.setContextForNextResponse(HttpResponseParser.ResponseContext(HttpMethods.GET, None))
parser.setContextForNextResponse(HttpResponseParser.ResponseContext(HttpMethods.GET, None))
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = { override def onPush(): Unit = {
parser.parseBytes(elem) match { parser.parseBytes(grab(in)) match {
case NeedMoreData ctx.pull() case NeedMoreData pull(in)
case ResponseStart(status, protocol, headers, entity, close) case ResponseStart(status, protocol, headers, entity, close)
val response = HttpResponse(status, headers, protocol = protocol) val response = HttpResponse(status, headers, protocol = protocol)
Handshake.Client.validateResponse(response, subprotocol.toList, key) match { Handshake.Client.validateResponse(response, subprotocol.toList, key) match {
case Right(NegotiatedWebSocketSettings(protocol)) case Right(NegotiatedWebSocketSettings(protocol))
result.success(ValidUpgrade(response, protocol)) result.success(ValidUpgrade(response, protocol))
become(transparent) setHandler(in, new InHandler {
valve.open() override def onPush(): Unit = push(out, grab(in))
})
valve.open()
val parseResult = parser.onPull() val parseResult = parser.onPull()
require(parseResult == ParserOutput.MessageEnd, s"parseResult should be MessageEnd but was $parseResult") require(parseResult == ParserOutput.MessageEnd, s"parseResult should be MessageEnd but was $parseResult")
parser.onPull() match { parser.onPull() match {
case NeedMoreData ctx.pull() case NeedMoreData pull(in)
case RemainingBytes(bytes) ctx.push(bytes) case RemainingBytes(bytes) push(out, bytes)
case other case other
throw new IllegalStateException(s"unexpected element of type ${other.getClass}") throw new IllegalStateException(s"unexpected element of type ${other.getClass}")
} }
case Left(problem) case Left(problem)
result.success(InvalidUpgradeResponse(response, s"WebSocket server at $uri returned $problem")) result.success(InvalidUpgradeResponse(response, s"WebSocket server at $uri returned $problem"))
ctx.fail(new IllegalArgumentException(s"WebSocket upgrade did not finish because of '$problem'")) failStage(new IllegalArgumentException(s"WebSocket upgrade did not finish because of '$problem'"))
} }
case other case other
throw new IllegalStateException(s"unexpected element of type ${other.getClass}") throw new IllegalStateException(s"unexpected element of type ${other.getClass}")
}
} }
}
}
def transparent: State = new State { override def onPull(): Unit = pull(in)
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = ctx.push(elem)
} setHandlers(in, out, this)
}
override def toString = "UpgradeStage"
} }
BidiFlow.fromGraph(GraphDSL.create() { implicit b BidiFlow.fromGraph(GraphDSL.create() { implicit b
import GraphDSL.Implicits._ import GraphDSL.Implicits._
val networkIn = b.add(Flow[ByteString].transform(() new UpgradeStage)) val networkIn = b.add(Flow[ByteString].via(new UpgradeStage))
val wsIn = b.add(Flow[ByteString]) val wsIn = b.add(Flow[ByteString])
val handshakeRequestSource = b.add(Source.single(renderedInitialRequest) ++ valve.source) val handshakeRequestSource = b.add(Source.single(renderedInitialRequest) ++ valve.source)