Update java-friendly api for scala stm

- move to japi.Stm
- add newMap, newSet, newList methods with java conversions
- add afterCompletion lifecycle callback
This commit is contained in:
Peter Vlugter 2012-01-20 11:31:28 +13:00
parent 2058765485
commit 27da7c4d12
7 changed files with 75 additions and 47 deletions

View file

@ -8,10 +8,10 @@ package akka.docs.transactor;
import akka.actor.*; import akka.actor.*;
import akka.transactor.*; import akka.transactor.*;
import scala.concurrent.stm.Ref; import scala.concurrent.stm.Ref;
import static scala.concurrent.stm.JavaAPI.*; import scala.concurrent.stm.japi.Stm;
public class CoordinatedCounter extends UntypedActor { public class CoordinatedCounter extends UntypedActor {
private Ref.View<Integer> count = newRef(0); private Ref.View<Integer> count = Stm.newRef(0);
public void onReceive(Object incoming) throws Exception { public void onReceive(Object incoming) throws Exception {
if (incoming instanceof Coordinated) { if (incoming instanceof Coordinated) {
@ -24,7 +24,7 @@ public class CoordinatedCounter extends UntypedActor {
} }
coordinated.atomic(new Runnable() { coordinated.atomic(new Runnable() {
public void run() { public void run() {
increment(count, 1); Stm.increment(count, 1);
} }
}); });
} }

View file

@ -7,14 +7,14 @@ package akka.docs.transactor;
//#class //#class
import akka.transactor.*; import akka.transactor.*;
import scala.concurrent.stm.Ref; import scala.concurrent.stm.Ref;
import static scala.concurrent.stm.JavaAPI.*; import scala.concurrent.stm.japi.Stm;
public class Counter extends UntypedTransactor { public class Counter extends UntypedTransactor {
Ref.View<Integer> count = newRef(0); Ref.View<Integer> count = Stm.newRef(0);
public void atomically(Object message) { public void atomically(Object message) {
if (message instanceof Increment) { if (message instanceof Increment) {
increment(count, 1); Stm.increment(count, 1);
} }
} }

View file

@ -9,10 +9,10 @@ import akka.actor.*;
import akka.transactor.*; import akka.transactor.*;
import java.util.Set; import java.util.Set;
import scala.concurrent.stm.Ref; import scala.concurrent.stm.Ref;
import static scala.concurrent.stm.JavaAPI.*; import scala.concurrent.stm.japi.Stm;
public class FriendlyCounter extends UntypedTransactor { public class FriendlyCounter extends UntypedTransactor {
Ref.View<Integer> count = newRef(0); Ref.View<Integer> count = Stm.newRef(0);
@Override public Set<SendTo> coordinate(Object message) { @Override public Set<SendTo> coordinate(Object message) {
if (message instanceof Increment) { if (message instanceof Increment) {
@ -25,7 +25,7 @@ public class FriendlyCounter extends UntypedTransactor {
public void atomically(Object message) { public void atomically(Object message) {
if (message instanceof Increment) { if (message instanceof Increment) {
increment(count, 1); Stm.increment(count, 1);
} }
} }

View file

@ -1,14 +1,19 @@
/* scala-stm - (c) 2009-2011, Stanford University, PPL */ /* scala-stm - (c) 2009-2011, Stanford University, PPL */
package scala.concurrent.stm package scala.concurrent.stm.japi
import java.util.concurrent.Callable import java.util.concurrent.Callable
import java.util.{ List JList, Map JMap, Set JSet }
import scala.collection.JavaConversions
import scala.concurrent.stm
import scala.concurrent.stm._
import scala.runtime.AbstractFunction1 import scala.runtime.AbstractFunction1
/** /**
* Java-friendly API. * Java-friendly API for ScalaSTM.
* These methods can also be statically imported.
*/ */
object JavaAPI { object Stm {
/** /**
* Create a Ref with an initial value. Return a `Ref.View`, which does not * Create a Ref with an initial value. Return a `Ref.View`, which does not
@ -20,38 +25,58 @@ object JavaAPI {
/** /**
* Create an empty TMap. Return a `TMap.View`, which does not require * Create an empty TMap. Return a `TMap.View`, which does not require
* implicit transactions. * implicit transactions. See newMap for included java conversion.
* @return a new, empty `TMap.View` * @return a new, empty `TMap.View`
*/ */
def newTMap[A, B](): TMap.View[A, B] = TMap.empty[A, B].single def newTMap[A, B](): TMap.View[A, B] = TMap.empty[A, B].single
/**
* Create an empty TMap. Return a `java.util.Map` view of this TMap.
* @return a new, empty `TMap.View` wrapped as a `java.util.Map`.
*/
def newMap[A, B](): JMap[A, B] = JavaConversions.mutableMapAsJavaMap(newTMap[A, B])
/** /**
* Create an empty TSet. Return a `TSet.View`, which does not require * Create an empty TSet. Return a `TSet.View`, which does not require
* implicit transactions. * implicit transactions. See newSet for included java conversion.
* @return a new, empty `TSet.View` * @return a new, empty `TSet.View`
*/ */
def newTSet[A](): TSet.View[A] = TSet.empty[A].single def newTSet[A](): TSet.View[A] = TSet.empty[A].single
/**
* Create an empty TSet. Return a `java.util.Set` view of this TSet.
* @return a new, empty `TSet.View` wrapped as a `java.util.Set`.
*/
def newSet[A](): JSet[A] = JavaConversions.mutableSetAsJavaSet(newTSet[A])
/** /**
* Create a TArray containing `length` elements. Return a `TArray.View`, * Create a TArray containing `length` elements. Return a `TArray.View`,
* which does not require implicit transactions. * which does not require implicit transactions. See newList for included
* java conversion.
* @param length the length of the `TArray.View` to be created * @param length the length of the `TArray.View` to be created
* @return a new `TArray.View` containing `length` elements (initially null) * @return a new `TArray.View` containing `length` elements (initially null)
*/ */
def newTArray[A <: AnyRef](length: Int): TArray.View[A] = TArray.ofDim[A](length)(ClassManifest.classType(AnyRef.getClass)).single def newTArray[A <: AnyRef](length: Int): TArray.View[A] = TArray.ofDim[A](length)(ClassManifest.classType(AnyRef.getClass)).single
/**
* Create an empty TArray. Return a `java.util.List` view of this Array.
* @param length the length of the `TArray.View` to be created
* @return a new, empty `TArray.View` wrapped as a `java.util.List`.
*/
def newList[A <: AnyRef](length: Int): JList[A] = JavaConversions.mutableSeqAsJavaList(newTArray[A](length))
/** /**
* Atomic block that takes a `Runnable`. * Atomic block that takes a `Runnable`.
* @param runnable the `Runnable` to run within a transaction * @param runnable the `Runnable` to run within a transaction
*/ */
def atomic(runnable: Runnable): Unit = scala.concurrent.stm.atomic { txn runnable.run } def atomic(runnable: Runnable): Unit = stm.atomic { txn runnable.run }
/** /**
* Atomic block that takes a `Callable`. * Atomic block that takes a `Callable`.
* @param callable the `Callable` to run within a transaction * @param callable the `Callable` to run within a transaction
* @return the value returned by the `Callable` * @return the value returned by the `Callable`
*/ */
def atomic[A](callable: Callable[A]): A = scala.concurrent.stm.atomic { txn callable.call } def atomic[A](callable: Callable[A]): A = stm.atomic { txn callable.call }
/** /**
* Transform the value stored by `ref` by applying the function `f`. * Transform the value stored by `ref` by applying the function `f`.
@ -109,4 +134,14 @@ object JavaAPI {
val txn = Txn.findCurrent val txn = Txn.findCurrent
if (txn.isDefined) Txn.afterRollback(status task.run)(txn.get) if (txn.isDefined) Txn.afterRollback(status task.run)(txn.get)
} }
/**
* Add a task to run after the current transaction has either rolled back
* or committed.
* @param task the `Runnable` task to run after transaction completion
*/
def afterCompletion(task: Runnable): Unit = {
val txn = Txn.findCurrent
if (txn.isDefined) Txn.afterCompletion(status task.run)(txn.get)
}
} }

View file

@ -7,15 +7,15 @@ package akka.transactor;
import akka.actor.ActorRef; import akka.actor.ActorRef;
import akka.actor.Actors; import akka.actor.Actors;
import akka.actor.UntypedActor; import akka.actor.UntypedActor;
import static scala.concurrent.stm.JavaAPI.*;
import scala.concurrent.stm.Ref; import scala.concurrent.stm.Ref;
import scala.concurrent.stm.japi.Stm;
import java.util.List; import java.util.List;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
public class UntypedCoordinatedCounter extends UntypedActor { public class UntypedCoordinatedCounter extends UntypedActor {
private String name; private String name;
private Ref.View<Integer> count = newRef(0); private Ref.View<Integer> count = Stm.newRef(0);
public UntypedCoordinatedCounter(String name) { public UntypedCoordinatedCounter(String name) {
this.name = name; this.name = name;
@ -40,9 +40,8 @@ public class UntypedCoordinatedCounter extends UntypedActor {
} }
coordinated.atomic(new Runnable() { coordinated.atomic(new Runnable() {
public void run() { public void run() {
increment(count, 1); Stm.increment(count, 1);
afterRollback(countDown); Stm.afterCompletion(countDown);
afterCommit(countDown);
} }
}); });
} }

View file

@ -7,8 +7,8 @@ package akka.transactor;
import akka.actor.ActorRef; import akka.actor.ActorRef;
import akka.transactor.UntypedTransactor; import akka.transactor.UntypedTransactor;
import akka.transactor.SendTo; import akka.transactor.SendTo;
import static scala.concurrent.stm.JavaAPI.*;
import scala.concurrent.stm.Ref; import scala.concurrent.stm.Ref;
import scala.concurrent.stm.japi.Stm;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
@ -16,7 +16,7 @@ import java.util.concurrent.TimeUnit;
public class UntypedCounter extends UntypedTransactor { public class UntypedCounter extends UntypedTransactor {
private String name; private String name;
private Ref.View<Integer> count = newRef(0); private Ref.View<Integer> count = Stm.newRef(0);
public UntypedCounter(String name) { public UntypedCounter(String name) {
this.name = name; this.name = name;
@ -39,15 +39,14 @@ public class UntypedCounter extends UntypedTransactor {
public void atomically(Object message) { public void atomically(Object message) {
if (message instanceof Increment) { if (message instanceof Increment) {
increment(count, 1); Stm.increment(count, 1);
final Increment increment = (Increment) message; final Increment increment = (Increment) message;
Runnable countDown = new Runnable() { Runnable countDown = new Runnable() {
public void run() { public void run() {
increment.getLatch().countDown(); increment.getLatch().countDown();
} }
}; };
afterRollback(countDown); Stm.afterCompletion(countDown);
afterCommit(countDown);
} }
} }

View file

@ -5,13 +5,12 @@ package scala.concurrent.stm;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import org.junit.Test; import org.junit.Test;
import scala.concurrent.stm.Ref; import scala.concurrent.stm.japi.Stm;
import static scala.concurrent.stm.JavaAPI.*; import static scala.concurrent.stm.japi.Stm.*;
import scala.runtime.AbstractFunction1; import scala.runtime.AbstractFunction1;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import static scala.collection.JavaConversions.*;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.List; import java.util.List;
@ -96,8 +95,7 @@ public class JavaAPITests {
@Test @Test
public void createAndUseTMap() { public void createAndUseTMap() {
TMap.View<Integer, String> tmap = newTMap(); Map<Integer, String> map = newMap();
Map<Integer, String> map = mutableMapAsJavaMap(tmap);
map.put(1, "one"); map.put(1, "one");
map.put(2, "two"); map.put(2, "two");
assertEquals("one", map.get(1)); assertEquals("one", map.get(1));
@ -109,8 +107,7 @@ public class JavaAPITests {
@Test(expected = TestException.class) @Test(expected = TestException.class)
public void failingTMapTransaction() { public void failingTMapTransaction() {
TMap.View<Integer, String> tmap = newTMap(); final Map<Integer, String> map = newMap();
final Map<Integer, String> map = mutableMapAsJavaMap(tmap);
try { try {
atomic(new Runnable() { atomic(new Runnable() {
public void run() { public void run() {
@ -130,8 +127,7 @@ public class JavaAPITests {
@Test @Test
public void createAndUseTSet() { public void createAndUseTSet() {
TSet.View<String> tset = newTSet(); Set<String> set = newSet();
Set<String> set = mutableSetAsJavaSet(tset);
set.add("one"); set.add("one");
set.add("two"); set.add("two");
assertTrue(set.contains("one")); assertTrue(set.contains("one"));
@ -146,16 +142,15 @@ public class JavaAPITests {
@Test @Test
public void createAndUseTArray() { public void createAndUseTArray() {
TArray.View<String> tarray = newTArray(3); List<String> list = newList(3);
List<String> seq = mutableSeqAsJavaList(tarray); assertEquals(null, list.get(0));
assertEquals(null, seq.get(0)); assertEquals(null, list.get(1));
assertEquals(null, seq.get(1)); assertEquals(null, list.get(2));
assertEquals(null, seq.get(2)); list.set(0, "zero");
seq.set(0, "zero"); list.set(1, "one");
seq.set(1, "one"); list.set(2, "two");
seq.set(2, "two"); assertEquals("zero", list.get(0));
assertEquals("zero", seq.get(0)); assertEquals("one", list.get(1));
assertEquals("one", seq.get(1)); assertEquals("two", list.get(2));
assertEquals("two", seq.get(2));
} }
} }