#19398 fix stream leak in ProtocolSwitchStage
also fix potential NPE in TCP streams when failed or canceled early for an outgoing connection
This commit is contained in:
parent
8c1350b0d4
commit
e0361ece66
5 changed files with 209 additions and 57 deletions
|
|
@ -371,6 +371,11 @@ private[http] object HttpServerBluePrint {
|
|||
def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) {
|
||||
import akka.http.impl.engine.rendering.ResponseRenderingOutput._
|
||||
|
||||
/*
|
||||
* These handlers are in charge until a switch command comes in, then they
|
||||
* are replaced.
|
||||
*/
|
||||
|
||||
setHandler(fromHttp, new InHandler {
|
||||
override def onPush(): Unit =
|
||||
grab(fromHttp) match {
|
||||
|
|
@ -381,21 +386,22 @@ private[http] object HttpServerBluePrint {
|
|||
cancel(fromHttp)
|
||||
switchToWebsocket(handlerFlow)
|
||||
}
|
||||
override def onUpstreamFinish(): Unit = complete(toNet)
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
|
||||
})
|
||||
setHandler(toNet, new OutHandler {
|
||||
override def onPull(): Unit = pull(fromHttp)
|
||||
override def onDownstreamFinish(): Unit = completeStage()
|
||||
})
|
||||
|
||||
setHandler(fromNet, new InHandler {
|
||||
def onPush(): Unit = push(toHttp, grab(fromNet))
|
||||
|
||||
// propagate error but don't close stage yet to prevent fromHttp/fromWs being cancelled
|
||||
// too eagerly
|
||||
override def onPush(): Unit = push(toHttp, grab(fromNet))
|
||||
override def onUpstreamFinish(): Unit = complete(toHttp)
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = fail(toHttp, ex)
|
||||
})
|
||||
setHandler(toHttp, new OutHandler {
|
||||
override def onPull(): Unit = pull(fromNet)
|
||||
override def onDownstreamFinish(): Unit = ()
|
||||
override def onDownstreamFinish(): Unit = cancel(fromNet)
|
||||
})
|
||||
|
||||
private var activeTimers = 0
|
||||
|
|
@ -438,13 +444,22 @@ private[http] object HttpServerBluePrint {
|
|||
|
||||
sinkIn.setHandler(new InHandler {
|
||||
override def onPush(): Unit = push(toNet, sinkIn.grab())
|
||||
override def onUpstreamFinish(): Unit = complete(toNet)
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
|
||||
})
|
||||
setHandler(toNet, new OutHandler {
|
||||
override def onPull(): Unit = sinkIn.pull()
|
||||
override def onDownstreamFinish(): Unit = {
|
||||
completeStage()
|
||||
sinkIn.cancel()
|
||||
sourceOut.complete()
|
||||
}
|
||||
})
|
||||
|
||||
setHandler(fromNet, new InHandler {
|
||||
override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes)
|
||||
override def onUpstreamFinish(): Unit = sourceOut.complete()
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex)
|
||||
})
|
||||
sourceOut.setHandler(new OutHandler {
|
||||
override def onPull(): Unit = {
|
||||
|
|
@ -454,6 +469,7 @@ private[http] object HttpServerBluePrint {
|
|||
override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet)
|
||||
})
|
||||
}
|
||||
override def onDownstreamFinish(): Unit = cancel(fromNet)
|
||||
})
|
||||
|
||||
Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer)
|
||||
|
|
|
|||
|
|
@ -1,50 +0,0 @@
|
|||
/**
|
||||
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
package akka.http.impl.engine.ws
|
||||
|
||||
import akka.stream.testkit.AkkaSpec
|
||||
import scala.concurrent.Await
|
||||
import com.typesafe.config.ConfigFactory
|
||||
import com.typesafe.config.Config
|
||||
import akka.actor.ActorSystem
|
||||
import akka.http.scaladsl.model.HttpRequest
|
||||
import akka.http.scaladsl.model.ws._
|
||||
import akka.http.scaladsl._
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream._
|
||||
import scala.concurrent.duration._
|
||||
import org.scalatest.concurrent.ScalaFutures
|
||||
import org.scalactic.ConversionCheckedTripleEquals
|
||||
import akka.stream.testkit.Utils
|
||||
|
||||
class BypassRouterSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode = off") with ScalaFutures with ConversionCheckedTripleEquals {
|
||||
|
||||
implicit val patience = PatienceConfig(3.seconds)
|
||||
import system.dispatcher
|
||||
implicit val materializer = ActorMaterializer()
|
||||
|
||||
"BypassRouter" must {
|
||||
|
||||
"work without double pull-ing some ports" in Utils.assertAllStagesStopped {
|
||||
val bindingFuture = Http().bindAndHandleSync({
|
||||
case HttpRequest(_, _, headers, _, _) ⇒
|
||||
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
||||
upgrade.handleMessages(Flow.apply, None)
|
||||
}, interface = "localhost", port = 8080)
|
||||
val binding = Await.result(bindingFuture, 3.seconds)
|
||||
|
||||
val N = 100
|
||||
val (response, count) = Http().singleWebsocketRequest(
|
||||
WebsocketRequest("ws://127.0.0.1:8080"),
|
||||
Flow.fromSinkAndSourceMat(
|
||||
Sink.fold(0)((n, _: Message) ⇒ n + 1),
|
||||
Source.repeat(TextMessage("hello")).take(N))(Keep.left))
|
||||
|
||||
count.futureValue should ===(N)
|
||||
binding.unbind()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
/**
|
||||
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
package akka.http.impl.engine.ws
|
||||
|
||||
import scala.concurrent.Await
|
||||
import scala.concurrent.duration.DurationInt
|
||||
import org.scalactic.ConversionCheckedTripleEquals
|
||||
import org.scalatest.concurrent.ScalaFutures
|
||||
import org.scalatest.time.Span.convertDurationToSpan
|
||||
import akka.http.scaladsl.Http
|
||||
import akka.http.scaladsl.model.HttpRequest
|
||||
import akka.http.scaladsl.model.Uri.apply
|
||||
import akka.http.scaladsl.model.ws._
|
||||
import akka.stream._
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream.testkit._
|
||||
import akka.stream.scaladsl.GraphDSL.Implicits._
|
||||
import org.scalatest.concurrent.Eventually
|
||||
import akka.stream.io.SslTlsPlacebo
|
||||
import java.net.InetSocketAddress
|
||||
import akka.stream.impl.fusing.GraphStages
|
||||
import akka.util.ByteString
|
||||
|
||||
class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode=off")
|
||||
with ScalaFutures with ConversionCheckedTripleEquals with Eventually {
|
||||
|
||||
implicit val patience = PatienceConfig(3.seconds)
|
||||
import system.dispatcher
|
||||
implicit val materializer = ActorMaterializer()
|
||||
|
||||
"A Websocket server" must {
|
||||
|
||||
"echo 100 elements and then shut down without error" in Utils.assertAllStagesStopped {
|
||||
val bindingFuture = Http().bindAndHandleSync({
|
||||
case HttpRequest(_, _, headers, _, _) ⇒
|
||||
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
||||
upgrade.handleMessages(Flow.apply, None)
|
||||
}, interface = "localhost", port = 8080)
|
||||
val binding = Await.result(bindingFuture, 3.seconds)
|
||||
|
||||
val N = 100
|
||||
val (response, count) = Http().singleWebsocketRequest(
|
||||
WebsocketRequest("ws://127.0.0.1:8080"),
|
||||
Flow.fromSinkAndSourceMat(
|
||||
Sink.fold(0)((n, _: Message) ⇒ n + 1),
|
||||
Source.repeat(TextMessage("hello")).take(N))(Keep.left))
|
||||
|
||||
count.futureValue should ===(N)
|
||||
binding.unbind()
|
||||
}
|
||||
|
||||
"send back 100 elements and then terminate without error even when not ordinarily closed" in Utils.assertAllStagesStopped {
|
||||
val N = 100
|
||||
|
||||
val handler = Flow.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||
val merge = b.add(Merge[Int](2))
|
||||
|
||||
// convert to int so we can connect to merge
|
||||
val mapMsgToInt = b.add(Flow[Message].map(_ ⇒ -1))
|
||||
val mapIntToMsg = b.add(Flow[Int].map(x ⇒ TextMessage.Strict(s"Sending: $x")))
|
||||
|
||||
// source we want to use to send message to the connected websocket sink
|
||||
val rangeSource = b.add(Source(1 to N))
|
||||
|
||||
mapMsgToInt ~> merge // this part of the merge will never provide msgs
|
||||
rangeSource ~> merge ~> mapIntToMsg
|
||||
|
||||
FlowShape(mapMsgToInt.in, mapIntToMsg.out)
|
||||
})
|
||||
|
||||
val bindingFuture = Http().bindAndHandleSync({
|
||||
case HttpRequest(_, _, headers, _, _) ⇒
|
||||
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
||||
upgrade.handleMessages(handler, None)
|
||||
}, interface = "localhost", port = 8080)
|
||||
val binding = Await.result(bindingFuture, 3.seconds)
|
||||
|
||||
@volatile var messages = 0
|
||||
val (breaker, completion) =
|
||||
Source.maybe
|
||||
.viaMat {
|
||||
Http().websocketClientLayer(WebsocketRequest("ws://localhost:8080"))
|
||||
.atop(SslTlsPlacebo.forScala)
|
||||
// the resource leak of #19398 existed only for severed websocket connections
|
||||
.atopMat(GraphStages.bidiBreaker[ByteString, ByteString])(Keep.right)
|
||||
.join(Tcp().outgoingConnection(new InetSocketAddress("localhost", 8080), halfClose = true))
|
||||
}(Keep.right)
|
||||
.toMat(Sink.foreach(_ ⇒ messages += 1))(Keep.both)
|
||||
.run()
|
||||
eventually(messages should ===(N))
|
||||
// breaker should have been fulfilled long ago
|
||||
breaker.value.get.get.complete()
|
||||
completion.futureValue
|
||||
|
||||
binding.unbind()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -105,6 +105,89 @@ object GraphStages {
|
|||
private val _detacher = new Detacher[Any]
|
||||
def detacher[T]: GraphStage[FlowShape[T, T]] = _detacher.asInstanceOf[GraphStage[FlowShape[T, T]]]
|
||||
|
||||
final class Breaker(callback: Breaker.Operation ⇒ Unit) {
|
||||
import Breaker._
|
||||
def complete(): Unit = callback(Complete)
|
||||
def fail(ex: Throwable): Unit = callback(Fail(ex))
|
||||
}
|
||||
|
||||
object Breaker extends GraphStageWithMaterializedValue[FlowShape[Any, Any], Future[Breaker]] {
|
||||
sealed trait Operation
|
||||
case object Complete extends Operation
|
||||
case class Fail(ex: Throwable) extends Operation
|
||||
|
||||
override val initialAttributes = Attributes.name("breaker")
|
||||
override val shape = FlowShape(Inlet[Any]("breaker.in"), Outlet[Any]("breaker.out"))
|
||||
|
||||
override def createLogicAndMaterializedValue(attr: Attributes) = {
|
||||
val promise = Promise[Breaker]
|
||||
|
||||
val logic = new GraphStageLogic(shape) {
|
||||
|
||||
passAlong(shape.in, shape.out)
|
||||
setHandler(shape.out, eagerTerminateOutput)
|
||||
|
||||
override def preStart(): Unit = {
|
||||
pull(shape.in)
|
||||
promise.success(new Breaker(getAsyncCallback[Operation] {
|
||||
case Complete ⇒ completeStage()
|
||||
case Fail(ex) ⇒ failStage(ex)
|
||||
}.invoke))
|
||||
}
|
||||
}
|
||||
|
||||
(logic, promise.future)
|
||||
}
|
||||
}
|
||||
|
||||
def breaker[T]: Graph[FlowShape[T, T], Future[Breaker]] = Breaker.asInstanceOf[Graph[FlowShape[T, T], Future[Breaker]]]
|
||||
|
||||
object BidiBreaker extends GraphStageWithMaterializedValue[BidiShape[Any, Any, Any, Any], Future[Breaker]] {
|
||||
import Breaker._
|
||||
|
||||
override val initialAttributes = Attributes.name("breaker")
|
||||
override val shape = BidiShape(
|
||||
Inlet[Any]("breaker.in1"), Outlet[Any]("breaker.out1"),
|
||||
Inlet[Any]("breaker.in2"), Outlet[Any]("breaker.out2"))
|
||||
|
||||
override def createLogicAndMaterializedValue(attr: Attributes) = {
|
||||
val promise = Promise[Breaker]
|
||||
|
||||
val logic = new GraphStageLogic(shape) {
|
||||
|
||||
setHandler(shape.in1, new InHandler {
|
||||
override def onPush(): Unit = push(shape.out1, grab(shape.in1))
|
||||
override def onUpstreamFinish(): Unit = complete(shape.out1)
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = fail(shape.out1, ex)
|
||||
})
|
||||
setHandler(shape.in2, new InHandler {
|
||||
override def onPush(): Unit = push(shape.out2, grab(shape.in2))
|
||||
override def onUpstreamFinish(): Unit = complete(shape.out2)
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = fail(shape.out2, ex)
|
||||
})
|
||||
setHandler(shape.out1, new OutHandler {
|
||||
override def onPull(): Unit = pull(shape.in1)
|
||||
override def onDownstreamFinish(): Unit = cancel(shape.in1)
|
||||
})
|
||||
setHandler(shape.out2, new OutHandler {
|
||||
override def onPull(): Unit = pull(shape.in2)
|
||||
override def onDownstreamFinish(): Unit = cancel(shape.in2)
|
||||
})
|
||||
|
||||
override def preStart(): Unit = {
|
||||
promise.success(new Breaker(getAsyncCallback[Operation] {
|
||||
case Complete ⇒ completeStage()
|
||||
case Fail(ex) ⇒ failStage(ex)
|
||||
}.invoke))
|
||||
}
|
||||
}
|
||||
|
||||
(logic, promise.future)
|
||||
}
|
||||
}
|
||||
|
||||
def bidiBreaker[T1, T2]: Graph[BidiShape[T1, T1, T2, T2], Future[Breaker]] = BidiBreaker.asInstanceOf[Graph[BidiShape[T1, T1, T2, T2], Future[Breaker]]]
|
||||
|
||||
private object TickSource {
|
||||
class TickSourceCancellable(cancelled: AtomicBoolean) extends Cancellable {
|
||||
private val cancelPromise = Promise[Unit]()
|
||||
|
|
|
|||
|
|
@ -261,11 +261,13 @@ private[stream] object TcpConnectionStage {
|
|||
// (or half-close is turned off)
|
||||
if (isClosed(bytesOut) || !role.halfClose) connection ! Close
|
||||
// We still read, so we only close the write side
|
||||
else connection ! ConfirmedClose
|
||||
else if (connection != null) connection ! ConfirmedClose
|
||||
else completeStage()
|
||||
}
|
||||
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = {
|
||||
connection ! Abort
|
||||
if (connection != null) connection ! Abort
|
||||
else failStage(ex)
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue