#19467 avoid pulling closed port in ProtocolSwitch

This commit is contained in:
Roland Kuhn 2016-01-18 20:20:18 +01:00
parent 8d6bacab2f
commit 5523e1d70f
4 changed files with 147 additions and 40 deletions

View file

@ -433,46 +433,60 @@ private[http] object HttpServerBluePrint {
case Right(messageHandler)
Websocket.stack(serverSide = true, maskingRandomFactory = settings.websocketRandomFactory, log = log).join(messageHandler)
}
val sinkIn = new SubSinkInlet[ByteString]("FrameSink")
val sourceOut = new SubSourceOutlet[ByteString]("FrameSource")
val timeoutKey = SubscriptionTimeout(() {
sourceOut.timeout(timeout)
if (sourceOut.isClosed) completeStage()
})
addTimeout(timeoutKey)
sinkIn.setHandler(new InHandler {
override def onPush(): Unit = push(toNet, sinkIn.grab())
override def onUpstreamFinish(): Unit = complete(toNet)
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
})
setHandler(toNet, new OutHandler {
override def onPull(): Unit = sinkIn.pull()
override def onDownstreamFinish(): Unit = {
completeStage()
sinkIn.cancel()
sourceOut.complete()
}
})
setHandler(fromNet, new InHandler {
override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes)
override def onUpstreamFinish(): Unit = sourceOut.complete()
override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex)
})
sourceOut.setHandler(new OutHandler {
override def onPull(): Unit = {
if (!hasBeenPulled(fromNet)) pull(fromNet)
cancelTimeout(timeoutKey)
sourceOut.setHandler(new OutHandler {
override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet)
})
}
override def onDownstreamFinish(): Unit = cancel(fromNet)
})
if (isClosed(fromNet)) {
setHandler(toNet, new OutHandler {
override def onPull(): Unit = sinkIn.pull()
override def onDownstreamFinish(): Unit = {
completeStage()
sinkIn.cancel()
}
})
Websocket.framing.join(frameHandler).runWith(Source.empty, sinkIn.sink)(subFusingMaterializer)
} else {
val sourceOut = new SubSourceOutlet[ByteString]("FrameSource")
Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer)
val timeoutKey = SubscriptionTimeout(() {
sourceOut.timeout(timeout)
if (sourceOut.isClosed) completeStage()
})
addTimeout(timeoutKey)
setHandler(toNet, new OutHandler {
override def onPull(): Unit = sinkIn.pull()
override def onDownstreamFinish(): Unit = {
completeStage()
sinkIn.cancel()
sourceOut.complete()
}
})
setHandler(fromNet, new InHandler {
override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes)
override def onUpstreamFinish(): Unit = sourceOut.complete()
override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex)
})
sourceOut.setHandler(new OutHandler {
override def onPull(): Unit = {
if (!hasBeenPulled(fromNet)) pull(fromNet)
cancelTimeout(timeoutKey)
sourceOut.setHandler(new OutHandler {
override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet)
override def onDownstreamFinish(): Unit = cancel(fromNet)
})
}
override def onDownstreamFinish(): Unit = cancel(fromNet)
})
Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer)
}
}
}
}

View file

@ -195,14 +195,19 @@ private[http] object Websocket {
def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) {
passAlong(bypass, out, doFinish = true, doFail = true)
passAlong(user, out, doFinish = false, doFail = false)
class PassAlong[T <: AnyRef](from: Inlet[T]) extends InHandler with (() Unit) {
override def apply(): Unit = tryPull(from)
override def onPush(): Unit = emit(out, grab(from), this)
override def onUpstreamFinish(): Unit =
if (isClosed(bypass) && isClosed(user)) completeStage()
}
setHandler(bypass, new PassAlong(bypass))
setHandler(user, new PassAlong(user))
passAlong(tick, out, doFinish = false, doFail = false)
setHandler(out, eagerTerminateOutput)
override def preStart(): Unit = {
super.preStart()
pull(bypass)
pull(user)
pull(tick)

View file

@ -21,6 +21,9 @@ import akka.stream.io.SslTlsPlacebo
import java.net.InetSocketAddress
import akka.stream.impl.fusing.GraphStages
import akka.util.ByteString
import akka.http.scaladsl.model.StatusCodes
import akka.stream.testkit.scaladsl.TestSink
import scala.concurrent.Future
class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode=off")
with ScalaFutures with ConversionCheckedTripleEquals with Eventually {
@ -31,6 +34,73 @@ class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.
"A Websocket server" must {
"not reset the connection when no data are flowing" in Utils.assertAllStagesStopped {
val source = TestPublisher.probe[Message]()
val bindingFuture = Http().bindAndHandleSync({
case HttpRequest(_, _, headers, _, _)
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket u }.get
upgrade.handleMessages(Flow.fromSinkAndSource(Sink.ignore, Source.fromPublisher(source)), None)
}, interface = "localhost", port = 0)
val binding = Await.result(bindingFuture, 3.seconds)
val myPort = binding.localAddress.getPort
val (response, sink) = Http().singleWebsocketRequest(
WebsocketRequest("ws://127.0.0.1:" + myPort),
Flow.fromSinkAndSourceMat(TestSink.probe[Message], Source.empty)(Keep.left))
response.futureValue.response.status.isSuccess should ===(true)
sink
.request(10)
.expectNoMsg(500.millis)
source
.sendNext(TextMessage("hello"))
.sendComplete()
sink
.expectNext(TextMessage("hello"))
.expectComplete()
binding.unbind()
}
"not reset the connection when no data are flowing and the connection is closed from the client" in Utils.assertAllStagesStopped {
val source = TestPublisher.probe[Message]()
val bindingFuture = Http().bindAndHandleSync({
case HttpRequest(_, _, headers, _, _)
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket u }.get
upgrade.handleMessages(Flow.fromSinkAndSource(Sink.ignore, Source.fromPublisher(source)), None)
}, interface = "localhost", port = 0)
val binding = Await.result(bindingFuture, 3.seconds)
val myPort = binding.localAddress.getPort
val ((response, breaker), sink) =
Source.empty
.viaMat {
Http().websocketClientLayer(WebsocketRequest("ws://localhost:" + myPort))
.atop(SslTlsPlacebo.forScala)
.joinMat(Flow.fromGraph(GraphStages.breaker[ByteString]).via(
Tcp().outgoingConnection(new InetSocketAddress("localhost", myPort), halfClose = true)))(Keep.both)
}(Keep.right)
.toMat(TestSink.probe[Message])(Keep.both)
.run()
response.futureValue.response.status.isSuccess should ===(true)
sink
.request(10)
.expectNoMsg(1500.millis)
breaker.value.get.get.complete()
source
.sendNext(TextMessage("hello"))
.sendComplete()
sink
.expectNext(TextMessage("hello"))
.expectComplete()
binding.unbind()
}
"echo 100 elements and then shut down without error" in Utils.assertAllStagesStopped {
val bindingFuture = Http().bindAndHandleSync({
@ -93,7 +163,7 @@ class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.
.run()
eventually(messages should ===(N))
// breaker should have been fulfilled long ago
breaker.value.get.get.complete()
breaker.value.get.get.completeAndCancel()
completion.futureValue
binding.unbind()