diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala index c53c78ded0..1361637cc0 100644 --- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala @@ -669,6 +669,7 @@ abstract class ShardCoordinator( var unAckedHostShards = Map.empty[ShardId, Cancellable] // regions that have requested handoff, for graceful shutdown var gracefulShutdownInProgress = Set.empty[ActorRef] + var waitingForLocalRegionToTerminate = false var aliveRegions = Set.empty[ActorRef] var regionTerminationInProgress = Set.empty[ActorRef] @@ -847,17 +848,19 @@ abstract class ShardCoordinator( } case GracefulShutdownReq(region) => - if (!gracefulShutdownInProgress(region)) + if (!gracefulShutdownInProgress(region)) { state.regions.get(region) match { case Some(shards) => if (log.isDebugEnabled) { if (verboseDebug) log.debug( - "{}: Graceful shutdown of region [{}] with [{}] shards [{}]", - typeName, - region, - shards.size, - shards.mkString(", ")) + "{}: Graceful shutdown of {} region [{}] with [{}] shards [{}] started", + Array( + typeName, + if (region.path.address.hasLocalScope) "local" else "", + region, + shards.size, + shards.mkString(", "))) else log.debug("{}: Graceful shutdown of region [{}] with [{}] shards", typeName, region, shards.size) } @@ -867,6 +870,7 @@ abstract class ShardCoordinator( case None => log.debug("{}: Unknown region requested graceful shutdown [{}]", typeName, region) } + } case ShardRegion.GetClusterShardingStats(waitMax) => import akka.pattern.ask @@ -917,24 +921,40 @@ abstract class ShardCoordinator( sender() ! reply case ShardCoordinator.Internal.Terminate => - if (rebalanceInProgress.isEmpty) - log.debug("{}: Received termination message.", typeName) - else if (log.isDebugEnabled) { - if (verboseDebug) - log.debug( - "{}: Received termination message. Rebalance in progress of [{}] shards [{}].", - typeName, - rebalanceInProgress.size, - rebalanceInProgress.keySet.mkString(", ")) - else - log.debug( - "{}: Received termination message. Rebalance in progress of [{}] shards.", - typeName, - rebalanceInProgress.size) - } - context.stop(self) + terminate() }: Receive).orElse[Any, Unit](receiveTerminated) + private def terminate(): Unit = { + if (aliveRegions.exists(_.path.address.hasLocalScope) || gracefulShutdownInProgress.exists( + _.path.address.hasLocalScope)) { + aliveRegions + .find(_.path.address.hasLocalScope) + .foreach(region => + // region will get this from taking part in coordinated shutdown, but for good measure + region ! ShardRegion.GracefulShutdown) + + log.debug("{}: Deferring coordinator termination until local region has terminated", typeName) + waitingForLocalRegionToTerminate = true + } else { + if (rebalanceInProgress.isEmpty) + log.debug("{}: Received termination message.", typeName) + else if (log.isDebugEnabled) { + if (verboseDebug) + log.debug( + "{}: Received termination message. Rebalance in progress of [{}] shards [{}].", + typeName, + rebalanceInProgress.size, + rebalanceInProgress.keySet.mkString(", ")) + else + log.debug( + "{}: Received termination message. Rebalance in progress of [{}] shards.", + typeName, + rebalanceInProgress.size) + } + context.stop(self) + } + } + private def clearRebalanceInProgress(shard: String): Unit = { rebalanceInProgress.get(shard) match { case Some(pendingGetShardHome) => @@ -1058,7 +1078,13 @@ abstract class ShardCoordinator( def regionTerminated(ref: ActorRef): Unit = { rebalanceWorkers.foreach(_ ! RebalanceWorker.ShardRegionTerminated(ref)) if (state.regions.contains(ref)) { - log.debug("{}: ShardRegion terminated: [{}]", typeName, ref) + if (log.isDebugEnabled) { + log.debug( + "{}: ShardRegion terminated{}: [{}] {}", + typeName, + if (gracefulShutdownInProgress.contains(ref)) " (gracefully)" else "", + ref) + } regionTerminationInProgress += ref state.regions(ref).foreach { s => self.tell(GetShardHome(s), ignoreRef) @@ -1070,6 +1096,12 @@ abstract class ShardCoordinator( regionTerminationInProgress -= ref aliveRegions -= ref allocateShardHomesForRememberEntities() + if (ref.path.address.hasLocalScope && waitingForLocalRegionToTerminate) { + // handoff optimization: singleton told coordinator to stop but we deferred stop until the local region + // had completed the handoff + log.debug("{}: Local region stopped, terminating coordinator", typeName) + terminate() + } } } } diff --git a/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingGracefulShutdownOldestSpec.scala b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingGracefulShutdownOldestSpec.scala new file mode 100644 index 0000000000..67ca592b15 --- /dev/null +++ b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingGracefulShutdownOldestSpec.scala @@ -0,0 +1,172 @@ +/* + * Copyright (C) 2009-2021 Lightbend Inc. + */ + +package akka.cluster.sharding + +import scala.concurrent.duration._ +import akka.actor._ +import akka.cluster.sharding.ShardCoordinator.ShardAllocationStrategy +import akka.remote.testconductor.RoleName +import akka.testkit._ + +import scala.concurrent.Await + +abstract class ClusterShardingGracefulShutdownOldestSpecConfig(mode: String) + extends MultiNodeClusterShardingConfig( + mode, + additionalConfig = "akka.persistence.journal.leveldb-shared.store.native = off") { + val first = role("first") + val second = role("second") +} + +object ClusterShardingGracefulShutdownOldestSpec { + object TerminationOrderActor { + case object RegionTerminated + + case object CoordinatorTerminated + + def props(probe: ActorRef, coordinator: ActorRef, region: ActorRef) = + Props(new TerminationOrderActor(probe, coordinator, region)) + } + + class TerminationOrderActor(probe: ActorRef, coordinator: ActorRef, region: ActorRef) extends Actor { + + import TerminationOrderActor._ + + context.watch(coordinator) + context.watch(region) + + def receive = { + case Terminated(`coordinator`) => + probe ! CoordinatorTerminated + case Terminated(`region`) => + probe ! RegionTerminated + } + + } + + object SlowStopShardedEntity { + case object Stop + case object ActualStop + } + + // slow stop previously made it more likely that the coordinator would stop before the local region + class SlowStopShardedEntity extends Actor with Timers { + import SlowStopShardedEntity._ + + def receive: Receive = { + case id: Int => sender() ! id + case SlowStopShardedEntity.Stop => + timers.startSingleTimer(ActualStop, ActualStop, 50.millis) + case SlowStopShardedEntity.ActualStop => + context.stop(self) + } + } + +} + +object PersistentClusterShardingGracefulShutdownOldestSpecConfig + extends ClusterShardingGracefulShutdownOldestSpecConfig(ClusterShardingSettings.StateStoreModePersistence) +object DDataClusterShardingGracefulShutdownOldestSpecConfig + extends ClusterShardingGracefulShutdownOldestSpecConfig(ClusterShardingSettings.StateStoreModeDData) + +class PersistentClusterShardingGracefulShutdownOldestSpec + extends ClusterShardingGracefulShutdownOldestSpec(PersistentClusterShardingGracefulShutdownOldestSpecConfig) +class DDataClusterShardingGracefulShutdownOldestSpec + extends ClusterShardingGracefulShutdownOldestSpec(DDataClusterShardingGracefulShutdownOldestSpecConfig) + +class PersistentClusterShardingGracefulShutdownOldestMultiJvmNode1 + extends PersistentClusterShardingGracefulShutdownOldestSpec +class PersistentClusterShardingGracefulShutdownOldestMultiJvmNode2 + extends PersistentClusterShardingGracefulShutdownOldestSpec + +class DDataClusterShardingGracefulShutdownOldestMultiJvmNode1 extends DDataClusterShardingGracefulShutdownOldestSpec +class DDataClusterShardingGracefulShutdownOldestMultiJvmNode2 extends DDataClusterShardingGracefulShutdownOldestSpec + +abstract class ClusterShardingGracefulShutdownOldestSpec( + multiNodeConfig: ClusterShardingGracefulShutdownOldestSpecConfig) + extends MultiNodeClusterShardingSpec(multiNodeConfig) + with ImplicitSender { + + import ClusterShardingGracefulShutdownOldestSpec._ + import multiNodeConfig._ + + private val typeName = "Entity" + + def join(from: RoleName, to: RoleName, typeName: String): Unit = { + super.join(from, to) + runOn(from) { + startSharding(typeName) + } + enterBarrier(s"$from-started") + } + + def startSharding(typeName: String): ActorRef = + startSharding( + system, + typeName, + entityProps = Props[SlowStopShardedEntity](), + extractEntityId = MultiNodeClusterShardingSpec.intExtractEntityId, + extractShardId = MultiNodeClusterShardingSpec.intExtractShardId, + allocationStrategy = ShardAllocationStrategy.leastShardAllocationStrategy(absoluteLimit = 2, relativeLimit = 1.0), + handOffStopMessage = SlowStopShardedEntity.Stop) + + lazy val region = ClusterSharding(system).shardRegion(typeName) + + s"Cluster sharding ($mode)" must { + + "start some shards in both regions" in within(30.seconds) { + startPersistenceIfNeeded(startOn = first, setStoreOn = Seq(first, second)) + + join(first, first, typeName) + join(second, first, typeName) + + awaitAssert { + val p = TestProbe() + val regionAddresses = (1 to 100).map { n => + region.tell(n, p.ref) + p.expectMsg(1.second, n) + p.lastSender.path.address + }.toSet + regionAddresses.size should be(2) + } + enterBarrier("after-2") + } + + "gracefully shutdown the oldest region" in within(30.seconds) { + runOn(first) { + val coordinator = awaitAssert { + Await.result( + system + .actorSelection(s"/system/sharding/${typeName}Coordinator/singleton/coordinator") + .resolveOne(remainingOrDefault), + remainingOrDefault) + } + val terminationProbe = TestProbe() + system.actorOf(TerminationOrderActor.props(terminationProbe.ref, coordinator, region)) + + // trigger graceful shutdown + cluster.leave(address(first)) + + // region first + terminationProbe.expectMsg(TerminationOrderActor.RegionTerminated) + terminationProbe.expectMsg(TerminationOrderActor.CoordinatorTerminated) + } + enterBarrier("terminated") + + runOn(second) { + awaitAssert { + val p = TestProbe() + val responses = (1 to 100).map { n => + region.tell(n, p.ref) + p.expectMsg(1.second, n) + }.toSet + responses.size should be(100) + } + } + enterBarrier("done-o") + } + + } +}