diff --git a/akka-stm/src/main/scala/stm/Atomic.scala b/akka-stm/src/main/scala/stm/Atomic.scala index 00092019a8..dc5f192573 100644 --- a/akka-stm/src/main/scala/stm/Atomic.scala +++ b/akka-stm/src/main/scala/stm/Atomic.scala @@ -33,7 +33,7 @@ package akka.stm * }.execute(); * }}} */ -abstract class Atomic[T](factory: TransactionFactory) { +abstract class Atomic[T](val factory: TransactionFactory) { def this() = this(DefaultTransactionFactory) def atomically: T def execute: T = atomic(factory)(atomically) diff --git a/akka-stm/src/main/scala/transactor/Coordinated.scala b/akka-stm/src/main/scala/transactor/Coordinated.scala index fa3274123e..97fcaffb2c 100644 --- a/akka-stm/src/main/scala/transactor/Coordinated.scala +++ b/akka-stm/src/main/scala/transactor/Coordinated.scala @@ -5,7 +5,7 @@ package akka.transactor import akka.config.Config -import akka.stm.{DefaultTransactionConfig, TransactionFactory} +import akka.stm.{Atomic, DefaultTransactionConfig, TransactionFactory} import org.multiverse.api.{Transaction => MultiverseTransaction} import org.multiverse.commitbarriers.CountDownCommitBarrier @@ -84,14 +84,42 @@ object Coordinated { * @see [[akka.actor.Transactor]] for an actor that implements coordinated transactions */ class Coordinated(val message: Any, barrier: CountDownCommitBarrier) { + + // Java API constructors + def this(message: Any) = this(message, Coordinated.createBarrier) + def this() = this(null, Coordinated.createBarrier) + + /** + * Create a new Coordinated object and increment the number of parties by one. + * Use this method to ''pass on'' the coordination. + */ def apply(msg: Any) = { barrier.incParties(1) new Coordinated(msg, barrier) } + /** + * Java API: get the message for this Coordinated. + */ + def getMessage() = message + + /** + * Java API: create a new Coordinated object and increment the number of parties by one. + * Use this method to ''pass on'' the coordination. + */ + def coordinate(msg: Any) = apply(msg) + + /** + * Delimits the coordinated transaction. The transaction will wait for all other transactions + * in this coordination before committing. The timeout is specified by the transaction factory. + */ def atomic[T](body: => T)(implicit factory: TransactionFactory = Coordinated.DefaultFactory): T = atomic(factory)(body) + /** + * Delimits the coordinated transaction. The transaction will wait for all other transactions + * in this coordination before committing. The timeout is specified by the transaction factory. + */ def atomic[T](factory: TransactionFactory)(body: => T): T = { factory.boilerplate.execute(new TransactionalCallable[T]() { def call(mtx: MultiverseTransaction): T = { @@ -108,4 +136,11 @@ class Coordinated(val message: Any, barrier: CountDownCommitBarrier) { } }) } + + /** + * Java API: coordinated atomic method that accepts an [[akka.stm.Atomic]]. + * Delimits the coordinated transaction. The transaction will wait for all other transactions + * in this coordination before committing. The timeout is specified by the transaction factory. + */ + def atomic[T](jatomic: Atomic[T]): T = atomic(jatomic.factory)(jatomic.atomically) } diff --git a/akka-stm/src/test/java/akka/transactor/test/CoordinatedCounter.java b/akka-stm/src/test/java/akka/transactor/test/CoordinatedCounter.java new file mode 100644 index 0000000000..d84673b206 --- /dev/null +++ b/akka-stm/src/test/java/akka/transactor/test/CoordinatedCounter.java @@ -0,0 +1,63 @@ +package akka.transactor.test; + +import akka.transactor.Coordinated; +import akka.actor.ActorRef; +import akka.actor.UntypedActor; +import akka.stm.*; +import akka.util.Duration; + +import org.multiverse.api.StmUtils; + +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public class CoordinatedCounter extends UntypedActor { + String name; + Ref count = new Ref(0); + TransactionFactory txFactory = new TransactionFactoryBuilder() + .setTimeout(new Duration(3, TimeUnit.SECONDS)) + .build(); + + public CoordinatedCounter(String name) { + this.name = name; + } + + private void increment() { + System.out.println(name + ": incrementing"); + count.set(count.get() + 1); + } + + public void onReceive(Object incoming) throws Exception { + if (incoming instanceof Coordinated) { + Coordinated coordinated = (Coordinated) incoming; + Object message = coordinated.getMessage(); + if (message instanceof Increment) { + Increment increment = (Increment) message; + List friends = increment.friends; + final CountDownLatch latch = increment.latch; + if (!friends.isEmpty()) { + Increment coordMessage = new Increment(friends.subList(1, friends.size()), latch); + friends.get(0).sendOneWay(coordinated.coordinate(coordMessage)); + } + coordinated.atomic(new Atomic(txFactory) { + public Object atomically() { + increment(); + StmUtils.scheduleDeferredTask(new Runnable() { + public void run() { latch.countDown(); } + }); + StmUtils.scheduleCompensatingTask(new Runnable() { + public void run() { latch.countDown(); } + }); + return null; + } + }); + } + } else if (incoming instanceof String) { + String message = (String) incoming; + if (message.equals("GetCount")) { + getContext().replyUnsafe(count.get()); + } + } + } +} diff --git a/akka-stm/src/test/java/akka/transactor/test/CoordinatedFailer.java b/akka-stm/src/test/java/akka/transactor/test/CoordinatedFailer.java new file mode 100644 index 0000000000..ab1f21d01f --- /dev/null +++ b/akka-stm/src/test/java/akka/transactor/test/CoordinatedFailer.java @@ -0,0 +1,9 @@ +package akka.transactor.test; + +import akka.actor.UntypedActor; + +public class CoordinatedFailer extends UntypedActor { + public void onReceive(Object incoming) throws Exception { + throw new RuntimeException("Expected failure"); + } +} diff --git a/akka-stm/src/test/java/akka/transactor/test/CoordinatedIncrementTest.java b/akka-stm/src/test/java/akka/transactor/test/CoordinatedIncrementTest.java new file mode 100644 index 0000000000..d520288141 --- /dev/null +++ b/akka-stm/src/test/java/akka/transactor/test/CoordinatedIncrementTest.java @@ -0,0 +1,88 @@ +package akka.transactor.test; + +import static org.junit.Assert.*; +import org.junit.Test; +import org.junit.Before; + +import akka.transactor.Coordinated; +import akka.actor.ActorRef; +import akka.actor.UntypedActor; +import akka.actor.UntypedActorFactory; +import akka.dispatch.Future; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import scala.Option; + +public class CoordinatedIncrementTest { + List counters; + ActorRef failer; + + int numCounters = 5; + int timeout = 5; + + @Before public void initialise() { + counters = new ArrayList(); + for (int i = 1; i <= numCounters; i++) { + final String name = "counter" + i; + ActorRef counter = UntypedActor.actorOf(new UntypedActorFactory() { + public UntypedActor create() { + return new CoordinatedCounter(name); + } + }); + counter.start(); + counters.add(counter); + } + failer = UntypedActor.actorOf(CoordinatedFailer.class); + failer.start(); + } + + @Test public void incrementAllCountersWithSuccessfulTransaction() { + CountDownLatch incrementLatch = new CountDownLatch(numCounters); + Increment message = new Increment(counters.subList(1, counters.size()), incrementLatch); + counters.get(0).sendOneWay(new Coordinated(message)); + try { + incrementLatch.await(timeout, TimeUnit.SECONDS); + } catch (InterruptedException exception) {} + for (ActorRef counter : counters) { + Future future = counter.sendRequestReplyFuture("GetCount"); + future.await(); + if (future.isCompleted()) { + Option resultOption = future.result(); + if (resultOption.isDefined()) { + Object result = resultOption.get(); + int count = (Integer) result; + assertEquals(1, count); + } + } + } + } + + @Test public void incrementNoCountersWithFailingTransaction() { + CountDownLatch incrementLatch = new CountDownLatch(numCounters); + List actors = new ArrayList(counters); + actors.add(failer); + Increment message = new Increment(actors.subList(1, actors.size()), incrementLatch); + actors.get(0).sendOneWay(new Coordinated(message)); + try { + incrementLatch.await(timeout, TimeUnit.SECONDS); + } catch (InterruptedException exception) {} + for (ActorRef counter : counters) { + Future future = counter.sendRequestReplyFuture("GetCount"); + future.await(); + if (future.isCompleted()) { + Option resultOption = future.result(); + if (resultOption.isDefined()) { + Object result = resultOption.get(); + int count = (Integer) result; + assertEquals(0, count); + } + } + } + } +} + + diff --git a/akka-stm/src/test/java/akka/transactor/test/Increment.java b/akka-stm/src/test/java/akka/transactor/test/Increment.java new file mode 100644 index 0000000000..2ec3c356d6 --- /dev/null +++ b/akka-stm/src/test/java/akka/transactor/test/Increment.java @@ -0,0 +1,15 @@ +package akka.transactor.test; + +import akka.actor.ActorRef; +import java.util.List; +import java.util.concurrent.CountDownLatch; + +public class Increment { + List friends; + CountDownLatch latch; + + public Increment(List friends, CountDownLatch latch) { + this.friends = friends; + this.latch = latch; + } +} \ No newline at end of file diff --git a/akka-stm/src/test/scala/transactor/JavaCoordinatedSpec.scala b/akka-stm/src/test/scala/transactor/JavaCoordinatedSpec.scala new file mode 100644 index 0000000000..25f3853bd8 --- /dev/null +++ b/akka-stm/src/test/scala/transactor/JavaCoordinatedSpec.scala @@ -0,0 +1,5 @@ +package akka.transactor.test + +import org.scalatest.junit.JUnitWrapperSuite + +class JavaCoordinatedSpec extends JUnitWrapperSuite("akka.transactor.test.CoordinatedIncrementTest", Thread.currentThread.getContextClassLoader)