diff --git a/akka-contrib/src/main/scala/akka/contrib/pattern/ClusterSharding.scala b/akka-contrib/src/main/scala/akka/contrib/pattern/ClusterSharding.scala index 765f3a654c..86e5705814 100644 --- a/akka-contrib/src/main/scala/akka/contrib/pattern/ClusterSharding.scala +++ b/akka-contrib/src/main/scala/akka/contrib/pattern/ClusterSharding.scala @@ -601,7 +601,7 @@ class ShardRegion( val ageOrdering = Ordering.fromLessThan[Member] { (a, b) ⇒ a.isOlderThan(b) } var membersByAge: immutable.SortedSet[Member] = immutable.SortedSet.empty(ageOrdering) - var regions = Map.empty[ActorRef, ShardId] + var regions = Map.empty[ActorRef, Set[ShardId]] var regionByShard = Map.empty[ShardId, ActorRef] var entries = Map.empty[ActorRef, ShardId] var entriesByShard = Map.empty[ShardId, Set[ActorRef]] @@ -685,7 +685,7 @@ class ShardRegion( case _ ⇒ } regionByShard = regionByShard.updated(shard, ref) - regions = regions.updated(ref, shard) + regions = regions.updated(ref, regions.getOrElse(ref, Set.empty) + shard) if (ref != self) context.watch(ref) shardBuffers.get(shard) match { @@ -705,7 +705,10 @@ class ShardRegion( case BeginHandOff(shard) ⇒ log.debug("BeginHandOff shard [{}]", shard) if (regionByShard.contains(shard)) { - regions -= regionByShard(shard) + val regionRef = regionByShard(shard) + val updatedShards = regions(regionRef) - shard + if (updatedShards.isEmpty) regions -= regionRef + else regions = regions.updated(regionRef, updatedShards) regionByShard -= shard } sender() ! BeginHandOffAck(shard) @@ -745,9 +748,11 @@ class ShardRegion( if (coordinator.exists(_ == ref)) coordinator = None else if (regions.contains(ref)) { - val shard = regions(ref) - regionByShard -= shard + val shards = regions(ref) + regionByShard --= shards regions -= ref + if (log.isDebugEnabled) + log.debug("Region [{}] with shards [{}] terminated", ref, shards.mkString(", ")) } else if (entries.contains(ref)) { val shard = entries(ref) val newShardEntities = entriesByShard(shard) - ref @@ -1220,13 +1225,14 @@ class ShardCoordinator(handOffTimeout: FiniteDuration, rebalanceInterval: Finite } case RebalanceTick ⇒ - allocationStrategy.rebalance(persistentState.regions, rebalanceInProgress).foreach { shard ⇒ - rebalanceInProgress += shard - val rebalanceFromRegion = persistentState.shards(shard) - log.debug("Rebalance shard [{}] from [{}]", shard, rebalanceFromRegion) - context.actorOf(Props(classOf[RebalanceWorker], shard, rebalanceFromRegion, handOffTimeout, - persistentState.regions.keySet ++ persistentState.regionProxies)) - } + if (persistentState.regions.nonEmpty) + allocationStrategy.rebalance(persistentState.regions, rebalanceInProgress).foreach { shard ⇒ + rebalanceInProgress += shard + val rebalanceFromRegion = persistentState.shards(shard) + log.debug("Rebalance shard [{}] from [{}]", shard, rebalanceFromRegion) + context.actorOf(Props(classOf[RebalanceWorker], shard, rebalanceFromRegion, handOffTimeout, + persistentState.regions.keySet ++ persistentState.regionProxies)) + } case RebalanceDone(shard, ok) ⇒ rebalanceInProgress -= shard diff --git a/akka-contrib/src/multi-jvm/scala/akka/contrib/pattern/ClusterShardingSpec.scala b/akka-contrib/src/multi-jvm/scala/akka/contrib/pattern/ClusterShardingSpec.scala index 151c3dc495..f3fc476eeb 100644 --- a/akka-contrib/src/multi-jvm/scala/akka/contrib/pattern/ClusterShardingSpec.scala +++ b/akka-contrib/src/multi-jvm/scala/akka/contrib/pattern/ClusterShardingSpec.scala @@ -112,8 +112,8 @@ object ClusterShardingSpec extends MultiNodeConfig { } val shardResolver: ShardRegion.ShardResolver = msg ⇒ msg match { - case EntryEnvelope(id, _) ⇒ (id % 10).toString - case Get(id) ⇒ (id % 10).toString + case EntryEnvelope(id, _) ⇒ (id % 12).toString + case Get(id) ⇒ (id % 12).toString } //#counter-extractor @@ -222,6 +222,13 @@ class ClusterShardingSpec extends MultiNodeSpec(ClusterShardingSpec) with STMult region ! EntryEnvelope(2, Decrement) region ! Get(2) expectMsg(2) + + region ! EntryEnvelope(11, Increment) + region ! EntryEnvelope(12, Increment) + region ! Get(11) + expectMsg(1) + region ! Get(12) + expectMsg(1) } enterBarrier("second-update") runOn(first) { @@ -229,6 +236,14 @@ class ClusterShardingSpec extends MultiNodeSpec(ClusterShardingSpec) with STMult region ! Get(2) expectMsg(3) lastSender.path should be(node(second) / "user" / "counterRegion" / "2") + + region ! Get(11) + expectMsg(1) + // local on first + lastSender.path should be(region.path / "11") + region ! Get(12) + expectMsg(1) + lastSender.path should be(node(second) / "user" / "counterRegion" / "12") } enterBarrier("first-update") @@ -267,12 +282,20 @@ class ClusterShardingSpec extends MultiNodeSpec(ClusterShardingSpec) with STMult enterBarrier("crash-second") runOn(first) { - val probe = TestProbe() + val probe1 = TestProbe() awaitAssert { within(1.second) { - region.tell(Get(2), probe.ref) - probe.expectMsg(4) - probe.lastSender.path should be(region.path / "2") + region.tell(Get(2), probe1.ref) + probe1.expectMsg(4) + probe1.lastSender.path should be(region.path / "2") + } + } + val probe2 = TestProbe() + awaitAssert { + within(1.second) { + region.tell(Get(12), probe2.ref) + probe2.expectMsg(1) + probe2.lastSender.path should be(region.path / "12") } } }