Add new mechanism for coordinated transactions
This commit is contained in:
parent
99d6d6b20d
commit
2d3f3df971
3 changed files with 263 additions and 0 deletions
51
akka-stm/src/main/scala/stm/Coordinated.scala
Normal file
51
akka-stm/src/main/scala/stm/Coordinated.scala
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
package akka.stm
|
||||
|
||||
import akka.config.Config
|
||||
|
||||
import org.multiverse.api.{Transaction => MultiverseTransaction}
|
||||
import org.multiverse.commitbarriers.CountDownCommitBarrier
|
||||
import org.multiverse.templates.TransactionalCallable
|
||||
|
||||
/**
|
||||
* Coordinated transactions across actors.
|
||||
*/
|
||||
object Coordinated {
|
||||
val DefaultFactory = TransactionFactory(DefaultTransactionConfig, "DefaultCoordinatedTransaction")
|
||||
val Fair = Config.config.getBool("akka.stm.fair", true)
|
||||
|
||||
def apply(message: Any = null) = new Coordinated(message, createBarrier)
|
||||
|
||||
def unapply(c: Coordinated): Option[Any] = Some(c.message)
|
||||
|
||||
def createBarrier = new CountDownCommitBarrier(1, Fair)
|
||||
}
|
||||
|
||||
/**
|
||||
* Coordinated transactions across actors.
|
||||
*/
|
||||
class Coordinated(val message: Any, barrier: CountDownCommitBarrier) {
|
||||
def apply(msg: Any) = {
|
||||
barrier.incParties(1)
|
||||
new Coordinated(msg, barrier)
|
||||
}
|
||||
|
||||
def atomic[T](body: => T)(implicit factory: TransactionFactory = Coordinated.DefaultFactory): T =
|
||||
atomic(factory)(body)
|
||||
|
||||
def atomic[T](factory: TransactionFactory)(body: => T): T = {
|
||||
factory.boilerplate.execute(new TransactionalCallable[T]() {
|
||||
def call(mtx: MultiverseTransaction): T = {
|
||||
factory.addHooks
|
||||
val result = body
|
||||
val timeout = factory.config.timeout
|
||||
try {
|
||||
barrier.tryJoinCommit(mtx, timeout.length, timeout.unit)
|
||||
} catch {
|
||||
// Need to catch IllegalStateException until we have fix in Multiverse, since it throws it by mistake
|
||||
case e: IllegalStateException => ()
|
||||
}
|
||||
result
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
91
akka-stm/src/test/scala/stm/CoordinatedIncrementSpec.scala
Normal file
91
akka-stm/src/test/scala/stm/CoordinatedIncrementSpec.scala
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
package akka.stm.test
|
||||
|
||||
import org.scalatest.WordSpec
|
||||
import org.scalatest.matchers.MustMatchers
|
||||
|
||||
import akka.actor.{Actor, ActorRef}
|
||||
import akka.stm._
|
||||
import akka.util.duration._
|
||||
|
||||
import java.util.concurrent.CountDownLatch
|
||||
|
||||
object CoordinatedIncrement {
|
||||
case class Increment(friends: Seq[ActorRef], latch: CountDownLatch)
|
||||
case object GetCount
|
||||
|
||||
class Counter(name: String) extends Actor {
|
||||
val count = Ref(0)
|
||||
|
||||
implicit val txFactory = TransactionFactory(timeout = 3 seconds)
|
||||
|
||||
def increment = {
|
||||
log.info(name + ": incrementing")
|
||||
count alter (_ + 1)
|
||||
}
|
||||
|
||||
def receive = {
|
||||
case coordinated @ Coordinated(Increment(friends, latch)) => {
|
||||
if (friends.nonEmpty) {
|
||||
friends.head ! coordinated(Increment(friends.tail, latch))
|
||||
}
|
||||
coordinated atomic {
|
||||
increment
|
||||
deferred { latch.countDown }
|
||||
compensating { latch.countDown }
|
||||
}
|
||||
}
|
||||
|
||||
case GetCount => self.reply(count.get)
|
||||
}
|
||||
}
|
||||
|
||||
class Failer extends Actor {
|
||||
def receive = {
|
||||
case Coordinated(Increment(friends, latch)) => {
|
||||
throw new RuntimeException("FAIL")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class CoordinatedIncrementSpec extends WordSpec with MustMatchers {
|
||||
import CoordinatedIncrement._
|
||||
|
||||
val numCounters = 10
|
||||
val timeout = 5 seconds
|
||||
|
||||
def createActors = {
|
||||
def createCounter(i: Int) = Actor.actorOf(new Counter("counter" + i)).start
|
||||
val counters = (1 to numCounters) map createCounter
|
||||
val failer = Actor.actorOf(new Failer).start
|
||||
(counters, failer)
|
||||
}
|
||||
|
||||
"coordinated friendly increment" should {
|
||||
"increment all counters by one with successful transactions" in {
|
||||
val (counters, failer) = createActors
|
||||
val incrementLatch = new CountDownLatch(numCounters)
|
||||
counters(0) ! Coordinated(Increment(counters.tail, incrementLatch))
|
||||
incrementLatch.await(timeout.length, timeout.unit)
|
||||
for (counter <- counters) {
|
||||
(counter !! GetCount).get must be === 1
|
||||
}
|
||||
counters foreach (_.stop)
|
||||
failer.stop
|
||||
}
|
||||
|
||||
"increment no counters with a failing transaction" in {
|
||||
val (counters, failer) = createActors
|
||||
val failLatch = new CountDownLatch(numCounters + 1)
|
||||
counters(0) ! Coordinated(Increment(counters.tail :+ failer, failLatch))
|
||||
failLatch.await(timeout.length, timeout.unit)
|
||||
for (counter <- counters) {
|
||||
(counter !! GetCount).get must be === 0
|
||||
}
|
||||
counters foreach (_.stop)
|
||||
failer.stop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
121
akka-stm/src/test/scala/stm/FickleFriendsSpec.scala
Normal file
121
akka-stm/src/test/scala/stm/FickleFriendsSpec.scala
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
package akka.stm.test
|
||||
|
||||
import org.scalatest.WordSpec
|
||||
import org.scalatest.matchers.MustMatchers
|
||||
|
||||
import akka.actor.{Actor, ActorRef}
|
||||
import akka.stm._
|
||||
import akka.util.duration._
|
||||
|
||||
import scala.util.Random.{nextInt => random}
|
||||
|
||||
import java.util.concurrent.CountDownLatch
|
||||
|
||||
object FickleFriends {
|
||||
case class FriendlyIncrement(friends: Seq[ActorRef], latch: CountDownLatch)
|
||||
case class Increment(friends: Seq[ActorRef])
|
||||
case object GetCount
|
||||
|
||||
/**
|
||||
* Coordinator will keep trying to coordinate an increment until successful.
|
||||
*/
|
||||
class Coordinator(name: String) extends Actor {
|
||||
val count = Ref(0)
|
||||
|
||||
implicit val txFactory = TransactionFactory(timeout = 3 seconds)
|
||||
|
||||
def increment = {
|
||||
log.info(name + ": incrementing")
|
||||
count alter (_ + 1)
|
||||
}
|
||||
|
||||
def receive = {
|
||||
case FriendlyIncrement(friends, latch) => {
|
||||
var success = false
|
||||
while (!success) {
|
||||
try {
|
||||
val coordinated = Coordinated()
|
||||
if (friends.nonEmpty) {
|
||||
friends.head ! coordinated(Increment(friends.tail))
|
||||
}
|
||||
coordinated atomic {
|
||||
increment
|
||||
deferred {
|
||||
success = true
|
||||
latch.countDown
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case _ => () // swallow exceptions
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case GetCount => self.reply(count.get)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* FickleCounter randomly fails at different points with 50% chance of failing overall.
|
||||
*/
|
||||
class FickleCounter(name: String) extends Actor {
|
||||
val count = Ref(0)
|
||||
|
||||
implicit val txFactory = TransactionFactory(timeout = 3 seconds)
|
||||
|
||||
def increment = {
|
||||
log.info(name + ": incrementing")
|
||||
count alter (_ + 1)
|
||||
}
|
||||
|
||||
def failIf(x: Int, y: Int) = {
|
||||
if (x == y) throw new RuntimeException("Fail at " + x)
|
||||
}
|
||||
|
||||
def receive = {
|
||||
case coordinated @ Coordinated(Increment(friends)) => {
|
||||
val failAt = random(8)
|
||||
failIf(failAt, 0)
|
||||
if (friends.nonEmpty) {
|
||||
friends.head ! coordinated(Increment(friends.tail))
|
||||
}
|
||||
failIf(failAt, 1)
|
||||
coordinated atomic {
|
||||
failIf(failAt, 2)
|
||||
increment
|
||||
failIf(failAt, 3)
|
||||
}
|
||||
}
|
||||
|
||||
case GetCount => self.reply(count.get)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class FickleFriendsSpec extends WordSpec with MustMatchers {
|
||||
import FickleFriends._
|
||||
|
||||
val numCounters = 2
|
||||
|
||||
def createActors = {
|
||||
def createCounter(i: Int) = Actor.actorOf(new FickleCounter("counter" + i)).start
|
||||
val counters = (1 to numCounters) map createCounter
|
||||
val coordinator = Actor.actorOf(new Coordinator("coordinator")).start
|
||||
(counters, coordinator)
|
||||
}
|
||||
|
||||
"coordinated fickle friends" should {
|
||||
"eventually succeed to increment all counters by one" in {
|
||||
val (counters, coordinator) = createActors
|
||||
val latch = new CountDownLatch(1)
|
||||
coordinator ! FriendlyIncrement(counters, latch)
|
||||
latch.await // this could take a while
|
||||
(coordinator !! GetCount).get must be === 1
|
||||
for (counter <- counters) {
|
||||
(counter !! GetCount).get must be === 1
|
||||
}
|
||||
counters foreach (_.stop)
|
||||
coordinator.stop
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue