From 664e0a786cdd8082fe3cd316b245382c961d8dc8 Mon Sep 17 00:00:00 2001 From: Roland Kuhn Date: Wed, 16 Dec 2015 22:51:29 +0100 Subject: [PATCH] #19196 fix StageActorRef termination watch leak --- .../akka/http/scaladsl/ClientServerSpec.scala | 15 +++++- .../scala/akka/stream/stage/GraphStage.scala | 48 ++++++++++++++----- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala b/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala index 26705d52a6..02791b03bd 100644 --- a/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/scaladsl/ClientServerSpec.scala @@ -26,8 +26,9 @@ import akka.testkit.EventFilter import akka.util.ByteString import com.typesafe.config.{ Config, ConfigFactory } import org.scalatest.{ BeforeAndAfterAll, Matchers, WordSpec } +import org.scalatest.concurrent.ScalaFutures -class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { +class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll with ScalaFutures { val testConf: Config = ConfigFactory.parseString(""" akka.loggers = ["akka.testkit.TestEventListener"] akka.loglevel = ERROR @@ -37,6 +38,7 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { implicit val system = ActorSystem(getClass.getSimpleName, testConf) import system.dispatcher implicit val materializer = ActorMaterializer() + implicit val patience = PatienceConfig(3.seconds) val testConf2: Config = ConfigFactory.parseString("akka.stream.materializer.subscription-timeout.timeout = 1 s") @@ -86,6 +88,17 @@ class ClientServerSpec extends WordSpec with Matchers with BeforeAndAfterAll { } } + "properly terminate client when server is not running" in Utils.assertAllStagesStopped { + for (i ← 1 to 100) + withClue(s"iterator $i: ") { + Source.single(HttpRequest(HttpMethods.POST, "/test", List.empty, HttpEntity(MediaTypes.`text/plain`.withCharset(HttpCharsets.`UTF-8`), "buh"))) + .via(Http(actorSystem).outgoingConnection("localhost", 7777)) + .runWith(Sink.head) + .failed + .futureValue shouldBe a[StreamTcpException] + } + } + "run with bindAndHandleSync" in { val (_, hostname, port) = TestUtils.temporaryServerHostnameAndPort() val binding = Http().bindAndHandleSync(_ ⇒ HttpResponse(), hostname, port) diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 2a57396f85..69ee95be10 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -138,8 +138,16 @@ object GraphStageLogic { private[this] var behaviour = initialReceive /** INTERNAL API */ - private[akka] def internalReceive(pack: (ActorRef, Any)): Unit = - behaviour(pack) + private[akka] def internalReceive(pack: (ActorRef, Any)): Unit = { + pack._2 match { + case Terminated(ref) ⇒ + if (watching contains ref) { + watching -= ref + behaviour(pack) + } + case _ ⇒ behaviour(pack) + } + } override def !(message: Any)(implicit sender: ActorRef = Actor.noSender): Unit = { message match { @@ -167,6 +175,7 @@ object GraphStageLogic { behaviour = receive } + private[this] var watching = ActorCell.emptyActorRefSet private[this] val _watchedBy = new AtomicReference[Set[ActorRef]](ActorCell.emptyActorRefSet) override def isTerminated = _watchedBy.get() == StageTerminatedTombstone @@ -174,16 +183,25 @@ object GraphStageLogic { //noinspection EmptyCheck protected def sendTerminated(): Unit = { val watchedBy = _watchedBy.getAndSet(StageTerminatedTombstone) - if (!(watchedBy == StageTerminatedTombstone) && !watchedBy.isEmpty) { - - watchedBy foreach sendTerminated(ifLocal = false) - watchedBy foreach sendTerminated(ifLocal = true) + if (watchedBy != StageTerminatedTombstone) { + if (watchedBy.nonEmpty) { + watchedBy foreach sendTerminated(ifLocal = false) + watchedBy foreach sendTerminated(ifLocal = true) + } + if (watching.nonEmpty) { + watching foreach unwatchWatched + watching = Set.empty + } } } + private def sendTerminated(ifLocal: Boolean)(watcher: ActorRef): Unit = if (watcher.asInstanceOf[ActorRefScope].isLocal == ifLocal) watcher.asInstanceOf[InternalActorRef].sendSystemMessage(DeathWatchNotification(this, existenceConfirmed = true, addressTerminated = false)) + private def unwatchWatched(watched: ActorRef): Unit = + watched.asInstanceOf[InternalActorRef].sendSystemMessage(Unwatch(watched, this)) + override def stop(): Unit = sendTerminated() @tailrec final def addWatcher(watchee: ActorRef, watcher: ActorRef): Unit = @@ -201,7 +219,7 @@ object GraphStageLogic { if (!_watchedBy.compareAndSet(watchedBy, watchedBy + watcher)) addWatcher(watchee, watcher) // try again } else if (!watcheeSelf && watcherSelf) { - watch(watchee) + log.warning("externally triggered watch from {} to {} is illegal on StageActorRef", watcher, watchee) } else { log.error("BUG: illegal Watch(%s,%s) for %s".format(watchee, watcher, this)) } @@ -219,18 +237,22 @@ object GraphStageLogic { if (!_watchedBy.compareAndSet(watchedBy, watchedBy - watcher)) remWatcher(watchee, watcher) // try again } else if (!watcheeSelf && watcherSelf) { - unwatch(watchee) + log.warning("externally triggered unwatch from {} to {} is illegal on StageActorRef", watcher, watchee) } else { log.error("BUG: illegal Unwatch(%s,%s) for %s".format(watchee, watcher, this)) } } } - def watch(actorRef: ActorRef): Unit = + def watch(actorRef: ActorRef): Unit = { + watching += actorRef actorRef.asInstanceOf[InternalActorRef].sendSystemMessage(Watch(actorRef.asInstanceOf[InternalActorRef], this)) + } - def unwatch(actorRef: ActorRef): Unit = + def unwatch(actorRef: ActorRef): Unit = { + watching -= actorRef actorRef.asInstanceOf[InternalActorRef].sendSystemMessage(Unwatch(actorRef.asInstanceOf[InternalActorRef], this)) + } } object StageActorRef { type Receive = ((ActorRef, Any)) ⇒ Unit @@ -885,8 +907,10 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: // Internal hooks to avoid reliance on user calling super in postStop /** INTERNAL API */ protected[stream] def afterPostStop(): Unit = { - if (_stageActorRef ne null) _stageActorRef.stop() - _stageActorRef = null + if (_stageActorRef ne null) { + _stageActorRef.stop() + _stageActorRef = null + } } /**