diff --git a/akka-remote/src/main/scala/akka/remote/artery/compress/InboundCompressions.scala b/akka-remote/src/main/scala/akka/remote/artery/compress/InboundCompressions.scala index c6deac61cf..75dd11bfe9 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/compress/InboundCompressions.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/compress/InboundCompressions.scala @@ -60,66 +60,97 @@ private[remote] final class InboundCompressionsImpl( private val stopped = new AtomicBoolean - private[this] val _actorRefsIns = new Long2ObjectHashMap[InboundActorRefCompression]() - private val createInboundActorRefsForOrigin = new LongFunction[InboundActorRefCompression] { - override def apply(originUid: Long): InboundActorRefCompression = { + // None is used as tombstone value after closed + // TOOD would be nice if we can cleanup the tombstones + private[this] val _actorRefsIns = new Long2ObjectHashMap[Option[InboundActorRefCompression]]() + private val createInboundActorRefsForOrigin = new LongFunction[Option[InboundActorRefCompression]] { + override def apply(originUid: Long): Option[InboundActorRefCompression] = { val actorRefHitters = new TopHeavyHitters[ActorRef](settings.ActorRefs.Max) - new InboundActorRefCompression(system, settings, originUid, inboundContext, actorRefHitters, stopped) + Some(new InboundActorRefCompression(system, settings, originUid, inboundContext, actorRefHitters, stopped)) } } - private def actorRefsIn(originUid: Long): InboundActorRefCompression = + private def actorRefsIn(originUid: Long): Option[InboundActorRefCompression] = _actorRefsIns.computeIfAbsent(originUid, createInboundActorRefsForOrigin) - private[this] val _classManifestsIns = new Long2ObjectHashMap[InboundManifestCompression]() - private val createInboundManifestsForOrigin = new LongFunction[InboundManifestCompression] { - override def apply(originUid: Long): InboundManifestCompression = { + // None is used as tombstone value after closed + private[this] val _classManifestsIns = new Long2ObjectHashMap[Option[InboundManifestCompression]]() + private val createInboundManifestsForOrigin = new LongFunction[Option[InboundManifestCompression]] { + override def apply(originUid: Long): Option[InboundManifestCompression] = { val manifestHitters = new TopHeavyHitters[String](settings.Manifests.Max) - new InboundManifestCompression(system, settings, originUid, inboundContext, manifestHitters, stopped) + Some(new InboundManifestCompression(system, settings, originUid, inboundContext, manifestHitters, stopped)) } } - private def classManifestsIn(originUid: Long): InboundManifestCompression = + private def classManifestsIn(originUid: Long): Option[InboundManifestCompression] = _classManifestsIns.computeIfAbsent(originUid, createInboundManifestsForOrigin) // actor ref compression --- override def decompressActorRef(originUid: Long, tableVersion: Int, idx: Int): OptionVal[ActorRef] = - actorRefsIn(originUid).decompress(tableVersion, idx) + actorRefsIn(originUid) match { + case Some(a) ⇒ a.decompress(tableVersion, idx) + case None ⇒ OptionVal.None + } + override def hitActorRef(originUid: Long, address: Address, ref: ActorRef, n: Int): Unit = { if (ArterySettings.Compression.Debug) println(s"[compress] hitActorRef($originUid, $address, $ref, $n)") - actorRefsIn(originUid).increment(address, ref, n) + actorRefsIn(originUid) match { + case Some(a) ⇒ a.increment(address, ref, n) + case None ⇒ // closed + } } override def confirmActorRefCompressionAdvertisement(originUid: Long, tableVersion: Int): Unit = { _actorRefsIns.get(originUid) match { - case null ⇒ // ignore, it was closed - case a ⇒ a.confirmAdvertisement(tableVersion) + case null ⇒ // ignore + case Some(a) ⇒ a.confirmAdvertisement(tableVersion) + case None ⇒ // closed } } // class manifest compression --- override def decompressClassManifest(originUid: Long, tableVersion: Int, idx: Int): OptionVal[String] = - classManifestsIn(originUid).decompress(tableVersion, idx) + classManifestsIn(originUid) match { + case Some(a) ⇒ a.decompress(tableVersion, idx) + case None ⇒ OptionVal.None + } + override def hitClassManifest(originUid: Long, address: Address, manifest: String, n: Int): Unit = { if (ArterySettings.Compression.Debug) println(s"[compress] hitClassManifest($originUid, $address, $manifest, $n)") - classManifestsIn(originUid).increment(address, manifest, n) + classManifestsIn(originUid) match { + case Some(a) ⇒ a.increment(address, manifest, n) + case None ⇒ // closed + } } override def confirmClassManifestCompressionAdvertisement(originUid: Long, tableVersion: Int): Unit = { _classManifestsIns.get(originUid) match { - case null ⇒ // ignore, it was closed - case a ⇒ a.confirmAdvertisement(tableVersion) + case null ⇒ // ignore + case Some(a) ⇒ a.confirmAdvertisement(tableVersion) + case None ⇒ // closed } } override def close(): Unit = stopped.set(true) override def close(originUid: Long): Unit = { - actorRefsIn(originUid).close() - classManifestsIn(originUid).close() - // FIXME This is not safe, it can be created again (concurrently), at least in theory. - // However, we should make the inbound compressions owned by the Decoder and it doesn't have to be thread-safe - _actorRefsIns.remove(originUid) - _classManifestsIns.remove(originUid) + _actorRefsIns.get(originUid) match { + case null ⇒ + if (_actorRefsIns.putIfAbsent(originUid, None) != null) + close(originUid) + case oldValue @ Some(a) ⇒ + if (_actorRefsIns.replace(originUid, oldValue, None)) + a.close() + case None ⇒ // already closed + } + _classManifestsIns.get(originUid) match { + case null ⇒ + if (_classManifestsIns.putIfAbsent(originUid, None) != null) + close(originUid) + case oldValue @ Some(a) ⇒ + if (_classManifestsIns.replace(originUid, oldValue, None)) + a.close() + case None ⇒ // already closed + } } // testing utilities --- @@ -127,13 +158,19 @@ private[remote] final class InboundCompressionsImpl( /** INTERNAL API: for testing only */ private[remote] def runNextActorRefAdvertisement() = { import scala.collection.JavaConverters._ - _actorRefsIns.values().asScala.foreach { inbound ⇒ inbound.runNextTableAdvertisement() } + _actorRefsIns.values().asScala.foreach { + case Some(inbound) ⇒ inbound.runNextTableAdvertisement() + case None ⇒ // closed + } } /** INTERNAL API: for testing only */ private[remote] def runNextClassManifestAdvertisement() = { import scala.collection.JavaConverters._ - _classManifestsIns.values().asScala.foreach { inbound ⇒ inbound.runNextTableAdvertisement() } + _classManifestsIns.values().asScala.foreach { + case Some(inbound) ⇒ inbound.runNextTableAdvertisement() + case None ⇒ // closed + } } }