#19398 fix stream leak in ProtocolSwitchStage
also fix potential NPE in TCP streams when failed or canceled early for an outgoing connection
This commit is contained in:
parent
8c1350b0d4
commit
e0361ece66
5 changed files with 209 additions and 57 deletions
|
|
@ -371,6 +371,11 @@ private[http] object HttpServerBluePrint {
|
||||||
def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) {
|
def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) {
|
||||||
import akka.http.impl.engine.rendering.ResponseRenderingOutput._
|
import akka.http.impl.engine.rendering.ResponseRenderingOutput._
|
||||||
|
|
||||||
|
/*
|
||||||
|
* These handlers are in charge until a switch command comes in, then they
|
||||||
|
* are replaced.
|
||||||
|
*/
|
||||||
|
|
||||||
setHandler(fromHttp, new InHandler {
|
setHandler(fromHttp, new InHandler {
|
||||||
override def onPush(): Unit =
|
override def onPush(): Unit =
|
||||||
grab(fromHttp) match {
|
grab(fromHttp) match {
|
||||||
|
|
@ -381,21 +386,22 @@ private[http] object HttpServerBluePrint {
|
||||||
cancel(fromHttp)
|
cancel(fromHttp)
|
||||||
switchToWebsocket(handlerFlow)
|
switchToWebsocket(handlerFlow)
|
||||||
}
|
}
|
||||||
|
override def onUpstreamFinish(): Unit = complete(toNet)
|
||||||
|
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
|
||||||
})
|
})
|
||||||
setHandler(toNet, new OutHandler {
|
setHandler(toNet, new OutHandler {
|
||||||
override def onPull(): Unit = pull(fromHttp)
|
override def onPull(): Unit = pull(fromHttp)
|
||||||
|
override def onDownstreamFinish(): Unit = completeStage()
|
||||||
})
|
})
|
||||||
|
|
||||||
setHandler(fromNet, new InHandler {
|
setHandler(fromNet, new InHandler {
|
||||||
def onPush(): Unit = push(toHttp, grab(fromNet))
|
override def onPush(): Unit = push(toHttp, grab(fromNet))
|
||||||
|
override def onUpstreamFinish(): Unit = complete(toHttp)
|
||||||
// propagate error but don't close stage yet to prevent fromHttp/fromWs being cancelled
|
|
||||||
// too eagerly
|
|
||||||
override def onUpstreamFailure(ex: Throwable): Unit = fail(toHttp, ex)
|
override def onUpstreamFailure(ex: Throwable): Unit = fail(toHttp, ex)
|
||||||
})
|
})
|
||||||
setHandler(toHttp, new OutHandler {
|
setHandler(toHttp, new OutHandler {
|
||||||
override def onPull(): Unit = pull(fromNet)
|
override def onPull(): Unit = pull(fromNet)
|
||||||
override def onDownstreamFinish(): Unit = ()
|
override def onDownstreamFinish(): Unit = cancel(fromNet)
|
||||||
})
|
})
|
||||||
|
|
||||||
private var activeTimers = 0
|
private var activeTimers = 0
|
||||||
|
|
@ -438,13 +444,22 @@ private[http] object HttpServerBluePrint {
|
||||||
|
|
||||||
sinkIn.setHandler(new InHandler {
|
sinkIn.setHandler(new InHandler {
|
||||||
override def onPush(): Unit = push(toNet, sinkIn.grab())
|
override def onPush(): Unit = push(toNet, sinkIn.grab())
|
||||||
|
override def onUpstreamFinish(): Unit = complete(toNet)
|
||||||
|
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
|
||||||
})
|
})
|
||||||
setHandler(toNet, new OutHandler {
|
setHandler(toNet, new OutHandler {
|
||||||
override def onPull(): Unit = sinkIn.pull()
|
override def onPull(): Unit = sinkIn.pull()
|
||||||
|
override def onDownstreamFinish(): Unit = {
|
||||||
|
completeStage()
|
||||||
|
sinkIn.cancel()
|
||||||
|
sourceOut.complete()
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
setHandler(fromNet, new InHandler {
|
setHandler(fromNet, new InHandler {
|
||||||
override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes)
|
override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes)
|
||||||
|
override def onUpstreamFinish(): Unit = sourceOut.complete()
|
||||||
|
override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex)
|
||||||
})
|
})
|
||||||
sourceOut.setHandler(new OutHandler {
|
sourceOut.setHandler(new OutHandler {
|
||||||
override def onPull(): Unit = {
|
override def onPull(): Unit = {
|
||||||
|
|
@ -454,6 +469,7 @@ private[http] object HttpServerBluePrint {
|
||||||
override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet)
|
override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
override def onDownstreamFinish(): Unit = cancel(fromNet)
|
||||||
})
|
})
|
||||||
|
|
||||||
Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer)
|
Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer)
|
||||||
|
|
|
||||||
|
|
@ -1,50 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
|
|
||||||
*/
|
|
||||||
package akka.http.impl.engine.ws
|
|
||||||
|
|
||||||
import akka.stream.testkit.AkkaSpec
|
|
||||||
import scala.concurrent.Await
|
|
||||||
import com.typesafe.config.ConfigFactory
|
|
||||||
import com.typesafe.config.Config
|
|
||||||
import akka.actor.ActorSystem
|
|
||||||
import akka.http.scaladsl.model.HttpRequest
|
|
||||||
import akka.http.scaladsl.model.ws._
|
|
||||||
import akka.http.scaladsl._
|
|
||||||
import akka.stream.scaladsl._
|
|
||||||
import akka.stream._
|
|
||||||
import scala.concurrent.duration._
|
|
||||||
import org.scalatest.concurrent.ScalaFutures
|
|
||||||
import org.scalactic.ConversionCheckedTripleEquals
|
|
||||||
import akka.stream.testkit.Utils
|
|
||||||
|
|
||||||
class BypassRouterSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode = off") with ScalaFutures with ConversionCheckedTripleEquals {
|
|
||||||
|
|
||||||
implicit val patience = PatienceConfig(3.seconds)
|
|
||||||
import system.dispatcher
|
|
||||||
implicit val materializer = ActorMaterializer()
|
|
||||||
|
|
||||||
"BypassRouter" must {
|
|
||||||
|
|
||||||
"work without double pull-ing some ports" in Utils.assertAllStagesStopped {
|
|
||||||
val bindingFuture = Http().bindAndHandleSync({
|
|
||||||
case HttpRequest(_, _, headers, _, _) ⇒
|
|
||||||
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
|
||||||
upgrade.handleMessages(Flow.apply, None)
|
|
||||||
}, interface = "localhost", port = 8080)
|
|
||||||
val binding = Await.result(bindingFuture, 3.seconds)
|
|
||||||
|
|
||||||
val N = 100
|
|
||||||
val (response, count) = Http().singleWebsocketRequest(
|
|
||||||
WebsocketRequest("ws://127.0.0.1:8080"),
|
|
||||||
Flow.fromSinkAndSourceMat(
|
|
||||||
Sink.fold(0)((n, _: Message) ⇒ n + 1),
|
|
||||||
Source.repeat(TextMessage("hello")).take(N))(Keep.left))
|
|
||||||
|
|
||||||
count.futureValue should ===(N)
|
|
||||||
binding.unbind()
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
@ -0,0 +1,101 @@
|
||||||
|
/**
|
||||||
|
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
|
||||||
|
*/
|
||||||
|
package akka.http.impl.engine.ws
|
||||||
|
|
||||||
|
import scala.concurrent.Await
|
||||||
|
import scala.concurrent.duration.DurationInt
|
||||||
|
import org.scalactic.ConversionCheckedTripleEquals
|
||||||
|
import org.scalatest.concurrent.ScalaFutures
|
||||||
|
import org.scalatest.time.Span.convertDurationToSpan
|
||||||
|
import akka.http.scaladsl.Http
|
||||||
|
import akka.http.scaladsl.model.HttpRequest
|
||||||
|
import akka.http.scaladsl.model.Uri.apply
|
||||||
|
import akka.http.scaladsl.model.ws._
|
||||||
|
import akka.stream._
|
||||||
|
import akka.stream.scaladsl._
|
||||||
|
import akka.stream.testkit._
|
||||||
|
import akka.stream.scaladsl.GraphDSL.Implicits._
|
||||||
|
import org.scalatest.concurrent.Eventually
|
||||||
|
import akka.stream.io.SslTlsPlacebo
|
||||||
|
import java.net.InetSocketAddress
|
||||||
|
import akka.stream.impl.fusing.GraphStages
|
||||||
|
import akka.util.ByteString
|
||||||
|
|
||||||
|
class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode=off")
|
||||||
|
with ScalaFutures with ConversionCheckedTripleEquals with Eventually {
|
||||||
|
|
||||||
|
implicit val patience = PatienceConfig(3.seconds)
|
||||||
|
import system.dispatcher
|
||||||
|
implicit val materializer = ActorMaterializer()
|
||||||
|
|
||||||
|
"A Websocket server" must {
|
||||||
|
|
||||||
|
"echo 100 elements and then shut down without error" in Utils.assertAllStagesStopped {
|
||||||
|
val bindingFuture = Http().bindAndHandleSync({
|
||||||
|
case HttpRequest(_, _, headers, _, _) ⇒
|
||||||
|
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
||||||
|
upgrade.handleMessages(Flow.apply, None)
|
||||||
|
}, interface = "localhost", port = 8080)
|
||||||
|
val binding = Await.result(bindingFuture, 3.seconds)
|
||||||
|
|
||||||
|
val N = 100
|
||||||
|
val (response, count) = Http().singleWebsocketRequest(
|
||||||
|
WebsocketRequest("ws://127.0.0.1:8080"),
|
||||||
|
Flow.fromSinkAndSourceMat(
|
||||||
|
Sink.fold(0)((n, _: Message) ⇒ n + 1),
|
||||||
|
Source.repeat(TextMessage("hello")).take(N))(Keep.left))
|
||||||
|
|
||||||
|
count.futureValue should ===(N)
|
||||||
|
binding.unbind()
|
||||||
|
}
|
||||||
|
|
||||||
|
"send back 100 elements and then terminate without error even when not ordinarily closed" in Utils.assertAllStagesStopped {
|
||||||
|
val N = 100
|
||||||
|
|
||||||
|
val handler = Flow.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||||
|
val merge = b.add(Merge[Int](2))
|
||||||
|
|
||||||
|
// convert to int so we can connect to merge
|
||||||
|
val mapMsgToInt = b.add(Flow[Message].map(_ ⇒ -1))
|
||||||
|
val mapIntToMsg = b.add(Flow[Int].map(x ⇒ TextMessage.Strict(s"Sending: $x")))
|
||||||
|
|
||||||
|
// source we want to use to send message to the connected websocket sink
|
||||||
|
val rangeSource = b.add(Source(1 to N))
|
||||||
|
|
||||||
|
mapMsgToInt ~> merge // this part of the merge will never provide msgs
|
||||||
|
rangeSource ~> merge ~> mapIntToMsg
|
||||||
|
|
||||||
|
FlowShape(mapMsgToInt.in, mapIntToMsg.out)
|
||||||
|
})
|
||||||
|
|
||||||
|
val bindingFuture = Http().bindAndHandleSync({
|
||||||
|
case HttpRequest(_, _, headers, _, _) ⇒
|
||||||
|
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
||||||
|
upgrade.handleMessages(handler, None)
|
||||||
|
}, interface = "localhost", port = 8080)
|
||||||
|
val binding = Await.result(bindingFuture, 3.seconds)
|
||||||
|
|
||||||
|
@volatile var messages = 0
|
||||||
|
val (breaker, completion) =
|
||||||
|
Source.maybe
|
||||||
|
.viaMat {
|
||||||
|
Http().websocketClientLayer(WebsocketRequest("ws://localhost:8080"))
|
||||||
|
.atop(SslTlsPlacebo.forScala)
|
||||||
|
// the resource leak of #19398 existed only for severed websocket connections
|
||||||
|
.atopMat(GraphStages.bidiBreaker[ByteString, ByteString])(Keep.right)
|
||||||
|
.join(Tcp().outgoingConnection(new InetSocketAddress("localhost", 8080), halfClose = true))
|
||||||
|
}(Keep.right)
|
||||||
|
.toMat(Sink.foreach(_ ⇒ messages += 1))(Keep.both)
|
||||||
|
.run()
|
||||||
|
eventually(messages should ===(N))
|
||||||
|
// breaker should have been fulfilled long ago
|
||||||
|
breaker.value.get.get.complete()
|
||||||
|
completion.futureValue
|
||||||
|
|
||||||
|
binding.unbind()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -105,6 +105,89 @@ object GraphStages {
|
||||||
private val _detacher = new Detacher[Any]
|
private val _detacher = new Detacher[Any]
|
||||||
def detacher[T]: GraphStage[FlowShape[T, T]] = _detacher.asInstanceOf[GraphStage[FlowShape[T, T]]]
|
def detacher[T]: GraphStage[FlowShape[T, T]] = _detacher.asInstanceOf[GraphStage[FlowShape[T, T]]]
|
||||||
|
|
||||||
|
final class Breaker(callback: Breaker.Operation ⇒ Unit) {
|
||||||
|
import Breaker._
|
||||||
|
def complete(): Unit = callback(Complete)
|
||||||
|
def fail(ex: Throwable): Unit = callback(Fail(ex))
|
||||||
|
}
|
||||||
|
|
||||||
|
object Breaker extends GraphStageWithMaterializedValue[FlowShape[Any, Any], Future[Breaker]] {
|
||||||
|
sealed trait Operation
|
||||||
|
case object Complete extends Operation
|
||||||
|
case class Fail(ex: Throwable) extends Operation
|
||||||
|
|
||||||
|
override val initialAttributes = Attributes.name("breaker")
|
||||||
|
override val shape = FlowShape(Inlet[Any]("breaker.in"), Outlet[Any]("breaker.out"))
|
||||||
|
|
||||||
|
override def createLogicAndMaterializedValue(attr: Attributes) = {
|
||||||
|
val promise = Promise[Breaker]
|
||||||
|
|
||||||
|
val logic = new GraphStageLogic(shape) {
|
||||||
|
|
||||||
|
passAlong(shape.in, shape.out)
|
||||||
|
setHandler(shape.out, eagerTerminateOutput)
|
||||||
|
|
||||||
|
override def preStart(): Unit = {
|
||||||
|
pull(shape.in)
|
||||||
|
promise.success(new Breaker(getAsyncCallback[Operation] {
|
||||||
|
case Complete ⇒ completeStage()
|
||||||
|
case Fail(ex) ⇒ failStage(ex)
|
||||||
|
}.invoke))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(logic, promise.future)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def breaker[T]: Graph[FlowShape[T, T], Future[Breaker]] = Breaker.asInstanceOf[Graph[FlowShape[T, T], Future[Breaker]]]
|
||||||
|
|
||||||
|
object BidiBreaker extends GraphStageWithMaterializedValue[BidiShape[Any, Any, Any, Any], Future[Breaker]] {
|
||||||
|
import Breaker._
|
||||||
|
|
||||||
|
override val initialAttributes = Attributes.name("breaker")
|
||||||
|
override val shape = BidiShape(
|
||||||
|
Inlet[Any]("breaker.in1"), Outlet[Any]("breaker.out1"),
|
||||||
|
Inlet[Any]("breaker.in2"), Outlet[Any]("breaker.out2"))
|
||||||
|
|
||||||
|
override def createLogicAndMaterializedValue(attr: Attributes) = {
|
||||||
|
val promise = Promise[Breaker]
|
||||||
|
|
||||||
|
val logic = new GraphStageLogic(shape) {
|
||||||
|
|
||||||
|
setHandler(shape.in1, new InHandler {
|
||||||
|
override def onPush(): Unit = push(shape.out1, grab(shape.in1))
|
||||||
|
override def onUpstreamFinish(): Unit = complete(shape.out1)
|
||||||
|
override def onUpstreamFailure(ex: Throwable): Unit = fail(shape.out1, ex)
|
||||||
|
})
|
||||||
|
setHandler(shape.in2, new InHandler {
|
||||||
|
override def onPush(): Unit = push(shape.out2, grab(shape.in2))
|
||||||
|
override def onUpstreamFinish(): Unit = complete(shape.out2)
|
||||||
|
override def onUpstreamFailure(ex: Throwable): Unit = fail(shape.out2, ex)
|
||||||
|
})
|
||||||
|
setHandler(shape.out1, new OutHandler {
|
||||||
|
override def onPull(): Unit = pull(shape.in1)
|
||||||
|
override def onDownstreamFinish(): Unit = cancel(shape.in1)
|
||||||
|
})
|
||||||
|
setHandler(shape.out2, new OutHandler {
|
||||||
|
override def onPull(): Unit = pull(shape.in2)
|
||||||
|
override def onDownstreamFinish(): Unit = cancel(shape.in2)
|
||||||
|
})
|
||||||
|
|
||||||
|
override def preStart(): Unit = {
|
||||||
|
promise.success(new Breaker(getAsyncCallback[Operation] {
|
||||||
|
case Complete ⇒ completeStage()
|
||||||
|
case Fail(ex) ⇒ failStage(ex)
|
||||||
|
}.invoke))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(logic, promise.future)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def bidiBreaker[T1, T2]: Graph[BidiShape[T1, T1, T2, T2], Future[Breaker]] = BidiBreaker.asInstanceOf[Graph[BidiShape[T1, T1, T2, T2], Future[Breaker]]]
|
||||||
|
|
||||||
private object TickSource {
|
private object TickSource {
|
||||||
class TickSourceCancellable(cancelled: AtomicBoolean) extends Cancellable {
|
class TickSourceCancellable(cancelled: AtomicBoolean) extends Cancellable {
|
||||||
private val cancelPromise = Promise[Unit]()
|
private val cancelPromise = Promise[Unit]()
|
||||||
|
|
|
||||||
|
|
@ -261,11 +261,13 @@ private[stream] object TcpConnectionStage {
|
||||||
// (or half-close is turned off)
|
// (or half-close is turned off)
|
||||||
if (isClosed(bytesOut) || !role.halfClose) connection ! Close
|
if (isClosed(bytesOut) || !role.halfClose) connection ! Close
|
||||||
// We still read, so we only close the write side
|
// We still read, so we only close the write side
|
||||||
else connection ! ConfirmedClose
|
else if (connection != null) connection ! ConfirmedClose
|
||||||
|
else completeStage()
|
||||||
}
|
}
|
||||||
|
|
||||||
override def onUpstreamFailure(ex: Throwable): Unit = {
|
override def onUpstreamFailure(ex: Throwable): Unit = {
|
||||||
connection ! Abort
|
if (connection != null) connection ! Abort
|
||||||
|
else failStage(ex)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue