#19398 fix stream leak in ProtocolSwitchStage
also fix potential NPE in TCP streams when failed or canceled early for an outgoing connection
This commit is contained in:
parent
8c1350b0d4
commit
e0361ece66
5 changed files with 209 additions and 57 deletions
|
|
@ -371,6 +371,11 @@ private[http] object HttpServerBluePrint {
|
|||
def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) {
|
||||
import akka.http.impl.engine.rendering.ResponseRenderingOutput._
|
||||
|
||||
/*
|
||||
* These handlers are in charge until a switch command comes in, then they
|
||||
* are replaced.
|
||||
*/
|
||||
|
||||
setHandler(fromHttp, new InHandler {
|
||||
override def onPush(): Unit =
|
||||
grab(fromHttp) match {
|
||||
|
|
@ -381,21 +386,22 @@ private[http] object HttpServerBluePrint {
|
|||
cancel(fromHttp)
|
||||
switchToWebsocket(handlerFlow)
|
||||
}
|
||||
override def onUpstreamFinish(): Unit = complete(toNet)
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
|
||||
})
|
||||
setHandler(toNet, new OutHandler {
|
||||
override def onPull(): Unit = pull(fromHttp)
|
||||
override def onDownstreamFinish(): Unit = completeStage()
|
||||
})
|
||||
|
||||
setHandler(fromNet, new InHandler {
|
||||
def onPush(): Unit = push(toHttp, grab(fromNet))
|
||||
|
||||
// propagate error but don't close stage yet to prevent fromHttp/fromWs being cancelled
|
||||
// too eagerly
|
||||
override def onPush(): Unit = push(toHttp, grab(fromNet))
|
||||
override def onUpstreamFinish(): Unit = complete(toHttp)
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = fail(toHttp, ex)
|
||||
})
|
||||
setHandler(toHttp, new OutHandler {
|
||||
override def onPull(): Unit = pull(fromNet)
|
||||
override def onDownstreamFinish(): Unit = ()
|
||||
override def onDownstreamFinish(): Unit = cancel(fromNet)
|
||||
})
|
||||
|
||||
private var activeTimers = 0
|
||||
|
|
@ -438,13 +444,22 @@ private[http] object HttpServerBluePrint {
|
|||
|
||||
sinkIn.setHandler(new InHandler {
|
||||
override def onPush(): Unit = push(toNet, sinkIn.grab())
|
||||
override def onUpstreamFinish(): Unit = complete(toNet)
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
|
||||
})
|
||||
setHandler(toNet, new OutHandler {
|
||||
override def onPull(): Unit = sinkIn.pull()
|
||||
override def onDownstreamFinish(): Unit = {
|
||||
completeStage()
|
||||
sinkIn.cancel()
|
||||
sourceOut.complete()
|
||||
}
|
||||
})
|
||||
|
||||
setHandler(fromNet, new InHandler {
|
||||
override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes)
|
||||
override def onUpstreamFinish(): Unit = sourceOut.complete()
|
||||
override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex)
|
||||
})
|
||||
sourceOut.setHandler(new OutHandler {
|
||||
override def onPull(): Unit = {
|
||||
|
|
@ -454,6 +469,7 @@ private[http] object HttpServerBluePrint {
|
|||
override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet)
|
||||
})
|
||||
}
|
||||
override def onDownstreamFinish(): Unit = cancel(fromNet)
|
||||
})
|
||||
|
||||
Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer)
|
||||
|
|
|
|||
|
|
@ -1,50 +0,0 @@
|
|||
/**
|
||||
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
package akka.http.impl.engine.ws
|
||||
|
||||
import akka.stream.testkit.AkkaSpec
|
||||
import scala.concurrent.Await
|
||||
import com.typesafe.config.ConfigFactory
|
||||
import com.typesafe.config.Config
|
||||
import akka.actor.ActorSystem
|
||||
import akka.http.scaladsl.model.HttpRequest
|
||||
import akka.http.scaladsl.model.ws._
|
||||
import akka.http.scaladsl._
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream._
|
||||
import scala.concurrent.duration._
|
||||
import org.scalatest.concurrent.ScalaFutures
|
||||
import org.scalactic.ConversionCheckedTripleEquals
|
||||
import akka.stream.testkit.Utils
|
||||
|
||||
class BypassRouterSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode = off") with ScalaFutures with ConversionCheckedTripleEquals {
|
||||
|
||||
implicit val patience = PatienceConfig(3.seconds)
|
||||
import system.dispatcher
|
||||
implicit val materializer = ActorMaterializer()
|
||||
|
||||
"BypassRouter" must {
|
||||
|
||||
"work without double pull-ing some ports" in Utils.assertAllStagesStopped {
|
||||
val bindingFuture = Http().bindAndHandleSync({
|
||||
case HttpRequest(_, _, headers, _, _) ⇒
|
||||
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
||||
upgrade.handleMessages(Flow.apply, None)
|
||||
}, interface = "localhost", port = 8080)
|
||||
val binding = Await.result(bindingFuture, 3.seconds)
|
||||
|
||||
val N = 100
|
||||
val (response, count) = Http().singleWebsocketRequest(
|
||||
WebsocketRequest("ws://127.0.0.1:8080"),
|
||||
Flow.fromSinkAndSourceMat(
|
||||
Sink.fold(0)((n, _: Message) ⇒ n + 1),
|
||||
Source.repeat(TextMessage("hello")).take(N))(Keep.left))
|
||||
|
||||
count.futureValue should ===(N)
|
||||
binding.unbind()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
/**
|
||||
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
package akka.http.impl.engine.ws
|
||||
|
||||
import scala.concurrent.Await
|
||||
import scala.concurrent.duration.DurationInt
|
||||
import org.scalactic.ConversionCheckedTripleEquals
|
||||
import org.scalatest.concurrent.ScalaFutures
|
||||
import org.scalatest.time.Span.convertDurationToSpan
|
||||
import akka.http.scaladsl.Http
|
||||
import akka.http.scaladsl.model.HttpRequest
|
||||
import akka.http.scaladsl.model.Uri.apply
|
||||
import akka.http.scaladsl.model.ws._
|
||||
import akka.stream._
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream.testkit._
|
||||
import akka.stream.scaladsl.GraphDSL.Implicits._
|
||||
import org.scalatest.concurrent.Eventually
|
||||
import akka.stream.io.SslTlsPlacebo
|
||||
import java.net.InetSocketAddress
|
||||
import akka.stream.impl.fusing.GraphStages
|
||||
import akka.util.ByteString
|
||||
|
||||
class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode=off")
|
||||
with ScalaFutures with ConversionCheckedTripleEquals with Eventually {
|
||||
|
||||
implicit val patience = PatienceConfig(3.seconds)
|
||||
import system.dispatcher
|
||||
implicit val materializer = ActorMaterializer()
|
||||
|
||||
"A Websocket server" must {
|
||||
|
||||
"echo 100 elements and then shut down without error" in Utils.assertAllStagesStopped {
|
||||
val bindingFuture = Http().bindAndHandleSync({
|
||||
case HttpRequest(_, _, headers, _, _) ⇒
|
||||
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
||||
upgrade.handleMessages(Flow.apply, None)
|
||||
}, interface = "localhost", port = 8080)
|
||||
val binding = Await.result(bindingFuture, 3.seconds)
|
||||
|
||||
val N = 100
|
||||
val (response, count) = Http().singleWebsocketRequest(
|
||||
WebsocketRequest("ws://127.0.0.1:8080"),
|
||||
Flow.fromSinkAndSourceMat(
|
||||
Sink.fold(0)((n, _: Message) ⇒ n + 1),
|
||||
Source.repeat(TextMessage("hello")).take(N))(Keep.left))
|
||||
|
||||
count.futureValue should ===(N)
|
||||
binding.unbind()
|
||||
}
|
||||
|
||||
"send back 100 elements and then terminate without error even when not ordinarily closed" in Utils.assertAllStagesStopped {
|
||||
val N = 100
|
||||
|
||||
val handler = Flow.fromGraph(GraphDSL.create() { implicit b ⇒
|
||||
val merge = b.add(Merge[Int](2))
|
||||
|
||||
// convert to int so we can connect to merge
|
||||
val mapMsgToInt = b.add(Flow[Message].map(_ ⇒ -1))
|
||||
val mapIntToMsg = b.add(Flow[Int].map(x ⇒ TextMessage.Strict(s"Sending: $x")))
|
||||
|
||||
// source we want to use to send message to the connected websocket sink
|
||||
val rangeSource = b.add(Source(1 to N))
|
||||
|
||||
mapMsgToInt ~> merge // this part of the merge will never provide msgs
|
||||
rangeSource ~> merge ~> mapIntToMsg
|
||||
|
||||
FlowShape(mapMsgToInt.in, mapIntToMsg.out)
|
||||
})
|
||||
|
||||
val bindingFuture = Http().bindAndHandleSync({
|
||||
case HttpRequest(_, _, headers, _, _) ⇒
|
||||
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get
|
||||
upgrade.handleMessages(handler, None)
|
||||
}, interface = "localhost", port = 8080)
|
||||
val binding = Await.result(bindingFuture, 3.seconds)
|
||||
|
||||
@volatile var messages = 0
|
||||
val (breaker, completion) =
|
||||
Source.maybe
|
||||
.viaMat {
|
||||
Http().websocketClientLayer(WebsocketRequest("ws://localhost:8080"))
|
||||
.atop(SslTlsPlacebo.forScala)
|
||||
// the resource leak of #19398 existed only for severed websocket connections
|
||||
.atopMat(GraphStages.bidiBreaker[ByteString, ByteString])(Keep.right)
|
||||
.join(Tcp().outgoingConnection(new InetSocketAddress("localhost", 8080), halfClose = true))
|
||||
}(Keep.right)
|
||||
.toMat(Sink.foreach(_ ⇒ messages += 1))(Keep.both)
|
||||
.run()
|
||||
eventually(messages should ===(N))
|
||||
// breaker should have been fulfilled long ago
|
||||
breaker.value.get.get.complete()
|
||||
completion.futureValue
|
||||
|
||||
binding.unbind()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue