diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebSocketClientBlueprint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebSocketClientBlueprint.scala index 001932af61..e193877496 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebSocketClientBlueprint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebSocketClientBlueprint.scala @@ -8,26 +8,23 @@ import akka.NotUsed import akka.http.scaladsl.model.ws._ import scala.concurrent.{ Future, Promise } - import akka.util.ByteString import akka.event.LoggingAdapter - import akka.stream.stage._ import akka.stream._ import akka.stream.TLSProtocol._ import akka.stream.scaladsl._ - import akka.http.scaladsl.settings.ClientConnectionSettings import akka.http.scaladsl.Http -import akka.http.scaladsl.model.{ HttpResponse, HttpMethods } +import akka.http.scaladsl.model.{ HttpMethods, HttpResponse } import akka.http.scaladsl.model.headers.Host - import akka.http.impl.engine.parsing.HttpMessageParser.StateResult -import akka.http.impl.engine.parsing.ParserOutput.{ RemainingBytes, ResponseStart, NeedMoreData } -import akka.http.impl.engine.parsing.{ ParserOutput, HttpHeaderParser, HttpResponseParser } +import akka.http.impl.engine.parsing.ParserOutput.{ NeedMoreData, RemainingBytes, ResponseStart } +import akka.http.impl.engine.parsing.{ HttpHeaderParser, HttpResponseParser, ParserOutput } import akka.http.impl.engine.rendering.{ HttpRequestRendererFactory, RequestRenderingContext } import akka.http.impl.engine.ws.Handshake.Client.NegotiatedWebSocketSettings import akka.http.impl.util.StreamUtils +import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage object WebSocketClientBlueprint { /** @@ -59,68 +56,70 @@ object WebSocketClientBlueprint { val renderedInitialRequest = HttpRequestRendererFactory.renderStrict(RequestRenderingContext(initialRequest, hostHeader), settings, log) - class UpgradeStage extends StatefulStage[ByteString, ByteString] { - type State = StageState[ByteString, ByteString] + class UpgradeStage extends SimpleLinearGraphStage[ByteString] { - def initial: State = parsingResponse - - def parsingResponse: State = new State { - // a special version of the parser which only parses one message and then reports the remaining data - // if some is available - val parser = new HttpResponseParser(settings.parserSettings, HttpHeaderParser(settings.parserSettings)()) { - var first = true - override def handleInformationalResponses = false - override protected def parseMessage(input: ByteString, offset: Int): StateResult = { - if (first) { - first = false - super.parseMessage(input, offset) - } else { - emit(RemainingBytes(input.drop(offset))) - terminate() + override def createLogic(attributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with InHandler with OutHandler { + // a special version of the parser which only parses one message and then reports the remaining data + // if some is available + val parser = new HttpResponseParser(settings.parserSettings, HttpHeaderParser(settings.parserSettings)()) { + var first = true + override def handleInformationalResponses = false + override protected def parseMessage(input: ByteString, offset: Int): StateResult = { + if (first) { + first = false + super.parseMessage(input, offset) + } else { + emit(RemainingBytes(input.drop(offset))) + terminate() + } } } - } - parser.setContextForNextResponse(HttpResponseParser.ResponseContext(HttpMethods.GET, None)) + parser.setContextForNextResponse(HttpResponseParser.ResponseContext(HttpMethods.GET, None)) - def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = { - parser.parseBytes(elem) match { - case NeedMoreData ⇒ ctx.pull() - case ResponseStart(status, protocol, headers, entity, close) ⇒ - val response = HttpResponse(status, headers, protocol = protocol) - Handshake.Client.validateResponse(response, subprotocol.toList, key) match { - case Right(NegotiatedWebSocketSettings(protocol)) ⇒ - result.success(ValidUpgrade(response, protocol)) + override def onPush(): Unit = { + parser.parseBytes(grab(in)) match { + case NeedMoreData ⇒ pull(in) + case ResponseStart(status, protocol, headers, entity, close) ⇒ + val response = HttpResponse(status, headers, protocol = protocol) + Handshake.Client.validateResponse(response, subprotocol.toList, key) match { + case Right(NegotiatedWebSocketSettings(protocol)) ⇒ + result.success(ValidUpgrade(response, protocol)) - become(transparent) - valve.open() + setHandler(in, new InHandler { + override def onPush(): Unit = push(out, grab(in)) + }) + valve.open() - val parseResult = parser.onPull() - require(parseResult == ParserOutput.MessageEnd, s"parseResult should be MessageEnd but was $parseResult") - parser.onPull() match { - case NeedMoreData ⇒ ctx.pull() - case RemainingBytes(bytes) ⇒ ctx.push(bytes) - case other ⇒ - throw new IllegalStateException(s"unexpected element of type ${other.getClass}") - } - case Left(problem) ⇒ - result.success(InvalidUpgradeResponse(response, s"WebSocket server at $uri returned $problem")) - ctx.fail(new IllegalArgumentException(s"WebSocket upgrade did not finish because of '$problem'")) - } - case other ⇒ - throw new IllegalStateException(s"unexpected element of type ${other.getClass}") + val parseResult = parser.onPull() + require(parseResult == ParserOutput.MessageEnd, s"parseResult should be MessageEnd but was $parseResult") + parser.onPull() match { + case NeedMoreData ⇒ pull(in) + case RemainingBytes(bytes) ⇒ push(out, bytes) + case other ⇒ + throw new IllegalStateException(s"unexpected element of type ${other.getClass}") + } + case Left(problem) ⇒ + result.success(InvalidUpgradeResponse(response, s"WebSocket server at $uri returned $problem")) + failStage(new IllegalArgumentException(s"WebSocket upgrade did not finish because of '$problem'")) + } + case other ⇒ + throw new IllegalStateException(s"unexpected element of type ${other.getClass}") + } } - } - } - def transparent: State = new State { - def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = ctx.push(elem) - } + override def onPull(): Unit = pull(in) + + setHandlers(in, out, this) + } + + override def toString = "UpgradeStage" } BidiFlow.fromGraph(GraphDSL.create() { implicit b ⇒ import GraphDSL.Implicits._ - val networkIn = b.add(Flow[ByteString].transform(() ⇒ new UpgradeStage)) + val networkIn = b.add(Flow[ByteString].via(new UpgradeStage)) val wsIn = b.add(Flow[ByteString]) val handshakeRequestSource = b.add(Source.single(renderedInitialRequest) ++ valve.source)