=htc #19361 Convert UpgradeStage to GraphStage (#20835)

This commit is contained in:
Daniel Moran 2016-06-28 02:58:12 -07:00 committed by Konrad Malawski
parent adc931a493
commit e00a86271a

View file

@ -8,26 +8,23 @@ import akka.NotUsed
import akka.http.scaladsl.model.ws._
import scala.concurrent.{ Future, Promise }
import akka.util.ByteString
import akka.event.LoggingAdapter
import akka.stream.stage._
import akka.stream._
import akka.stream.TLSProtocol._
import akka.stream.scaladsl._
import akka.http.scaladsl.settings.ClientConnectionSettings
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.{ HttpResponse, HttpMethods }
import akka.http.scaladsl.model.{ HttpMethods, HttpResponse }
import akka.http.scaladsl.model.headers.Host
import akka.http.impl.engine.parsing.HttpMessageParser.StateResult
import akka.http.impl.engine.parsing.ParserOutput.{ RemainingBytes, ResponseStart, NeedMoreData }
import akka.http.impl.engine.parsing.{ ParserOutput, HttpHeaderParser, HttpResponseParser }
import akka.http.impl.engine.parsing.ParserOutput.{ NeedMoreData, RemainingBytes, ResponseStart }
import akka.http.impl.engine.parsing.{ HttpHeaderParser, HttpResponseParser, ParserOutput }
import akka.http.impl.engine.rendering.{ HttpRequestRendererFactory, RequestRenderingContext }
import akka.http.impl.engine.ws.Handshake.Client.NegotiatedWebSocketSettings
import akka.http.impl.util.StreamUtils
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
object WebSocketClientBlueprint {
/**
@ -59,68 +56,70 @@ object WebSocketClientBlueprint {
val renderedInitialRequest =
HttpRequestRendererFactory.renderStrict(RequestRenderingContext(initialRequest, hostHeader), settings, log)
class UpgradeStage extends StatefulStage[ByteString, ByteString] {
type State = StageState[ByteString, ByteString]
class UpgradeStage extends SimpleLinearGraphStage[ByteString] {
def initial: State = parsingResponse
def parsingResponse: State = new State {
// a special version of the parser which only parses one message and then reports the remaining data
// if some is available
val parser = new HttpResponseParser(settings.parserSettings, HttpHeaderParser(settings.parserSettings)()) {
var first = true
override def handleInformationalResponses = false
override protected def parseMessage(input: ByteString, offset: Int): StateResult = {
if (first) {
first = false
super.parseMessage(input, offset)
} else {
emit(RemainingBytes(input.drop(offset)))
terminate()
override def createLogic(attributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {
// a special version of the parser which only parses one message and then reports the remaining data
// if some is available
val parser = new HttpResponseParser(settings.parserSettings, HttpHeaderParser(settings.parserSettings)()) {
var first = true
override def handleInformationalResponses = false
override protected def parseMessage(input: ByteString, offset: Int): StateResult = {
if (first) {
first = false
super.parseMessage(input, offset)
} else {
emit(RemainingBytes(input.drop(offset)))
terminate()
}
}
}
}
parser.setContextForNextResponse(HttpResponseParser.ResponseContext(HttpMethods.GET, None))
parser.setContextForNextResponse(HttpResponseParser.ResponseContext(HttpMethods.GET, None))
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = {
parser.parseBytes(elem) match {
case NeedMoreData ctx.pull()
case ResponseStart(status, protocol, headers, entity, close)
val response = HttpResponse(status, headers, protocol = protocol)
Handshake.Client.validateResponse(response, subprotocol.toList, key) match {
case Right(NegotiatedWebSocketSettings(protocol))
result.success(ValidUpgrade(response, protocol))
override def onPush(): Unit = {
parser.parseBytes(grab(in)) match {
case NeedMoreData pull(in)
case ResponseStart(status, protocol, headers, entity, close)
val response = HttpResponse(status, headers, protocol = protocol)
Handshake.Client.validateResponse(response, subprotocol.toList, key) match {
case Right(NegotiatedWebSocketSettings(protocol))
result.success(ValidUpgrade(response, protocol))
become(transparent)
valve.open()
setHandler(in, new InHandler {
override def onPush(): Unit = push(out, grab(in))
})
valve.open()
val parseResult = parser.onPull()
require(parseResult == ParserOutput.MessageEnd, s"parseResult should be MessageEnd but was $parseResult")
parser.onPull() match {
case NeedMoreData ctx.pull()
case RemainingBytes(bytes) ctx.push(bytes)
case other
throw new IllegalStateException(s"unexpected element of type ${other.getClass}")
}
case Left(problem)
result.success(InvalidUpgradeResponse(response, s"WebSocket server at $uri returned $problem"))
ctx.fail(new IllegalArgumentException(s"WebSocket upgrade did not finish because of '$problem'"))
}
case other
throw new IllegalStateException(s"unexpected element of type ${other.getClass}")
val parseResult = parser.onPull()
require(parseResult == ParserOutput.MessageEnd, s"parseResult should be MessageEnd but was $parseResult")
parser.onPull() match {
case NeedMoreData pull(in)
case RemainingBytes(bytes) push(out, bytes)
case other
throw new IllegalStateException(s"unexpected element of type ${other.getClass}")
}
case Left(problem)
result.success(InvalidUpgradeResponse(response, s"WebSocket server at $uri returned $problem"))
failStage(new IllegalArgumentException(s"WebSocket upgrade did not finish because of '$problem'"))
}
case other
throw new IllegalStateException(s"unexpected element of type ${other.getClass}")
}
}
}
}
def transparent: State = new State {
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = ctx.push(elem)
}
override def onPull(): Unit = pull(in)
setHandlers(in, out, this)
}
override def toString = "UpgradeStage"
}
BidiFlow.fromGraph(GraphDSL.create() { implicit b
import GraphDSL.Implicits._
val networkIn = b.add(Flow[ByteString].transform(() new UpgradeStage))
val networkIn = b.add(Flow[ByteString].via(new UpgradeStage))
val wsIn = b.add(Flow[ByteString])
val handshakeRequestSource = b.add(Source.single(renderedInitialRequest) ++ valve.source)