From 21e8f89f5338a20a4f7f53c61583ab9e8db4e1b3 Mon Sep 17 00:00:00 2001 From: Patrik Nordwall Date: Wed, 19 Feb 2014 08:06:32 +0100 Subject: [PATCH] =con #3880 Keep track of all shards per region in ClusterSharding * The problem was that ShardRegion actor only kept track of one shard id per region actor. Therefore the Terminated message only removes one of the shards from its registry when there are multiple shards per region. * Added failing test and solved the problem by keeping track of all shards per region * Also, rebalance must not be done before any regions have been registered --- .../contrib/pattern/ClusterSharding.scala | 30 +++++++++------- .../contrib/pattern/ClusterShardingSpec.scala | 35 +++++++++++++++---- 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/akka-contrib/src/main/scala/akka/contrib/pattern/ClusterSharding.scala b/akka-contrib/src/main/scala/akka/contrib/pattern/ClusterSharding.scala index 765f3a654c..86e5705814 100644 --- a/akka-contrib/src/main/scala/akka/contrib/pattern/ClusterSharding.scala +++ b/akka-contrib/src/main/scala/akka/contrib/pattern/ClusterSharding.scala @@ -601,7 +601,7 @@ class ShardRegion( val ageOrdering = Ordering.fromLessThan[Member] { (a, b) ⇒ a.isOlderThan(b) } var membersByAge: immutable.SortedSet[Member] = immutable.SortedSet.empty(ageOrdering) - var regions = Map.empty[ActorRef, ShardId] + var regions = Map.empty[ActorRef, Set[ShardId]] var regionByShard = Map.empty[ShardId, ActorRef] var entries = Map.empty[ActorRef, ShardId] var entriesByShard = Map.empty[ShardId, Set[ActorRef]] @@ -685,7 +685,7 @@ class ShardRegion( case _ ⇒ } regionByShard = regionByShard.updated(shard, ref) - regions = regions.updated(ref, shard) + regions = regions.updated(ref, regions.getOrElse(ref, Set.empty) + shard) if (ref != self) context.watch(ref) shardBuffers.get(shard) match { @@ -705,7 +705,10 @@ class ShardRegion( case BeginHandOff(shard) ⇒ log.debug("BeginHandOff shard [{}]", shard) if (regionByShard.contains(shard)) { - regions -= regionByShard(shard) + val regionRef = regionByShard(shard) + val updatedShards = regions(regionRef) - shard + if (updatedShards.isEmpty) regions -= regionRef + else regions = regions.updated(regionRef, updatedShards) regionByShard -= shard } sender() ! BeginHandOffAck(shard) @@ -745,9 +748,11 @@ class ShardRegion( if (coordinator.exists(_ == ref)) coordinator = None else if (regions.contains(ref)) { - val shard = regions(ref) - regionByShard -= shard + val shards = regions(ref) + regionByShard --= shards regions -= ref + if (log.isDebugEnabled) + log.debug("Region [{}] with shards [{}] terminated", ref, shards.mkString(", ")) } else if (entries.contains(ref)) { val shard = entries(ref) val newShardEntities = entriesByShard(shard) - ref @@ -1220,13 +1225,14 @@ class ShardCoordinator(handOffTimeout: FiniteDuration, rebalanceInterval: Finite } case RebalanceTick ⇒ - allocationStrategy.rebalance(persistentState.regions, rebalanceInProgress).foreach { shard ⇒ - rebalanceInProgress += shard - val rebalanceFromRegion = persistentState.shards(shard) - log.debug("Rebalance shard [{}] from [{}]", shard, rebalanceFromRegion) - context.actorOf(Props(classOf[RebalanceWorker], shard, rebalanceFromRegion, handOffTimeout, - persistentState.regions.keySet ++ persistentState.regionProxies)) - } + if (persistentState.regions.nonEmpty) + allocationStrategy.rebalance(persistentState.regions, rebalanceInProgress).foreach { shard ⇒ + rebalanceInProgress += shard + val rebalanceFromRegion = persistentState.shards(shard) + log.debug("Rebalance shard [{}] from [{}]", shard, rebalanceFromRegion) + context.actorOf(Props(classOf[RebalanceWorker], shard, rebalanceFromRegion, handOffTimeout, + persistentState.regions.keySet ++ persistentState.regionProxies)) + } case RebalanceDone(shard, ok) ⇒ rebalanceInProgress -= shard diff --git a/akka-contrib/src/multi-jvm/scala/akka/contrib/pattern/ClusterShardingSpec.scala b/akka-contrib/src/multi-jvm/scala/akka/contrib/pattern/ClusterShardingSpec.scala index 151c3dc495..f3fc476eeb 100644 --- a/akka-contrib/src/multi-jvm/scala/akka/contrib/pattern/ClusterShardingSpec.scala +++ b/akka-contrib/src/multi-jvm/scala/akka/contrib/pattern/ClusterShardingSpec.scala @@ -112,8 +112,8 @@ object ClusterShardingSpec extends MultiNodeConfig { } val shardResolver: ShardRegion.ShardResolver = msg ⇒ msg match { - case EntryEnvelope(id, _) ⇒ (id % 10).toString - case Get(id) ⇒ (id % 10).toString + case EntryEnvelope(id, _) ⇒ (id % 12).toString + case Get(id) ⇒ (id % 12).toString } //#counter-extractor @@ -222,6 +222,13 @@ class ClusterShardingSpec extends MultiNodeSpec(ClusterShardingSpec) with STMult region ! EntryEnvelope(2, Decrement) region ! Get(2) expectMsg(2) + + region ! EntryEnvelope(11, Increment) + region ! EntryEnvelope(12, Increment) + region ! Get(11) + expectMsg(1) + region ! Get(12) + expectMsg(1) } enterBarrier("second-update") runOn(first) { @@ -229,6 +236,14 @@ class ClusterShardingSpec extends MultiNodeSpec(ClusterShardingSpec) with STMult region ! Get(2) expectMsg(3) lastSender.path should be(node(second) / "user" / "counterRegion" / "2") + + region ! Get(11) + expectMsg(1) + // local on first + lastSender.path should be(region.path / "11") + region ! Get(12) + expectMsg(1) + lastSender.path should be(node(second) / "user" / "counterRegion" / "12") } enterBarrier("first-update") @@ -267,12 +282,20 @@ class ClusterShardingSpec extends MultiNodeSpec(ClusterShardingSpec) with STMult enterBarrier("crash-second") runOn(first) { - val probe = TestProbe() + val probe1 = TestProbe() awaitAssert { within(1.second) { - region.tell(Get(2), probe.ref) - probe.expectMsg(4) - probe.lastSender.path should be(region.path / "2") + region.tell(Get(2), probe1.ref) + probe1.expectMsg(4) + probe1.lastSender.path should be(region.path / "2") + } + } + val probe2 = TestProbe() + awaitAssert { + within(1.second) { + region.tell(Get(12), probe2.ref) + probe2.expectMsg(1) + probe2.lastSender.path should be(region.path / "12") } } }