Detect illegal access to context (#29431)

This commit is contained in:
Christopher Batey 2020-07-28 13:45:25 +01:00
parent f63ca66e56
commit e4f5781d65
5 changed files with 187 additions and 16 deletions

View file

@ -0,0 +1,101 @@
/*
* Copyright (C) 2020 Lightbend Inc. <https://www.lightbend.com>
*/
package akka.persistence.typed
import akka.actor.testkit.typed.scaladsl.{ LogCapturing, ScalaTestWithActorTestKit }
import akka.actor.typed.{ ActorRef, Behavior }
import akka.persistence.testkit.PersistenceTestKitPlugin
import akka.persistence.testkit.query.scaladsl.PersistenceTestKitReadJournal
import akka.persistence.typed.scaladsl.{ ActiveActiveEventSourcing, Effect, EventSourcedBehavior }
import akka.serialization.jackson.CborSerializable
import org.scalatest.concurrent.Eventually
import org.scalatest.wordspec.AnyWordSpecLike
object ActiveActiveIllegalAccessSpec {
val R1 = ReplicaId("R1")
val R2 = ReplicaId("R1")
val AllReplicas = Set(R1, R2)
sealed trait Command
case class AccessInCommandHandler(replyTo: ActorRef[Thrown]) extends Command
case class AccessInPersistCallback(replyTo: ActorRef[Thrown]) extends Command
case class Thrown(exception: Option[Throwable])
case class State(all: List[String]) extends CborSerializable
def apply(entityId: String, replica: ReplicaId): Behavior[Command] = {
ActiveActiveEventSourcing.withSharedJournal(
entityId,
replica,
AllReplicas,
PersistenceTestKitReadJournal.Identifier)(
aaContext =>
EventSourcedBehavior[Command, String, State](
aaContext.persistenceId,
State(Nil),
(_, command) =>
command match {
case AccessInCommandHandler(replyTo) =>
val exception = try {
aaContext.origin
None
} catch {
case t: Throwable =>
Some(t)
}
replyTo ! Thrown(exception)
Effect.none
case AccessInPersistCallback(replyTo) =>
Effect.persist("cat").thenRun { _ =>
val exception = try {
aaContext.concurrent
None
} catch {
case t: Throwable =>
Some(t)
}
replyTo ! Thrown(exception)
}
},
(state, event) => state.copy(all = event :: state.all)))
}
}
class ActiveActiveIllegalAccessSpec
extends ScalaTestWithActorTestKit(PersistenceTestKitPlugin.config)
with AnyWordSpecLike
with LogCapturing
with Eventually {
import ActiveActiveIllegalAccessSpec._
"ActiveActive" should {
"detect illegal access to context in command handler" in {
val probe = createTestProbe[Thrown]()
val ref = spawn(ActiveActiveIllegalAccessSpec("id1", R1))
ref ! AccessInCommandHandler(probe.ref)
val thrown: Throwable = probe.expectMessageType[Thrown].exception.get
thrown.getMessage should include("from the event handler")
}
"detect illegal access to context in persist thenRun" in {
val probe = createTestProbe[Thrown]()
val ref = spawn(ActiveActiveIllegalAccessSpec("id1", R1))
ref ! AccessInPersistCallback(probe.ref)
val thrown: Throwable = probe.expectMessageType[Thrown].exception.get
thrown.getMessage should include("from the event handler")
}
"detect illegal access in the factory" in {
val exception = intercept[UnsupportedOperationException] {
ActiveActiveEventSourcing.withSharedJournal("id2", R1, AllReplicas, PersistenceTestKitReadJournal.Identifier) {
aaContext =>
aaContext.origin
???
}
}
exception.getMessage should include("from the event handler")
}
}
}

View file

@ -23,6 +23,7 @@ import akka.persistence.typed.SingleEventSeq
import akka.persistence.typed.internal.EventSourcedBehaviorImpl.GetState
import akka.persistence.typed.internal.ReplayingEvents.ReplayingState
import akka.persistence.typed.internal.Running.WithSeqNrAccessible
import akka.persistence.typed.scaladsl.EventSourcedBehavior.ActiveActive
import akka.util.OptionVal
import akka.util.PrettyDuration._
import akka.util.unused
@ -122,7 +123,7 @@ private[akka] final class ReplayingEvents[C, E, S](
eventForErrorReporting = OptionVal.Some(event)
state = state.copy(seqNr = repr.sequenceNr)
val aaMetaAndSelfReplica: Option[(ReplicatedEventMetaData, ReplicaId)] =
val aaMetaAndSelfReplica: Option[(ReplicatedEventMetaData, ReplicaId, ActiveActive)] =
setup.activeActive match {
case Some(aa) =>
val meta = repr.metadata match {
@ -133,19 +134,30 @@ private[akka] final class ReplayingEvents[C, E, S](
}
aa.setContext(recoveryRunning = true, meta.originReplica, meta.concurrent)
Some(meta -> aa.replicaId)
Some((meta, aa.replicaId, aa))
case None => None
}
val newState = setup.eventHandler(state.state, event)
setup.activeActive match {
case Some(aa) =>
aa.clearContext()
case None =>
}
aaMetaAndSelfReplica match {
case Some((meta, selfReplica)) if meta.originReplica != selfReplica =>
case Some((meta, selfReplica, aa)) if meta.originReplica != selfReplica =>
// keep track of highest origin seqnr per other replica
state = state.copy(
state = newState,
eventSeenInInterval = true,
version = meta.version,
seenSeqNrPerReplica = state.seenSeqNrPerReplica + (meta.originReplica -> meta.originSequenceNr))
aa.clearContext()
case Some((_, _, aa)) =>
aa.clearContext()
state = state.copy(state = newState, eventSeenInInterval = true)
case _ =>
state = state.copy(state = newState, eventSeenInInterval = true)
}

View file

@ -377,13 +377,19 @@ private[akka] object Running {
this
}
def withContext[A](aa: ActiveActive, withActiveActive: ActiveActive => Unit, f: () => A): A = {
withActiveActive(aa)
val result = f()
aa.clearContext()
result
}
private def handleExternalReplicatedEventPersist(
activeActive: ActiveActive,
event: ReplicatedEvent[E]): Behavior[InternalProtocol] = {
_currentSequenceNumber = state.seqNr + 1
val isConcurrent: Boolean = event.originVersion <> state.version
val updatedVersion = event.originVersion.merge(state.version)
activeActive.setContext(false, event.originReplica, isConcurrent)
if (setup.log.isDebugEnabled())
setup.log.debugN(
@ -394,7 +400,11 @@ private[akka] object Running {
updatedVersion,
isConcurrent)
val newState: RunningState[S] = state.applyEvent(setup, event.event)
val newState: RunningState[S] = withContext(
activeActive,
aa => aa.setContext(recoveryRunning = false, event.originReplica, concurrent = isConcurrent),
() => state.applyEvent(setup, event.event))
val newState2: RunningState[S] = internalPersist(
setup.context,
null,
@ -423,13 +433,17 @@ private[akka] object Running {
// also, ensure that there is an event handler for each single event
_currentSequenceNumber = state.seqNr + 1
setup.activeActive.foreach { aa =>
// set concurrent to false, local events are never concurrent
aa.setContext(recoveryRunning = false, aa.replicaId, concurrent = false)
val newState: RunningState[S] = setup.activeActive match {
case Some(aa) =>
// set concurrent to false, local events are never concurrent
withContext(
aa,
aa => aa.setContext(recoveryRunning = false, aa.replicaId, concurrent = false),
() => state.applyEvent(setup, event))
case None =>
state.applyEvent(setup, event)
}
val newState: RunningState[S] = state.applyEvent(setup, event)
val eventToPersist = adaptEvent(event)
val eventAdapterManifest = setup.eventAdapter.manifest(event)
@ -500,7 +514,17 @@ private[akka] object Running {
Some(template.copy(originSequenceNr = _currentSequenceNumber, version = updatedVersion))
case None => None
}
currentState = currentState.applyEvent(setup, event)
currentState = setup.activeActive match {
case Some(aa) =>
withContext(
aa,
aa => aa.setContext(recoveryRunning = false, aa.replicaId, concurrent = false),
() => currentState.applyEvent(setup, event))
case None =>
currentState.applyEvent(setup, event)
}
eventsToPersist = EventToPersist(adaptedEvent, evtManifest, eventMetadata) :: eventsToPersist
}

View file

@ -6,7 +6,8 @@ package akka.persistence.typed.scaladsl
import akka.persistence.typed.PersistenceId
import akka.persistence.typed.ReplicaId
import akka.util.WallClock
import akka.util.{ OptionVal, WallClock }
import scala.collection.JavaConverters._
// FIXME docs
@ -34,30 +35,54 @@ private[akka] class ActiveActiveContextImpl(
extends ActiveActiveContext
with akka.persistence.typed.javadsl.ActiveActiveContext {
val allReplicas: Set[ReplicaId] = replicasAndQueryPlugins.keySet
// these are not volatile as they are set on the same thread as they should be accessed
var _origin: ReplicaId = null
var _recoveryRunning: Boolean = false
var _concurrent: Boolean = false
var _currentThread: OptionVal[Thread] = OptionVal.None
// FIXME check illegal access https://github.com/akka/akka/issues/29264
private def checkAccess(functionName: String): Unit = {
val callerThread = Thread.currentThread()
def error() =
throw new UnsupportedOperationException(
s"Unsupported access to ActiveActiveContext operation from the outside of event handler. " +
s"$functionName can only be called from the event handler")
_currentThread match {
case OptionVal.Some(t) =>
if (callerThread ne t) error()
case OptionVal.None =>
error()
}
}
/**
* The origin of the current event.
* Undefined result if called from anywhere other than an event handler.
*/
override def origin: ReplicaId = _origin
override def origin: ReplicaId = {
checkAccess("origin")
_origin
}
/**
* Whether the happened concurrently with an event from another replica.
* Undefined result if called from any where other than an event handler.
*/
override def concurrent: Boolean = _concurrent
override def concurrent: Boolean = {
checkAccess("concurrent")
_concurrent
}
override def persistenceId: PersistenceId = PersistenceId.replicatedUniqueId(entityId, replicaId)
override def currentTimeMillis(): Long = {
WallClock.AlwaysIncreasingClock.currentTimeMillis()
}
override def recoveryRunning: Boolean = _recoveryRunning
override def recoveryRunning: Boolean = {
checkAccess("recoveryRunning")
_recoveryRunning
}
override def getAllReplicas: java.util.Set[ReplicaId] = allReplicas.asJava
}

View file

@ -20,6 +20,7 @@ import akka.persistence.typed.ReplicaId
import akka.persistence.typed.SnapshotAdapter
import akka.persistence.typed.SnapshotSelectionCriteria
import akka.persistence.typed.internal._
import akka.util.OptionVal
object EventSourcedBehavior {
@ -36,11 +37,19 @@ object EventSourcedBehavior {
* Must only be called on the same thread that will execute the user code
*/
def setContext(recoveryRunning: Boolean, originReplica: ReplicaId, concurrent: Boolean): Unit = {
aaContext._currentThread = OptionVal.Some(Thread.currentThread())
aaContext._recoveryRunning = recoveryRunning
aaContext._concurrent = concurrent
aaContext._origin = originReplica
}
def clearContext(): Unit = {
aaContext._currentThread = OptionVal.None
aaContext._recoveryRunning = false
aaContext._concurrent = false
aaContext._origin = null
}
}
/**