=cls #17261 Use persistent shard only when rememberEntities=true
This commit is contained in:
parent
d02b003628
commit
985510d6ac
1 changed files with 120 additions and 62 deletions
|
|
@ -12,6 +12,8 @@ import akka.actor.Terminated
|
|||
import akka.cluster.sharding.Shard.ShardCommand
|
||||
import akka.persistence.PersistentActor
|
||||
import akka.persistence.SnapshotOffer
|
||||
import akka.actor.Actor
|
||||
import akka.persistence.RecoveryCompleted
|
||||
|
||||
/**
|
||||
* INTERNAL API
|
||||
|
|
@ -58,6 +60,8 @@ private[akka] object Shard {
|
|||
|
||||
/**
|
||||
* Factory method for the [[akka.actor.Props]] of the [[Shard]] actor.
|
||||
* If `settings.rememberEntities` is enabled the `PersistentShard`
|
||||
* subclass is used, otherwise `Shard`.
|
||||
*/
|
||||
def props(typeName: String,
|
||||
shardId: ShardRegion.ShardId,
|
||||
|
|
@ -65,9 +69,14 @@ private[akka] object Shard {
|
|||
settings: ClusterShardingSettings,
|
||||
extractEntityId: ShardRegion.ExtractEntityId,
|
||||
extractShardId: ShardRegion.ExtractShardId,
|
||||
handOffStopMessage: Any): Props =
|
||||
Props(new Shard(typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage))
|
||||
.withDeploy(Deploy.local)
|
||||
handOffStopMessage: Any): Props = {
|
||||
if (settings.rememberEntities)
|
||||
Props(new PersistentShard(typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage))
|
||||
.withDeploy(Deploy.local)
|
||||
else
|
||||
Props(new Shard(typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage))
|
||||
.withDeploy(Deploy.local)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -85,57 +94,31 @@ private[akka] class Shard(
|
|||
settings: ClusterShardingSettings,
|
||||
extractEntityId: ShardRegion.ExtractEntityId,
|
||||
extractShardId: ShardRegion.ExtractShardId,
|
||||
handOffStopMessage: Any) extends PersistentActor with ActorLogging {
|
||||
handOffStopMessage: Any) extends Actor with ActorLogging {
|
||||
|
||||
import ShardRegion.{ handOffStopperProps, EntityId, Msg, Passivate }
|
||||
import ShardCoordinator.Internal.{ HandOff, ShardStopped }
|
||||
import Shard.{ State, RestartEntity, EntityStopped, EntityStarted }
|
||||
import akka.cluster.sharding.ShardCoordinator.Internal.CoordinatorMessage
|
||||
import akka.cluster.sharding.ShardRegion.ShardRegionCommand
|
||||
import akka.persistence.RecoveryCompleted
|
||||
|
||||
import settings.rememberEntities
|
||||
import settings.tuningParameters._
|
||||
|
||||
override def persistenceId = s"/sharding/${typeName}Shard/${shardId}"
|
||||
|
||||
override def journalPluginId: String = settings.journalPluginId
|
||||
|
||||
override def snapshotPluginId: String = settings.snapshotPluginId
|
||||
|
||||
var state = State.Empty
|
||||
var idByRef = Map.empty[ActorRef, EntityId]
|
||||
var refById = Map.empty[EntityId, ActorRef]
|
||||
var passivating = Set.empty[ActorRef]
|
||||
var messageBuffers = Map.empty[EntityId, Vector[(Msg, ActorRef)]]
|
||||
var persistCount = 0
|
||||
|
||||
var handOffStopper: Option[ActorRef] = None
|
||||
|
||||
def totalBufferSize = messageBuffers.foldLeft(0) { (sum, entity) ⇒ sum + entity._2.size }
|
||||
|
||||
def processChange[A](event: A)(handler: A ⇒ Unit): Unit =
|
||||
if (rememberEntities) {
|
||||
saveSnapshotWhenNeeded()
|
||||
persist(event)(handler)
|
||||
} else handler(event)
|
||||
handler(event)
|
||||
|
||||
def saveSnapshotWhenNeeded(): Unit = {
|
||||
persistCount += 1
|
||||
if (persistCount % snapshotAfter == 0) {
|
||||
log.debug("Saving snapshot, sequence number [{}]", snapshotSequenceNr)
|
||||
saveSnapshot(state)
|
||||
}
|
||||
}
|
||||
def receive = receiveCommand
|
||||
|
||||
override def receiveRecover: Receive = {
|
||||
case EntityStarted(id) if rememberEntities ⇒ state = state.copy(state.entities + id)
|
||||
case EntityStopped(id) if rememberEntities ⇒ state = state.copy(state.entities - id)
|
||||
case SnapshotOffer(_, snapshot: State) ⇒ state = snapshot
|
||||
case RecoveryCompleted ⇒ state.entities foreach getEntity
|
||||
}
|
||||
|
||||
override def receiveCommand: Receive = {
|
||||
def receiveCommand: Receive = {
|
||||
case Terminated(ref) ⇒ receiveTerminated(ref)
|
||||
case msg: CoordinatorMessage ⇒ receiveCoordinatorMessage(msg)
|
||||
case msg: ShardCommand ⇒ receiveShardCommand(msg)
|
||||
|
|
@ -178,25 +161,22 @@ private[akka] class Shard(
|
|||
}
|
||||
|
||||
def receiveTerminated(ref: ActorRef): Unit = {
|
||||
if (handOffStopper.exists(_ == ref)) {
|
||||
if (handOffStopper.exists(_ == ref))
|
||||
context stop self
|
||||
} else if (idByRef.contains(ref) && handOffStopper.isEmpty) {
|
||||
val id = idByRef(ref)
|
||||
if (messageBuffers.getOrElse(id, Vector.empty).nonEmpty) {
|
||||
//Note; because we're not persisting the EntityStopped, we don't need
|
||||
// to persist the EntityStarted either.
|
||||
log.debug("Starting entity [{}] again, there are buffered messages for it", id)
|
||||
sendMsgBuffer(EntityStarted(id))
|
||||
} else {
|
||||
if (rememberEntities && !passivating.contains(ref)) {
|
||||
log.debug("Entity [{}] stopped without passivating, will restart after backoff", id)
|
||||
import context.dispatcher
|
||||
context.system.scheduler.scheduleOnce(entityRestartBackoff, self, RestartEntity(id))
|
||||
} else processChange(EntityStopped(id))(passivateCompleted)
|
||||
}
|
||||
else if (idByRef.contains(ref) && handOffStopper.isEmpty)
|
||||
entityTerminated(ref)
|
||||
}
|
||||
|
||||
passivating = passivating - ref
|
||||
def entityTerminated(ref: ActorRef): Unit = {
|
||||
val id = idByRef(ref)
|
||||
if (messageBuffers.getOrElse(id, Vector.empty).nonEmpty) {
|
||||
log.debug("Starting entity [{}] again, there are buffered messages for it", id)
|
||||
sendMsgBuffer(EntityStarted(id))
|
||||
} else {
|
||||
processChange(EntityStopped(id))(passivateCompleted)
|
||||
}
|
||||
|
||||
passivating = passivating - ref
|
||||
}
|
||||
|
||||
def passivate(entity: ActorRef, stopMessage: Any): Unit = {
|
||||
|
|
@ -212,7 +192,7 @@ private[akka] class Shard(
|
|||
}
|
||||
}
|
||||
|
||||
// EntityStopped persistence handler
|
||||
// EntityStopped handler
|
||||
def passivateCompleted(event: EntityStopped): Unit = {
|
||||
log.debug("Entity stopped [{}]", event.entityId)
|
||||
|
||||
|
|
@ -224,7 +204,7 @@ private[akka] class Shard(
|
|||
messageBuffers = messageBuffers - event.entityId
|
||||
}
|
||||
|
||||
// EntityStarted persistence handler
|
||||
// EntityStarted handler
|
||||
def sendMsgBuffer(event: EntityStarted): Unit = {
|
||||
//Get the buffered messages and remove the buffer
|
||||
val messages = messageBuffers.getOrElse(event.entityId, Vector.empty)
|
||||
|
|
@ -265,17 +245,8 @@ private[akka] class Shard(
|
|||
def deliverTo(id: EntityId, msg: Any, payload: Msg, snd: ActorRef): Unit = {
|
||||
val name = URLEncoder.encode(id, "utf-8")
|
||||
context.child(name) match {
|
||||
case Some(actor) ⇒
|
||||
actor.tell(payload, snd)
|
||||
|
||||
case None if rememberEntities ⇒
|
||||
//Note; we only do this if remembering, otherwise the buffer is an overhead
|
||||
messageBuffers = messageBuffers.updated(id, Vector((msg, snd)))
|
||||
saveSnapshotWhenNeeded()
|
||||
persist(EntityStarted(id))(sendMsgBuffer)
|
||||
|
||||
case None ⇒
|
||||
getEntity(id).tell(payload, snd)
|
||||
case Some(actor) ⇒ actor.tell(payload, snd)
|
||||
case None ⇒ getEntity(id).tell(payload, snd)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -292,3 +263,90 @@ private[akka] class Shard(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* INTERNAL API
|
||||
*
|
||||
* This actor creates children entity actors on demand that it is told to be
|
||||
* responsible for. It is used when `rememberEntities` is enabled.
|
||||
*
|
||||
* @see [[ClusterSharding$ ClusterSharding extension]]
|
||||
*/
|
||||
private[akka] class PersistentShard(
|
||||
typeName: String,
|
||||
shardId: ShardRegion.ShardId,
|
||||
entityProps: Props,
|
||||
settings: ClusterShardingSettings,
|
||||
extractEntityId: ShardRegion.ExtractEntityId,
|
||||
extractShardId: ShardRegion.ExtractShardId,
|
||||
handOffStopMessage: Any) extends Shard(
|
||||
typeName, shardId, entityProps, settings, extractEntityId, extractShardId, handOffStopMessage)
|
||||
with PersistentActor with ActorLogging {
|
||||
|
||||
import ShardRegion.{ EntityId, Msg }
|
||||
import Shard.{ State, RestartEntity, EntityStopped, EntityStarted }
|
||||
import settings.tuningParameters._
|
||||
|
||||
override def persistenceId = s"/sharding/${typeName}Shard/${shardId}"
|
||||
|
||||
override def journalPluginId: String = settings.journalPluginId
|
||||
|
||||
override def snapshotPluginId: String = settings.snapshotPluginId
|
||||
|
||||
var persistCount = 0
|
||||
|
||||
override def receive = receiveCommand
|
||||
|
||||
override def processChange[A](event: A)(handler: A ⇒ Unit): Unit = {
|
||||
saveSnapshotWhenNeeded()
|
||||
persist(event)(handler)
|
||||
}
|
||||
|
||||
def saveSnapshotWhenNeeded(): Unit = {
|
||||
persistCount += 1
|
||||
if (persistCount % snapshotAfter == 0) {
|
||||
log.debug("Saving snapshot, sequence number [{}]", snapshotSequenceNr)
|
||||
saveSnapshot(state)
|
||||
}
|
||||
}
|
||||
|
||||
override def receiveRecover: Receive = {
|
||||
case EntityStarted(id) ⇒ state = state.copy(state.entities + id)
|
||||
case EntityStopped(id) ⇒ state = state.copy(state.entities - id)
|
||||
case SnapshotOffer(_, snapshot: State) ⇒ state = snapshot
|
||||
case RecoveryCompleted ⇒ state.entities foreach getEntity
|
||||
}
|
||||
|
||||
override def entityTerminated(ref: ActorRef): Unit = {
|
||||
val id = idByRef(ref)
|
||||
if (messageBuffers.getOrElse(id, Vector.empty).nonEmpty) {
|
||||
//Note; because we're not persisting the EntityStopped, we don't need
|
||||
// to persist the EntityStarted either.
|
||||
log.debug("Starting entity [{}] again, there are buffered messages for it", id)
|
||||
sendMsgBuffer(EntityStarted(id))
|
||||
} else {
|
||||
if (!passivating.contains(ref)) {
|
||||
log.debug("Entity [{}] stopped without passivating, will restart after backoff", id)
|
||||
import context.dispatcher
|
||||
context.system.scheduler.scheduleOnce(entityRestartBackoff, self, RestartEntity(id))
|
||||
} else processChange(EntityStopped(id))(passivateCompleted)
|
||||
}
|
||||
|
||||
passivating = passivating - ref
|
||||
}
|
||||
|
||||
override def deliverTo(id: EntityId, msg: Any, payload: Msg, snd: ActorRef): Unit = {
|
||||
val name = URLEncoder.encode(id, "utf-8")
|
||||
context.child(name) match {
|
||||
case Some(actor) ⇒
|
||||
actor.tell(payload, snd)
|
||||
|
||||
case None ⇒
|
||||
//Note; we only do this if remembering, otherwise the buffer is an overhead
|
||||
messageBuffers = messageBuffers.updated(id, Vector((msg, snd)))
|
||||
saveSnapshotWhenNeeded()
|
||||
persist(EntityStarted(id))(sendMsgBuffer)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue