chore: avoid the double evaluation of entityId in ClusterSharding (#1304)

* chore: avoid the double evaluation of entityId in ClusterSharding

* new cacheable partial function

* optimized for review

* fix the right type
This commit is contained in:
AndyChen(Jingzhang) 2024-06-05 23:23:33 +08:00 committed by GitHub
parent 67211737da
commit b0e9886439
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 33 additions and 13 deletions

View file

@ -20,6 +20,7 @@ import java.util.concurrent.CompletionStage
import java.util.concurrent.ConcurrentHashMap
import scala.concurrent.Future
import scala.runtime.AbstractPartialFunction
import org.apache.pekko
import pekko.actor.ActorRefProvider
@ -172,10 +173,16 @@ import pekko.util.JavaDurationConverters._
allocationStrategy: Option[ShardAllocationStrategy]): ActorRef[E] = {
val extractorAdapter = new ExtractorAdapter(extractor)
val extractEntityId: ShardRegion.ExtractEntityId = {
// TODO is it possible to avoid the double evaluation of entityId
case message if extractorAdapter.entityId(message) != null =>
(extractorAdapter.entityId(message), extractorAdapter.unwrapMessage(message))
// !!!important is only applicable if you know that isDefinedAt(x) is always called before apply(x) (with the same x)
val extractEntityId: ShardRegion.ExtractEntityId = new AbstractPartialFunction[Any, (String, Any)] {
var cache: String = _
override def isDefinedAt(msg: Any): Boolean = {
cache = extractorAdapter.entityId(msg)
cache != null
}
override def apply(x: Any): (String, Any) = (cache, extractorAdapter.unwrapMessage(x))
}
val extractShardId: ShardRegion.ExtractShardId = { message =>
extractorAdapter.entityId(message) match {

View file

@ -19,6 +19,7 @@ import java.util.concurrent.ConcurrentHashMap
import scala.collection.immutable
import scala.concurrent.Await
import scala.runtime.AbstractPartialFunction
import scala.util.control.NonFatal
import org.apache.pekko
@ -429,15 +430,26 @@ class ClusterSharding(system: ExtendedActorSystem) extends Extension {
typeName,
_ => entityProps,
settings,
extractEntityId = {
case msg if messageExtractor.entityId(msg) ne null =>
(messageExtractor.entityId(msg), messageExtractor.entityMessage(msg))
},
extractEntityId = extractEntityIdFromExtractor(messageExtractor),
extractShardId = msg => messageExtractor.shardId(msg),
allocationStrategy = allocationStrategy,
handOffStopMessage = handOffStopMessage)
}
// !!!important is only applicable if you know that isDefinedAt(x) is always called before apply(x) (with the same x)
private def extractEntityIdFromExtractor(
messageExtractor: ShardRegion.MessageExtractor): ShardRegion.ExtractEntityId =
new AbstractPartialFunction[Any, (String, Any)] {
var cache: String = _
override def isDefinedAt(msg: Any): Boolean = {
cache = messageExtractor.entityId(msg)
cache != null
}
override def apply(x: Any): (String, Any) = (cache, messageExtractor.entityMessage(x))
}
/**
* Java/Scala API: Register a named entity type by defining the [[pekko.actor.Props]] of the entity actor
* and functions to extract entity and shard identifier from messages. The [[ShardRegion]] actor
@ -612,11 +624,12 @@ class ClusterSharding(system: ExtendedActorSystem) extends Extension {
dataCenter: Optional[String],
messageExtractor: ShardRegion.MessageExtractor): ActorRef = {
startProxy(typeName, Option(role.orElse(null)), Option(dataCenter.orElse(null)),
extractEntityId = {
case msg if messageExtractor.entityId(msg) ne null =>
(messageExtractor.entityId(msg), messageExtractor.entityMessage(msg))
}, extractShardId = msg => messageExtractor.shardId(msg))
startProxy(
typeName,
Option(role.orElse(null)),
Option(dataCenter.orElse(null)),
extractEntityId = extractEntityIdFromExtractor(messageExtractor),
msg => messageExtractor.shardId(msg))
}