diff --git a/akka-core/src/main/scala/stm/JTA.scala b/akka-core/src/main/scala/stm/JTA.scala index 5be0460681..510e9cf78c 100644 --- a/akka-core/src/main/scala/stm/JTA.scala +++ b/akka-core/src/main/scala/stm/JTA.scala @@ -104,6 +104,21 @@ object TransactionContainer extends Logging { * @author Jonas Bonér */ class TransactionContainer private (val tm: Either[Option[UserTransaction], Option[TransactionManager]]) { + + def registerSynchronization(sync: Synchronization) = { + TransactionContainer.findSynchronizationRegistry match { // try to use SynchronizationRegistry in JNDI + case Some(registry) => + registry.asInstanceOf[TransactionSynchronizationRegistry].registerInterposedSynchronization(sync) + case None => + tm match { + case Right(Some(txMan)) => // try to use TransactionManager + txMan.getTransaction.registerSynchronization(sync) + case _ => + log.warning("Cannot find TransactionSynchronizationRegistry in JNDI, can't register STM synchronization") + } + } + } + def begin = tm match { case Left(Some(userTx)) => userTx.begin case Right(Some(txMan)) => txMan.begin diff --git a/akka-core/src/main/scala/stm/Transaction.scala b/akka-core/src/main/scala/stm/Transaction.scala index 97d588f400..37b38b670b 100644 --- a/akka-core/src/main/scala/stm/Transaction.scala +++ b/akka-core/src/main/scala/stm/Transaction.scala @@ -316,13 +316,7 @@ object Transaction { def begin = synchronized { jta.foreach { txContainer => txContainer.begin - TransactionContainer.findSynchronizationRegistry match { - case Some(registry) => - registry.asInstanceOf[TransactionSynchronizationRegistry].registerInterposedSynchronization( - new StmSynchronization(txContainer, this)) - case None => - log.warning("Cannot find TransactionSynchronizationRegistry in JNDI, can't register STM synchronization") - } + txContainer.registerSynchronization(new StmSynchronization(txContainer, this)) } } diff --git a/akka-jta/src/main/scala/TransactionContext.scala b/akka-jta/src/main/scala/TransactionContext.scala index 43d2c082d8..b2574dd478 100644 --- a/akka-jta/src/main/scala/TransactionContext.scala +++ b/akka-jta/src/main/scala/TransactionContext.scala @@ -4,7 +4,7 @@ package se.scalablesolutions.akka.jta -import javax.transaction.{Transaction, Status, TransactionManager} +import javax.transaction.{Transaction, Status, TransactionManager, Synchronization} import se.scalablesolutions.akka.stm.{TransactionService, TransactionContainer} import se.scalablesolutions.akka.util.Logging @@ -56,6 +56,41 @@ object TransactionContext extends TransactionProtocol with Logging { implicit val tc = TransactionContainer() private[TransactionContext] val stack = new scala.util.DynamicVariable(new TransactionContext(tc)) + /** + * This method can be used to register a Synchronization instance for participating with the JTA transaction. + * Here is an example of how to add a JPA EntityManager integration. + *
+   *   TransactionContext.registerSynchronization(new javax.transaction.Synchronization() {
+   *     def beforeCompletion = {
+   *       try {
+   *         val status = tm.getStatus
+   *         if (status != Status.STATUS_ROLLEDBACK &&
+   *             status != Status.STATUS_ROLLING_BACK &&
+   *             status != Status.STATUS_MARKED_ROLLBACK) {
+   *           log.debug("Flushing EntityManager...") 
+   *           em.flush // flush EntityManager on success
+   *         }
+   *       } catch {
+   *         case e: javax.transaction.SystemException => throw new RuntimeException(e)
+   *       }
+   *     }
+   *
+   *     def afterCompletion(status: Int) = {
+   *       val status = tm.getStatus
+   *       if (closeAtTxCompletion) em.close
+   *       if (status == Status.STATUS_ROLLEDBACK ||
+   *           status == Status.STATUS_ROLLING_BACK ||
+   *           status == Status.STATUS_MARKED_ROLLBACK) {
+   *         em.close
+   *       }
+   *     }
+   *   })
+   * 
+ * You should also override the 'joinTransaction' and 'handleException' methods. + * See ScalaDoc for these methods in the 'TransactionProtocol' for details. + */ + def registerSynchronization(sync: Synchronization) = synchronization.add(sync) + object Required extends TransactionMonad { def map[T](f: TransactionMonad => T): T = withTxRequired { f(this) } def flatMap[T](f: TransactionMonad => T): T = withTxRequired { f(this) } @@ -170,7 +205,8 @@ trait TransactionMonad { * @author Jonas Bonér */ class TransactionContext(val tc: TransactionContainer) { - private def setRollbackOnly = tc.setRollbackOnly - private def isRollbackOnly: Boolean = tc.getStatus == Status.STATUS_MARKED_ROLLBACK - private def getTransactionContainer: TransactionContainer = tc + def registerSynchronization(sync: Synchronization) = TransactionContext.registerSynchronization(sync) + def setRollbackOnly = tc.setRollbackOnly + def isRollbackOnly: Boolean = tc.getStatus == Status.STATUS_MARKED_ROLLBACK + def getTransactionContainer: TransactionContainer = tc } diff --git a/akka-jta/src/main/scala/TransactionProtocol.scala b/akka-jta/src/main/scala/TransactionProtocol.scala index 2036f0d013..004cf99657 100644 --- a/akka-jta/src/main/scala/TransactionProtocol.scala +++ b/akka-jta/src/main/scala/TransactionProtocol.scala @@ -7,6 +7,9 @@ package se.scalablesolutions.akka.jta import se.scalablesolutions.akka.util.Logging import se.scalablesolutions.akka.stm.TransactionContainer +import java.util.{List => JList} +import java.util.concurrent.CopyOnWriteArrayList + import javax.naming.{NamingException, Context, InitialContext} import javax.transaction.{ Transaction, @@ -15,6 +18,7 @@ import javax.transaction.{ Status, RollbackException, SystemException, + Synchronization, TransactionRequiredException } @@ -62,6 +66,8 @@ import javax.transaction.{ * @author Jonas Bonér */ trait TransactionProtocol extends Logging { + + protected val synchronization: JList[Synchronization] = new CopyOnWriteArrayList[Synchronization] /** * Join JTA transaction. Can be overriden by concrete transaction service implementation @@ -71,36 +77,10 @@ trait TransactionProtocol extends Logging { * *
    * override def joinTransaction = {
-   *   val em = TransactionContext.getEntityManager
-   *   val tm = TransactionContext.getTransactionContainer
-   *   val closeAtTxCompletion: Boolean) 
-   *   tm.getTransaction.registerSynchronization(new javax.transaction.Synchronization() {
-   *     def beforeCompletion = {
-   *       try {
-   *         val status = tm.getStatus
-   *         if (status != Status.STATUS_ROLLEDBACK &&
-   *             status != Status.STATUS_ROLLING_BACK &&
-   *             status != Status.STATUS_MARKED_ROLLBACK) {
-   *           log.debug("Flushing EntityManager...") 
-   *           em.flush // flush EntityManager on success
-   *         }
-   *       } catch {
-   *         case e: javax.transaction.SystemException => throw new RuntimeException(e)
-   *       }
-   *     }
-   *
-   *     def afterCompletion(status: Int) = {
-   *       val status = tm.getStatus
-   *       if (closeAtTxCompletion) em.close
-   *       if (status == Status.STATUS_ROLLEDBACK ||
-   *           status == Status.STATUS_ROLLING_BACK ||
-   *           status == Status.STATUS_MARKED_ROLLBACK) {
-   *         em.close
-   *       }
-   *     }
-   *   })
+   *   val em: EntityManager = ... // get the EntityManager
    *   em.joinTransaction // join JTA transaction
    * }
+   * 
*/ def joinTransaction: Unit = {} @@ -137,6 +117,7 @@ trait TransactionProtocol extends Logging { val tm = TransactionContext.getTransactionContainer if (!isInExistingTransaction(tm)) { tm.begin + registerSynchronization try { joinTransaction body @@ -157,6 +138,7 @@ trait TransactionProtocol extends Logging { def withTxRequiresNew[T](body: => T): T = TransactionContext.withNewContext { val tm = TransactionContext.getTransactionContainer tm.begin + registerSynchronization try { joinTransaction body @@ -224,6 +206,10 @@ trait TransactionProtocol extends Logging { // Helper methods // --------------------------- + protected def registerSynchronization = { + val it = synchronization.iterator + while (it.hasNext) TransactionContext.getTransactionContainer.registerSynchronization(it.next) + } /** * Checks if a transaction is an existing transaction. *