#19398 fix stream leak in ProtocolSwitchStage

also fix potential NPE in TCP streams when failed or canceled early for
an outgoing connection
This commit is contained in:
Roland Kuhn 2016-01-13 17:52:46 +01:00
parent 8c1350b0d4
commit e0361ece66
5 changed files with 209 additions and 57 deletions

View file

@ -371,6 +371,11 @@ private[http] object HttpServerBluePrint {
def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) {
import akka.http.impl.engine.rendering.ResponseRenderingOutput._
/*
* These handlers are in charge until a switch command comes in, then they
* are replaced.
*/
setHandler(fromHttp, new InHandler {
override def onPush(): Unit =
grab(fromHttp) match {
@ -381,21 +386,22 @@ private[http] object HttpServerBluePrint {
cancel(fromHttp)
switchToWebsocket(handlerFlow)
}
override def onUpstreamFinish(): Unit = complete(toNet)
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
})
setHandler(toNet, new OutHandler {
override def onPull(): Unit = pull(fromHttp)
override def onDownstreamFinish(): Unit = completeStage()
})
setHandler(fromNet, new InHandler {
def onPush(): Unit = push(toHttp, grab(fromNet))
// propagate error but don't close stage yet to prevent fromHttp/fromWs being cancelled
// too eagerly
override def onPush(): Unit = push(toHttp, grab(fromNet))
override def onUpstreamFinish(): Unit = complete(toHttp)
override def onUpstreamFailure(ex: Throwable): Unit = fail(toHttp, ex)
})
setHandler(toHttp, new OutHandler {
override def onPull(): Unit = pull(fromNet)
override def onDownstreamFinish(): Unit = ()
override def onDownstreamFinish(): Unit = cancel(fromNet)
})
private var activeTimers = 0
@ -438,13 +444,22 @@ private[http] object HttpServerBluePrint {
sinkIn.setHandler(new InHandler {
override def onPush(): Unit = push(toNet, sinkIn.grab())
override def onUpstreamFinish(): Unit = complete(toNet)
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
})
setHandler(toNet, new OutHandler {
override def onPull(): Unit = sinkIn.pull()
override def onDownstreamFinish(): Unit = {
completeStage()
sinkIn.cancel()
sourceOut.complete()
}
})
setHandler(fromNet, new InHandler {
override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes)
override def onUpstreamFinish(): Unit = sourceOut.complete()
override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex)
})
sourceOut.setHandler(new OutHandler {
override def onPull(): Unit = {
@ -454,6 +469,7 @@ private[http] object HttpServerBluePrint {
override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet)
})
}
override def onDownstreamFinish(): Unit = cancel(fromNet)
})
Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer)

View file

@ -1,50 +0,0 @@
/**
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import akka.stream.testkit.AkkaSpec
import scala.concurrent.Await
import com.typesafe.config.ConfigFactory
import com.typesafe.config.Config
import akka.actor.ActorSystem
import akka.http.scaladsl.model.HttpRequest
import akka.http.scaladsl.model.ws._
import akka.http.scaladsl._
import akka.stream.scaladsl._
import akka.stream._
import scala.concurrent.duration._
import org.scalatest.concurrent.ScalaFutures
import org.scalactic.ConversionCheckedTripleEquals
import akka.stream.testkit.Utils
class BypassRouterSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode = off") with ScalaFutures with ConversionCheckedTripleEquals {
implicit val patience = PatienceConfig(3.seconds)
import system.dispatcher
implicit val materializer = ActorMaterializer()
"BypassRouter" must {
"work without double pull-ing some ports" in Utils.assertAllStagesStopped {
val bindingFuture = Http().bindAndHandleSync({
case HttpRequest(_, _, headers, _, _)
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket u }.get
upgrade.handleMessages(Flow.apply, None)
}, interface = "localhost", port = 8080)
val binding = Await.result(bindingFuture, 3.seconds)
val N = 100
val (response, count) = Http().singleWebsocketRequest(
WebsocketRequest("ws://127.0.0.1:8080"),
Flow.fromSinkAndSourceMat(
Sink.fold(0)((n, _: Message) n + 1),
Source.repeat(TextMessage("hello")).take(N))(Keep.left))
count.futureValue should ===(N)
binding.unbind()
}
}
}

View file

@ -0,0 +1,101 @@
/**
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import scala.concurrent.Await
import scala.concurrent.duration.DurationInt
import org.scalactic.ConversionCheckedTripleEquals
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.time.Span.convertDurationToSpan
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.HttpRequest
import akka.http.scaladsl.model.Uri.apply
import akka.http.scaladsl.model.ws._
import akka.stream._
import akka.stream.scaladsl._
import akka.stream.testkit._
import akka.stream.scaladsl.GraphDSL.Implicits._
import org.scalatest.concurrent.Eventually
import akka.stream.io.SslTlsPlacebo
import java.net.InetSocketAddress
import akka.stream.impl.fusing.GraphStages
import akka.util.ByteString
class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode=off")
with ScalaFutures with ConversionCheckedTripleEquals with Eventually {
implicit val patience = PatienceConfig(3.seconds)
import system.dispatcher
implicit val materializer = ActorMaterializer()
"A Websocket server" must {
"echo 100 elements and then shut down without error" in Utils.assertAllStagesStopped {
val bindingFuture = Http().bindAndHandleSync({
case HttpRequest(_, _, headers, _, _)
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket u }.get
upgrade.handleMessages(Flow.apply, None)
}, interface = "localhost", port = 8080)
val binding = Await.result(bindingFuture, 3.seconds)
val N = 100
val (response, count) = Http().singleWebsocketRequest(
WebsocketRequest("ws://127.0.0.1:8080"),
Flow.fromSinkAndSourceMat(
Sink.fold(0)((n, _: Message) n + 1),
Source.repeat(TextMessage("hello")).take(N))(Keep.left))
count.futureValue should ===(N)
binding.unbind()
}
"send back 100 elements and then terminate without error even when not ordinarily closed" in Utils.assertAllStagesStopped {
val N = 100
val handler = Flow.fromGraph(GraphDSL.create() { implicit b
val merge = b.add(Merge[Int](2))
// convert to int so we can connect to merge
val mapMsgToInt = b.add(Flow[Message].map(_ -1))
val mapIntToMsg = b.add(Flow[Int].map(x TextMessage.Strict(s"Sending: $x")))
// source we want to use to send message to the connected websocket sink
val rangeSource = b.add(Source(1 to N))
mapMsgToInt ~> merge // this part of the merge will never provide msgs
rangeSource ~> merge ~> mapIntToMsg
FlowShape(mapMsgToInt.in, mapIntToMsg.out)
})
val bindingFuture = Http().bindAndHandleSync({
case HttpRequest(_, _, headers, _, _)
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket u }.get
upgrade.handleMessages(handler, None)
}, interface = "localhost", port = 8080)
val binding = Await.result(bindingFuture, 3.seconds)
@volatile var messages = 0
val (breaker, completion) =
Source.maybe
.viaMat {
Http().websocketClientLayer(WebsocketRequest("ws://localhost:8080"))
.atop(SslTlsPlacebo.forScala)
// the resource leak of #19398 existed only for severed websocket connections
.atopMat(GraphStages.bidiBreaker[ByteString, ByteString])(Keep.right)
.join(Tcp().outgoingConnection(new InetSocketAddress("localhost", 8080), halfClose = true))
}(Keep.right)
.toMat(Sink.foreach(_ messages += 1))(Keep.both)
.run()
eventually(messages should ===(N))
// breaker should have been fulfilled long ago
breaker.value.get.get.complete()
completion.futureValue
binding.unbind()
}
}
}