From 228c19e688d3fef890048e02b910672205539b2d Mon Sep 17 00:00:00 2001 From: Patrik Nordwall Date: Tue, 19 May 2020 13:50:22 +0200 Subject: [PATCH] Allow ShardCoordinator to watch old region ActorRef that is not in cluster, #29034 * Otherwise the remote watch is disabled and the old region ActorRef remains in the coordinator's state --- .../ClusterShardCoordinatorDowningSpec.scala | 170 ++++++++++++++++++ .../akka/cluster/ClusterRemoteWatcher.scala | 17 +- .../akka/remote/RemoteActorRefProvider.scala | 2 +- .../scala/akka/remote/RemoteWatcher.scala | 4 +- 4 files changed, 189 insertions(+), 4 deletions(-) create mode 100644 akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardCoordinatorDowningSpec.scala diff --git a/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardCoordinatorDowningSpec.scala b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardCoordinatorDowningSpec.scala new file mode 100644 index 0000000000..aeaf33e254 --- /dev/null +++ b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardCoordinatorDowningSpec.scala @@ -0,0 +1,170 @@ +/* + * Copyright (C) 2020 Lightbend Inc. + */ + +package akka.cluster.sharding + +import scala.concurrent.duration._ + +import akka.actor.Actor +import akka.actor.ActorRef +import akka.actor.Props +import akka.cluster.MemberStatus +import akka.remote.transport.ThrottlerTransportAdapter.Direction +import akka.serialization.jackson.CborSerializable +import akka.testkit._ +import akka.util.ccompat._ + +@ccompatUsedUntil213 +object ClusterShardCoordinatorDowningSpec { + case class Ping(id: String) extends CborSerializable + + class Entity extends Actor { + def receive = { + case Ping(_) => sender() ! self + } + } + + case object GetLocations extends CborSerializable + case class Locations(locations: Map[String, ActorRef]) extends CborSerializable + + class ShardLocations extends Actor { + var locations: Locations = _ + def receive = { + case GetLocations => sender() ! locations + case l: Locations => locations = l + } + } + + val extractEntityId: ShardRegion.ExtractEntityId = { + case m @ Ping(id) => (id, m) + } + + val extractShardId: ShardRegion.ExtractShardId = { + case Ping(id: String) => id.charAt(0).toString + } +} + +abstract class ClusterShardCoordinatorDowningSpecConfig(mode: String) + extends MultiNodeClusterShardingConfig( + mode, + loglevel = "INFO", + additionalConfig = """ + akka.cluster.sharding.rebalance-interval = 120 s + akka.cluster.down-removal-margin = 3 s + akka.remote.watch-failure-detector.acceptable-heartbeat-pause = 3s + """) { + val controller = role("controller") + val first = role("first") + val second = role("second") + + testTransport(on = true) + +} + +object PersistentClusterShardCoordinatorDowningSpecConfig + extends ClusterShardCoordinatorDowningSpecConfig(ClusterShardingSettings.StateStoreModePersistence) +object DDataClusterShardCoordinatorDowningSpecConfig + extends ClusterShardCoordinatorDowningSpecConfig(ClusterShardingSettings.StateStoreModeDData) + +class PersistentClusterShardCoordinatorDowningSpec + extends ClusterShardCoordinatorDowningSpec(PersistentClusterShardCoordinatorDowningSpecConfig) +class DDataClusterShardCoordinatorDowningSpec + extends ClusterShardCoordinatorDowningSpec(DDataClusterShardCoordinatorDowningSpecConfig) + +class PersistentClusterShardCoordinatorDowningMultiJvmNode1 extends PersistentClusterShardCoordinatorDowningSpec +class PersistentClusterShardCoordinatorDowningMultiJvmNode2 extends PersistentClusterShardCoordinatorDowningSpec +class PersistentClusterShardCoordinatorDowningMultiJvmNode3 extends PersistentClusterShardCoordinatorDowningSpec + +class DDataClusterShardCoordinatorDowningMultiJvmNode1 extends DDataClusterShardCoordinatorDowningSpec +class DDataClusterShardCoordinatorDowningMultiJvmNode2 extends DDataClusterShardCoordinatorDowningSpec +class DDataClusterShardCoordinatorDowningMultiJvmNode3 extends DDataClusterShardCoordinatorDowningSpec + +abstract class ClusterShardCoordinatorDowningSpec(multiNodeConfig: ClusterShardCoordinatorDowningSpecConfig) + extends MultiNodeClusterShardingSpec(multiNodeConfig) + with ImplicitSender { + import multiNodeConfig._ + + import ClusterShardCoordinatorDowningSpec._ + + def startSharding(): Unit = { + startSharding( + system, + typeName = "Entity", + entityProps = Props[Entity](), + extractEntityId = extractEntityId, + extractShardId = extractShardId) + } + + lazy val region = ClusterSharding(system).shardRegion("Entity") + + s"Cluster sharding ($mode) with leaving member" must { + + "join cluster" in within(20.seconds) { + startPersistenceIfNotDdataMode(startOn = controller, setStoreOn = Seq(first, second)) + + join(first, first, onJoinedRunOnFrom = startSharding()) + join(second, first, onJoinedRunOnFrom = startSharding(), assertNodeUp = false) + + // all Up, everywhere before continuing + runOn(first, second) { + awaitAssert { + cluster.state.members.size should ===(2) + cluster.state.members.unsorted.map(_.status) should ===(Set(MemberStatus.Up)) + } + } + + enterBarrier("after-2") + } + + "initialize shards" in { + runOn(first) { + val shardLocations = system.actorOf(Props[ShardLocations](), "shardLocations") + val locations = (for (n <- 1 to 4) yield { + val id = n.toString + region ! Ping(id) + id -> expectMsgType[ActorRef] + }).toMap + shardLocations ! Locations(locations) + system.log.debug("Original locations: {}", locations) + } + enterBarrier("after-3") + } + + "recover after downing coordinator node" in within(20.seconds) { + val firstAddress = address(first) + system.actorSelection(node(first) / "user" / "shardLocations") ! GetLocations + val Locations(originalLocations) = expectMsgType[Locations] + + runOn(controller) { + testConductor.blackhole(first, second, Direction.Both).await + } + + Thread.sleep(3000) + + runOn(second) { + cluster.down(first) + awaitAssert { + cluster.state.members.size should ===(1) + } + + awaitAssert { + val probe = TestProbe() + originalLocations.foreach { + case (id, ref) => + region.tell(Ping(id), probe.ref) + if (ref.path.address == firstAddress) { + val newRef = probe.expectMsgType[ActorRef](1.second) + newRef should not be (ref) + system.log.debug("Moved [{}] from [{}] to [{}]", id, ref, newRef) + } else + probe.expectMsg(1.second, ref) // should not move + } + } + } + + enterBarrier("after-4") + } + + } +} diff --git a/akka-cluster/src/main/scala/akka/cluster/ClusterRemoteWatcher.scala b/akka-cluster/src/main/scala/akka/cluster/ClusterRemoteWatcher.scala index 2958e8e172..fc7a7770ca 100644 --- a/akka-cluster/src/main/scala/akka/cluster/ClusterRemoteWatcher.scala +++ b/akka-cluster/src/main/scala/akka/cluster/ClusterRemoteWatcher.scala @@ -70,6 +70,9 @@ private[cluster] class ClusterRemoteWatcher( override val log = Logging(context.system, ActorWithLogClass(this, ClusterLogClass.ClusterCore)) + // allowed to watch even though address not in cluster membership, i.e. remote watch + private val watchPathWhitelist = Set("/system/sharding/") + private var pendingDelayedQuarantine: Set[UniqueAddress] = Set.empty var clusterNodes: Set[Address] = Set.empty @@ -164,7 +167,19 @@ private[cluster] class ClusterRemoteWatcher( if (!clusterNodes(watchee.path.address)) super.watchNode(watchee) override protected def shouldWatch(watchee: InternalActorRef): Boolean = - clusterNodes(watchee.path.address) || super.shouldWatch(watchee) + clusterNodes(watchee.path.address) || super.shouldWatch(watchee) || isWatchOutsideClusterAllowed(watchee) + + /** + * Allowed to watch some paths even though address not in cluster membership, i.e. remote watch. + * Needed for ShardCoordinator that has to watch old incarnations of region ActorRef from the + * recovered state. + */ + private def isWatchOutsideClusterAllowed(watchee: InternalActorRef): Boolean = { + context.system.name == watchee.path.address.system && { + val pathPrefix = watchee.path.elements.take(2).mkString("/", "/", "/") + watchPathWhitelist.contains(pathPrefix) + } + } /** * When a cluster node is added this class takes over the diff --git a/akka-remote/src/main/scala/akka/remote/RemoteActorRefProvider.scala b/akka-remote/src/main/scala/akka/remote/RemoteActorRefProvider.scala index b83ab570d5..889c37a174 100644 --- a/akka-remote/src/main/scala/akka/remote/RemoteActorRefProvider.scala +++ b/akka-remote/src/main/scala/akka/remote/RemoteActorRefProvider.scala @@ -715,7 +715,7 @@ private[akka] class RemoteActorRef private[akka] ( else if (provider.remoteWatcher.isDefined) remote.send(message, OptionVal.None, this) else - provider.warnIfUnsafeDeathwatchWithoutCluster(watchee, watcher, "remote Watch") + provider.warnIfUnsafeDeathwatchWithoutCluster(watchee, watcher, "Watch") //Unwatch has a different signature, need to pattern match arguments against InternalActorRef case Unwatch(watchee: InternalActorRef, watcher: InternalActorRef) => diff --git a/akka-remote/src/main/scala/akka/remote/RemoteWatcher.scala b/akka-remote/src/main/scala/akka/remote/RemoteWatcher.scala index a727ef21aa..5eaf45a864 100644 --- a/akka-remote/src/main/scala/akka/remote/RemoteWatcher.scala +++ b/akka-remote/src/main/scala/akka/remote/RemoteWatcher.scala @@ -221,7 +221,7 @@ private[akka] class RemoteWatcher( // add watch from self, this will actually send a Watch to the target when necessary context.watch(watchee) - } else remoteProvider.warnIfUnsafeDeathwatchWithoutCluster(watcher, watchee, "Watch") + } else remoteProvider.warnIfUnsafeDeathwatchWithoutCluster(watchee, watcher, "Watch") } def watchNode(watchee: InternalActorRef): Unit = { @@ -250,7 +250,7 @@ private[akka] class RemoteWatcher( } case None => } - } else remoteProvider.warnIfUnsafeDeathwatchWithoutCluster(watcher, watchee, "Unwatch") + } else remoteProvider.warnIfUnsafeDeathwatchWithoutCluster(watchee, watcher, "Unwatch") } def removeWatchee(watchee: InternalActorRef): Unit = {