diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala index 62832ed412..8be730f4b1 100644 --- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardCoordinator.scala @@ -257,7 +257,15 @@ object ShardCoordinator { // shards for each region regions: Map[ActorRef, Vector[ShardId]] = Map.empty, regionProxies: Set[ActorRef] = Set.empty, - unallocatedShards: Set[ShardId] = Set.empty) extends ClusterShardingSerializable { + unallocatedShards: Set[ShardId] = Set.empty, + rememberEntities: Boolean = false) extends ClusterShardingSerializable { + + def withRememberEntities(enabled: Boolean): State = { + if (enabled) + copy(rememberEntities = enabled) + else + copy(unallocatedShards = Set.empty, rememberEntities = enabled) + } def updated(event: DomainEvent): State = event match { case ShardRegionRegistered(region) ⇒ @@ -268,28 +276,34 @@ object ShardCoordinator { copy(regionProxies = regionProxies + proxy) case ShardRegionTerminated(region) ⇒ require(regions.contains(region), s"Terminated region $region not registered: $this") + val newUnallocatedShards = + if (rememberEntities) (unallocatedShards ++ regions(region)) else unallocatedShards copy( regions = regions - region, shards = shards -- regions(region), - unallocatedShards = unallocatedShards ++ regions(region)) + unallocatedShards = newUnallocatedShards) case ShardRegionProxyTerminated(proxy) ⇒ require(regionProxies.contains(proxy), s"Terminated region proxy $proxy not registered: $this") copy(regionProxies = regionProxies - proxy) case ShardHomeAllocated(shard, region) ⇒ require(regions.contains(region), s"Region $region not registered: $this") require(!shards.contains(shard), s"Shard [$shard] already allocated: $this") + val newUnallocatedShards = + if (rememberEntities) (unallocatedShards - shard) else unallocatedShards copy( shards = shards.updated(shard, region), regions = regions.updated(region, regions(region) :+ shard), - unallocatedShards = unallocatedShards - shard) + unallocatedShards = newUnallocatedShards) case ShardHomeDeallocated(shard) ⇒ require(shards.contains(shard), s"Shard [$shard] not allocated: $this") val region = shards(shard) require(regions.contains(region), s"Region $region for shard [$shard] not registered: $this") + val newUnallocatedShards = + if (rememberEntities) (unallocatedShards + shard) else unallocatedShards copy( shards = shards - shard, regions = regions.updated(region, regions(region).filterNot(_ == shard)), - unallocatedShards = unallocatedShards + shard) + unallocatedShards = newUnallocatedShards) } } @@ -381,7 +395,7 @@ abstract class ShardCoordinator(typeName: String, settings: ClusterShardingSetti val cluster = Cluster(context.system) val removalMargin = cluster.settings.DownRemovalMargin - var state = State.empty + var state = State.empty.withRememberEntities(settings.rememberEntities) var rebalanceInProgress = Set.empty[ShardId] var unAckedHostShards = Map.empty[ShardId, Cancellable] // regions that have requested handoff, for graceful shutdown @@ -653,7 +667,10 @@ abstract class ShardCoordinator(typeName: String, settings: ClusterShardingSetti unAckedHostShards = unAckedHostShards.updated(shard, cancel) } - def allocateShardHomes(): Unit = state.unallocatedShards.foreach { self ! GetShardHome(_) } + def allocateShardHomes(): Unit = { + if (settings.rememberEntities) + state.unallocatedShards.foreach { self ! GetShardHome(_) } + } def continueGetShardHome(shard: ShardId, region: ActorRef, getShardHomeSender: ActorRef): Unit = if (!rebalanceInProgress.contains(shard)) { @@ -747,6 +764,7 @@ class PersistentShardCoordinator(typeName: String, settings: ClusterShardingSett state = st case RecoveryCompleted ⇒ + state = state.withRememberEntities(settings.rememberEntities) watchStateActors() } @@ -798,6 +816,7 @@ class DDataShardCoordinator(typeName: String, settings: ClusterShardingSettings, implicit val node = Cluster(context.system) val CoordinatorStateKey = LWWRegisterKey[State](s"${typeName}CoordinatorState") + val initEmptyState = State.empty.withRememberEntities(settings.rememberEntities) node.subscribe(self, ClusterShuttingDown.getClass) @@ -809,7 +828,7 @@ class DDataShardCoordinator(typeName: String, settings: ClusterShardingSettings, // This state will drop all other messages since they will be retried def waitingForState: Receive = ({ case g @ GetSuccess(CoordinatorStateKey, _) ⇒ - state = g.get(CoordinatorStateKey).value + state = g.get(CoordinatorStateKey).value.withRememberEntities(settings.rememberEntities) context.become(waitingForStateInitialized) // note that watchStateActors may call update watchStateActors() @@ -880,7 +899,7 @@ class DDataShardCoordinator(typeName: String, settings: ClusterShardingSettings, def sendUpdate(evt: DomainEvent) = { val s = state.updated(evt) - replicator ! Update(CoordinatorStateKey, LWWRegister(State.empty), WriteMajority(updatingStateTimeout), Some(evt)) { reg ⇒ + replicator ! Update(CoordinatorStateKey, LWWRegister(initEmptyState), WriteMajority(updatingStateTimeout), Some(evt)) { reg ⇒ reg.withValue(s) } } diff --git a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala index 086ecbe3db..884000caff 100644 --- a/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala +++ b/akka-cluster-sharding/src/main/scala/akka/cluster/sharding/ShardRegion.scala @@ -752,8 +752,7 @@ class ShardRegion( None case None ⇒ throw new IllegalStateException("Shard must not be allocated to a proxy only ShardRegion") - } - ) + }) } } diff --git a/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingSpec.scala b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingSpec.scala index 0fb3defe07..68e196c159 100644 --- a/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingSpec.scala +++ b/akka-cluster-sharding/src/multi-jvm/scala/akka/cluster/sharding/ClusterShardingSpec.scala @@ -232,25 +232,26 @@ abstract class ClusterShardingSpec(config: ClusterShardingSpecConfig) extends Mu val replicator = system.actorOf(Replicator.props( ReplicatorSettings(system).withGossipInterval(1.second).withMaxDeltaElements(10)), "replicator") - def coordinatorProps(typeName: String, rebalanceEnabled: Boolean) = { + def coordinatorProps(typeName: String, rebalanceEnabled: Boolean, rememberEntities: Boolean) = { val allocationStrategy = new ShardCoordinator.LeastShardAllocationStrategy(rebalanceThreshold = 2, maxSimultaneousRebalance = 1) val cfg = ConfigFactory.parseString(s""" handoff-timeout = 10s shard-start-timeout = 10s rebalance-interval = ${if (rebalanceEnabled) "2s" else "3600s"} """).withFallback(system.settings.config.getConfig("akka.cluster.sharding")) - val settings = ClusterShardingSettings(cfg) + val settings = ClusterShardingSettings(cfg).withRememberEntities(rememberEntities) if (settings.stateStoreMode == "persistence") ShardCoordinator.props(typeName, settings, allocationStrategy) else ShardCoordinator.props(typeName, settings, allocationStrategy, replicator) } - List("counter", "rebalancingCounter", "PersistentCounterEntities", "AnotherPersistentCounter", - "PersistentCounter", "RebalancingPersistentCounter", "AutoMigrateRegionTest").foreach { typeName ⇒ + List("counter", "rebalancingCounter", "RememberCounterEntities", "AnotherRememberCounter", + "RememberCounter", "RebalancingRememberCounter", "AutoMigrateRememberRegionTest").foreach { typeName ⇒ val rebalanceEnabled = typeName.toLowerCase.startsWith("rebalancing") + val rememberEnabled = typeName.toLowerCase.contains("remember") val singletonProps = BackoffSupervisor.props( - childProps = coordinatorProps(typeName, rebalanceEnabled), + childProps = coordinatorProps(typeName, rebalanceEnabled, rememberEnabled), childName = "coordinator", minBackoff = 5.seconds, maxBackoff = 5.seconds, @@ -286,11 +287,11 @@ abstract class ClusterShardingSpec(config: ClusterShardingSpecConfig) extends Mu lazy val region = createRegion("counter", rememberEntities = false) lazy val rebalancingRegion = createRegion("rebalancingCounter", rememberEntities = false) - lazy val persistentEntitiesRegion = createRegion("PersistentCounterEntities", rememberEntities = true) - lazy val anotherPersistentRegion = createRegion("AnotherPersistentCounter", rememberEntities = true) - lazy val persistentRegion = createRegion("PersistentCounter", rememberEntities = true) - lazy val rebalancingPersistentRegion = createRegion("RebalancingPersistentCounter", rememberEntities = true) - lazy val autoMigrateRegion = createRegion("AutoMigrateRegionTest", rememberEntities = true) + lazy val persistentEntitiesRegion = createRegion("RememberCounterEntities", rememberEntities = true) + lazy val anotherPersistentRegion = createRegion("AnotherRememberCounter", rememberEntities = true) + lazy val persistentRegion = createRegion("RememberCounter", rememberEntities = true) + lazy val rebalancingPersistentRegion = createRegion("RebalancingRememberCounter", rememberEntities = true) + lazy val autoMigrateRegion = createRegion("AutoMigrateRememberRegionTest", rememberEntities = true) s"Cluster sharding ($mode)" must { @@ -529,7 +530,6 @@ abstract class ClusterShardingSpec(config: ClusterShardingSpecConfig) extends Mu } "rebalance to nodes with less shards" in within(60 seconds) { - runOn(fourth) { for (n ← 1 to 10) { rebalancingRegion ! EntityEnvelope(n, Increment) @@ -558,7 +558,6 @@ abstract class ClusterShardingSpec(config: ClusterShardingSpecConfig) extends Mu } enterBarrier("after-9") - } } @@ -811,7 +810,7 @@ abstract class ClusterShardingSpec(config: ClusterShardingSpecConfig) extends Mu autoMigrateRegion ! Get(1) expectMsg(2) - lastSender.path should ===(node(third) / "user" / "AutoMigrateRegionTestRegion" / "1" / "1") + lastSender.path should ===(node(third) / "user" / "AutoMigrateRememberRegionTestRegion" / "1" / "1") //Kill region 3 system.actorSelection(lastSender.path.parent.parent) ! PoisonPill @@ -821,7 +820,7 @@ abstract class ClusterShardingSpec(config: ClusterShardingSpecConfig) extends Mu // Wait for migration to happen //Test the shard, thus counter was moved onto node 4 and started. runOn(fourth) { - val counter1 = system.actorSelection(system / "AutoMigrateRegionTestRegion" / "1" / "1") + val counter1 = system.actorSelection(system / "AutoMigrateRememberRegionTestRegion" / "1" / "1") val probe = TestProbe() awaitAssert({ counter1.tell(Identify(1), probe.ref)