diff --git a/akka-core/src/main/scala/stm/JTA.scala b/akka-core/src/main/scala/stm/JTA.scala index 5be0460681..510e9cf78c 100644 --- a/akka-core/src/main/scala/stm/JTA.scala +++ b/akka-core/src/main/scala/stm/JTA.scala @@ -104,6 +104,21 @@ object TransactionContainer extends Logging { * @author Jonas Bonér */ class TransactionContainer private (val tm: Either[Option[UserTransaction], Option[TransactionManager]]) { + + def registerSynchronization(sync: Synchronization) = { + TransactionContainer.findSynchronizationRegistry match { // try to use SynchronizationRegistry in JNDI + case Some(registry) => + registry.asInstanceOf[TransactionSynchronizationRegistry].registerInterposedSynchronization(sync) + case None => + tm match { + case Right(Some(txMan)) => // try to use TransactionManager + txMan.getTransaction.registerSynchronization(sync) + case _ => + log.warning("Cannot find TransactionSynchronizationRegistry in JNDI, can't register STM synchronization") + } + } + } + def begin = tm match { case Left(Some(userTx)) => userTx.begin case Right(Some(txMan)) => txMan.begin diff --git a/akka-core/src/main/scala/stm/Transaction.scala b/akka-core/src/main/scala/stm/Transaction.scala index 97d588f400..37b38b670b 100644 --- a/akka-core/src/main/scala/stm/Transaction.scala +++ b/akka-core/src/main/scala/stm/Transaction.scala @@ -316,13 +316,7 @@ object Transaction { def begin = synchronized { jta.foreach { txContainer => txContainer.begin - TransactionContainer.findSynchronizationRegistry match { - case Some(registry) => - registry.asInstanceOf[TransactionSynchronizationRegistry].registerInterposedSynchronization( - new StmSynchronization(txContainer, this)) - case None => - log.warning("Cannot find TransactionSynchronizationRegistry in JNDI, can't register STM synchronization") - } + txContainer.registerSynchronization(new StmSynchronization(txContainer, this)) } } diff --git a/akka-jta/src/main/scala/TransactionContext.scala b/akka-jta/src/main/scala/TransactionContext.scala index 43d2c082d8..b2574dd478 100644 --- a/akka-jta/src/main/scala/TransactionContext.scala +++ b/akka-jta/src/main/scala/TransactionContext.scala @@ -4,7 +4,7 @@ package se.scalablesolutions.akka.jta -import javax.transaction.{Transaction, Status, TransactionManager} +import javax.transaction.{Transaction, Status, TransactionManager, Synchronization} import se.scalablesolutions.akka.stm.{TransactionService, TransactionContainer} import se.scalablesolutions.akka.util.Logging @@ -56,6 +56,41 @@ object TransactionContext extends TransactionProtocol with Logging { implicit val tc = TransactionContainer() private[TransactionContext] val stack = new scala.util.DynamicVariable(new TransactionContext(tc)) + /** + * This method can be used to register a Synchronization instance for participating with the JTA transaction. + * Here is an example of how to add a JPA EntityManager integration. + *
+ * TransactionContext.registerSynchronization(new javax.transaction.Synchronization() {
+ * def beforeCompletion = {
+ * try {
+ * val status = tm.getStatus
+ * if (status != Status.STATUS_ROLLEDBACK &&
+ * status != Status.STATUS_ROLLING_BACK &&
+ * status != Status.STATUS_MARKED_ROLLBACK) {
+ * log.debug("Flushing EntityManager...")
+ * em.flush // flush EntityManager on success
+ * }
+ * } catch {
+ * case e: javax.transaction.SystemException => throw new RuntimeException(e)
+ * }
+ * }
+ *
+ * def afterCompletion(status: Int) = {
+ * val status = tm.getStatus
+ * if (closeAtTxCompletion) em.close
+ * if (status == Status.STATUS_ROLLEDBACK ||
+ * status == Status.STATUS_ROLLING_BACK ||
+ * status == Status.STATUS_MARKED_ROLLBACK) {
+ * em.close
+ * }
+ * }
+ * })
+ *
+ * You should also override the 'joinTransaction' and 'handleException' methods.
+ * See ScalaDoc for these methods in the 'TransactionProtocol' for details.
+ */
+ def registerSynchronization(sync: Synchronization) = synchronization.add(sync)
+
object Required extends TransactionMonad {
def map[T](f: TransactionMonad => T): T = withTxRequired { f(this) }
def flatMap[T](f: TransactionMonad => T): T = withTxRequired { f(this) }
@@ -170,7 +205,8 @@ trait TransactionMonad {
* @author Jonas Bonér
*/
class TransactionContext(val tc: TransactionContainer) {
- private def setRollbackOnly = tc.setRollbackOnly
- private def isRollbackOnly: Boolean = tc.getStatus == Status.STATUS_MARKED_ROLLBACK
- private def getTransactionContainer: TransactionContainer = tc
+ def registerSynchronization(sync: Synchronization) = TransactionContext.registerSynchronization(sync)
+ def setRollbackOnly = tc.setRollbackOnly
+ def isRollbackOnly: Boolean = tc.getStatus == Status.STATUS_MARKED_ROLLBACK
+ def getTransactionContainer: TransactionContainer = tc
}
diff --git a/akka-jta/src/main/scala/TransactionProtocol.scala b/akka-jta/src/main/scala/TransactionProtocol.scala
index 2036f0d013..004cf99657 100644
--- a/akka-jta/src/main/scala/TransactionProtocol.scala
+++ b/akka-jta/src/main/scala/TransactionProtocol.scala
@@ -7,6 +7,9 @@ package se.scalablesolutions.akka.jta
import se.scalablesolutions.akka.util.Logging
import se.scalablesolutions.akka.stm.TransactionContainer
+import java.util.{List => JList}
+import java.util.concurrent.CopyOnWriteArrayList
+
import javax.naming.{NamingException, Context, InitialContext}
import javax.transaction.{
Transaction,
@@ -15,6 +18,7 @@ import javax.transaction.{
Status,
RollbackException,
SystemException,
+ Synchronization,
TransactionRequiredException
}
@@ -62,6 +66,8 @@ import javax.transaction.{
* @author Jonas Bonér
*/
trait TransactionProtocol extends Logging {
+
+ protected val synchronization: JList[Synchronization] = new CopyOnWriteArrayList[Synchronization]
/**
* Join JTA transaction. Can be overriden by concrete transaction service implementation
@@ -71,36 +77,10 @@ trait TransactionProtocol extends Logging {
*
*
* override def joinTransaction = {
- * val em = TransactionContext.getEntityManager
- * val tm = TransactionContext.getTransactionContainer
- * val closeAtTxCompletion: Boolean)
- * tm.getTransaction.registerSynchronization(new javax.transaction.Synchronization() {
- * def beforeCompletion = {
- * try {
- * val status = tm.getStatus
- * if (status != Status.STATUS_ROLLEDBACK &&
- * status != Status.STATUS_ROLLING_BACK &&
- * status != Status.STATUS_MARKED_ROLLBACK) {
- * log.debug("Flushing EntityManager...")
- * em.flush // flush EntityManager on success
- * }
- * } catch {
- * case e: javax.transaction.SystemException => throw new RuntimeException(e)
- * }
- * }
- *
- * def afterCompletion(status: Int) = {
- * val status = tm.getStatus
- * if (closeAtTxCompletion) em.close
- * if (status == Status.STATUS_ROLLEDBACK ||
- * status == Status.STATUS_ROLLING_BACK ||
- * status == Status.STATUS_MARKED_ROLLBACK) {
- * em.close
- * }
- * }
- * })
+ * val em: EntityManager = ... // get the EntityManager
* em.joinTransaction // join JTA transaction
* }
+ *
*/
def joinTransaction: Unit = {}
@@ -137,6 +117,7 @@ trait TransactionProtocol extends Logging {
val tm = TransactionContext.getTransactionContainer
if (!isInExistingTransaction(tm)) {
tm.begin
+ registerSynchronization
try {
joinTransaction
body
@@ -157,6 +138,7 @@ trait TransactionProtocol extends Logging {
def withTxRequiresNew[T](body: => T): T = TransactionContext.withNewContext {
val tm = TransactionContext.getTransactionContainer
tm.begin
+ registerSynchronization
try {
joinTransaction
body
@@ -224,6 +206,10 @@ trait TransactionProtocol extends Logging {
// Helper methods
// ---------------------------
+ protected def registerSynchronization = {
+ val it = synchronization.iterator
+ while (it.hasNext) TransactionContext.getTransactionContainer.registerSynchronization(it.next)
+ }
/**
* Checks if a transaction is an existing transaction.
*