Merge pull request #19488 from akka/wip-19398-forward-port-RK
multiple HTTP logic fixes
This commit is contained in:
commit
4de3f63b93
9 changed files with 397 additions and 121 deletions
|
|
@ -167,7 +167,6 @@ private[http] object OutgoingConnectionBlueprint {
|
|||
// each connection uses a single (private) response parser instance for all its responses
|
||||
// which builds a cache of all header instances seen on that connection
|
||||
val parser = rootParser.createShallowCopy()
|
||||
var methodBypassCompleted = false
|
||||
var waitingForMethod = true
|
||||
|
||||
setHandler(methodBypassInput, new InHandler {
|
||||
|
|
@ -179,7 +178,6 @@ private[http] object OutgoingConnectionBlueprint {
|
|||
}
|
||||
override def onUpstreamFinish(): Unit =
|
||||
if (waitingForMethod) completeStage()
|
||||
else methodBypassCompleted = true
|
||||
})
|
||||
|
||||
setHandler(dataInput, new InHandler {
|
||||
|
|
@ -201,17 +199,16 @@ private[http] object OutgoingConnectionBlueprint {
|
|||
|
||||
setHandler(out, eagerTerminateOutput)
|
||||
|
||||
val getNextMethod = () ⇒
|
||||
if (methodBypassCompleted) completeStage()
|
||||
else {
|
||||
pull(methodBypassInput)
|
||||
waitingForMethod = true
|
||||
}
|
||||
val getNextMethod = () ⇒ {
|
||||
waitingForMethod = true
|
||||
if (isClosed(methodBypassInput)) completeStage()
|
||||
else pull(methodBypassInput)
|
||||
}
|
||||
|
||||
val getNextData = () ⇒ {
|
||||
waitingForMethod = false
|
||||
if (!isClosed(dataInput)) pull(dataInput)
|
||||
else completeStage()
|
||||
if (isClosed(dataInput)) completeStage()
|
||||
else pull(dataInput)
|
||||
}
|
||||
|
||||
@tailrec def drainParser(current: ResponseOutput, b: ListBuffer[ResponseOutput] = ListBuffer.empty): Unit = {
|
||||
|
|
@ -219,13 +216,10 @@ private[http] object OutgoingConnectionBlueprint {
|
|||
if (output.nonEmpty) emit(out, output, andThen)
|
||||
else andThen()
|
||||
current match {
|
||||
case NeedNextRequestMethod ⇒
|
||||
e(b.result(), getNextMethod)
|
||||
case StreamEnd ⇒
|
||||
e(b.result(), () ⇒ completeStage())
|
||||
case NeedMoreData ⇒
|
||||
e(b.result(), getNextData)
|
||||
case x ⇒ drainParser(parser.onPull(), b += x)
|
||||
case NeedNextRequestMethod ⇒ e(b.result(), getNextMethod)
|
||||
case StreamEnd ⇒ e(b.result(), () ⇒ completeStage())
|
||||
case NeedMoreData ⇒ e(b.result(), getNextData)
|
||||
case x ⇒ drainParser(parser.onPull(), b += x)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -62,13 +62,13 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
|
|||
|
||||
def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
|
||||
new GraphStageLogic(shape) {
|
||||
private[this] var closeMode: CloseMode = DontClose // signals what to do after the current response
|
||||
private[this] def close: Boolean = closeMode != DontClose
|
||||
private[this] def closeIf(cond: Boolean): Unit =
|
||||
if (cond) closeMode = CloseConnection
|
||||
var closeMode: CloseMode = DontClose // signals what to do after the current response
|
||||
def close: Boolean = closeMode != DontClose
|
||||
def closeIf(cond: Boolean): Unit = if (cond) closeMode = CloseConnection
|
||||
var transferring = false
|
||||
|
||||
setHandler(in, new InHandler {
|
||||
def onPush(): Unit =
|
||||
override def onPush(): Unit =
|
||||
render(grab(in)) match {
|
||||
case Strict(outElement) ⇒
|
||||
push(out, outElement)
|
||||
|
|
@ -76,23 +76,36 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
|
|||
case Streamed(outStream) ⇒ transfer(outStream)
|
||||
}
|
||||
|
||||
override def onUpstreamFinish(): Unit = closeMode = CloseConnection
|
||||
override def onUpstreamFinish(): Unit =
|
||||
if (transferring) closeMode = CloseConnection
|
||||
else completeStage()
|
||||
})
|
||||
val waitForDemandHandler = new OutHandler {
|
||||
def onPull(): Unit = if (close) completeStage() else pull(in)
|
||||
def onPull(): Unit = pull(in)
|
||||
}
|
||||
setHandler(out, waitForDemandHandler)
|
||||
def transfer(outStream: Source[ResponseRenderingOutput, Any]): Unit = {
|
||||
transferring = true
|
||||
val sinkIn = new SubSinkInlet[ResponseRenderingOutput]("RenderingSink")
|
||||
sinkIn.setHandler(new InHandler {
|
||||
def onPush(): Unit = push(out, sinkIn.grab())
|
||||
override def onUpstreamFinish(): Unit = if (close) completeStage() else setHandler(out, waitForDemandHandler)
|
||||
override def onPush(): Unit = push(out, sinkIn.grab())
|
||||
override def onUpstreamFinish(): Unit =
|
||||
if (close) completeStage()
|
||||
else {
|
||||
transferring = false
|
||||
setHandler(out, waitForDemandHandler)
|
||||
if (isAvailable(out)) pull(in)
|
||||
}
|
||||
})
|
||||
setHandler(out, new OutHandler {
|
||||
def onPull(): Unit = sinkIn.pull()
|
||||
override def onPull(): Unit = sinkIn.pull()
|
||||
override def onDownstreamFinish(): Unit = {
|
||||
completeStage()
|
||||
sinkIn.cancel()
|
||||
}
|
||||
})
|
||||
sinkIn.pull()
|
||||
Source.fromGraph(outStream).runWith(sinkIn.sink)(interpreter.subFusingMaterializer)
|
||||
outStream.runWith(sinkIn.sink)(interpreter.subFusingMaterializer)
|
||||
}
|
||||
|
||||
def render(ctx: ResponseRenderingContext): StrictOrStreamed = {
|
||||
|
|
|
|||
|
|
@ -371,6 +371,11 @@ private[http] object HttpServerBluePrint {
|
|||
def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) {
|
||||
import akka.http.impl.engine.rendering.ResponseRenderingOutput._
|
||||
|
||||
/*
|
||||
* These handlers are in charge until a switch command comes in, then they
|
||||
* are replaced.
|
||||
*/
|
||||
|
||||
setHandler(fromHttp, new InHandler {
|
||||
override def onPush(): Unit =
|
||||
grab(fromHttp) match {
|
||||
|
|
@ -381,21 +386,22 @@ private[http] object HttpServerBluePrint {
|
|||
cancel(fromHttp)
|
||||
switchToWebsocket(handlerFlow)
|
||||
}
|
||||
override def onUpstreamFinish(): Unit = complete(toNet)
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
|
||||
})
|
||||
setHandler(toNet, new OutHandler {
|
||||
override def onPull(): Unit = pull(fromHttp)
|
||||
override def onDownstreamFinish(): Unit = completeStage()
|
||||
})
|
||||
|
||||
setHandler(fromNet, new InHandler {
|
||||
def onPush(): Unit = push(toHttp, grab(fromNet))
|
||||
|
||||
// propagate error but don't close stage yet to prevent fromHttp/fromWs being cancelled
|
||||
// too eagerly
|
||||
override def onPush(): Unit = push(toHttp, grab(fromNet))
|
||||
override def onUpstreamFinish(): Unit = complete(toHttp)
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = fail(toHttp, ex)
|
||||
})
|
||||
setHandler(toHttp, new OutHandler {
|
||||
override def onPull(): Unit = pull(fromNet)
|
||||
override def onDownstreamFinish(): Unit = ()
|
||||
override def onDownstreamFinish(): Unit = cancel(fromNet)
|
||||
})
|
||||
|
||||
private var activeTimers = 0
|
||||
|
|
@ -427,36 +433,60 @@ private[http] object HttpServerBluePrint {
|
|||
case Right(messageHandler) ⇒
|
||||
Websocket.stack(serverSide = true, maskingRandomFactory = settings.websocketRandomFactory, log = log).join(messageHandler)
|
||||
}
|
||||
|
||||
val sinkIn = new SubSinkInlet[ByteString]("FrameSink")
|
||||
val sourceOut = new SubSourceOutlet[ByteString]("FrameSource")
|
||||
|
||||
val timeoutKey = SubscriptionTimeout(() ⇒ {
|
||||
sourceOut.timeout(timeout)
|
||||
if (sourceOut.isClosed) completeStage()
|
||||
})
|
||||
addTimeout(timeoutKey)
|
||||
|
||||
sinkIn.setHandler(new InHandler {
|
||||
override def onPush(): Unit = push(toNet, sinkIn.grab())
|
||||
})
|
||||
setHandler(toNet, new OutHandler {
|
||||
override def onPull(): Unit = sinkIn.pull()
|
||||
override def onUpstreamFinish(): Unit = complete(toNet)
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
|
||||
})
|
||||
|
||||
setHandler(fromNet, new InHandler {
|
||||
override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes)
|
||||
})
|
||||
sourceOut.setHandler(new OutHandler {
|
||||
override def onPull(): Unit = {
|
||||
if (!hasBeenPulled(fromNet)) pull(fromNet)
|
||||
cancelTimeout(timeoutKey)
|
||||
sourceOut.setHandler(new OutHandler {
|
||||
override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet)
|
||||
})
|
||||
}
|
||||
})
|
||||
if (isClosed(fromNet)) {
|
||||
setHandler(toNet, new OutHandler {
|
||||
override def onPull(): Unit = sinkIn.pull()
|
||||
override def onDownstreamFinish(): Unit = {
|
||||
completeStage()
|
||||
sinkIn.cancel()
|
||||
}
|
||||
})
|
||||
Websocket.framing.join(frameHandler).runWith(Source.empty, sinkIn.sink)(subFusingMaterializer)
|
||||
} else {
|
||||
val sourceOut = new SubSourceOutlet[ByteString]("FrameSource")
|
||||
|
||||
Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer)
|
||||
val timeoutKey = SubscriptionTimeout(() ⇒ {
|
||||
sourceOut.timeout(timeout)
|
||||
if (sourceOut.isClosed) completeStage()
|
||||
})
|
||||
addTimeout(timeoutKey)
|
||||
|
||||
setHandler(toNet, new OutHandler {
|
||||
override def onPull(): Unit = sinkIn.pull()
|
||||
override def onDownstreamFinish(): Unit = {
|
||||
completeStage()
|
||||
sinkIn.cancel()
|
||||
sourceOut.complete()
|
||||
}
|
||||
})
|
||||
|
||||
setHandler(fromNet, new InHandler {
|
||||
override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes)
|
||||
override def onUpstreamFinish(): Unit = sourceOut.complete()
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex)
|
||||
})
|
||||
sourceOut.setHandler(new OutHandler {
|
||||
override def onPull(): Unit = {
|
||||
if (!hasBeenPulled(fromNet)) pull(fromNet)
|
||||
cancelTimeout(timeoutKey)
|
||||
sourceOut.setHandler(new OutHandler {
|
||||
override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet)
|
||||
override def onDownstreamFinish(): Unit = cancel(fromNet)
|
||||
})
|
||||
}
|
||||
override def onDownstreamFinish(): Unit = cancel(fromNet)
|
||||
})
|
||||
|
||||
Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -195,14 +195,19 @@ private[http] object Websocket {
|
|||
|
||||
def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) {
|
||||
|
||||
passAlong(bypass, out, doFinish = true, doFail = true)
|
||||
passAlong(user, out, doFinish = false, doFail = false)
|
||||
class PassAlong[T <: AnyRef](from: Inlet[T]) extends InHandler with (() ⇒ Unit) {
|
||||
override def apply(): Unit = tryPull(from)
|
||||
override def onPush(): Unit = emit(out, grab(from), this)
|
||||
override def onUpstreamFinish(): Unit =
|
||||
if (isClosed(bypass) && isClosed(user)) completeStage()
|
||||
}
|
||||
setHandler(bypass, new PassAlong(bypass))
|
||||
setHandler(user, new PassAlong(user))
|
||||
passAlong(tick, out, doFinish = false, doFail = false)
|
||||
|
||||
setHandler(out, eagerTerminateOutput)
|
||||
|
||||
override def preStart(): Unit = {
|
||||
super.preStart()
|
||||
pull(bypass)
|
||||
pull(user)
|
||||
pull(tick)
|
||||
|
|
|
|||
|
|
@ -11,16 +11,19 @@ import akka.stream.scaladsl._
|
|||
import akka.stream.testkit.AkkaSpec
|
||||
import akka.http.scaladsl.{ Http, TestUtils }
|
||||
import akka.http.scaladsl.model._
|
||||
import akka.stream.testkit.Utils
|
||||
import org.scalatest.concurrent.ScalaFutures
|
||||
|
||||
class HighLevelOutgoingConnectionSpec extends AkkaSpec {
|
||||
class HighLevelOutgoingConnectionSpec extends AkkaSpec with ScalaFutures {
|
||||
implicit val materializer = ActorMaterializer()
|
||||
implicit val patience = PatienceConfig(1.second)
|
||||
|
||||
"The connection-level client implementation" should {
|
||||
|
||||
"be able to handle 100 pipelined requests across one connection" in {
|
||||
"be able to handle 100 pipelined requests across one connection" in Utils.assertAllStagesStopped {
|
||||
val (_, serverHostName, serverPort) = TestUtils.temporaryServerHostnameAndPort()
|
||||
|
||||
Http().bindAndHandleSync(r ⇒ HttpResponse(entity = r.uri.toString.reverse.takeWhile(Character.isDigit).reverse),
|
||||
val binding = Http().bindAndHandleSync(r ⇒ HttpResponse(entity = r.uri.toString.reverse.takeWhile(Character.isDigit).reverse),
|
||||
serverHostName, serverPort)
|
||||
|
||||
val N = 100
|
||||
|
|
@ -32,13 +35,14 @@ class HighLevelOutgoingConnectionSpec extends AkkaSpec {
|
|||
.map { r ⇒ val s = r.data.utf8String; log.debug(s); s.toInt }
|
||||
.runFold(0)(_ + _)
|
||||
|
||||
Await.result(result, 10.seconds) shouldEqual N * (N + 1) / 2
|
||||
result.futureValue(PatienceConfig(10.seconds)) shouldEqual N * (N + 1) / 2
|
||||
binding.futureValue.unbind()
|
||||
}
|
||||
|
||||
"be able to handle 100 pipelined requests across 4 connections (client-flow is reusable)" in {
|
||||
"be able to handle 100 pipelined requests across 4 connections (client-flow is reusable)" in Utils.assertAllStagesStopped {
|
||||
val (_, serverHostName, serverPort) = TestUtils.temporaryServerHostnameAndPort()
|
||||
|
||||
Http().bindAndHandleSync(r ⇒ HttpResponse(entity = r.uri.toString.reverse.takeWhile(Character.isDigit).reverse),
|
||||
val binding = Http().bindAndHandleSync(r ⇒ HttpResponse(entity = r.uri.toString.reverse.takeWhile(Character.isDigit).reverse),
|
||||
serverHostName, serverPort)
|
||||
|
||||
val connFlow = Http().outgoingConnection(serverHostName, serverPort)
|
||||
|
|
@ -64,12 +68,14 @@ class HighLevelOutgoingConnectionSpec extends AkkaSpec {
|
|||
.map { r ⇒ val s = r.data.utf8String; log.debug(s); s.toInt }
|
||||
.runFold(0)(_ + _)
|
||||
|
||||
Await.result(result, 10.seconds) shouldEqual C * N * (N + 1) / 2
|
||||
result.futureValue(PatienceConfig(10.seconds)) shouldEqual C * N * (N + 1) / 2
|
||||
binding.futureValue.unbind()
|
||||
}
|
||||
|
||||
"catch response stream truncation" in {
|
||||
"catch response stream truncation" in Utils.assertAllStagesStopped {
|
||||
val (_, serverHostName, serverPort) = TestUtils.temporaryServerHostnameAndPort()
|
||||
Http().bindAndHandleSync({
|
||||
|
||||
val binding = Http().bindAndHandleSync({
|
||||
case HttpRequest(_, Uri.Path("/b"), _, _, _) ⇒ HttpResponse(headers = List(headers.Connection("close")))
|
||||
case _ ⇒ HttpResponse()
|
||||
}, serverHostName, serverPort)
|
||||
|
|
@ -81,6 +87,7 @@ class HighLevelOutgoingConnectionSpec extends AkkaSpec {
|
|||
.runWith(Sink.head)
|
||||
|
||||
a[One2OneBidiFlow.OutputTruncationException.type] should be thrownBy Await.result(x, 1.second)
|
||||
binding.futureValue.unbind()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,50 +0,0 @@
|
|||
/**
|
||||
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
package akka.http.impl.engine.ws
|
||||
|
||||
import akka.stream.testkit.AkkaSpec
|
||||
import scala.concurrent.Await
|
||||
import com.typesafe.config.ConfigFactory
|
||||
import com.typesafe.config.Config
|
||||
import akka.actor.ActorSystem
|
||||
import akka.http.scaladsl.model.HttpRequest
|
||||
import akka.http.scaladsl.model.ws._
|
||||
import akka.http.scaladsl._
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream._
|
||||
import scala.concurrent.duration._
|
||||
import org.scalatest.concurrent.ScalaFutures
|
||||
import org.scalactic.ConversionCheckedTripleEquals
|
||||
import akka.stream.testkit.Utils
|
||||
|
||||
class BypassRouterSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode = off") with ScalaFutures with ConversionCheckedTripleEquals {
|
||||
|
||||
implicit val patience = PatienceConfig(3.seconds)
|
||||
import system.dispatcher
|
||||
implicit val materializer = ActorMaterializer()
|
||||
|
||||
"BypassRouter" must {
|
||||
|
||||
"work without double pull-ing some ports" in Utils.assertAllStagesStopped {
|
||||
val bindingFuture = Http().bindAndHandleSync({
|
||||
case HttpRequest(_, _, headers, _, _) ⇒
|
||||
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
||||
upgrade.handleMessages(Flow.apply, None)
|
||||
}, interface = "localhost", port = 8080)
|
||||
val binding = Await.result(bindingFuture, 3.seconds)
|
||||
|
||||
val N = 100
|
||||
val (response, count) = Http().singleWebsocketRequest(
|
||||
WebsocketRequest("ws://127.0.0.1:8080"),
|
||||
Flow.fromSinkAndSourceMat(
|
||||
Sink.fold(0)((n, _: Message) ⇒ n + 1),
|
||||
Source.repeat(TextMessage("hello")).take(N))(Keep.left))
|
||||
|
||||
count.futureValue should ===(N)
|
||||
binding.unbind()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
/**
|
||||
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
package akka.http.impl.engine.ws
|
||||
|
||||
import scala.concurrent.Await
|
||||
import scala.concurrent.duration.DurationInt
|
||||
import org.scalactic.ConversionCheckedTripleEquals
|
||||
import org.scalatest.concurrent.ScalaFutures
|
||||
import org.scalatest.time.Span.convertDurationToSpan
|
||||
import akka.http.scaladsl.Http
|
||||
import akka.http.scaladsl.model.HttpRequest
|
||||
import akka.http.scaladsl.model.Uri.apply
|
||||
import akka.http.scaladsl.model.ws._
|
||||
import akka.stream._
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream.testkit._
|
||||
import akka.stream.scaladsl.GraphDSL.Implicits._
|
||||
import org.scalatest.concurrent.Eventually
|
||||
import akka.stream.io.SslTlsPlacebo
|
||||
import java.net.InetSocketAddress
|
||||
import akka.stream.impl.fusing.GraphStages
|
||||
import akka.util.ByteString
|
||||
import akka.http.scaladsl.model.StatusCodes
|
||||
import akka.stream.testkit.scaladsl.TestSink
|
||||
import scala.concurrent.Future
|
||||
|
||||
class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode=off")
|
||||
with ScalaFutures with ConversionCheckedTripleEquals with Eventually {
|
||||
|
||||
implicit val patience = PatienceConfig(3.seconds)
|
||||
import system.dispatcher
|
||||
implicit val materializer = ActorMaterializer()
|
||||
|
||||
"A Websocket server" must {
|
||||
|
||||
"not reset the connection when no data are flowing" in Utils.assertAllStagesStopped {
|
||||
val source = TestPublisher.probe[Message]()
|
||||
val bindingFuture = Http().bindAndHandleSync({
|
||||
case HttpRequest(_, _, headers, _, _) ⇒
|
||||
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
||||
upgrade.handleMessages(Flow.fromSinkAndSource(Sink.ignore, Source.fromPublisher(source)), None)
|
||||
}, interface = "localhost", port = 0)
|
||||
val binding = Await.result(bindingFuture, 3.seconds)
|
||||
val myPort = binding.localAddress.getPort
|
||||
|
||||
val (response, sink) = Http().singleWebsocketRequest(
|
||||
WebsocketRequest("ws://127.0.0.1:" + myPort),
|
||||
Flow.fromSinkAndSourceMat(TestSink.probe[Message], Source.empty)(Keep.left))
|
||||
|
||||
response.futureValue.response.status.isSuccess should ===(true)
|
||||
sink
|
||||
.request(10)
|
||||
.expectNoMsg(500.millis)
|
||||
|
||||
source
|
||||
.sendNext(TextMessage("hello"))
|
||||
.sendComplete()
|
||||
sink
|
||||
.expectNext(TextMessage("hello"))
|
||||
.expectComplete()
|
||||
|
||||
binding.unbind()
|
||||
}
|
||||
|
||||
"not reset the connection when no data are flowing and the connection is closed from the client" in Utils.assertAllStagesStopped {
|
||||
val source = TestPublisher.probe[Message]()
|
||||
val bindingFuture = Http().bindAndHandleSync({
|
||||
case HttpRequest(_, _, headers, _, _) ⇒
|
||||
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
||||
upgrade.handleMessages(Flow.fromSinkAndSource(Sink.ignore, Source.fromPublisher(source)), None)
|
||||
}, interface = "localhost", port = 0)
|
||||
val binding = Await.result(bindingFuture, 3.seconds)
|
||||
val myPort = binding.localAddress.getPort
|
||||
|
||||
val ((response, breaker), sink) =
|
||||
Source.empty
|
||||
.viaMat {
|
||||
Http().websocketClientLayer(WebsocketRequest("ws://localhost:" + myPort))
|
||||
.atop(SslTlsPlacebo.forScala)
|
||||
.joinMat(Flow.fromGraph(GraphStages.breaker[ByteString]).via(
|
||||
Tcp().outgoingConnection(new InetSocketAddress("localhost", myPort), halfClose = true)))(Keep.both)
|
||||
}(Keep.right)
|
||||
.toMat(TestSink.probe[Message])(Keep.both)
|
||||
.run()
|
||||
|
||||
response.futureValue.response.status.isSuccess should ===(true)
|
||||
sink
|
||||
.request(10)
|
||||
.expectNoMsg(1500.millis)
|
||||
|
||||
breaker.value.get.get.complete()
|
||||
|
||||
source
|
||||
.sendNext(TextMessage("hello"))
|
||||
.sendComplete()
|
||||
sink
|
||||
.expectNext(TextMessage("hello"))
|
||||
.expectComplete()
|
||||
|
||||
binding.unbind()
|
||||
}
|
||||
|
||||
"echo 100 elements and then shut down without error" in Utils.assertAllStagesStopped {
|
||||
|
||||
val bindingFuture = Http().bindAndHandleSync({
|
||||
case HttpRequest(_, _, headers, _, _) ⇒
|
||||
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
||||
upgrade.handleMessages(Flow.apply, None)
|
||||
}, interface = "localhost", port = 0)
|
||||
val binding = Await.result(bindingFuture, 3.seconds)
|
||||
val myPort = binding.localAddress.getPort
|
||||
|
||||
val N = 100
|
||||
val (response, count) = Http().singleWebsocketRequest(
|
||||
WebsocketRequest("ws://127.0.0.1:" + myPort),
|
||||
Flow.fromSinkAndSourceMat(
|
||||
Sink.fold(0)((n, _: Message) ⇒ n + 1),
|
||||
Source.repeat(TextMessage("hello")).take(N))(Keep.left))
|
||||
|
||||
count.futureValue should ===(N)
|
||||
binding.unbind()
|
||||
}
|
||||
|
||||
"send back 100 elements and then terminate without error even when not ordinarily closed" in Utils.assertAllStagesStopped {
|
||||
val N = 100
|
||||
|
||||
val handler = Flow.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||
val merge = b.add(Merge[Int](2))
|
||||
|
||||
// convert to int so we can connect to merge
|
||||
val mapMsgToInt = b.add(Flow[Message].map(_ ⇒ -1))
|
||||
val mapIntToMsg = b.add(Flow[Int].map(x ⇒ TextMessage.Strict(s"Sending: $x")))
|
||||
|
||||
// source we want to use to send message to the connected websocket sink
|
||||
val rangeSource = b.add(Source(1 to N))
|
||||
|
||||
mapMsgToInt ~> merge // this part of the merge will never provide msgs
|
||||
rangeSource ~> merge ~> mapIntToMsg
|
||||
|
||||
FlowShape(mapMsgToInt.in, mapIntToMsg.out)
|
||||
})
|
||||
|
||||
val bindingFuture = Http().bindAndHandleSync({
|
||||
case HttpRequest(_, _, headers, _, _) ⇒
|
||||
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
||||
upgrade.handleMessages(handler, None)
|
||||
}, interface = "localhost", port = 0)
|
||||
val binding = Await.result(bindingFuture, 3.seconds)
|
||||
val myPort = binding.localAddress.getPort
|
||||
|
||||
@volatile var messages = 0
|
||||
val (breaker, completion) =
|
||||
Source.maybe
|
||||
.viaMat {
|
||||
Http().websocketClientLayer(WebsocketRequest("ws://localhost:" + myPort))
|
||||
.atop(SslTlsPlacebo.forScala)
|
||||
// the resource leak of #19398 existed only for severed websocket connections
|
||||
.atopMat(GraphStages.bidiBreaker[ByteString, ByteString])(Keep.right)
|
||||
.join(Tcp().outgoingConnection(new InetSocketAddress("localhost", myPort), halfClose = true))
|
||||
}(Keep.right)
|
||||
.toMat(Sink.foreach(_ ⇒ messages += 1))(Keep.both)
|
||||
.run()
|
||||
eventually(messages should ===(N))
|
||||
// breaker should have been fulfilled long ago
|
||||
breaker.value.get.get.completeAndCancel()
|
||||
completion.futureValue
|
||||
|
||||
binding.unbind()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -108,6 +108,107 @@ object GraphStages {
|
|||
private val _detacher = new Detacher[Any]
|
||||
def detacher[T]: GraphStage[FlowShape[T, T]] = _detacher.asInstanceOf[GraphStage[FlowShape[T, T]]]
|
||||
|
||||
final class Breaker(callback: Breaker.Operation ⇒ Unit) {
|
||||
import Breaker._
|
||||
def complete(): Unit = callback(Complete)
|
||||
def cancel(): Unit = callback(Cancel)
|
||||
def fail(ex: Throwable): Unit = callback(Fail(ex))
|
||||
def completeAndCancel(): Unit = callback(CompleteAndCancel)
|
||||
def failAndCancel(ex: Throwable): Unit = callback(FailAndCancel(ex))
|
||||
}
|
||||
|
||||
object Breaker extends GraphStageWithMaterializedValue[FlowShape[Any, Any], Future[Breaker]] {
|
||||
sealed trait Operation
|
||||
case object Complete extends Operation
|
||||
case object Cancel extends Operation
|
||||
case class Fail(ex: Throwable) extends Operation
|
||||
case object CompleteAndCancel extends Operation
|
||||
case class FailAndCancel(ex: Throwable) extends Operation
|
||||
|
||||
override val initialAttributes = Attributes.name("breaker")
|
||||
override val shape = FlowShape(Inlet[Any]("breaker.in"), Outlet[Any]("breaker.out"))
|
||||
|
||||
override def createLogicAndMaterializedValue(attr: Attributes) = {
|
||||
val promise = Promise[Breaker]
|
||||
|
||||
val logic = new GraphStageLogic(shape) {
|
||||
|
||||
passAlong(shape.in, shape.out)
|
||||
setHandler(shape.out, eagerTerminateOutput)
|
||||
|
||||
override def preStart(): Unit = {
|
||||
pull(shape.in)
|
||||
promise.success(new Breaker(getAsyncCallback[Operation] {
|
||||
case Complete ⇒ complete(shape.out)
|
||||
case Cancel ⇒ cancel(shape.in)
|
||||
case Fail(ex) ⇒ fail(shape.out, ex)
|
||||
case CompleteAndCancel ⇒ completeStage()
|
||||
case FailAndCancel(ex) ⇒ failStage(ex)
|
||||
}.invoke))
|
||||
}
|
||||
}
|
||||
|
||||
(logic, promise.future)
|
||||
}
|
||||
}
|
||||
|
||||
def breaker[T]: Graph[FlowShape[T, T], Future[Breaker]] = Breaker.asInstanceOf[Graph[FlowShape[T, T], Future[Breaker]]]
|
||||
|
||||
object BidiBreaker extends GraphStageWithMaterializedValue[BidiShape[Any, Any, Any, Any], Future[Breaker]] {
|
||||
import Breaker._
|
||||
|
||||
override val initialAttributes = Attributes.name("breaker")
|
||||
override val shape = BidiShape(
|
||||
Inlet[Any]("breaker.in1"), Outlet[Any]("breaker.out1"),
|
||||
Inlet[Any]("breaker.in2"), Outlet[Any]("breaker.out2"))
|
||||
|
||||
override def createLogicAndMaterializedValue(attr: Attributes) = {
|
||||
val promise = Promise[Breaker]
|
||||
|
||||
val logic = new GraphStageLogic(shape) {
|
||||
|
||||
setHandler(shape.in1, new InHandler {
|
||||
override def onPush(): Unit = push(shape.out1, grab(shape.in1))
|
||||
override def onUpstreamFinish(): Unit = complete(shape.out1)
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = fail(shape.out1, ex)
|
||||
})
|
||||
setHandler(shape.in2, new InHandler {
|
||||
override def onPush(): Unit = push(shape.out2, grab(shape.in2))
|
||||
override def onUpstreamFinish(): Unit = complete(shape.out2)
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = fail(shape.out2, ex)
|
||||
})
|
||||
setHandler(shape.out1, new OutHandler {
|
||||
override def onPull(): Unit = pull(shape.in1)
|
||||
override def onDownstreamFinish(): Unit = cancel(shape.in1)
|
||||
})
|
||||
setHandler(shape.out2, new OutHandler {
|
||||
override def onPull(): Unit = pull(shape.in2)
|
||||
override def onDownstreamFinish(): Unit = cancel(shape.in2)
|
||||
})
|
||||
|
||||
override def preStart(): Unit = {
|
||||
promise.success(new Breaker(getAsyncCallback[Operation] {
|
||||
case Complete ⇒
|
||||
complete(shape.out1)
|
||||
complete(shape.out2)
|
||||
case Cancel ⇒
|
||||
cancel(shape.in1)
|
||||
cancel(shape.in2)
|
||||
case Fail(ex) ⇒
|
||||
fail(shape.out1, ex)
|
||||
fail(shape.out2, ex)
|
||||
case CompleteAndCancel ⇒ completeStage()
|
||||
case FailAndCancel(ex) ⇒ failStage(ex)
|
||||
}.invoke))
|
||||
}
|
||||
}
|
||||
|
||||
(logic, promise.future)
|
||||
}
|
||||
}
|
||||
|
||||
def bidiBreaker[T1, T2]: Graph[BidiShape[T1, T1, T2, T2], Future[Breaker]] = BidiBreaker.asInstanceOf[Graph[BidiShape[T1, T1, T2, T2], Future[Breaker]]]
|
||||
|
||||
private object TickSource {
|
||||
class TickSourceCancellable(cancelled: AtomicBoolean) extends Cancellable {
|
||||
private val cancelPromise = Promise[Unit]()
|
||||
|
|
|
|||
|
|
@ -261,11 +261,13 @@ private[stream] object TcpConnectionStage {
|
|||
// (or half-close is turned off)
|
||||
if (isClosed(bytesOut) || !role.halfClose) connection ! Close
|
||||
// We still read, so we only close the write side
|
||||
else connection ! ConfirmedClose
|
||||
else if (connection != null) connection ! ConfirmedClose
|
||||
else completeStage()
|
||||
}
|
||||
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = {
|
||||
connection ! Abort
|
||||
if (connection != null) connection ! Abort
|
||||
else failStage(ex)
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue