From f04fbba47b8501090cb7ea4be4285e44f5b8b51c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20Bon=C3=A9r?= Date: Mon, 5 Apr 2010 11:53:43 +0200 Subject: [PATCH] Refactored STM API into Transaction.Global and Transaction.Local, fixes issues with "atomic" outside actors --- akka-core/src/main/scala/actor/Actor.scala | 2 +- .../src/main/scala/stm/Transaction.scala | 116 ++++++++++++------ .../scala/stm/TransactionManagement.scala | 11 +- .../main/scala/stm/TransactionalState.scala | 1 - akka-core/src/test/scala/AgentTest.scala | 2 +- akka-core/src/test/scala/StmSpec.scala | 81 ++++++++++++ .../src/main/scala/ChatServer.scala | 2 +- 7 files changed, 171 insertions(+), 44 deletions(-) create mode 100644 akka-core/src/test/scala/StmSpec.scala diff --git a/akka-core/src/main/scala/actor/Actor.scala b/akka-core/src/main/scala/actor/Actor.scala index efd635c21a..f466333388 100644 --- a/akka-core/src/main/scala/actor/Actor.scala +++ b/akka-core/src/main/scala/actor/Actor.scala @@ -8,7 +8,7 @@ import se.scalablesolutions.akka.dispatch._ import se.scalablesolutions.akka.config.Config._ import se.scalablesolutions.akka.config.{AllForOneStrategy, OneForOneStrategy, FaultHandlingStrategy} import se.scalablesolutions.akka.config.ScalaConfig._ -import se.scalablesolutions.akka.stm.Transaction._ +import se.scalablesolutions.akka.stm.Transaction.Global._ import se.scalablesolutions.akka.stm.TransactionManagement._ import se.scalablesolutions.akka.stm.TransactionManagement import se.scalablesolutions.akka.remote.protobuf.RemoteProtocol.RemoteRequest diff --git a/akka-core/src/main/scala/stm/Transaction.scala b/akka-core/src/main/scala/stm/Transaction.scala index 29d4c586e9..221ab86bd6 100644 --- a/akka-core/src/main/scala/stm/Transaction.scala +++ b/akka-core/src/main/scala/stm/Transaction.scala @@ -12,7 +12,7 @@ import scala.collection.mutable.HashMap import se.scalablesolutions.akka.util.Logging -import org.multiverse.api.{Transaction => MultiverseTransaction} +import org.multiverse.api.{Transaction => MultiverseTransaction, TransactionLifecycleListener, TransactionLifecycleEvent} import org.multiverse.api.GlobalStmInstance.getGlobalStmInstance import org.multiverse.api.ThreadLocalTransaction._ import org.multiverse.templates.{TransactionTemplate, OrElseTemplate} @@ -97,9 +97,21 @@ class TransactionRetryException(message: String) extends RuntimeException(messag * * @author Jonas Bonér */ -object Transaction extends TransactionManagement with Logging { +object Transaction { val idFactory = new AtomicLong(-1L) + /** + * Creates a STM atomic transaction and by-passes all transactions hooks + * such as persistence etc. + * + * Only for internal usage. + */ + private[akka] def atomic0[T](body: => T): T = new TransactionTemplate[T]() { + def execute(mtx: MultiverseTransaction): T = body + }.execute() + + object Local extends TransactionManagement with Logging { + /** * See ScalaDoc on Transaction class. */ @@ -116,40 +128,22 @@ object Transaction extends TransactionManagement with Logging { def foreach(f: => Unit): Unit = atomic {f} /** - * See ScalaDoc on Transaction class. + * See ScalaDoc on class. */ def atomic[T](body: => T): T = { - var isTopLevelTransaction = true new TransactionTemplate[T]() { - def execute(mtx: MultiverseTransaction): T = { - val result = body - - val txSet = getTransactionSetInScope - log.trace("Committing transaction [%s]\n\tby joining transaction set [%s]", mtx, txSet) - txSet.joinCommit(mtx) - - // FIXME tryJoinCommit(mtx, TransactionManagement.TRANSACTION_TIMEOUT, TimeUnit.MILLISECONDS) - //getTransactionSetInScope.tryJoinCommit(mtx, TransactionManagement.TRANSACTION_TIMEOUT, TimeUnit.MILLISECONDS) - - clearTransaction - result - } + def execute(mtx: MultiverseTransaction): T = body override def onStart(mtx: MultiverseTransaction) = { - val txSet = - if (!isTransactionSetInScope) { - isTopLevelTransaction = true - createNewTransactionSet - } else getTransactionSetInScope val tx = new Transaction tx.transaction = Some(mtx) setTransaction(Some(tx)) - - txSet.registerOnCommitTask(new Runnable() { - def run = tx.commit - }) - txSet.registerOnAbortTask(new Runnable() { - def run = tx.abort + mtx.registerLifecycleListener(new TransactionLifecycleListener() { + def notify(tx: MultiverseTransaction, event: TransactionLifecycleEvent) = event.name match { + case "postCommit" => tx.commit + case "postAbort" => tx.abort + case _ => {} + } }) } }.execute() @@ -170,24 +164,70 @@ object Transaction extends TransactionManagement with Logging { def orelserun(t: MultiverseTransaction) = secondBody }.execute() } + } + + object Global extends TransactionManagement with Logging { + /** + * See ScalaDoc on Transaction class. + */ + def map[T](f: => T): T = atomic {f} /** - * Creates a STM atomic transaction and by-passes all transactions hooks - * such as persistence etc. - * - * Only for internal usage. + * See ScalaDoc on Transaction class. */ - private[akka] def atomic0[T](body: => T): T = new TransactionTemplate[T]() { - def execute(mtx: MultiverseTransaction): T = body - }.execute() + def flatMap[T](f: => T): T = atomic {f} + + /** + * See ScalaDoc on Transaction class. + */ + def foreach(f: => Unit): Unit = atomic {f} + + /** + * See ScalaDoc on Transaction class. + */ + def atomic[T](body: => T): T = { + var isTopLevelTransaction = false + new TransactionTemplate[T]() { + def execute(mtx: MultiverseTransaction): T = { + val result = body + + val txSet = getTransactionSetInScope + log.trace("Committing transaction [%s]\n\tby joining transaction set [%s]", mtx, txSet) + txSet.joinCommit(mtx) + + // FIXME tryJoinCommit(mtx, TransactionManagement.TRANSACTION_TIMEOUT, TimeUnit.MILLISECONDS) + //getTransactionSetInScope.tryJoinCommit(mtx, TransactionManagement.TRANSACTION_TIMEOUT, TimeUnit.MILLISECONDS) + + clearTransaction + result + } + + override def onStart(mtx: MultiverseTransaction) = { + val txSet = + if (!isTransactionSetInScope) { + isTopLevelTransaction = true + createNewTransactionSet + } else getTransactionSetInScope + val tx = new Transaction + tx.transaction = Some(mtx) + setTransaction(Some(tx)) + + txSet.registerOnCommitTask(new Runnable() { + def run = tx.commit + }) + txSet.registerOnAbortTask(new Runnable() { + def run = tx.abort + }) + } + }.execute() + } + } } /** * @author Jonas Bonér */ @serializable class Transaction extends Logging { - import Transaction._ - val id = Transaction.idFactory.incrementAndGet @volatile private[this] var status: TransactionStatus = TransactionStatus.New private[akka] var transaction: Option[MultiverseTransaction] = None @@ -200,7 +240,7 @@ object Transaction extends TransactionManagement with Logging { def commit = synchronized { log.trace("Committing transaction %s", toString) - atomic0 { + Transaction.atomic0 { persistentStateMap.valuesIterator.foreach(_.commit) } status = TransactionStatus.Completed diff --git a/akka-core/src/main/scala/stm/TransactionManagement.scala b/akka-core/src/main/scala/stm/TransactionManagement.scala index 48c8c7dd95..401c2379ee 100644 --- a/akka-core/src/main/scala/stm/TransactionManagement.scala +++ b/akka-core/src/main/scala/stm/TransactionManagement.scala @@ -4,6 +4,8 @@ package se.scalablesolutions.akka.stm +import se.scalablesolutions.akka.util.Logging + import java.util.concurrent.atomic.AtomicBoolean import org.multiverse.api.ThreadLocalTransaction._ @@ -49,9 +51,10 @@ object TransactionManagement extends TransactionManagement { } } -trait TransactionManagement { +trait TransactionManagement extends Logging { private[akka] def createNewTransactionSet: CountDownCommitBarrier = { + log.trace("Creating new transaction set") val txSet = new CountDownCommitBarrier(1, TransactionManagement.FAIR_TRANSACTIONS) TransactionManagement.transactionSet.set(Some(txSet)) txSet @@ -63,9 +66,13 @@ trait TransactionManagement { private[akka] def setTransaction(tx: Option[Transaction]) = if (tx.isDefined) TransactionManagement.transaction.set(tx) - private[akka] def clearTransactionSet = TransactionManagement.transactionSet.set(None) + private[akka] def clearTransactionSet = { + log.trace("Clearing transaction set") + TransactionManagement.transactionSet.set(None) + } private[akka] def clearTransaction = { + log.trace("Clearing transaction") TransactionManagement.transaction.set(None) setThreadLocalTransaction(null) } diff --git a/akka-core/src/main/scala/stm/TransactionalState.scala b/akka-core/src/main/scala/stm/TransactionalState.scala index a8b5da83b9..7afb3fb6bb 100644 --- a/akka-core/src/main/scala/stm/TransactionalState.scala +++ b/akka-core/src/main/scala/stm/TransactionalState.scala @@ -4,7 +4,6 @@ package se.scalablesolutions.akka.stm -import se.scalablesolutions.akka.stm.Transaction.atomic import se.scalablesolutions.akka.util.UUID import org.multiverse.stms.alpha.AlphaRef diff --git a/akka-core/src/test/scala/AgentTest.scala b/akka-core/src/test/scala/AgentTest.scala index 01818e1938..b378b67207 100644 --- a/akka-core/src/test/scala/AgentTest.scala +++ b/akka-core/src/test/scala/AgentTest.scala @@ -2,7 +2,7 @@ package se.scalablesolutions.akka.actor import _root_.java.util.concurrent.TimeUnit import se.scalablesolutions.akka.actor.Actor.transactor -import se.scalablesolutions.akka.stm.Transaction.atomic +import se.scalablesolutions.akka.stm.Transaction.Global.atomic import se.scalablesolutions.akka.util.Logging import org.scalatest.Suite diff --git a/akka-core/src/test/scala/StmSpec.scala b/akka-core/src/test/scala/StmSpec.scala new file mode 100644 index 0000000000..121407426a --- /dev/null +++ b/akka-core/src/test/scala/StmSpec.scala @@ -0,0 +1,81 @@ +package se.scalablesolutions.akka.actor + +import se.scalablesolutions.akka.stm.Transaction.Local._ +import se.scalablesolutions.akka.stm._ + +import org.scalatest.Spec +import org.scalatest.Assertions +import org.scalatest.matchers.ShouldMatchers +import org.scalatest.BeforeAndAfterAll +import org.scalatest.junit.JUnitRunner +import org.junit.runner.RunWith + +@RunWith(classOf[JUnitRunner]) +class StmSpec extends + Spec with + ShouldMatchers with + BeforeAndAfterAll { + + describe("STM outside actors") { + it("should be able to do multiple consecutive atomic {..} statements") { + + lazy val ref = TransactionalState.newRef[Int] + + def increment = atomic { + ref.swap(ref.get.getOrElse(0) + 1) + } + + def total: Int = atomic { + ref.get.getOrElse(0) + } + + increment + increment + increment + total should equal(3) + } + + it("should be able to do nested atomic {..} statements") { + + lazy val ref = TransactionalState.newRef[Int] + + def increment = atomic { + ref.swap(ref.get.getOrElse(0) + 1) + } + def total: Int = atomic { + ref.get.getOrElse(0) + } + + atomic { + increment + increment + } + atomic { + increment + total should equal(3) + } + } + + it("should roll back failing nested atomic {..} statements") { + + lazy val ref = TransactionalState.newRef[Int] + + def increment = atomic { + ref.swap(ref.get.getOrElse(0) + 1) + } + def total: Int = atomic { + ref.get.getOrElse(0) + } + try { + atomic { + increment + increment + throw new Exception + } + } catch { + case e => {} + } + total should equal(0) + } + } +} diff --git a/akka-samples/akka-sample-chat/src/main/scala/ChatServer.scala b/akka-samples/akka-sample-chat/src/main/scala/ChatServer.scala index 6afd7baaf2..6f1d93feb6 100644 --- a/akka-samples/akka-sample-chat/src/main/scala/ChatServer.scala +++ b/akka-samples/akka-sample-chat/src/main/scala/ChatServer.scala @@ -10,7 +10,7 @@ import se.scalablesolutions.akka.actor.{SupervisorFactory, Actor, RemoteActor} import se.scalablesolutions.akka.remote.{RemoteNode, RemoteClient} import se.scalablesolutions.akka.persistence.common.PersistentVector import se.scalablesolutions.akka.persistence.redis.RedisStorage -import se.scalablesolutions.akka.stm.Transaction._ +import se.scalablesolutions.akka.stm.Transaction.Global._ import se.scalablesolutions.akka.config.ScalaConfig._ import se.scalablesolutions.akka.config.OneForOneStrategy import se.scalablesolutions.akka.util.Logging