diff --git a/akka-persistence-typed-tests/src/test/scala/akka/persistence/typed/ActiveActiveIllegalAccessSpec.scala b/akka-persistence-typed-tests/src/test/scala/akka/persistence/typed/ActiveActiveIllegalAccessSpec.scala new file mode 100644 index 0000000000..a8145639b0 --- /dev/null +++ b/akka-persistence-typed-tests/src/test/scala/akka/persistence/typed/ActiveActiveIllegalAccessSpec.scala @@ -0,0 +1,101 @@ +/* + * Copyright (C) 2020 Lightbend Inc. + */ + +package akka.persistence.typed + +import akka.actor.testkit.typed.scaladsl.{ LogCapturing, ScalaTestWithActorTestKit } +import akka.actor.typed.{ ActorRef, Behavior } +import akka.persistence.testkit.PersistenceTestKitPlugin +import akka.persistence.testkit.query.scaladsl.PersistenceTestKitReadJournal +import akka.persistence.typed.scaladsl.{ ActiveActiveEventSourcing, Effect, EventSourcedBehavior } +import akka.serialization.jackson.CborSerializable +import org.scalatest.concurrent.Eventually +import org.scalatest.wordspec.AnyWordSpecLike + +object ActiveActiveIllegalAccessSpec { + + val R1 = ReplicaId("R1") + val R2 = ReplicaId("R1") + val AllReplicas = Set(R1, R2) + + sealed trait Command + case class AccessInCommandHandler(replyTo: ActorRef[Thrown]) extends Command + case class AccessInPersistCallback(replyTo: ActorRef[Thrown]) extends Command + + case class Thrown(exception: Option[Throwable]) + + case class State(all: List[String]) extends CborSerializable + + def apply(entityId: String, replica: ReplicaId): Behavior[Command] = { + ActiveActiveEventSourcing.withSharedJournal( + entityId, + replica, + AllReplicas, + PersistenceTestKitReadJournal.Identifier)( + aaContext => + EventSourcedBehavior[Command, String, State]( + aaContext.persistenceId, + State(Nil), + (_, command) => + command match { + case AccessInCommandHandler(replyTo) => + val exception = try { + aaContext.origin + None + } catch { + case t: Throwable => + Some(t) + } + replyTo ! Thrown(exception) + Effect.none + case AccessInPersistCallback(replyTo) => + Effect.persist("cat").thenRun { _ => + val exception = try { + aaContext.concurrent + None + } catch { + case t: Throwable => + Some(t) + } + replyTo ! Thrown(exception) + } + }, + (state, event) => state.copy(all = event :: state.all))) + } + +} + +class ActiveActiveIllegalAccessSpec + extends ScalaTestWithActorTestKit(PersistenceTestKitPlugin.config) + with AnyWordSpecLike + with LogCapturing + with Eventually { + import ActiveActiveIllegalAccessSpec._ + "ActiveActive" should { + "detect illegal access to context in command handler" in { + val probe = createTestProbe[Thrown]() + val ref = spawn(ActiveActiveIllegalAccessSpec("id1", R1)) + ref ! AccessInCommandHandler(probe.ref) + val thrown: Throwable = probe.expectMessageType[Thrown].exception.get + thrown.getMessage should include("from the event handler") + } + "detect illegal access to context in persist thenRun" in { + val probe = createTestProbe[Thrown]() + val ref = spawn(ActiveActiveIllegalAccessSpec("id1", R1)) + ref ! AccessInPersistCallback(probe.ref) + val thrown: Throwable = probe.expectMessageType[Thrown].exception.get + thrown.getMessage should include("from the event handler") + } + "detect illegal access in the factory" in { + val exception = intercept[UnsupportedOperationException] { + ActiveActiveEventSourcing.withSharedJournal("id2", R1, AllReplicas, PersistenceTestKitReadJournal.Identifier) { + aaContext => + aaContext.origin + ??? + } + } + exception.getMessage should include("from the event handler") + } + } +} diff --git a/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/ReplayingEvents.scala b/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/ReplayingEvents.scala index 96b90d565f..aa7dd1c1d1 100644 --- a/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/ReplayingEvents.scala +++ b/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/ReplayingEvents.scala @@ -23,6 +23,7 @@ import akka.persistence.typed.SingleEventSeq import akka.persistence.typed.internal.EventSourcedBehaviorImpl.GetState import akka.persistence.typed.internal.ReplayingEvents.ReplayingState import akka.persistence.typed.internal.Running.WithSeqNrAccessible +import akka.persistence.typed.scaladsl.EventSourcedBehavior.ActiveActive import akka.util.OptionVal import akka.util.PrettyDuration._ import akka.util.unused @@ -122,7 +123,7 @@ private[akka] final class ReplayingEvents[C, E, S]( eventForErrorReporting = OptionVal.Some(event) state = state.copy(seqNr = repr.sequenceNr) - val aaMetaAndSelfReplica: Option[(ReplicatedEventMetaData, ReplicaId)] = + val aaMetaAndSelfReplica: Option[(ReplicatedEventMetaData, ReplicaId, ActiveActive)] = setup.activeActive match { case Some(aa) => val meta = repr.metadata match { @@ -133,19 +134,30 @@ private[akka] final class ReplayingEvents[C, E, S]( } aa.setContext(recoveryRunning = true, meta.originReplica, meta.concurrent) - Some(meta -> aa.replicaId) + Some((meta, aa.replicaId, aa)) case None => None } + val newState = setup.eventHandler(state.state, event) + setup.activeActive match { + case Some(aa) => + aa.clearContext() + case None => + } + aaMetaAndSelfReplica match { - case Some((meta, selfReplica)) if meta.originReplica != selfReplica => + case Some((meta, selfReplica, aa)) if meta.originReplica != selfReplica => // keep track of highest origin seqnr per other replica state = state.copy( state = newState, eventSeenInInterval = true, version = meta.version, seenSeqNrPerReplica = state.seenSeqNrPerReplica + (meta.originReplica -> meta.originSequenceNr)) + aa.clearContext() + case Some((_, _, aa)) => + aa.clearContext() + state = state.copy(state = newState, eventSeenInInterval = true) case _ => state = state.copy(state = newState, eventSeenInInterval = true) } diff --git a/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/Running.scala b/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/Running.scala index fcadd33a99..7ec2bf42d0 100644 --- a/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/Running.scala +++ b/akka-persistence-typed/src/main/scala/akka/persistence/typed/internal/Running.scala @@ -377,13 +377,19 @@ private[akka] object Running { this } + def withContext[A](aa: ActiveActive, withActiveActive: ActiveActive => Unit, f: () => A): A = { + withActiveActive(aa) + val result = f() + aa.clearContext() + result + } + private def handleExternalReplicatedEventPersist( activeActive: ActiveActive, event: ReplicatedEvent[E]): Behavior[InternalProtocol] = { _currentSequenceNumber = state.seqNr + 1 val isConcurrent: Boolean = event.originVersion <> state.version val updatedVersion = event.originVersion.merge(state.version) - activeActive.setContext(false, event.originReplica, isConcurrent) if (setup.log.isDebugEnabled()) setup.log.debugN( @@ -394,7 +400,11 @@ private[akka] object Running { updatedVersion, isConcurrent) - val newState: RunningState[S] = state.applyEvent(setup, event.event) + val newState: RunningState[S] = withContext( + activeActive, + aa => aa.setContext(recoveryRunning = false, event.originReplica, concurrent = isConcurrent), + () => state.applyEvent(setup, event.event)) + val newState2: RunningState[S] = internalPersist( setup.context, null, @@ -423,13 +433,17 @@ private[akka] object Running { // also, ensure that there is an event handler for each single event _currentSequenceNumber = state.seqNr + 1 - setup.activeActive.foreach { aa => - // set concurrent to false, local events are never concurrent - aa.setContext(recoveryRunning = false, aa.replicaId, concurrent = false) + val newState: RunningState[S] = setup.activeActive match { + case Some(aa) => + // set concurrent to false, local events are never concurrent + withContext( + aa, + aa => aa.setContext(recoveryRunning = false, aa.replicaId, concurrent = false), + () => state.applyEvent(setup, event)) + case None => + state.applyEvent(setup, event) } - val newState: RunningState[S] = state.applyEvent(setup, event) - val eventToPersist = adaptEvent(event) val eventAdapterManifest = setup.eventAdapter.manifest(event) @@ -500,7 +514,17 @@ private[akka] object Running { Some(template.copy(originSequenceNr = _currentSequenceNumber, version = updatedVersion)) case None => None } - currentState = currentState.applyEvent(setup, event) + + currentState = setup.activeActive match { + case Some(aa) => + withContext( + aa, + aa => aa.setContext(recoveryRunning = false, aa.replicaId, concurrent = false), + () => currentState.applyEvent(setup, event)) + case None => + currentState.applyEvent(setup, event) + } + eventsToPersist = EventToPersist(adaptedEvent, evtManifest, eventMetadata) :: eventsToPersist } diff --git a/akka-persistence-typed/src/main/scala/akka/persistence/typed/scaladsl/ActiveActiveEventSourcing.scala b/akka-persistence-typed/src/main/scala/akka/persistence/typed/scaladsl/ActiveActiveEventSourcing.scala index 1f7b3dee1f..eec1a91976 100644 --- a/akka-persistence-typed/src/main/scala/akka/persistence/typed/scaladsl/ActiveActiveEventSourcing.scala +++ b/akka-persistence-typed/src/main/scala/akka/persistence/typed/scaladsl/ActiveActiveEventSourcing.scala @@ -6,7 +6,8 @@ package akka.persistence.typed.scaladsl import akka.persistence.typed.PersistenceId import akka.persistence.typed.ReplicaId -import akka.util.WallClock +import akka.util.{ OptionVal, WallClock } + import scala.collection.JavaConverters._ // FIXME docs @@ -34,30 +35,54 @@ private[akka] class ActiveActiveContextImpl( extends ActiveActiveContext with akka.persistence.typed.javadsl.ActiveActiveContext { val allReplicas: Set[ReplicaId] = replicasAndQueryPlugins.keySet + + // these are not volatile as they are set on the same thread as they should be accessed var _origin: ReplicaId = null var _recoveryRunning: Boolean = false var _concurrent: Boolean = false + var _currentThread: OptionVal[Thread] = OptionVal.None - // FIXME check illegal access https://github.com/akka/akka/issues/29264 + private def checkAccess(functionName: String): Unit = { + val callerThread = Thread.currentThread() + def error() = + throw new UnsupportedOperationException( + s"Unsupported access to ActiveActiveContext operation from the outside of event handler. " + + s"$functionName can only be called from the event handler") + _currentThread match { + case OptionVal.Some(t) => + if (callerThread ne t) error() + case OptionVal.None => + error() + } + } /** * The origin of the current event. * Undefined result if called from anywhere other than an event handler. */ - override def origin: ReplicaId = _origin + override def origin: ReplicaId = { + checkAccess("origin") + _origin + } /** * Whether the happened concurrently with an event from another replica. * Undefined result if called from any where other than an event handler. */ - override def concurrent: Boolean = _concurrent + override def concurrent: Boolean = { + checkAccess("concurrent") + _concurrent + } override def persistenceId: PersistenceId = PersistenceId.replicatedUniqueId(entityId, replicaId) override def currentTimeMillis(): Long = { WallClock.AlwaysIncreasingClock.currentTimeMillis() } - override def recoveryRunning: Boolean = _recoveryRunning + override def recoveryRunning: Boolean = { + checkAccess("recoveryRunning") + _recoveryRunning + } override def getAllReplicas: java.util.Set[ReplicaId] = allReplicas.asJava } diff --git a/akka-persistence-typed/src/main/scala/akka/persistence/typed/scaladsl/EventSourcedBehavior.scala b/akka-persistence-typed/src/main/scala/akka/persistence/typed/scaladsl/EventSourcedBehavior.scala index eba5e0a6a8..b40d3a327e 100644 --- a/akka-persistence-typed/src/main/scala/akka/persistence/typed/scaladsl/EventSourcedBehavior.scala +++ b/akka-persistence-typed/src/main/scala/akka/persistence/typed/scaladsl/EventSourcedBehavior.scala @@ -20,6 +20,7 @@ import akka.persistence.typed.ReplicaId import akka.persistence.typed.SnapshotAdapter import akka.persistence.typed.SnapshotSelectionCriteria import akka.persistence.typed.internal._ +import akka.util.OptionVal object EventSourcedBehavior { @@ -36,11 +37,19 @@ object EventSourcedBehavior { * Must only be called on the same thread that will execute the user code */ def setContext(recoveryRunning: Boolean, originReplica: ReplicaId, concurrent: Boolean): Unit = { + aaContext._currentThread = OptionVal.Some(Thread.currentThread()) aaContext._recoveryRunning = recoveryRunning aaContext._concurrent = concurrent aaContext._origin = originReplica } + def clearContext(): Unit = { + aaContext._currentThread = OptionVal.None + aaContext._recoveryRunning = false + aaContext._concurrent = false + aaContext._origin = null + } + } /**