parent
adc931a493
commit
e00a86271a
1 changed files with 55 additions and 56 deletions
|
|
@ -8,26 +8,23 @@ import akka.NotUsed
|
|||
import akka.http.scaladsl.model.ws._
|
||||
|
||||
import scala.concurrent.{ Future, Promise }
|
||||
|
||||
import akka.util.ByteString
|
||||
import akka.event.LoggingAdapter
|
||||
|
||||
import akka.stream.stage._
|
||||
import akka.stream._
|
||||
import akka.stream.TLSProtocol._
|
||||
import akka.stream.scaladsl._
|
||||
|
||||
import akka.http.scaladsl.settings.ClientConnectionSettings
|
||||
import akka.http.scaladsl.Http
|
||||
import akka.http.scaladsl.model.{ HttpResponse, HttpMethods }
|
||||
import akka.http.scaladsl.model.{ HttpMethods, HttpResponse }
|
||||
import akka.http.scaladsl.model.headers.Host
|
||||
|
||||
import akka.http.impl.engine.parsing.HttpMessageParser.StateResult
|
||||
import akka.http.impl.engine.parsing.ParserOutput.{ RemainingBytes, ResponseStart, NeedMoreData }
|
||||
import akka.http.impl.engine.parsing.{ ParserOutput, HttpHeaderParser, HttpResponseParser }
|
||||
import akka.http.impl.engine.parsing.ParserOutput.{ NeedMoreData, RemainingBytes, ResponseStart }
|
||||
import akka.http.impl.engine.parsing.{ HttpHeaderParser, HttpResponseParser, ParserOutput }
|
||||
import akka.http.impl.engine.rendering.{ HttpRequestRendererFactory, RequestRenderingContext }
|
||||
import akka.http.impl.engine.ws.Handshake.Client.NegotiatedWebSocketSettings
|
||||
import akka.http.impl.util.StreamUtils
|
||||
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
|
||||
|
||||
object WebSocketClientBlueprint {
|
||||
/**
|
||||
|
|
@ -59,12 +56,10 @@ object WebSocketClientBlueprint {
|
|||
val renderedInitialRequest =
|
||||
HttpRequestRendererFactory.renderStrict(RequestRenderingContext(initialRequest, hostHeader), settings, log)
|
||||
|
||||
class UpgradeStage extends StatefulStage[ByteString, ByteString] {
|
||||
type State = StageState[ByteString, ByteString]
|
||||
class UpgradeStage extends SimpleLinearGraphStage[ByteString] {
|
||||
|
||||
def initial: State = parsingResponse
|
||||
|
||||
def parsingResponse: State = new State {
|
||||
override def createLogic(attributes: Attributes): GraphStageLogic =
|
||||
new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||
// a special version of the parser which only parses one message and then reports the remaining data
|
||||
// if some is available
|
||||
val parser = new HttpResponseParser(settings.parserSettings, HttpHeaderParser(settings.parserSettings)()) {
|
||||
|
|
@ -82,45 +77,49 @@ object WebSocketClientBlueprint {
|
|||
}
|
||||
parser.setContextForNextResponse(HttpResponseParser.ResponseContext(HttpMethods.GET, None))
|
||||
|
||||
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = {
|
||||
parser.parseBytes(elem) match {
|
||||
case NeedMoreData ⇒ ctx.pull()
|
||||
override def onPush(): Unit = {
|
||||
parser.parseBytes(grab(in)) match {
|
||||
case NeedMoreData ⇒ pull(in)
|
||||
case ResponseStart(status, protocol, headers, entity, close) ⇒
|
||||
val response = HttpResponse(status, headers, protocol = protocol)
|
||||
Handshake.Client.validateResponse(response, subprotocol.toList, key) match {
|
||||
case Right(NegotiatedWebSocketSettings(protocol)) ⇒
|
||||
result.success(ValidUpgrade(response, protocol))
|
||||
|
||||
become(transparent)
|
||||
setHandler(in, new InHandler {
|
||||
override def onPush(): Unit = push(out, grab(in))
|
||||
})
|
||||
valve.open()
|
||||
|
||||
val parseResult = parser.onPull()
|
||||
require(parseResult == ParserOutput.MessageEnd, s"parseResult should be MessageEnd but was $parseResult")
|
||||
parser.onPull() match {
|
||||
case NeedMoreData ⇒ ctx.pull()
|
||||
case RemainingBytes(bytes) ⇒ ctx.push(bytes)
|
||||
case NeedMoreData ⇒ pull(in)
|
||||
case RemainingBytes(bytes) ⇒ push(out, bytes)
|
||||
case other ⇒
|
||||
throw new IllegalStateException(s"unexpected element of type ${other.getClass}")
|
||||
}
|
||||
case Left(problem) ⇒
|
||||
result.success(InvalidUpgradeResponse(response, s"WebSocket server at $uri returned $problem"))
|
||||
ctx.fail(new IllegalArgumentException(s"WebSocket upgrade did not finish because of '$problem'"))
|
||||
failStage(new IllegalArgumentException(s"WebSocket upgrade did not finish because of '$problem'"))
|
||||
}
|
||||
case other ⇒
|
||||
throw new IllegalStateException(s"unexpected element of type ${other.getClass}")
|
||||
}
|
||||
}
|
||||
|
||||
override def onPull(): Unit = pull(in)
|
||||
|
||||
setHandlers(in, out, this)
|
||||
}
|
||||
|
||||
def transparent: State = new State {
|
||||
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = ctx.push(elem)
|
||||
}
|
||||
override def toString = "UpgradeStage"
|
||||
}
|
||||
|
||||
BidiFlow.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||
import GraphDSL.Implicits._
|
||||
|
||||
val networkIn = b.add(Flow[ByteString].transform(() ⇒ new UpgradeStage))
|
||||
val networkIn = b.add(Flow[ByteString].via(new UpgradeStage))
|
||||
val wsIn = b.add(Flow[ByteString])
|
||||
|
||||
val handshakeRequestSource = b.add(Source.single(renderedInitialRequest) ++ valve.source)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue