From e0361ece66076ad0ca0f99abd0e169e58cef07bd Mon Sep 17 00:00:00 2001 From: Roland Kuhn Date: Wed, 13 Jan 2016 17:52:46 +0100 Subject: [PATCH 1/5] #19398 fix stream leak in ProtocolSwitchStage also fix potential NPE in TCP streams when failed or canceled early for an outgoing connection --- .../engine/server/HttpServerBluePrint.scala | 26 ++++- .../impl/engine/ws/BypassRouterSpec.scala | 50 --------- .../engine/ws/WebsocketIntegrationSpec.scala | 101 ++++++++++++++++++ .../akka/stream/impl/fusing/GraphStages.scala | 83 ++++++++++++++ .../scala/akka/stream/impl/io/TcpStages.scala | 6 +- 5 files changed, 209 insertions(+), 57 deletions(-) delete mode 100644 akka-http-core/src/test/scala/akka/http/impl/engine/ws/BypassRouterSpec.scala create mode 100644 akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index 2216cb5b43..c0337777bb 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -371,6 +371,11 @@ private[http] object HttpServerBluePrint { def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { import akka.http.impl.engine.rendering.ResponseRenderingOutput._ + /* + * These handlers are in charge until a switch command comes in, then they + * are replaced. + */ + setHandler(fromHttp, new InHandler { override def onPush(): Unit = grab(fromHttp) match { @@ -381,21 +386,22 @@ private[http] object HttpServerBluePrint { cancel(fromHttp) switchToWebsocket(handlerFlow) } + override def onUpstreamFinish(): Unit = complete(toNet) + override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex) }) setHandler(toNet, new OutHandler { override def onPull(): Unit = pull(fromHttp) + override def onDownstreamFinish(): Unit = completeStage() }) setHandler(fromNet, new InHandler { - def onPush(): Unit = push(toHttp, grab(fromNet)) - - // propagate error but don't close stage yet to prevent fromHttp/fromWs being cancelled - // too eagerly + override def onPush(): Unit = push(toHttp, grab(fromNet)) + override def onUpstreamFinish(): Unit = complete(toHttp) override def onUpstreamFailure(ex: Throwable): Unit = fail(toHttp, ex) }) setHandler(toHttp, new OutHandler { override def onPull(): Unit = pull(fromNet) - override def onDownstreamFinish(): Unit = () + override def onDownstreamFinish(): Unit = cancel(fromNet) }) private var activeTimers = 0 @@ -438,13 +444,22 @@ private[http] object HttpServerBluePrint { sinkIn.setHandler(new InHandler { override def onPush(): Unit = push(toNet, sinkIn.grab()) + override def onUpstreamFinish(): Unit = complete(toNet) + override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex) }) setHandler(toNet, new OutHandler { override def onPull(): Unit = sinkIn.pull() + override def onDownstreamFinish(): Unit = { + completeStage() + sinkIn.cancel() + sourceOut.complete() + } }) setHandler(fromNet, new InHandler { override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes) + override def onUpstreamFinish(): Unit = sourceOut.complete() + override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex) }) sourceOut.setHandler(new OutHandler { override def onPull(): Unit = { @@ -454,6 +469,7 @@ private[http] object HttpServerBluePrint { override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet) }) } + override def onDownstreamFinish(): Unit = cancel(fromNet) }) Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/BypassRouterSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/BypassRouterSpec.scala deleted file mode 100644 index 89be18a5da..0000000000 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/BypassRouterSpec.scala +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright (C) 2015 Typesafe Inc. - */ -package akka.http.impl.engine.ws - -import akka.stream.testkit.AkkaSpec -import scala.concurrent.Await -import com.typesafe.config.ConfigFactory -import com.typesafe.config.Config -import akka.actor.ActorSystem -import akka.http.scaladsl.model.HttpRequest -import akka.http.scaladsl.model.ws._ -import akka.http.scaladsl._ -import akka.stream.scaladsl._ -import akka.stream._ -import scala.concurrent.duration._ -import org.scalatest.concurrent.ScalaFutures -import org.scalactic.ConversionCheckedTripleEquals -import akka.stream.testkit.Utils - -class BypassRouterSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode = off") with ScalaFutures with ConversionCheckedTripleEquals { - - implicit val patience = PatienceConfig(3.seconds) - import system.dispatcher - implicit val materializer = ActorMaterializer() - - "BypassRouter" must { - - "work without double pull-ing some ports" in Utils.assertAllStagesStopped { - val bindingFuture = Http().bindAndHandleSync({ - case HttpRequest(_, _, headers, _, _) ⇒ - val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get - upgrade.handleMessages(Flow.apply, None) - }, interface = "localhost", port = 8080) - val binding = Await.result(bindingFuture, 3.seconds) - - val N = 100 - val (response, count) = Http().singleWebsocketRequest( - WebsocketRequest("ws://127.0.0.1:8080"), - Flow.fromSinkAndSourceMat( - Sink.fold(0)((n, _: Message) ⇒ n + 1), - Source.repeat(TextMessage("hello")).take(N))(Keep.left)) - - count.futureValue should ===(N) - binding.unbind() - } - - } - -} diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala new file mode 100644 index 0000000000..f389f88c72 --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala @@ -0,0 +1,101 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.http.impl.engine.ws + +import scala.concurrent.Await +import scala.concurrent.duration.DurationInt +import org.scalactic.ConversionCheckedTripleEquals +import org.scalatest.concurrent.ScalaFutures +import org.scalatest.time.Span.convertDurationToSpan +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.HttpRequest +import akka.http.scaladsl.model.Uri.apply +import akka.http.scaladsl.model.ws._ +import akka.stream._ +import akka.stream.scaladsl._ +import akka.stream.testkit._ +import akka.stream.scaladsl.GraphDSL.Implicits._ +import org.scalatest.concurrent.Eventually +import akka.stream.io.SslTlsPlacebo +import java.net.InetSocketAddress +import akka.stream.impl.fusing.GraphStages +import akka.util.ByteString + +class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode=off") + with ScalaFutures with ConversionCheckedTripleEquals with Eventually { + + implicit val patience = PatienceConfig(3.seconds) + import system.dispatcher + implicit val materializer = ActorMaterializer() + + "A Websocket server" must { + + "echo 100 elements and then shut down without error" in Utils.assertAllStagesStopped { + val bindingFuture = Http().bindAndHandleSync({ + case HttpRequest(_, _, headers, _, _) ⇒ + val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get + upgrade.handleMessages(Flow.apply, None) + }, interface = "localhost", port = 8080) + val binding = Await.result(bindingFuture, 3.seconds) + + val N = 100 + val (response, count) = Http().singleWebsocketRequest( + WebsocketRequest("ws://127.0.0.1:8080"), + Flow.fromSinkAndSourceMat( + Sink.fold(0)((n, _: Message) ⇒ n + 1), + Source.repeat(TextMessage("hello")).take(N))(Keep.left)) + + count.futureValue should ===(N) + binding.unbind() + } + + "send back 100 elements and then terminate without error even when not ordinarily closed" in Utils.assertAllStagesStopped { + val N = 100 + + val handler = Flow.fromGraph(GraphDSL.create() { implicit b ⇒ + val merge = b.add(Merge[Int](2)) + + // convert to int so we can connect to merge + val mapMsgToInt = b.add(Flow[Message].map(_ ⇒ -1)) + val mapIntToMsg = b.add(Flow[Int].map(x ⇒ TextMessage.Strict(s"Sending: $x"))) + + // source we want to use to send message to the connected websocket sink + val rangeSource = b.add(Source(1 to N)) + + mapMsgToInt ~> merge // this part of the merge will never provide msgs + rangeSource ~> merge ~> mapIntToMsg + + FlowShape(mapMsgToInt.in, mapIntToMsg.out) + }) + + val bindingFuture = Http().bindAndHandleSync({ + case HttpRequest(_, _, headers, _, _) ⇒ + val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get + upgrade.handleMessages(handler, None) + }, interface = "localhost", port = 8080) + val binding = Await.result(bindingFuture, 3.seconds) + + @volatile var messages = 0 + val (breaker, completion) = + Source.maybe + .viaMat { + Http().websocketClientLayer(WebsocketRequest("ws://localhost:8080")) + .atop(SslTlsPlacebo.forScala) + // the resource leak of #19398 existed only for severed websocket connections + .atopMat(GraphStages.bidiBreaker[ByteString, ByteString])(Keep.right) + .join(Tcp().outgoingConnection(new InetSocketAddress("localhost", 8080), halfClose = true)) + }(Keep.right) + .toMat(Sink.foreach(_ ⇒ messages += 1))(Keep.both) + .run() + eventually(messages should ===(N)) + // breaker should have been fulfilled long ago + breaker.value.get.get.complete() + completion.futureValue + + binding.unbind() + } + + } + +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala index ea70aaf273..359a46c7ef 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala @@ -105,6 +105,89 @@ object GraphStages { private val _detacher = new Detacher[Any] def detacher[T]: GraphStage[FlowShape[T, T]] = _detacher.asInstanceOf[GraphStage[FlowShape[T, T]]] + final class Breaker(callback: Breaker.Operation ⇒ Unit) { + import Breaker._ + def complete(): Unit = callback(Complete) + def fail(ex: Throwable): Unit = callback(Fail(ex)) + } + + object Breaker extends GraphStageWithMaterializedValue[FlowShape[Any, Any], Future[Breaker]] { + sealed trait Operation + case object Complete extends Operation + case class Fail(ex: Throwable) extends Operation + + override val initialAttributes = Attributes.name("breaker") + override val shape = FlowShape(Inlet[Any]("breaker.in"), Outlet[Any]("breaker.out")) + + override def createLogicAndMaterializedValue(attr: Attributes) = { + val promise = Promise[Breaker] + + val logic = new GraphStageLogic(shape) { + + passAlong(shape.in, shape.out) + setHandler(shape.out, eagerTerminateOutput) + + override def preStart(): Unit = { + pull(shape.in) + promise.success(new Breaker(getAsyncCallback[Operation] { + case Complete ⇒ completeStage() + case Fail(ex) ⇒ failStage(ex) + }.invoke)) + } + } + + (logic, promise.future) + } + } + + def breaker[T]: Graph[FlowShape[T, T], Future[Breaker]] = Breaker.asInstanceOf[Graph[FlowShape[T, T], Future[Breaker]]] + + object BidiBreaker extends GraphStageWithMaterializedValue[BidiShape[Any, Any, Any, Any], Future[Breaker]] { + import Breaker._ + + override val initialAttributes = Attributes.name("breaker") + override val shape = BidiShape( + Inlet[Any]("breaker.in1"), Outlet[Any]("breaker.out1"), + Inlet[Any]("breaker.in2"), Outlet[Any]("breaker.out2")) + + override def createLogicAndMaterializedValue(attr: Attributes) = { + val promise = Promise[Breaker] + + val logic = new GraphStageLogic(shape) { + + setHandler(shape.in1, new InHandler { + override def onPush(): Unit = push(shape.out1, grab(shape.in1)) + override def onUpstreamFinish(): Unit = complete(shape.out1) + override def onUpstreamFailure(ex: Throwable): Unit = fail(shape.out1, ex) + }) + setHandler(shape.in2, new InHandler { + override def onPush(): Unit = push(shape.out2, grab(shape.in2)) + override def onUpstreamFinish(): Unit = complete(shape.out2) + override def onUpstreamFailure(ex: Throwable): Unit = fail(shape.out2, ex) + }) + setHandler(shape.out1, new OutHandler { + override def onPull(): Unit = pull(shape.in1) + override def onDownstreamFinish(): Unit = cancel(shape.in1) + }) + setHandler(shape.out2, new OutHandler { + override def onPull(): Unit = pull(shape.in2) + override def onDownstreamFinish(): Unit = cancel(shape.in2) + }) + + override def preStart(): Unit = { + promise.success(new Breaker(getAsyncCallback[Operation] { + case Complete ⇒ completeStage() + case Fail(ex) ⇒ failStage(ex) + }.invoke)) + } + } + + (logic, promise.future) + } + } + + def bidiBreaker[T1, T2]: Graph[BidiShape[T1, T1, T2, T2], Future[Breaker]] = BidiBreaker.asInstanceOf[Graph[BidiShape[T1, T1, T2, T2], Future[Breaker]]] + private object TickSource { class TickSourceCancellable(cancelled: AtomicBoolean) extends Cancellable { private val cancelPromise = Promise[Unit]() diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala index 18b324244a..97b562beae 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala @@ -261,11 +261,13 @@ private[stream] object TcpConnectionStage { // (or half-close is turned off) if (isClosed(bytesOut) || !role.halfClose) connection ! Close // We still read, so we only close the write side - else connection ! ConfirmedClose + else if (connection != null) connection ! ConfirmedClose + else completeStage() } override def onUpstreamFailure(ex: Throwable): Unit = { - connection ! Abort + if (connection != null) connection ! Abort + else failStage(ex) } }) From 8d6bacab2f142be79bb93d0e1c580fe11352035a Mon Sep 17 00:00:00 2001 From: Roland Kuhn Date: Sat, 16 Jan 2016 20:05:09 +0100 Subject: [PATCH 2/5] #19451 use random port in WebsocketIntegrationSpec --- .../impl/engine/ws/WebsocketIntegrationSpec.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala index f389f88c72..59219ff313 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala @@ -32,16 +32,18 @@ class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug. "A Websocket server" must { "echo 100 elements and then shut down without error" in Utils.assertAllStagesStopped { + val bindingFuture = Http().bindAndHandleSync({ case HttpRequest(_, _, headers, _, _) ⇒ val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get upgrade.handleMessages(Flow.apply, None) - }, interface = "localhost", port = 8080) + }, interface = "localhost", port = 0) val binding = Await.result(bindingFuture, 3.seconds) + val myPort = binding.localAddress.getPort val N = 100 val (response, count) = Http().singleWebsocketRequest( - WebsocketRequest("ws://127.0.0.1:8080"), + WebsocketRequest("ws://127.0.0.1:" + myPort), Flow.fromSinkAndSourceMat( Sink.fold(0)((n, _: Message) ⇒ n + 1), Source.repeat(TextMessage("hello")).take(N))(Keep.left)) @@ -73,18 +75,19 @@ class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug. case HttpRequest(_, _, headers, _, _) ⇒ val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get upgrade.handleMessages(handler, None) - }, interface = "localhost", port = 8080) + }, interface = "localhost", port = 0) val binding = Await.result(bindingFuture, 3.seconds) + val myPort = binding.localAddress.getPort @volatile var messages = 0 val (breaker, completion) = Source.maybe .viaMat { - Http().websocketClientLayer(WebsocketRequest("ws://localhost:8080")) + Http().websocketClientLayer(WebsocketRequest("ws://localhost:" + myPort)) .atop(SslTlsPlacebo.forScala) // the resource leak of #19398 existed only for severed websocket connections .atopMat(GraphStages.bidiBreaker[ByteString, ByteString])(Keep.right) - .join(Tcp().outgoingConnection(new InetSocketAddress("localhost", 8080), halfClose = true)) + .join(Tcp().outgoingConnection(new InetSocketAddress("localhost", myPort), halfClose = true)) }(Keep.right) .toMat(Sink.foreach(_ ⇒ messages += 1))(Keep.both) .run() From 5523e1d70f6b103322a2b42f60ecab7759af84e8 Mon Sep 17 00:00:00 2001 From: Roland Kuhn Date: Mon, 18 Jan 2016 20:20:18 +0100 Subject: [PATCH 3/5] #19467 avoid pulling closed port in ProtocolSwitch --- .../engine/server/HttpServerBluePrint.scala | 78 +++++++++++-------- .../akka/http/impl/engine/ws/Websocket.scala | 11 ++- .../engine/ws/WebsocketIntegrationSpec.scala | 72 ++++++++++++++++- .../akka/stream/impl/fusing/GraphStages.scala | 26 ++++++- 4 files changed, 147 insertions(+), 40 deletions(-) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index c0337777bb..970e0db865 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -433,46 +433,60 @@ private[http] object HttpServerBluePrint { case Right(messageHandler) ⇒ Websocket.stack(serverSide = true, maskingRandomFactory = settings.websocketRandomFactory, log = log).join(messageHandler) } + val sinkIn = new SubSinkInlet[ByteString]("FrameSink") - val sourceOut = new SubSourceOutlet[ByteString]("FrameSource") - - val timeoutKey = SubscriptionTimeout(() ⇒ { - sourceOut.timeout(timeout) - if (sourceOut.isClosed) completeStage() - }) - addTimeout(timeoutKey) - sinkIn.setHandler(new InHandler { override def onPush(): Unit = push(toNet, sinkIn.grab()) override def onUpstreamFinish(): Unit = complete(toNet) override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex) }) - setHandler(toNet, new OutHandler { - override def onPull(): Unit = sinkIn.pull() - override def onDownstreamFinish(): Unit = { - completeStage() - sinkIn.cancel() - sourceOut.complete() - } - }) - setHandler(fromNet, new InHandler { - override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes) - override def onUpstreamFinish(): Unit = sourceOut.complete() - override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex) - }) - sourceOut.setHandler(new OutHandler { - override def onPull(): Unit = { - if (!hasBeenPulled(fromNet)) pull(fromNet) - cancelTimeout(timeoutKey) - sourceOut.setHandler(new OutHandler { - override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet) - }) - } - override def onDownstreamFinish(): Unit = cancel(fromNet) - }) + if (isClosed(fromNet)) { + setHandler(toNet, new OutHandler { + override def onPull(): Unit = sinkIn.pull() + override def onDownstreamFinish(): Unit = { + completeStage() + sinkIn.cancel() + } + }) + Websocket.framing.join(frameHandler).runWith(Source.empty, sinkIn.sink)(subFusingMaterializer) + } else { + val sourceOut = new SubSourceOutlet[ByteString]("FrameSource") - Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer) + val timeoutKey = SubscriptionTimeout(() ⇒ { + sourceOut.timeout(timeout) + if (sourceOut.isClosed) completeStage() + }) + addTimeout(timeoutKey) + + setHandler(toNet, new OutHandler { + override def onPull(): Unit = sinkIn.pull() + override def onDownstreamFinish(): Unit = { + completeStage() + sinkIn.cancel() + sourceOut.complete() + } + }) + + setHandler(fromNet, new InHandler { + override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes) + override def onUpstreamFinish(): Unit = sourceOut.complete() + override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex) + }) + sourceOut.setHandler(new OutHandler { + override def onPull(): Unit = { + if (!hasBeenPulled(fromNet)) pull(fromNet) + cancelTimeout(timeoutKey) + sourceOut.setHandler(new OutHandler { + override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet) + override def onDownstreamFinish(): Unit = cancel(fromNet) + }) + } + override def onDownstreamFinish(): Unit = cancel(fromNet) + }) + + Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer) + } } } } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala index a2d734471a..c98bac652f 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/Websocket.scala @@ -195,14 +195,19 @@ private[http] object Websocket { def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) { - passAlong(bypass, out, doFinish = true, doFail = true) - passAlong(user, out, doFinish = false, doFail = false) + class PassAlong[T <: AnyRef](from: Inlet[T]) extends InHandler with (() ⇒ Unit) { + override def apply(): Unit = tryPull(from) + override def onPush(): Unit = emit(out, grab(from), this) + override def onUpstreamFinish(): Unit = + if (isClosed(bypass) && isClosed(user)) completeStage() + } + setHandler(bypass, new PassAlong(bypass)) + setHandler(user, new PassAlong(user)) passAlong(tick, out, doFinish = false, doFail = false) setHandler(out, eagerTerminateOutput) override def preStart(): Unit = { - super.preStart() pull(bypass) pull(user) pull(tick) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala index 59219ff313..184ed34cf4 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala @@ -21,6 +21,9 @@ import akka.stream.io.SslTlsPlacebo import java.net.InetSocketAddress import akka.stream.impl.fusing.GraphStages import akka.util.ByteString +import akka.http.scaladsl.model.StatusCodes +import akka.stream.testkit.scaladsl.TestSink +import scala.concurrent.Future class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode=off") with ScalaFutures with ConversionCheckedTripleEquals with Eventually { @@ -31,6 +34,73 @@ class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug. "A Websocket server" must { + "not reset the connection when no data are flowing" in Utils.assertAllStagesStopped { + val source = TestPublisher.probe[Message]() + val bindingFuture = Http().bindAndHandleSync({ + case HttpRequest(_, _, headers, _, _) ⇒ + val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get + upgrade.handleMessages(Flow.fromSinkAndSource(Sink.ignore, Source.fromPublisher(source)), None) + }, interface = "localhost", port = 0) + val binding = Await.result(bindingFuture, 3.seconds) + val myPort = binding.localAddress.getPort + + val (response, sink) = Http().singleWebsocketRequest( + WebsocketRequest("ws://127.0.0.1:" + myPort), + Flow.fromSinkAndSourceMat(TestSink.probe[Message], Source.empty)(Keep.left)) + + response.futureValue.response.status.isSuccess should ===(true) + sink + .request(10) + .expectNoMsg(500.millis) + + source + .sendNext(TextMessage("hello")) + .sendComplete() + sink + .expectNext(TextMessage("hello")) + .expectComplete() + + binding.unbind() + } + + "not reset the connection when no data are flowing and the connection is closed from the client" in Utils.assertAllStagesStopped { + val source = TestPublisher.probe[Message]() + val bindingFuture = Http().bindAndHandleSync({ + case HttpRequest(_, _, headers, _, _) ⇒ + val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get + upgrade.handleMessages(Flow.fromSinkAndSource(Sink.ignore, Source.fromPublisher(source)), None) + }, interface = "localhost", port = 0) + val binding = Await.result(bindingFuture, 3.seconds) + val myPort = binding.localAddress.getPort + + val ((response, breaker), sink) = + Source.empty + .viaMat { + Http().websocketClientLayer(WebsocketRequest("ws://localhost:" + myPort)) + .atop(SslTlsPlacebo.forScala) + .joinMat(Flow.fromGraph(GraphStages.breaker[ByteString]).via( + Tcp().outgoingConnection(new InetSocketAddress("localhost", myPort), halfClose = true)))(Keep.both) + }(Keep.right) + .toMat(TestSink.probe[Message])(Keep.both) + .run() + + response.futureValue.response.status.isSuccess should ===(true) + sink + .request(10) + .expectNoMsg(1500.millis) + + breaker.value.get.get.complete() + + source + .sendNext(TextMessage("hello")) + .sendComplete() + sink + .expectNext(TextMessage("hello")) + .expectComplete() + + binding.unbind() + } + "echo 100 elements and then shut down without error" in Utils.assertAllStagesStopped { val bindingFuture = Http().bindAndHandleSync({ @@ -93,7 +163,7 @@ class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug. .run() eventually(messages should ===(N)) // breaker should have been fulfilled long ago - breaker.value.get.get.complete() + breaker.value.get.get.completeAndCancel() completion.futureValue binding.unbind() diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala index 359a46c7ef..3680ab13ce 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala @@ -108,13 +108,19 @@ object GraphStages { final class Breaker(callback: Breaker.Operation ⇒ Unit) { import Breaker._ def complete(): Unit = callback(Complete) + def cancel(): Unit = callback(Cancel) def fail(ex: Throwable): Unit = callback(Fail(ex)) + def completeAndCancel(): Unit = callback(CompleteAndCancel) + def failAndCancel(ex: Throwable): Unit = callback(FailAndCancel(ex)) } object Breaker extends GraphStageWithMaterializedValue[FlowShape[Any, Any], Future[Breaker]] { sealed trait Operation case object Complete extends Operation + case object Cancel extends Operation case class Fail(ex: Throwable) extends Operation + case object CompleteAndCancel extends Operation + case class FailAndCancel(ex: Throwable) extends Operation override val initialAttributes = Attributes.name("breaker") override val shape = FlowShape(Inlet[Any]("breaker.in"), Outlet[Any]("breaker.out")) @@ -130,8 +136,11 @@ object GraphStages { override def preStart(): Unit = { pull(shape.in) promise.success(new Breaker(getAsyncCallback[Operation] { - case Complete ⇒ completeStage() - case Fail(ex) ⇒ failStage(ex) + case Complete ⇒ complete(shape.out) + case Cancel ⇒ cancel(shape.in) + case Fail(ex) ⇒ fail(shape.out, ex) + case CompleteAndCancel ⇒ completeStage() + case FailAndCancel(ex) ⇒ failStage(ex) }.invoke)) } } @@ -176,8 +185,17 @@ object GraphStages { override def preStart(): Unit = { promise.success(new Breaker(getAsyncCallback[Operation] { - case Complete ⇒ completeStage() - case Fail(ex) ⇒ failStage(ex) + case Complete ⇒ + complete(shape.out1) + complete(shape.out2) + case Cancel ⇒ + cancel(shape.in1) + cancel(shape.in2) + case Fail(ex) ⇒ + fail(shape.out1, ex) + fail(shape.out2, ex) + case CompleteAndCancel ⇒ completeStage() + case FailAndCancel(ex) ⇒ failStage(ex) }.invoke)) } } From dd388d838b9e3d0cb522d60df58e8beece0ebec7 Mon Sep 17 00:00:00 2001 From: Roland Kuhn Date: Mon, 18 Jan 2016 20:44:12 +0100 Subject: [PATCH 4/5] fix HttpResponseRenderer termination - upstream termination arriving after a pull() would lead to deadlock - downstream pull that goes to SubSink would lead to deadlock if followed by SubSink completion --- .../HttpResponseRendererFactory.scala | 35 +++++++++++++------ 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala index 7863046e8b..693bd18931 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/rendering/HttpResponseRendererFactory.scala @@ -62,13 +62,13 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { - private[this] var closeMode: CloseMode = DontClose // signals what to do after the current response - private[this] def close: Boolean = closeMode != DontClose - private[this] def closeIf(cond: Boolean): Unit = - if (cond) closeMode = CloseConnection + var closeMode: CloseMode = DontClose // signals what to do after the current response + def close: Boolean = closeMode != DontClose + def closeIf(cond: Boolean): Unit = if (cond) closeMode = CloseConnection + var transferring = false setHandler(in, new InHandler { - def onPush(): Unit = + override def onPush(): Unit = render(grab(in)) match { case Strict(outElement) ⇒ push(out, outElement) @@ -76,23 +76,36 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser case Streamed(outStream) ⇒ transfer(outStream) } - override def onUpstreamFinish(): Unit = closeMode = CloseConnection + override def onUpstreamFinish(): Unit = + if (transferring) closeMode = CloseConnection + else completeStage() }) val waitForDemandHandler = new OutHandler { - def onPull(): Unit = if (close) completeStage() else pull(in) + def onPull(): Unit = pull(in) } setHandler(out, waitForDemandHandler) def transfer(outStream: Source[ResponseRenderingOutput, Any]): Unit = { + transferring = true val sinkIn = new SubSinkInlet[ResponseRenderingOutput]("RenderingSink") sinkIn.setHandler(new InHandler { - def onPush(): Unit = push(out, sinkIn.grab()) - override def onUpstreamFinish(): Unit = if (close) completeStage() else setHandler(out, waitForDemandHandler) + override def onPush(): Unit = push(out, sinkIn.grab()) + override def onUpstreamFinish(): Unit = + if (close) completeStage() + else { + transferring = false + setHandler(out, waitForDemandHandler) + if (isAvailable(out)) pull(in) + } }) setHandler(out, new OutHandler { - def onPull(): Unit = sinkIn.pull() + override def onPull(): Unit = sinkIn.pull() + override def onDownstreamFinish(): Unit = { + completeStage() + sinkIn.cancel() + } }) sinkIn.pull() - Source.fromGraph(outStream).runWith(sinkIn.sink)(interpreter.subFusingMaterializer) + outStream.runWith(sinkIn.sink)(interpreter.subFusingMaterializer) } def render(ctx: ResponseRenderingContext): StrictOrStreamed = { From 2cca0788ed2cce85fd3816fe5dcd395be92d2ec3 Mon Sep 17 00:00:00 2001 From: Roland Kuhn Date: Mon, 18 Jan 2016 21:11:10 +0100 Subject: [PATCH 5/5] #19503 fix closed-pull in ResponseParsingMerge also assert all stages stopped in HighLevelOutgoingConnectionSpec --- .../client/OutgoingConnectionBlueprint.scala | 28 ++++++++----------- .../HighLevelOutgoingConnectionSpec.scala | 25 +++++++++++------ 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala index a597aec773..0c10c3975f 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala @@ -167,7 +167,6 @@ private[http] object OutgoingConnectionBlueprint { // each connection uses a single (private) response parser instance for all its responses // which builds a cache of all header instances seen on that connection val parser = rootParser.createShallowCopy() - var methodBypassCompleted = false var waitingForMethod = true setHandler(methodBypassInput, new InHandler { @@ -179,7 +178,6 @@ private[http] object OutgoingConnectionBlueprint { } override def onUpstreamFinish(): Unit = if (waitingForMethod) completeStage() - else methodBypassCompleted = true }) setHandler(dataInput, new InHandler { @@ -201,17 +199,16 @@ private[http] object OutgoingConnectionBlueprint { setHandler(out, eagerTerminateOutput) - val getNextMethod = () ⇒ - if (methodBypassCompleted) completeStage() - else { - pull(methodBypassInput) - waitingForMethod = true - } + val getNextMethod = () ⇒ { + waitingForMethod = true + if (isClosed(methodBypassInput)) completeStage() + else pull(methodBypassInput) + } val getNextData = () ⇒ { waitingForMethod = false - if (!isClosed(dataInput)) pull(dataInput) - else completeStage() + if (isClosed(dataInput)) completeStage() + else pull(dataInput) } @tailrec def drainParser(current: ResponseOutput, b: ListBuffer[ResponseOutput] = ListBuffer.empty): Unit = { @@ -219,13 +216,10 @@ private[http] object OutgoingConnectionBlueprint { if (output.nonEmpty) emit(out, output, andThen) else andThen() current match { - case NeedNextRequestMethod ⇒ - e(b.result(), getNextMethod) - case StreamEnd ⇒ - e(b.result(), () ⇒ completeStage()) - case NeedMoreData ⇒ - e(b.result(), getNextData) - case x ⇒ drainParser(parser.onPull(), b += x) + case NeedNextRequestMethod ⇒ e(b.result(), getNextMethod) + case StreamEnd ⇒ e(b.result(), () ⇒ completeStage()) + case NeedMoreData ⇒ e(b.result(), getNextData) + case x ⇒ drainParser(parser.onPull(), b += x) } } diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/HighLevelOutgoingConnectionSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/HighLevelOutgoingConnectionSpec.scala index 32c659eb2d..1004a7ae48 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/client/HighLevelOutgoingConnectionSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/HighLevelOutgoingConnectionSpec.scala @@ -11,16 +11,19 @@ import akka.stream.scaladsl._ import akka.stream.testkit.AkkaSpec import akka.http.scaladsl.{ Http, TestUtils } import akka.http.scaladsl.model._ +import akka.stream.testkit.Utils +import org.scalatest.concurrent.ScalaFutures -class HighLevelOutgoingConnectionSpec extends AkkaSpec { +class HighLevelOutgoingConnectionSpec extends AkkaSpec with ScalaFutures { implicit val materializer = ActorMaterializer() + implicit val patience = PatienceConfig(1.second) "The connection-level client implementation" should { - "be able to handle 100 pipelined requests across one connection" in { + "be able to handle 100 pipelined requests across one connection" in Utils.assertAllStagesStopped { val (_, serverHostName, serverPort) = TestUtils.temporaryServerHostnameAndPort() - Http().bindAndHandleSync(r ⇒ HttpResponse(entity = r.uri.toString.reverse.takeWhile(Character.isDigit).reverse), + val binding = Http().bindAndHandleSync(r ⇒ HttpResponse(entity = r.uri.toString.reverse.takeWhile(Character.isDigit).reverse), serverHostName, serverPort) val N = 100 @@ -32,13 +35,14 @@ class HighLevelOutgoingConnectionSpec extends AkkaSpec { .map { r ⇒ val s = r.data.utf8String; log.debug(s); s.toInt } .runFold(0)(_ + _) - Await.result(result, 10.seconds) shouldEqual N * (N + 1) / 2 + result.futureValue(PatienceConfig(10.seconds)) shouldEqual N * (N + 1) / 2 + binding.futureValue.unbind() } - "be able to handle 100 pipelined requests across 4 connections (client-flow is reusable)" in { + "be able to handle 100 pipelined requests across 4 connections (client-flow is reusable)" in Utils.assertAllStagesStopped { val (_, serverHostName, serverPort) = TestUtils.temporaryServerHostnameAndPort() - Http().bindAndHandleSync(r ⇒ HttpResponse(entity = r.uri.toString.reverse.takeWhile(Character.isDigit).reverse), + val binding = Http().bindAndHandleSync(r ⇒ HttpResponse(entity = r.uri.toString.reverse.takeWhile(Character.isDigit).reverse), serverHostName, serverPort) val connFlow = Http().outgoingConnection(serverHostName, serverPort) @@ -64,12 +68,14 @@ class HighLevelOutgoingConnectionSpec extends AkkaSpec { .map { r ⇒ val s = r.data.utf8String; log.debug(s); s.toInt } .runFold(0)(_ + _) - Await.result(result, 10.seconds) shouldEqual C * N * (N + 1) / 2 + result.futureValue(PatienceConfig(10.seconds)) shouldEqual C * N * (N + 1) / 2 + binding.futureValue.unbind() } - "catch response stream truncation" in { + "catch response stream truncation" in Utils.assertAllStagesStopped { val (_, serverHostName, serverPort) = TestUtils.temporaryServerHostnameAndPort() - Http().bindAndHandleSync({ + + val binding = Http().bindAndHandleSync({ case HttpRequest(_, Uri.Path("/b"), _, _, _) ⇒ HttpResponse(headers = List(headers.Connection("close"))) case _ ⇒ HttpResponse() }, serverHostName, serverPort) @@ -81,6 +87,7 @@ class HighLevelOutgoingConnectionSpec extends AkkaSpec { .runWith(Sink.head) a[One2OneBidiFlow.OutputTruncationException.type] should be thrownBy Await.result(x, 1.second) + binding.futureValue.unbind() } } }