=con #3880 Keep track of all shards per region in ClusterSharding
* The problem was that ShardRegion actor only kept track of one shard id per region actor. Therefore the Terminated message only removes one of the shards from its registry when there are multiple shards per region. * Added failing test and solved the problem by keeping track of all shards per region * Also, rebalance must not be done before any regions have been registered
This commit is contained in:
parent
a3b8f51079
commit
21e8f89f53
2 changed files with 47 additions and 18 deletions
|
|
@ -601,7 +601,7 @@ class ShardRegion(
|
|||
val ageOrdering = Ordering.fromLessThan[Member] { (a, b) ⇒ a.isOlderThan(b) }
|
||||
var membersByAge: immutable.SortedSet[Member] = immutable.SortedSet.empty(ageOrdering)
|
||||
|
||||
var regions = Map.empty[ActorRef, ShardId]
|
||||
var regions = Map.empty[ActorRef, Set[ShardId]]
|
||||
var regionByShard = Map.empty[ShardId, ActorRef]
|
||||
var entries = Map.empty[ActorRef, ShardId]
|
||||
var entriesByShard = Map.empty[ShardId, Set[ActorRef]]
|
||||
|
|
@ -685,7 +685,7 @@ class ShardRegion(
|
|||
case _ ⇒
|
||||
}
|
||||
regionByShard = regionByShard.updated(shard, ref)
|
||||
regions = regions.updated(ref, shard)
|
||||
regions = regions.updated(ref, regions.getOrElse(ref, Set.empty) + shard)
|
||||
if (ref != self)
|
||||
context.watch(ref)
|
||||
shardBuffers.get(shard) match {
|
||||
|
|
@ -705,7 +705,10 @@ class ShardRegion(
|
|||
case BeginHandOff(shard) ⇒
|
||||
log.debug("BeginHandOff shard [{}]", shard)
|
||||
if (regionByShard.contains(shard)) {
|
||||
regions -= regionByShard(shard)
|
||||
val regionRef = regionByShard(shard)
|
||||
val updatedShards = regions(regionRef) - shard
|
||||
if (updatedShards.isEmpty) regions -= regionRef
|
||||
else regions = regions.updated(regionRef, updatedShards)
|
||||
regionByShard -= shard
|
||||
}
|
||||
sender() ! BeginHandOffAck(shard)
|
||||
|
|
@ -745,9 +748,11 @@ class ShardRegion(
|
|||
if (coordinator.exists(_ == ref))
|
||||
coordinator = None
|
||||
else if (regions.contains(ref)) {
|
||||
val shard = regions(ref)
|
||||
regionByShard -= shard
|
||||
val shards = regions(ref)
|
||||
regionByShard --= shards
|
||||
regions -= ref
|
||||
if (log.isDebugEnabled)
|
||||
log.debug("Region [{}] with shards [{}] terminated", ref, shards.mkString(", "))
|
||||
} else if (entries.contains(ref)) {
|
||||
val shard = entries(ref)
|
||||
val newShardEntities = entriesByShard(shard) - ref
|
||||
|
|
@ -1220,13 +1225,14 @@ class ShardCoordinator(handOffTimeout: FiniteDuration, rebalanceInterval: Finite
|
|||
}
|
||||
|
||||
case RebalanceTick ⇒
|
||||
allocationStrategy.rebalance(persistentState.regions, rebalanceInProgress).foreach { shard ⇒
|
||||
rebalanceInProgress += shard
|
||||
val rebalanceFromRegion = persistentState.shards(shard)
|
||||
log.debug("Rebalance shard [{}] from [{}]", shard, rebalanceFromRegion)
|
||||
context.actorOf(Props(classOf[RebalanceWorker], shard, rebalanceFromRegion, handOffTimeout,
|
||||
persistentState.regions.keySet ++ persistentState.regionProxies))
|
||||
}
|
||||
if (persistentState.regions.nonEmpty)
|
||||
allocationStrategy.rebalance(persistentState.regions, rebalanceInProgress).foreach { shard ⇒
|
||||
rebalanceInProgress += shard
|
||||
val rebalanceFromRegion = persistentState.shards(shard)
|
||||
log.debug("Rebalance shard [{}] from [{}]", shard, rebalanceFromRegion)
|
||||
context.actorOf(Props(classOf[RebalanceWorker], shard, rebalanceFromRegion, handOffTimeout,
|
||||
persistentState.regions.keySet ++ persistentState.regionProxies))
|
||||
}
|
||||
|
||||
case RebalanceDone(shard, ok) ⇒
|
||||
rebalanceInProgress -= shard
|
||||
|
|
|
|||
|
|
@ -112,8 +112,8 @@ object ClusterShardingSpec extends MultiNodeConfig {
|
|||
}
|
||||
|
||||
val shardResolver: ShardRegion.ShardResolver = msg ⇒ msg match {
|
||||
case EntryEnvelope(id, _) ⇒ (id % 10).toString
|
||||
case Get(id) ⇒ (id % 10).toString
|
||||
case EntryEnvelope(id, _) ⇒ (id % 12).toString
|
||||
case Get(id) ⇒ (id % 12).toString
|
||||
}
|
||||
//#counter-extractor
|
||||
|
||||
|
|
@ -222,6 +222,13 @@ class ClusterShardingSpec extends MultiNodeSpec(ClusterShardingSpec) with STMult
|
|||
region ! EntryEnvelope(2, Decrement)
|
||||
region ! Get(2)
|
||||
expectMsg(2)
|
||||
|
||||
region ! EntryEnvelope(11, Increment)
|
||||
region ! EntryEnvelope(12, Increment)
|
||||
region ! Get(11)
|
||||
expectMsg(1)
|
||||
region ! Get(12)
|
||||
expectMsg(1)
|
||||
}
|
||||
enterBarrier("second-update")
|
||||
runOn(first) {
|
||||
|
|
@ -229,6 +236,14 @@ class ClusterShardingSpec extends MultiNodeSpec(ClusterShardingSpec) with STMult
|
|||
region ! Get(2)
|
||||
expectMsg(3)
|
||||
lastSender.path should be(node(second) / "user" / "counterRegion" / "2")
|
||||
|
||||
region ! Get(11)
|
||||
expectMsg(1)
|
||||
// local on first
|
||||
lastSender.path should be(region.path / "11")
|
||||
region ! Get(12)
|
||||
expectMsg(1)
|
||||
lastSender.path should be(node(second) / "user" / "counterRegion" / "12")
|
||||
}
|
||||
enterBarrier("first-update")
|
||||
|
||||
|
|
@ -267,12 +282,20 @@ class ClusterShardingSpec extends MultiNodeSpec(ClusterShardingSpec) with STMult
|
|||
enterBarrier("crash-second")
|
||||
|
||||
runOn(first) {
|
||||
val probe = TestProbe()
|
||||
val probe1 = TestProbe()
|
||||
awaitAssert {
|
||||
within(1.second) {
|
||||
region.tell(Get(2), probe.ref)
|
||||
probe.expectMsg(4)
|
||||
probe.lastSender.path should be(region.path / "2")
|
||||
region.tell(Get(2), probe1.ref)
|
||||
probe1.expectMsg(4)
|
||||
probe1.lastSender.path should be(region.path / "2")
|
||||
}
|
||||
}
|
||||
val probe2 = TestProbe()
|
||||
awaitAssert {
|
||||
within(1.second) {
|
||||
region.tell(Get(12), probe2.ref)
|
||||
probe2.expectMsg(1)
|
||||
probe2.lastSender.path should be(region.path / "12")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue