From 2757869c62d1bd879f23dae1facd942e106aee7b Mon Sep 17 00:00:00 2001 From: Peter Vlugter Date: Wed, 18 Jan 2012 15:59:59 +1300 Subject: [PATCH 1/3] Update to new Java API for Scala STM --- .../docs/transactor/CoordinatedCounter.java | 18 ++--- .../akka/docs/transactor/Coordinator.java | 5 +- .../code/akka/docs/transactor/Counter.java | 12 ++-- .../akka/docs/transactor/FriendlyCounter.java | 12 ++-- akka-docs/java/transactors.rst | 2 +- .../scala/akka/transactor/Atomically.scala | 67 ------------------- .../scala/akka/transactor/Coordinated.scala | 18 +++-- .../akka/transactor/UntypedTransactor.scala | 4 +- .../transactor/UntypedCoordinatedCounter.java | 25 +++---- .../java/akka/transactor/UntypedCounter.java | 23 +++---- .../java/akka/transactor/UntypedFailer.java | 2 +- project/AkkaBuild.scala | 2 +- 12 files changed, 61 insertions(+), 129 deletions(-) delete mode 100644 akka-transactor/src/main/scala/akka/transactor/Atomically.scala diff --git a/akka-docs/java/code/akka/docs/transactor/CoordinatedCounter.java b/akka-docs/java/code/akka/docs/transactor/CoordinatedCounter.java index dca10b8984..f17e86ade0 100644 --- a/akka-docs/java/code/akka/docs/transactor/CoordinatedCounter.java +++ b/akka-docs/java/code/akka/docs/transactor/CoordinatedCounter.java @@ -7,15 +7,11 @@ package akka.docs.transactor; //#class import akka.actor.*; import akka.transactor.*; -import scala.concurrent.stm.*; +import scala.concurrent.stm.Ref; +import static scala.concurrent.stm.JavaAPI.*; public class CoordinatedCounter extends UntypedActor { - private Ref count = Stm.ref(0); - - private void increment(InTxn txn) { - Integer newValue = count.get(txn) + 1; - count.set(newValue, txn); - } + private Ref.View count = newRef(0); public void onReceive(Object incoming) throws Exception { if (incoming instanceof Coordinated) { @@ -26,14 +22,14 @@ public class CoordinatedCounter extends UntypedActor { if (increment.hasFriend()) { increment.getFriend().tell(coordinated.coordinate(new Increment())); } - coordinated.atomic(new Atomically() { - public void atomically(InTxn txn) { - increment(txn); + coordinated.atomic(new Runnable() { + public void run() { + increment(count, 1); } }); } } else if ("GetCount".equals(incoming)) { - getSender().tell(count.single().get()); + getSender().tell(count.get()); } else { unhandled(incoming); } diff --git a/akka-docs/java/code/akka/docs/transactor/Coordinator.java b/akka-docs/java/code/akka/docs/transactor/Coordinator.java index 6854ed99f6..195906f5f6 100644 --- a/akka-docs/java/code/akka/docs/transactor/Coordinator.java +++ b/akka-docs/java/code/akka/docs/transactor/Coordinator.java @@ -6,7 +6,6 @@ package akka.docs.transactor; import akka.actor.*; import akka.transactor.*; -import scala.concurrent.stm.*; public class Coordinator extends UntypedActor { public void onReceive(Object incoming) throws Exception { @@ -15,8 +14,8 @@ public class Coordinator extends UntypedActor { Object message = coordinated.getMessage(); if (message instanceof Message) { //#coordinated-atomic - coordinated.atomic(new Atomically() { - public void atomically(InTxn txn) { + coordinated.atomic(new Runnable() { + public void run() { // do something in the coordinated transaction ... } }); diff --git a/akka-docs/java/code/akka/docs/transactor/Counter.java b/akka-docs/java/code/akka/docs/transactor/Counter.java index 0a6b7b2219..efe2aaed72 100644 --- a/akka-docs/java/code/akka/docs/transactor/Counter.java +++ b/akka-docs/java/code/akka/docs/transactor/Counter.java @@ -6,21 +6,21 @@ package akka.docs.transactor; //#class import akka.transactor.*; -import scala.concurrent.stm.*; +import scala.concurrent.stm.Ref; +import static scala.concurrent.stm.JavaAPI.*; public class Counter extends UntypedTransactor { - Ref count = Stm.ref(0); + Ref.View count = newRef(0); - public void atomically(InTxn txn, Object message) { + public void atomically(Object message) { if (message instanceof Increment) { - Integer newValue = count.get(txn) + 1; - count.set(newValue, txn); + increment(count, 1); } } @Override public boolean normally(Object message) { if ("GetCount".equals(message)) { - getSender().tell(count.single().get()); + getSender().tell(count.get()); return true; } else return false; } diff --git a/akka-docs/java/code/akka/docs/transactor/FriendlyCounter.java b/akka-docs/java/code/akka/docs/transactor/FriendlyCounter.java index d70c653063..7ef31c5bea 100644 --- a/akka-docs/java/code/akka/docs/transactor/FriendlyCounter.java +++ b/akka-docs/java/code/akka/docs/transactor/FriendlyCounter.java @@ -8,10 +8,11 @@ package akka.docs.transactor; import akka.actor.*; import akka.transactor.*; import java.util.Set; -import scala.concurrent.stm.*; +import scala.concurrent.stm.Ref; +import static scala.concurrent.stm.JavaAPI.*; public class FriendlyCounter extends UntypedTransactor { - Ref count = Stm.ref(0); + Ref.View count = newRef(0); @Override public Set coordinate(Object message) { if (message instanceof Increment) { @@ -22,16 +23,15 @@ public class FriendlyCounter extends UntypedTransactor { return nobody(); } - public void atomically(InTxn txn, Object message) { + public void atomically(Object message) { if (message instanceof Increment) { - Integer newValue = count.get(txn) + 1; - count.set(newValue, txn); + increment(count, 1); } } @Override public boolean normally(Object message) { if ("GetCount".equals(message)) { - getSender().tell(count.single().get()); + getSender().tell(count.get()); return true; } else return false; } diff --git a/akka-docs/java/transactors.rst b/akka-docs/java/transactors.rst index f7471412a9..9dd69664b6 100644 --- a/akka-docs/java/transactors.rst +++ b/akka-docs/java/transactors.rst @@ -102,7 +102,7 @@ be sent. :language: java To enter the coordinated transaction use the atomic method of the coordinated -object, passing in an ``akka.transactor.Atomically`` object. +object, passing in a ``java.lang.Runnable``. .. includecode:: code/akka/docs/transactor/Coordinator.java#coordinated-atomic :language: java diff --git a/akka-transactor/src/main/scala/akka/transactor/Atomically.scala b/akka-transactor/src/main/scala/akka/transactor/Atomically.scala deleted file mode 100644 index 4995a6b8bd..0000000000 --- a/akka-transactor/src/main/scala/akka/transactor/Atomically.scala +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright (C) 2009-2011 Typesafe Inc. - */ - -package akka.transactor - -import scala.concurrent.stm._ - -/** - * Java API. - * - * For creating Java-friendly coordinated atomic blocks. - * - * @see [[akka.transactor.Coordinated]] - */ -trait Atomically { - def atomically(txn: InTxn): Unit -} - -/** - * Java API. - * - * For creating completion handlers. - */ -trait CompletionHandler { - def handle(status: Txn.Status): Unit -} - -/** - * Java API. - * - * To ease some of the pain of using Scala STM from Java until - * the proper Java API is created. - */ -object Stm { - /** - * Create an STM Ref with an initial value. - */ - def ref[A](initialValue: A): Ref[A] = Ref(initialValue) - - /** - * Add a CompletionHandler to run after the current transaction - * has committed. - */ - def afterCommit(handler: CompletionHandler): Unit = { - val txn = Txn.findCurrent - if (txn.isDefined) Txn.afterCommit(status ⇒ handler.handle(status))(txn.get) - } - - /** - * Add a CompletionHandler to run after the current transaction - * has rolled back. - */ - def afterRollback(handler: CompletionHandler): Unit = { - val txn = Txn.findCurrent - if (txn.isDefined) Txn.afterRollback(status ⇒ handler.handle(status))(txn.get) - } - - /** - * Add a CompletionHandler to run after the current transaction - * has committed or rolled back. - */ - def afterCompletion(handler: CompletionHandler): Unit = { - val txn = Txn.findCurrent - if (txn.isDefined) Txn.afterCompletion(status ⇒ handler.handle(status))(txn.get) - } -} diff --git a/akka-transactor/src/main/scala/akka/transactor/Coordinated.scala b/akka-transactor/src/main/scala/akka/transactor/Coordinated.scala index f9ef8538be..a7c709b9fe 100644 --- a/akka-transactor/src/main/scala/akka/transactor/Coordinated.scala +++ b/akka-transactor/src/main/scala/akka/transactor/Coordinated.scala @@ -6,7 +6,8 @@ package akka.transactor import akka.AkkaException import akka.util.Timeout -import scala.concurrent.stm._ +import scala.concurrent.stm.{ CommitBarrier, InTxn } +import java.util.concurrent.Callable /** * Akka-specific exception for coordinated transactions. @@ -125,7 +126,7 @@ class Coordinated(val message: Any, member: CommitBarrier.Member) { * * @throws CoordinatedTransactionException if the coordinated transaction fails. */ - def atomic[T](body: InTxn ⇒ T): T = { + def atomic[A](body: InTxn ⇒ A): A = { member.atomic(body) match { case Right(result) ⇒ result case Left(CommitBarrier.MemberUncaughtExceptionCause(x)) ⇒ @@ -136,13 +137,22 @@ class Coordinated(val message: Any, member: CommitBarrier.Member) { } /** - * Java API: coordinated atomic method that accepts an [[akka.transactor.Atomically]]. + * Java API: coordinated atomic method that accepts a `java.lang.Runnable`. * Delimits the coordinated transaction. The transaction will wait for all other transactions * in this coordination before committing. The timeout is specified when creating the Coordinated. * * @throws CoordinatedTransactionException if the coordinated transaction fails. */ - def atomic(atomically: Atomically): Unit = atomic(txn ⇒ atomically.atomically(txn)) + def atomic(runnable: Runnable): Unit = atomic { _ ⇒ runnable.run } + + /** + * Java API: coordinated atomic method that accepts a `java.util.concurrent.Callable`. + * Delimits the coordinated transaction. The transaction will wait for all other transactions + * in this coordination before committing. The timeout is specified when creating the Coordinated. + * + * @throws CoordinatedTransactionException if the coordinated transaction fails. + */ + def atomic[A](callable: Callable[A]): A = atomic { _ ⇒ callable.call } /** * An empty coordinated atomic block. Can be used to complete the number of members involved diff --git a/akka-transactor/src/main/scala/akka/transactor/UntypedTransactor.scala b/akka-transactor/src/main/scala/akka/transactor/UntypedTransactor.scala index 9a37f81915..59dc8f049d 100644 --- a/akka-transactor/src/main/scala/akka/transactor/UntypedTransactor.scala +++ b/akka-transactor/src/main/scala/akka/transactor/UntypedTransactor.scala @@ -25,7 +25,7 @@ abstract class UntypedTransactor extends UntypedActor { sendTo.actor.tell(coordinated(sendTo.message.getOrElse(message))) } before(message) - coordinated.atomic { txn ⇒ atomically(txn, message) } + coordinated.atomic { txn ⇒ atomically(message) } after(message) } case message ⇒ { @@ -84,7 +84,7 @@ abstract class UntypedTransactor extends UntypedActor { * The Receive block to run inside the coordinated transaction. */ @throws(classOf[Exception]) - def atomically(txn: InTxn, message: Any) {} + def atomically(message: Any) /** * A Receive block that runs after the coordinated transaction. diff --git a/akka-transactor/src/test/java/akka/transactor/UntypedCoordinatedCounter.java b/akka-transactor/src/test/java/akka/transactor/UntypedCoordinatedCounter.java index 694a675d8e..7c92930e02 100644 --- a/akka-transactor/src/test/java/akka/transactor/UntypedCoordinatedCounter.java +++ b/akka-transactor/src/test/java/akka/transactor/UntypedCoordinatedCounter.java @@ -7,24 +7,20 @@ package akka.transactor; import akka.actor.ActorRef; import akka.actor.Actors; import akka.actor.UntypedActor; -import scala.concurrent.stm.*; +import static scala.concurrent.stm.JavaAPI.*; +import scala.concurrent.stm.Ref; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; public class UntypedCoordinatedCounter extends UntypedActor { private String name; - private Ref count = Stm.ref(0); + private Ref.View count = newRef(0); public UntypedCoordinatedCounter(String name) { this.name = name; } - private void increment(InTxn txn) { - Integer newValue = count.get(txn) + 1; - count.set(newValue, txn); - } - public void onReceive(Object incoming) throws Exception { if (incoming instanceof Coordinated) { Coordinated coordinated = (Coordinated) incoming; @@ -33,8 +29,8 @@ public class UntypedCoordinatedCounter extends UntypedActor { Increment increment = (Increment) message; List friends = increment.getFriends(); final CountDownLatch latch = increment.getLatch(); - final CompletionHandler countDown = new CompletionHandler() { - public void handle(Txn.Status status) { + final Runnable countDown = new Runnable() { + public void run() { latch.countDown(); } }; @@ -42,15 +38,16 @@ public class UntypedCoordinatedCounter extends UntypedActor { Increment coordMessage = new Increment(friends.subList(1, friends.size()), latch); friends.get(0).tell(coordinated.coordinate(coordMessage)); } - coordinated.atomic(new Atomically() { - public void atomically(InTxn txn) { - increment(txn); - Stm.afterCompletion(countDown); + coordinated.atomic(new Runnable() { + public void run() { + increment(count, 1); + afterRollback(countDown); + afterCommit(countDown); } }); } } else if ("GetCount".equals(incoming)) { - getSender().tell(count.single().get()); + getSender().tell(count.get()); } } } diff --git a/akka-transactor/src/test/java/akka/transactor/UntypedCounter.java b/akka-transactor/src/test/java/akka/transactor/UntypedCounter.java index f03f74b10f..392bfbca42 100644 --- a/akka-transactor/src/test/java/akka/transactor/UntypedCounter.java +++ b/akka-transactor/src/test/java/akka/transactor/UntypedCounter.java @@ -7,7 +7,8 @@ package akka.transactor; import akka.actor.ActorRef; import akka.transactor.UntypedTransactor; import akka.transactor.SendTo; -import scala.concurrent.stm.*; +import static scala.concurrent.stm.JavaAPI.*; +import scala.concurrent.stm.Ref; import java.util.List; import java.util.Set; import java.util.concurrent.CountDownLatch; @@ -15,17 +16,12 @@ import java.util.concurrent.TimeUnit; public class UntypedCounter extends UntypedTransactor { private String name; - private Ref count = Stm.ref(0); + private Ref.View count = newRef(0); public UntypedCounter(String name) { this.name = name; } - private void increment(InTxn txn) { - Integer newValue = count.get(txn) + 1; - count.set(newValue, txn); - } - @Override public Set coordinate(Object message) { if (message instanceof Increment) { Increment increment = (Increment) message; @@ -41,22 +37,23 @@ public class UntypedCounter extends UntypedTransactor { } } - public void atomically(InTxn txn, Object message) { + public void atomically(Object message) { if (message instanceof Increment) { - increment(txn); + increment(count, 1); final Increment increment = (Increment) message; - CompletionHandler countDown = new CompletionHandler() { - public void handle(Txn.Status status) { + Runnable countDown = new Runnable() { + public void run() { increment.getLatch().countDown(); } }; - Stm.afterCompletion(countDown); + afterRollback(countDown); + afterCommit(countDown); } } @Override public boolean normally(Object message) { if ("GetCount".equals(message)) { - getSender().tell(count.single().get()); + getSender().tell(count.get()); return true; } else return false; } diff --git a/akka-transactor/src/test/java/akka/transactor/UntypedFailer.java b/akka-transactor/src/test/java/akka/transactor/UntypedFailer.java index 1f9e6ff41c..8ead9ae2ea 100644 --- a/akka-transactor/src/test/java/akka/transactor/UntypedFailer.java +++ b/akka-transactor/src/test/java/akka/transactor/UntypedFailer.java @@ -7,7 +7,7 @@ package akka.transactor; import scala.concurrent.stm.InTxn; public class UntypedFailer extends UntypedTransactor { - public void atomically(InTxn txn, Object message) throws Exception { + public void atomically(Object message) throws Exception { throw new ExpectedFailureException(); } } diff --git a/project/AkkaBuild.scala b/project/AkkaBuild.scala index 091346de34..626ede1834 100644 --- a/project/AkkaBuild.scala +++ b/project/AkkaBuild.scala @@ -450,7 +450,7 @@ object Dependency { val Netty = "3.2.5.Final" val Protobuf = "2.4.1" val Rabbit = "2.3.1" - val ScalaStm = "0.4" + val ScalaStm = "0.5.0-SNAPSHOT" val Scalatest = "1.6.1" val Slf4j = "1.6.4" val Spring = "3.0.5.RELEASE" From 20587654852144aadc824ea25cbd202081c6e9db Mon Sep 17 00:00:00 2001 From: Peter Vlugter Date: Thu, 19 Jan 2012 11:09:03 +1300 Subject: [PATCH 2/3] Add basic java api for scala stm to transactor module Note: this commit will be reverted once a java api is published for scala stm. Adding to transactor module for M3 release. --- .../scala/scala/concurrent/stm/JavaAPI.scala | 112 ++++++++++++ .../scala/concurrent/stm/JavaAPITests.java | 161 ++++++++++++++++++ .../scala/concurrent/stm/TestException.java | 9 + .../scala/concurrent/stm/JavaAPISuite.scala | 7 + project/AkkaBuild.scala | 2 +- 5 files changed, 290 insertions(+), 1 deletion(-) create mode 100644 akka-transactor/src/main/scala/scala/concurrent/stm/JavaAPI.scala create mode 100644 akka-transactor/src/test/java/scala/concurrent/stm/JavaAPITests.java create mode 100644 akka-transactor/src/test/java/scala/concurrent/stm/TestException.java create mode 100644 akka-transactor/src/test/scala/scala/concurrent/stm/JavaAPISuite.scala diff --git a/akka-transactor/src/main/scala/scala/concurrent/stm/JavaAPI.scala b/akka-transactor/src/main/scala/scala/concurrent/stm/JavaAPI.scala new file mode 100644 index 0000000000..964664fe55 --- /dev/null +++ b/akka-transactor/src/main/scala/scala/concurrent/stm/JavaAPI.scala @@ -0,0 +1,112 @@ +/* scala-stm - (c) 2009-2011, Stanford University, PPL */ + +package scala.concurrent.stm + +import java.util.concurrent.Callable +import scala.runtime.AbstractFunction1 + +/** + * Java-friendly API. + */ +object JavaAPI { + + /** + * Create a Ref with an initial value. Return a `Ref.View`, which does not + * require implicit transactions. + * @param initialValue the initial value for the newly created `Ref.View` + * @return a new `Ref.View` + */ + def newRef[A](initialValue: A): Ref.View[A] = Ref(initialValue).single + + /** + * Create an empty TMap. Return a `TMap.View`, which does not require + * implicit transactions. + * @return a new, empty `TMap.View` + */ + def newTMap[A, B](): TMap.View[A, B] = TMap.empty[A, B].single + + /** + * Create an empty TSet. Return a `TSet.View`, which does not require + * implicit transactions. + * @return a new, empty `TSet.View` + */ + def newTSet[A](): TSet.View[A] = TSet.empty[A].single + + /** + * Create a TArray containing `length` elements. Return a `TArray.View`, + * which does not require implicit transactions. + * @param length the length of the `TArray.View` to be created + * @return a new `TArray.View` containing `length` elements (initially null) + */ + def newTArray[A <: AnyRef](length: Int): TArray.View[A] = TArray.ofDim[A](length)(ClassManifest.classType(AnyRef.getClass)).single + + /** + * Atomic block that takes a `Runnable`. + * @param runnable the `Runnable` to run within a transaction + */ + def atomic(runnable: Runnable): Unit = scala.concurrent.stm.atomic { txn ⇒ runnable.run } + + /** + * Atomic block that takes a `Callable`. + * @param callable the `Callable` to run within a transaction + * @return the value returned by the `Callable` + */ + def atomic[A](callable: Callable[A]): A = scala.concurrent.stm.atomic { txn ⇒ callable.call } + + /** + * Transform the value stored by `ref` by applying the function `f`. + * @param ref the `Ref.View` to be transformed + * @param f the function to be applied + */ + def transform[A](ref: Ref.View[A], f: AbstractFunction1[A, A]): Unit = ref.transform(f) + + /** + * Transform the value stored by `ref` by applying the function `f` and + * return the old value. + * @param ref the `Ref.View` to be transformed + * @param f the function to be applied + * @return the old value of `ref` + */ + def getAndTransform[A](ref: Ref.View[A], f: AbstractFunction1[A, A]): A = ref.getAndTransform(f) + + /** + * Transform the value stored by `ref` by applying the function `f` and + * return the new value. + * @param ref the `Ref.View` to be transformed + * @param f the function to be applied + * @return the new value of `ref` + */ + def transformAndGet[A](ref: Ref.View[A], f: AbstractFunction1[A, A]): A = ref.transformAndGet(f) + + /** + * Increment the `java.lang.Integer` value of a `Ref.View`. + * @param ref the `Ref.View` to be incremented + * @param delta the amount to increment + */ + def increment(ref: Ref.View[java.lang.Integer], delta: Int): Unit = ref.transform { v ⇒ v.intValue + delta } + + /** + * Increment the `java.lang.Long` value of a `Ref.View`. + * @param ref the `Ref.View` to be incremented + * @param delta the amount to increment + */ + def increment(ref: Ref.View[java.lang.Long], delta: Long): Unit = ref.transform { v ⇒ v.longValue + delta } + + /** + * Add a task to run after the current transaction has committed. + * @param task the `Runnable` task to run after transaction commit + */ + def afterCommit(task: Runnable): Unit = { + val txn = Txn.findCurrent + if (txn.isDefined) Txn.afterCommit(status ⇒ task.run)(txn.get) + } + + /** + * Add a task to run after the current transaction has rolled back. + * @param task the `Runnable` task to run after transaction rollback + */ + def afterRollback(task: Runnable): Unit = { + val txn = Txn.findCurrent + if (txn.isDefined) Txn.afterRollback(status ⇒ task.run)(txn.get) + } +} diff --git a/akka-transactor/src/test/java/scala/concurrent/stm/JavaAPITests.java b/akka-transactor/src/test/java/scala/concurrent/stm/JavaAPITests.java new file mode 100644 index 0000000000..e2d0631590 --- /dev/null +++ b/akka-transactor/src/test/java/scala/concurrent/stm/JavaAPITests.java @@ -0,0 +1,161 @@ +/* scala-stm - (c) 2009-2011, Stanford University, PPL */ + +package scala.concurrent.stm; + +import static org.junit.Assert.*; +import org.junit.Test; + +import scala.concurrent.stm.Ref; +import static scala.concurrent.stm.JavaAPI.*; + +import scala.runtime.AbstractFunction1; +import java.util.concurrent.Callable; + +import static scala.collection.JavaConversions.*; +import java.util.Map; +import java.util.Set; +import java.util.List; + +public class JavaAPITests { + @Test + public void createIntegerRef() { + Ref.View ref = newRef(0); + int unboxed = ref.get(); + assertEquals(0, unboxed); + } + + @Test + public void atomicWithRunnable() { + final Ref.View ref = newRef(0); + atomic(new Runnable() { + public void run() { + ref.set(10); + } + }); + int value = ref.get(); + assertEquals(10, value); + } + + @Test + public void atomicWithCallable() { + final Ref.View ref = newRef(0); + int oldValue = atomic(new Callable() { + public Integer call() { + return ref.swap(10); + } + }); + assertEquals(0, oldValue); + int newValue = ref.get(); + assertEquals(10, newValue); + } + + @Test(expected = TestException.class) + public void failingTransaction() { + final Ref.View ref = newRef(0); + try { + atomic(new Runnable() { + public void run() { + ref.set(10); + throw new TestException(); + } + }); + } catch (TestException e) { + int value = ref.get(); + assertEquals(0, value); + throw e; + } + } + + @Test + public void transformInteger() { + Ref.View ref = newRef(0); + transform(ref, new AbstractFunction1() { + public Integer apply(Integer i) { + return i + 10; + } + }); + int value = ref.get(); + assertEquals(10, value); + } + + @Test + public void incrementInteger() { + Ref.View ref = newRef(0); + increment(ref, 10); + int value = ref.get(); + assertEquals(10, value); + } + + @Test + public void incrementLong() { + Ref.View ref = newRef(0L); + increment(ref, 10L); + long value = ref.get(); + assertEquals(10L, value); + } + + @Test + public void createAndUseTMap() { + TMap.View tmap = newTMap(); + Map map = mutableMapAsJavaMap(tmap); + map.put(1, "one"); + map.put(2, "two"); + assertEquals("one", map.get(1)); + assertEquals("two", map.get(2)); + assertTrue(map.containsKey(2)); + map.remove(2); + assertFalse(map.containsKey(2)); + } + + @Test(expected = TestException.class) + public void failingTMapTransaction() { + TMap.View tmap = newTMap(); + final Map map = mutableMapAsJavaMap(tmap); + try { + atomic(new Runnable() { + public void run() { + map.put(1, "one"); + map.put(2, "two"); + assertTrue(map.containsKey(1)); + assertTrue(map.containsKey(2)); + throw new TestException(); + } + }); + } catch (TestException e) { + assertFalse(map.containsKey(1)); + assertFalse(map.containsKey(2)); + throw e; + } + } + + @Test + public void createAndUseTSet() { + TSet.View tset = newTSet(); + Set set = mutableSetAsJavaSet(tset); + set.add("one"); + set.add("two"); + assertTrue(set.contains("one")); + assertTrue(set.contains("two")); + assertEquals(2, set.size()); + set.add("one"); + assertEquals(2, set.size()); + set.remove("two"); + assertFalse(set.contains("two")); + assertEquals(1, set.size()); + } + + @Test + public void createAndUseTArray() { + TArray.View tarray = newTArray(3); + List seq = mutableSeqAsJavaList(tarray); + assertEquals(null, seq.get(0)); + assertEquals(null, seq.get(1)); + assertEquals(null, seq.get(2)); + seq.set(0, "zero"); + seq.set(1, "one"); + seq.set(2, "two"); + assertEquals("zero", seq.get(0)); + assertEquals("one", seq.get(1)); + assertEquals("two", seq.get(2)); + } +} diff --git a/akka-transactor/src/test/java/scala/concurrent/stm/TestException.java b/akka-transactor/src/test/java/scala/concurrent/stm/TestException.java new file mode 100644 index 0000000000..cc810761d4 --- /dev/null +++ b/akka-transactor/src/test/java/scala/concurrent/stm/TestException.java @@ -0,0 +1,9 @@ +/* scala-stm - (c) 2009-2011, Stanford University, PPL */ + +package scala.concurrent.stm; + +public class TestException extends RuntimeException { + public TestException() { + super("Expected failure"); + } +} diff --git a/akka-transactor/src/test/scala/scala/concurrent/stm/JavaAPISuite.scala b/akka-transactor/src/test/scala/scala/concurrent/stm/JavaAPISuite.scala new file mode 100644 index 0000000000..3d0c48e90f --- /dev/null +++ b/akka-transactor/src/test/scala/scala/concurrent/stm/JavaAPISuite.scala @@ -0,0 +1,7 @@ +/* scala-stm - (c) 2009-2011, Stanford University, PPL */ + +package scala.concurrent.stm + +import org.scalatest.junit.JUnitWrapperSuite + +class JavaAPISuite extends JUnitWrapperSuite("scala.concurrent.stm.JavaAPITests", Thread.currentThread.getContextClassLoader) diff --git a/project/AkkaBuild.scala b/project/AkkaBuild.scala index 626ede1834..091346de34 100644 --- a/project/AkkaBuild.scala +++ b/project/AkkaBuild.scala @@ -450,7 +450,7 @@ object Dependency { val Netty = "3.2.5.Final" val Protobuf = "2.4.1" val Rabbit = "2.3.1" - val ScalaStm = "0.5.0-SNAPSHOT" + val ScalaStm = "0.4" val Scalatest = "1.6.1" val Slf4j = "1.6.4" val Spring = "3.0.5.RELEASE" From 27da7c4d128243ca41025fad24bce5ab19e8e87c Mon Sep 17 00:00:00 2001 From: Peter Vlugter Date: Fri, 20 Jan 2012 11:31:28 +1300 Subject: [PATCH 3/3] Update java-friendly api for scala stm - move to japi.Stm - add newMap, newSet, newList methods with java conversions - add afterCompletion lifecycle callback --- .../docs/transactor/CoordinatedCounter.java | 6 +-- .../code/akka/docs/transactor/Counter.java | 6 +-- .../akka/docs/transactor/FriendlyCounter.java | 6 +-- .../stm/{JavaAPI.scala => japi/Stm.scala} | 51 ++++++++++++++++--- .../transactor/UntypedCoordinatedCounter.java | 9 ++-- .../java/akka/transactor/UntypedCounter.java | 9 ++-- .../scala/concurrent/stm/JavaAPITests.java | 35 ++++++------- 7 files changed, 75 insertions(+), 47 deletions(-) rename akka-transactor/src/main/scala/scala/concurrent/stm/{JavaAPI.scala => japi/Stm.scala} (66%) diff --git a/akka-docs/java/code/akka/docs/transactor/CoordinatedCounter.java b/akka-docs/java/code/akka/docs/transactor/CoordinatedCounter.java index f17e86ade0..a00d26ed88 100644 --- a/akka-docs/java/code/akka/docs/transactor/CoordinatedCounter.java +++ b/akka-docs/java/code/akka/docs/transactor/CoordinatedCounter.java @@ -8,10 +8,10 @@ package akka.docs.transactor; import akka.actor.*; import akka.transactor.*; import scala.concurrent.stm.Ref; -import static scala.concurrent.stm.JavaAPI.*; +import scala.concurrent.stm.japi.Stm; public class CoordinatedCounter extends UntypedActor { - private Ref.View count = newRef(0); + private Ref.View count = Stm.newRef(0); public void onReceive(Object incoming) throws Exception { if (incoming instanceof Coordinated) { @@ -24,7 +24,7 @@ public class CoordinatedCounter extends UntypedActor { } coordinated.atomic(new Runnable() { public void run() { - increment(count, 1); + Stm.increment(count, 1); } }); } diff --git a/akka-docs/java/code/akka/docs/transactor/Counter.java b/akka-docs/java/code/akka/docs/transactor/Counter.java index efe2aaed72..acd0d8f516 100644 --- a/akka-docs/java/code/akka/docs/transactor/Counter.java +++ b/akka-docs/java/code/akka/docs/transactor/Counter.java @@ -7,14 +7,14 @@ package akka.docs.transactor; //#class import akka.transactor.*; import scala.concurrent.stm.Ref; -import static scala.concurrent.stm.JavaAPI.*; +import scala.concurrent.stm.japi.Stm; public class Counter extends UntypedTransactor { - Ref.View count = newRef(0); + Ref.View count = Stm.newRef(0); public void atomically(Object message) { if (message instanceof Increment) { - increment(count, 1); + Stm.increment(count, 1); } } diff --git a/akka-docs/java/code/akka/docs/transactor/FriendlyCounter.java b/akka-docs/java/code/akka/docs/transactor/FriendlyCounter.java index 7ef31c5bea..fe3d759539 100644 --- a/akka-docs/java/code/akka/docs/transactor/FriendlyCounter.java +++ b/akka-docs/java/code/akka/docs/transactor/FriendlyCounter.java @@ -9,10 +9,10 @@ import akka.actor.*; import akka.transactor.*; import java.util.Set; import scala.concurrent.stm.Ref; -import static scala.concurrent.stm.JavaAPI.*; +import scala.concurrent.stm.japi.Stm; public class FriendlyCounter extends UntypedTransactor { - Ref.View count = newRef(0); + Ref.View count = Stm.newRef(0); @Override public Set coordinate(Object message) { if (message instanceof Increment) { @@ -25,7 +25,7 @@ public class FriendlyCounter extends UntypedTransactor { public void atomically(Object message) { if (message instanceof Increment) { - increment(count, 1); + Stm.increment(count, 1); } } diff --git a/akka-transactor/src/main/scala/scala/concurrent/stm/JavaAPI.scala b/akka-transactor/src/main/scala/scala/concurrent/stm/japi/Stm.scala similarity index 66% rename from akka-transactor/src/main/scala/scala/concurrent/stm/JavaAPI.scala rename to akka-transactor/src/main/scala/scala/concurrent/stm/japi/Stm.scala index 964664fe55..d9ed5a8330 100644 --- a/akka-transactor/src/main/scala/scala/concurrent/stm/JavaAPI.scala +++ b/akka-transactor/src/main/scala/scala/concurrent/stm/japi/Stm.scala @@ -1,14 +1,19 @@ /* scala-stm - (c) 2009-2011, Stanford University, PPL */ -package scala.concurrent.stm +package scala.concurrent.stm.japi import java.util.concurrent.Callable +import java.util.{ List ⇒ JList, Map ⇒ JMap, Set ⇒ JSet } +import scala.collection.JavaConversions +import scala.concurrent.stm +import scala.concurrent.stm._ import scala.runtime.AbstractFunction1 /** - * Java-friendly API. + * Java-friendly API for ScalaSTM. + * These methods can also be statically imported. */ -object JavaAPI { +object Stm { /** * Create a Ref with an initial value. Return a `Ref.View`, which does not @@ -20,38 +25,58 @@ object JavaAPI { /** * Create an empty TMap. Return a `TMap.View`, which does not require - * implicit transactions. + * implicit transactions. See newMap for included java conversion. * @return a new, empty `TMap.View` */ def newTMap[A, B](): TMap.View[A, B] = TMap.empty[A, B].single + /** + * Create an empty TMap. Return a `java.util.Map` view of this TMap. + * @return a new, empty `TMap.View` wrapped as a `java.util.Map`. + */ + def newMap[A, B](): JMap[A, B] = JavaConversions.mutableMapAsJavaMap(newTMap[A, B]) + /** * Create an empty TSet. Return a `TSet.View`, which does not require - * implicit transactions. + * implicit transactions. See newSet for included java conversion. * @return a new, empty `TSet.View` */ def newTSet[A](): TSet.View[A] = TSet.empty[A].single + /** + * Create an empty TSet. Return a `java.util.Set` view of this TSet. + * @return a new, empty `TSet.View` wrapped as a `java.util.Set`. + */ + def newSet[A](): JSet[A] = JavaConversions.mutableSetAsJavaSet(newTSet[A]) + /** * Create a TArray containing `length` elements. Return a `TArray.View`, - * which does not require implicit transactions. + * which does not require implicit transactions. See newList for included + * java conversion. * @param length the length of the `TArray.View` to be created * @return a new `TArray.View` containing `length` elements (initially null) */ def newTArray[A <: AnyRef](length: Int): TArray.View[A] = TArray.ofDim[A](length)(ClassManifest.classType(AnyRef.getClass)).single + /** + * Create an empty TArray. Return a `java.util.List` view of this Array. + * @param length the length of the `TArray.View` to be created + * @return a new, empty `TArray.View` wrapped as a `java.util.List`. + */ + def newList[A <: AnyRef](length: Int): JList[A] = JavaConversions.mutableSeqAsJavaList(newTArray[A](length)) + /** * Atomic block that takes a `Runnable`. * @param runnable the `Runnable` to run within a transaction */ - def atomic(runnable: Runnable): Unit = scala.concurrent.stm.atomic { txn ⇒ runnable.run } + def atomic(runnable: Runnable): Unit = stm.atomic { txn ⇒ runnable.run } /** * Atomic block that takes a `Callable`. * @param callable the `Callable` to run within a transaction * @return the value returned by the `Callable` */ - def atomic[A](callable: Callable[A]): A = scala.concurrent.stm.atomic { txn ⇒ callable.call } + def atomic[A](callable: Callable[A]): A = stm.atomic { txn ⇒ callable.call } /** * Transform the value stored by `ref` by applying the function `f`. @@ -109,4 +134,14 @@ object JavaAPI { val txn = Txn.findCurrent if (txn.isDefined) Txn.afterRollback(status ⇒ task.run)(txn.get) } + + /** + * Add a task to run after the current transaction has either rolled back + * or committed. + * @param task the `Runnable` task to run after transaction completion + */ + def afterCompletion(task: Runnable): Unit = { + val txn = Txn.findCurrent + if (txn.isDefined) Txn.afterCompletion(status ⇒ task.run)(txn.get) + } } diff --git a/akka-transactor/src/test/java/akka/transactor/UntypedCoordinatedCounter.java b/akka-transactor/src/test/java/akka/transactor/UntypedCoordinatedCounter.java index 7c92930e02..435fb0df54 100644 --- a/akka-transactor/src/test/java/akka/transactor/UntypedCoordinatedCounter.java +++ b/akka-transactor/src/test/java/akka/transactor/UntypedCoordinatedCounter.java @@ -7,15 +7,15 @@ package akka.transactor; import akka.actor.ActorRef; import akka.actor.Actors; import akka.actor.UntypedActor; -import static scala.concurrent.stm.JavaAPI.*; import scala.concurrent.stm.Ref; +import scala.concurrent.stm.japi.Stm; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; public class UntypedCoordinatedCounter extends UntypedActor { private String name; - private Ref.View count = newRef(0); + private Ref.View count = Stm.newRef(0); public UntypedCoordinatedCounter(String name) { this.name = name; @@ -40,9 +40,8 @@ public class UntypedCoordinatedCounter extends UntypedActor { } coordinated.atomic(new Runnable() { public void run() { - increment(count, 1); - afterRollback(countDown); - afterCommit(countDown); + Stm.increment(count, 1); + Stm.afterCompletion(countDown); } }); } diff --git a/akka-transactor/src/test/java/akka/transactor/UntypedCounter.java b/akka-transactor/src/test/java/akka/transactor/UntypedCounter.java index 392bfbca42..e4e680f74b 100644 --- a/akka-transactor/src/test/java/akka/transactor/UntypedCounter.java +++ b/akka-transactor/src/test/java/akka/transactor/UntypedCounter.java @@ -7,8 +7,8 @@ package akka.transactor; import akka.actor.ActorRef; import akka.transactor.UntypedTransactor; import akka.transactor.SendTo; -import static scala.concurrent.stm.JavaAPI.*; import scala.concurrent.stm.Ref; +import scala.concurrent.stm.japi.Stm; import java.util.List; import java.util.Set; import java.util.concurrent.CountDownLatch; @@ -16,7 +16,7 @@ import java.util.concurrent.TimeUnit; public class UntypedCounter extends UntypedTransactor { private String name; - private Ref.View count = newRef(0); + private Ref.View count = Stm.newRef(0); public UntypedCounter(String name) { this.name = name; @@ -39,15 +39,14 @@ public class UntypedCounter extends UntypedTransactor { public void atomically(Object message) { if (message instanceof Increment) { - increment(count, 1); + Stm.increment(count, 1); final Increment increment = (Increment) message; Runnable countDown = new Runnable() { public void run() { increment.getLatch().countDown(); } }; - afterRollback(countDown); - afterCommit(countDown); + Stm.afterCompletion(countDown); } } diff --git a/akka-transactor/src/test/java/scala/concurrent/stm/JavaAPITests.java b/akka-transactor/src/test/java/scala/concurrent/stm/JavaAPITests.java index e2d0631590..63fb6abb74 100644 --- a/akka-transactor/src/test/java/scala/concurrent/stm/JavaAPITests.java +++ b/akka-transactor/src/test/java/scala/concurrent/stm/JavaAPITests.java @@ -5,13 +5,12 @@ package scala.concurrent.stm; import static org.junit.Assert.*; import org.junit.Test; -import scala.concurrent.stm.Ref; -import static scala.concurrent.stm.JavaAPI.*; +import scala.concurrent.stm.japi.Stm; +import static scala.concurrent.stm.japi.Stm.*; import scala.runtime.AbstractFunction1; import java.util.concurrent.Callable; -import static scala.collection.JavaConversions.*; import java.util.Map; import java.util.Set; import java.util.List; @@ -96,8 +95,7 @@ public class JavaAPITests { @Test public void createAndUseTMap() { - TMap.View tmap = newTMap(); - Map map = mutableMapAsJavaMap(tmap); + Map map = newMap(); map.put(1, "one"); map.put(2, "two"); assertEquals("one", map.get(1)); @@ -109,8 +107,7 @@ public class JavaAPITests { @Test(expected = TestException.class) public void failingTMapTransaction() { - TMap.View tmap = newTMap(); - final Map map = mutableMapAsJavaMap(tmap); + final Map map = newMap(); try { atomic(new Runnable() { public void run() { @@ -130,8 +127,7 @@ public class JavaAPITests { @Test public void createAndUseTSet() { - TSet.View tset = newTSet(); - Set set = mutableSetAsJavaSet(tset); + Set set = newSet(); set.add("one"); set.add("two"); assertTrue(set.contains("one")); @@ -146,16 +142,15 @@ public class JavaAPITests { @Test public void createAndUseTArray() { - TArray.View tarray = newTArray(3); - List seq = mutableSeqAsJavaList(tarray); - assertEquals(null, seq.get(0)); - assertEquals(null, seq.get(1)); - assertEquals(null, seq.get(2)); - seq.set(0, "zero"); - seq.set(1, "one"); - seq.set(2, "two"); - assertEquals("zero", seq.get(0)); - assertEquals("one", seq.get(1)); - assertEquals("two", seq.get(2)); + List list = newList(3); + assertEquals(null, list.get(0)); + assertEquals(null, list.get(1)); + assertEquals(null, list.get(2)); + list.set(0, "zero"); + list.set(1, "one"); + list.set(2, "two"); + assertEquals("zero", list.get(0)); + assertEquals("one", list.get(1)); + assertEquals("two", list.get(2)); } }