diff --git a/akka-jta/src/main/scala/TransactionContext.scala b/akka-jta/src/main/scala/TransactionContext.scala index b2574dd478..e6ac00de8b 100644 --- a/akka-jta/src/main/scala/TransactionContext.scala +++ b/akka-jta/src/main/scala/TransactionContext.scala @@ -54,6 +54,7 @@ import se.scalablesolutions.akka.config.Config._ */ object TransactionContext extends TransactionProtocol with Logging { implicit val tc = TransactionContainer() + private[TransactionContext] val stack = new scala.util.DynamicVariable(new TransactionContext(tc)) /** @@ -91,6 +92,32 @@ object TransactionContext extends TransactionProtocol with Logging { */ def registerSynchronization(sync: Synchronization) = synchronization.add(sync) + /** + * Registeres a join transaction function. + *
+ * Here is an example on how to integrate with JPA EntityManager. + * + *
+ * TransactionContext.registerJoinTransactionFun(() => {
+ * val em: EntityManager = ... // get the EntityManager
+ * em.joinTransaction // join JTA transaction
+ * })
+ *
+ */
+ def registerJoinTransactionFun(fn: () => Unit) = joinTransactionFuns.add(fn)
+
+ /**
+ * Handle exception. Can be overriden by concrete transaction service implementation.
+ *
+ * Here is an example on how to handle JPA exceptions.
+ *
+ * + * TransactionContext.registerExceptionNotToRollbackOn(classOf[NoResultException]) + * TransactionContext.registerExceptionNotToRollbackOn(classOf[NonUniqueResultException]) + *+ */ + def registerExceptionNotToRollbackOn(e: Class[_ <: Exception]) = exceptionsNotToRollbackOn.add(e) + object Required extends TransactionMonad { def map[T](f: TransactionMonad => T): T = withTxRequired { f(this) } def flatMap[T](f: TransactionMonad => T): T = withTxRequired { f(this) } diff --git a/akka-jta/src/main/scala/TransactionProtocol.scala b/akka-jta/src/main/scala/TransactionProtocol.scala index 004cf99657..c23ec26fd7 100644 --- a/akka-jta/src/main/scala/TransactionProtocol.scala +++ b/akka-jta/src/main/scala/TransactionProtocol.scala @@ -68,43 +68,26 @@ import javax.transaction.{ trait TransactionProtocol extends Logging { protected val synchronization: JList[Synchronization] = new CopyOnWriteArrayList[Synchronization] + protected val joinTransactionFuns: JList[() => Unit] = new CopyOnWriteArrayList[() => Unit] + protected val exceptionsNotToRollbackOn: JList[Class[_ <: Exception]] = new CopyOnWriteArrayList[Class[_ <: Exception]] - /** - * Join JTA transaction. Can be overriden by concrete transaction service implementation - * to hook into other transaction services. - * - * Here is an example on how to integrate with JPA EntityManager. - * - *
- * override def joinTransaction = {
- * val em: EntityManager = ... // get the EntityManager
- * em.joinTransaction // join JTA transaction
- * }
- *
- */
- def joinTransaction: Unit = {}
+ def joinTransaction: Unit = {
+ val it = joinTransactionFuns.iterator
+ while (it.hasNext) {
+ val fn = it.next
+ fn()
+ }
+ }
- /**
- * Handle exception. Can be overriden by concrete transaction service implementation.
- *
- * Here is an example on how to handle JPA exceptions.
- *
- *
- * def handleException(tm: TransactionContainer, e: Exception) = {
- * if (isInExistingTransaction(tm)) {
- * // Do not roll back in case of NoResultException or NonUniqueResultException
- * if (!e.isInstanceOf[NoResultException] &&
- * !e.isInstanceOf[NonUniqueResultException]) {
- * log.debug("Setting TX to ROLLBACK_ONLY, due to: %s", e)
- * tm.setRollbackOnly
- * }
- * }
- * throw e
- * }
- *
- */
def handleException(tm: TransactionContainer, e: Exception) = {
- tm.setRollbackOnly
+ var rollback = true
+ val it = joinTransactionFuns.iterator
+ while (it.hasNext) {
+ val exception = it.next
+ if (e.getClass.isAssignableFrom(exception.getClass))
+ rollback = false
+ }
+ if (rollback) tm.setRollbackOnly
throw e
}