Detect illegal access to context (#29431)
This commit is contained in:
parent
f63ca66e56
commit
e4f5781d65
5 changed files with 187 additions and 16 deletions
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Copyright (C) 2020 Lightbend Inc. <https://www.lightbend.com>
|
||||
*/
|
||||
|
||||
package akka.persistence.typed
|
||||
|
||||
import akka.actor.testkit.typed.scaladsl.{ LogCapturing, ScalaTestWithActorTestKit }
|
||||
import akka.actor.typed.{ ActorRef, Behavior }
|
||||
import akka.persistence.testkit.PersistenceTestKitPlugin
|
||||
import akka.persistence.testkit.query.scaladsl.PersistenceTestKitReadJournal
|
||||
import akka.persistence.typed.scaladsl.{ ActiveActiveEventSourcing, Effect, EventSourcedBehavior }
|
||||
import akka.serialization.jackson.CborSerializable
|
||||
import org.scalatest.concurrent.Eventually
|
||||
import org.scalatest.wordspec.AnyWordSpecLike
|
||||
|
||||
object ActiveActiveIllegalAccessSpec {
|
||||
|
||||
val R1 = ReplicaId("R1")
|
||||
val R2 = ReplicaId("R1")
|
||||
val AllReplicas = Set(R1, R2)
|
||||
|
||||
sealed trait Command
|
||||
case class AccessInCommandHandler(replyTo: ActorRef[Thrown]) extends Command
|
||||
case class AccessInPersistCallback(replyTo: ActorRef[Thrown]) extends Command
|
||||
|
||||
case class Thrown(exception: Option[Throwable])
|
||||
|
||||
case class State(all: List[String]) extends CborSerializable
|
||||
|
||||
def apply(entityId: String, replica: ReplicaId): Behavior[Command] = {
|
||||
ActiveActiveEventSourcing.withSharedJournal(
|
||||
entityId,
|
||||
replica,
|
||||
AllReplicas,
|
||||
PersistenceTestKitReadJournal.Identifier)(
|
||||
aaContext =>
|
||||
EventSourcedBehavior[Command, String, State](
|
||||
aaContext.persistenceId,
|
||||
State(Nil),
|
||||
(_, command) =>
|
||||
command match {
|
||||
case AccessInCommandHandler(replyTo) =>
|
||||
val exception = try {
|
||||
aaContext.origin
|
||||
None
|
||||
} catch {
|
||||
case t: Throwable =>
|
||||
Some(t)
|
||||
}
|
||||
replyTo ! Thrown(exception)
|
||||
Effect.none
|
||||
case AccessInPersistCallback(replyTo) =>
|
||||
Effect.persist("cat").thenRun { _ =>
|
||||
val exception = try {
|
||||
aaContext.concurrent
|
||||
None
|
||||
} catch {
|
||||
case t: Throwable =>
|
||||
Some(t)
|
||||
}
|
||||
replyTo ! Thrown(exception)
|
||||
}
|
||||
},
|
||||
(state, event) => state.copy(all = event :: state.all)))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
class ActiveActiveIllegalAccessSpec
|
||||
extends ScalaTestWithActorTestKit(PersistenceTestKitPlugin.config)
|
||||
with AnyWordSpecLike
|
||||
with LogCapturing
|
||||
with Eventually {
|
||||
import ActiveActiveIllegalAccessSpec._
|
||||
"ActiveActive" should {
|
||||
"detect illegal access to context in command handler" in {
|
||||
val probe = createTestProbe[Thrown]()
|
||||
val ref = spawn(ActiveActiveIllegalAccessSpec("id1", R1))
|
||||
ref ! AccessInCommandHandler(probe.ref)
|
||||
val thrown: Throwable = probe.expectMessageType[Thrown].exception.get
|
||||
thrown.getMessage should include("from the event handler")
|
||||
}
|
||||
"detect illegal access to context in persist thenRun" in {
|
||||
val probe = createTestProbe[Thrown]()
|
||||
val ref = spawn(ActiveActiveIllegalAccessSpec("id1", R1))
|
||||
ref ! AccessInPersistCallback(probe.ref)
|
||||
val thrown: Throwable = probe.expectMessageType[Thrown].exception.get
|
||||
thrown.getMessage should include("from the event handler")
|
||||
}
|
||||
"detect illegal access in the factory" in {
|
||||
val exception = intercept[UnsupportedOperationException] {
|
||||
ActiveActiveEventSourcing.withSharedJournal("id2", R1, AllReplicas, PersistenceTestKitReadJournal.Identifier) {
|
||||
aaContext =>
|
||||
aaContext.origin
|
||||
???
|
||||
}
|
||||
}
|
||||
exception.getMessage should include("from the event handler")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -23,6 +23,7 @@ import akka.persistence.typed.SingleEventSeq
|
|||
import akka.persistence.typed.internal.EventSourcedBehaviorImpl.GetState
|
||||
import akka.persistence.typed.internal.ReplayingEvents.ReplayingState
|
||||
import akka.persistence.typed.internal.Running.WithSeqNrAccessible
|
||||
import akka.persistence.typed.scaladsl.EventSourcedBehavior.ActiveActive
|
||||
import akka.util.OptionVal
|
||||
import akka.util.PrettyDuration._
|
||||
import akka.util.unused
|
||||
|
|
@ -122,7 +123,7 @@ private[akka] final class ReplayingEvents[C, E, S](
|
|||
eventForErrorReporting = OptionVal.Some(event)
|
||||
state = state.copy(seqNr = repr.sequenceNr)
|
||||
|
||||
val aaMetaAndSelfReplica: Option[(ReplicatedEventMetaData, ReplicaId)] =
|
||||
val aaMetaAndSelfReplica: Option[(ReplicatedEventMetaData, ReplicaId, ActiveActive)] =
|
||||
setup.activeActive match {
|
||||
case Some(aa) =>
|
||||
val meta = repr.metadata match {
|
||||
|
|
@ -133,19 +134,30 @@ private[akka] final class ReplayingEvents[C, E, S](
|
|||
|
||||
}
|
||||
aa.setContext(recoveryRunning = true, meta.originReplica, meta.concurrent)
|
||||
Some(meta -> aa.replicaId)
|
||||
Some((meta, aa.replicaId, aa))
|
||||
case None => None
|
||||
}
|
||||
|
||||
val newState = setup.eventHandler(state.state, event)
|
||||
|
||||
setup.activeActive match {
|
||||
case Some(aa) =>
|
||||
aa.clearContext()
|
||||
case None =>
|
||||
}
|
||||
|
||||
aaMetaAndSelfReplica match {
|
||||
case Some((meta, selfReplica)) if meta.originReplica != selfReplica =>
|
||||
case Some((meta, selfReplica, aa)) if meta.originReplica != selfReplica =>
|
||||
// keep track of highest origin seqnr per other replica
|
||||
state = state.copy(
|
||||
state = newState,
|
||||
eventSeenInInterval = true,
|
||||
version = meta.version,
|
||||
seenSeqNrPerReplica = state.seenSeqNrPerReplica + (meta.originReplica -> meta.originSequenceNr))
|
||||
aa.clearContext()
|
||||
case Some((_, _, aa)) =>
|
||||
aa.clearContext()
|
||||
state = state.copy(state = newState, eventSeenInInterval = true)
|
||||
case _ =>
|
||||
state = state.copy(state = newState, eventSeenInInterval = true)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -377,13 +377,19 @@ private[akka] object Running {
|
|||
this
|
||||
}
|
||||
|
||||
def withContext[A](aa: ActiveActive, withActiveActive: ActiveActive => Unit, f: () => A): A = {
|
||||
withActiveActive(aa)
|
||||
val result = f()
|
||||
aa.clearContext()
|
||||
result
|
||||
}
|
||||
|
||||
private def handleExternalReplicatedEventPersist(
|
||||
activeActive: ActiveActive,
|
||||
event: ReplicatedEvent[E]): Behavior[InternalProtocol] = {
|
||||
_currentSequenceNumber = state.seqNr + 1
|
||||
val isConcurrent: Boolean = event.originVersion <> state.version
|
||||
val updatedVersion = event.originVersion.merge(state.version)
|
||||
activeActive.setContext(false, event.originReplica, isConcurrent)
|
||||
|
||||
if (setup.log.isDebugEnabled())
|
||||
setup.log.debugN(
|
||||
|
|
@ -394,7 +400,11 @@ private[akka] object Running {
|
|||
updatedVersion,
|
||||
isConcurrent)
|
||||
|
||||
val newState: RunningState[S] = state.applyEvent(setup, event.event)
|
||||
val newState: RunningState[S] = withContext(
|
||||
activeActive,
|
||||
aa => aa.setContext(recoveryRunning = false, event.originReplica, concurrent = isConcurrent),
|
||||
() => state.applyEvent(setup, event.event))
|
||||
|
||||
val newState2: RunningState[S] = internalPersist(
|
||||
setup.context,
|
||||
null,
|
||||
|
|
@ -423,13 +433,17 @@ private[akka] object Running {
|
|||
// also, ensure that there is an event handler for each single event
|
||||
_currentSequenceNumber = state.seqNr + 1
|
||||
|
||||
setup.activeActive.foreach { aa =>
|
||||
// set concurrent to false, local events are never concurrent
|
||||
aa.setContext(recoveryRunning = false, aa.replicaId, concurrent = false)
|
||||
val newState: RunningState[S] = setup.activeActive match {
|
||||
case Some(aa) =>
|
||||
// set concurrent to false, local events are never concurrent
|
||||
withContext(
|
||||
aa,
|
||||
aa => aa.setContext(recoveryRunning = false, aa.replicaId, concurrent = false),
|
||||
() => state.applyEvent(setup, event))
|
||||
case None =>
|
||||
state.applyEvent(setup, event)
|
||||
}
|
||||
|
||||
val newState: RunningState[S] = state.applyEvent(setup, event)
|
||||
|
||||
val eventToPersist = adaptEvent(event)
|
||||
val eventAdapterManifest = setup.eventAdapter.manifest(event)
|
||||
|
||||
|
|
@ -500,7 +514,17 @@ private[akka] object Running {
|
|||
Some(template.copy(originSequenceNr = _currentSequenceNumber, version = updatedVersion))
|
||||
case None => None
|
||||
}
|
||||
currentState = currentState.applyEvent(setup, event)
|
||||
|
||||
currentState = setup.activeActive match {
|
||||
case Some(aa) =>
|
||||
withContext(
|
||||
aa,
|
||||
aa => aa.setContext(recoveryRunning = false, aa.replicaId, concurrent = false),
|
||||
() => currentState.applyEvent(setup, event))
|
||||
case None =>
|
||||
currentState.applyEvent(setup, event)
|
||||
}
|
||||
|
||||
eventsToPersist = EventToPersist(adaptedEvent, evtManifest, eventMetadata) :: eventsToPersist
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ package akka.persistence.typed.scaladsl
|
|||
|
||||
import akka.persistence.typed.PersistenceId
|
||||
import akka.persistence.typed.ReplicaId
|
||||
import akka.util.WallClock
|
||||
import akka.util.{ OptionVal, WallClock }
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
// FIXME docs
|
||||
|
|
@ -34,30 +35,54 @@ private[akka] class ActiveActiveContextImpl(
|
|||
extends ActiveActiveContext
|
||||
with akka.persistence.typed.javadsl.ActiveActiveContext {
|
||||
val allReplicas: Set[ReplicaId] = replicasAndQueryPlugins.keySet
|
||||
|
||||
// these are not volatile as they are set on the same thread as they should be accessed
|
||||
var _origin: ReplicaId = null
|
||||
var _recoveryRunning: Boolean = false
|
||||
var _concurrent: Boolean = false
|
||||
var _currentThread: OptionVal[Thread] = OptionVal.None
|
||||
|
||||
// FIXME check illegal access https://github.com/akka/akka/issues/29264
|
||||
private def checkAccess(functionName: String): Unit = {
|
||||
val callerThread = Thread.currentThread()
|
||||
def error() =
|
||||
throw new UnsupportedOperationException(
|
||||
s"Unsupported access to ActiveActiveContext operation from the outside of event handler. " +
|
||||
s"$functionName can only be called from the event handler")
|
||||
_currentThread match {
|
||||
case OptionVal.Some(t) =>
|
||||
if (callerThread ne t) error()
|
||||
case OptionVal.None =>
|
||||
error()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The origin of the current event.
|
||||
* Undefined result if called from anywhere other than an event handler.
|
||||
*/
|
||||
override def origin: ReplicaId = _origin
|
||||
override def origin: ReplicaId = {
|
||||
checkAccess("origin")
|
||||
_origin
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether the happened concurrently with an event from another replica.
|
||||
* Undefined result if called from any where other than an event handler.
|
||||
*/
|
||||
override def concurrent: Boolean = _concurrent
|
||||
override def concurrent: Boolean = {
|
||||
checkAccess("concurrent")
|
||||
_concurrent
|
||||
}
|
||||
|
||||
override def persistenceId: PersistenceId = PersistenceId.replicatedUniqueId(entityId, replicaId)
|
||||
|
||||
override def currentTimeMillis(): Long = {
|
||||
WallClock.AlwaysIncreasingClock.currentTimeMillis()
|
||||
}
|
||||
override def recoveryRunning: Boolean = _recoveryRunning
|
||||
override def recoveryRunning: Boolean = {
|
||||
checkAccess("recoveryRunning")
|
||||
_recoveryRunning
|
||||
}
|
||||
|
||||
override def getAllReplicas: java.util.Set[ReplicaId] = allReplicas.asJava
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import akka.persistence.typed.ReplicaId
|
|||
import akka.persistence.typed.SnapshotAdapter
|
||||
import akka.persistence.typed.SnapshotSelectionCriteria
|
||||
import akka.persistence.typed.internal._
|
||||
import akka.util.OptionVal
|
||||
|
||||
object EventSourcedBehavior {
|
||||
|
||||
|
|
@ -36,11 +37,19 @@ object EventSourcedBehavior {
|
|||
* Must only be called on the same thread that will execute the user code
|
||||
*/
|
||||
def setContext(recoveryRunning: Boolean, originReplica: ReplicaId, concurrent: Boolean): Unit = {
|
||||
aaContext._currentThread = OptionVal.Some(Thread.currentThread())
|
||||
aaContext._recoveryRunning = recoveryRunning
|
||||
aaContext._concurrent = concurrent
|
||||
aaContext._origin = originReplica
|
||||
}
|
||||
|
||||
def clearContext(): Unit = {
|
||||
aaContext._currentThread = OptionVal.None
|
||||
aaContext._recoveryRunning = false
|
||||
aaContext._concurrent = false
|
||||
aaContext._origin = null
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue