Detect illegal access to context (#29431)
This commit is contained in:
parent
f63ca66e56
commit
e4f5781d65
5 changed files with 187 additions and 16 deletions
|
|
@ -0,0 +1,101 @@
|
||||||
|
/*
|
||||||
|
* Copyright (C) 2020 Lightbend Inc. <https://www.lightbend.com>
|
||||||
|
*/
|
||||||
|
|
||||||
|
package akka.persistence.typed
|
||||||
|
|
||||||
|
import akka.actor.testkit.typed.scaladsl.{ LogCapturing, ScalaTestWithActorTestKit }
|
||||||
|
import akka.actor.typed.{ ActorRef, Behavior }
|
||||||
|
import akka.persistence.testkit.PersistenceTestKitPlugin
|
||||||
|
import akka.persistence.testkit.query.scaladsl.PersistenceTestKitReadJournal
|
||||||
|
import akka.persistence.typed.scaladsl.{ ActiveActiveEventSourcing, Effect, EventSourcedBehavior }
|
||||||
|
import akka.serialization.jackson.CborSerializable
|
||||||
|
import org.scalatest.concurrent.Eventually
|
||||||
|
import org.scalatest.wordspec.AnyWordSpecLike
|
||||||
|
|
||||||
|
object ActiveActiveIllegalAccessSpec {
|
||||||
|
|
||||||
|
val R1 = ReplicaId("R1")
|
||||||
|
val R2 = ReplicaId("R1")
|
||||||
|
val AllReplicas = Set(R1, R2)
|
||||||
|
|
||||||
|
sealed trait Command
|
||||||
|
case class AccessInCommandHandler(replyTo: ActorRef[Thrown]) extends Command
|
||||||
|
case class AccessInPersistCallback(replyTo: ActorRef[Thrown]) extends Command
|
||||||
|
|
||||||
|
case class Thrown(exception: Option[Throwable])
|
||||||
|
|
||||||
|
case class State(all: List[String]) extends CborSerializable
|
||||||
|
|
||||||
|
def apply(entityId: String, replica: ReplicaId): Behavior[Command] = {
|
||||||
|
ActiveActiveEventSourcing.withSharedJournal(
|
||||||
|
entityId,
|
||||||
|
replica,
|
||||||
|
AllReplicas,
|
||||||
|
PersistenceTestKitReadJournal.Identifier)(
|
||||||
|
aaContext =>
|
||||||
|
EventSourcedBehavior[Command, String, State](
|
||||||
|
aaContext.persistenceId,
|
||||||
|
State(Nil),
|
||||||
|
(_, command) =>
|
||||||
|
command match {
|
||||||
|
case AccessInCommandHandler(replyTo) =>
|
||||||
|
val exception = try {
|
||||||
|
aaContext.origin
|
||||||
|
None
|
||||||
|
} catch {
|
||||||
|
case t: Throwable =>
|
||||||
|
Some(t)
|
||||||
|
}
|
||||||
|
replyTo ! Thrown(exception)
|
||||||
|
Effect.none
|
||||||
|
case AccessInPersistCallback(replyTo) =>
|
||||||
|
Effect.persist("cat").thenRun { _ =>
|
||||||
|
val exception = try {
|
||||||
|
aaContext.concurrent
|
||||||
|
None
|
||||||
|
} catch {
|
||||||
|
case t: Throwable =>
|
||||||
|
Some(t)
|
||||||
|
}
|
||||||
|
replyTo ! Thrown(exception)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
(state, event) => state.copy(all = event :: state.all)))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
class ActiveActiveIllegalAccessSpec
|
||||||
|
extends ScalaTestWithActorTestKit(PersistenceTestKitPlugin.config)
|
||||||
|
with AnyWordSpecLike
|
||||||
|
with LogCapturing
|
||||||
|
with Eventually {
|
||||||
|
import ActiveActiveIllegalAccessSpec._
|
||||||
|
"ActiveActive" should {
|
||||||
|
"detect illegal access to context in command handler" in {
|
||||||
|
val probe = createTestProbe[Thrown]()
|
||||||
|
val ref = spawn(ActiveActiveIllegalAccessSpec("id1", R1))
|
||||||
|
ref ! AccessInCommandHandler(probe.ref)
|
||||||
|
val thrown: Throwable = probe.expectMessageType[Thrown].exception.get
|
||||||
|
thrown.getMessage should include("from the event handler")
|
||||||
|
}
|
||||||
|
"detect illegal access to context in persist thenRun" in {
|
||||||
|
val probe = createTestProbe[Thrown]()
|
||||||
|
val ref = spawn(ActiveActiveIllegalAccessSpec("id1", R1))
|
||||||
|
ref ! AccessInPersistCallback(probe.ref)
|
||||||
|
val thrown: Throwable = probe.expectMessageType[Thrown].exception.get
|
||||||
|
thrown.getMessage should include("from the event handler")
|
||||||
|
}
|
||||||
|
"detect illegal access in the factory" in {
|
||||||
|
val exception = intercept[UnsupportedOperationException] {
|
||||||
|
ActiveActiveEventSourcing.withSharedJournal("id2", R1, AllReplicas, PersistenceTestKitReadJournal.Identifier) {
|
||||||
|
aaContext =>
|
||||||
|
aaContext.origin
|
||||||
|
???
|
||||||
|
}
|
||||||
|
}
|
||||||
|
exception.getMessage should include("from the event handler")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -23,6 +23,7 @@ import akka.persistence.typed.SingleEventSeq
|
||||||
import akka.persistence.typed.internal.EventSourcedBehaviorImpl.GetState
|
import akka.persistence.typed.internal.EventSourcedBehaviorImpl.GetState
|
||||||
import akka.persistence.typed.internal.ReplayingEvents.ReplayingState
|
import akka.persistence.typed.internal.ReplayingEvents.ReplayingState
|
||||||
import akka.persistence.typed.internal.Running.WithSeqNrAccessible
|
import akka.persistence.typed.internal.Running.WithSeqNrAccessible
|
||||||
|
import akka.persistence.typed.scaladsl.EventSourcedBehavior.ActiveActive
|
||||||
import akka.util.OptionVal
|
import akka.util.OptionVal
|
||||||
import akka.util.PrettyDuration._
|
import akka.util.PrettyDuration._
|
||||||
import akka.util.unused
|
import akka.util.unused
|
||||||
|
|
@ -122,7 +123,7 @@ private[akka] final class ReplayingEvents[C, E, S](
|
||||||
eventForErrorReporting = OptionVal.Some(event)
|
eventForErrorReporting = OptionVal.Some(event)
|
||||||
state = state.copy(seqNr = repr.sequenceNr)
|
state = state.copy(seqNr = repr.sequenceNr)
|
||||||
|
|
||||||
val aaMetaAndSelfReplica: Option[(ReplicatedEventMetaData, ReplicaId)] =
|
val aaMetaAndSelfReplica: Option[(ReplicatedEventMetaData, ReplicaId, ActiveActive)] =
|
||||||
setup.activeActive match {
|
setup.activeActive match {
|
||||||
case Some(aa) =>
|
case Some(aa) =>
|
||||||
val meta = repr.metadata match {
|
val meta = repr.metadata match {
|
||||||
|
|
@ -133,19 +134,30 @@ private[akka] final class ReplayingEvents[C, E, S](
|
||||||
|
|
||||||
}
|
}
|
||||||
aa.setContext(recoveryRunning = true, meta.originReplica, meta.concurrent)
|
aa.setContext(recoveryRunning = true, meta.originReplica, meta.concurrent)
|
||||||
Some(meta -> aa.replicaId)
|
Some((meta, aa.replicaId, aa))
|
||||||
case None => None
|
case None => None
|
||||||
}
|
}
|
||||||
|
|
||||||
val newState = setup.eventHandler(state.state, event)
|
val newState = setup.eventHandler(state.state, event)
|
||||||
|
|
||||||
|
setup.activeActive match {
|
||||||
|
case Some(aa) =>
|
||||||
|
aa.clearContext()
|
||||||
|
case None =>
|
||||||
|
}
|
||||||
|
|
||||||
aaMetaAndSelfReplica match {
|
aaMetaAndSelfReplica match {
|
||||||
case Some((meta, selfReplica)) if meta.originReplica != selfReplica =>
|
case Some((meta, selfReplica, aa)) if meta.originReplica != selfReplica =>
|
||||||
// keep track of highest origin seqnr per other replica
|
// keep track of highest origin seqnr per other replica
|
||||||
state = state.copy(
|
state = state.copy(
|
||||||
state = newState,
|
state = newState,
|
||||||
eventSeenInInterval = true,
|
eventSeenInInterval = true,
|
||||||
version = meta.version,
|
version = meta.version,
|
||||||
seenSeqNrPerReplica = state.seenSeqNrPerReplica + (meta.originReplica -> meta.originSequenceNr))
|
seenSeqNrPerReplica = state.seenSeqNrPerReplica + (meta.originReplica -> meta.originSequenceNr))
|
||||||
|
aa.clearContext()
|
||||||
|
case Some((_, _, aa)) =>
|
||||||
|
aa.clearContext()
|
||||||
|
state = state.copy(state = newState, eventSeenInInterval = true)
|
||||||
case _ =>
|
case _ =>
|
||||||
state = state.copy(state = newState, eventSeenInInterval = true)
|
state = state.copy(state = newState, eventSeenInInterval = true)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -377,13 +377,19 @@ private[akka] object Running {
|
||||||
this
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def withContext[A](aa: ActiveActive, withActiveActive: ActiveActive => Unit, f: () => A): A = {
|
||||||
|
withActiveActive(aa)
|
||||||
|
val result = f()
|
||||||
|
aa.clearContext()
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
private def handleExternalReplicatedEventPersist(
|
private def handleExternalReplicatedEventPersist(
|
||||||
activeActive: ActiveActive,
|
activeActive: ActiveActive,
|
||||||
event: ReplicatedEvent[E]): Behavior[InternalProtocol] = {
|
event: ReplicatedEvent[E]): Behavior[InternalProtocol] = {
|
||||||
_currentSequenceNumber = state.seqNr + 1
|
_currentSequenceNumber = state.seqNr + 1
|
||||||
val isConcurrent: Boolean = event.originVersion <> state.version
|
val isConcurrent: Boolean = event.originVersion <> state.version
|
||||||
val updatedVersion = event.originVersion.merge(state.version)
|
val updatedVersion = event.originVersion.merge(state.version)
|
||||||
activeActive.setContext(false, event.originReplica, isConcurrent)
|
|
||||||
|
|
||||||
if (setup.log.isDebugEnabled())
|
if (setup.log.isDebugEnabled())
|
||||||
setup.log.debugN(
|
setup.log.debugN(
|
||||||
|
|
@ -394,7 +400,11 @@ private[akka] object Running {
|
||||||
updatedVersion,
|
updatedVersion,
|
||||||
isConcurrent)
|
isConcurrent)
|
||||||
|
|
||||||
val newState: RunningState[S] = state.applyEvent(setup, event.event)
|
val newState: RunningState[S] = withContext(
|
||||||
|
activeActive,
|
||||||
|
aa => aa.setContext(recoveryRunning = false, event.originReplica, concurrent = isConcurrent),
|
||||||
|
() => state.applyEvent(setup, event.event))
|
||||||
|
|
||||||
val newState2: RunningState[S] = internalPersist(
|
val newState2: RunningState[S] = internalPersist(
|
||||||
setup.context,
|
setup.context,
|
||||||
null,
|
null,
|
||||||
|
|
@ -423,13 +433,17 @@ private[akka] object Running {
|
||||||
// also, ensure that there is an event handler for each single event
|
// also, ensure that there is an event handler for each single event
|
||||||
_currentSequenceNumber = state.seqNr + 1
|
_currentSequenceNumber = state.seqNr + 1
|
||||||
|
|
||||||
setup.activeActive.foreach { aa =>
|
val newState: RunningState[S] = setup.activeActive match {
|
||||||
// set concurrent to false, local events are never concurrent
|
case Some(aa) =>
|
||||||
aa.setContext(recoveryRunning = false, aa.replicaId, concurrent = false)
|
// set concurrent to false, local events are never concurrent
|
||||||
|
withContext(
|
||||||
|
aa,
|
||||||
|
aa => aa.setContext(recoveryRunning = false, aa.replicaId, concurrent = false),
|
||||||
|
() => state.applyEvent(setup, event))
|
||||||
|
case None =>
|
||||||
|
state.applyEvent(setup, event)
|
||||||
}
|
}
|
||||||
|
|
||||||
val newState: RunningState[S] = state.applyEvent(setup, event)
|
|
||||||
|
|
||||||
val eventToPersist = adaptEvent(event)
|
val eventToPersist = adaptEvent(event)
|
||||||
val eventAdapterManifest = setup.eventAdapter.manifest(event)
|
val eventAdapterManifest = setup.eventAdapter.manifest(event)
|
||||||
|
|
||||||
|
|
@ -500,7 +514,17 @@ private[akka] object Running {
|
||||||
Some(template.copy(originSequenceNr = _currentSequenceNumber, version = updatedVersion))
|
Some(template.copy(originSequenceNr = _currentSequenceNumber, version = updatedVersion))
|
||||||
case None => None
|
case None => None
|
||||||
}
|
}
|
||||||
currentState = currentState.applyEvent(setup, event)
|
|
||||||
|
currentState = setup.activeActive match {
|
||||||
|
case Some(aa) =>
|
||||||
|
withContext(
|
||||||
|
aa,
|
||||||
|
aa => aa.setContext(recoveryRunning = false, aa.replicaId, concurrent = false),
|
||||||
|
() => currentState.applyEvent(setup, event))
|
||||||
|
case None =>
|
||||||
|
currentState.applyEvent(setup, event)
|
||||||
|
}
|
||||||
|
|
||||||
eventsToPersist = EventToPersist(adaptedEvent, evtManifest, eventMetadata) :: eventsToPersist
|
eventsToPersist = EventToPersist(adaptedEvent, evtManifest, eventMetadata) :: eventsToPersist
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,8 @@ package akka.persistence.typed.scaladsl
|
||||||
|
|
||||||
import akka.persistence.typed.PersistenceId
|
import akka.persistence.typed.PersistenceId
|
||||||
import akka.persistence.typed.ReplicaId
|
import akka.persistence.typed.ReplicaId
|
||||||
import akka.util.WallClock
|
import akka.util.{ OptionVal, WallClock }
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
// FIXME docs
|
// FIXME docs
|
||||||
|
|
@ -34,30 +35,54 @@ private[akka] class ActiveActiveContextImpl(
|
||||||
extends ActiveActiveContext
|
extends ActiveActiveContext
|
||||||
with akka.persistence.typed.javadsl.ActiveActiveContext {
|
with akka.persistence.typed.javadsl.ActiveActiveContext {
|
||||||
val allReplicas: Set[ReplicaId] = replicasAndQueryPlugins.keySet
|
val allReplicas: Set[ReplicaId] = replicasAndQueryPlugins.keySet
|
||||||
|
|
||||||
|
// these are not volatile as they are set on the same thread as they should be accessed
|
||||||
var _origin: ReplicaId = null
|
var _origin: ReplicaId = null
|
||||||
var _recoveryRunning: Boolean = false
|
var _recoveryRunning: Boolean = false
|
||||||
var _concurrent: Boolean = false
|
var _concurrent: Boolean = false
|
||||||
|
var _currentThread: OptionVal[Thread] = OptionVal.None
|
||||||
|
|
||||||
// FIXME check illegal access https://github.com/akka/akka/issues/29264
|
private def checkAccess(functionName: String): Unit = {
|
||||||
|
val callerThread = Thread.currentThread()
|
||||||
|
def error() =
|
||||||
|
throw new UnsupportedOperationException(
|
||||||
|
s"Unsupported access to ActiveActiveContext operation from the outside of event handler. " +
|
||||||
|
s"$functionName can only be called from the event handler")
|
||||||
|
_currentThread match {
|
||||||
|
case OptionVal.Some(t) =>
|
||||||
|
if (callerThread ne t) error()
|
||||||
|
case OptionVal.None =>
|
||||||
|
error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The origin of the current event.
|
* The origin of the current event.
|
||||||
* Undefined result if called from anywhere other than an event handler.
|
* Undefined result if called from anywhere other than an event handler.
|
||||||
*/
|
*/
|
||||||
override def origin: ReplicaId = _origin
|
override def origin: ReplicaId = {
|
||||||
|
checkAccess("origin")
|
||||||
|
_origin
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Whether the happened concurrently with an event from another replica.
|
* Whether the happened concurrently with an event from another replica.
|
||||||
* Undefined result if called from any where other than an event handler.
|
* Undefined result if called from any where other than an event handler.
|
||||||
*/
|
*/
|
||||||
override def concurrent: Boolean = _concurrent
|
override def concurrent: Boolean = {
|
||||||
|
checkAccess("concurrent")
|
||||||
|
_concurrent
|
||||||
|
}
|
||||||
|
|
||||||
override def persistenceId: PersistenceId = PersistenceId.replicatedUniqueId(entityId, replicaId)
|
override def persistenceId: PersistenceId = PersistenceId.replicatedUniqueId(entityId, replicaId)
|
||||||
|
|
||||||
override def currentTimeMillis(): Long = {
|
override def currentTimeMillis(): Long = {
|
||||||
WallClock.AlwaysIncreasingClock.currentTimeMillis()
|
WallClock.AlwaysIncreasingClock.currentTimeMillis()
|
||||||
}
|
}
|
||||||
override def recoveryRunning: Boolean = _recoveryRunning
|
override def recoveryRunning: Boolean = {
|
||||||
|
checkAccess("recoveryRunning")
|
||||||
|
_recoveryRunning
|
||||||
|
}
|
||||||
|
|
||||||
override def getAllReplicas: java.util.Set[ReplicaId] = allReplicas.asJava
|
override def getAllReplicas: java.util.Set[ReplicaId] = allReplicas.asJava
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ import akka.persistence.typed.ReplicaId
|
||||||
import akka.persistence.typed.SnapshotAdapter
|
import akka.persistence.typed.SnapshotAdapter
|
||||||
import akka.persistence.typed.SnapshotSelectionCriteria
|
import akka.persistence.typed.SnapshotSelectionCriteria
|
||||||
import akka.persistence.typed.internal._
|
import akka.persistence.typed.internal._
|
||||||
|
import akka.util.OptionVal
|
||||||
|
|
||||||
object EventSourcedBehavior {
|
object EventSourcedBehavior {
|
||||||
|
|
||||||
|
|
@ -36,11 +37,19 @@ object EventSourcedBehavior {
|
||||||
* Must only be called on the same thread that will execute the user code
|
* Must only be called on the same thread that will execute the user code
|
||||||
*/
|
*/
|
||||||
def setContext(recoveryRunning: Boolean, originReplica: ReplicaId, concurrent: Boolean): Unit = {
|
def setContext(recoveryRunning: Boolean, originReplica: ReplicaId, concurrent: Boolean): Unit = {
|
||||||
|
aaContext._currentThread = OptionVal.Some(Thread.currentThread())
|
||||||
aaContext._recoveryRunning = recoveryRunning
|
aaContext._recoveryRunning = recoveryRunning
|
||||||
aaContext._concurrent = concurrent
|
aaContext._concurrent = concurrent
|
||||||
aaContext._origin = originReplica
|
aaContext._origin = originReplica
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def clearContext(): Unit = {
|
||||||
|
aaContext._currentThread = OptionVal.None
|
||||||
|
aaContext._recoveryRunning = false
|
||||||
|
aaContext._concurrent = false
|
||||||
|
aaContext._origin = null
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue