#19467 avoid pulling closed port in ProtocolSwitch
This commit is contained in:
parent
8d6bacab2f
commit
5523e1d70f
4 changed files with 147 additions and 40 deletions
|
|
@ -433,46 +433,60 @@ private[http] object HttpServerBluePrint {
|
|||
case Right(messageHandler) ⇒
|
||||
Websocket.stack(serverSide = true, maskingRandomFactory = settings.websocketRandomFactory, log = log).join(messageHandler)
|
||||
}
|
||||
|
||||
val sinkIn = new SubSinkInlet[ByteString]("FrameSink")
|
||||
val sourceOut = new SubSourceOutlet[ByteString]("FrameSource")
|
||||
|
||||
val timeoutKey = SubscriptionTimeout(() ⇒ {
|
||||
sourceOut.timeout(timeout)
|
||||
if (sourceOut.isClosed) completeStage()
|
||||
})
|
||||
addTimeout(timeoutKey)
|
||||
|
||||
sinkIn.setHandler(new InHandler {
|
||||
override def onPush(): Unit = push(toNet, sinkIn.grab())
|
||||
override def onUpstreamFinish(): Unit = complete(toNet)
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
|
||||
})
|
||||
setHandler(toNet, new OutHandler {
|
||||
override def onPull(): Unit = sinkIn.pull()
|
||||
override def onDownstreamFinish(): Unit = {
|
||||
completeStage()
|
||||
sinkIn.cancel()
|
||||
sourceOut.complete()
|
||||
}
|
||||
})
|
||||
|
||||
setHandler(fromNet, new InHandler {
|
||||
override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes)
|
||||
override def onUpstreamFinish(): Unit = sourceOut.complete()
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex)
|
||||
})
|
||||
sourceOut.setHandler(new OutHandler {
|
||||
override def onPull(): Unit = {
|
||||
if (!hasBeenPulled(fromNet)) pull(fromNet)
|
||||
cancelTimeout(timeoutKey)
|
||||
sourceOut.setHandler(new OutHandler {
|
||||
override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet)
|
||||
})
|
||||
}
|
||||
override def onDownstreamFinish(): Unit = cancel(fromNet)
|
||||
})
|
||||
if (isClosed(fromNet)) {
|
||||
setHandler(toNet, new OutHandler {
|
||||
override def onPull(): Unit = sinkIn.pull()
|
||||
override def onDownstreamFinish(): Unit = {
|
||||
completeStage()
|
||||
sinkIn.cancel()
|
||||
}
|
||||
})
|
||||
Websocket.framing.join(frameHandler).runWith(Source.empty, sinkIn.sink)(subFusingMaterializer)
|
||||
} else {
|
||||
val sourceOut = new SubSourceOutlet[ByteString]("FrameSource")
|
||||
|
||||
Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer)
|
||||
val timeoutKey = SubscriptionTimeout(() ⇒ {
|
||||
sourceOut.timeout(timeout)
|
||||
if (sourceOut.isClosed) completeStage()
|
||||
})
|
||||
addTimeout(timeoutKey)
|
||||
|
||||
setHandler(toNet, new OutHandler {
|
||||
override def onPull(): Unit = sinkIn.pull()
|
||||
override def onDownstreamFinish(): Unit = {
|
||||
completeStage()
|
||||
sinkIn.cancel()
|
||||
sourceOut.complete()
|
||||
}
|
||||
})
|
||||
|
||||
setHandler(fromNet, new InHandler {
|
||||
override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes)
|
||||
override def onUpstreamFinish(): Unit = sourceOut.complete()
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex)
|
||||
})
|
||||
sourceOut.setHandler(new OutHandler {
|
||||
override def onPull(): Unit = {
|
||||
if (!hasBeenPulled(fromNet)) pull(fromNet)
|
||||
cancelTimeout(timeoutKey)
|
||||
sourceOut.setHandler(new OutHandler {
|
||||
override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet)
|
||||
override def onDownstreamFinish(): Unit = cancel(fromNet)
|
||||
})
|
||||
}
|
||||
override def onDownstreamFinish(): Unit = cancel(fromNet)
|
||||
})
|
||||
|
||||
Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -195,14 +195,19 @@ private[http] object Websocket {
|
|||
|
||||
def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) {
|
||||
|
||||
passAlong(bypass, out, doFinish = true, doFail = true)
|
||||
passAlong(user, out, doFinish = false, doFail = false)
|
||||
class PassAlong[T <: AnyRef](from: Inlet[T]) extends InHandler with (() ⇒ Unit) {
|
||||
override def apply(): Unit = tryPull(from)
|
||||
override def onPush(): Unit = emit(out, grab(from), this)
|
||||
override def onUpstreamFinish(): Unit =
|
||||
if (isClosed(bypass) && isClosed(user)) completeStage()
|
||||
}
|
||||
setHandler(bypass, new PassAlong(bypass))
|
||||
setHandler(user, new PassAlong(user))
|
||||
passAlong(tick, out, doFinish = false, doFail = false)
|
||||
|
||||
setHandler(out, eagerTerminateOutput)
|
||||
|
||||
override def preStart(): Unit = {
|
||||
super.preStart()
|
||||
pull(bypass)
|
||||
pull(user)
|
||||
pull(tick)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ import akka.stream.io.SslTlsPlacebo
|
|||
import java.net.InetSocketAddress
|
||||
import akka.stream.impl.fusing.GraphStages
|
||||
import akka.util.ByteString
|
||||
import akka.http.scaladsl.model.StatusCodes
|
||||
import akka.stream.testkit.scaladsl.TestSink
|
||||
import scala.concurrent.Future
|
||||
|
||||
class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode=off")
|
||||
with ScalaFutures with ConversionCheckedTripleEquals with Eventually {
|
||||
|
|
@ -31,6 +34,73 @@ class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.
|
|||
|
||||
"A Websocket server" must {
|
||||
|
||||
"not reset the connection when no data are flowing" in Utils.assertAllStagesStopped {
|
||||
val source = TestPublisher.probe[Message]()
|
||||
val bindingFuture = Http().bindAndHandleSync({
|
||||
case HttpRequest(_, _, headers, _, _) ⇒
|
||||
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
||||
upgrade.handleMessages(Flow.fromSinkAndSource(Sink.ignore, Source.fromPublisher(source)), None)
|
||||
}, interface = "localhost", port = 0)
|
||||
val binding = Await.result(bindingFuture, 3.seconds)
|
||||
val myPort = binding.localAddress.getPort
|
||||
|
||||
val (response, sink) = Http().singleWebsocketRequest(
|
||||
WebsocketRequest("ws://127.0.0.1:" + myPort),
|
||||
Flow.fromSinkAndSourceMat(TestSink.probe[Message], Source.empty)(Keep.left))
|
||||
|
||||
response.futureValue.response.status.isSuccess should ===(true)
|
||||
sink
|
||||
.request(10)
|
||||
.expectNoMsg(500.millis)
|
||||
|
||||
source
|
||||
.sendNext(TextMessage("hello"))
|
||||
.sendComplete()
|
||||
sink
|
||||
.expectNext(TextMessage("hello"))
|
||||
.expectComplete()
|
||||
|
||||
binding.unbind()
|
||||
}
|
||||
|
||||
"not reset the connection when no data are flowing and the connection is closed from the client" in Utils.assertAllStagesStopped {
|
||||
val source = TestPublisher.probe[Message]()
|
||||
val bindingFuture = Http().bindAndHandleSync({
|
||||
case HttpRequest(_, _, headers, _, _) ⇒
|
||||
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
||||
upgrade.handleMessages(Flow.fromSinkAndSource(Sink.ignore, Source.fromPublisher(source)), None)
|
||||
}, interface = "localhost", port = 0)
|
||||
val binding = Await.result(bindingFuture, 3.seconds)
|
||||
val myPort = binding.localAddress.getPort
|
||||
|
||||
val ((response, breaker), sink) =
|
||||
Source.empty
|
||||
.viaMat {
|
||||
Http().websocketClientLayer(WebsocketRequest("ws://localhost:" + myPort))
|
||||
.atop(SslTlsPlacebo.forScala)
|
||||
.joinMat(Flow.fromGraph(GraphStages.breaker[ByteString]).via(
|
||||
Tcp().outgoingConnection(new InetSocketAddress("localhost", myPort), halfClose = true)))(Keep.both)
|
||||
}(Keep.right)
|
||||
.toMat(TestSink.probe[Message])(Keep.both)
|
||||
.run()
|
||||
|
||||
response.futureValue.response.status.isSuccess should ===(true)
|
||||
sink
|
||||
.request(10)
|
||||
.expectNoMsg(1500.millis)
|
||||
|
||||
breaker.value.get.get.complete()
|
||||
|
||||
source
|
||||
.sendNext(TextMessage("hello"))
|
||||
.sendComplete()
|
||||
sink
|
||||
.expectNext(TextMessage("hello"))
|
||||
.expectComplete()
|
||||
|
||||
binding.unbind()
|
||||
}
|
||||
|
||||
"echo 100 elements and then shut down without error" in Utils.assertAllStagesStopped {
|
||||
|
||||
val bindingFuture = Http().bindAndHandleSync({
|
||||
|
|
@ -93,7 +163,7 @@ class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.
|
|||
.run()
|
||||
eventually(messages should ===(N))
|
||||
// breaker should have been fulfilled long ago
|
||||
breaker.value.get.get.complete()
|
||||
breaker.value.get.get.completeAndCancel()
|
||||
completion.futureValue
|
||||
|
||||
binding.unbind()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue