#19398 fix stream leak in ProtocolSwitchStage

also fix potential NPE in TCP streams when failed or canceled early for
an outgoing connection
This commit is contained in:
Roland Kuhn 2016-01-13 17:52:46 +01:00
parent 8c1350b0d4
commit e0361ece66
5 changed files with 209 additions and 57 deletions

View file

@ -371,6 +371,11 @@ private[http] object HttpServerBluePrint {
def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) {
import akka.http.impl.engine.rendering.ResponseRenderingOutput._ import akka.http.impl.engine.rendering.ResponseRenderingOutput._
/*
* These handlers are in charge until a switch command comes in, then they
* are replaced.
*/
setHandler(fromHttp, new InHandler { setHandler(fromHttp, new InHandler {
override def onPush(): Unit = override def onPush(): Unit =
grab(fromHttp) match { grab(fromHttp) match {
@ -381,21 +386,22 @@ private[http] object HttpServerBluePrint {
cancel(fromHttp) cancel(fromHttp)
switchToWebsocket(handlerFlow) switchToWebsocket(handlerFlow)
} }
override def onUpstreamFinish(): Unit = complete(toNet)
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
}) })
setHandler(toNet, new OutHandler { setHandler(toNet, new OutHandler {
override def onPull(): Unit = pull(fromHttp) override def onPull(): Unit = pull(fromHttp)
override def onDownstreamFinish(): Unit = completeStage()
}) })
setHandler(fromNet, new InHandler { setHandler(fromNet, new InHandler {
def onPush(): Unit = push(toHttp, grab(fromNet)) override def onPush(): Unit = push(toHttp, grab(fromNet))
override def onUpstreamFinish(): Unit = complete(toHttp)
// propagate error but don't close stage yet to prevent fromHttp/fromWs being cancelled
// too eagerly
override def onUpstreamFailure(ex: Throwable): Unit = fail(toHttp, ex) override def onUpstreamFailure(ex: Throwable): Unit = fail(toHttp, ex)
}) })
setHandler(toHttp, new OutHandler { setHandler(toHttp, new OutHandler {
override def onPull(): Unit = pull(fromNet) override def onPull(): Unit = pull(fromNet)
override def onDownstreamFinish(): Unit = () override def onDownstreamFinish(): Unit = cancel(fromNet)
}) })
private var activeTimers = 0 private var activeTimers = 0
@ -438,13 +444,22 @@ private[http] object HttpServerBluePrint {
sinkIn.setHandler(new InHandler { sinkIn.setHandler(new InHandler {
override def onPush(): Unit = push(toNet, sinkIn.grab()) override def onPush(): Unit = push(toNet, sinkIn.grab())
override def onUpstreamFinish(): Unit = complete(toNet)
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
}) })
setHandler(toNet, new OutHandler { setHandler(toNet, new OutHandler {
override def onPull(): Unit = sinkIn.pull() override def onPull(): Unit = sinkIn.pull()
override def onDownstreamFinish(): Unit = {
completeStage()
sinkIn.cancel()
sourceOut.complete()
}
}) })
setHandler(fromNet, new InHandler { setHandler(fromNet, new InHandler {
override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes) override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes)
override def onUpstreamFinish(): Unit = sourceOut.complete()
override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex)
}) })
sourceOut.setHandler(new OutHandler { sourceOut.setHandler(new OutHandler {
override def onPull(): Unit = { override def onPull(): Unit = {
@ -454,6 +469,7 @@ private[http] object HttpServerBluePrint {
override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet) override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet)
}) })
} }
override def onDownstreamFinish(): Unit = cancel(fromNet)
}) })
Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer) Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer)

View file

@ -1,50 +0,0 @@
/**
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import akka.stream.testkit.AkkaSpec
import scala.concurrent.Await
import com.typesafe.config.ConfigFactory
import com.typesafe.config.Config
import akka.actor.ActorSystem
import akka.http.scaladsl.model.HttpRequest
import akka.http.scaladsl.model.ws._
import akka.http.scaladsl._
import akka.stream.scaladsl._
import akka.stream._
import scala.concurrent.duration._
import org.scalatest.concurrent.ScalaFutures
import org.scalactic.ConversionCheckedTripleEquals
import akka.stream.testkit.Utils
class BypassRouterSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode = off") with ScalaFutures with ConversionCheckedTripleEquals {
implicit val patience = PatienceConfig(3.seconds)
import system.dispatcher
implicit val materializer = ActorMaterializer()
"BypassRouter" must {
"work without double pull-ing some ports" in Utils.assertAllStagesStopped {
val bindingFuture = Http().bindAndHandleSync({
case HttpRequest(_, _, headers, _, _)
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket u }.get
upgrade.handleMessages(Flow.apply, None)
}, interface = "localhost", port = 8080)
val binding = Await.result(bindingFuture, 3.seconds)
val N = 100
val (response, count) = Http().singleWebsocketRequest(
WebsocketRequest("ws://127.0.0.1:8080"),
Flow.fromSinkAndSourceMat(
Sink.fold(0)((n, _: Message) n + 1),
Source.repeat(TextMessage("hello")).take(N))(Keep.left))
count.futureValue should ===(N)
binding.unbind()
}
}
}

View file

@ -0,0 +1,101 @@
/**
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import scala.concurrent.Await
import scala.concurrent.duration.DurationInt
import org.scalactic.ConversionCheckedTripleEquals
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.time.Span.convertDurationToSpan
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.HttpRequest
import akka.http.scaladsl.model.Uri.apply
import akka.http.scaladsl.model.ws._
import akka.stream._
import akka.stream.scaladsl._
import akka.stream.testkit._
import akka.stream.scaladsl.GraphDSL.Implicits._
import org.scalatest.concurrent.Eventually
import akka.stream.io.SslTlsPlacebo
import java.net.InetSocketAddress
import akka.stream.impl.fusing.GraphStages
import akka.util.ByteString
class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode=off")
with ScalaFutures with ConversionCheckedTripleEquals with Eventually {
implicit val patience = PatienceConfig(3.seconds)
import system.dispatcher
implicit val materializer = ActorMaterializer()
"A Websocket server" must {
"echo 100 elements and then shut down without error" in Utils.assertAllStagesStopped {
val bindingFuture = Http().bindAndHandleSync({
case HttpRequest(_, _, headers, _, _)
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket u }.get
upgrade.handleMessages(Flow.apply, None)
}, interface = "localhost", port = 8080)
val binding = Await.result(bindingFuture, 3.seconds)
val N = 100
val (response, count) = Http().singleWebsocketRequest(
WebsocketRequest("ws://127.0.0.1:8080"),
Flow.fromSinkAndSourceMat(
Sink.fold(0)((n, _: Message) n + 1),
Source.repeat(TextMessage("hello")).take(N))(Keep.left))
count.futureValue should ===(N)
binding.unbind()
}
"send back 100 elements and then terminate without error even when not ordinarily closed" in Utils.assertAllStagesStopped {
val N = 100
val handler = Flow.fromGraph(GraphDSL.create() { implicit b
val merge = b.add(Merge[Int](2))
// convert to int so we can connect to merge
val mapMsgToInt = b.add(Flow[Message].map(_ -1))
val mapIntToMsg = b.add(Flow[Int].map(x TextMessage.Strict(s"Sending: $x")))
// source we want to use to send message to the connected websocket sink
val rangeSource = b.add(Source(1 to N))
mapMsgToInt ~> merge // this part of the merge will never provide msgs
rangeSource ~> merge ~> mapIntToMsg
FlowShape(mapMsgToInt.in, mapIntToMsg.out)
})
val bindingFuture = Http().bindAndHandleSync({
case HttpRequest(_, _, headers, _, _)
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket u }.get
upgrade.handleMessages(handler, None)
}, interface = "localhost", port = 8080)
val binding = Await.result(bindingFuture, 3.seconds)
@volatile var messages = 0
val (breaker, completion) =
Source.maybe
.viaMat {
Http().websocketClientLayer(WebsocketRequest("ws://localhost:8080"))
.atop(SslTlsPlacebo.forScala)
// the resource leak of #19398 existed only for severed websocket connections
.atopMat(GraphStages.bidiBreaker[ByteString, ByteString])(Keep.right)
.join(Tcp().outgoingConnection(new InetSocketAddress("localhost", 8080), halfClose = true))
}(Keep.right)
.toMat(Sink.foreach(_ messages += 1))(Keep.both)
.run()
eventually(messages should ===(N))
// breaker should have been fulfilled long ago
breaker.value.get.get.complete()
completion.futureValue
binding.unbind()
}
}
}

View file

@ -105,6 +105,89 @@ object GraphStages {
private val _detacher = new Detacher[Any] private val _detacher = new Detacher[Any]
def detacher[T]: GraphStage[FlowShape[T, T]] = _detacher.asInstanceOf[GraphStage[FlowShape[T, T]]] def detacher[T]: GraphStage[FlowShape[T, T]] = _detacher.asInstanceOf[GraphStage[FlowShape[T, T]]]
final class Breaker(callback: Breaker.Operation Unit) {
import Breaker._
def complete(): Unit = callback(Complete)
def fail(ex: Throwable): Unit = callback(Fail(ex))
}
object Breaker extends GraphStageWithMaterializedValue[FlowShape[Any, Any], Future[Breaker]] {
sealed trait Operation
case object Complete extends Operation
case class Fail(ex: Throwable) extends Operation
override val initialAttributes = Attributes.name("breaker")
override val shape = FlowShape(Inlet[Any]("breaker.in"), Outlet[Any]("breaker.out"))
override def createLogicAndMaterializedValue(attr: Attributes) = {
val promise = Promise[Breaker]
val logic = new GraphStageLogic(shape) {
passAlong(shape.in, shape.out)
setHandler(shape.out, eagerTerminateOutput)
override def preStart(): Unit = {
pull(shape.in)
promise.success(new Breaker(getAsyncCallback[Operation] {
case Complete completeStage()
case Fail(ex) failStage(ex)
}.invoke))
}
}
(logic, promise.future)
}
}
def breaker[T]: Graph[FlowShape[T, T], Future[Breaker]] = Breaker.asInstanceOf[Graph[FlowShape[T, T], Future[Breaker]]]
object BidiBreaker extends GraphStageWithMaterializedValue[BidiShape[Any, Any, Any, Any], Future[Breaker]] {
import Breaker._
override val initialAttributes = Attributes.name("breaker")
override val shape = BidiShape(
Inlet[Any]("breaker.in1"), Outlet[Any]("breaker.out1"),
Inlet[Any]("breaker.in2"), Outlet[Any]("breaker.out2"))
override def createLogicAndMaterializedValue(attr: Attributes) = {
val promise = Promise[Breaker]
val logic = new GraphStageLogic(shape) {
setHandler(shape.in1, new InHandler {
override def onPush(): Unit = push(shape.out1, grab(shape.in1))
override def onUpstreamFinish(): Unit = complete(shape.out1)
override def onUpstreamFailure(ex: Throwable): Unit = fail(shape.out1, ex)
})
setHandler(shape.in2, new InHandler {
override def onPush(): Unit = push(shape.out2, grab(shape.in2))
override def onUpstreamFinish(): Unit = complete(shape.out2)
override def onUpstreamFailure(ex: Throwable): Unit = fail(shape.out2, ex)
})
setHandler(shape.out1, new OutHandler {
override def onPull(): Unit = pull(shape.in1)
override def onDownstreamFinish(): Unit = cancel(shape.in1)
})
setHandler(shape.out2, new OutHandler {
override def onPull(): Unit = pull(shape.in2)
override def onDownstreamFinish(): Unit = cancel(shape.in2)
})
override def preStart(): Unit = {
promise.success(new Breaker(getAsyncCallback[Operation] {
case Complete completeStage()
case Fail(ex) failStage(ex)
}.invoke))
}
}
(logic, promise.future)
}
}
def bidiBreaker[T1, T2]: Graph[BidiShape[T1, T1, T2, T2], Future[Breaker]] = BidiBreaker.asInstanceOf[Graph[BidiShape[T1, T1, T2, T2], Future[Breaker]]]
private object TickSource { private object TickSource {
class TickSourceCancellable(cancelled: AtomicBoolean) extends Cancellable { class TickSourceCancellable(cancelled: AtomicBoolean) extends Cancellable {
private val cancelPromise = Promise[Unit]() private val cancelPromise = Promise[Unit]()

View file

@ -261,11 +261,13 @@ private[stream] object TcpConnectionStage {
// (or half-close is turned off) // (or half-close is turned off)
if (isClosed(bytesOut) || !role.halfClose) connection ! Close if (isClosed(bytesOut) || !role.halfClose) connection ! Close
// We still read, so we only close the write side // We still read, so we only close the write side
else connection ! ConfirmedClose else if (connection != null) connection ! ConfirmedClose
else completeStage()
} }
override def onUpstreamFailure(ex: Throwable): Unit = { override def onUpstreamFailure(ex: Throwable): Unit = {
connection ! Abort if (connection != null) connection ! Abort
else failStage(ex)
} }
}) })