diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index c0337777bb..970e0db865 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -433,46 +433,60 @@ private[http] object HttpServerBluePrint { case Right(messageHandler) ⇒ Websocket.stack(serverSide = true, maskingRandomFactory = settings.websocketRandomFactory, log = log).join(messageHandler) } + val sinkIn = new SubSinkInlet[ByteString]("FrameSink") - val sourceOut = new SubSourceOutlet[ByteString]("FrameSource") - - val timeoutKey = SubscriptionTimeout(() ⇒ { - sourceOut.timeout(timeout) - if (sourceOut.isClosed) completeStage() - }) - addTimeout(timeoutKey) - sinkIn.setHandler(new InHandler { override def onPush(): Unit = push(toNet, sinkIn.grab()) override def onUpstreamFinish(): Unit = complete(toNet) override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex) }) - setHandler(toNet, new OutHandler { - override def onPull(): Unit = sinkIn.pull() - override def onDownstreamFinish(): Unit = { - completeStage() - sinkIn.cancel() - sourceOut.complete() - } - }) - setHandler(fromNet, new InHandler { - override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes) - override def onUpstreamFinish(): Unit = sourceOut.complete() - override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex) - }) - sourceOut.setHandler(new OutHandler { - override def onPull(): Unit = { - if (!hasBeenPulled(fromNet)) pull(fromNet) - cancelTimeout(timeoutKey) - sourceOut.setHandler(new OutHandler { - override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet) - }) - } - override def onDownstreamFinish(): Unit = cancel(fromNet) - }) + if (isClosed(fromNet)) { + setHandler(toNet, new OutHandler { + override def onPull(): Unit = sinkIn.pull() + override def onDownstreamFinish(): Unit = { + completeStage() + sinkIn.cancel() + } + }) + Websocket.framing.join(frameHandler).runWith(Source.empty, sinkIn.sink)(subFusingMaterializer) + } else { + val sourceOut = new SubSourceOutlet[ByteString]("FrameSource") - Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer) + val timeoutKey = SubscriptionTimeout(() ⇒ { + sourceOut.timeout(timeout) + if (sourceOut.isClosed) completeStage() + }) + addTimeout(timeoutKey) + + setHandler(toNet, new OutHandler { + override def onPull(): Unit = sinkIn.pull() + override def onDownstreamFinish(): Unit = { + completeStage() + sinkIn.cancel() + sourceOut.complete() + } + }) + + setHandler(fromNet, new InHandler { + override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes) + override def onUpstreamFinish(): Unit = sourceOut.complete() + override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex) + }) + sourceOut.setHandler(new OutHandler { + override def onPull(): Unit = { + if (!hasBeenPulled(fromNet)) pull(fromNet) + cancelTimeout(timeoutKey) + sourceOut.setHandler(new OutHandler { + override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet) + override def onDownstreamFinish(): Unit = cancel(fromNet) + }) + } + override def onDownstreamFinish(): Unit = cancel(fromNet) + }) + + Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer) + } } } } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala index a2d734471a..c98bac652f 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala @@ -195,14 +195,19 @@ private[http] object Websocket { def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { - passAlong(bypass, out, doFinish = true, doFail = true) - passAlong(user, out, doFinish = false, doFail = false) + class PassAlong[T <: AnyRef](from: Inlet[T]) extends InHandler with (() ⇒ Unit) { + override def apply(): Unit = tryPull(from) + override def onPush(): Unit = emit(out, grab(from), this) + override def onUpstreamFinish(): Unit = + if (isClosed(bypass) && isClosed(user)) completeStage() + } + setHandler(bypass, new PassAlong(bypass)) + setHandler(user, new PassAlong(user)) passAlong(tick, out, doFinish = false, doFail = false) setHandler(out, eagerTerminateOutput) override def preStart(): Unit = { - super.preStart() pull(bypass) pull(user) pull(tick) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala index 59219ff313..184ed34cf4 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala @@ -21,6 +21,9 @@ import akka.stream.io.SslTlsPlacebo import java.net.InetSocketAddress import akka.stream.impl.fusing.GraphStages import akka.util.ByteString +import akka.http.scaladsl.model.StatusCodes +import akka.stream.testkit.scaladsl.TestSink +import scala.concurrent.Future class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode=off") with ScalaFutures with ConversionCheckedTripleEquals with Eventually { @@ -31,6 +34,73 @@ class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug. "A Websocket server" must { + "not reset the connection when no data are flowing" in Utils.assertAllStagesStopped { + val source = TestPublisher.probe[Message]() + val bindingFuture = Http().bindAndHandleSync({ + case HttpRequest(_, _, headers, _, _) ⇒ + val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get + upgrade.handleMessages(Flow.fromSinkAndSource(Sink.ignore, Source.fromPublisher(source)), None) + }, interface = "localhost", port = 0) + val binding = Await.result(bindingFuture, 3.seconds) + val myPort = binding.localAddress.getPort + + val (response, sink) = Http().singleWebsocketRequest( + WebsocketRequest("ws://127.0.0.1:" + myPort), + Flow.fromSinkAndSourceMat(TestSink.probe[Message], Source.empty)(Keep.left)) + + response.futureValue.response.status.isSuccess should ===(true) + sink + .request(10) + .expectNoMsg(500.millis) + + source + .sendNext(TextMessage("hello")) + .sendComplete() + sink + .expectNext(TextMessage("hello")) + .expectComplete() + + binding.unbind() + } + + "not reset the connection when no data are flowing and the connection is closed from the client" in Utils.assertAllStagesStopped { + val source = TestPublisher.probe[Message]() + val bindingFuture = Http().bindAndHandleSync({ + case HttpRequest(_, _, headers, _, _) ⇒ + val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get + upgrade.handleMessages(Flow.fromSinkAndSource(Sink.ignore, Source.fromPublisher(source)), None) + }, interface = "localhost", port = 0) + val binding = Await.result(bindingFuture, 3.seconds) + val myPort = binding.localAddress.getPort + + val ((response, breaker), sink) = + Source.empty + .viaMat { + Http().websocketClientLayer(WebsocketRequest("ws://localhost:" + myPort)) + .atop(SslTlsPlacebo.forScala) + .joinMat(Flow.fromGraph(GraphStages.breaker[ByteString]).via( + Tcp().outgoingConnection(new InetSocketAddress("localhost", myPort), halfClose = true)))(Keep.both) + }(Keep.right) + .toMat(TestSink.probe[Message])(Keep.both) + .run() + + response.futureValue.response.status.isSuccess should ===(true) + sink + .request(10) + .expectNoMsg(1500.millis) + + breaker.value.get.get.complete() + + source + .sendNext(TextMessage("hello")) + .sendComplete() + sink + .expectNext(TextMessage("hello")) + .expectComplete() + + binding.unbind() + } + "echo 100 elements and then shut down without error" in Utils.assertAllStagesStopped { val bindingFuture = Http().bindAndHandleSync({ @@ -93,7 +163,7 @@ class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug. .run() eventually(messages should ===(N)) // breaker should have been fulfilled long ago - breaker.value.get.get.complete() + breaker.value.get.get.completeAndCancel() completion.futureValue binding.unbind() diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala index 359a46c7ef..3680ab13ce 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala @@ -108,13 +108,19 @@ object GraphStages { final class Breaker(callback: Breaker.Operation ⇒ Unit) { import Breaker._ def complete(): Unit = callback(Complete) + def cancel(): Unit = callback(Cancel) def fail(ex: Throwable): Unit = callback(Fail(ex)) + def completeAndCancel(): Unit = callback(CompleteAndCancel) + def failAndCancel(ex: Throwable): Unit = callback(FailAndCancel(ex)) } object Breaker extends GraphStageWithMaterializedValue[FlowShape[Any, Any], Future[Breaker]] { sealed trait Operation case object Complete extends Operation + case object Cancel extends Operation case class Fail(ex: Throwable) extends Operation + case object CompleteAndCancel extends Operation + case class FailAndCancel(ex: Throwable) extends Operation override val initialAttributes = Attributes.name("breaker") override val shape = FlowShape(Inlet[Any]("breaker.in"), Outlet[Any]("breaker.out")) @@ -130,8 +136,11 @@ object GraphStages { override def preStart(): Unit = { pull(shape.in) promise.success(new Breaker(getAsyncCallback[Operation] { - case Complete ⇒ completeStage() - case Fail(ex) ⇒ failStage(ex) + case Complete ⇒ complete(shape.out) + case Cancel ⇒ cancel(shape.in) + case Fail(ex) ⇒ fail(shape.out, ex) + case CompleteAndCancel ⇒ completeStage() + case FailAndCancel(ex) ⇒ failStage(ex) }.invoke)) } } @@ -176,8 +185,17 @@ object GraphStages { override def preStart(): Unit = { promise.success(new Breaker(getAsyncCallback[Operation] { - case Complete ⇒ completeStage() - case Fail(ex) ⇒ failStage(ex) + case Complete ⇒ + complete(shape.out1) + complete(shape.out2) + case Cancel ⇒ + cancel(shape.in1) + cancel(shape.in2) + case Fail(ex) ⇒ + fail(shape.out1, ex) + fail(shape.out2, ex) + case CompleteAndCancel ⇒ completeStage() + case FailAndCancel(ex) ⇒ failStage(ex) }.invoke)) } }