diff --git a/akka-persistence/src/test/scala/akka/persistence/PersistentActorSpec.scala b/akka-persistence/src/test/scala/akka/persistence/PersistentActorSpec.scala index a2f466eda1..3071dca90d 100644 --- a/akka-persistence/src/test/scala/akka/persistence/PersistentActorSpec.scala +++ b/akka-persistence/src/test/scala/akka/persistence/PersistentActorSpec.scala @@ -188,35 +188,6 @@ object PersistentActorSpec { } } - class UserStashPersistentActor(name: String) extends ExamplePersistentActor(name) { - var stashed = false - val receiveCommand: Receive = { - case Cmd("a") ⇒ if (!stashed) { stash(); stashed = true } else sender() ! "a" - case Cmd("b") ⇒ persist(Evt("b"))(evt ⇒ sender() ! evt.data) - case Cmd("c") ⇒ unstashAll(); sender() ! "c" - } - } - - class UserStashManyPersistentActor(name: String) extends ExamplePersistentActor(name) { - val receiveCommand: Receive = commonBehavior orElse { - case Cmd("a") ⇒ persist(Evt("a")) { evt ⇒ - updateState(evt) - context.become(processC) - } - case Cmd("b-1") ⇒ persist(Evt("b-1"))(updateState) - case Cmd("b-2") ⇒ persist(Evt("b-2"))(updateState) - } - - val processC: Receive = { - case Cmd("c") ⇒ - persist(Evt("c")) { evt ⇒ - updateState(evt) - context.unbecome() - } - unstashAll() - case other ⇒ stash() - } - } class AsyncPersistPersistentActor(name: String) extends ExamplePersistentActor(name) { var counter = 0 @@ -346,27 +317,6 @@ object PersistentActorSpec { } } - class UserStashFailurePersistentActor(name: String) extends ExamplePersistentActor(name) { - val receiveCommand: Receive = commonBehavior orElse { - case Cmd(data) ⇒ - if (data == "b-2") throw new TestException("boom") - persist(Evt(data)) { event ⇒ - updateState(event) - if (data == "a") context.become(otherCommandHandler) - } - } - - val otherCommandHandler: Receive = { - case Cmd("c") ⇒ - persist(Evt("c")) { event ⇒ - updateState(event) - context.unbecome() - } - unstashAll() - case other ⇒ stash() - } - } - class AnyValEventPersistentActor(name: String) extends ExamplePersistentActor(name) { val receiveCommand: Receive = { case Cmd("a") ⇒ persist(5)(evt ⇒ sender() ! evt) @@ -776,34 +726,6 @@ abstract class PersistentActorSpec(config: Config) extends PersistenceSpec(confi persistentActor ! Cmd("a") expectMsg("a") } - "support user stash operations" in { - val persistentActor = namedPersistentActor[UserStashPersistentActor] - persistentActor ! Cmd("a") - persistentActor ! Cmd("b") - persistentActor ! Cmd("c") - expectMsg("b") - expectMsg("c") - expectMsg("a") - } - "support user stash operations with several stashed messages" in { - val persistentActor = namedPersistentActor[UserStashManyPersistentActor] - val n = 10 - val cmds = 1 to n flatMap (_ ⇒ List(Cmd("a"), Cmd("b-1"), Cmd("b-2"), Cmd("c"))) - val evts = 1 to n flatMap (_ ⇒ List("a", "c", "b-1", "b-2")) - - cmds foreach (persistentActor ! _) - persistentActor ! GetState - expectMsg((List("a-1", "a-2") ++ evts)) - } - "support user stash operations under failures" in { - val persistentActor = namedPersistentActor[UserStashFailurePersistentActor] - val bs = 1 to 10 map ("b-" + _) - persistentActor ! Cmd("a") - bs foreach (persistentActor ! Cmd(_)) - persistentActor ! Cmd("c") - persistentActor ! GetState - expectMsg(List("a-1", "a-2", "a", "c") ++ bs.filter(_ != "b-2")) - } "be able to persist events that extend AnyVal" in { val persistentActor = namedPersistentActor[AnyValEventPersistentActor] persistentActor ! Cmd("a") diff --git a/akka-persistence/src/test/scala/akka/persistence/PersistentActorStashingSpec.scala b/akka-persistence/src/test/scala/akka/persistence/PersistentActorStashingSpec.scala new file mode 100644 index 0000000000..65709fdbc6 --- /dev/null +++ b/akka-persistence/src/test/scala/akka/persistence/PersistentActorStashingSpec.scala @@ -0,0 +1,175 @@ +/** + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.persistence + +import akka.actor.ActorRef +import akka.persistence.journal.SteppingInmemJournal +import akka.testkit.ImplicitSender +import com.typesafe.config.Config + +import scala.concurrent.duration._ + +object PersistentActorStashingSpec { + final case class Cmd(data: Any) + final case class Evt(data: Any) + + abstract class ExamplePersistentActor(name: String) extends NamedPersistentActor(name) { + var events: List[Any] = Nil + var askedForDelete: Option[ActorRef] = None + + val updateState: Receive = { + case Evt(data) ⇒ events = data :: events + case d @ Some(ref: ActorRef) ⇒ askedForDelete = d.asInstanceOf[Some[ActorRef]] + } + + val commonBehavior: Receive = { + case "boom" ⇒ throw new TestException("boom") + case GetState ⇒ sender() ! events.reverse + } + + def receiveRecover = updateState + } + + class UserStashPersistentActor(name: String) extends ExamplePersistentActor(name) { + var stashed = false + val receiveCommand: Receive = { + case Cmd("a") ⇒ if (!stashed) { stash(); stashed = true } else sender() ! "a" + case Cmd("b") ⇒ persist(Evt("b"))(evt ⇒ sender() ! evt.data) + case Cmd("c") ⇒ unstashAll(); sender() ! "c" + } + } + + class UserStashManyPersistentActor(name: String) extends ExamplePersistentActor(name) { + val receiveCommand: Receive = commonBehavior orElse { + case Cmd("a") ⇒ persist(Evt("a")) { evt ⇒ + updateState(evt) + context.become(processC) + } + case Cmd("b-1") ⇒ persist(Evt("b-1"))(updateState) + case Cmd("b-2") ⇒ persist(Evt("b-2"))(updateState) + } + + val processC: Receive = { + case Cmd("c") ⇒ + persist(Evt("c")) { evt ⇒ + updateState(evt) + context.unbecome() + } + unstashAll() + case other ⇒ stash() + } + } + + class UserStashFailurePersistentActor(name: String) extends ExamplePersistentActor(name) { + val receiveCommand: Receive = commonBehavior orElse { + case Cmd(data) ⇒ + if (data == "b-2") throw new TestException("boom") + persist(Evt(data)) { event ⇒ + updateState(event) + if (data == "a") context.become(otherCommandHandler) + } + } + + val otherCommandHandler: Receive = { + case Cmd("c") ⇒ + persist(Evt("c")) { event ⇒ + updateState(event) + context.unbecome() + } + unstashAll() + case other ⇒ stash() + } + } + + class AsyncStashingPersistentActor(name: String) extends ExamplePersistentActor(name) { + var stashed = false + val receiveCommand: Receive = commonBehavior orElse { + case Cmd("a") ⇒ persistAsync(Evt("a"))(updateState) + case Cmd("b") if !stashed ⇒ + stash(); stashed = true + case Cmd("b") ⇒ persistAsync(Evt("b"))(updateState) + case Cmd("c") ⇒ persistAsync(Evt("c"))(updateState); unstashAll() + } + } + +} + +abstract class PersistentActorStashingSpec(config: Config) extends PersistenceSpec(config) + with ImplicitSender { + + import PersistentActorStashingSpec._ + + "Stashing in a persistent actor" must { + + "support user stash operations" in { + val persistentActor = namedPersistentActor[UserStashPersistentActor] + persistentActor ! Cmd("a") + persistentActor ! Cmd("b") + persistentActor ! Cmd("c") + expectMsg("b") + expectMsg("c") + expectMsg("a") + } + + "support user stash operations with several stashed messages" in { + val persistentActor = namedPersistentActor[UserStashManyPersistentActor] + val n = 10 + val cmds = 1 to n flatMap (_ ⇒ List(Cmd("a"), Cmd("b-1"), Cmd("b-2"), Cmd("c"))) + val evts = 1 to n flatMap (_ ⇒ List("a", "c", "b-1", "b-2")) + + cmds foreach (persistentActor ! _) + persistentActor ! GetState + expectMsg(evts) + } + + "support user stash operations under failures" in { + val persistentActor = namedPersistentActor[UserStashFailurePersistentActor] + val bs = 1 to 10 map ("b-" + _) + persistentActor ! Cmd("a") + bs foreach (persistentActor ! Cmd(_)) + persistentActor ! Cmd("c") + persistentActor ! GetState + expectMsg(List("a", "c") ++ bs.filter(_ != "b-2")) + } + } +} + +class SteppingInMemPersistentActorStashingSpec extends PersistenceSpec( + SteppingInmemJournal.config("persistence-stash").withFallback(PersistenceSpec.config("stepping-inmem", "SteppingInMemPersistentActorStashingSpec"))) + with ImplicitSender { + + import PersistentActorStashingSpec._ + + "Stashing in a persistent actor mixed with persistAsync" should { + + "handle async callback not happening until next message has been stashed" in { + val persistentActor = namedPersistentActor[AsyncStashingPersistentActor] + awaitAssert(SteppingInmemJournal.getRef("persistence-stash"), 3.seconds) + val journal = SteppingInmemJournal.getRef("persistence-stash") + + // initial read highest + SteppingInmemJournal.step(journal) + + persistentActor ! Cmd("a") + persistentActor ! Cmd("b") + + // allow the write to complete, after the stash + SteppingInmemJournal.step(journal) + + persistentActor ! Cmd("c") + // writing of c and b + SteppingInmemJournal.step(journal) + SteppingInmemJournal.step(journal) + + persistentActor ! GetState + expectMsg(List("a", "c", "b")) + } + + } + +} + +class LeveldbPersistentActorStashingSpec extends PersistentActorStashingSpec(PersistenceSpec.config("leveldb", "LeveldbPersistentActorStashingSpec")) +class InmemPersistentActorStashingSpec extends PersistentActorStashingSpec(PersistenceSpec.config("inmem", "InmemPersistentActorStashingSpec")) \ No newline at end of file diff --git a/akka-persistence/src/test/scala/akka/persistence/journal/SteppingInmemJournal.scala b/akka-persistence/src/test/scala/akka/persistence/journal/SteppingInmemJournal.scala new file mode 100644 index 0000000000..261bc37286 --- /dev/null +++ b/akka-persistence/src/test/scala/akka/persistence/journal/SteppingInmemJournal.scala @@ -0,0 +1,150 @@ +/** + * Copyright (C) 2009-2015 Typesafe Inc. + */ + +package akka.persistence.journal + +import akka.actor.{ ActorSystem, ActorRef } +import akka.pattern.ask +import akka.persistence.journal.inmem.InmemJournal +import akka.persistence.{ AtomicWrite, PersistentRepr } +import akka.util.Timeout +import akka.testkit._ +import com.typesafe.config.{ ConfigFactory, Config } + +import scala.collection.immutable.Seq +import scala.concurrent.duration._ +import scala.concurrent.{ Await, Future, Promise } +import scala.util.Try + +object SteppingInmemJournal { + + /** allow the journal to do one operation */ + case object Token + case object TokenConsumed + + /** + * Allow the journal to do one operation, will block until that completes + */ + def step(journal: ActorRef)(implicit system: ActorSystem): Unit = { + implicit val timeout: Timeout = 3.seconds.dilated + Await.result(journal ? SteppingInmemJournal.Token, timeout.duration) + } + + def config(instanceId: String): Config = + ConfigFactory.parseString( + s""" + |akka.persistence.journal.stepping-inmem.class=${classOf[SteppingInmemJournal].getName} + |akka.persistence.journal.plugin = "akka.persistence.journal.stepping-inmem" + |akka.persistence.journal.stepping-inmem.instance-id = "$instanceId" + """.stripMargin) + + // keep it in a thread safe:d global so that tests can get their + // hand on the actor ref and send Steps to it + private[this] var _current: Map[String, ActorRef] = Map() + + // shhh don't tell anyone I sinn-croniz-ed + /** get the actor ref to the journal for a given instance id, throws exception if not found */ + def getRef(instanceId: String): ActorRef = synchronized(_current(instanceId)) + + private def putRef(instanceId: String, instance: ActorRef): Unit = synchronized { + _current = _current + (instanceId -> instance) + } + private def remove(instanceId: String): Unit = synchronized( + _current -= instanceId) +} + +/** + * An in memory journal that will not complete any persists or persistAsyncs until it gets tokens + * to trigger those steps. Allows for tests that need to deterministically trigger the callbacks + * intermixed with receiving messages. + * + * Configure your actor system using {{{SteppingInMemJournal.config}}} and then access + * it using {{{SteppingInmemJournal.getRef(String)}}}, send it {{{SteppingInmemJournal.Token}}}s to + * allow one journal operation to complete. + */ +final class SteppingInmemJournal extends InmemJournal { + + import SteppingInmemJournal._ + import context.dispatcher + + val instanceId = context.system.settings.config.getString("akka.persistence.journal.stepping-inmem.instance-id") + + var queuedOps: Seq[() ⇒ Future[Unit]] = Seq.empty + var queuedTokenRecipients = List.empty[ActorRef] + + override def receivePluginInternal = super.receivePluginInternal orElse { + case Token if queuedOps.isEmpty ⇒ queuedTokenRecipients = queuedTokenRecipients :+ sender() + case Token ⇒ + val op +: rest = queuedOps + queuedOps = rest + val tokenConsumer = sender() + op().onComplete(_ ⇒ tokenConsumer ! TokenConsumed) + } + + override def preStart(): Unit = { + SteppingInmemJournal.putRef(instanceId, self) + super.preStart() + } + + override def postStop(): Unit = { + super.postStop() + SteppingInmemJournal.remove(instanceId) + } + + override def asyncWriteMessages(messages: Seq[AtomicWrite]): Future[Seq[Try[Unit]]] = { + val futures = messages.map { message ⇒ + val promise = Promise[Try[Unit]]() + val future = promise.future + doOrEnqueue { () ⇒ + promise.completeWith(super.asyncWriteMessages(Seq(message)).map(_.head)) + future.map(_ ⇒ ()) + } + future + } + + Future.sequence(futures) + } + + override def asyncDeleteMessagesTo(persistenceId: String, toSequenceNr: Long): Future[Unit] = { + val promise = Promise[Unit]() + val future = promise.future + doOrEnqueue { () ⇒ + promise.completeWith(super.asyncDeleteMessagesTo(persistenceId, toSequenceNr)) + future + } + future + } + + override def asyncReadHighestSequenceNr(persistenceId: String, fromSequenceNr: Long): Future[Long] = { + val promise = Promise[Long]() + val future = promise.future + doOrEnqueue { () ⇒ + promise.completeWith(super.asyncReadHighestSequenceNr(persistenceId, fromSequenceNr)) + future.map(_ ⇒ ()) + } + future + } + + override def asyncReplayMessages(persistenceId: String, fromSequenceNr: Long, toSequenceNr: Long, max: Long)(recoveryCallback: (PersistentRepr) ⇒ Unit): Future[Unit] = { + val promise = Promise[Unit]() + val future = promise.future + doOrEnqueue { () ⇒ + promise.completeWith(super.asyncReplayMessages(persistenceId, fromSequenceNr, toSequenceNr, max)(recoveryCallback)) + future + } + + future + } + + private def doOrEnqueue(op: () ⇒ Future[Unit]): Unit = { + if (queuedTokenRecipients.nonEmpty) { + val completed = op() + val tokenRecipient +: rest = queuedTokenRecipients + queuedTokenRecipients = rest + completed.onComplete(_ ⇒ tokenRecipient ! TokenConsumed) + } else { + queuedOps = queuedOps :+ op + } + } +}