parent
adc931a493
commit
e00a86271a
1 changed files with 55 additions and 56 deletions
|
|
@ -8,26 +8,23 @@ import akka.NotUsed
|
||||||
import akka.http.scaladsl.model.ws._
|
import akka.http.scaladsl.model.ws._
|
||||||
|
|
||||||
import scala.concurrent.{ Future, Promise }
|
import scala.concurrent.{ Future, Promise }
|
||||||
|
|
||||||
import akka.util.ByteString
|
import akka.util.ByteString
|
||||||
import akka.event.LoggingAdapter
|
import akka.event.LoggingAdapter
|
||||||
|
|
||||||
import akka.stream.stage._
|
import akka.stream.stage._
|
||||||
import akka.stream._
|
import akka.stream._
|
||||||
import akka.stream.TLSProtocol._
|
import akka.stream.TLSProtocol._
|
||||||
import akka.stream.scaladsl._
|
import akka.stream.scaladsl._
|
||||||
|
|
||||||
import akka.http.scaladsl.settings.ClientConnectionSettings
|
import akka.http.scaladsl.settings.ClientConnectionSettings
|
||||||
import akka.http.scaladsl.Http
|
import akka.http.scaladsl.Http
|
||||||
import akka.http.scaladsl.model.{ HttpResponse, HttpMethods }
|
import akka.http.scaladsl.model.{ HttpMethods, HttpResponse }
|
||||||
import akka.http.scaladsl.model.headers.Host
|
import akka.http.scaladsl.model.headers.Host
|
||||||
|
|
||||||
import akka.http.impl.engine.parsing.HttpMessageParser.StateResult
|
import akka.http.impl.engine.parsing.HttpMessageParser.StateResult
|
||||||
import akka.http.impl.engine.parsing.ParserOutput.{ RemainingBytes, ResponseStart, NeedMoreData }
|
import akka.http.impl.engine.parsing.ParserOutput.{ NeedMoreData, RemainingBytes, ResponseStart }
|
||||||
import akka.http.impl.engine.parsing.{ ParserOutput, HttpHeaderParser, HttpResponseParser }
|
import akka.http.impl.engine.parsing.{ HttpHeaderParser, HttpResponseParser, ParserOutput }
|
||||||
import akka.http.impl.engine.rendering.{ HttpRequestRendererFactory, RequestRenderingContext }
|
import akka.http.impl.engine.rendering.{ HttpRequestRendererFactory, RequestRenderingContext }
|
||||||
import akka.http.impl.engine.ws.Handshake.Client.NegotiatedWebSocketSettings
|
import akka.http.impl.engine.ws.Handshake.Client.NegotiatedWebSocketSettings
|
||||||
import akka.http.impl.util.StreamUtils
|
import akka.http.impl.util.StreamUtils
|
||||||
|
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
|
||||||
|
|
||||||
object WebSocketClientBlueprint {
|
object WebSocketClientBlueprint {
|
||||||
/**
|
/**
|
||||||
|
|
@ -59,68 +56,70 @@ object WebSocketClientBlueprint {
|
||||||
val renderedInitialRequest =
|
val renderedInitialRequest =
|
||||||
HttpRequestRendererFactory.renderStrict(RequestRenderingContext(initialRequest, hostHeader), settings, log)
|
HttpRequestRendererFactory.renderStrict(RequestRenderingContext(initialRequest, hostHeader), settings, log)
|
||||||
|
|
||||||
class UpgradeStage extends StatefulStage[ByteString, ByteString] {
|
class UpgradeStage extends SimpleLinearGraphStage[ByteString] {
|
||||||
type State = StageState[ByteString, ByteString]
|
|
||||||
|
|
||||||
def initial: State = parsingResponse
|
override def createLogic(attributes: Attributes): GraphStageLogic =
|
||||||
|
new GraphStageLogic(shape) with InHandler with OutHandler {
|
||||||
def parsingResponse: State = new State {
|
// a special version of the parser which only parses one message and then reports the remaining data
|
||||||
// a special version of the parser which only parses one message and then reports the remaining data
|
// if some is available
|
||||||
// if some is available
|
val parser = new HttpResponseParser(settings.parserSettings, HttpHeaderParser(settings.parserSettings)()) {
|
||||||
val parser = new HttpResponseParser(settings.parserSettings, HttpHeaderParser(settings.parserSettings)()) {
|
var first = true
|
||||||
var first = true
|
override def handleInformationalResponses = false
|
||||||
override def handleInformationalResponses = false
|
override protected def parseMessage(input: ByteString, offset: Int): StateResult = {
|
||||||
override protected def parseMessage(input: ByteString, offset: Int): StateResult = {
|
if (first) {
|
||||||
if (first) {
|
first = false
|
||||||
first = false
|
super.parseMessage(input, offset)
|
||||||
super.parseMessage(input, offset)
|
} else {
|
||||||
} else {
|
emit(RemainingBytes(input.drop(offset)))
|
||||||
emit(RemainingBytes(input.drop(offset)))
|
terminate()
|
||||||
terminate()
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
parser.setContextForNextResponse(HttpResponseParser.ResponseContext(HttpMethods.GET, None))
|
||||||
parser.setContextForNextResponse(HttpResponseParser.ResponseContext(HttpMethods.GET, None))
|
|
||||||
|
|
||||||
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = {
|
override def onPush(): Unit = {
|
||||||
parser.parseBytes(elem) match {
|
parser.parseBytes(grab(in)) match {
|
||||||
case NeedMoreData ⇒ ctx.pull()
|
case NeedMoreData ⇒ pull(in)
|
||||||
case ResponseStart(status, protocol, headers, entity, close) ⇒
|
case ResponseStart(status, protocol, headers, entity, close) ⇒
|
||||||
val response = HttpResponse(status, headers, protocol = protocol)
|
val response = HttpResponse(status, headers, protocol = protocol)
|
||||||
Handshake.Client.validateResponse(response, subprotocol.toList, key) match {
|
Handshake.Client.validateResponse(response, subprotocol.toList, key) match {
|
||||||
case Right(NegotiatedWebSocketSettings(protocol)) ⇒
|
case Right(NegotiatedWebSocketSettings(protocol)) ⇒
|
||||||
result.success(ValidUpgrade(response, protocol))
|
result.success(ValidUpgrade(response, protocol))
|
||||||
|
|
||||||
become(transparent)
|
setHandler(in, new InHandler {
|
||||||
valve.open()
|
override def onPush(): Unit = push(out, grab(in))
|
||||||
|
})
|
||||||
|
valve.open()
|
||||||
|
|
||||||
val parseResult = parser.onPull()
|
val parseResult = parser.onPull()
|
||||||
require(parseResult == ParserOutput.MessageEnd, s"parseResult should be MessageEnd but was $parseResult")
|
require(parseResult == ParserOutput.MessageEnd, s"parseResult should be MessageEnd but was $parseResult")
|
||||||
parser.onPull() match {
|
parser.onPull() match {
|
||||||
case NeedMoreData ⇒ ctx.pull()
|
case NeedMoreData ⇒ pull(in)
|
||||||
case RemainingBytes(bytes) ⇒ ctx.push(bytes)
|
case RemainingBytes(bytes) ⇒ push(out, bytes)
|
||||||
case other ⇒
|
case other ⇒
|
||||||
throw new IllegalStateException(s"unexpected element of type ${other.getClass}")
|
throw new IllegalStateException(s"unexpected element of type ${other.getClass}")
|
||||||
}
|
}
|
||||||
case Left(problem) ⇒
|
case Left(problem) ⇒
|
||||||
result.success(InvalidUpgradeResponse(response, s"WebSocket server at $uri returned $problem"))
|
result.success(InvalidUpgradeResponse(response, s"WebSocket server at $uri returned $problem"))
|
||||||
ctx.fail(new IllegalArgumentException(s"WebSocket upgrade did not finish because of '$problem'"))
|
failStage(new IllegalArgumentException(s"WebSocket upgrade did not finish because of '$problem'"))
|
||||||
}
|
}
|
||||||
case other ⇒
|
case other ⇒
|
||||||
throw new IllegalStateException(s"unexpected element of type ${other.getClass}")
|
throw new IllegalStateException(s"unexpected element of type ${other.getClass}")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def transparent: State = new State {
|
override def onPull(): Unit = pull(in)
|
||||||
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = ctx.push(elem)
|
|
||||||
}
|
setHandlers(in, out, this)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def toString = "UpgradeStage"
|
||||||
}
|
}
|
||||||
|
|
||||||
BidiFlow.fromGraph(GraphDSL.create() { implicit b ⇒
|
BidiFlow.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||||
import GraphDSL.Implicits._
|
import GraphDSL.Implicits._
|
||||||
|
|
||||||
val networkIn = b.add(Flow[ByteString].transform(() ⇒ new UpgradeStage))
|
val networkIn = b.add(Flow[ByteString].via(new UpgradeStage))
|
||||||
val wsIn = b.add(Flow[ByteString])
|
val wsIn = b.add(Flow[ByteString])
|
||||||
|
|
||||||
val handshakeRequestSource = b.add(Source.single(renderedInitialRequest) ++ valve.source)
|
val handshakeRequestSource = b.add(Source.single(renderedInitialRequest) ++ valve.source)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue