From 2d3f3df9718456624c4116febe1f14fd227ff63d Mon Sep 17 00:00:00 2001 From: Peter Vlugter Date: Sat, 6 Nov 2010 21:13:00 +1300 Subject: [PATCH] Add new mechanism for coordinated transactions --- akka-stm/src/main/scala/stm/Coordinated.scala | 51 ++++++++ .../scala/stm/CoordinatedIncrementSpec.scala | 91 +++++++++++++ .../test/scala/stm/FickleFriendsSpec.scala | 121 ++++++++++++++++++ 3 files changed, 263 insertions(+) create mode 100644 akka-stm/src/main/scala/stm/Coordinated.scala create mode 100644 akka-stm/src/test/scala/stm/CoordinatedIncrementSpec.scala create mode 100644 akka-stm/src/test/scala/stm/FickleFriendsSpec.scala diff --git a/akka-stm/src/main/scala/stm/Coordinated.scala b/akka-stm/src/main/scala/stm/Coordinated.scala new file mode 100644 index 0000000000..94a7d08828 --- /dev/null +++ b/akka-stm/src/main/scala/stm/Coordinated.scala @@ -0,0 +1,51 @@ +package akka.stm + +import akka.config.Config + +import org.multiverse.api.{Transaction => MultiverseTransaction} +import org.multiverse.commitbarriers.CountDownCommitBarrier +import org.multiverse.templates.TransactionalCallable + +/** + * Coordinated transactions across actors. + */ +object Coordinated { + val DefaultFactory = TransactionFactory(DefaultTransactionConfig, "DefaultCoordinatedTransaction") + val Fair = Config.config.getBool("akka.stm.fair", true) + + def apply(message: Any = null) = new Coordinated(message, createBarrier) + + def unapply(c: Coordinated): Option[Any] = Some(c.message) + + def createBarrier = new CountDownCommitBarrier(1, Fair) +} + +/** + * Coordinated transactions across actors. + */ +class Coordinated(val message: Any, barrier: CountDownCommitBarrier) { + def apply(msg: Any) = { + barrier.incParties(1) + new Coordinated(msg, barrier) + } + + def atomic[T](body: => T)(implicit factory: TransactionFactory = Coordinated.DefaultFactory): T = + atomic(factory)(body) + + def atomic[T](factory: TransactionFactory)(body: => T): T = { + factory.boilerplate.execute(new TransactionalCallable[T]() { + def call(mtx: MultiverseTransaction): T = { + factory.addHooks + val result = body + val timeout = factory.config.timeout + try { + barrier.tryJoinCommit(mtx, timeout.length, timeout.unit) + } catch { + // Need to catch IllegalStateException until we have fix in Multiverse, since it throws it by mistake + case e: IllegalStateException => () + } + result + } + }) + } +} diff --git a/akka-stm/src/test/scala/stm/CoordinatedIncrementSpec.scala b/akka-stm/src/test/scala/stm/CoordinatedIncrementSpec.scala new file mode 100644 index 0000000000..702f2abdc3 --- /dev/null +++ b/akka-stm/src/test/scala/stm/CoordinatedIncrementSpec.scala @@ -0,0 +1,91 @@ +package akka.stm.test + +import org.scalatest.WordSpec +import org.scalatest.matchers.MustMatchers + +import akka.actor.{Actor, ActorRef} +import akka.stm._ +import akka.util.duration._ + +import java.util.concurrent.CountDownLatch + +object CoordinatedIncrement { + case class Increment(friends: Seq[ActorRef], latch: CountDownLatch) + case object GetCount + + class Counter(name: String) extends Actor { + val count = Ref(0) + + implicit val txFactory = TransactionFactory(timeout = 3 seconds) + + def increment = { + log.info(name + ": incrementing") + count alter (_ + 1) + } + + def receive = { + case coordinated @ Coordinated(Increment(friends, latch)) => { + if (friends.nonEmpty) { + friends.head ! coordinated(Increment(friends.tail, latch)) + } + coordinated atomic { + increment + deferred { latch.countDown } + compensating { latch.countDown } + } + } + + case GetCount => self.reply(count.get) + } + } + + class Failer extends Actor { + def receive = { + case Coordinated(Increment(friends, latch)) => { + throw new RuntimeException("FAIL") + } + } + } +} + +class CoordinatedIncrementSpec extends WordSpec with MustMatchers { + import CoordinatedIncrement._ + + val numCounters = 10 + val timeout = 5 seconds + + def createActors = { + def createCounter(i: Int) = Actor.actorOf(new Counter("counter" + i)).start + val counters = (1 to numCounters) map createCounter + val failer = Actor.actorOf(new Failer).start + (counters, failer) + } + + "coordinated friendly increment" should { + "increment all counters by one with successful transactions" in { + val (counters, failer) = createActors + val incrementLatch = new CountDownLatch(numCounters) + counters(0) ! Coordinated(Increment(counters.tail, incrementLatch)) + incrementLatch.await(timeout.length, timeout.unit) + for (counter <- counters) { + (counter !! GetCount).get must be === 1 + } + counters foreach (_.stop) + failer.stop + } + + "increment no counters with a failing transaction" in { + val (counters, failer) = createActors + val failLatch = new CountDownLatch(numCounters + 1) + counters(0) ! Coordinated(Increment(counters.tail :+ failer, failLatch)) + failLatch.await(timeout.length, timeout.unit) + for (counter <- counters) { + (counter !! GetCount).get must be === 0 + } + counters foreach (_.stop) + failer.stop + } + } +} + + diff --git a/akka-stm/src/test/scala/stm/FickleFriendsSpec.scala b/akka-stm/src/test/scala/stm/FickleFriendsSpec.scala new file mode 100644 index 0000000000..825eae5960 --- /dev/null +++ b/akka-stm/src/test/scala/stm/FickleFriendsSpec.scala @@ -0,0 +1,121 @@ +package akka.stm.test + +import org.scalatest.WordSpec +import org.scalatest.matchers.MustMatchers + +import akka.actor.{Actor, ActorRef} +import akka.stm._ +import akka.util.duration._ + +import scala.util.Random.{nextInt => random} + +import java.util.concurrent.CountDownLatch + +object FickleFriends { + case class FriendlyIncrement(friends: Seq[ActorRef], latch: CountDownLatch) + case class Increment(friends: Seq[ActorRef]) + case object GetCount + + /** + * Coordinator will keep trying to coordinate an increment until successful. + */ + class Coordinator(name: String) extends Actor { + val count = Ref(0) + + implicit val txFactory = TransactionFactory(timeout = 3 seconds) + + def increment = { + log.info(name + ": incrementing") + count alter (_ + 1) + } + + def receive = { + case FriendlyIncrement(friends, latch) => { + var success = false + while (!success) { + try { + val coordinated = Coordinated() + if (friends.nonEmpty) { + friends.head ! coordinated(Increment(friends.tail)) + } + coordinated atomic { + increment + deferred { + success = true + latch.countDown + } + } + } catch { + case _ => () // swallow exceptions + } + } + } + + case GetCount => self.reply(count.get) + } + } + + /** + * FickleCounter randomly fails at different points with 50% chance of failing overall. + */ + class FickleCounter(name: String) extends Actor { + val count = Ref(0) + + implicit val txFactory = TransactionFactory(timeout = 3 seconds) + + def increment = { + log.info(name + ": incrementing") + count alter (_ + 1) + } + + def failIf(x: Int, y: Int) = { + if (x == y) throw new RuntimeException("Fail at " + x) + } + + def receive = { + case coordinated @ Coordinated(Increment(friends)) => { + val failAt = random(8) + failIf(failAt, 0) + if (friends.nonEmpty) { + friends.head ! coordinated(Increment(friends.tail)) + } + failIf(failAt, 1) + coordinated atomic { + failIf(failAt, 2) + increment + failIf(failAt, 3) + } + } + + case GetCount => self.reply(count.get) + } + } +} + +class FickleFriendsSpec extends WordSpec with MustMatchers { + import FickleFriends._ + + val numCounters = 2 + + def createActors = { + def createCounter(i: Int) = Actor.actorOf(new FickleCounter("counter" + i)).start + val counters = (1 to numCounters) map createCounter + val coordinator = Actor.actorOf(new Coordinator("coordinator")).start + (counters, coordinator) + } + + "coordinated fickle friends" should { + "eventually succeed to increment all counters by one" in { + val (counters, coordinator) = createActors + val latch = new CountDownLatch(1) + coordinator ! FriendlyIncrement(counters, latch) + latch.await // this could take a while + (coordinator !! GetCount).get must be === 1 + for (counter <- counters) { + (counter !! GetCount).get must be === 1 + } + counters foreach (_.stop) + coordinator.stop + } + } +}