diff --git a/akka-actor/src/main/scala/actor/ActorRef.scala b/akka-actor/src/main/scala/actor/ActorRef.scala
index 5dce8b11cd..3472ec4696 100644
--- a/akka-actor/src/main/scala/actor/ActorRef.scala
+++ b/akka-actor/src/main/scala/actor/ActorRef.scala
@@ -196,19 +196,10 @@ trait ActorRef extends
*/
@volatile private[akka] var _transactionFactory: Option[TransactionFactory] = None
- /**
- * This lock ensures thread safety in the dispatching: only one message can
- * be dispatched at once on the actor.
- */
- protected[akka] val dispatcherLock = new ReentrantLock
-
/**
* This is a reference to the message currently being processed by the actor
*/
- protected[akka] var _currentMessage: Option[MessageInvocation] = None
-
- protected[akka] def currentMessage_=(msg: Option[MessageInvocation]) = guard.withGuard { _currentMessage = msg }
- protected[akka] def currentMessage = guard.withGuard { _currentMessage }
+ @volatile protected[akka] var currentMessage: MessageInvocation = null
/**
* Comparison only takes uuid into account.
@@ -978,12 +969,12 @@ class LocalActorRef private[akka](
protected[akka] def postMessageToMailbox(message: Any, senderOption: Option[ActorRef]): Unit = {
joinTransaction(message)
- if (isRemotingEnabled && remoteAddress.isDefined) {
+ if (remoteAddress.isDefined && isRemotingEnabled) {
RemoteClientModule.send[Any](
message, senderOption, None, remoteAddress.get, timeout, true, this, None, ActorType.ScalaActor)
} else {
val invocation = new MessageInvocation(this, message, senderOption, None, transactionSet.get)
- invocation.send
+ dispatcher dispatch invocation
}
}
@@ -994,7 +985,7 @@ class LocalActorRef private[akka](
senderFuture: Option[CompletableFuture[T]]): CompletableFuture[T] = {
joinTransaction(message)
- if (isRemotingEnabled && remoteAddress.isDefined) {
+ if (remoteAddress.isDefined && isRemotingEnabled) {
val future = RemoteClientModule.send[T](
message, senderOption, senderFuture, remoteAddress.get, timeout, false, this, None, ActorType.ScalaActor)
if (future.isDefined) future.get
@@ -1004,7 +995,7 @@ class LocalActorRef private[akka](
else new DefaultCompletableFuture[T](timeout)
val invocation = new MessageInvocation(
this, message, senderOption, Some(future.asInstanceOf[CompletableFuture[Any]]), transactionSet.get)
- invocation.send
+ dispatcher dispatch invocation
future
}
}
@@ -1016,7 +1007,7 @@ class LocalActorRef private[akka](
if (isShutdown)
Actor.log.warning("Actor [%s] is shut down,\n\tignoring message [%s]", toString, messageHandle)
else {
- currentMessage = Option(messageHandle)
+ currentMessage = messageHandle
try {
dispatch(messageHandle)
} catch {
@@ -1024,7 +1015,7 @@ class LocalActorRef private[akka](
Actor.log.error(e, "Could not invoke actor [%s]", this)
throw e
} finally {
- currentMessage = None //TODO: Don't reset this, we might want to resend the message
+ currentMessage = null //TODO: Don't reset this, we might want to resend the message
}
}
}
@@ -1187,7 +1178,7 @@ class LocalActorRef private[akka](
}
private def dispatch[T](messageHandle: MessageInvocation) = {
- Actor.log.trace("Invoking actor with message:\n" + messageHandle)
+ Actor.log.trace("Invoking actor with message: %s\n",messageHandle)
val message = messageHandle.message //serializeMessage(messageHandle.message)
var topLevelTransaction = false
val txSet: Option[CountDownCommitBarrier] =
@@ -1351,17 +1342,18 @@ object RemoteActorSystemMessage {
* @author Jonas Bonér
*/
private[akka] case class RemoteActorRef private[akka] (
- uuuid: String,
+ classOrServiceName: String,
val className: String,
val hostname: String,
val port: Int,
_timeout: Long,
- loader: Option[ClassLoader])
+ loader: Option[ClassLoader],
+ val actorType: ActorType = ActorType.ScalaActor)
extends ActorRef with ScalaActorRef {
ensureRemotingEnabled
- _uuid = uuuid
+ id = classOrServiceName
timeout = _timeout
start
@@ -1369,7 +1361,7 @@ private[akka] case class RemoteActorRef private[akka] (
def postMessageToMailbox(message: Any, senderOption: Option[ActorRef]): Unit =
RemoteClientModule.send[Any](
- message, senderOption, None, remoteAddress.get, timeout, true, this, None, ActorType.ScalaActor)
+ message, senderOption, None, remoteAddress.get, timeout, true, this, None, actorType)
def postMessageToMailboxAndCreateFutureResultWithTimeout[T](
message: Any,
@@ -1377,7 +1369,7 @@ private[akka] case class RemoteActorRef private[akka] (
senderOption: Option[ActorRef],
senderFuture: Option[CompletableFuture[T]]): CompletableFuture[T] = {
val future = RemoteClientModule.send[T](
- message, senderOption, senderFuture, remoteAddress.get, timeout, false, this, None, ActorType.ScalaActor)
+ message, senderOption, senderFuture, remoteAddress.get, timeout, false, this, None, actorType)
if (future.isDefined) future.get
else throw new IllegalActorStateException("Expected a future from remote call to actor " + toString)
}
@@ -1533,10 +1525,9 @@ trait ScalaActorRef extends ActorRefShared { ref: ActorRef =>
* Is defined if the message was sent from another Actor, else None.
*/
def sender: Option[ActorRef] = {
- // Five lines of map-performance-avoidance, could be just: currentMessage map { _.sender }
val msg = currentMessage
- if (msg.isEmpty) None
- else msg.get.sender
+ if (msg eq null) None
+ else msg.sender
}
/**
@@ -1544,10 +1535,9 @@ trait ScalaActorRef extends ActorRefShared { ref: ActorRef =>
* Is defined if the message was sent with sent with '!!' or '!!!', else None.
*/
def senderFuture(): Option[CompletableFuture[Any]] = {
- // Five lines of map-performance-avoidance, could be just: currentMessage map { _.senderFuture }
val msg = currentMessage
- if (msg.isEmpty) None
- else msg.get.senderFuture
+ if (msg eq null) None
+ else msg.senderFuture
}
diff --git a/akka-actor/src/main/scala/actor/ActorRegistry.scala b/akka-actor/src/main/scala/actor/ActorRegistry.scala
index f3a479e6fd..51bbfd3477 100644
--- a/akka-actor/src/main/scala/actor/ActorRegistry.scala
+++ b/akka-actor/src/main/scala/actor/ActorRegistry.scala
@@ -125,7 +125,7 @@ object ActorRegistry extends ListenerManagement {
actorsByUUID.put(actor.uuid, actor)
// notify listeners
- foreachListener(_ ! ActorRegistered(actor))
+ notifyListeners(ActorRegistered(actor))
}
/**
@@ -137,7 +137,7 @@ object ActorRegistry extends ListenerManagement {
actorsById.remove(actor.id,actor)
// notify listeners
- foreachListener(_ ! ActorUnregistered(actor))
+ notifyListeners(ActorUnregistered(actor))
}
/**
diff --git a/akka-actor/src/main/scala/dispatch/ExecutorBasedEventDrivenDispatcher.scala b/akka-actor/src/main/scala/dispatch/ExecutorBasedEventDrivenDispatcher.scala
index dbbfe3442a..6cabdec5e5 100644
--- a/akka-actor/src/main/scala/dispatch/ExecutorBasedEventDrivenDispatcher.scala
+++ b/akka-actor/src/main/scala/dispatch/ExecutorBasedEventDrivenDispatcher.scala
@@ -7,7 +7,7 @@ package se.scalablesolutions.akka.dispatch
import se.scalablesolutions.akka.actor.{ActorRef, IllegalActorStateException}
import java.util.Queue
-import java.util.concurrent.{ConcurrentLinkedQueue, LinkedBlockingQueue}
+import java.util.concurrent.{RejectedExecutionException, ConcurrentLinkedQueue, LinkedBlockingQueue}
/**
* Default settings are:
@@ -64,7 +64,7 @@ import java.util.concurrent.{ConcurrentLinkedQueue, LinkedBlockingQueue}
*/
class ExecutorBasedEventDrivenDispatcher(
_name: String,
- throughput: Int = Dispatchers.THROUGHPUT,
+ val throughput: Int = Dispatchers.THROUGHPUT,
mailboxConfig: MailboxConfig = Dispatchers.MAILBOX_CONFIG,
config: (ThreadPoolBuilder) => Unit = _ => ()) extends MessageDispatcher with ThreadPoolBuilder {
@@ -80,71 +80,84 @@ class ExecutorBasedEventDrivenDispatcher(
val name = "akka:event-driven:dispatcher:" + _name
init
+ /**
+ * This is the behavior of an ExecutorBasedEventDrivenDispatchers mailbox
+ */
+ trait ExecutableMailbox extends Runnable { self: MessageQueue =>
+ final def run = {
+
+ val reschedule = try {
+ processMailbox()
+ } finally {
+ dispatcherLock.unlock()
+ }
+
+ if (reschedule || !self.isEmpty)
+ registerForExecution(self)
+ }
+
+ /**
+ * Process the messages in the mailbox
+ *
+ * @return true if the processing finished before the mailbox was empty, due to the throughput constraint
+ */
+ final def processMailbox(): Boolean = {
+ val throttle = throughput > 0
+ var processedMessages = 0
+ var nextMessage = self.dequeue
+ if (nextMessage ne null) {
+ do {
+ nextMessage.invoke
+
+ if(throttle) { //Will be elided when false
+ processedMessages += 1
+ if (processedMessages >= throughput) //If we're throttled, break out
+ return !self.isEmpty
+ }
+ nextMessage = self.dequeue
+ }
+ while (nextMessage ne null)
+ }
+
+ false
+ }
+ }
+
def dispatch(invocation: MessageInvocation) = {
- getMailbox(invocation.receiver) enqueue invocation
- dispatch(invocation.receiver)
+ val mbox = getMailbox(invocation.receiver)
+ mbox enqueue invocation
+ registerForExecution(mbox)
+ }
+
+ protected def registerForExecution(mailbox: MessageQueue with ExecutableMailbox): Unit = if (active) {
+ if (mailbox.dispatcherLock.tryLock()) {
+ try {
+ executor execute mailbox
+ } catch {
+ case e: RejectedExecutionException =>
+ mailbox.dispatcherLock.unlock()
+ throw e
+ }
+ }
+ } else {
+ log.warning("%s is shut down,\n\tignoring the rest of the messages in the mailbox of\n\t%s", toString, mailbox)
}
/**
* @return the mailbox associated with the actor
*/
- private def getMailbox(receiver: ActorRef) = receiver.mailbox.asInstanceOf[MessageQueue]
+ private def getMailbox(receiver: ActorRef) = receiver.mailbox.asInstanceOf[MessageQueue with ExecutableMailbox]
override def mailboxSize(actorRef: ActorRef) = getMailbox(actorRef).size
- override def createMailbox(actorRef: ActorRef): AnyRef = mailboxConfig.newMailbox(bounds = mailboxCapacity, blockDequeue = false)
-
- def dispatch(receiver: ActorRef): Unit = if (active) {
-
- executor.execute(new Runnable() {
- def run = {
- var lockAcquiredOnce = false
- var finishedBeforeMailboxEmpty = false
- val lock = receiver.dispatcherLock
- val mailbox = getMailbox(receiver)
- // this do-while loop is required to prevent missing new messages between the end of the inner while
- // loop and releasing the lock
- do {
- if (lock.tryLock) {
- // Only dispatch if we got the lock. Otherwise another thread is already dispatching.
- lockAcquiredOnce = true
- try {
- finishedBeforeMailboxEmpty = processMailbox(receiver)
- } finally {
- lock.unlock
- if (finishedBeforeMailboxEmpty) dispatch(receiver)
- }
- }
- } while ((lockAcquiredOnce && !finishedBeforeMailboxEmpty && !mailbox.isEmpty))
- }
- })
- } else {
- log.warning("%s is shut down,\n\tignoring the rest of the messages in the mailbox of\n\t%s", toString, receiver)
+ override def createMailbox(actorRef: ActorRef): AnyRef = {
+ if (mailboxCapacity > 0)
+ new DefaultBoundedMessageQueue(mailboxCapacity,mailboxConfig.pushTimeOut,blockDequeue = false) with ExecutableMailbox
+ else
+ new DefaultUnboundedMessageQueue(blockDequeue = false) with ExecutableMailbox
}
- /**
- * Process the messages in the mailbox of the given actor.
- *
- * @return true if the processing finished before the mailbox was empty, due to the throughput constraint
- */
- def processMailbox(receiver: ActorRef): Boolean = {
- var processedMessages = 0
- val mailbox = getMailbox(receiver)
- var messageInvocation = mailbox.dequeue
- while (messageInvocation != null) {
- messageInvocation.invoke
- processedMessages += 1
- // check if we simply continue with other messages, or reached the throughput limit
- if (throughput <= 0 || processedMessages < throughput) messageInvocation = mailbox.dequeue
- else {
- messageInvocation = null
- return !mailbox.isEmpty
- }
- }
- false
- }
-
def start = if (!active) {
log.debug("Starting up %s\n\twith throughput [%d]", toString, throughput)
active = true
@@ -157,8 +170,10 @@ class ExecutorBasedEventDrivenDispatcher(
uuids.clear
}
- def ensureNotActive(): Unit = if (active) throw new IllegalActorStateException(
+ def ensureNotActive(): Unit = if (active) {
+ throw new IllegalActorStateException(
"Can't build a new thread pool for a dispatcher that is already up and running")
+ }
override def toString = "ExecutorBasedEventDrivenDispatcher[" + name + "]"
diff --git a/akka-actor/src/main/scala/dispatch/ExecutorBasedEventDrivenWorkStealingDispatcher.scala b/akka-actor/src/main/scala/dispatch/ExecutorBasedEventDrivenWorkStealingDispatcher.scala
index 9b1097213e..10afb1bfb6 100644
--- a/akka-actor/src/main/scala/dispatch/ExecutorBasedEventDrivenWorkStealingDispatcher.scala
+++ b/akka-actor/src/main/scala/dispatch/ExecutorBasedEventDrivenWorkStealingDispatcher.scala
@@ -56,21 +56,14 @@ class ExecutorBasedEventDrivenWorkStealingDispatcher(
/**
* @return the mailbox associated with the actor
*/
- private def getMailbox(receiver: ActorRef) = receiver.mailbox.asInstanceOf[Deque[MessageInvocation]]
+ private def getMailbox(receiver: ActorRef) = receiver.mailbox.asInstanceOf[Deque[MessageInvocation] with MessageQueue with Runnable]
override def mailboxSize(actorRef: ActorRef) = getMailbox(actorRef).size
def dispatch(invocation: MessageInvocation) = if (active) {
- getMailbox(invocation.receiver).add(invocation)
- executor.execute(new Runnable() {
- def run = {
- if (!tryProcessMailbox(invocation.receiver)) {
- // we are not able to process our mailbox (another thread is busy with it), so lets donate some of our mailbox
- // to another actor and then process his mailbox in stead.
- findThief(invocation.receiver).foreach( tryDonateAndProcessMessages(invocation.receiver,_) )
- }
- }
- })
+ val mbox = getMailbox(invocation.receiver)
+ mbox enqueue invocation
+ executor execute mbox
} else throw new IllegalActorStateException("Can't submit invocations to dispatcher since it's not started")
/**
@@ -79,22 +72,21 @@ class ExecutorBasedEventDrivenWorkStealingDispatcher(
*
* @return true if the mailbox was processed, false otherwise
*/
- private def tryProcessMailbox(receiver: ActorRef): Boolean = {
+ private def tryProcessMailbox(mailbox: MessageQueue): Boolean = {
var lockAcquiredOnce = false
- val lock = receiver.dispatcherLock
// this do-wile loop is required to prevent missing new messages between the end of processing
// the mailbox and releasing the lock
do {
- if (lock.tryLock) {
+ if (mailbox.dispatcherLock.tryLock) {
lockAcquiredOnce = true
try {
- processMailbox(receiver)
+ processMailbox(mailbox)
} finally {
- lock.unlock
+ mailbox.dispatcherLock.unlock
}
}
- } while ((lockAcquiredOnce && !getMailbox(receiver).isEmpty))
+ } while ((lockAcquiredOnce && !mailbox.isEmpty))
lockAcquiredOnce
}
@@ -102,12 +94,11 @@ class ExecutorBasedEventDrivenWorkStealingDispatcher(
/**
* Process the messages in the mailbox of the given actor.
*/
- private def processMailbox(receiver: ActorRef) = {
- val mailbox = getMailbox(receiver)
- var messageInvocation = mailbox.poll
- while (messageInvocation != null) {
+ private def processMailbox(mailbox: MessageQueue) = {
+ var messageInvocation = mailbox.dequeue
+ while (messageInvocation ne null) {
messageInvocation.invoke
- messageInvocation = mailbox.poll
+ messageInvocation = mailbox.dequeue
}
}
@@ -145,11 +136,12 @@ class ExecutorBasedEventDrivenWorkStealingDispatcher(
* the thiefs dispatching lock, because in that case another thread is already processing the thiefs mailbox.
*/
private def tryDonateAndProcessMessages(receiver: ActorRef, thief: ActorRef) = {
- if (thief.dispatcherLock.tryLock) {
+ val mailbox = getMailbox(thief)
+ if (mailbox.dispatcherLock.tryLock) {
try {
- while(donateMessage(receiver, thief)) processMailbox(thief)
+ while(donateMessage(receiver, thief)) processMailbox(mailbox)
} finally {
- thief.dispatcherLock.unlock
+ mailbox.dispatcherLock.unlock
}
}
}
@@ -191,18 +183,44 @@ class ExecutorBasedEventDrivenWorkStealingDispatcher(
}
protected override def createMailbox(actorRef: ActorRef): AnyRef = {
- if (mailboxCapacity <= 0) new ConcurrentLinkedDeque[MessageInvocation]
- else new LinkedBlockingDeque[MessageInvocation](mailboxCapacity)
+ if (mailboxCapacity <= 0) {
+ new ConcurrentLinkedDeque[MessageInvocation] with MessageQueue with Runnable {
+ def enqueue(handle: MessageInvocation): Unit = this.add(handle)
+ def dequeue: MessageInvocation = this.poll()
+
+ def run = {
+ if (!tryProcessMailbox(this)) {
+ // we are not able to process our mailbox (another thread is busy with it), so lets donate some of our mailbox
+ // to another actor and then process his mailbox in stead.
+ findThief(actorRef).foreach( tryDonateAndProcessMessages(actorRef,_) )
+ }
+ }
+ }
+ }
+ else {
+ new LinkedBlockingDeque[MessageInvocation](mailboxCapacity) with MessageQueue with Runnable {
+ def enqueue(handle: MessageInvocation): Unit = this.add(handle)
+ def dequeue: MessageInvocation = this.poll()
+
+ def run = {
+ if (!tryProcessMailbox(this)) {
+ // we are not able to process our mailbox (another thread is busy with it), so lets donate some of our mailbox
+ // to another actor and then process his mailbox in stead.
+ findThief(actorRef).foreach( tryDonateAndProcessMessages(actorRef,_) )
+ }
+ }
+ }
+ }
}
override def register(actorRef: ActorRef) = {
verifyActorsAreOfSameType(actorRef)
- pooledActors.add(actorRef)
+ pooledActors add actorRef
super.register(actorRef)
}
override def unregister(actorRef: ActorRef) = {
- pooledActors.remove(actorRef)
+ pooledActors remove actorRef
super.unregister(actorRef)
}
diff --git a/akka-actor/src/main/scala/dispatch/MessageHandling.scala b/akka-actor/src/main/scala/dispatch/MessageHandling.scala
index 383c58905a..25a02f2603 100644
--- a/akka-actor/src/main/scala/dispatch/MessageHandling.scala
+++ b/akka-actor/src/main/scala/dispatch/MessageHandling.scala
@@ -8,10 +8,10 @@ import se.scalablesolutions.akka.actor.{Actor, ActorRef, ActorInitializationExce
import org.multiverse.commitbarriers.CountDownCommitBarrier
import se.scalablesolutions.akka.AkkaException
-import se.scalablesolutions.akka.util.{Duration, HashCode, Logging}
import java.util.{Queue, List}
import java.util.concurrent._
import concurrent.forkjoin.LinkedTransferQueue
+import se.scalablesolutions.akka.util.{SimpleLock, Duration, HashCode, Logging}
/**
* @author Jonas Bonér
@@ -30,8 +30,6 @@ final class MessageInvocation(val receiver: ActorRef,
"Don't call 'self ! message' in the Actor's constructor (e.g. body of the class).")
}
- def send = receiver.dispatcher.dispatch(this)
-
override def hashCode(): Int = synchronized {
var result = HashCode.SEED
result = HashCode.hash(result, receiver.actor)
@@ -63,6 +61,7 @@ class MessageQueueAppendFailedException(message: String) extends AkkaException(m
* @author Jonas Bonér
*/
trait MessageQueue {
+ val dispatcherLock = new SimpleLock
def enqueue(handle: MessageInvocation)
def dequeue(): MessageInvocation
def size: Int
@@ -84,40 +83,36 @@ case class MailboxConfig(capacity: Int, pushTimeOut: Option[Duration], blockingD
*/
def newMailbox(bounds: Int = capacity,
pushTime: Option[Duration] = pushTimeOut,
- blockDequeue: Boolean = blockingDequeue) : MessageQueue = {
- if (bounds <= 0) { //UNBOUNDED: Will never block enqueue and optionally blocking dequeue
- new LinkedTransferQueue[MessageInvocation] with MessageQueue {
- def enqueue(handle: MessageInvocation): Unit = this add handle
- def dequeue(): MessageInvocation = {
- if(blockDequeue) this.take()
- else this.poll()
- }
- }
- }
- else if (pushTime.isDefined) { //BOUNDED: Timeouted enqueue with MessageQueueAppendFailedException and optionally blocking dequeue
- val time = pushTime.get
- new BoundedTransferQueue[MessageInvocation](bounds) with MessageQueue {
- def enqueue(handle: MessageInvocation) {
- if (!this.offer(handle,time.length,time.unit))
- throw new MessageQueueAppendFailedException("Couldn't enqueue message " + handle + " to " + this.toString)
- }
+ blockDequeue: Boolean = blockingDequeue) : MessageQueue =
+ if (capacity > 0) new DefaultBoundedMessageQueue(bounds,pushTime,blockDequeue)
+ else new DefaultUnboundedMessageQueue(blockDequeue)
+}
- def dequeue(): MessageInvocation = {
- if (blockDequeue) this.take()
- else this.poll()
- }
- }
+class DefaultUnboundedMessageQueue(blockDequeue: Boolean) extends LinkedBlockingQueue[MessageInvocation] with MessageQueue {
+ final def enqueue(handle: MessageInvocation) {
+ this add handle
+ }
+
+ final def dequeue(): MessageInvocation =
+ if (blockDequeue) this.take()
+ else this.poll()
+}
+
+class DefaultBoundedMessageQueue(capacity: Int, pushTimeOut: Option[Duration], blockDequeue: Boolean) extends LinkedBlockingQueue[MessageInvocation](capacity) with MessageQueue {
+ final def enqueue(handle: MessageInvocation) {
+ if (pushTimeOut.isDefined) {
+ if(!this.offer(handle,pushTimeOut.get.length,pushTimeOut.get.unit))
+ throw new MessageQueueAppendFailedException("Couldn't enqueue message " + handle + " to " + toString)
}
- else { //BOUNDED: Blocking enqueue and optionally blocking dequeue
- new LinkedBlockingQueue[MessageInvocation](bounds) with MessageQueue {
- def enqueue(handle: MessageInvocation): Unit = this put handle
- def dequeue(): MessageInvocation = {
- if(blockDequeue) this.take()
- else this.poll()
- }
- }
+ else {
+ this put handle
}
}
+
+ final def dequeue(): MessageInvocation =
+ if (blockDequeue) this.take()
+ else this.poll()
+
}
/**
@@ -139,7 +134,7 @@ trait MessageDispatcher extends Logging {
}
def unregister(actorRef: ActorRef) = {
uuids remove actorRef.uuid
- //actorRef.mailbox = null //FIXME should we null out the mailbox here?
+ actorRef.mailbox = null
if (canBeShutDown) shutdown // shut down in the dispatcher's references is zero
}
@@ -156,14 +151,4 @@ trait MessageDispatcher extends Logging {
* Creates and returns a mailbox for the given actor
*/
protected def createMailbox(actorRef: ActorRef): AnyRef = null
-}
-
-/**
- * @author Jonas Bonér
- */
-trait MessageDemultiplexer {
- def select
- def wakeUp
- def acquireSelectedInvocations: List[MessageInvocation]
- def releaseSelectedInvocations
-}
+}
\ No newline at end of file
diff --git a/akka-actor/src/main/scala/dispatch/Queues.scala b/akka-actor/src/main/scala/dispatch/Queues.scala
deleted file mode 100644
index 2ba88f25c3..0000000000
--- a/akka-actor/src/main/scala/dispatch/Queues.scala
+++ /dev/null
@@ -1,141 +0,0 @@
-/**
- * Copyright (C) 2009-2010 Scalable Solutions AB
- */
-
-package se.scalablesolutions.akka.dispatch
-
-import concurrent.forkjoin.LinkedTransferQueue
-import java.util.concurrent.{TimeUnit, Semaphore}
-import java.util.Iterator
-import se.scalablesolutions.akka.util.Logger
-
-class BoundedTransferQueue[E <: AnyRef](val capacity: Int) extends LinkedTransferQueue[E] {
- require(capacity > 0)
-
- protected val guard = new Semaphore(capacity)
-
- override def take(): E = {
- val e = super.take
- if (e ne null) guard.release
- e
- }
-
- override def poll(): E = {
- val e = super.poll
- if (e ne null) guard.release
- e
- }
-
- override def poll(timeout: Long, unit: TimeUnit): E = {
- val e = super.poll(timeout,unit)
- if (e ne null) guard.release
- e
- }
-
- override def remainingCapacity = guard.availablePermits
-
- override def remove(o: AnyRef): Boolean = {
- if (super.remove(o)) {
- guard.release
- true
- } else {
- false
- }
- }
-
- override def offer(e: E): Boolean = {
- if (guard.tryAcquire) {
- val result = try {
- super.offer(e)
- } catch {
- case e => guard.release; throw e
- }
- if (!result) guard.release
- result
- } else
- false
- }
-
- override def offer(e: E, timeout: Long, unit: TimeUnit): Boolean = {
- if (guard.tryAcquire(timeout,unit)) {
- val result = try {
- super.offer(e)
- } catch {
- case e => guard.release; throw e
- }
- if (!result) guard.release
- result
- } else
- false
- }
-
- override def add(e: E): Boolean = {
- if (guard.tryAcquire) {
- val result = try {
- super.add(e)
- } catch {
- case e => guard.release; throw e
- }
- if (!result) guard.release
- result
- } else
- false
- }
-
- override def put(e :E): Unit = {
- guard.acquire
- try {
- super.put(e)
- } catch {
- case e => guard.release; throw e
- }
- }
-
- override def tryTransfer(e: E): Boolean = {
- if (guard.tryAcquire) {
- val result = try {
- super.tryTransfer(e)
- } catch {
- case e => guard.release; throw e
- }
- if (!result) guard.release
- result
- } else
- false
- }
-
- override def tryTransfer(e: E, timeout: Long, unit: TimeUnit): Boolean = {
- if (guard.tryAcquire(timeout,unit)) {
- val result = try {
- super.tryTransfer(e)
- } catch {
- case e => guard.release; throw e
- }
- if (!result) guard.release
- result
- } else
- false
- }
-
- override def transfer(e: E): Unit = {
- if (guard.tryAcquire) {
- try {
- super.transfer(e)
- } catch {
- case e => guard.release; throw e
- }
- }
- }
-
- override def iterator: Iterator[E] = {
- val it = super.iterator
- new Iterator[E] {
- def hasNext = it.hasNext
- def next = it.next
- def remove {
- it.remove
- guard.release //Assume remove worked if no exception was thrown
- }
- }
- }
-}
\ No newline at end of file
diff --git a/akka-actor/src/main/scala/util/ListenerManagement.scala b/akka-actor/src/main/scala/util/ListenerManagement.scala
index 0e17058380..7ad0f451f1 100644
--- a/akka-actor/src/main/scala/util/ListenerManagement.scala
+++ b/akka-actor/src/main/scala/util/ListenerManagement.scala
@@ -40,6 +40,23 @@ trait ListenerManagement extends Logging {
if (manageLifeCycleOfListeners) listener.stop
}
+ /*
+ * Returns whether there are any listeners currently
+ */
+ def hasListeners: Boolean = !listeners.isEmpty
+
+ protected def notifyListeners(message: => Any) {
+ if (hasListeners) {
+ val msg = message
+ val iterator = listeners.iterator
+ while (iterator.hasNext) {
+ val listener = iterator.next
+ if (listener.isRunning) listener ! msg
+ else log.warning("Can't notify [%s] since it is not running.", listener)
+ }
+ }
+ }
+
/**
* Execute f with each listener as argument.
*/
diff --git a/akka-actor/src/main/scala/util/LockUtil.scala b/akka-actor/src/main/scala/util/LockUtil.scala
index 885e11def7..3d1261e468 100644
--- a/akka-actor/src/main/scala/util/LockUtil.scala
+++ b/akka-actor/src/main/scala/util/LockUtil.scala
@@ -5,6 +5,7 @@
package se.scalablesolutions.akka.util
import java.util.concurrent.locks.{ReentrantReadWriteLock, ReentrantLock}
+import java.util.concurrent.atomic.AtomicBoolean
/**
* @author Jonas Bonér
@@ -58,3 +59,56 @@ class ReadWriteGuard {
}
}
+/**
+ * A very simple lock that uses CCAS (Compare Compare-And-Swap)
+ * Does not keep track of the owner and isn't Reentrant, so don't nest and try to stick to the if*-methods
+ */
+class SimpleLock {
+ val acquired = new AtomicBoolean(false)
+
+ def ifPossible(perform: () => Unit): Boolean = {
+ if (tryLock()) {
+ try {
+ perform
+ } finally {
+ unlock()
+ }
+ true
+ } else false
+ }
+
+ def ifPossibleYield[T](perform: () => T): Option[T] = {
+ if (tryLock()) {
+ try {
+ Some(perform())
+ } finally {
+ unlock()
+ }
+ } else None
+ }
+
+ def ifPossibleApply[T,R](value: T)(function: (T) => R): Option[R] = {
+ if (tryLock()) {
+ try {
+ Some(function(value))
+ } finally {
+ unlock()
+ }
+ } else None
+ }
+
+ def tryLock() = {
+ if (acquired.get) false
+ else acquired.compareAndSet(false,true)
+ }
+
+ def tryUnlock() = {
+ acquired.compareAndSet(true,false)
+ }
+
+ def locked = acquired.get
+
+ def unlock() {
+ acquired.set(false)
+ }
+}
\ No newline at end of file
diff --git a/akka-actor/src/test/scala/dispatch/ExecutorBasedEventDrivenWorkStealingDispatcherSpec.scala b/akka-actor/src/test/scala/dispatch/ExecutorBasedEventDrivenWorkStealingDispatcherSpec.scala
index cde57a0544..3285e450c6 100644
--- a/akka-actor/src/test/scala/dispatch/ExecutorBasedEventDrivenWorkStealingDispatcherSpec.scala
+++ b/akka-actor/src/test/scala/dispatch/ExecutorBasedEventDrivenWorkStealingDispatcherSpec.scala
@@ -5,11 +5,10 @@ import org.scalatest.junit.JUnitSuite
import org.junit.Test
-import se.scalablesolutions.akka.dispatch.Dispatchers
-
import java.util.concurrent.{TimeUnit, CountDownLatch}
import se.scalablesolutions.akka.actor.{IllegalActorStateException, Actor}
import Actor._
+import se.scalablesolutions.akka.dispatch.{MessageQueue, Dispatchers}
object ExecutorBasedEventDrivenWorkStealingDispatcherSpec {
val delayableActorDispatcher = Dispatchers.newExecutorBasedEventDrivenWorkStealingDispatcher("pooled-dispatcher")
@@ -18,7 +17,7 @@ object ExecutorBasedEventDrivenWorkStealingDispatcherSpec {
class DelayableActor(name: String, delay: Int, finishedCounter: CountDownLatch) extends Actor {
self.dispatcher = delayableActorDispatcher
- var invocationCount = 0
+ @volatile var invocationCount = 0
self.id = name
def receive = {
@@ -61,10 +60,14 @@ class ExecutorBasedEventDrivenWorkStealingDispatcherSpec extends JUnitSuite with
val slow = actorOf(new DelayableActor("slow", 50, finishedCounter)).start
val fast = actorOf(new DelayableActor("fast", 10, finishedCounter)).start
+ var sentToFast = 0
+
for (i <- 1 to 100) {
// send most work to slow actor
- if (i % 20 == 0)
+ if (i % 20 == 0) {
fast ! i
+ sentToFast += 1
+ }
else
slow ! i
}
@@ -72,13 +75,18 @@ class ExecutorBasedEventDrivenWorkStealingDispatcherSpec extends JUnitSuite with
// now send some messages to actors to keep the dispatcher dispatching messages
for (i <- 1 to 10) {
Thread.sleep(150)
- if (i % 2 == 0)
+ if (i % 2 == 0) {
fast ! i
+ sentToFast += 1
+ }
else
slow ! i
}
finishedCounter.await(5, TimeUnit.SECONDS)
+ fast.mailbox.asInstanceOf[MessageQueue].isEmpty must be(true)
+ slow.mailbox.asInstanceOf[MessageQueue].isEmpty must be(true)
+ fast.actor.asInstanceOf[DelayableActor].invocationCount must be > sentToFast
fast.actor.asInstanceOf[DelayableActor].invocationCount must be >
(slow.actor.asInstanceOf[DelayableActor].invocationCount)
slow.stop
diff --git a/akka-persistence/akka-persistence-cassandra/src/main/scala/CassandraStorage.scala b/akka-persistence/akka-persistence-cassandra/src/main/scala/CassandraStorage.scala
index be5fc4f4c7..0c6f239ef7 100644
--- a/akka-persistence/akka-persistence-cassandra/src/main/scala/CassandraStorage.scala
+++ b/akka-persistence/akka-persistence-cassandra/src/main/scala/CassandraStorage.scala
@@ -29,7 +29,7 @@ object CassandraStorage extends Storage {
*
* @author Jonas Bonér
*/
-class CassandraPersistentMap(id: String) extends PersistentMap[Array[Byte], Array[Byte]] {
+class CassandraPersistentMap(id: String) extends PersistentMapBinary {
val uuid = id
val storage = CassandraStorageBackend
}
diff --git a/akka-persistence/akka-persistence-common/src/main/scala/Storage.scala b/akka-persistence/akka-persistence-common/src/main/scala/Storage.scala
index ccaf7518f1..4d9ff48a60 100644
--- a/akka-persistence/akka-persistence-common/src/main/scala/Storage.scala
+++ b/akka-persistence/akka-persistence-common/src/main/scala/Storage.scala
@@ -7,9 +7,11 @@ package se.scalablesolutions.akka.persistence.common
import se.scalablesolutions.akka.stm._
import se.scalablesolutions.akka.stm.TransactionManagement.transaction
import se.scalablesolutions.akka.util.Logging
-import se.scalablesolutions.akka.AkkaException
-class StorageException(message: String) extends AkkaException(message)
+// FIXME move to 'stm' package + add message with more info
+class NoTransactionInScopeException extends RuntimeException
+
+class StorageException(message: String) extends RuntimeException(message)
/**
* Example Scala usage.
@@ -80,24 +82,90 @@ trait Storage {
*/
trait PersistentMap[K, V] extends scala.collection.mutable.Map[K, V]
with Transactional with Committable with Abortable with Logging {
- protected val newAndUpdatedEntries = TransactionalMap[K, V]()
- protected val removedEntries = TransactionalVector[K]()
protected val shouldClearOnCommit = Ref[Boolean]()
+ // operations on the Map
+ trait Op
+ case object GET extends Op
+ case object PUT extends Op
+ case object REM extends Op
+ case object UPD extends Op
+
+ // append only log: records all mutating operations
+ protected val appendOnlyTxLog = TransactionalVector[LogEntry]()
+
+ case class LogEntry(key: K, value: Option[V], op: Op)
+
+ // need to override in subclasses e.g. "sameElements" for Array[Byte]
+ def equal(k1: K, k2: K): Boolean = k1 == k2
+
+ // Seqable type that's required for maintaining the log of distinct keys affected in current transaction
+ type T <: Equals
+
+ // converts key K to the Seqable type Equals
+ def toEquals(k: K): T
+
+ // keys affected in the current transaction
+ protected val keysInCurrentTx = TransactionalMap[T, K]()
+
+ protected def addToListOfKeysInTx(key: K): Unit =
+ keysInCurrentTx += (toEquals(key), key)
+
+ protected def clearDistinctKeys = keysInCurrentTx.clear
+
+ protected def filterTxLogByKey(key: K): IndexedSeq[LogEntry] =
+ appendOnlyTxLog filter(e => equal(e.key, key))
+
+ // need to get current value considering the underlying storage as well as the transaction log
+ protected def getCurrentValue(key: K): Option[V] = {
+
+ // get all mutating entries for this key for this tx
+ val txEntries = filterTxLogByKey(key)
+
+ // get the snapshot from the underlying store for this key
+ val underlying = try {
+ storage.getMapStorageEntryFor(uuid, key)
+ } catch { case e: Exception => None }
+
+ if (txEntries.isEmpty) underlying
+ else replay(txEntries, key, underlying)
+ }
+
+ // replay all tx entries for key k with seed = initial
+ private def replay(txEntries: IndexedSeq[LogEntry], key: K, initial: Option[V]): Option[V] = {
+ import scala.collection.mutable._
+
+ val m = initial match {
+ case None => Map.empty[K, V]
+ case Some(v) => Map((key, v))
+ }
+ txEntries.foreach {case LogEntry(k, v, o) => o match {
+ case PUT => m.put(k, v.get)
+ case REM => m -= k
+ case UPD => m.update(k, v.get)
+ }}
+ m get key
+ }
+
// to be concretized in subclasses
val storage: MapStorageBackend[K, V]
def commit = {
- if (shouldClearOnCommit.isDefined && shouldClearOnCommit.get) storage.removeMapStorageFor(uuid)
- removedEntries.toList.foreach(key => storage.removeMapStorageFor(uuid, key))
- storage.insertMapStorageEntriesFor(uuid, newAndUpdatedEntries.toList)
- newAndUpdatedEntries.clear
- removedEntries.clear
+ // if (shouldClearOnCommit.isDefined && shouldClearOnCommit.get) storage.removeMapStorageFor(uuid)
+
+ appendOnlyTxLog.foreach { case LogEntry(k, v, o) => o match {
+ case PUT => storage.insertMapStorageEntryFor(uuid, k, v.get)
+ case UPD => storage.insertMapStorageEntryFor(uuid, k, v.get)
+ case REM => storage.removeMapStorageFor(uuid, k)
+ }}
+
+ appendOnlyTxLog.clear
+ clearDistinctKeys
}
def abort = {
- newAndUpdatedEntries.clear
- removedEntries.clear
+ appendOnlyTxLog.clear
+ clearDistinctKeys
shouldClearOnCommit.swap(false)
}
@@ -118,68 +186,84 @@ trait PersistentMap[K, V] extends scala.collection.mutable.Map[K, V]
override def put(key: K, value: V): Option[V] = {
register
- newAndUpdatedEntries.put(key, value)
+ val curr = getCurrentValue(key)
+ appendOnlyTxLog add LogEntry(key, Some(value), PUT)
+ addToListOfKeysInTx(key)
+ curr
}
override def update(key: K, value: V) = {
register
- newAndUpdatedEntries.update(key, value)
+ val curr = getCurrentValue(key)
+ appendOnlyTxLog add LogEntry(key, Some(value), UPD)
+ addToListOfKeysInTx(key)
+ curr
}
override def remove(key: K) = {
register
- removedEntries.add(key)
- newAndUpdatedEntries.get(key)
+ val curr = getCurrentValue(key)
+ appendOnlyTxLog add LogEntry(key, None, REM)
+ addToListOfKeysInTx(key)
+ curr
}
- def slice(start: Option[K], count: Int): List[Tuple2[K, V]] =
+ def slice(start: Option[K], count: Int): List[(K, V)] =
slice(start, None, count)
- def slice(start: Option[K], finish: Option[K], count: Int): List[Tuple2[K, V]] = try {
- storage.getMapStorageRangeFor(uuid, start, finish, count)
- } catch { case e: Exception => Nil }
+ def slice(start: Option[K], finish: Option[K], count: Int): List[(K, V)]
override def clear = {
register
+ appendOnlyTxLog.clear
+ clearDistinctKeys
shouldClearOnCommit.swap(true)
}
override def contains(key: K): Boolean = try {
- newAndUpdatedEntries.contains(key) ||
- storage.getMapStorageEntryFor(uuid, key).isDefined
+ filterTxLogByKey(key) match {
+ case Seq() => // current tx doesn't use this
+ storage.getMapStorageEntryFor(uuid, key).isDefined // check storage
+ case txs => // present in log
+ txs.last.op != REM // last entry cannot be a REM
+ }
} catch { case e: Exception => false }
+ protected def existsInStorage(key: K): Option[V] = try {
+ storage.getMapStorageEntryFor(uuid, key)
+ } catch {
+ case e: Exception => None
+ }
+
override def size: Int = try {
- storage.getMapStorageSizeFor(uuid)
- } catch { case e: Exception => 0 }
+ // partition key set affected in current tx into those which r added & which r deleted
+ val (keysAdded, keysRemoved) = keysInCurrentTx.map {
+ case (kseq, k) => ((kseq, k), getCurrentValue(k))
+ }.partition(_._2.isDefined)
- override def get(key: K): Option[V] = {
- if (newAndUpdatedEntries.contains(key)) {
- newAndUpdatedEntries.get(key)
- }
- else try {
- storage.getMapStorageEntryFor(uuid, key)
- } catch { case e: Exception => None }
+ // keys which existed in storage but removed in current tx
+ val inStorageRemovedInTx =
+ keysRemoved.keySet
+ .map(_._2)
+ .filter(k => existsInStorage(k).isDefined)
+ .size
+
+ // all keys in storage
+ val keysInStorage =
+ storage.getMapStorageFor(uuid)
+ .map { case (k, v) => toEquals(k) }
+ .toSet
+
+ // (keys that existed UNION keys added ) - (keys removed)
+ (keysInStorage union keysAdded.keySet.map(_._1)).size - inStorageRemovedInTx
+ } catch {
+ case e: Exception => 0
}
- def iterator = elements
+ // get must consider underlying storage & current uncommitted tx log
+ override def get(key: K): Option[V] = getCurrentValue(key)
- override def elements: Iterator[Tuple2[K, V]] = {
- new Iterator[Tuple2[K, V]] {
- private val originalList: List[Tuple2[K, V]] = try {
- storage.getMapStorageFor(uuid)
- } catch {
- case e: Throwable => Nil
- }
- private var elements = newAndUpdatedEntries.toList union originalList.reverse
- override def next: Tuple2[K, V]= synchronized {
- val element = elements.head
- elements = elements.tail
- element
- }
- override def hasNext: Boolean = synchronized { !elements.isEmpty }
- }
- }
+ def iterator: Iterator[Tuple2[K, V]]
private def register = {
if (transaction.get.isEmpty) throw new NoTransactionInScopeException
@@ -187,6 +271,95 @@ trait PersistentMap[K, V] extends scala.collection.mutable.Map[K, V]
}
}
+trait PersistentMapBinary extends PersistentMap[Array[Byte], Array[Byte]] {
+ import scala.collection.mutable.ArraySeq
+
+ type T = ArraySeq[Byte]
+ def toEquals(k: Array[Byte]) = ArraySeq(k: _*)
+ override def equal(k1: Array[Byte], k2: Array[Byte]): Boolean = k1 sameElements k2
+
+ object COrdering {
+ implicit object ArraySeqOrdering extends Ordering[ArraySeq[Byte]] {
+ def compare(o1: ArraySeq[Byte], o2: ArraySeq[Byte]) =
+ new String(o1.toArray) compare new String(o2.toArray)
+ }
+ }
+
+ import scala.collection.immutable.{TreeMap, SortedMap}
+ private def replayAllKeys: SortedMap[ArraySeq[Byte], Array[Byte]] = {
+ import COrdering._
+
+ // need ArraySeq for ordering
+ val fromStorage =
+ TreeMap(storage.getMapStorageFor(uuid).map { case (k, v) => (ArraySeq(k: _*), v) }: _*)
+
+ val (keysAdded, keysRemoved) = keysInCurrentTx.map {
+ case (_, k) => (k, getCurrentValue(k))
+ }.partition(_._2.isDefined)
+
+ val inStorageRemovedInTx =
+ keysRemoved.keySet
+ .filter(k => existsInStorage(k).isDefined)
+ .map(k => ArraySeq(k: _*))
+
+ (fromStorage -- inStorageRemovedInTx) ++ keysAdded.map { case (k, Some(v)) => (ArraySeq(k: _*), v) }
+ }
+
+ override def slice(start: Option[Array[Byte]], finish: Option[Array[Byte]], count: Int): List[(Array[Byte], Array[Byte])] = try {
+ val newMap = replayAllKeys
+
+ if (newMap isEmpty) List[(Array[Byte], Array[Byte])]()
+
+ val startKey =
+ start match {
+ case Some(bytes) => Some(ArraySeq(bytes: _*))
+ case None => None
+ }
+
+ val endKey =
+ finish match {
+ case Some(bytes) => Some(ArraySeq(bytes: _*))
+ case None => None
+ }
+
+ ((startKey, endKey, count): @unchecked) match {
+ case ((Some(s), Some(e), _)) =>
+ newMap.range(s, e)
+ .toList
+ .map(e => (e._1.toArray, e._2))
+ .toList
+ case ((Some(s), None, c)) if c > 0 =>
+ newMap.from(s)
+ .iterator
+ .take(count)
+ .map(e => (e._1.toArray, e._2))
+ .toList
+ case ((Some(s), None, _)) =>
+ newMap.from(s)
+ .toList
+ .map(e => (e._1.toArray, e._2))
+ .toList
+ case ((None, Some(e), _)) =>
+ newMap.until(e)
+ .toList
+ .map(e => (e._1.toArray, e._2))
+ .toList
+ }
+ } catch { case e: Exception => Nil }
+
+ override def iterator: Iterator[(Array[Byte], Array[Byte])] = {
+ new Iterator[(Array[Byte], Array[Byte])] {
+ private var elements = replayAllKeys
+ override def next: (Array[Byte], Array[Byte]) = synchronized {
+ val (k, v) = elements.head
+ elements = elements.tail
+ (k.toArray, v)
+ }
+ override def hasNext: Boolean = synchronized { !elements.isEmpty }
+ }
+ }
+}
+
/**
* Implements a template for a concrete persistent transactional vector based storage.
*
@@ -198,42 +371,83 @@ trait PersistentVector[T] extends IndexedSeq[T] with Transactional with Committa
protected val removedElems = TransactionalVector[T]()
protected val shouldClearOnCommit = Ref[Boolean]()
+ // operations on the Vector
+ trait Op
+ case object ADD extends Op
+ case object UPD extends Op
+ case object POP extends Op
+
+ // append only log: records all mutating operations
+ protected val appendOnlyTxLog = TransactionalVector[LogEntry]()
+
+ case class LogEntry(index: Option[Int], value: Option[T], op: Op)
+
+ // need to override in subclasses e.g. "sameElements" for Array[Byte]
+ def equal(v1: T, v2: T): Boolean = v1 == v2
+
val storage: VectorStorageBackend[T]
def commit = {
- for (element <- newElems) storage.insertVectorStorageEntryFor(uuid, element)
- for (entry <- updatedElems) storage.updateVectorStorageEntryFor(uuid, entry._1, entry._2)
- newElems.clear
- updatedElems.clear
+ for(entry <- appendOnlyTxLog) {
+ entry match {
+ case LogEntry(_, Some(v), ADD) => storage.insertVectorStorageEntryFor(uuid, v)
+ case LogEntry(Some(i), Some(v), UPD) => storage.updateVectorStorageEntryFor(uuid, i, v)
+ case LogEntry(_, _, POP) => //..
+ }
+ }
+ appendOnlyTxLog.clear
}
def abort = {
- newElems.clear
- updatedElems.clear
- removedElems.clear
+ appendOnlyTxLog.clear
shouldClearOnCommit.swap(false)
}
+ private def replay: List[T] = {
+ import scala.collection.mutable.ArrayBuffer
+ var elemsStorage = ArrayBuffer(storage.getVectorStorageRangeFor(uuid, None, None, storage.getVectorStorageSizeFor(uuid)).reverse: _*)
+
+ for(entry <- appendOnlyTxLog) {
+ entry match {
+ case LogEntry(_, Some(v), ADD) => elemsStorage += v
+ case LogEntry(Some(i), Some(v), UPD) => elemsStorage.update(i, v)
+ case LogEntry(_, _, POP) => elemsStorage = elemsStorage.drop(1)
+ }
+ }
+ elemsStorage.toList.reverse
+ }
+
def +(elem: T) = add(elem)
def add(elem: T) = {
register
- newElems + elem
+ appendOnlyTxLog + LogEntry(None, Some(elem), ADD)
}
def apply(index: Int): T = get(index)
def get(index: Int): T = {
- if (newElems.size > index) newElems(index)
- else storage.getVectorStorageEntryFor(uuid, index)
+ if (appendOnlyTxLog.isEmpty) {
+ storage.getVectorStorageEntryFor(uuid, index)
+ } else {
+ val curr = replay
+ curr(index)
+ }
}
override def slice(start: Int, finish: Int): IndexedSeq[T] = slice(Some(start), Some(finish))
def slice(start: Option[Int], finish: Option[Int], count: Int = 0): IndexedSeq[T] = {
- val buffer = new scala.collection.mutable.ArrayBuffer[T]
- storage.getVectorStorageRangeFor(uuid, start, finish, count).foreach(buffer.append(_))
- buffer
+ val curr = replay
+ val s = if (start.isDefined) start.get else 0
+ val cnt =
+ if (finish.isDefined) {
+ val f = finish.get
+ if (f >= s) (f - s) else count
+ }
+ else count
+ if (s == 0 && cnt == 0) List().toIndexedSeq
+ else curr.slice(s, s + cnt).toIndexedSeq
}
/**
@@ -241,12 +455,13 @@ trait PersistentVector[T] extends IndexedSeq[T] with Transactional with Committa
*/
def pop: T = {
register
+ appendOnlyTxLog + LogEntry(None, None, POP)
throw new UnsupportedOperationException("PersistentVector::pop is not implemented")
}
def update(index: Int, newElem: T) = {
register
- storage.updateVectorStorageEntryFor(uuid, index, newElem)
+ appendOnlyTxLog + LogEntry(Some(index), Some(newElem), UPD)
}
override def first: T = get(0)
@@ -260,7 +475,7 @@ trait PersistentVector[T] extends IndexedSeq[T] with Transactional with Committa
}
}
- def length: Int = storage.getVectorStorageSizeFor(uuid) + newElems.length
+ def length: Int = replay.length
private def register = {
if (transaction.get.isEmpty) throw new NoTransactionInScopeException
diff --git a/akka-persistence/akka-persistence-mongo/src/main/scala/MongoStorage.scala b/akka-persistence/akka-persistence-mongo/src/main/scala/MongoStorage.scala
index 98776253a5..83e47e3ba5 100644
--- a/akka-persistence/akka-persistence-mongo/src/main/scala/MongoStorage.scala
+++ b/akka-persistence/akka-persistence-mongo/src/main/scala/MongoStorage.scala
@@ -9,7 +9,7 @@ import se.scalablesolutions.akka.persistence.common._
import se.scalablesolutions.akka.util.UUID
object MongoStorage extends Storage {
- type ElementType = AnyRef
+ type ElementType = Array[Byte]
def newMap: PersistentMap[ElementType, ElementType] = newMap(UUID.newUuid.toString)
def newVector: PersistentVector[ElementType] = newVector(UUID.newUuid.toString)
@@ -29,7 +29,7 @@ object MongoStorage extends Storage {
*
* @author Debasish Ghosh
*/
-class MongoPersistentMap(id: String) extends PersistentMap[AnyRef, AnyRef] {
+class MongoPersistentMap(id: String) extends PersistentMapBinary {
val uuid = id
val storage = MongoStorageBackend
}
@@ -40,12 +40,12 @@ class MongoPersistentMap(id: String) extends PersistentMap[AnyRef, AnyRef] {
*
* @author Debaissh Ghosh
*/
-class MongoPersistentVector(id: String) extends PersistentVector[AnyRef] {
+class MongoPersistentVector(id: String) extends PersistentVector[Array[Byte]] {
val uuid = id
val storage = MongoStorageBackend
}
-class MongoPersistentRef(id: String) extends PersistentRef[AnyRef] {
+class MongoPersistentRef(id: String) extends PersistentRef[Array[Byte]] {
val uuid = id
val storage = MongoStorageBackend
}
diff --git a/akka-persistence/akka-persistence-mongo/src/main/scala/MongoStorageBackend.scala b/akka-persistence/akka-persistence-mongo/src/main/scala/MongoStorageBackend.scala
index 950165567d..01d8ababce 100644
--- a/akka-persistence/akka-persistence-mongo/src/main/scala/MongoStorageBackend.scala
+++ b/akka-persistence/akka-persistence-mongo/src/main/scala/MongoStorageBackend.scala
@@ -9,13 +9,8 @@ import se.scalablesolutions.akka.persistence.common._
import se.scalablesolutions.akka.util.Logging
import se.scalablesolutions.akka.config.Config.config
-import sjson.json.Serializer._
-
import java.util.NoSuchElementException
-
-import com.mongodb._
-
-import java.util.{Map=>JMap, List=>JList, ArrayList=>JArrayList}
+import com.novus.casbah.mongodb.Imports._
/**
* A module for supporting MongoDB based persistence.
@@ -28,294 +23,208 @@ import java.util.{Map=>JMap, List=>JList, ArrayList=>JArrayList}
* @author Debasish Ghosh
*/
private[akka] object MongoStorageBackend extends
- MapStorageBackend[AnyRef, AnyRef] with
- VectorStorageBackend[AnyRef] with
- RefStorageBackend[AnyRef] with
+ MapStorageBackend[Array[Byte], Array[Byte]] with
+ VectorStorageBackend[Array[Byte]] with
+ RefStorageBackend[Array[Byte]] with
Logging {
- // enrich with null safe findOne
- class RichDBCollection(value: DBCollection) {
- def findOneNS(o: DBObject): Option[DBObject] = {
- value.findOne(o) match {
- case null => None
- case x => Some(x)
- }
- }
- }
-
- implicit def enrichDBCollection(c: DBCollection) = new RichDBCollection(c)
-
- val KEY = "key"
- val VALUE = "value"
+ val KEY = "__key"
+ val REF = "__ref"
val COLLECTION = "akka_coll"
- val MONGODB_SERVER_HOSTNAME = config.getString("akka.storage.mongodb.hostname", "127.0.0.1")
- val MONGODB_SERVER_DBNAME = config.getString("akka.storage.mongodb.dbname", "testdb")
- val MONGODB_SERVER_PORT = config.getInt("akka.storage.mongodb.port", 27017)
+ val HOSTNAME = config.getString("akka.storage.mongodb.hostname", "127.0.0.1")
+ val DBNAME = config.getString("akka.storage.mongodb.dbname", "testdb")
+ val PORT = config.getInt("akka.storage.mongodb.port", 27017)
- val db = new Mongo(MONGODB_SERVER_HOSTNAME, MONGODB_SERVER_PORT)
- val coll = db.getDB(MONGODB_SERVER_DBNAME).getCollection(COLLECTION)
+ val db: MongoDB = MongoConnection(HOSTNAME, PORT)(DBNAME)
+ val coll: MongoCollection = db(COLLECTION)
- private[this] val serializer = SJSON
+ def drop() { db.dropDatabase() }
- def insertMapStorageEntryFor(name: String, key: AnyRef, value: AnyRef) {
+ def insertMapStorageEntryFor(name: String, key: Array[Byte], value: Array[Byte]) {
insertMapStorageEntriesFor(name, List((key, value)))
}
- def insertMapStorageEntriesFor(name: String, entries: List[Tuple2[AnyRef, AnyRef]]) {
- import java.util.{Map, HashMap}
-
- val m: Map[AnyRef, AnyRef] = new HashMap
- for ((k, v) <- entries) {
- m.put(k, serializer.out(v))
- }
-
- nullSafeFindOne(name) match {
- case None =>
- coll.insert(new BasicDBObject().append(KEY, name).append(VALUE, m))
- case Some(dbo) => {
- // collate the maps
- val o = dbo.get(VALUE).asInstanceOf[Map[AnyRef, AnyRef]]
- o.putAll(m)
-
- val newdbo = new BasicDBObject().append(KEY, name).append(VALUE, o)
- coll.update(new BasicDBObject().append(KEY, name), newdbo, true, false)
+ def insertMapStorageEntriesFor(name: String, entries: List[(Array[Byte], Array[Byte])]) {
+ db.safely { db =>
+ val q: DBObject = MongoDBObject(KEY -> name)
+ coll.findOne(q) match {
+ case Some(dbo) =>
+ entries.foreach { case (k, v) => dbo += new String(k) -> v }
+ db.safely { db => coll.update(q, dbo, true, false) }
+ case None =>
+ val builder = MongoDBObject.newBuilder
+ builder += KEY -> name
+ entries.foreach { case (k, v) => builder += new String(k) -> v }
+ coll += builder.result.asDBObject
}
}
}
def removeMapStorageFor(name: String): Unit = {
- val q = new BasicDBObject
- q.put(KEY, name)
- coll.remove(q)
+ val q: DBObject = MongoDBObject(KEY -> name)
+ db.safely { db => coll.remove(q) }
}
- def removeMapStorageFor(name: String, key: AnyRef): Unit = {
- nullSafeFindOne(name) match {
- case None =>
- case Some(dbo) => {
- val orig = dbo.get(VALUE).asInstanceOf[DBObject].toMap
- if (key.isInstanceOf[List[_]]) {
- val keys = key.asInstanceOf[List[_]]
- keys.foreach(k => orig.remove(k.asInstanceOf[String]))
- } else {
- orig.remove(key.asInstanceOf[String])
- }
- // remove existing reference
- removeMapStorageFor(name)
- // and insert
- coll.insert(new BasicDBObject().append(KEY, name).append(VALUE, orig))
- }
+ private def queryFor[T](name: String)(body: (MongoDBObject, Option[DBObject]) => T): T = {
+ val q = MongoDBObject(KEY -> name)
+ body(q, coll.findOne(q))
+ }
+
+ def removeMapStorageFor(name: String, key: Array[Byte]): Unit = queryFor(name) { (q, dbo) =>
+ dbo.foreach { d =>
+ d -= new String(key)
+ db.safely { db => coll.update(q, d, true, false) }
}
}
- def getMapStorageEntryFor(name: String, key: AnyRef): Option[AnyRef] =
- getValueForKey(name, key.asInstanceOf[String])
-
- def getMapStorageSizeFor(name: String): Int = {
- nullSafeFindOne(name) match {
- case None => 0
- case Some(dbo) =>
- dbo.get(VALUE).asInstanceOf[JMap[String, AnyRef]].keySet.size
- }
+ def getMapStorageEntryFor(name: String, key: Array[Byte]): Option[Array[Byte]] = queryFor(name) { (q, dbo) =>
+ dbo.map { d =>
+ d.getAs[Array[Byte]](new String(key))
+ }.getOrElse(None)
}
- def getMapStorageFor(name: String): List[Tuple2[AnyRef, AnyRef]] = {
- val m =
- nullSafeFindOne(name) match {
- case None =>
- throw new NoSuchElementException(name + " not present")
- case Some(dbo) =>
- dbo.get(VALUE).asInstanceOf[JMap[String, AnyRef]]
- }
- val n =
- List(m.keySet.toArray: _*).asInstanceOf[List[String]]
- val vals =
- for(s <- n)
- yield (s, serializer.in[AnyRef](m.get(s).asInstanceOf[Array[Byte]]))
- vals.asInstanceOf[List[Tuple2[String, AnyRef]]]
+ def getMapStorageSizeFor(name: String): Int = queryFor(name) { (q, dbo) =>
+ dbo.map { d =>
+ d.size - 2 // need to exclude object id and our KEY
+ }.getOrElse(0)
}
- def getMapStorageRangeFor(name: String, start: Option[AnyRef],
- finish: Option[AnyRef],
- count: Int): List[Tuple2[AnyRef, AnyRef]] = {
- val m =
- nullSafeFindOne(name) match {
- case None =>
- throw new NoSuchElementException(name + " not present")
- case Some(dbo) =>
- dbo.get(VALUE).asInstanceOf[JMap[String, AnyRef]]
- }
-
- /**
- * count is the max number of results to return. Start with
- * start or 0 (if start is not defined) and go until
- * you hit finish or count.
- */
- val s = if (start.isDefined) start.get.asInstanceOf[Int] else 0
- val cnt =
- if (finish.isDefined) {
- val f = finish.get.asInstanceOf[Int]
- if (f >= s) math.min(count, (f - s)) else count
- }
- else count
-
- val n =
- List(m.keySet.toArray: _*).asInstanceOf[List[String]].sortWith((e1, e2) => (e1 compareTo e2) < 0).slice(s, s + cnt)
- val vals =
- for(s <- n)
- yield (s, serializer.in[AnyRef](m.get(s).asInstanceOf[Array[Byte]]))
- vals.asInstanceOf[List[Tuple2[String, AnyRef]]]
+ def getMapStorageFor(name: String): List[(Array[Byte], Array[Byte])] = queryFor(name) { (q, dbo) =>
+ dbo.map { d =>
+ for {
+ (k, v) <- d.toList
+ if k != "_id" && k != KEY
+ } yield (k.getBytes, v.asInstanceOf[Array[Byte]])
+ }.getOrElse(List.empty[(Array[Byte], Array[Byte])])
}
- private def getValueForKey(name: String, key: String): Option[AnyRef] = {
- try {
- nullSafeFindOne(name) match {
- case None => None
- case Some(dbo) =>
- Some(serializer.in[AnyRef](
- dbo.get(VALUE)
- .asInstanceOf[JMap[String, AnyRef]]
- .get(key).asInstanceOf[Array[Byte]]))
- }
- } catch {
- case e =>
- throw new NoSuchElementException(e.getMessage)
- }
+ def getMapStorageRangeFor(name: String, start: Option[Array[Byte]],
+ finish: Option[Array[Byte]],
+ count: Int): List[(Array[Byte], Array[Byte])] = queryFor(name) { (q, dbo) =>
+ dbo.map { d =>
+ // get all keys except the special ones
+ val keys = d.keys
+ .filter(k => k != "_id" && k != KEY)
+ .toList
+ .sortWith(_ < _)
+
+ // if the supplied start is not defined, get the head of keys
+ val s = start.map(new String(_)).getOrElse(keys.head)
+
+ // if the supplied finish is not defined, get the last element of keys
+ val f = finish.map(new String(_)).getOrElse(keys.last)
+
+ // slice from keys: both ends inclusive
+ val ks = keys.slice(keys.indexOf(s), scala.math.min(count, keys.indexOf(f) + 1))
+ ks.map(k => (k.getBytes, d.getAs[Array[Byte]](k).get))
+ }.getOrElse(List.empty[(Array[Byte], Array[Byte])])
}
- def insertVectorStorageEntriesFor(name: String, elements: List[AnyRef]) = {
- val q = new BasicDBObject
- q.put(KEY, name)
-
- val currentList =
- coll.findOneNS(q) match {
- case None =>
- new JArrayList[AnyRef]
- case Some(dbo) =>
- dbo.get(VALUE).asInstanceOf[JArrayList[AnyRef]]
- }
- if (!currentList.isEmpty) {
- // record exists
- // remove before adding
- coll.remove(q)
- }
-
- // add to the current list
- elements.map(serializer.out(_)).foreach(currentList.add(_))
-
- coll.insert(
- new BasicDBObject()
- .append(KEY, name)
- .append(VALUE, currentList)
- )
- }
-
- def insertVectorStorageEntryFor(name: String, element: AnyRef) = {
+ def insertVectorStorageEntryFor(name: String, element: Array[Byte]) = {
insertVectorStorageEntriesFor(name, List(element))
}
- def getVectorStorageEntryFor(name: String, index: Int): AnyRef = {
- try {
- val o =
- nullSafeFindOne(name) match {
- case None =>
- throw new NoSuchElementException(name + " not present")
+ def insertVectorStorageEntriesFor(name: String, elements: List[Array[Byte]]) = {
+ // lookup with name
+ val q: DBObject = MongoDBObject(KEY -> name)
- case Some(dbo) =>
- dbo.get(VALUE).asInstanceOf[JList[AnyRef]]
+ db.safely { db =>
+ coll.findOne(q) match {
+ // exists : need to update
+ case Some(dbo) =>
+ dbo -= KEY
+ dbo -= "_id"
+ val listBuilder = MongoDBList.newBuilder
+
+ // expensive!
+ listBuilder ++= (elements ++ dbo.toSeq.sortWith((e1, e2) => (e1._1.toInt < e2._1.toInt)).map(_._2))
+
+ val builder = MongoDBObject.newBuilder
+ builder += KEY -> name
+ builder ++= listBuilder.result
+ coll.update(q, builder.result.asDBObject, true, false)
+
+ // new : just add
+ case None =>
+ val listBuilder = MongoDBList.newBuilder
+ listBuilder ++= elements
+
+ val builder = MongoDBObject.newBuilder
+ builder += KEY -> name
+ builder ++= listBuilder.result
+ coll += builder.result.asDBObject
}
- serializer.in[AnyRef](
- o.get(index).asInstanceOf[Array[Byte]])
- } catch {
- case e =>
- throw new NoSuchElementException(e.getMessage)
}
}
- def getVectorStorageRangeFor(name: String,
- start: Option[Int], finish: Option[Int], count: Int): List[AnyRef] = {
- try {
- val o =
- nullSafeFindOne(name) match {
- case None =>
- throw new NoSuchElementException(name + " not present")
+ def updateVectorStorageEntryFor(name: String, index: Int, elem: Array[Byte]) = queryFor(name) { (q, dbo) =>
+ dbo.foreach { d =>
+ d += ((index.toString, elem))
+ db.safely { db => coll.update(q, d, true, false) }
+ }
+ }
- case Some(dbo) =>
- dbo.get(VALUE).asInstanceOf[JList[AnyRef]]
- }
+ def getVectorStorageEntryFor(name: String, index: Int): Array[Byte] = queryFor(name) { (q, dbo) =>
+ dbo.map { d =>
+ d(index.toString).asInstanceOf[Array[Byte]]
+ }.getOrElse(Array.empty[Byte])
+ }
- val s = if (start.isDefined) start.get else 0
+ /**
+ * if start and finish both are defined, ignore count and
+ * report the range [start, finish)
+ * if start is not defined, assume start = 0
+ * if start == 0 and finish == 0, return an empty collection
+ */
+ def getVectorStorageRangeFor(name: String, start: Option[Int], finish: Option[Int], count: Int): List[Array[Byte]] = queryFor(name) { (q, dbo) =>
+ dbo.map { d =>
+ val ls = d.filter { case (k, v) => k != KEY && k != "_id" }
+ .toSeq
+ .sortWith((e1, e2) => (e1._1.toInt < e2._1.toInt))
+ .map(_._2)
+
+ val st = start.getOrElse(0)
val cnt =
if (finish.isDefined) {
val f = finish.get
- if (f >= s) (f - s) else count
+ if (f >= st) (f - st) else count
}
else count
-
- // pick the subrange and make a Scala list
- val l =
- List(o.subList(s, s + cnt).toArray: _*)
-
- for(e <- l)
- yield serializer.in[AnyRef](e.asInstanceOf[Array[Byte]])
- } catch {
- case e =>
- throw new NoSuchElementException(e.getMessage)
- }
+ if (st == 0 && cnt == 0) List()
+ ls.slice(st, st + cnt).asInstanceOf[List[Array[Byte]]]
+ }.getOrElse(List.empty[Array[Byte]])
}
- def updateVectorStorageEntryFor(name: String, index: Int, elem: AnyRef) = {
- val q = new BasicDBObject
- q.put(KEY, name)
-
- val dbobj =
- coll.findOneNS(q) match {
- case None =>
- throw new NoSuchElementException(name + " not present")
- case Some(dbo) => dbo
- }
- val currentList = dbobj.get(VALUE).asInstanceOf[JArrayList[AnyRef]]
- currentList.set(index, serializer.out(elem))
- coll.update(q,
- new BasicDBObject().append(KEY, name).append(VALUE, currentList))
+ def getVectorStorageSizeFor(name: String): Int = queryFor(name) { (q, dbo) =>
+ dbo.map { d => d.size - 2 }.getOrElse(0)
}
- def getVectorStorageSizeFor(name: String): Int = {
- nullSafeFindOne(name) match {
- case None => 0
- case Some(dbo) =>
- dbo.get(VALUE).asInstanceOf[JList[AnyRef]].size
- }
- }
+ def insertRefStorageFor(name: String, element: Array[Byte]) = {
+ // lookup with name
+ val q: DBObject = MongoDBObject(KEY -> name)
- private def nullSafeFindOne(name: String): Option[DBObject] = {
- val o = new BasicDBObject
- o.put(KEY, name)
- coll.findOneNS(o)
- }
+ db.safely { db =>
+ coll.findOne(q) match {
+ // exists : need to update
+ case Some(dbo) =>
+ dbo += ((REF, element))
+ coll.update(q, dbo, true, false)
- def insertRefStorageFor(name: String, element: AnyRef) = {
- nullSafeFindOne(name) match {
- case None =>
- case Some(dbo) => {
- val q = new BasicDBObject
- q.put(KEY, name)
- coll.remove(q)
+ // not found : make one
+ case None =>
+ val builder = MongoDBObject.newBuilder
+ builder += KEY -> name
+ builder += REF -> element
+ coll += builder.result.asDBObject
}
}
- coll.insert(
- new BasicDBObject()
- .append(KEY, name)
- .append(VALUE, serializer.out(element)))
}
- def getRefStorageFor(name: String): Option[AnyRef] = {
- nullSafeFindOne(name) match {
- case None => None
- case Some(dbo) =>
- Some(serializer.in[AnyRef](dbo.get(VALUE).asInstanceOf[Array[Byte]]))
- }
+ def getRefStorageFor(name: String): Option[Array[Byte]] = queryFor(name) { (q, dbo) =>
+ dbo.map { d =>
+ d.getAs[Array[Byte]](REF)
+ }.getOrElse(None)
}
}
diff --git a/akka-persistence/akka-persistence-mongo/src/test/scala/MongoPersistentActorSpec.scala b/akka-persistence/akka-persistence-mongo/src/test/scala/MongoPersistentActorSpec.scala
index 1acc9ee97d..01f735b254 100644
--- a/akka-persistence/akka-persistence-mongo/src/test/scala/MongoPersistentActorSpec.scala
+++ b/akka-persistence/akka-persistence-mongo/src/test/scala/MongoPersistentActorSpec.scala
@@ -1,32 +1,19 @@
package se.scalablesolutions.akka.persistence.mongo
-import org.junit.{Test, Before}
-import org.junit.Assert._
-import org.scalatest.junit.JUnitSuite
-
-import _root_.dispatch.json.{JsNumber, JsValue}
-import _root_.dispatch.json.Js._
+import org.scalatest.Spec
+import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.junit.JUnitRunner
+import org.junit.runner.RunWith
import se.scalablesolutions.akka.actor.{Transactor, Actor, ActorRef}
import Actor._
-/**
- * A persistent actor based on MongoDB storage.
- *
- * Demonstrates a bank account operation consisting of messages that:
- * checks balance Balance
- * debits amountDebit
- * debits multiple amountsMultiDebit
- * credits amountCredit
- *
- * Needs a running Mongo server.
- * @author Debasish Ghosh
- */
case class Balance(accountNo: String)
-case class Debit(accountNo: String, amount: BigInt, failer: ActorRef)
-case class MultiDebit(accountNo: String, amounts: List[BigInt], failer: ActorRef)
-case class Credit(accountNo: String, amount: BigInt)
+case class Debit(accountNo: String, amount: Int, failer: ActorRef)
+case class MultiDebit(accountNo: String, amounts: List[Int], failer: ActorRef)
+case class Credit(accountNo: String, amount: Int)
case class Log(start: Int, finish: Int)
case object LogSize
@@ -35,63 +22,65 @@ class BankAccountActor extends Transactor {
private lazy val accountState = MongoStorage.newMap
private lazy val txnLog = MongoStorage.newVector
+ import sjson.json.DefaultProtocol._
+ import sjson.json.JsonSerialization._
+
def receive: Receive = {
// check balance
case Balance(accountNo) =>
- txnLog.add("Balance:" + accountNo)
- self.reply(accountState.get(accountNo).get)
+ txnLog.add(("Balance:" + accountNo).getBytes)
+ self.reply(
+ accountState.get(accountNo.getBytes)
+ .map(frombinary[Int](_))
+ .getOrElse(0))
// debit amount: can fail
case Debit(accountNo, amount, failer) =>
- txnLog.add("Debit:" + accountNo + " " + amount)
+ txnLog.add(("Debit:" + accountNo + " " + amount).getBytes)
+ val m = accountState.get(accountNo.getBytes)
+ .map(frombinary[Int](_))
+ .getOrElse(0)
+
+ accountState.put(accountNo.getBytes, tobinary(m - amount))
+ if (amount > m) failer !! "Failure"
- val m: BigInt =
- accountState.get(accountNo) match {
- case Some(JsNumber(n)) =>
- BigInt(n.asInstanceOf[BigDecimal].intValue)
- case None => 0
- }
- accountState.put(accountNo, (m - amount))
- if (amount > m)
- failer !! "Failure"
self.reply(m - amount)
// many debits: can fail
// demonstrates true rollback even if multiple puts have been done
case MultiDebit(accountNo, amounts, failer) =>
- txnLog.add("MultiDebit:" + accountNo + " " + amounts.map(_.intValue).foldLeft(0)(_ + _))
+ val sum = amounts.foldRight(0)(_ + _)
+ txnLog.add(("MultiDebit:" + accountNo + " " + sum).getBytes)
- val m: BigInt =
- accountState.get(accountNo) match {
- case Some(JsNumber(n)) => BigInt(n.toString)
- case None => 0
+ val m = accountState.get(accountNo.getBytes)
+ .map(frombinary[Int](_))
+ .getOrElse(0)
+
+ var cbal = m
+ amounts.foreach { amount =>
+ accountState.put(accountNo.getBytes, tobinary(m - amount))
+ cbal = cbal - amount
+ if (cbal < 0) failer !! "Failure"
}
- var bal: BigInt = 0
- amounts.foreach {amount =>
- bal = bal + amount
- accountState.put(accountNo, (m - bal))
- }
- if (bal > m) failer !! "Failure"
- self.reply(m - bal)
+
+ self.reply(m - sum)
// credit amount
case Credit(accountNo, amount) =>
- txnLog.add("Credit:" + accountNo + " " + amount)
+ txnLog.add(("Credit:" + accountNo + " " + amount).getBytes)
+ val m = accountState.get(accountNo.getBytes)
+ .map(frombinary[Int](_))
+ .getOrElse(0)
+
+ accountState.put(accountNo.getBytes, tobinary(m + amount))
- val m: BigInt =
- accountState.get(accountNo) match {
- case Some(JsNumber(n)) =>
- BigInt(n.asInstanceOf[BigDecimal].intValue)
- case None => 0
- }
- accountState.put(accountNo, (m + amount))
self.reply(m + amount)
case LogSize =>
- self.reply(txnLog.length.asInstanceOf[AnyRef])
+ self.reply(txnLog.length)
case Log(start, finish) =>
- self.reply(txnLog.slice(start, finish))
+ self.reply(txnLog.slice(start, finish).map(new String(_)))
}
}
@@ -102,82 +91,71 @@ class BankAccountActor extends Transactor {
}
}
-class MongoPersistentActorSpec extends JUnitSuite {
- @Test
- def testSuccessfulDebit = {
- val bactor = actorOf[BankAccountActor]
- bactor.start
- val failer = actorOf[PersistentFailerActor]
- failer.start
- bactor !! Credit("a-123", 5000)
- bactor !! Debit("a-123", 3000, failer)
+@RunWith(classOf[JUnitRunner])
+class MongoPersistentActorSpec extends
+ Spec with
+ ShouldMatchers with
+ BeforeAndAfterEach {
- val JsNumber(b) = (bactor !! Balance("a-123")).get.asInstanceOf[JsValue]
- assertEquals(BigInt(2000), BigInt(b.intValue))
-
- bactor !! Credit("a-123", 7000)
-
- val JsNumber(b1) = (bactor !! Balance("a-123")).get.asInstanceOf[JsValue]
- assertEquals(BigInt(9000), BigInt(b1.intValue))
-
- bactor !! Debit("a-123", 8000, failer)
-
- val JsNumber(b2) = (bactor !! Balance("a-123")).get.asInstanceOf[JsValue]
- assertEquals(BigInt(1000), BigInt(b2.intValue))
-
- assert(7 == (bactor !! LogSize).get.asInstanceOf[Int])
-
- import scala.collection.mutable.ArrayBuffer
- assert((bactor !! Log(0, 7)).get.asInstanceOf[ArrayBuffer[String]].size == 7)
- assert((bactor !! Log(0, 0)).get.asInstanceOf[ArrayBuffer[String]].size == 0)
- assert((bactor !! Log(1, 2)).get.asInstanceOf[ArrayBuffer[String]].size == 1)
- assert((bactor !! Log(6, 7)).get.asInstanceOf[ArrayBuffer[String]].size == 1)
- assert((bactor !! Log(0, 1)).get.asInstanceOf[ArrayBuffer[String]].size == 1)
+ override def beforeEach {
+ MongoStorageBackend.drop
}
- @Test
- def testUnsuccessfulDebit = {
- val bactor = actorOf[BankAccountActor]
- bactor.start
- bactor !! Credit("a-123", 5000)
-
- val JsNumber(b) = (bactor !! Balance("a-123")).get.asInstanceOf[JsValue]
- assertEquals(BigInt(5000), BigInt(b.intValue))
-
- val failer = actorOf[PersistentFailerActor]
- failer.start
- try {
- bactor !! Debit("a-123", 7000, failer)
- fail("should throw exception")
- } catch { case e: RuntimeException => {}}
-
- val JsNumber(b1) = (bactor !! Balance("a-123")).get.asInstanceOf[JsValue]
- assertEquals(BigInt(5000), BigInt(b1.intValue))
-
- // should not count the failed one
- assert(3 == (bactor !! LogSize).get.asInstanceOf[Int])
+ override def afterEach {
+ MongoStorageBackend.drop
}
- @Test
- def testUnsuccessfulMultiDebit = {
- val bactor = actorOf[BankAccountActor]
- bactor.start
- bactor !! Credit("a-123", 5000)
+ describe("successful debit") {
+ it("should debit successfully") {
+ val bactor = actorOf[BankAccountActor]
+ bactor.start
+ val failer = actorOf[PersistentFailerActor]
+ failer.start
+ bactor !! Credit("a-123", 5000)
+ bactor !! Debit("a-123", 3000, failer)
- val JsNumber(b) = (bactor !! Balance("a-123")).get.asInstanceOf[JsValue]
- assertEquals(BigInt(5000), BigInt(b.intValue))
+ (bactor !! Balance("a-123")).get.asInstanceOf[Int] should equal(2000)
- val failer = actorOf[PersistentFailerActor]
- failer.start
- try {
- bactor !! MultiDebit("a-123", List(500, 2000, 1000, 3000), failer)
- fail("should throw exception")
- } catch { case e: RuntimeException => {}}
+ bactor !! Credit("a-123", 7000)
+ (bactor !! Balance("a-123")).get.asInstanceOf[Int] should equal(9000)
- val JsNumber(b1) = (bactor !! Balance("a-123")).get.asInstanceOf[JsValue]
- assertEquals(BigInt(5000), BigInt(b1.intValue))
+ bactor !! Debit("a-123", 8000, failer)
+ (bactor !! Balance("a-123")).get.asInstanceOf[Int] should equal(1000)
- // should not count the failed one
- assert(3 == (bactor !! LogSize).get.asInstanceOf[Int])
+ (bactor !! LogSize).get.asInstanceOf[Int] should equal(7)
+ (bactor !! Log(0, 7)).get.asInstanceOf[Iterable[String]].size should equal(7)
+ }
+ }
+
+ describe("unsuccessful debit") {
+ it("debit should fail") {
+ val bactor = actorOf[BankAccountActor]
+ bactor.start
+ val failer = actorOf[PersistentFailerActor]
+ failer.start
+ bactor !! Credit("a-123", 5000)
+ (bactor !! Balance("a-123")).get.asInstanceOf[Int] should equal(5000)
+ evaluating {
+ bactor !! Debit("a-123", 7000, failer)
+ } should produce [Exception]
+ (bactor !! Balance("a-123")).get.asInstanceOf[Int] should equal(5000)
+ (bactor !! LogSize).get.asInstanceOf[Int] should equal(3)
+ }
+ }
+
+ describe("unsuccessful multidebit") {
+ it("multidebit should fail") {
+ val bactor = actorOf[BankAccountActor]
+ bactor.start
+ val failer = actorOf[PersistentFailerActor]
+ failer.start
+ bactor !! Credit("a-123", 5000)
+ (bactor !! Balance("a-123")).get.asInstanceOf[Int] should equal(5000)
+ evaluating {
+ bactor !! MultiDebit("a-123", List(1000, 2000, 4000), failer)
+ } should produce [Exception]
+ (bactor !! Balance("a-123")).get.asInstanceOf[Int] should equal(5000)
+ (bactor !! LogSize).get.asInstanceOf[Int] should equal(3)
+ }
}
}
diff --git a/akka-persistence/akka-persistence-mongo/src/test/scala/MongoStorageSpec.scala b/akka-persistence/akka-persistence-mongo/src/test/scala/MongoStorageSpec.scala
index e518b28d66..e9576cc152 100644
--- a/akka-persistence/akka-persistence-mongo/src/test/scala/MongoStorageSpec.scala
+++ b/akka-persistence/akka-persistence-mongo/src/test/scala/MongoStorageSpec.scala
@@ -1,364 +1,158 @@
package se.scalablesolutions.akka.persistence.mongo
-import org.junit.{Test, Before}
-import org.junit.Assert._
-import org.scalatest.junit.JUnitSuite
-import _root_.dispatch.json._
-import _root_.dispatch.json.Js._
+import org.scalatest.Spec
+import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.junit.JUnitRunner
+import org.junit.runner.RunWith
import java.util.NoSuchElementException
-@scala.reflect.BeanInfo case class Foo(no: Int, name: String)
-class MongoStorageSpec extends JUnitSuite {
+@RunWith(classOf[JUnitRunner])
+class MongoStorageSpec extends
+ Spec with
+ ShouldMatchers with
+ BeforeAndAfterEach {
- val changeSetV = new scala.collection.mutable.ArrayBuffer[AnyRef]
- val changeSetM = new scala.collection.mutable.HashMap[AnyRef, AnyRef]
-
- @Before def initialize() = {
- MongoStorageBackend.coll.drop
+ override def beforeEach {
+ MongoStorageBackend.drop
}
- @Test
- def testVectorInsertForTransactionId = {
- changeSetV += "debasish" // string
- changeSetV += List(1, 2, 3) // Scala List
- changeSetV += List(100, 200)
- MongoStorageBackend.insertVectorStorageEntriesFor("U-A1", changeSetV.toList)
- assertEquals(
- 3,
- MongoStorageBackend.getVectorStorageSizeFor("U-A1"))
- changeSetV.clear
-
- // changeSetV should be reinitialized
- changeSetV += List(12, 23, 45)
- changeSetV += "maulindu"
- MongoStorageBackend.insertVectorStorageEntriesFor("U-A1", changeSetV.toList)
- assertEquals(
- 5,
- MongoStorageBackend.getVectorStorageSizeFor("U-A1"))
-
- // add more to the same changeSetV
- changeSetV += "ramanendu"
- changeSetV += Map(1 -> "dg", 2 -> "mc")
-
- // add for a diff transaction
- MongoStorageBackend.insertVectorStorageEntriesFor("U-A2", changeSetV.toList)
- assertEquals(
- 4,
- MongoStorageBackend.getVectorStorageSizeFor("U-A2"))
-
- // previous transaction change set should remain same
- assertEquals(
- 5,
- MongoStorageBackend.getVectorStorageSizeFor("U-A1"))
-
- // test single element entry
- MongoStorageBackend.insertVectorStorageEntryFor("U-A1", Map(1->1, 2->4, 3->9))
- assertEquals(
- 6,
- MongoStorageBackend.getVectorStorageSizeFor("U-A1"))
+ override def afterEach {
+ MongoStorageBackend.drop
}
- @Test
- def testVectorFetchForKeys = {
+ describe("persistent maps") {
+ it("should insert with single key and value") {
+ import MongoStorageBackend._
- // initially everything 0
- assertEquals(
- 0,
- MongoStorageBackend.getVectorStorageSizeFor("U-A2"))
-
- assertEquals(
- 0,
- MongoStorageBackend.getVectorStorageSizeFor("U-A1"))
-
- // get some stuff
- changeSetV += "debasish"
- changeSetV += List(BigDecimal(12), BigDecimal(13), BigDecimal(14))
- MongoStorageBackend.insertVectorStorageEntriesFor("U-A1", changeSetV.toList)
-
- assertEquals(
- 2,
- MongoStorageBackend.getVectorStorageSizeFor("U-A1"))
-
- val JsString(str) = MongoStorageBackend.getVectorStorageEntryFor("U-A1", 0).asInstanceOf[JsString]
- assertEquals("debasish", str)
-
- val l = MongoStorageBackend.getVectorStorageEntryFor("U-A1", 1).asInstanceOf[JsValue]
- val num_list = list ! num
- val num_list(l0) = l
- assertEquals(List(12, 13, 14), l0)
-
- changeSetV.clear
- changeSetV += Map(1->1, 2->4, 3->9)
- changeSetV += BigInt(2310)
- changeSetV += List(100, 200, 300)
- MongoStorageBackend.insertVectorStorageEntriesFor("U-A1", changeSetV.toList)
-
- assertEquals(
- 5,
- MongoStorageBackend.getVectorStorageSizeFor("U-A1"))
-
- val r =
- MongoStorageBackend.getVectorStorageRangeFor("U-A1", Some(1), None, 3)
-
- assertEquals(3, r.size)
- val lr = r(0).asInstanceOf[JsValue]
- val num_list(l1) = lr
- assertEquals(List(12, 13, 14), l1)
- }
-
- @Test
- def testVectorFetchForNonExistentKeys = {
- try {
- MongoStorageBackend.getVectorStorageEntryFor("U-A1", 1)
- fail("should throw an exception")
- } catch {case e: NoSuchElementException => {}}
-
- try {
- MongoStorageBackend.getVectorStorageRangeFor("U-A1", Some(2), None, 12)
- fail("should throw an exception")
- } catch {case e: NoSuchElementException => {}}
- }
-
- @Test
- def testVectorUpdateForTransactionId = {
- import MongoStorageBackend._
-
- changeSetV += "debasish" // string
- changeSetV += List(1, 2, 3) // Scala List
- changeSetV += List(100, 200)
-
- insertVectorStorageEntriesFor("U-A1", changeSetV.toList)
- assertEquals(3, getVectorStorageSizeFor("U-A1"))
- updateVectorStorageEntryFor("U-A1", 0, "maulindu")
- val JsString(str) = getVectorStorageEntryFor("U-A1", 0).asInstanceOf[JsString]
- assertEquals("maulindu", str)
-
- updateVectorStorageEntryFor("U-A1", 1, Map("1"->"dg", "2"->"mc"))
- val JsObject(m) = getVectorStorageEntryFor("U-A1", 1).asInstanceOf[JsObject]
- assertEquals(m.keySet.size, 2)
- }
-
- @Test
- def testMapInsertForTransactionId = {
- fillMap
-
- // add some more to changeSet
- changeSetM += "5" -> Foo(12, "dg")
- changeSetM += "6" -> java.util.Calendar.getInstance.getTime
-
- // insert all into Mongo
- MongoStorageBackend.insertMapStorageEntriesFor("U-M1", changeSetM.toList)
- assertEquals(
- 6,
- MongoStorageBackend.getMapStorageSizeFor("U-M1"))
-
- // individual insert api
- MongoStorageBackend.insertMapStorageEntryFor("U-M1", "7", "akka")
- MongoStorageBackend.insertMapStorageEntryFor("U-M1", "8", List(23, 25))
- assertEquals(
- 8,
- MongoStorageBackend.getMapStorageSizeFor("U-M1"))
-
- // add the same changeSet for another transaction
- MongoStorageBackend.insertMapStorageEntriesFor("U-M2", changeSetM.toList)
- assertEquals(
- 6,
- MongoStorageBackend.getMapStorageSizeFor("U-M2"))
-
- // the first transaction should remain the same
- assertEquals(
- 8,
- MongoStorageBackend.getMapStorageSizeFor("U-M1"))
- changeSetM.clear
- }
-
- @Test
- def testMapContents = {
- fillMap
- MongoStorageBackend.insertMapStorageEntriesFor("U-M1", changeSetM.toList)
- MongoStorageBackend.getMapStorageEntryFor("U-M1", "2") match {
- case Some(x) => {
- val JsString(str) = x.asInstanceOf[JsValue]
- assertEquals("peter", str)
- }
- case None => fail("should fetch peter")
- }
- MongoStorageBackend.getMapStorageEntryFor("U-M1", "4") match {
- case Some(x) => {
- val num_list = list ! num
- val num_list(l0) = x.asInstanceOf[JsValue]
- assertEquals(3, l0.size)
- }
- case None => fail("should fetch list")
- }
- MongoStorageBackend.getMapStorageEntryFor("U-M1", "3") match {
- case Some(x) => {
- val num_list = list ! num
- val num_list(l0) = x.asInstanceOf[JsValue]
- assertEquals(2, l0.size)
- }
- case None => fail("should fetch list")
+ insertMapStorageEntryFor("t1", "odersky".getBytes, "scala".getBytes)
+ insertMapStorageEntryFor("t1", "gosling".getBytes, "java".getBytes)
+ insertMapStorageEntryFor("t1", "stroustrup".getBytes, "c++".getBytes)
+ getMapStorageSizeFor("t1") should equal(3)
+ new String(getMapStorageEntryFor("t1", "odersky".getBytes).get) should equal("scala")
+ new String(getMapStorageEntryFor("t1", "gosling".getBytes).get) should equal("java")
+ new String(getMapStorageEntryFor("t1", "stroustrup".getBytes).get) should equal("c++")
+ getMapStorageEntryFor("t1", "torvalds".getBytes) should equal(None)
}
- // get the entire map
- val l: List[Tuple2[AnyRef, AnyRef]] =
- MongoStorageBackend.getMapStorageFor("U-M1")
+ it("should insert with multiple keys and values") {
+ import MongoStorageBackend._
- assertEquals(4, l.size)
- assertTrue(l.map(_._1).contains("1"))
- assertTrue(l.map(_._1).contains("2"))
- assertTrue(l.map(_._1).contains("3"))
- assertTrue(l.map(_._1).contains("4"))
+ val l = List(("stroustrup", "c++"), ("odersky", "scala"), ("gosling", "java"))
+ insertMapStorageEntriesFor("t1", l.map { case (k, v) => (k.getBytes, v.getBytes) })
+ getMapStorageSizeFor("t1") should equal(3)
+ new String(getMapStorageEntryFor("t1", "stroustrup".getBytes).get) should equal("c++")
+ new String(getMapStorageEntryFor("t1", "gosling".getBytes).get) should equal("java")
+ new String(getMapStorageEntryFor("t1", "odersky".getBytes).get) should equal("scala")
+ getMapStorageEntryFor("t1", "torvalds".getBytes) should equal(None)
- val JsString(str) = l.filter(_._1 == "2").head._2
- assertEquals(str, "peter")
+ getMapStorageEntryFor("t2", "torvalds".getBytes) should equal(None)
- // trying to fetch for a non-existent transaction will throw
- try {
- MongoStorageBackend.getMapStorageFor("U-M2")
- fail("should throw an exception")
- } catch {case e: NoSuchElementException => {}}
+ getMapStorageFor("t1").map { case (k, v) => (new String(k), new String(v)) } should equal (l)
- changeSetM.clear
- }
+ removeMapStorageFor("t1", "gosling".getBytes)
+ getMapStorageSizeFor("t1") should equal(2)
- @Test
- def testMapContentsByRange = {
- fillMap
- changeSetM += "5" -> Map(1 -> "dg", 2 -> "mc")
- MongoStorageBackend.insertMapStorageEntriesFor("U-M1", changeSetM.toList)
-
- // specify start and count
- val l: List[Tuple2[AnyRef, AnyRef]] =
- MongoStorageBackend.getMapStorageRangeFor(
- "U-M1", Some(Integer.valueOf(2)), None, 3)
-
- assertEquals(3, l.size)
- assertEquals("3", l(0)._1.asInstanceOf[String])
- val lst = l(0)._2.asInstanceOf[JsValue]
- val num_list = list ! num
- val num_list(l0) = lst
- assertEquals(List(100, 200), l0)
- assertEquals("4", l(1)._1.asInstanceOf[String])
- val ls = l(1)._2.asInstanceOf[JsValue]
- val num_list(l1) = ls
- assertEquals(List(10, 20, 30), l1)
-
- // specify start, finish and count where finish - start == count
- assertEquals(3,
- MongoStorageBackend.getMapStorageRangeFor(
- "U-M1", Some(Integer.valueOf(2)), Some(Integer.valueOf(5)), 3).size)
-
- // specify start, finish and count where finish - start > count
- assertEquals(3,
- MongoStorageBackend.getMapStorageRangeFor(
- "U-M1", Some(Integer.valueOf(2)), Some(Integer.valueOf(9)), 3).size)
-
- // do not specify start or finish
- assertEquals(3,
- MongoStorageBackend.getMapStorageRangeFor(
- "U-M1", None, None, 3).size)
-
- // specify finish and count
- assertEquals(3,
- MongoStorageBackend.getMapStorageRangeFor(
- "U-M1", None, Some(Integer.valueOf(3)), 3).size)
-
- // specify start, finish and count where finish < start
- assertEquals(3,
- MongoStorageBackend.getMapStorageRangeFor(
- "U-M1", Some(Integer.valueOf(2)), Some(Integer.valueOf(1)), 3).size)
-
- changeSetM.clear
- }
-
- @Test
- def testMapStorageRemove = {
- fillMap
- changeSetM += "5" -> Map(1 -> "dg", 2 -> "mc")
-
- MongoStorageBackend.insertMapStorageEntriesFor("U-M1", changeSetM.toList)
- assertEquals(5,
- MongoStorageBackend.getMapStorageSizeFor("U-M1"))
-
- // remove key "3"
- MongoStorageBackend.removeMapStorageFor("U-M1", "3")
- assertEquals(4,
- MongoStorageBackend.getMapStorageSizeFor("U-M1"))
-
- try {
- MongoStorageBackend.getMapStorageEntryFor("U-M1", "3")
- fail("should throw exception")
- } catch { case e => {}}
-
- // remove key "4"
- MongoStorageBackend.removeMapStorageFor("U-M1", "4")
- assertEquals(3,
- MongoStorageBackend.getMapStorageSizeFor("U-M1"))
-
- // remove key "2"
- MongoStorageBackend.removeMapStorageFor("U-M1", "2")
- assertEquals(2,
- MongoStorageBackend.getMapStorageSizeFor("U-M1"))
-
- // remove the whole stuff
- MongoStorageBackend.removeMapStorageFor("U-M1")
-
- try {
- MongoStorageBackend.getMapStorageFor("U-M1")
- fail("should throw exception")
- } catch { case e: NoSuchElementException => {}}
-
- changeSetM.clear
- }
-
- private def fillMap = {
- changeSetM += "1" -> "john"
- changeSetM += "2" -> "peter"
- changeSetM += "3" -> List(100, 200)
- changeSetM += "4" -> List(10, 20, 30)
- changeSetM
- }
-
- @Test
- def testRefStorage = {
- MongoStorageBackend.getRefStorageFor("U-R1") match {
- case None =>
- case Some(o) => fail("should be None")
+ removeMapStorageFor("t1")
+ getMapStorageSizeFor("t1") should equal(0)
}
- val m = Map("1"->1, "2"->4, "3"->9)
- MongoStorageBackend.insertRefStorageFor("U-R1", m)
- MongoStorageBackend.getRefStorageFor("U-R1") match {
- case None => fail("should not be empty")
- case Some(r) => {
- val a = r.asInstanceOf[JsValue]
- val m1 = Symbol("1") ? num
- val m2 = Symbol("2") ? num
- val m3 = Symbol("3") ? num
+ it("should do proper range queries") {
+ import MongoStorageBackend._
+ val l = List(
+ ("bjarne stroustrup", "c++"),
+ ("martin odersky", "scala"),
+ ("james gosling", "java"),
+ ("yukihiro matsumoto", "ruby"),
+ ("slava pestov", "factor"),
+ ("rich hickey", "clojure"),
+ ("ola bini", "ioke"),
+ ("dennis ritchie", "c"),
+ ("larry wall", "perl"),
+ ("guido van rossum", "python"),
+ ("james strachan", "groovy"))
+ insertMapStorageEntriesFor("t1", l.map { case (k, v) => (k.getBytes, v.getBytes) })
+ getMapStorageSizeFor("t1") should equal(l.size)
+ getMapStorageRangeFor("t1", None, None, 100).map { case (k, v) => (new String(k), new String(v)) } should equal(l.sortWith(_._1 < _._1))
+ getMapStorageRangeFor("t1", None, None, 5).map { case (k, v) => (new String(k), new String(v)) }.size should equal(5)
+ }
+ }
- val m1(n1) = a
- val m2(n2) = a
- val m3(n3) = a
+ describe("persistent vectors") {
+ it("should insert a single value") {
+ import MongoStorageBackend._
- assertEquals(n1, 1)
- assertEquals(n2, 4)
- assertEquals(n3, 9)
- }
+ insertVectorStorageEntryFor("t1", "martin odersky".getBytes)
+ insertVectorStorageEntryFor("t1", "james gosling".getBytes)
+ new String(getVectorStorageEntryFor("t1", 0)) should equal("james gosling")
+ new String(getVectorStorageEntryFor("t1", 1)) should equal("martin odersky")
}
- // insert another one
- // the previous one should be replaced
- val b = List("100", "jonas")
- MongoStorageBackend.insertRefStorageFor("U-R1", b)
- MongoStorageBackend.getRefStorageFor("U-R1") match {
- case None => fail("should not be empty")
- case Some(r) => {
- val a = r.asInstanceOf[JsValue]
- val str_lst = list ! str
- val str_lst(l) = a
- assertEquals(b, l)
- }
+ it("should insert multiple values") {
+ import MongoStorageBackend._
+
+ insertVectorStorageEntryFor("t1", "martin odersky".getBytes)
+ insertVectorStorageEntryFor("t1", "james gosling".getBytes)
+ insertVectorStorageEntriesFor("t1", List("ola bini".getBytes, "james strachan".getBytes, "dennis ritchie".getBytes))
+ new String(getVectorStorageEntryFor("t1", 0)) should equal("ola bini")
+ new String(getVectorStorageEntryFor("t1", 1)) should equal("james strachan")
+ new String(getVectorStorageEntryFor("t1", 2)) should equal("dennis ritchie")
+ new String(getVectorStorageEntryFor("t1", 3)) should equal("james gosling")
+ new String(getVectorStorageEntryFor("t1", 4)) should equal("martin odersky")
+ }
+
+ it("should fetch a range of values") {
+ import MongoStorageBackend._
+
+ insertVectorStorageEntryFor("t1", "martin odersky".getBytes)
+ insertVectorStorageEntryFor("t1", "james gosling".getBytes)
+ getVectorStorageSizeFor("t1") should equal(2)
+ insertVectorStorageEntriesFor("t1", List("ola bini".getBytes, "james strachan".getBytes, "dennis ritchie".getBytes))
+ getVectorStorageRangeFor("t1", None, None, 100).map(new String(_)) should equal(List("ola bini", "james strachan", "dennis ritchie", "james gosling", "martin odersky"))
+ getVectorStorageRangeFor("t1", Some(0), Some(5), 100).map(new String(_)) should equal(List("ola bini", "james strachan", "dennis ritchie", "james gosling", "martin odersky"))
+ getVectorStorageRangeFor("t1", Some(2), Some(5), 100).map(new String(_)) should equal(List("dennis ritchie", "james gosling", "martin odersky"))
+
+ getVectorStorageSizeFor("t1") should equal(5)
+ }
+
+ it("should insert and query complex structures") {
+ import MongoStorageBackend._
+ import sjson.json.DefaultProtocol._
+ import sjson.json.JsonSerialization._
+
+ // a list[AnyRef] should be added successfully
+ val l = List("ola bini".getBytes, tobinary(List(100, 200, 300)), tobinary(List(1, 2, 3)))
+
+ // for id = t1
+ insertVectorStorageEntriesFor("t1", l)
+ new String(getVectorStorageEntryFor("t1", 0)) should equal("ola bini")
+ frombinary[List[Int]](getVectorStorageEntryFor("t1", 1)) should equal(List(100, 200, 300))
+ frombinary[List[Int]](getVectorStorageEntryFor("t1", 2)) should equal(List(1, 2, 3))
+
+ getVectorStorageSizeFor("t1") should equal(3)
+
+ // some more for id = t1
+ val m = List(tobinary(Map(1 -> "dg", 2 -> "mc", 3 -> "nd")), tobinary(List("martin odersky", "james gosling")))
+ insertVectorStorageEntriesFor("t1", m)
+
+ // size should add up
+ getVectorStorageSizeFor("t1") should equal(5)
+
+ // now for a diff id
+ insertVectorStorageEntriesFor("t2", l)
+ getVectorStorageSizeFor("t2") should equal(3)
+ }
+ }
+
+ describe("persistent refs") {
+ it("should insert a ref") {
+ import MongoStorageBackend._
+
+ insertRefStorageFor("t1", "martin odersky".getBytes)
+ new String(getRefStorageFor("t1").get) should equal("martin odersky")
+ insertRefStorageFor("t1", "james gosling".getBytes)
+ new String(getRefStorageFor("t1").get) should equal("james gosling")
+ getRefStorageFor("t2") should equal(None)
}
}
}
diff --git a/akka-persistence/akka-persistence-mongo/src/test/scala/MongoTicket343Spec.scala b/akka-persistence/akka-persistence-mongo/src/test/scala/MongoTicket343Spec.scala
new file mode 100644
index 0000000000..3b160c8c50
--- /dev/null
+++ b/akka-persistence/akka-persistence-mongo/src/test/scala/MongoTicket343Spec.scala
@@ -0,0 +1,347 @@
+package se.scalablesolutions.akka.persistence.mongo
+
+import org.scalatest.Spec
+import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+import org.scalatest.junit.JUnitRunner
+import org.junit.runner.RunWith
+
+import se.scalablesolutions.akka.actor.{Actor, ActorRef}
+import se.scalablesolutions.akka.config.OneForOneStrategy
+import Actor._
+import se.scalablesolutions.akka.stm.global._
+import se.scalablesolutions.akka.config.ScalaConfig._
+import se.scalablesolutions.akka.util.Logging
+
+import MongoStorageBackend._
+
+case class GET(k: String)
+case class SET(k: String, v: String)
+case class REM(k: String)
+case class CONTAINS(k: String)
+case object MAP_SIZE
+case class MSET(kvs: List[(String, String)])
+case class REMOVE_AFTER_PUT(kvsToAdd: List[(String, String)], ksToRem: List[String])
+case class CLEAR_AFTER_PUT(kvsToAdd: List[(String, String)])
+case class PUT_WITH_SLICE(kvsToAdd: List[(String, String)], start: String, cnt: Int)
+case class PUT_REM_WITH_SLICE(kvsToAdd: List[(String, String)], ksToRem: List[String], start: String, cnt: Int)
+
+case class VADD(v: String)
+case class VUPD(i: Int, v: String)
+case class VUPD_AND_ABORT(i: Int, v: String)
+case class VGET(i: Int)
+case object VSIZE
+case class VGET_AFTER_VADD(vsToAdd: List[String], isToFetch: List[Int])
+case class VADD_WITH_SLICE(vsToAdd: List[String], start: Int, cnt: Int)
+
+object Storage {
+ class MongoSampleMapStorage extends Actor {
+ self.lifeCycle = Some(LifeCycle(Permanent))
+ val FOO_MAP = "akka.sample.map"
+
+ private var fooMap = atomic { MongoStorage.getMap(FOO_MAP) }
+
+ def receive = {
+ case SET(k, v) =>
+ atomic {
+ fooMap += (k.getBytes, v.getBytes)
+ }
+ self.reply((k, v))
+
+ case GET(k) =>
+ val v = atomic {
+ fooMap.get(k.getBytes).map(new String(_)).getOrElse(k + " Not found")
+ }
+ self.reply(v)
+
+ case REM(k) =>
+ val v = atomic {
+ fooMap -= k.getBytes
+ }
+ self.reply(k)
+
+ case CONTAINS(k) =>
+ val v = atomic {
+ fooMap contains k.getBytes
+ }
+ self.reply(v)
+
+ case MAP_SIZE =>
+ val v = atomic {
+ fooMap.size
+ }
+ self.reply(v)
+
+ case MSET(kvs) => atomic {
+ kvs.foreach {kv => fooMap += (kv._1.getBytes, kv._2.getBytes) }
+ }
+ self.reply(kvs.size)
+
+ case REMOVE_AFTER_PUT(kvs2add, ks2rem) => atomic {
+ kvs2add.foreach {kv =>
+ fooMap += (kv._1.getBytes, kv._2.getBytes)
+ }
+
+ ks2rem.foreach {k =>
+ fooMap -= k.getBytes
+ }}
+ self.reply(fooMap.size)
+
+ case CLEAR_AFTER_PUT(kvs2add) => atomic {
+ kvs2add.foreach {kv =>
+ fooMap += (kv._1.getBytes, kv._2.getBytes)
+ }
+ fooMap.clear
+ }
+ self.reply(true)
+
+ case PUT_WITH_SLICE(kvs2add, from, cnt) =>
+ val v = atomic {
+ kvs2add.foreach {kv =>
+ fooMap += (kv._1.getBytes, kv._2.getBytes)
+ }
+ fooMap.slice(Some(from.getBytes), cnt)
+ }
+ self.reply(v: List[(Array[Byte], Array[Byte])])
+
+ case PUT_REM_WITH_SLICE(kvs2add, ks2rem, from, cnt) =>
+ val v = atomic {
+ kvs2add.foreach {kv =>
+ fooMap += (kv._1.getBytes, kv._2.getBytes)
+ }
+ ks2rem.foreach {k =>
+ fooMap -= k.getBytes
+ }
+ fooMap.slice(Some(from.getBytes), cnt)
+ }
+ self.reply(v: List[(Array[Byte], Array[Byte])])
+ }
+ }
+
+ class MongoSampleVectorStorage extends Actor {
+ self.lifeCycle = Some(LifeCycle(Permanent))
+ val FOO_VECTOR = "akka.sample.vector"
+
+ private var fooVector = atomic { MongoStorage.getVector(FOO_VECTOR) }
+
+ def receive = {
+ case VADD(v) =>
+ val size =
+ atomic {
+ fooVector + v.getBytes
+ fooVector length
+ }
+ self.reply(size)
+
+ case VGET(index) =>
+ val ind =
+ atomic {
+ fooVector get index
+ }
+ self.reply(ind)
+
+ case VGET_AFTER_VADD(vs, is) =>
+ val els =
+ atomic {
+ vs.foreach(fooVector + _.getBytes)
+ (is.foldRight(List[Array[Byte]]())(fooVector.get(_) :: _)).map(new String(_))
+ }
+ self.reply(els)
+
+ case VUPD_AND_ABORT(index, value) =>
+ val l =
+ atomic {
+ fooVector.update(index, value.getBytes)
+ // force fail
+ fooVector get 100
+ }
+ self.reply(index)
+
+ case VADD_WITH_SLICE(vs, s, c) =>
+ val l =
+ atomic {
+ vs.foreach(fooVector + _.getBytes)
+ fooVector.slice(Some(s), None, c)
+ }
+ self.reply(l.map(new String(_)))
+ }
+ }
+}
+
+import Storage._
+
+@RunWith(classOf[JUnitRunner])
+class MongoTicket343Spec extends
+ Spec with
+ ShouldMatchers with
+ BeforeAndAfterAll with
+ BeforeAndAfterEach {
+
+
+ override def beforeAll {
+ MongoStorageBackend.drop
+ println("** destroyed database")
+ }
+
+ override def beforeEach {
+ MongoStorageBackend.drop
+ println("** destroyed database")
+ }
+
+ override def afterEach {
+ MongoStorageBackend.drop
+ println("** destroyed database")
+ }
+
+ describe("Ticket 343 Issue #1") {
+ it("remove after put should work within the same transaction") {
+ val proc = actorOf[MongoSampleMapStorage]
+ proc.start
+
+ (proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
+ (proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
+
+ (proc !! MSET(List(("dg", "1"), ("mc", "2"), ("nd", "3")))).getOrElse("Mset failed") should equal(3)
+
+ (proc !! GET("dg")).getOrElse("Get failed") should equal("1")
+ (proc !! GET("mc")).getOrElse("Get failed") should equal("2")
+ (proc !! GET("nd")).getOrElse("Get failed") should equal("3")
+
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(4)
+
+ val add = List(("a", "1"), ("b", "2"), ("c", "3"))
+ val rem = List("a", "debasish")
+ (proc !! REMOVE_AFTER_PUT(add, rem)).getOrElse("REMOVE_AFTER_PUT failed") should equal(5)
+
+ (proc !! GET("debasish")).getOrElse("debasish not found") should equal("debasish Not found")
+ (proc !! GET("a")).getOrElse("a not found") should equal("a Not found")
+
+ (proc !! GET("b")).getOrElse("b not found") should equal("2")
+
+ (proc !! CONTAINS("b")).getOrElse("b not found") should equal(true)
+ (proc !! CONTAINS("debasish")).getOrElse("debasish not found") should equal(false)
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(5)
+ proc.stop
+ }
+ }
+
+ describe("Ticket 343 Issue #2") {
+ it("clear after put should work within the same transaction") {
+ val proc = actorOf[MongoSampleMapStorage]
+ proc.start
+
+ (proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
+ (proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
+
+ val add = List(("a", "1"), ("b", "2"), ("c", "3"))
+ (proc !! CLEAR_AFTER_PUT(add)).getOrElse("CLEAR_AFTER_PUT failed") should equal(true)
+
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
+ proc.stop
+ }
+ }
+
+ describe("Ticket 343 Issue #3") {
+ it("map size should change after the transaction") {
+ val proc = actorOf[MongoSampleMapStorage]
+ proc.start
+
+ (proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
+ (proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
+
+ (proc !! MSET(List(("dg", "1"), ("mc", "2"), ("nd", "3")))).getOrElse("Mset failed") should equal(3)
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(4)
+
+ (proc !! GET("dg")).getOrElse("Get failed") should equal("1")
+ (proc !! GET("mc")).getOrElse("Get failed") should equal("2")
+ (proc !! GET("nd")).getOrElse("Get failed") should equal("3")
+ proc.stop
+ }
+ }
+
+ describe("slice test") {
+ it("should pass") {
+ val proc = actorOf[MongoSampleMapStorage]
+ proc.start
+
+ (proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
+ (proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
+ // (proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
+
+ (proc !! MSET(List(("dg", "1"), ("mc", "2"), ("nd", "3")))).getOrElse("Mset failed") should equal(3)
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(4)
+
+ (proc !! PUT_WITH_SLICE(List(("ec", "1"), ("tb", "2"), ("mc", "10")), "dg", 3)).get.asInstanceOf[List[(Array[Byte], Array[Byte])]].map { case (k, v) => (new String(k), new String(v)) } should equal(List(("dg", "1"), ("ec", "1"), ("mc", "10")))
+
+ (proc !! PUT_REM_WITH_SLICE(List(("fc", "1"), ("gb", "2"), ("xy", "10")), List("tb", "fc"), "dg", 5)).get.asInstanceOf[List[(Array[Byte], Array[Byte])]].map { case (k, v) => (new String(k), new String(v)) } should equal(List(("dg", "1"), ("ec", "1"), ("gb", "2"), ("mc", "10"), ("nd", "3")))
+ proc.stop
+ }
+ }
+
+ describe("Ticket 343 Issue #4") {
+ it("vector get should not ignore elements that were in vector before transaction") {
+
+ val proc = actorOf[MongoSampleVectorStorage]
+ proc.start
+
+ // add 4 elements in separate transactions
+ (proc !! VADD("debasish")).getOrElse("VADD failed") should equal(1)
+ (proc !! VADD("maulindu")).getOrElse("VADD failed") should equal(2)
+ (proc !! VADD("ramanendu")).getOrElse("VADD failed") should equal(3)
+ (proc !! VADD("nilanjan")).getOrElse("VADD failed") should equal(4)
+
+ new String((proc !! VGET(0)).get.asInstanceOf[Array[Byte]] ) should equal("nilanjan")
+ new String((proc !! VGET(1)).get.asInstanceOf[Array[Byte]] ) should equal("ramanendu")
+ new String((proc !! VGET(2)).get.asInstanceOf[Array[Byte]] ) should equal("maulindu")
+ new String((proc !! VGET(3)).get.asInstanceOf[Array[Byte]] ) should equal("debasish")
+
+ // now add 3 more and do gets in the same transaction
+ (proc !! VGET_AFTER_VADD(List("a", "b", "c"), List(0, 2, 4))).get.asInstanceOf[List[String]] should equal(List("c", "a", "ramanendu"))
+ proc.stop
+ }
+ }
+
+ describe("Ticket 343 Issue #6") {
+ it("vector update should not ignore transaction") {
+ val proc = actorOf[MongoSampleVectorStorage]
+ proc.start
+
+ // add 4 elements in separate transactions
+ (proc !! VADD("debasish")).getOrElse("VADD failed") should equal(1)
+ (proc !! VADD("maulindu")).getOrElse("VADD failed") should equal(2)
+ (proc !! VADD("ramanendu")).getOrElse("VADD failed") should equal(3)
+ (proc !! VADD("nilanjan")).getOrElse("VADD failed") should equal(4)
+
+ evaluating {
+ (proc !! VUPD_AND_ABORT(0, "virat")).getOrElse("VUPD_AND_ABORT failed")
+ } should produce [Exception]
+
+ // update aborts and hence values will remain unchanged
+ new String((proc !! VGET(0)).get.asInstanceOf[Array[Byte]] ) should equal("nilanjan")
+ proc.stop
+ }
+ }
+
+ describe("Ticket 343 Issue #5") {
+ it("vector slice() should not ignore elements added in current transaction") {
+ val proc = actorOf[MongoSampleVectorStorage]
+ proc.start
+
+ // add 4 elements in separate transactions
+ (proc !! VADD("debasish")).getOrElse("VADD failed") should equal(1)
+ (proc !! VADD("maulindu")).getOrElse("VADD failed") should equal(2)
+ (proc !! VADD("ramanendu")).getOrElse("VADD failed") should equal(3)
+ (proc !! VADD("nilanjan")).getOrElse("VADD failed") should equal(4)
+
+ // slice with no new elements added in current transaction
+ (proc !! VADD_WITH_SLICE(List(), 2, 2)).getOrElse("VADD_WITH_SLICE failed") should equal(Vector("maulindu", "debasish"))
+
+ // slice with new elements added in current transaction
+ (proc !! VADD_WITH_SLICE(List("a", "b", "c", "d"), 2, 2)).getOrElse("VADD_WITH_SLICE failed") should equal(Vector("b", "a"))
+ proc.stop
+ }
+ }
+}
diff --git a/akka-persistence/akka-persistence-redis/src/main/scala/RedisStorage.scala b/akka-persistence/akka-persistence-redis/src/main/scala/RedisStorage.scala
index c92761beea..1eca775567 100644
--- a/akka-persistence/akka-persistence-redis/src/main/scala/RedisStorage.scala
+++ b/akka-persistence/akka-persistence-redis/src/main/scala/RedisStorage.scala
@@ -36,7 +36,7 @@ object RedisStorage extends Storage {
*
* @author Debasish Ghosh
*/
-class RedisPersistentMap(id: String) extends PersistentMap[Array[Byte], Array[Byte]] {
+class RedisPersistentMap(id: String) extends PersistentMapBinary {
val uuid = id
val storage = RedisStorageBackend
}
diff --git a/akka-persistence/akka-persistence-redis/src/main/scala/RedisStorageBackend.scala b/akka-persistence/akka-persistence-redis/src/main/scala/RedisStorageBackend.scala
index 9200393ef9..61595ec21f 100644
--- a/akka-persistence/akka-persistence-redis/src/main/scala/RedisStorageBackend.scala
+++ b/akka-persistence/akka-persistence-redis/src/main/scala/RedisStorageBackend.scala
@@ -96,12 +96,12 @@ private [akka] object RedisStorageBackend extends
* both parts of the key need to be based64 encoded since there can be spaces within each of them
*/
private [this] def makeRedisKey(name: String, key: Array[Byte]): String = withErrorHandling {
- "%s:%s".format(name, byteArrayToString(key))
+ "%s:%s".format(name, new String(key))
}
private [this] def makeKeyFromRedisKey(redisKey: String) = withErrorHandling {
val nk = redisKey.split(':')
- (nk(0), stringToByteArray(nk(1)))
+ (nk(0), nk(1).getBytes)
}
private [this] def mset(entries: List[(String, String)]): Unit = withErrorHandling {
@@ -124,27 +124,22 @@ private [akka] object RedisStorageBackend extends
}
def getMapStorageEntryFor(name: String, key: Array[Byte]): Option[Array[Byte]] = withErrorHandling {
- db.get(makeRedisKey(name, key)) match {
- case None =>
- throw new NoSuchElementException(new String(key) + " not present")
- case Some(s) => Some(stringToByteArray(s))
+ db.get(makeRedisKey(name, key))
+ .map(stringToByteArray(_))
+ .orElse(throw new NoSuchElementException(new String(key) + " not present"))
}
- }
def getMapStorageSizeFor(name: String): Int = withErrorHandling {
- db.keys("%s:*".format(name)) match {
- case None => 0
- case Some(keys) => keys.length
- }
+ db.keys("%s:*".format(name)).map(_.length).getOrElse(0)
}
def getMapStorageFor(name: String): List[(Array[Byte], Array[Byte])] = withErrorHandling {
- db.keys("%s:*".format(name)) match {
- case None =>
- throw new NoSuchElementException(name + " not present")
- case Some(keys) =>
+ db.keys("%s:*".format(name))
+ .map { keys =>
keys.map(key => (makeKeyFromRedisKey(key.get)._2, stringToByteArray(db.get(key.get).get))).toList
- }
+ }.getOrElse {
+ throw new NoSuchElementException(name + " not present")
+ }
}
def getMapStorageRangeFor(name: String, start: Option[Array[Byte]],
@@ -207,12 +202,11 @@ private [akka] object RedisStorageBackend extends
}
def getVectorStorageEntryFor(name: String, index: Int): Array[Byte] = withErrorHandling {
- db.lindex(name, index) match {
- case None =>
+ db.lindex(name, index)
+ .map(stringToByteArray(_))
+ .getOrElse {
throw new NoSuchElementException(name + " does not have element at " + index)
- case Some(e) =>
- stringToByteArray(e)
- }
+ }
}
/**
@@ -252,11 +246,11 @@ private [akka] object RedisStorageBackend extends
}
def getRefStorageFor(name: String): Option[Array[Byte]] = withErrorHandling {
- db.get(name) match {
- case None =>
+ db.get(name)
+ .map(stringToByteArray(_))
+ .orElse {
throw new NoSuchElementException(name + " not present")
- case Some(s) => Some(stringToByteArray(s))
- }
+ }
}
// add to the end of the queue
@@ -266,11 +260,11 @@ private [akka] object RedisStorageBackend extends
// pop from the front of the queue
def dequeue(name: String): Option[Array[Byte]] = withErrorHandling {
- db.lpop(name) match {
- case None =>
+ db.lpop(name)
+ .map(stringToByteArray(_))
+ .orElse {
throw new NoSuchElementException(name + " not present")
- case Some(s) => Some(stringToByteArray(s))
- }
+ }
}
// get the size of the queue
@@ -302,26 +296,19 @@ private [akka] object RedisStorageBackend extends
// completely delete the queue
def remove(name: String): Boolean = withErrorHandling {
- db.del(name) match {
- case Some(1) => true
- case _ => false
- }
+ db.del(name).map { case 1 => true }.getOrElse(false)
}
// add item to sorted set identified by name
def zadd(name: String, zscore: String, item: Array[Byte]): Boolean = withErrorHandling {
- db.zadd(name, zscore, byteArrayToString(item)) match {
- case Some(1) => true
- case _ => false
- }
+ db.zadd(name, zscore, byteArrayToString(item))
+ .map { case 1 => true }.getOrElse(false)
}
// remove item from sorted set identified by name
def zrem(name: String, item: Array[Byte]): Boolean = withErrorHandling {
- db.zrem(name, byteArrayToString(item)) match {
- case Some(1) => true
- case _ => false
- }
+ db.zrem(name, byteArrayToString(item))
+ .map { case 1 => true }.getOrElse(false)
}
// cardinality of the set identified by name
@@ -330,29 +317,23 @@ private [akka] object RedisStorageBackend extends
}
def zscore(name: String, item: Array[Byte]): Option[Float] = withErrorHandling {
- db.zscore(name, byteArrayToString(item)) match {
- case Some(s) => Some(s.toFloat)
- case None => None
- }
+ db.zscore(name, byteArrayToString(item)).map(_.toFloat)
}
def zrange(name: String, start: Int, end: Int): List[Array[Byte]] = withErrorHandling {
- db.zrange(name, start.toString, end.toString, RedisClient.ASC, false) match {
- case None =>
+ db.zrange(name, start.toString, end.toString, RedisClient.ASC, false)
+ .map(_.map(e => stringToByteArray(e.get)))
+ .getOrElse {
throw new NoSuchElementException(name + " not present")
- case Some(s) =>
- s.map(e => stringToByteArray(e.get))
- }
+ }
}
def zrangeWithScore(name: String, start: Int, end: Int): List[(Array[Byte], Float)] = withErrorHandling {
- db.zrangeWithScore(
- name, start.toString, end.toString, RedisClient.ASC) match {
- case None =>
- throw new NoSuchElementException(name + " not present")
- case Some(l) =>
- l.map{ case (elem, score) => (stringToByteArray(elem.get), score.get.toFloat) }
- }
+ db.zrangeWithScore(name, start.toString, end.toString, RedisClient.ASC)
+ .map(_.map { case (elem, score) => (stringToByteArray(elem.get), score.get.toFloat) })
+ .getOrElse {
+ throw new NoSuchElementException(name + " not present")
+ }
}
def flushDB = withErrorHandling(db.flushdb)
diff --git a/akka-persistence/akka-persistence-redis/src/test/scala/RedisTicket343Spec.scala b/akka-persistence/akka-persistence-redis/src/test/scala/RedisTicket343Spec.scala
new file mode 100644
index 0000000000..de236b9a5a
--- /dev/null
+++ b/akka-persistence/akka-persistence-redis/src/test/scala/RedisTicket343Spec.scala
@@ -0,0 +1,351 @@
+package se.scalablesolutions.akka.persistence.redis
+
+import org.scalatest.Spec
+import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+import org.scalatest.junit.JUnitRunner
+import org.junit.runner.RunWith
+
+import se.scalablesolutions.akka.actor.{Actor}
+import se.scalablesolutions.akka.config.OneForOneStrategy
+import Actor._
+import se.scalablesolutions.akka.persistence.common.PersistentVector
+import se.scalablesolutions.akka.stm.global._
+import se.scalablesolutions.akka.config.ScalaConfig._
+import se.scalablesolutions.akka.util.Logging
+
+import RedisStorageBackend._
+
+case class GET(k: String)
+case class SET(k: String, v: String)
+case class REM(k: String)
+case class CONTAINS(k: String)
+case object MAP_SIZE
+case class MSET(kvs: List[(String, String)])
+case class REMOVE_AFTER_PUT(kvsToAdd: List[(String, String)], ksToRem: List[String])
+case class CLEAR_AFTER_PUT(kvsToAdd: List[(String, String)])
+case class PUT_WITH_SLICE(kvsToAdd: List[(String, String)], start: String, cnt: Int)
+case class PUT_REM_WITH_SLICE(kvsToAdd: List[(String, String)], ksToRem: List[String], start: String, cnt: Int)
+
+case class VADD(v: String)
+case class VUPD(i: Int, v: String)
+case class VUPD_AND_ABORT(i: Int, v: String)
+case class VGET(i: Int)
+case object VSIZE
+case class VGET_AFTER_VADD(vsToAdd: List[String], isToFetch: List[Int])
+case class VADD_WITH_SLICE(vsToAdd: List[String], start: Int, cnt: Int)
+
+object Storage {
+ class RedisSampleMapStorage extends Actor {
+ self.lifeCycle = Some(LifeCycle(Permanent))
+ val FOO_MAP = "akka.sample.map"
+
+ private var fooMap = atomic { RedisStorage.getMap(FOO_MAP) }
+
+ def receive = {
+ case SET(k, v) =>
+ atomic {
+ fooMap += (k.getBytes, v.getBytes)
+ }
+ self.reply((k, v))
+
+ case GET(k) =>
+ val v = atomic {
+ fooMap.get(k.getBytes)
+ }
+ self.reply(v.collect {case byte => new String(byte)}.getOrElse(k + " Not found"))
+
+ case REM(k) =>
+ val v = atomic {
+ fooMap -= k.getBytes
+ }
+ self.reply(k)
+
+ case CONTAINS(k) =>
+ val v = atomic {
+ fooMap contains k.getBytes
+ }
+ self.reply(v)
+
+ case MAP_SIZE =>
+ val v = atomic {
+ fooMap.size
+ }
+ self.reply(v)
+
+ case MSET(kvs) =>
+ atomic {
+ kvs.foreach {kv =>
+ fooMap += (kv._1.getBytes, kv._2.getBytes)
+ }
+ }
+ self.reply(kvs.size)
+
+ case REMOVE_AFTER_PUT(kvs2add, ks2rem) =>
+ val v =
+ atomic {
+ kvs2add.foreach {kv =>
+ fooMap += (kv._1.getBytes, kv._2.getBytes)
+ }
+
+ ks2rem.foreach {k =>
+ fooMap -= k.getBytes
+ }
+ fooMap.size
+ }
+ self.reply(v)
+
+ case CLEAR_AFTER_PUT(kvs2add) =>
+ atomic {
+ kvs2add.foreach {kv =>
+ fooMap += (kv._1.getBytes, kv._2.getBytes)
+ }
+ fooMap.clear
+ }
+ self.reply(true)
+
+ case PUT_WITH_SLICE(kvs2add, from, cnt) =>
+ val v =
+ atomic {
+ kvs2add.foreach {kv =>
+ fooMap += (kv._1.getBytes, kv._2.getBytes)
+ }
+ fooMap.slice(Some(from.getBytes), cnt)
+ }
+ self.reply(v: List[(Array[Byte], Array[Byte])])
+
+ case PUT_REM_WITH_SLICE(kvs2add, ks2rem, from, cnt) =>
+ val v =
+ atomic {
+ kvs2add.foreach {kv =>
+ fooMap += (kv._1.getBytes, kv._2.getBytes)
+ }
+ ks2rem.foreach {k =>
+ fooMap -= k.getBytes
+ }
+ fooMap.slice(Some(from.getBytes), cnt)
+ }
+ self.reply(v: List[(Array[Byte], Array[Byte])])
+ }
+ }
+
+ class RedisSampleVectorStorage extends Actor {
+ self.lifeCycle = Some(LifeCycle(Permanent))
+ val FOO_VECTOR = "akka.sample.vector"
+
+ private var fooVector = atomic { RedisStorage.getVector(FOO_VECTOR) }
+
+ def receive = {
+ case VADD(v) =>
+ val size =
+ atomic {
+ fooVector + v.getBytes
+ fooVector length
+ }
+ self.reply(size)
+
+ case VGET(index) =>
+ val ind =
+ atomic {
+ fooVector get index
+ }
+ self.reply(ind)
+
+ case VGET_AFTER_VADD(vs, is) =>
+ val els =
+ atomic {
+ vs.foreach(fooVector + _.getBytes)
+ (is.foldRight(List[Array[Byte]]())(fooVector.get(_) :: _)).map(new String(_))
+ }
+ self.reply(els)
+
+ case VUPD_AND_ABORT(index, value) =>
+ val l =
+ atomic {
+ fooVector.update(index, value.getBytes)
+ // force fail
+ fooVector get 100
+ }
+ self.reply(index)
+
+ case VADD_WITH_SLICE(vs, s, c) =>
+ val l =
+ atomic {
+ vs.foreach(fooVector + _.getBytes)
+ fooVector.slice(Some(s), None, c)
+ }
+ self.reply(l.map(new String(_)))
+ }
+ }
+}
+
+import Storage._
+
+@RunWith(classOf[JUnitRunner])
+class RedisTicket343Spec extends
+ Spec with
+ ShouldMatchers with
+ BeforeAndAfterAll with
+ BeforeAndAfterEach {
+
+ override def beforeAll {
+ flushDB
+ println("** destroyed database")
+ }
+
+ override def afterEach {
+ flushDB
+ println("** destroyed database")
+ }
+
+ describe("Ticket 343 Issue #1") {
+ it("remove after put should work within the same transaction") {
+ val proc = actorOf[RedisSampleMapStorage]
+ proc.start
+
+ (proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
+ (proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
+
+ (proc !! MSET(List(("dg", "1"), ("mc", "2"), ("nd", "3")))).getOrElse("Mset failed") should equal(3)
+
+ (proc !! GET("dg")).getOrElse("Get failed") should equal("1")
+ (proc !! GET("mc")).getOrElse("Get failed") should equal("2")
+ (proc !! GET("nd")).getOrElse("Get failed") should equal("3")
+
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(4)
+
+ val add = List(("a", "1"), ("b", "2"), ("c", "3"))
+ val rem = List("a", "debasish")
+ (proc !! REMOVE_AFTER_PUT(add, rem)).getOrElse("REMOVE_AFTER_PUT failed") should equal(5)
+
+ (proc !! GET("debasish")).getOrElse("debasish not found") should equal("debasish Not found")
+ (proc !! GET("a")).getOrElse("a not found") should equal("a Not found")
+
+ (proc !! GET("b")).getOrElse("b not found") should equal("2")
+
+ (proc !! CONTAINS("b")).getOrElse("b not found") should equal(true)
+ (proc !! CONTAINS("debasish")).getOrElse("debasish not found") should equal(false)
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(5)
+ proc.stop
+ }
+ }
+
+ describe("Ticket 343 Issue #2") {
+ it("clear after put should work within the same transaction") {
+ val proc = actorOf[RedisSampleMapStorage]
+ proc.start
+
+ (proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
+ (proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
+
+ val add = List(("a", "1"), ("b", "2"), ("c", "3"))
+ (proc !! CLEAR_AFTER_PUT(add)).getOrElse("CLEAR_AFTER_PUT failed") should equal(true)
+
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
+ proc.stop
+ }
+ }
+
+ describe("Ticket 343 Issue #3") {
+ it("map size should change after the transaction") {
+ val proc = actorOf[RedisSampleMapStorage]
+ proc.start
+
+ (proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
+ (proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
+
+ (proc !! MSET(List(("dg", "1"), ("mc", "2"), ("nd", "3")))).getOrElse("Mset failed") should equal(3)
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(4)
+
+ (proc !! GET("dg")).getOrElse("Get failed") should equal("1")
+ (proc !! GET("mc")).getOrElse("Get failed") should equal("2")
+ (proc !! GET("nd")).getOrElse("Get failed") should equal("3")
+ proc.stop
+ }
+ }
+
+ describe("slice test") {
+ it("should pass") {
+ val proc = actorOf[RedisSampleMapStorage]
+ proc.start
+
+ (proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
+ (proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
+
+ (proc !! MSET(List(("dg", "1"), ("mc", "2"), ("nd", "3")))).getOrElse("Mset failed") should equal(3)
+ (proc !! MAP_SIZE).getOrElse("Size failed") should equal(4)
+
+ (proc !! PUT_WITH_SLICE(List(("ec", "1"), ("tb", "2"), ("mc", "10")), "dg", 3)).get.asInstanceOf[List[(Array[Byte], Array[Byte])]].map { case (k, v) => (new String(k), new String(v)) } should equal(List(("dg", "1"), ("ec", "1"), ("mc", "10")))
+
+ (proc !! PUT_REM_WITH_SLICE(List(("fc", "1"), ("gb", "2"), ("xy", "10")), List("tb", "fc"), "dg", 5)).get.asInstanceOf[List[(Array[Byte], Array[Byte])]].map { case (k, v) => (new String(k), new String(v)) } should equal(List(("dg", "1"), ("ec", "1"), ("gb", "2"), ("mc", "10"), ("nd", "3")))
+ proc.stop
+ }
+ }
+
+ describe("Ticket 343 Issue #4") {
+ it("vector get should not ignore elements that were in vector before transaction") {
+ val proc = actorOf[RedisSampleVectorStorage]
+ proc.start
+
+ // add 4 elements in separate transactions
+ (proc !! VADD("debasish")).getOrElse("VADD failed") should equal(1)
+ (proc !! VADD("maulindu")).getOrElse("VADD failed") should equal(2)
+ (proc !! VADD("ramanendu")).getOrElse("VADD failed") should equal(3)
+ (proc !! VADD("nilanjan")).getOrElse("VADD failed") should equal(4)
+
+ new String((proc !! VGET(0)).get.asInstanceOf[Array[Byte]] ) should equal("nilanjan")
+ new String((proc !! VGET(1)).get.asInstanceOf[Array[Byte]] ) should equal("ramanendu")
+ new String((proc !! VGET(2)).get.asInstanceOf[Array[Byte]] ) should equal("maulindu")
+ new String((proc !! VGET(3)).get.asInstanceOf[Array[Byte]] ) should equal("debasish")
+
+ // now add 3 more and do gets in the same transaction
+ (proc !! VGET_AFTER_VADD(List("a", "b", "c"), List(0, 2, 4))).get.asInstanceOf[List[String]] should equal(List("c", "a", "ramanendu"))
+ proc.stop
+ }
+ }
+
+ describe("Ticket 343 Issue #6") {
+ it("vector update should not ignore transaction") {
+ val proc = actorOf[RedisSampleVectorStorage]
+ proc.start
+
+ // add 4 elements in separate transactions
+ (proc !! VADD("debasish")).getOrElse("VADD failed") should equal(1)
+ (proc !! VADD("maulindu")).getOrElse("VADD failed") should equal(2)
+ (proc !! VADD("ramanendu")).getOrElse("VADD failed") should equal(3)
+ (proc !! VADD("nilanjan")).getOrElse("VADD failed") should equal(4)
+
+ evaluating {
+ (proc !! VUPD_AND_ABORT(0, "virat")).getOrElse("VUPD_AND_ABORT failed")
+ } should produce [Exception]
+
+ // update aborts and hence values will remain unchanged
+ new String((proc !! VGET(0)).get.asInstanceOf[Array[Byte]] ) should equal("nilanjan")
+ proc.stop
+ }
+ }
+
+ describe("Ticket 343 Issue #5") {
+ it("vector slice() should not ignore elements added in current transaction") {
+ val proc = actorOf[RedisSampleVectorStorage]
+ proc.start
+
+ // add 4 elements in separate transactions
+ (proc !! VADD("debasish")).getOrElse("VADD failed") should equal(1)
+ (proc !! VADD("maulindu")).getOrElse("VADD failed") should equal(2)
+ (proc !! VADD("ramanendu")).getOrElse("VADD failed") should equal(3)
+ (proc !! VADD("nilanjan")).getOrElse("VADD failed") should equal(4)
+
+ // slice with no new elements added in current transaction
+ (proc !! VADD_WITH_SLICE(List(), 2, 2)).getOrElse("VADD_WITH_SLICE failed") should equal(Vector("maulindu", "debasish"))
+
+ // slice with new elements added in current transaction
+ (proc !! VADD_WITH_SLICE(List("a", "b", "c", "d"), 2, 2)).getOrElse("VADD_WITH_SLICE failed") should equal(Vector("b", "a"))
+ proc.stop
+ }
+ }
+}
diff --git a/akka-remote/src/main/scala/remote/RemoteClient.scala b/akka-remote/src/main/scala/remote/RemoteClient.scala
index 62fec595d3..97e2f3070f 100644
--- a/akka-remote/src/main/scala/remote/RemoteClient.scala
+++ b/akka-remote/src/main/scala/remote/RemoteClient.scala
@@ -30,6 +30,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.{HashSet, HashMap}
import scala.reflect.BeanProperty
+import se.scalablesolutions.akka.actor._
/**
* Atomic remote request/reply message id generator.
@@ -76,8 +77,6 @@ object RemoteClient extends Logging {
private val remoteClients = new HashMap[String, RemoteClient]
private val remoteActors = new HashMap[RemoteServer.Address, HashSet[String]]
- // FIXME: simplify overloaded methods when we have Scala 2.8
-
def actorFor(classNameOrServiceId: String, hostname: String, port: Int): ActorRef =
actorFor(classNameOrServiceId, classNameOrServiceId, 5000L, hostname, port, None)
@@ -99,6 +98,27 @@ object RemoteClient extends Logging {
def actorFor(serviceId: String, className: String, timeout: Long, hostname: String, port: Int): ActorRef =
RemoteActorRef(serviceId, className, hostname, port, timeout, None)
+ def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, hostname: String, port: Int) : T = {
+ typedActorFor(intfClass, serviceIdOrClassName, serviceIdOrClassName, 5000L, hostname, port, None)
+ }
+
+ def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, timeout: Long, hostname: String, port: Int) : T = {
+ typedActorFor(intfClass, serviceIdOrClassName, serviceIdOrClassName, timeout, hostname, port, None)
+ }
+
+ def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, timeout: Long, hostname: String, port: Int, loader: ClassLoader) : T = {
+ typedActorFor(intfClass, serviceIdOrClassName, serviceIdOrClassName, timeout, hostname, port, Some(loader))
+ }
+
+ def typedActorFor[T](intfClass: Class[T], serviceId: String, implClassName: String, timeout: Long, hostname: String, port: Int, loader: ClassLoader) : T = {
+ typedActorFor(intfClass, serviceId, implClassName, timeout, hostname, port, Some(loader))
+ }
+
+ private[akka] def typedActorFor[T](intfClass: Class[T], serviceId: String, implClassName: String, timeout: Long, hostname: String, port: Int, loader: Option[ClassLoader]) : T = {
+ val actorRef = RemoteActorRef(serviceId, implClassName, hostname, port, timeout, loader, ActorType.TypedActor)
+ TypedActor.createProxyForRemoteActorRef(intfClass, actorRef)
+ }
+
private[akka] def actorFor(serviceId: String, className: String, timeout: Long, hostname: String, port: Int, loader: ClassLoader): ActorRef =
RemoteActorRef(serviceId, className, hostname, port, timeout, Some(loader))
@@ -220,10 +240,10 @@ class RemoteClient private[akka] (
val channel = connection.awaitUninterruptibly.getChannel
openChannels.add(channel)
if (!connection.isSuccess) {
- foreachListener(_ ! RemoteClientError(connection.getCause, this))
+ notifyListeners(RemoteClientError(connection.getCause, this))
log.error(connection.getCause, "Remote client connection to [%s:%s] has failed", hostname, port)
}
- foreachListener(_ ! RemoteClientStarted(this))
+ notifyListeners(RemoteClientStarted(this))
isRunning = true
}
}
@@ -232,7 +252,7 @@ class RemoteClient private[akka] (
log.info("Shutting down %s", name)
if (isRunning) {
isRunning = false
- foreachListener(_ ! RemoteClientShutdown(this))
+ notifyListeners(RemoteClientShutdown(this))
timer.stop
timer = null
openChannels.close.awaitUninterruptibly
@@ -250,7 +270,7 @@ class RemoteClient private[akka] (
@deprecated("Use removeListener instead")
def deregisterListener(actorRef: ActorRef) = removeListener(actorRef)
- override def foreachListener(f: (ActorRef) => Unit): Unit = super.foreachListener(f)
+ override def notifyListeners(message: => Any): Unit = super.notifyListeners(message)
protected override def manageLifeCycleOfListeners = false
@@ -287,7 +307,7 @@ class RemoteClient private[akka] (
} else {
val exception = new RemoteClientException(
"Remote client is not running, make sure you have invoked 'RemoteClient.connect' before using it.", this)
- foreachListener(l => l ! RemoteClientError(exception, this))
+ notifyListeners(RemoteClientError(exception, this))
throw exception
}
@@ -403,12 +423,12 @@ class RemoteClientHandler(
futures.remove(reply.getId)
} else {
val exception = new RemoteClientException("Unknown message received in remote client handler: " + result, client)
- client.foreachListener(_ ! RemoteClientError(exception, client))
+ client.notifyListeners(RemoteClientError(exception, client))
throw exception
}
} catch {
case e: Exception =>
- client.foreachListener(_ ! RemoteClientError(e, client))
+ client.notifyListeners(RemoteClientError(e, client))
log.error("Unexpected exception in remote client handler: %s", e)
throw e
}
@@ -423,7 +443,7 @@ class RemoteClientHandler(
client.connection = bootstrap.connect(remoteAddress)
client.connection.awaitUninterruptibly // Wait until the connection attempt succeeds or fails.
if (!client.connection.isSuccess) {
- client.foreachListener(_ ! RemoteClientError(client.connection.getCause, client))
+ client.notifyListeners(RemoteClientError(client.connection.getCause, client))
log.error(client.connection.getCause, "Reconnection to [%s] has failed", remoteAddress)
}
}
@@ -433,7 +453,7 @@ class RemoteClientHandler(
override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
def connect = {
- client.foreachListener(_ ! RemoteClientConnected(client))
+ client.notifyListeners(RemoteClientConnected(client))
log.debug("Remote client connected to [%s]", ctx.getChannel.getRemoteAddress)
client.resetReconnectionTimeWindow
}
@@ -450,12 +470,12 @@ class RemoteClientHandler(
}
override def channelDisconnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
- client.foreachListener(_ ! RemoteClientDisconnected(client))
+ client.notifyListeners(RemoteClientDisconnected(client))
log.debug("Remote client disconnected from [%s]", ctx.getChannel.getRemoteAddress)
}
override def exceptionCaught(ctx: ChannelHandlerContext, event: ExceptionEvent) = {
- client.foreachListener(_ ! RemoteClientError(event.getCause, client))
+ client.notifyListeners(RemoteClientError(event.getCause, client))
log.error(event.getCause, "Unexpected exception from downstream in remote client")
event.getChannel.close
}
diff --git a/akka-remote/src/main/scala/remote/RemoteServer.scala b/akka-remote/src/main/scala/remote/RemoteServer.scala
index f68e602866..4bcd4861ff 100644
--- a/akka-remote/src/main/scala/remote/RemoteServer.scala
+++ b/akka-remote/src/main/scala/remote/RemoteServer.scala
@@ -10,7 +10,7 @@ import java.util.concurrent.{ConcurrentHashMap, Executors}
import java.util.{Map => JMap}
import se.scalablesolutions.akka.actor.{
- Actor, TypedActor, ActorRef, LocalActorRef, RemoteActorRef, IllegalActorStateException, RemoteActorSystemMessage}
+ Actor, TypedActor, ActorRef, IllegalActorStateException, RemoteActorSystemMessage}
import se.scalablesolutions.akka.actor.Actor._
import se.scalablesolutions.akka.util._
import se.scalablesolutions.akka.remote.protocol.RemoteProtocol._
@@ -133,8 +133,8 @@ object RemoteServer {
actorsFor(RemoteServer.Address(address.getHostName, address.getPort)).actors.put(uuid, actor)
}
- private[akka] def registerTypedActor(address: InetSocketAddress, name: String, typedActor: AnyRef) = guard.withWriteGuard {
- actorsFor(RemoteServer.Address(address.getHostName, address.getPort)).typedActors.put(name, typedActor)
+ private[akka] def registerTypedActor(address: InetSocketAddress, uuid: String, typedActor: AnyRef) = guard.withWriteGuard {
+ actorsFor(RemoteServer.Address(address.getHostName, address.getPort)).typedActors.put(uuid, typedActor)
}
private[akka] def getOrCreateServer(address: InetSocketAddress): RemoteServer = guard.withWriteGuard {
@@ -245,12 +245,12 @@ class RemoteServer extends Logging with ListenerManagement {
openChannels.add(bootstrap.bind(new InetSocketAddress(hostname, port)))
_isRunning = true
Cluster.registerLocalNode(hostname, port)
- foreachListener(_ ! RemoteServerStarted(this))
+ notifyListeners(RemoteServerStarted(this))
}
} catch {
case e =>
log.error(e, "Could not start up remote server")
- foreachListener(_ ! RemoteServerError(e, this))
+ notifyListeners(RemoteServerError(e, this))
}
this
}
@@ -263,7 +263,7 @@ class RemoteServer extends Logging with ListenerManagement {
openChannels.close.awaitUninterruptibly
bootstrap.releaseExternalResources
Cluster.deregisterLocalNode(hostname, port)
- foreachListener(_ ! RemoteServerShutdown(this))
+ notifyListeners(RemoteServerShutdown(this))
} catch {
case e: java.nio.channels.ClosedChannelException => {}
case e => log.warning("Could not close remote server channel in a graceful way")
@@ -271,12 +271,28 @@ class RemoteServer extends Logging with ListenerManagement {
}
}
- // TODO: register typed actor in RemoteServer as well
+ /**
+ * Register typed actor by interface name.
+ */
+ def registerTypedActor(intfClass: Class[_], typedActor: AnyRef) : Unit = registerTypedActor(intfClass.getName, typedActor)
/**
- * Register Remote Actor by the Actor's 'uuid' field. It starts the Actor if it is not started already.
+ * Register remote typed actor by a specific id.
+ * @param id custom actor id
+ * @param typedActor typed actor to register
*/
- def register(actorRef: ActorRef): Unit = register(actorRef.id,actorRef)
+ def registerTypedActor(id: String, typedActor: AnyRef): Unit = synchronized {
+ val typedActors = RemoteServer.actorsFor(RemoteServer.Address(hostname, port)).typedActors
+ if (!typedActors.contains(id)) {
+ log.debug("Registering server side remote actor [%s] with id [%s] on [%s:%d]", typedActor.getClass.getName, id, hostname, port)
+ typedActors.put(id, typedActor)
+ }
+ }
+
+ /**
+ * Register Remote Actor by the Actor's 'id' field. It starts the Actor if it is not started already.
+ */
+ def register(actorRef: ActorRef): Unit = register(actorRef.id, actorRef)
/**
* Register Remote Actor by a specific 'id' passed as argument.
@@ -321,16 +337,25 @@ class RemoteServer extends Logging with ListenerManagement {
}
}
+ /**
+ * Unregister Remote Typed Actor by specific 'id'.
+ *
+ * NOTE: You need to call this method if you have registered an actor by a custom ID.
+ */
+ def unregisterTypedActor(id: String):Unit = synchronized {
+ if (_isRunning) {
+ log.info("Unregistering server side remote typed actor with id [%s]", id)
+ val registeredTypedActors = typedActors()
+ registeredTypedActors.remove(id)
+ }
+ }
+
protected override def manageLifeCycleOfListeners = false
- protected[akka] override def foreachListener(f: (ActorRef) => Unit): Unit = super.foreachListener(f)
+ protected[akka] override def notifyListeners(message: => Any): Unit = super.notifyListeners(message)
- private[akka] def actors() : ConcurrentHashMap[String, ActorRef] = {
- RemoteServer.actorsFor(address).actors
- }
- private[akka] def typedActors() : ConcurrentHashMap[String, AnyRef] = {
- RemoteServer.actorsFor(address).typedActors
- }
+ private[akka] def actors() = RemoteServer.actorsFor(address).actors
+ private[akka] def typedActors() = RemoteServer.actorsFor(address).typedActors
}
object RemoteServerSslContext {
@@ -413,18 +438,18 @@ class RemoteServerHandler(
def operationComplete(future: ChannelFuture): Unit = {
if (future.isSuccess) {
openChannels.add(future.getChannel)
- server.foreachListener(_ ! RemoteServerClientConnected(server))
+ server.notifyListeners(RemoteServerClientConnected(server))
} else future.getChannel.close
}
})
} else {
- server.foreachListener(_ ! RemoteServerClientConnected(server))
+ server.notifyListeners(RemoteServerClientConnected(server))
}
}
override def channelClosed(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
log.debug("Remote client disconnected from [%s]", server.name)
- server.foreachListener(_ ! RemoteServerClientDisconnected(server))
+ server.notifyListeners(RemoteServerClientDisconnected(server))
}
override def handleUpstream(ctx: ChannelHandlerContext, event: ChannelEvent) = {
@@ -446,7 +471,7 @@ class RemoteServerHandler(
override def exceptionCaught(ctx: ChannelHandlerContext, event: ExceptionEvent) = {
log.error(event.getCause, "Unexpected exception from remote downstream")
event.getChannel.close
- server.foreachListener(_ ! RemoteServerError(event.getCause, server))
+ server.notifyListeners(RemoteServerError(event.getCause, server))
}
private def handleRemoteRequestProtocol(request: RemoteRequestProtocol, channel: Channel) = {
@@ -491,7 +516,7 @@ class RemoteServerHandler(
} catch {
case e: Throwable =>
channel.write(createErrorReplyMessage(e, request, true))
- server.foreachListener(_ ! RemoteServerError(e, server))
+ server.notifyListeners(RemoteServerError(e, server))
}
}
}
@@ -523,13 +548,39 @@ class RemoteServerHandler(
} catch {
case e: InvocationTargetException =>
channel.write(createErrorReplyMessage(e.getCause, request, false))
- server.foreachListener(_ ! RemoteServerError(e, server))
+ server.notifyListeners(RemoteServerError(e, server))
case e: Throwable =>
channel.write(createErrorReplyMessage(e, request, false))
- server.foreachListener(_ ! RemoteServerError(e, server))
+ server.notifyListeners(RemoteServerError(e, server))
}
}
+ /**
+ * Find a registered actor by ID (default) or UUID.
+ * Actors are registered by id apart from registering during serialization see SerializationProtocol.
+ */
+ private def findActorByIdOrUuid(id: String, uuid: String) : ActorRef = {
+ val registeredActors = server.actors()
+ var actorRefOrNull = registeredActors get id
+ if (actorRefOrNull eq null) {
+ actorRefOrNull = registeredActors get uuid
+ }
+ actorRefOrNull
+ }
+
+ /**
+ * Find a registered typed actor by ID (default) or UUID.
+ * Actors are registered by id apart from registering during serialization see SerializationProtocol.
+ */
+ private def findTypedActorByIdOrUUid(id: String, uuid: String) : AnyRef = {
+ val registeredActors = server.typedActors()
+ var actorRefOrNull = registeredActors get id
+ if (actorRefOrNull eq null) {
+ actorRefOrNull = registeredActors get uuid
+ }
+ actorRefOrNull
+ }
+
/**
* Creates a new instance of the actor with name, uuid and timeout specified as arguments.
*
@@ -538,12 +589,14 @@ class RemoteServerHandler(
* Does not start the actor.
*/
private def createActor(actorInfo: ActorInfoProtocol): ActorRef = {
- val uuid = actorInfo.getUuid
+ val ids = actorInfo.getUuid.split(':')
+ val uuid = ids(0)
+ val id = ids(1)
+
val name = actorInfo.getTarget
val timeout = actorInfo.getTimeout
- val registeredActors = server.actors()
- val actorRefOrNull = registeredActors get uuid
+ val actorRefOrNull = findActorByIdOrUuid(id, uuid)
if (actorRefOrNull eq null) {
try {
@@ -552,23 +605,26 @@ class RemoteServerHandler(
else Class.forName(name)
val actorRef = Actor.actorOf(clazz.newInstance.asInstanceOf[Actor])
actorRef.uuid = uuid
+ actorRef.id = id
actorRef.timeout = timeout
actorRef.remoteAddress = None
- registeredActors.put(uuid, actorRef)
+ server.actors.put(id, actorRef) // register by id
actorRef
} catch {
case e =>
log.error(e, "Could not create remote actor instance")
- server.foreachListener(_ ! RemoteServerError(e, server))
+ server.notifyListeners(RemoteServerError(e, server))
throw e
}
} else actorRefOrNull
}
private def createTypedActor(actorInfo: ActorInfoProtocol): AnyRef = {
- val uuid = actorInfo.getUuid
- val registeredTypedActors = server.typedActors()
- val typedActorOrNull = registeredTypedActors get uuid
+ val ids = actorInfo.getUuid.split(':')
+ val uuid = ids(0)
+ val id = ids(1)
+
+ val typedActorOrNull = findTypedActorByIdOrUUid(id, uuid)
if (typedActorOrNull eq null) {
val typedActorInfo = actorInfo.getTypedActorInfo
@@ -585,12 +641,12 @@ class RemoteServerHandler(
val newInstance = TypedActor.newInstance(
interfaceClass, targetClass.asInstanceOf[Class[_ <: TypedActor]], actorInfo.getTimeout).asInstanceOf[AnyRef]
- registeredTypedActors.put(uuid, newInstance)
+ server.typedActors.put(id, newInstance) // register by id
newInstance
} catch {
case e =>
log.error(e, "Could not create remote typed actor instance")
- server.foreachListener(_ ! RemoteServerError(e, server))
+ server.notifyListeners(RemoteServerError(e, server))
throw e
}
} else typedActorOrNull
diff --git a/akka-remote/src/main/scala/serialization/SerializationProtocol.scala b/akka-remote/src/main/scala/serialization/SerializationProtocol.scala
index 7da001edab..afebae8f3b 100644
--- a/akka-remote/src/main/scala/serialization/SerializationProtocol.scala
+++ b/akka-remote/src/main/scala/serialization/SerializationProtocol.scala
@@ -230,7 +230,7 @@ object RemoteActorSerialization {
}
RemoteActorRefProtocol.newBuilder
- .setUuid(uuid)
+ .setUuid(uuid + ":" + id)
.setActorClassname(actorClass.getName)
.setHomeAddress(AddressProtocol.newBuilder.setHostname(host).setPort(port).build)
.setTimeout(timeout)
@@ -248,7 +248,7 @@ object RemoteActorSerialization {
import actorRef._
val actorInfoBuilder = ActorInfoProtocol.newBuilder
- .setUuid(uuid)
+ .setUuid(uuid + ":" + actorRef.id)
.setTarget(actorClassName)
.setTimeout(timeout)
diff --git a/akka-remote/src/test/resources/META-INF/aop.xml b/akka-remote/src/test/resources/META-INF/aop.xml
index bdc167ca54..be133a51b8 100644
--- a/akka-remote/src/test/resources/META-INF/aop.xml
+++ b/akka-remote/src/test/resources/META-INF/aop.xml
@@ -2,6 +2,7 @@
+
diff --git a/akka-remote/src/test/scala/remote/ClientInitiatedRemoteActorSpec.scala b/akka-remote/src/test/scala/remote/ClientInitiatedRemoteActorSpec.scala
index 7ff46ab910..6670722b02 100644
--- a/akka-remote/src/test/scala/remote/ClientInitiatedRemoteActorSpec.scala
+++ b/akka-remote/src/test/scala/remote/ClientInitiatedRemoteActorSpec.scala
@@ -93,6 +93,7 @@ class ClientInitiatedRemoteActorSpec extends JUnitSuite {
actor.stop
}
+
@Test
def shouldSendOneWayAndReceiveReply = {
val actor = actorOf[SendOneWayAndReplyReceiverActor]
@@ -103,7 +104,7 @@ class ClientInitiatedRemoteActorSpec extends JUnitSuite {
sender.actor.asInstanceOf[SendOneWayAndReplySenderActor].sendTo = actor
sender.start
sender.actor.asInstanceOf[SendOneWayAndReplySenderActor].sendOff
- assert(SendOneWayAndReplySenderActor.latch.await(1, TimeUnit.SECONDS))
+ assert(SendOneWayAndReplySenderActor.latch.await(3, TimeUnit.SECONDS))
assert(sender.actor.asInstanceOf[SendOneWayAndReplySenderActor].state.isDefined === true)
assert("World" === sender.actor.asInstanceOf[SendOneWayAndReplySenderActor].state.get.asInstanceOf[String])
actor.stop
@@ -134,6 +135,6 @@ class ClientInitiatedRemoteActorSpec extends JUnitSuite {
assert("Expected exception; to test fault-tolerance" === e.getMessage())
}
actor.stop
- }
+ }
}
diff --git a/akka-remote/src/test/scala/remote/RemoteTypedActorSpec.scala b/akka-remote/src/test/scala/remote/RemoteTypedActorSpec.scala
index 780828c310..8b28b35f57 100644
--- a/akka-remote/src/test/scala/remote/RemoteTypedActorSpec.scala
+++ b/akka-remote/src/test/scala/remote/RemoteTypedActorSpec.scala
@@ -4,10 +4,7 @@
package se.scalablesolutions.akka.actor.remote
-import org.scalatest.Spec
-import org.scalatest.Assertions
import org.scalatest.matchers.ShouldMatchers
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
@@ -19,6 +16,7 @@ import se.scalablesolutions.akka.actor._
import se.scalablesolutions.akka.remote.{RemoteServer, RemoteClient}
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit, BlockingQueue}
+import org.scalatest.{BeforeAndAfterEach, Spec, Assertions, BeforeAndAfterAll}
object RemoteTypedActorSpec {
val HOSTNAME = "localhost"
@@ -40,7 +38,7 @@ object RemoteTypedActorLog {
class RemoteTypedActorSpec extends
Spec with
ShouldMatchers with
- BeforeAndAfterAll {
+ BeforeAndAfterEach with BeforeAndAfterAll {
import RemoteTypedActorLog._
import RemoteTypedActorSpec._
@@ -82,6 +80,10 @@ class RemoteTypedActorSpec extends
ActorRegistry.shutdownAll
}
+ override def afterEach() {
+ server.typedActors.clear
+ }
+
describe("Remote Typed Actor ") {
it("should receive one-way message") {
diff --git a/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala
index 4ef1abf0c4..59f122c656 100644
--- a/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala
+++ b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteActorSpec.scala
@@ -79,6 +79,7 @@ class ServerInitiatedRemoteActorSpec extends JUnitSuite {
}
}
+
@Test
def shouldSendWithBang {
val actor = RemoteClient.actorFor(
@@ -153,10 +154,29 @@ class ServerInitiatedRemoteActorSpec extends JUnitSuite {
server.register(actorOf[RemoteActorSpecActorUnidirectional])
val actor = RemoteClient.actorFor("se.scalablesolutions.akka.actor.remote.ServerInitiatedRemoteActorSpec$RemoteActorSpecActorUnidirectional", HOSTNAME, PORT)
val numberOfActorsInRegistry = ActorRegistry.actors.length
- val result = actor ! "OneWay"
+ actor ! "OneWay"
assert(RemoteActorSpecActorUnidirectional.latch.await(1, TimeUnit.SECONDS))
assert(numberOfActorsInRegistry === ActorRegistry.actors.length)
actor.stop
}
+
+ @Test
+ def shouldUseServiceNameAsIdForRemoteActorRef {
+ server.register(actorOf[RemoteActorSpecActorUnidirectional])
+ server.register("my-service", actorOf[RemoteActorSpecActorUnidirectional])
+ val actor1 = RemoteClient.actorFor("se.scalablesolutions.akka.actor.remote.ServerInitiatedRemoteActorSpec$RemoteActorSpecActorUnidirectional", HOSTNAME, PORT)
+ val actor2 = RemoteClient.actorFor("my-service", HOSTNAME, PORT)
+ val actor3 = RemoteClient.actorFor("my-service", HOSTNAME, PORT)
+
+ actor1 ! "OneWay"
+ actor2 ! "OneWay"
+ actor3 ! "OneWay"
+
+ assert(actor1.uuid != actor2.uuid)
+ assert(actor1.uuid != actor3.uuid)
+ assert(actor1.id != actor2.id)
+ assert(actor2.id == actor3.id)
+ }
+
}
diff --git a/akka-remote/src/test/scala/remote/ServerInitiatedRemoteTypedActorSpec.scala b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteTypedActorSpec.scala
new file mode 100644
index 0000000000..b800fbf2c3
--- /dev/null
+++ b/akka-remote/src/test/scala/remote/ServerInitiatedRemoteTypedActorSpec.scala
@@ -0,0 +1,112 @@
+/**
+ * Copyright (C) 2009-2010 Scalable Solutions AB
+ */
+
+package se.scalablesolutions.akka.actor.remote
+
+import org.scalatest.Spec
+import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.junit.JUnitRunner
+import org.junit.runner.RunWith
+
+import java.util.concurrent.TimeUnit
+
+import se.scalablesolutions.akka.remote.{RemoteServer, RemoteClient}
+import se.scalablesolutions.akka.actor._
+import RemoteTypedActorLog._
+
+object ServerInitiatedRemoteTypedActorSpec {
+ val HOSTNAME = "localhost"
+ val PORT = 9990
+ var server: RemoteServer = null
+}
+
+@RunWith(classOf[JUnitRunner])
+class ServerInitiatedRemoteTypedActorSpec extends
+ Spec with
+ ShouldMatchers with
+ BeforeAndAfterAll {
+ import ServerInitiatedRemoteTypedActorSpec._
+
+ private val unit = TimeUnit.MILLISECONDS
+
+
+ override def beforeAll = {
+ server = new RemoteServer()
+ server.start(HOSTNAME, PORT)
+
+ val typedActor = TypedActor.newInstance(classOf[RemoteTypedActorOne], classOf[RemoteTypedActorOneImpl], 1000)
+ server.registerTypedActor("typed-actor-service", typedActor)
+
+ Thread.sleep(1000)
+ }
+
+ // make sure the servers shutdown cleanly after the test has finished
+ override def afterAll = {
+ try {
+ server.shutdown
+ RemoteClient.shutdownAll
+ Thread.sleep(1000)
+ } catch {
+ case e => ()
+ }
+ }
+
+ describe("Server managed remote typed Actor ") {
+
+ it("should receive one-way message") {
+ clearMessageLogs
+ val actor = RemoteClient.typedActorFor(classOf[RemoteTypedActorOne], "typed-actor-service", 5000L, HOSTNAME, PORT)
+ expect("oneway") {
+ actor.oneWay
+ oneWayLog.poll(5, TimeUnit.SECONDS)
+ }
+ }
+
+ it("should respond to request-reply message") {
+ clearMessageLogs
+ val actor = RemoteClient.typedActorFor(classOf[RemoteTypedActorOne], "typed-actor-service", 5000L, HOSTNAME, PORT)
+ expect("pong") {
+ actor.requestReply("ping")
+ }
+ }
+
+ it("should not recreate registered actors") {
+ val actor = RemoteClient.typedActorFor(classOf[RemoteTypedActorOne], "typed-actor-service", 5000L, HOSTNAME, PORT)
+ val numberOfActorsInRegistry = ActorRegistry.actors.length
+ expect("oneway") {
+ actor.oneWay
+ oneWayLog.poll(5, TimeUnit.SECONDS)
+ }
+ assert(numberOfActorsInRegistry === ActorRegistry.actors.length)
+ }
+
+ it("should support multiple variants to get the actor from client side") {
+ var actor = RemoteClient.typedActorFor(classOf[RemoteTypedActorOne], "typed-actor-service", 5000L, HOSTNAME, PORT)
+ expect("oneway") {
+ actor.oneWay
+ oneWayLog.poll(5, TimeUnit.SECONDS)
+ }
+ actor = RemoteClient.typedActorFor(classOf[RemoteTypedActorOne], "typed-actor-service", HOSTNAME, PORT)
+ expect("oneway") {
+ actor.oneWay
+ oneWayLog.poll(5, TimeUnit.SECONDS)
+ }
+ actor = RemoteClient.typedActorFor(classOf[RemoteTypedActorOne], "typed-actor-service", 5000L, HOSTNAME, PORT, this.getClass().getClassLoader)
+ expect("oneway") {
+ actor.oneWay
+ oneWayLog.poll(5, TimeUnit.SECONDS)
+ }
+ }
+
+ it("should register and unregister typed actors") {
+ val typedActor = TypedActor.newInstance(classOf[RemoteTypedActorOne], classOf[RemoteTypedActorOneImpl], 1000)
+ server.registerTypedActor("my-test-service", typedActor)
+ assert(server.typedActors().get("my-test-service") != null)
+ server.unregisterTypedActor("my-test-service")
+ assert(server.typedActors().get("my-test-service") == null)
+ }
+ }
+}
+
diff --git a/akka-spring/src/main/resources/se/scalablesolutions/akka/spring/akka-1.0-SNAPSHOT.xsd b/akka-spring/src/main/resources/se/scalablesolutions/akka/spring/akka-1.0-SNAPSHOT.xsd
index cf3c8ffafc..84a382a78e 100644
--- a/akka-spring/src/main/resources/se/scalablesolutions/akka/spring/akka-1.0-SNAPSHOT.xsd
+++ b/akka-spring/src/main/resources/se/scalablesolutions/akka/spring/akka-1.0-SNAPSHOT.xsd
@@ -64,6 +64,14 @@
+
+
+
+
+
+
+
+
@@ -105,6 +113,20 @@
+
+
+
+ Management type for remote actors: client managed or server managed.
+
+
+
+
+
+
+ Custom service name for server managed actor.
+
+
+
@@ -133,7 +155,7 @@
- Theh default timeout for '!!' invocations.
+ The default timeout for '!!' invocations.
@@ -227,6 +249,41 @@
+
+
+
+
+
+
+
+ Name of the remote host.
+
+
+
+
+
+
+ Port of the remote host.
+
+
+
+
+
+
+ Custom service name or class name for the server managed actor.
+
+
+
+
+
+
+ Name of the interface the typed actor implements.
+
+
+
+
+
+
@@ -292,4 +349,7 @@
+
+
+
diff --git a/akka-spring/src/main/scala/ActorBeanDefinitionParser.scala b/akka-spring/src/main/scala/ActorBeanDefinitionParser.scala
new file mode 100644
index 0000000000..55aa82b8e4
--- /dev/null
+++ b/akka-spring/src/main/scala/ActorBeanDefinitionParser.scala
@@ -0,0 +1,72 @@
+/**
+ * Copyright (C) 2009-2010 Scalable Solutions AB
+ */
+package se.scalablesolutions.akka.spring
+
+import org.springframework.beans.factory.support.BeanDefinitionBuilder
+import org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser
+import org.springframework.beans.factory.xml.ParserContext
+import AkkaSpringConfigurationTags._
+import org.w3c.dom.Element
+
+
+/**
+ * Parser for custom namespace configuration.
+ * @author michaelkober
+ */
+class TypedActorBeanDefinitionParser extends AbstractSingleBeanDefinitionParser with ActorParser {
+ /*
+ * @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#doParse(org.w3c.dom.Element, org.springframework.beans.factory.xml.ParserContext, org.springframework.beans.factory.support.BeanDefinitionBuilder)
+ */
+ override def doParse(element: Element, parserContext: ParserContext, builder: BeanDefinitionBuilder) {
+ val typedActorConf = parseActor(element)
+ typedActorConf.typed = TYPED_ACTOR_TAG
+ typedActorConf.setAsProperties(builder)
+ }
+
+ /*
+ * @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#getBeanClass(org.w3c.dom.Element)
+ */
+ override def getBeanClass(element: Element): Class[_] = classOf[ActorFactoryBean]
+}
+
+
+/**
+ * Parser for custom namespace configuration.
+ * @author michaelkober
+ */
+class UntypedActorBeanDefinitionParser extends AbstractSingleBeanDefinitionParser with ActorParser {
+ /*
+ * @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#doParse(org.w3c.dom.Element, org.springframework.beans.factory.xml.ParserContext, org.springframework.beans.factory.support.BeanDefinitionBuilder)
+ */
+ override def doParse(element: Element, parserContext: ParserContext, builder: BeanDefinitionBuilder) {
+ val untypedActorConf = parseActor(element)
+ untypedActorConf.typed = UNTYPED_ACTOR_TAG
+ untypedActorConf.setAsProperties(builder)
+ }
+
+ /*
+ * @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#getBeanClass(org.w3c.dom.Element)
+ */
+ override def getBeanClass(element: Element): Class[_] = classOf[ActorFactoryBean]
+}
+
+
+/**
+ * Parser for custom namespace configuration.
+ * @author michaelkober
+ */
+class ActorForBeanDefinitionParser extends AbstractSingleBeanDefinitionParser with ActorForParser {
+ /*
+ * @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#doParse(org.w3c.dom.Element, org.springframework.beans.factory.xml.ParserContext, org.springframework.beans.factory.support.BeanDefinitionBuilder)
+ */
+ override def doParse(element: Element, parserContext: ParserContext, builder: BeanDefinitionBuilder) {
+ val actorForConf = parseActorFor(element)
+ actorForConf.setAsProperties(builder)
+ }
+
+ /*
+ * @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#getBeanClass(org.w3c.dom.Element)
+ */
+ override def getBeanClass(element: Element): Class[_] = classOf[ActorForFactoryBean]
+}
diff --git a/akka-spring/src/main/scala/ActorFactoryBean.scala b/akka-spring/src/main/scala/ActorFactoryBean.scala
index ee4b370b4f..caa344825a 100644
--- a/akka-spring/src/main/scala/ActorFactoryBean.scala
+++ b/akka-spring/src/main/scala/ActorFactoryBean.scala
@@ -4,22 +4,19 @@
package se.scalablesolutions.akka.spring
-import java.beans.PropertyDescriptor
-import java.lang.reflect.Method
-import javax.annotation.PreDestroy
-import javax.annotation.PostConstruct
-
import org.springframework.beans.{BeanUtils,BeansException,BeanWrapper,BeanWrapperImpl}
-import org.springframework.beans.factory.BeanFactory
+import se.scalablesolutions.akka.remote.{RemoteClient, RemoteServer}
+//import org.springframework.beans.factory.BeanFactory
import org.springframework.beans.factory.config.AbstractFactoryBean
import org.springframework.context.{ApplicationContext,ApplicationContextAware}
-import org.springframework.util.ReflectionUtils
+//import org.springframework.util.ReflectionUtils
import org.springframework.util.StringUtils
import se.scalablesolutions.akka.actor.{ActorRef, AspectInitRegistry, TypedActorConfiguration, TypedActor,Actor}
import se.scalablesolutions.akka.dispatch.MessageDispatcher
import se.scalablesolutions.akka.util.{Logging, Duration}
import scala.reflect.BeanProperty
+import java.net.InetSocketAddress
/**
* Exception to use when something goes wrong during bean creation.
@@ -49,6 +46,8 @@ class ActorFactoryBean extends AbstractFactoryBean[AnyRef] with Logging with App
@BeanProperty var transactional: Boolean = false
@BeanProperty var host: String = ""
@BeanProperty var port: Int = _
+ @BeanProperty var serverManaged: Boolean = false
+ @BeanProperty var serviceName: String = ""
@BeanProperty var lifecycle: String = ""
@BeanProperty var dispatcher: DispatcherProperties = _
@BeanProperty var scope: String = VAL_SCOPE_SINGLETON
@@ -94,7 +93,16 @@ class ActorFactoryBean extends AbstractFactoryBean[AnyRef] with Logging with App
if (implementation == null || implementation == "") throw new AkkaBeansException(
"The 'implementation' part of the 'akka:typed-actor' element in the Spring config file can't be null or empty string")
- TypedActor.newInstance(interface.toClass, implementation.toClass, createConfig)
+ val typedActor: AnyRef = TypedActor.newInstance(interface.toClass, implementation.toClass, createConfig)
+ if (isRemote && serverManaged) {
+ val server = RemoteServer.getOrCreateServer(new InetSocketAddress(host, port))
+ if (serviceName.isEmpty) {
+ server.registerTypedActor(interface, typedActor)
+ } else {
+ server.registerTypedActor(serviceName, typedActor)
+ }
+ }
+ typedActor
}
/**
@@ -111,7 +119,16 @@ class ActorFactoryBean extends AbstractFactoryBean[AnyRef] with Logging with App
actorRef.makeTransactionRequired
}
if (isRemote) {
- actorRef.makeRemote(host, port)
+ if (serverManaged) {
+ val server = RemoteServer.getOrCreateServer(new InetSocketAddress(host, port))
+ if (serviceName.isEmpty) {
+ server.register(actorRef)
+ } else {
+ server.register(serviceName, actorRef)
+ }
+ } else {
+ actorRef.makeRemote(host, port)
+ }
}
if (hasDispatcher) {
if (dispatcher.dispatcherType != THREAD_BASED){
@@ -159,7 +176,7 @@ class ActorFactoryBean extends AbstractFactoryBean[AnyRef] with Logging with App
private[akka] def createConfig: TypedActorConfiguration = {
val config = new TypedActorConfiguration().timeout(Duration(timeout, "millis"))
if (transactional) config.makeTransactionRequired
- if (isRemote) config.makeRemote(host, port)
+ if (isRemote && !serverManaged) config.makeRemote(host, port)
if (hasDispatcher) {
if (dispatcher.dispatcherType != THREAD_BASED) {
config.dispatcher(dispatcherInstance())
@@ -191,3 +208,39 @@ class ActorFactoryBean extends AbstractFactoryBean[AnyRef] with Logging with App
}
}
}
+
+/**
+ * Factory bean for remote client actor-for.
+ *
+ * @author michaelkober
+ */
+class ActorForFactoryBean extends AbstractFactoryBean[AnyRef] with Logging with ApplicationContextAware {
+ import StringReflect._
+ import AkkaSpringConfigurationTags._
+
+ @BeanProperty var interface: String = ""
+ @BeanProperty var host: String = ""
+ @BeanProperty var port: Int = _
+ @BeanProperty var serviceName: String = ""
+ //@BeanProperty var scope: String = VAL_SCOPE_SINGLETON
+ @BeanProperty var applicationContext: ApplicationContext = _
+
+ override def isSingleton = false
+
+ /*
+ * @see org.springframework.beans.factory.FactoryBean#getObjectType()
+ */
+ def getObjectType: Class[AnyRef] = classOf[AnyRef]
+
+ /*
+ * @see org.springframework.beans.factory.config.AbstractFactoryBean#createInstance()
+ */
+ def createInstance: AnyRef = {
+ if (interface.isEmpty) {
+ RemoteClient.actorFor(serviceName, host, port)
+ } else {
+ RemoteClient.typedActorFor(interface.toClass, serviceName, host, port)
+ }
+ }
+}
+
diff --git a/akka-spring/src/main/scala/ActorParser.scala b/akka-spring/src/main/scala/ActorParser.scala
index 69073bd52f..9858e9ad7e 100644
--- a/akka-spring/src/main/scala/ActorParser.scala
+++ b/akka-spring/src/main/scala/ActorParser.scala
@@ -6,6 +6,7 @@ package se.scalablesolutions.akka.spring
import org.springframework.util.xml.DomUtils
import org.w3c.dom.Element
import scala.collection.JavaConversions._
+import se.scalablesolutions.akka.util.Logging
import se.scalablesolutions.akka.actor.IllegalActorStateException
@@ -27,11 +28,17 @@ trait ActorParser extends BeanParser with DispatcherParser {
val objectProperties = new ActorProperties()
val remoteElement = DomUtils.getChildElementByTagName(element, REMOTE_TAG);
val dispatcherElement = DomUtils.getChildElementByTagName(element, DISPATCHER_TAG)
- val propertyEntries = DomUtils.getChildElementsByTagName(element,PROPERTYENTRY_TAG)
+ val propertyEntries = DomUtils.getChildElementsByTagName(element, PROPERTYENTRY_TAG)
if (remoteElement != null) {
objectProperties.host = mandatory(remoteElement, HOST)
objectProperties.port = mandatory(remoteElement, PORT).toInt
+ objectProperties.serverManaged = (remoteElement.getAttribute(MANAGED_BY) != null) && (remoteElement.getAttribute(MANAGED_BY).equals(SERVER_MANAGED))
+ val serviceName = remoteElement.getAttribute(SERVICE_NAME)
+ if ((serviceName != null) && (!serviceName.isEmpty)) {
+ objectProperties.serviceName = serviceName
+ objectProperties.serverManaged = true
+ }
}
if (dispatcherElement != null) {
@@ -43,7 +50,7 @@ trait ActorParser extends BeanParser with DispatcherParser {
val entry = new PropertyEntry
entry.name = element.getAttribute("name");
entry.value = element.getAttribute("value")
- entry.ref = element.getAttribute("ref")
+ entry.ref = element.getAttribute("ref")
objectProperties.propertyEntries.add(entry)
}
@@ -59,15 +66,13 @@ trait ActorParser extends BeanParser with DispatcherParser {
objectProperties.target = mandatory(element, IMPLEMENTATION)
objectProperties.transactional = if (element.getAttribute(TRANSACTIONAL).isEmpty) false else element.getAttribute(TRANSACTIONAL).toBoolean
- if (!element.getAttribute(INTERFACE).isEmpty) {
+ if (element.hasAttribute(INTERFACE)) {
objectProperties.interface = element.getAttribute(INTERFACE)
}
-
- if (!element.getAttribute(LIFECYCLE).isEmpty) {
+ if (element.hasAttribute(LIFECYCLE)) {
objectProperties.lifecycle = element.getAttribute(LIFECYCLE)
}
-
- if (!element.getAttribute(SCOPE).isEmpty) {
+ if (element.hasAttribute(SCOPE)) {
objectProperties.scope = element.getAttribute(SCOPE)
}
@@ -75,3 +80,158 @@ trait ActorParser extends BeanParser with DispatcherParser {
}
}
+
+/**
+ * Parser trait for custom namespace configuration for RemoteClient actor-for.
+ * @author michaelkober
+ */
+trait ActorForParser extends BeanParser {
+ import AkkaSpringConfigurationTags._
+
+ /**
+ * Parses the given element and returns a ActorForProperties.
+ * @param element dom element to parse
+ * @return configuration for the typed actor
+ */
+ def parseActorFor(element: Element): ActorForProperties = {
+ val objectProperties = new ActorForProperties()
+
+ objectProperties.host = mandatory(element, HOST)
+ objectProperties.port = mandatory(element, PORT).toInt
+ objectProperties.serviceName = mandatory(element, SERVICE_NAME)
+ if (element.hasAttribute(INTERFACE)) {
+ objectProperties.interface = element.getAttribute(INTERFACE)
+ }
+ objectProperties
+ }
+
+}
+
+/**
+ * Base trait with utility methods for bean parsing.
+ */
+trait BeanParser extends Logging {
+
+ /**
+ * Get a mandatory element attribute.
+ * @param element the element with the mandatory attribute
+ * @param attribute name of the mandatory attribute
+ */
+ def mandatory(element: Element, attribute: String): String = {
+ if ((element.getAttribute(attribute) == null) || (element.getAttribute(attribute).isEmpty)) {
+ throw new IllegalArgumentException("Mandatory attribute missing: " + attribute)
+ } else {
+ element.getAttribute(attribute)
+ }
+ }
+
+ /**
+ * Get a mandatory child element.
+ * @param element the parent element
+ * @param childName name of the mandatory child element
+ */
+ def mandatoryElement(element: Element, childName: String): Element = {
+ val childElement = DomUtils.getChildElementByTagName(element, childName);
+ if (childElement == null) {
+ throw new IllegalArgumentException("Mandatory element missing: ''")
+ } else {
+ childElement
+ }
+ }
+
+}
+
+
+/**
+ * Parser trait for custom namespace for Akka dispatcher configuration.
+ * @author michaelkober
+ */
+trait DispatcherParser extends BeanParser {
+ import AkkaSpringConfigurationTags._
+
+ /**
+ * Parses the given element and returns a DispatcherProperties.
+ * @param element dom element to parse
+ * @return configuration for the dispatcher
+ */
+ def parseDispatcher(element: Element): DispatcherProperties = {
+ val properties = new DispatcherProperties()
+ var dispatcherElement = element
+ if (hasRef(element)) {
+ val ref = element.getAttribute(REF)
+ dispatcherElement = element.getOwnerDocument.getElementById(ref)
+ if (dispatcherElement == null) {
+ throw new IllegalArgumentException("Referenced dispatcher not found: '" + ref + "'")
+ }
+ }
+
+ properties.dispatcherType = mandatory(dispatcherElement, TYPE)
+ if (properties.dispatcherType == THREAD_BASED) {
+ val allowedParentNodes = "akka:typed-actor" :: "akka:untyped-actor" :: "typed-actor" :: "untyped-actor" :: Nil
+ if (!allowedParentNodes.contains(dispatcherElement.getParentNode.getNodeName)) {
+ throw new IllegalArgumentException("Thread based dispatcher must be nested in 'typed-actor' or 'untyped-actor' element!")
+ }
+ }
+
+ if (properties.dispatcherType == HAWT) { // no name for HawtDispatcher
+ properties.name = dispatcherElement.getAttribute(NAME)
+ if (dispatcherElement.hasAttribute(AGGREGATE)) {
+ properties.aggregate = dispatcherElement.getAttribute(AGGREGATE).toBoolean
+ }
+ } else {
+ properties.name = mandatory(dispatcherElement, NAME)
+ }
+
+ val threadPoolElement = DomUtils.getChildElementByTagName(dispatcherElement, THREAD_POOL_TAG);
+ if (threadPoolElement != null) {
+ if (properties.dispatcherType == THREAD_BASED) {
+ throw new IllegalArgumentException("Element 'thread-pool' not allowed for this dispatcher type.")
+ }
+ val threadPoolProperties = parseThreadPool(threadPoolElement)
+ properties.threadPool = threadPoolProperties
+ }
+ properties
+ }
+
+ /**
+ * Parses the given element and returns a ThreadPoolProperties.
+ * @param element dom element to parse
+ * @return configuration for the thread pool
+ */
+ def parseThreadPool(element: Element): ThreadPoolProperties = {
+ val properties = new ThreadPoolProperties()
+ properties.queue = element.getAttribute(QUEUE)
+ if (element.hasAttribute(CAPACITY)) {
+ properties.capacity = element.getAttribute(CAPACITY).toInt
+ }
+ if (element.hasAttribute(BOUND)) {
+ properties.bound = element.getAttribute(BOUND).toInt
+ }
+ if (element.hasAttribute(FAIRNESS)) {
+ properties.fairness = element.getAttribute(FAIRNESS).toBoolean
+ }
+ if (element.hasAttribute(CORE_POOL_SIZE)) {
+ properties.corePoolSize = element.getAttribute(CORE_POOL_SIZE).toInt
+ }
+ if (element.hasAttribute(MAX_POOL_SIZE)) {
+ properties.maxPoolSize = element.getAttribute(MAX_POOL_SIZE).toInt
+ }
+ if (element.hasAttribute(KEEP_ALIVE)) {
+ properties.keepAlive = element.getAttribute(KEEP_ALIVE).toLong
+ }
+ if (element.hasAttribute(REJECTION_POLICY)) {
+ properties.rejectionPolicy = element.getAttribute(REJECTION_POLICY)
+ }
+ if (element.hasAttribute(MAILBOX_CAPACITY)) {
+ properties.mailboxCapacity = element.getAttribute(MAILBOX_CAPACITY).toInt
+ }
+ properties
+ }
+
+ def hasRef(element: Element): Boolean = {
+ val ref = element.getAttribute(REF)
+ (ref != null) && !ref.isEmpty
+ }
+
+}
+
diff --git a/akka-spring/src/main/scala/ActorProperties.scala b/akka-spring/src/main/scala/ActorProperties.scala
index 15c7e61fe0..0f86942935 100644
--- a/akka-spring/src/main/scala/ActorProperties.scala
+++ b/akka-spring/src/main/scala/ActorProperties.scala
@@ -8,7 +8,7 @@ import org.springframework.beans.factory.support.BeanDefinitionBuilder
import AkkaSpringConfigurationTags._
/**
- * Data container for typed actor configuration data.
+ * Data container for actor configuration data.
* @author michaelkober
* @author Martin Krasser
*/
@@ -20,6 +20,8 @@ class ActorProperties {
var transactional: Boolean = false
var host: String = ""
var port: Int = _
+ var serverManaged: Boolean = false
+ var serviceName: String = ""
var lifecycle: String = ""
var scope:String = VAL_SCOPE_SINGLETON
var dispatcher: DispatcherProperties = _
@@ -34,6 +36,8 @@ class ActorProperties {
builder.addPropertyValue("typed", typed)
builder.addPropertyValue(HOST, host)
builder.addPropertyValue(PORT, port)
+ builder.addPropertyValue("serverManaged", serverManaged)
+ builder.addPropertyValue("serviceName", serviceName)
builder.addPropertyValue(TIMEOUT, timeout)
builder.addPropertyValue(IMPLEMENTATION, target)
builder.addPropertyValue(INTERFACE, interface)
@@ -45,3 +49,26 @@ class ActorProperties {
}
}
+
+/**
+ * Data container for actor configuration data.
+ * @author michaelkober
+ */
+class ActorForProperties {
+ var interface: String = ""
+ var host: String = ""
+ var port: Int = _
+ var serviceName: String = ""
+
+ /**
+ * Sets the properties to the given builder.
+ * @param builder bean definition builder
+ */
+ def setAsProperties(builder: BeanDefinitionBuilder) {
+ builder.addPropertyValue(HOST, host)
+ builder.addPropertyValue(PORT, port)
+ builder.addPropertyValue("serviceName", serviceName)
+ builder.addPropertyValue(INTERFACE, interface)
+ }
+
+}
diff --git a/akka-spring/src/main/scala/AkkaNamespaceHandler.scala b/akka-spring/src/main/scala/AkkaNamespaceHandler.scala
index a478b7b262..b1c58baa20 100644
--- a/akka-spring/src/main/scala/AkkaNamespaceHandler.scala
+++ b/akka-spring/src/main/scala/AkkaNamespaceHandler.scala
@@ -12,10 +12,11 @@ import AkkaSpringConfigurationTags._
*/
class AkkaNamespaceHandler extends NamespaceHandlerSupport {
def init = {
- registerBeanDefinitionParser(TYPED_ACTOR_TAG, new TypedActorBeanDefinitionParser());
- registerBeanDefinitionParser(UNTYPED_ACTOR_TAG, new UntypedActorBeanDefinitionParser());
- registerBeanDefinitionParser(SUPERVISION_TAG, new SupervisionBeanDefinitionParser());
- registerBeanDefinitionParser(DISPATCHER_TAG, new DispatcherBeanDefinitionParser());
- registerBeanDefinitionParser(CAMEL_SERVICE_TAG, new CamelServiceBeanDefinitionParser);
+ registerBeanDefinitionParser(TYPED_ACTOR_TAG, new TypedActorBeanDefinitionParser())
+ registerBeanDefinitionParser(UNTYPED_ACTOR_TAG, new UntypedActorBeanDefinitionParser())
+ registerBeanDefinitionParser(SUPERVISION_TAG, new SupervisionBeanDefinitionParser())
+ registerBeanDefinitionParser(DISPATCHER_TAG, new DispatcherBeanDefinitionParser())
+ registerBeanDefinitionParser(CAMEL_SERVICE_TAG, new CamelServiceBeanDefinitionParser)
+ registerBeanDefinitionParser(ACTOR_FOR_TAG, new ActorForBeanDefinitionParser());
}
}
diff --git a/akka-spring/src/main/scala/AkkaSpringConfigurationTags.scala b/akka-spring/src/main/scala/AkkaSpringConfigurationTags.scala
index 4037f2c3ba..0e4de3576f 100644
--- a/akka-spring/src/main/scala/AkkaSpringConfigurationTags.scala
+++ b/akka-spring/src/main/scala/AkkaSpringConfigurationTags.scala
@@ -19,6 +19,7 @@ object AkkaSpringConfigurationTags {
val DISPATCHER_TAG = "dispatcher"
val PROPERTYENTRY_TAG = "property"
val CAMEL_SERVICE_TAG = "camel-service"
+ val ACTOR_FOR_TAG = "actor-for"
// actor sub tags
val REMOTE_TAG = "remote"
@@ -45,6 +46,8 @@ object AkkaSpringConfigurationTags {
val TRANSACTIONAL = "transactional"
val HOST = "host"
val PORT = "port"
+ val MANAGED_BY = "managed-by"
+ val SERVICE_NAME = "service-name"
val LIFECYCLE = "lifecycle"
val SCOPE = "scope"
@@ -101,4 +104,8 @@ object AkkaSpringConfigurationTags {
val THREAD_BASED = "thread-based"
val HAWT = "hawt"
+ // managed by types
+ val SERVER_MANAGED = "server"
+ val CLIENT_MANAGED = "client"
+
}
diff --git a/akka-spring/src/main/scala/BeanParser.scala b/akka-spring/src/main/scala/BeanParser.scala
deleted file mode 100644
index 1bbba9f09f..0000000000
--- a/akka-spring/src/main/scala/BeanParser.scala
+++ /dev/null
@@ -1,42 +0,0 @@
-/**
- * Copyright (C) 2009-2010 Scalable Solutions AB
- */
-package se.scalablesolutions.akka.spring
-
-import se.scalablesolutions.akka.util.Logging
-import org.w3c.dom.Element
-import org.springframework.util.xml.DomUtils
-
-/**
- * Base trait with utility methods for bean parsing.
- */
-trait BeanParser extends Logging {
-
- /**
- * Get a mandatory element attribute.
- * @param element the element with the mandatory attribute
- * @param attribute name of the mandatory attribute
- */
- def mandatory(element: Element, attribute: String): String = {
- if ((element.getAttribute(attribute) == null) || (element.getAttribute(attribute).isEmpty)) {
- throw new IllegalArgumentException("Mandatory attribute missing: " + attribute)
- } else {
- element.getAttribute(attribute)
- }
- }
-
- /**
- * Get a mandatory child element.
- * @param element the parent element
- * @param childName name of the mandatory child element
- */
- def mandatoryElement(element: Element, childName: String): Element = {
- val childElement = DomUtils.getChildElementByTagName(element, childName);
- if (childElement == null) {
- throw new IllegalArgumentException("Mandatory element missing: ''")
- } else {
- childElement
- }
- }
-
-}
diff --git a/akka-spring/src/main/scala/DispatcherParser.scala b/akka-spring/src/main/scala/DispatcherParser.scala
deleted file mode 100644
index 4eaa4e05a7..0000000000
--- a/akka-spring/src/main/scala/DispatcherParser.scala
+++ /dev/null
@@ -1,100 +0,0 @@
-/**
- * Copyright (C) 2009-2010 Scalable Solutions AB
- */
-package se.scalablesolutions.akka.spring
-
-import org.w3c.dom.Element
-import org.springframework.util.xml.DomUtils
-
-/**
- * Parser trait for custom namespace for Akka dispatcher configuration.
- * @author michaelkober
- */
-trait DispatcherParser extends BeanParser {
- import AkkaSpringConfigurationTags._
-
- /**
- * Parses the given element and returns a DispatcherProperties.
- * @param element dom element to parse
- * @return configuration for the dispatcher
- */
- def parseDispatcher(element: Element): DispatcherProperties = {
- val properties = new DispatcherProperties()
- var dispatcherElement = element
- if (hasRef(element)) {
- val ref = element.getAttribute(REF)
- dispatcherElement = element.getOwnerDocument.getElementById(ref)
- if (dispatcherElement == null) {
- throw new IllegalArgumentException("Referenced dispatcher not found: '" + ref + "'")
- }
- }
-
- properties.dispatcherType = mandatory(dispatcherElement, TYPE)
- if (properties.dispatcherType == THREAD_BASED) {
- val allowedParentNodes = "akka:typed-actor" :: "akka:untyped-actor" :: "typed-actor" :: "untyped-actor" :: Nil
- if (!allowedParentNodes.contains(dispatcherElement.getParentNode.getNodeName)) {
- throw new IllegalArgumentException("Thread based dispatcher must be nested in 'typed-actor' or 'untyped-actor' element!")
- }
- }
-
- if (properties.dispatcherType == HAWT) { // no name for HawtDispatcher
- properties.name = dispatcherElement.getAttribute(NAME)
- if (dispatcherElement.hasAttribute(AGGREGATE)) {
- properties.aggregate = dispatcherElement.getAttribute(AGGREGATE).toBoolean
- }
- } else {
- properties.name = mandatory(dispatcherElement, NAME)
- }
-
- val threadPoolElement = DomUtils.getChildElementByTagName(dispatcherElement, THREAD_POOL_TAG);
- if (threadPoolElement != null) {
- if (properties.dispatcherType == THREAD_BASED) {
- throw new IllegalArgumentException("Element 'thread-pool' not allowed for this dispatcher type.")
- }
- val threadPoolProperties = parseThreadPool(threadPoolElement)
- properties.threadPool = threadPoolProperties
- }
- properties
- }
-
- /**
- * Parses the given element and returns a ThreadPoolProperties.
- * @param element dom element to parse
- * @return configuration for the thread pool
- */
- def parseThreadPool(element: Element): ThreadPoolProperties = {
- val properties = new ThreadPoolProperties()
- properties.queue = element.getAttribute(QUEUE)
- if (element.hasAttribute(CAPACITY)) {
- properties.capacity = element.getAttribute(CAPACITY).toInt
- }
- if (element.hasAttribute(BOUND)) {
- properties.bound = element.getAttribute(BOUND).toInt
- }
- if (element.hasAttribute(FAIRNESS)) {
- properties.fairness = element.getAttribute(FAIRNESS).toBoolean
- }
- if (element.hasAttribute(CORE_POOL_SIZE)) {
- properties.corePoolSize = element.getAttribute(CORE_POOL_SIZE).toInt
- }
- if (element.hasAttribute(MAX_POOL_SIZE)) {
- properties.maxPoolSize = element.getAttribute(MAX_POOL_SIZE).toInt
- }
- if (element.hasAttribute(KEEP_ALIVE)) {
- properties.keepAlive = element.getAttribute(KEEP_ALIVE).toLong
- }
- if (element.hasAttribute(REJECTION_POLICY)) {
- properties.rejectionPolicy = element.getAttribute(REJECTION_POLICY)
- }
- if (element.hasAttribute(MAILBOX_CAPACITY)) {
- properties.mailboxCapacity = element.getAttribute(MAILBOX_CAPACITY).toInt
- }
- properties
- }
-
- def hasRef(element: Element): Boolean = {
- val ref = element.getAttribute(REF)
- (ref != null) && !ref.isEmpty
- }
-
-}
diff --git a/akka-spring/src/main/scala/PropertyEntries.scala b/akka-spring/src/main/scala/PropertyEntries.scala
index bf1898a805..9a7dc098de 100644
--- a/akka-spring/src/main/scala/PropertyEntries.scala
+++ b/akka-spring/src/main/scala/PropertyEntries.scala
@@ -18,3 +18,19 @@ class PropertyEntries {
entryList.append(entry)
}
}
+
+/**
+ * Represents a property element
+ * @author Johan Rask
+ */
+class PropertyEntry {
+ var name: String = _
+ var value: String = null
+ var ref: String = null
+
+
+ override def toString(): String = {
+ format("name = %s,value = %s, ref = %s", name, value, ref)
+ }
+}
+
diff --git a/akka-spring/src/main/scala/PropertyEntry.scala b/akka-spring/src/main/scala/PropertyEntry.scala
deleted file mode 100644
index 9fe6357fc0..0000000000
--- a/akka-spring/src/main/scala/PropertyEntry.scala
+++ /dev/null
@@ -1,19 +0,0 @@
-/**
- * Copyright (C) 2009-2010 Scalable Solutions AB
- */
-package se.scalablesolutions.akka.spring
-
-/**
- * Represents a property element
- * @author Johan Rask
- */
-class PropertyEntry {
- var name: String = _
- var value: String = null
- var ref: String = null
-
-
- override def toString(): String = {
- format("name = %s,value = %s, ref = %s", name, value, ref)
- }
-}
diff --git a/akka-spring/src/main/scala/TypedActorBeanDefinitionParser.scala b/akka-spring/src/main/scala/TypedActorBeanDefinitionParser.scala
deleted file mode 100644
index e8e0cef7d4..0000000000
--- a/akka-spring/src/main/scala/TypedActorBeanDefinitionParser.scala
+++ /dev/null
@@ -1,31 +0,0 @@
-/**
- * Copyright (C) 2009-2010 Scalable Solutions AB
- */
-package se.scalablesolutions.akka.spring
-
-import org.springframework.beans.factory.support.BeanDefinitionBuilder
-import org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser
-import org.springframework.beans.factory.xml.ParserContext
-import AkkaSpringConfigurationTags._
-import org.w3c.dom.Element
-
-
-/**
- * Parser for custom namespace configuration.
- * @author michaelkober
- */
-class TypedActorBeanDefinitionParser extends AbstractSingleBeanDefinitionParser with ActorParser {
- /*
- * @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#doParse(org.w3c.dom.Element, org.springframework.beans.factory.xml.ParserContext, org.springframework.beans.factory.support.BeanDefinitionBuilder)
- */
- override def doParse(element: Element, parserContext: ParserContext, builder: BeanDefinitionBuilder) {
- val typedActorConf = parseActor(element)
- typedActorConf.typed = TYPED_ACTOR_TAG
- typedActorConf.setAsProperties(builder)
- }
-
- /*
- * @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#getBeanClass(org.w3c.dom.Element)
- */
- override def getBeanClass(element: Element): Class[_] = classOf[ActorFactoryBean]
-}
diff --git a/akka-spring/src/main/scala/UntypedActorBeanDefinitionParser.scala b/akka-spring/src/main/scala/UntypedActorBeanDefinitionParser.scala
deleted file mode 100644
index 752e18559f..0000000000
--- a/akka-spring/src/main/scala/UntypedActorBeanDefinitionParser.scala
+++ /dev/null
@@ -1,31 +0,0 @@
-/**
- * Copyright (C) 2009-2010 Scalable Solutions AB
- */
-package se.scalablesolutions.akka.spring
-
-import org.springframework.beans.factory.support.BeanDefinitionBuilder
-import org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser
-import org.springframework.beans.factory.xml.ParserContext
-import AkkaSpringConfigurationTags._
-import org.w3c.dom.Element
-
-
-/**
- * Parser for custom namespace configuration.
- * @author michaelkober
- */
-class UntypedActorBeanDefinitionParser extends AbstractSingleBeanDefinitionParser with ActorParser {
- /*
- * @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#doParse(org.w3c.dom.Element, org.springframework.beans.factory.xml.ParserContext, org.springframework.beans.factory.support.BeanDefinitionBuilder)
- */
- override def doParse(element: Element, parserContext: ParserContext, builder: BeanDefinitionBuilder) {
- val untypedActorConf = parseActor(element)
- untypedActorConf.typed = UNTYPED_ACTOR_TAG
- untypedActorConf.setAsProperties(builder)
- }
-
- /*
- * @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#getBeanClass(org.w3c.dom.Element)
- */
- override def getBeanClass(element: Element): Class[_] = classOf[ActorFactoryBean]
-}
diff --git a/akka-spring/src/test/java/se/scalablesolutions/akka/spring/foo/IMyPojo.java b/akka-spring/src/test/java/se/scalablesolutions/akka/spring/foo/IMyPojo.java
index f2c5e24884..5a2a272e6c 100644
--- a/akka-spring/src/test/java/se/scalablesolutions/akka/spring/foo/IMyPojo.java
+++ b/akka-spring/src/test/java/se/scalablesolutions/akka/spring/foo/IMyPojo.java
@@ -8,14 +8,12 @@ package se.scalablesolutions.akka.spring.foo;
* To change this template use File | Settings | File Templates.
*/
public interface IMyPojo {
+ public void oneWay(String message);
+
public String getFoo();
- public String getBar();
-
- public void preRestart();
-
- public void postRestart();
-
public String longRunning();
+
+
}
diff --git a/akka-spring/src/test/java/se/scalablesolutions/akka/spring/foo/MyPojo.java b/akka-spring/src/test/java/se/scalablesolutions/akka/spring/foo/MyPojo.java
index fe3e9ba767..8f610eef63 100644
--- a/akka-spring/src/test/java/se/scalablesolutions/akka/spring/foo/MyPojo.java
+++ b/akka-spring/src/test/java/se/scalablesolutions/akka/spring/foo/MyPojo.java
@@ -1,42 +1,34 @@
package se.scalablesolutions.akka.spring.foo;
-import se.scalablesolutions.akka.actor.*;
+import se.scalablesolutions.akka.actor.TypedActor;
-public class MyPojo extends TypedActor implements IMyPojo{
+import java.util.concurrent.CountDownLatch;
- private String foo;
- private String bar;
+public class MyPojo extends TypedActor implements IMyPojo {
+
+ public static CountDownLatch latch = new CountDownLatch(1);
+ public static String lastOneWayMessage = null;
+ private String foo = "foo";
- public MyPojo() {
- this.foo = "foo";
- this.bar = "bar";
- }
+ public MyPojo() {
+ }
+ public String getFoo() {
+ return foo;
+ }
- public String getFoo() {
- return foo;
- }
+ public void oneWay(String message) {
+ lastOneWayMessage = message;
+ latch.countDown();
+ }
-
- public String getBar() {
- return bar;
- }
-
- public void preRestart() {
- System.out.println("pre restart");
- }
-
- public void postRestart() {
- System.out.println("post restart");
- }
-
- public String longRunning() {
- try {
- Thread.sleep(6000);
- } catch (InterruptedException e) {
- }
- return "this took long";
+ public String longRunning() {
+ try {
+ Thread.sleep(6000);
+ } catch (InterruptedException e) {
}
+ return "this took long";
+ }
}
diff --git a/akka-spring/src/test/java/se/scalablesolutions/akka/spring/foo/PingActor.java b/akka-spring/src/test/java/se/scalablesolutions/akka/spring/foo/PingActor.java
index e447b26a28..3063a1b529 100644
--- a/akka-spring/src/test/java/se/scalablesolutions/akka/spring/foo/PingActor.java
+++ b/akka-spring/src/test/java/se/scalablesolutions/akka/spring/foo/PingActor.java
@@ -6,6 +6,8 @@ import se.scalablesolutions.akka.actor.ActorRef;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
+import java.util.concurrent.CountDownLatch;
+
/**
* test class
@@ -14,6 +16,9 @@ public class PingActor extends UntypedActor implements ApplicationContextAware {
private String stringFromVal;
private String stringFromRef;
+ public static String lastMessage = null;
+ public static CountDownLatch latch = new CountDownLatch(1);
+
private boolean gotApplicationContext = false;
@@ -42,7 +47,6 @@ public class PingActor extends UntypedActor implements ApplicationContextAware {
stringFromRef = s;
}
-
private String longRunning() {
try {
Thread.sleep(6000);
@@ -53,12 +57,12 @@ public class PingActor extends UntypedActor implements ApplicationContextAware {
public void onReceive(Object message) throws Exception {
if (message instanceof String) {
- System.out.println("Ping received String message: " + message);
+ lastMessage = (String) message;
if (message.equals("longRunning")) {
- System.out.println("### starting pong");
ActorRef pongActor = UntypedActor.actorOf(PongActor.class).start();
pongActor.sendRequestReply("longRunning", getContext());
}
+ latch.countDown();
} else {
throw new IllegalArgumentException("Unknown message: " + message);
}
diff --git a/akka-spring/src/test/resources/server-managed-config.xml b/akka-spring/src/test/resources/server-managed-config.xml
new file mode 100644
index 0000000000..128b16c8b6
--- /dev/null
+++ b/akka-spring/src/test/resources/server-managed-config.xml
@@ -0,0 +1,57 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/akka-spring/src/test/resources/typed-actor-config.xml b/akka-spring/src/test/resources/typed-actor-config.xml
index faca749469..989884e4fa 100644
--- a/akka-spring/src/test/resources/typed-actor-config.xml
+++ b/akka-spring/src/test/resources/typed-actor-config.xml
@@ -37,7 +37,7 @@ http://scalablesolutions.se/akka/akka-1.0-SNAPSHOT.xsd">
implementation="se.scalablesolutions.akka.spring.foo.MyPojo"
timeout="2000"
transactional="true">
-
+
-
+
+
+
+ val props = parser.parseActor(dom(xml).getDocumentElement);
+ assert(props != null)
+ assert(props.host === "com.some.host")
+ assert(props.port === 9999)
+ assert(props.serviceName === "my-service")
+ assert(props.serverManaged)
}
}
}
diff --git a/akka-spring/src/test/scala/TypedActorSpringFeatureTest.scala b/akka-spring/src/test/scala/TypedActorSpringFeatureTest.scala
index 8767b2e75a..3cdcd17cb0 100644
--- a/akka-spring/src/test/scala/TypedActorSpringFeatureTest.scala
+++ b/akka-spring/src/test/scala/TypedActorSpringFeatureTest.scala
@@ -4,10 +4,8 @@
package se.scalablesolutions.akka.spring
-import foo.{IMyPojo, MyPojo}
+import foo.{PingActor, IMyPojo, MyPojo}
import se.scalablesolutions.akka.dispatch.FutureTimeoutException
-import se.scalablesolutions.akka.remote.RemoteNode
-import org.scalatest.FeatureSpec
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
@@ -16,13 +14,52 @@ import org.springframework.beans.factory.xml.XmlBeanDefinitionReader
import org.springframework.context.ApplicationContext
import org.springframework.context.support.ClassPathXmlApplicationContext
import org.springframework.core.io.{ClassPathResource, Resource}
+import org.scalatest.{BeforeAndAfterAll, FeatureSpec}
+import se.scalablesolutions.akka.remote.{RemoteClient, RemoteServer, RemoteNode}
+import java.util.concurrent.CountDownLatch
+import se.scalablesolutions.akka.actor.{TypedActor, RemoteTypedActorOne, Actor}
+import se.scalablesolutions.akka.actor.remote.RemoteTypedActorOneImpl
/**
* Tests for spring configuration of typed actors.
* @author michaelkober
*/
@RunWith(classOf[JUnitRunner])
-class TypedActorSpringFeatureTest extends FeatureSpec with ShouldMatchers {
+class TypedActorSpringFeatureTest extends FeatureSpec with ShouldMatchers with BeforeAndAfterAll {
+
+ var server1: RemoteServer = null
+ var server2: RemoteServer = null
+
+ override def beforeAll = {
+ val actor = Actor.actorOf[PingActor] // FIXME: remove this line when ticket 425 is fixed
+ server1 = new RemoteServer()
+ server1.start("localhost", 9990)
+ server2 = new RemoteServer()
+ server2.start("localhost", 9992)
+
+ val typedActor = TypedActor.newInstance(classOf[RemoteTypedActorOne], classOf[RemoteTypedActorOneImpl], 1000)
+ server1.registerTypedActor("typed-actor-service", typedActor)
+ }
+
+ // make sure the servers shutdown cleanly after the test has finished
+ override def afterAll = {
+ try {
+ server1.shutdown
+ server2.shutdown
+ RemoteClient.shutdownAll
+ Thread.sleep(1000)
+ } catch {
+ case e => ()
+ }
+ }
+
+ def getTypedActorFromContext(config: String, id: String) : IMyPojo = {
+ MyPojo.latch = new CountDownLatch(1)
+ val context = new ClassPathXmlApplicationContext(config)
+ val myPojo: IMyPojo = context.getBean(id).asInstanceOf[IMyPojo]
+ myPojo
+ }
+
feature("parse Spring application context") {
scenario("akka:typed-actor and akka:supervision and akka:dispatcher can be used as top level elements") {
@@ -37,41 +74,79 @@ class TypedActorSpringFeatureTest extends FeatureSpec with ShouldMatchers {
}
scenario("get a typed actor") {
- val context = new ClassPathXmlApplicationContext("/typed-actor-config.xml")
- val myPojo = context.getBean("simple-typed-actor").asInstanceOf[IMyPojo]
- var msg = myPojo.getFoo()
- msg += myPojo.getBar()
- assert(msg === "foobar")
+ val myPojo = getTypedActorFromContext("/typed-actor-config.xml", "simple-typed-actor")
+ assert(myPojo.getFoo() === "foo")
+ myPojo.oneWay("hello 1")
+ MyPojo.latch.await
+ assert(MyPojo.lastOneWayMessage === "hello 1")
}
scenario("FutureTimeoutException when timed out") {
- val context = new ClassPathXmlApplicationContext("/typed-actor-config.xml")
- val myPojo = context.getBean("simple-typed-actor").asInstanceOf[IMyPojo]
+ val myPojo = getTypedActorFromContext("/typed-actor-config.xml", "simple-typed-actor")
evaluating {myPojo.longRunning()} should produce[FutureTimeoutException]
-
}
scenario("typed-actor with timeout") {
- val context = new ClassPathXmlApplicationContext("/typed-actor-config.xml")
- val myPojo = context.getBean("simple-typed-actor-long-timeout").asInstanceOf[IMyPojo]
+ val myPojo = getTypedActorFromContext("/typed-actor-config.xml", "simple-typed-actor-long-timeout")
assert(myPojo.longRunning() === "this took long");
}
scenario("transactional typed-actor") {
- val context = new ClassPathXmlApplicationContext("/typed-actor-config.xml")
- val myPojo = context.getBean("transactional-typed-actor").asInstanceOf[IMyPojo]
- var msg = myPojo.getFoo()
- msg += myPojo.getBar()
- assert(msg === "foobar")
+ val myPojo = getTypedActorFromContext("/typed-actor-config.xml", "transactional-typed-actor")
+ assert(myPojo.getFoo() === "foo")
+ myPojo.oneWay("hello 2")
+ MyPojo.latch.await
+ assert(MyPojo.lastOneWayMessage === "hello 2")
}
scenario("get a remote typed-actor") {
- RemoteNode.start
- Thread.sleep(1000)
- val context = new ClassPathXmlApplicationContext("/typed-actor-config.xml")
- val myPojo = context.getBean("remote-typed-actor").asInstanceOf[IMyPojo]
- assert(myPojo.getFoo === "foo")
+ val myPojo = getTypedActorFromContext("/typed-actor-config.xml", "remote-typed-actor")
+ assert(myPojo.getFoo() === "foo")
+ myPojo.oneWay("hello 3")
+ MyPojo.latch.await
+ assert(MyPojo.lastOneWayMessage === "hello 3")
}
+
+ scenario("get a client-managed-remote-typed-actor") {
+ val myPojo = getTypedActorFromContext("/server-managed-config.xml", "client-managed-remote-typed-actor")
+ assert(myPojo.getFoo() === "foo")
+ myPojo.oneWay("hello client-managed-remote-typed-actor")
+ MyPojo.latch.await
+ assert(MyPojo.lastOneWayMessage === "hello client-managed-remote-typed-actor")
+ }
+
+ scenario("get a server-managed-remote-typed-actor") {
+ val serverPojo = getTypedActorFromContext("/server-managed-config.xml", "server-managed-remote-typed-actor")
+ //
+ val myPojoProxy = RemoteClient.typedActorFor(classOf[IMyPojo], classOf[IMyPojo].getName, 5000L, "localhost", 9990)
+ assert(myPojoProxy.getFoo() === "foo")
+ myPojoProxy.oneWay("hello server-managed-remote-typed-actor")
+ MyPojo.latch.await
+ assert(MyPojo.lastOneWayMessage === "hello server-managed-remote-typed-actor")
+ }
+
+ scenario("get a server-managed-remote-typed-actor-custom-id") {
+ val serverPojo = getTypedActorFromContext("/server-managed-config.xml", "server-managed-remote-typed-actor-custom-id")
+ //
+ val myPojoProxy = RemoteClient.typedActorFor(classOf[IMyPojo], "mypojo-service", 5000L, "localhost", 9990)
+ assert(myPojoProxy.getFoo() === "foo")
+ myPojoProxy.oneWay("hello server-managed-remote-typed-actor 2")
+ MyPojo.latch.await
+ assert(MyPojo.lastOneWayMessage === "hello server-managed-remote-typed-actor 2")
+ }
+
+ scenario("get a client proxy for server-managed-remote-typed-actor") {
+ MyPojo.latch = new CountDownLatch(1)
+ val context = new ClassPathXmlApplicationContext("/server-managed-config.xml")
+ val myPojo: IMyPojo = context.getBean("server-managed-remote-typed-actor-custom-id").asInstanceOf[IMyPojo]
+ // get client proxy from spring context
+ val myPojoProxy = context.getBean("typed-client-1").asInstanceOf[IMyPojo]
+ assert(myPojoProxy.getFoo() === "foo")
+ myPojoProxy.oneWay("hello")
+ MyPojo.latch.await
+ }
+
+
}
}
diff --git a/akka-spring/src/test/scala/UntypedActorSpringFeatureTest.scala b/akka-spring/src/test/scala/UntypedActorSpringFeatureTest.scala
index 11246cdc91..0397d30bf0 100644
--- a/akka-spring/src/test/scala/UntypedActorSpringFeatureTest.scala
+++ b/akka-spring/src/test/scala/UntypedActorSpringFeatureTest.scala
@@ -6,74 +6,146 @@ package se.scalablesolutions.akka.spring
import foo.PingActor
import se.scalablesolutions.akka.dispatch.ExecutorBasedEventDrivenWorkStealingDispatcher
-import se.scalablesolutions.akka.remote.RemoteNode
-import se.scalablesolutions.akka.actor.ActorRef
-import org.scalatest.FeatureSpec
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
-import org.springframework.context.ApplicationContext
import org.springframework.context.support.ClassPathXmlApplicationContext
+import se.scalablesolutions.akka.remote.{RemoteClient, RemoteServer}
+import org.scalatest.{BeforeAndAfterAll, FeatureSpec}
+import java.util.concurrent.CountDownLatch
+import se.scalablesolutions.akka.actor.{RemoteActorRef, ActorRegistry, Actor, ActorRef}
/**
* Tests for spring configuration of typed actors.
* @author michaelkober
*/
@RunWith(classOf[JUnitRunner])
-class UntypedActorSpringFeatureTest extends FeatureSpec with ShouldMatchers {
+class UntypedActorSpringFeatureTest extends FeatureSpec with ShouldMatchers with BeforeAndAfterAll {
+
+ var server1: RemoteServer = null
+ var server2: RemoteServer = null
+
+
+ override def beforeAll = {
+ val actor = Actor.actorOf[PingActor] // FIXME: remove this line when ticket 425 is fixed
+ server1 = new RemoteServer()
+ server1.start("localhost", 9990)
+ server2 = new RemoteServer()
+ server2.start("localhost", 9992)
+ }
+
+ // make sure the servers shutdown cleanly after the test has finished
+ override def afterAll = {
+ try {
+ server1.shutdown
+ server2.shutdown
+ RemoteClient.shutdownAll
+ Thread.sleep(1000)
+ } catch {
+ case e => ()
+ }
+ }
+
+
+ def getPingActorFromContext(config: String, id: String) : ActorRef = {
+ PingActor.latch = new CountDownLatch(1)
+ val context = new ClassPathXmlApplicationContext(config)
+ val pingActor = context.getBean(id).asInstanceOf[ActorRef]
+ assert(pingActor.getActorClassName() === "se.scalablesolutions.akka.spring.foo.PingActor")
+ pingActor.start()
+ }
+
+
feature("parse Spring application context") {
- scenario("get an untyped actor") {
- val context = new ClassPathXmlApplicationContext("/untyped-actor-config.xml")
- val myactor = context.getBean("simple-untyped-actor").asInstanceOf[ActorRef]
- assert(myactor.getActorClassName() === "se.scalablesolutions.akka.spring.foo.PingActor")
- myactor.start()
+ scenario("get a untyped actor") {
+ val myactor = getPingActorFromContext("/untyped-actor-config.xml", "simple-untyped-actor")
myactor.sendOneWay("Hello")
+ PingActor.latch.await
+ assert(PingActor.lastMessage === "Hello")
assert(myactor.isDefinedAt("some string message"))
}
scenario("untyped-actor with timeout") {
- val context = new ClassPathXmlApplicationContext("/untyped-actor-config.xml")
- val myactor = context.getBean("simple-untyped-actor-long-timeout").asInstanceOf[ActorRef]
- assert(myactor.getActorClassName() === "se.scalablesolutions.akka.spring.foo.PingActor")
- myactor.start()
- myactor.sendOneWay("Hello")
+ val myactor = getPingActorFromContext("/untyped-actor-config.xml", "simple-untyped-actor-long-timeout")
assert(myactor.getTimeout() === 10000)
+ myactor.sendOneWay("Hello 2")
+ PingActor.latch.await
+ assert(PingActor.lastMessage === "Hello 2")
}
scenario("transactional untyped-actor") {
- val context = new ClassPathXmlApplicationContext("/untyped-actor-config.xml")
- val myactor = context.getBean("transactional-untyped-actor").asInstanceOf[ActorRef]
- assert(myactor.getActorClassName() === "se.scalablesolutions.akka.spring.foo.PingActor")
- myactor.start()
- myactor.sendOneWay("Hello")
- assert(myactor.isDefinedAt("some string message"))
+ val myactor = getPingActorFromContext("/untyped-actor-config.xml", "transactional-untyped-actor")
+ myactor.sendOneWay("Hello 3")
+ PingActor.latch.await
+ assert(PingActor.lastMessage === "Hello 3")
}
scenario("get a remote typed-actor") {
- RemoteNode.start
- Thread.sleep(1000)
- val context = new ClassPathXmlApplicationContext("/untyped-actor-config.xml")
- val myactor = context.getBean("remote-untyped-actor").asInstanceOf[ActorRef]
- assert(myactor.getActorClassName() === "se.scalablesolutions.akka.spring.foo.PingActor")
- myactor.start()
- myactor.sendOneWay("Hello")
- assert(myactor.isDefinedAt("some string message"))
+ val myactor = getPingActorFromContext("/untyped-actor-config.xml", "remote-untyped-actor")
+ myactor.sendOneWay("Hello 4")
assert(myactor.getRemoteAddress().isDefined)
assert(myactor.getRemoteAddress().get.getHostName() === "localhost")
- assert(myactor.getRemoteAddress().get.getPort() === 9999)
+ assert(myactor.getRemoteAddress().get.getPort() === 9992)
+ PingActor.latch.await
+ assert(PingActor.lastMessage === "Hello 4")
}
scenario("untyped-actor with custom dispatcher") {
- val context = new ClassPathXmlApplicationContext("/untyped-actor-config.xml")
- val myactor = context.getBean("untyped-actor-with-dispatcher").asInstanceOf[ActorRef]
- assert(myactor.getActorClassName() === "se.scalablesolutions.akka.spring.foo.PingActor")
- myactor.start()
- myactor.sendOneWay("Hello")
+ val myactor = getPingActorFromContext("/untyped-actor-config.xml", "untyped-actor-with-dispatcher")
assert(myactor.getTimeout() === 1000)
assert(myactor.getDispatcher.isInstanceOf[ExecutorBasedEventDrivenWorkStealingDispatcher])
+ myactor.sendOneWay("Hello 5")
+ PingActor.latch.await
+ assert(PingActor.lastMessage === "Hello 5")
}
+
+ scenario("create client managed remote untyped-actor") {
+ val myactor = getPingActorFromContext("/server-managed-config.xml", "client-managed-remote-untyped-actor")
+ myactor.sendOneWay("Hello client managed remote untyped-actor")
+ PingActor.latch.await
+ assert(PingActor.lastMessage === "Hello client managed remote untyped-actor")
+ assert(myactor.getRemoteAddress().isDefined)
+ assert(myactor.getRemoteAddress().get.getHostName() === "localhost")
+ assert(myactor.getRemoteAddress().get.getPort() === 9990)
+ }
+
+ scenario("create server managed remote untyped-actor") {
+ val myactor = getPingActorFromContext("/server-managed-config.xml", "server-managed-remote-untyped-actor")
+ val nrOfActors = ActorRegistry.actors.length
+ val actorRef = RemoteClient.actorFor("se.scalablesolutions.akka.spring.foo.PingActor", "localhost", 9990)
+ actorRef.sendOneWay("Hello server managed remote untyped-actor")
+ PingActor.latch.await
+ assert(PingActor.lastMessage === "Hello server managed remote untyped-actor")
+ assert(ActorRegistry.actors.length === nrOfActors)
+ }
+
+ scenario("create server managed remote untyped-actor with custom service id") {
+ val myactor = getPingActorFromContext("/server-managed-config.xml", "server-managed-remote-untyped-actor-custom-id")
+ val nrOfActors = ActorRegistry.actors.length
+ val actorRef = RemoteClient.actorFor("ping-service", "localhost", 9990)
+ actorRef.sendOneWay("Hello server managed remote untyped-actor")
+ PingActor.latch.await
+ assert(PingActor.lastMessage === "Hello server managed remote untyped-actor")
+ assert(ActorRegistry.actors.length === nrOfActors)
+ }
+
+ scenario("get client actor for server managed remote untyped-actor") {
+ PingActor.latch = new CountDownLatch(1)
+ val context = new ClassPathXmlApplicationContext("/server-managed-config.xml")
+ val pingActor = context.getBean("server-managed-remote-untyped-actor-custom-id").asInstanceOf[ActorRef]
+ assert(pingActor.getActorClassName() === "se.scalablesolutions.akka.spring.foo.PingActor")
+ pingActor.start()
+ val nrOfActors = ActorRegistry.actors.length
+ // get client actor ref from spring context
+ val actorRef = context.getBean("client-1").asInstanceOf[ActorRef]
+ assert(actorRef.isInstanceOf[RemoteActorRef])
+ actorRef.sendOneWay("Hello")
+ PingActor.latch.await
+ assert(ActorRegistry.actors.length === nrOfActors)
+ }
+
}
}
diff --git a/akka-typed-actor/src/main/scala/actor/TypedActor.scala b/akka-typed-actor/src/main/scala/actor/TypedActor.scala
index d3c6a56f9f..c3457cb43b 100644
--- a/akka-typed-actor/src/main/scala/actor/TypedActor.scala
+++ b/akka-typed-actor/src/main/scala/actor/TypedActor.scala
@@ -16,9 +16,8 @@ import org.codehaus.aspectwerkz.proxy.Proxy
import org.codehaus.aspectwerkz.annotation.{Aspect, Around}
import java.net.InetSocketAddress
-import java.lang.reflect.{InvocationTargetException, Method, Field}
-
import scala.reflect.BeanProperty
+import java.lang.reflect.{Method, Field, InvocationHandler, Proxy => JProxy}
/**
* TypedActor is a type-safe actor made out of a POJO with interface.
@@ -390,7 +389,8 @@ object TypedActor extends Logging {
typedActor.initialize(proxy)
if (config._messageDispatcher.isDefined) actorRef.dispatcher = config._messageDispatcher.get
if (config._threadBasedDispatcher.isDefined) actorRef.dispatcher = Dispatchers.newThreadBasedDispatcher(actorRef)
- AspectInitRegistry.register(proxy, AspectInit(intfClass, typedActor, actorRef, None, config.timeout))
+ if (config._host.isDefined) actorRef.makeRemote(config._host.get)
+ AspectInitRegistry.register(proxy, AspectInit(intfClass, typedActor, actorRef, config._host, config.timeout))
actorRef.start
proxy.asInstanceOf[T]
}
@@ -408,24 +408,47 @@ object TypedActor extends Logging {
proxy.asInstanceOf[T]
}
-/*
- // NOTE: currently not used - but keep it around
- private[akka] def newInstance[T <: TypedActor](targetClass: Class[T],
- remoteAddress: Option[InetSocketAddress], timeout: Long): T = {
- val proxy = {
- val instance = Proxy.newInstance(targetClass, true, false)
- if (instance.isInstanceOf[TypedActor]) instance.asInstanceOf[TypedActor]
- else throw new IllegalActorStateException("Actor [" + targetClass.getName + "] is not a sub class of 'TypedActor'")
+ /**
+ * Create a proxy for a RemoteActorRef representing a server managed remote typed actor.
+ *
+ */
+ private[akka] def createProxyForRemoteActorRef[T](intfClass: Class[T], actorRef: ActorRef): T = {
+
+ class MyInvocationHandler extends InvocationHandler {
+ def invoke(proxy: AnyRef, method: Method, args: Array[AnyRef]): AnyRef = {
+ // do nothing, this is just a dummy
+ null
+ }
}
- val context = injectTypedActorContext(proxy)
- actorRef.actor.asInstanceOf[Dispatcher].initialize(targetClass, proxy, proxy, context)
- actorRef.timeout = timeout
- if (remoteAddress.isDefined) actorRef.makeRemote(remoteAddress.get)
- AspectInitRegistry.register(proxy, AspectInit(targetClass, proxy, actorRef, remoteAddress, timeout))
- actorRef.start
- proxy.asInstanceOf[T]
+ val handler = new MyInvocationHandler()
+
+ val interfaces = Array(intfClass, classOf[ServerManagedTypedActor]).asInstanceOf[Array[java.lang.Class[_]]]
+ val jProxy = JProxy.newProxyInstance(intfClass.getClassLoader(), interfaces, handler)
+ val awProxy = Proxy.newInstance(interfaces, Array(jProxy, jProxy), true, false)
+
+ AspectInitRegistry.register(awProxy, AspectInit(intfClass, null, actorRef, None, 5000L))
+ awProxy.asInstanceOf[T]
}
-*/
+
+
+ /*
+ // NOTE: currently not used - but keep it around
+ private[akka] def newInstance[T <: TypedActor](targetClass: Class[T],
+ remoteAddress: Option[InetSocketAddress], timeout: Long): T = {
+ val proxy = {
+ val instance = Proxy.newInstance(targetClass, true, false)
+ if (instance.isInstanceOf[TypedActor]) instance.asInstanceOf[TypedActor]
+ else throw new IllegalActorStateException("Actor [" + targetClass.getName + "] is not a sub class of 'TypedActor'")
+ }
+ val context = injectTypedActorContext(proxy)
+ actorRef.actor.asInstanceOf[Dispatcher].initialize(targetClass, proxy, proxy, context)
+ actorRef.timeout = timeout
+ if (remoteAddress.isDefined) actorRef.makeRemote(remoteAddress.get)
+ AspectInitRegistry.register(proxy, AspectInit(targetClass, proxy, actorRef, remoteAddress, timeout))
+ actorRef.start
+ proxy.asInstanceOf[T]
+ }
+ */
/**
* Stops the current Typed Actor.
@@ -542,6 +565,30 @@ object TypedActor extends Logging {
private[akka] def isJoinPoint(message: Any): Boolean = message.isInstanceOf[JoinPoint]
}
+
+/**
+ * AspectWerkz Aspect that is turning POJO into proxy to a server managed remote TypedActor.
+ *
+ * Is deployed on a 'perInstance' basis with the pointcut 'execution(* *.*(..))',
+ * e.g. all methods on the instance.
+ *
+ * @author Jonas Bonér
+ */
+@Aspect("perInstance")
+private[akka] sealed class ServerManagedTypedActorAspect extends ActorAspect {
+
+ @Around("execution(* *.*(..)) && this(se.scalablesolutions.akka.actor.ServerManagedTypedActor)")
+ def invoke(joinPoint: JoinPoint): AnyRef = {
+ if (!isInitialized) initialize(joinPoint)
+ remoteDispatch(joinPoint)
+ }
+
+ override def initialize(joinPoint: JoinPoint): Unit = {
+ super.initialize(joinPoint)
+ remoteAddress = actorRef.remoteAddress
+ }
+}
+
/**
* AspectWerkz Aspect that is turning POJO into TypedActor.
*
@@ -551,18 +598,9 @@ object TypedActor extends Logging {
* @author Jonas Bonér
*/
@Aspect("perInstance")
-private[akka] sealed class TypedActorAspect {
- @volatile private var isInitialized = false
- @volatile private var isStopped = false
- private var interfaceClass: Class[_] = _
- private var typedActor: TypedActor = _
- private var actorRef: ActorRef = _
- private var remoteAddress: Option[InetSocketAddress] = _
- private var timeout: Long = _
- private var uuid: String = _
- @volatile private var instance: TypedActor = _
+private[akka] sealed class TypedActorAspect extends ActorAspect {
- @Around("execution(* *.*(..))")
+ @Around("execution(* *.*(..)) && !this(se.scalablesolutions.akka.actor.ServerManagedTypedActor)")
def invoke(joinPoint: JoinPoint): AnyRef = {
if (!isInitialized) initialize(joinPoint)
dispatch(joinPoint)
@@ -572,12 +610,26 @@ private[akka] sealed class TypedActorAspect {
if (remoteAddress.isDefined) remoteDispatch(joinPoint)
else localDispatch(joinPoint)
}
+}
- private def localDispatch(joinPoint: JoinPoint): AnyRef = {
- val methodRtti = joinPoint.getRtti.asInstanceOf[MethodRtti]
- val isOneWay = TypedActor.isOneWay(methodRtti)
+/**
+ * Base class for TypedActorAspect and ServerManagedTypedActorAspect to reduce code duplication.
+ */
+private[akka] abstract class ActorAspect {
+ @volatile protected var isInitialized = false
+ @volatile protected var isStopped = false
+ protected var interfaceClass: Class[_] = _
+ protected var typedActor: TypedActor = _
+ protected var actorRef: ActorRef = _
+ protected var timeout: Long = _
+ protected var uuid: String = _
+ protected var remoteAddress: Option[InetSocketAddress] = _
+
+ protected def localDispatch(joinPoint: JoinPoint): AnyRef = {
+ val methodRtti = joinPoint.getRtti.asInstanceOf[MethodRtti]
+ val isOneWay = TypedActor.isOneWay(methodRtti)
val senderActorRef = Some(SenderContextInfo.senderActorRef.value)
- val senderProxy = Some(SenderContextInfo.senderProxy.value)
+ val senderProxy = Some(SenderContextInfo.senderProxy.value)
typedActor.context._sender = senderProxy
if (!actorRef.isRunning && !isStopped) {
@@ -598,7 +650,7 @@ private[akka] sealed class TypedActorAspect {
}
}
- private def remoteDispatch(joinPoint: JoinPoint): AnyRef = {
+ protected def remoteDispatch(joinPoint: JoinPoint): AnyRef = {
val methodRtti = joinPoint.getRtti.asInstanceOf[MethodRtti]
val isOneWay = TypedActor.isOneWay(methodRtti)
@@ -637,7 +689,7 @@ private[akka] sealed class TypedActorAspect {
(escapedArgs, isEscaped)
}
- private def initialize(joinPoint: JoinPoint): Unit = {
+ protected def initialize(joinPoint: JoinPoint): Unit = {
val init = AspectInitRegistry.initFor(joinPoint.getThis)
interfaceClass = init.interfaceClass
typedActor = init.targetInstance
@@ -649,6 +701,7 @@ private[akka] sealed class TypedActorAspect {
}
}
+
/**
* Internal helper class to help pass the contextual information between threads.
*
@@ -670,7 +723,7 @@ private[akka] object AspectInitRegistry extends ListenerManagement {
def register(proxy: AnyRef, init: AspectInit) = {
val res = initializations.put(proxy, init)
- foreachListener(_ ! AspectInitRegistered(proxy, init))
+ notifyListeners(AspectInitRegistered(proxy, init))
res
}
@@ -679,7 +732,7 @@ private[akka] object AspectInitRegistry extends ListenerManagement {
*/
def unregister(proxy: AnyRef): AspectInit = {
val init = initializations.remove(proxy)
- foreachListener(_ ! AspectInitUnregistered(proxy, init))
+ notifyListeners(AspectInitUnregistered(proxy, init))
init.actorRef.stop
init
}
@@ -700,5 +753,11 @@ private[akka] sealed case class AspectInit(
val timeout: Long) {
def this(interfaceClass: Class[_], targetInstance: TypedActor, actorRef: ActorRef, timeout: Long) =
this(interfaceClass, targetInstance, actorRef, None, timeout)
+
}
+
+/**
+ * Marker interface for server manager typed actors.
+ */
+private[akka] sealed trait ServerManagedTypedActor extends TypedActor
diff --git a/akka-typed-actor/src/main/scala/config/TypedActorGuiceConfigurator.scala b/akka-typed-actor/src/main/scala/config/TypedActorGuiceConfigurator.scala
index 339c4d297d..5ca249a3ec 100644
--- a/akka-typed-actor/src/main/scala/config/TypedActorGuiceConfigurator.scala
+++ b/akka-typed-actor/src/main/scala/config/TypedActorGuiceConfigurator.scala
@@ -122,7 +122,6 @@ private[akka] class TypedActorGuiceConfigurator extends TypedActorConfiguratorBa
remoteAddress.foreach { address =>
actorRef.makeRemote(remoteAddress.get)
- RemoteServerModule.registerTypedActor(address, implementationClass.getName, proxy)
}
AspectInitRegistry.register(
diff --git a/akka-typed-actor/src/test/resources/META-INF/aop.xml b/akka-typed-actor/src/test/resources/META-INF/aop.xml
index bdc167ca54..be133a51b8 100644
--- a/akka-typed-actor/src/test/resources/META-INF/aop.xml
+++ b/akka-typed-actor/src/test/resources/META-INF/aop.xml
@@ -2,6 +2,7 @@
+
diff --git a/embedded-repo/com/redis/redisclient/2.8.0-2.0/redisclient-2.8.0-2.0.pom b/embedded-repo/com/redis/redisclient/2.8.0-2.0/redisclient-2.8.0-2.0.pom
new file mode 100644
index 0000000000..12558da1c4
--- /dev/null
+++ b/embedded-repo/com/redis/redisclient/2.8.0-2.0/redisclient-2.8.0-2.0.pom
@@ -0,0 +1,8 @@
+
+
+ 4.0.0
+ com.redis
+ redisclient
+ 2.8.0-2.0
+ jar
+
\ No newline at end of file
diff --git a/embedded-repo/org/scala-tools/time/2.8.0-0.2-SNAPSHOT/time-2.8.0-0.2-SNAPSHOT.jar b/embedded-repo/org/scala-tools/time/2.8.0-0.2-SNAPSHOT/time-2.8.0-0.2-SNAPSHOT.jar
new file mode 100644
index 0000000000..038768fe14
Binary files /dev/null and b/embedded-repo/org/scala-tools/time/2.8.0-0.2-SNAPSHOT/time-2.8.0-0.2-SNAPSHOT.jar differ
diff --git a/embedded-repo/org/scala-tools/time/2.8.0-0.2-SNAPSHOT/time-2.8.0-0.2-SNAPSHOT.pom b/embedded-repo/org/scala-tools/time/2.8.0-0.2-SNAPSHOT/time-2.8.0-0.2-SNAPSHOT.pom
new file mode 100644
index 0000000000..fc1cf3406e
--- /dev/null
+++ b/embedded-repo/org/scala-tools/time/2.8.0-0.2-SNAPSHOT/time-2.8.0-0.2-SNAPSHOT.pom
@@ -0,0 +1,8 @@
+
+
+ 4.0.0
+ org.scala-tools
+ time
+ 2.8.0-0.2-SNAPSHOT
+ jar
+
\ No newline at end of file
diff --git a/embedded-repo/sjson/json/sjson/0.8-SNAPSHOT-2.8.0/sjson-0.8-SNAPSHOT-2.8.0.jar b/embedded-repo/sjson/json/sjson/0.8-SNAPSHOT-2.8.0/sjson-0.8-SNAPSHOT-2.8.0.jar
new file mode 100644
index 0000000000..1542632a82
Binary files /dev/null and b/embedded-repo/sjson/json/sjson/0.8-SNAPSHOT-2.8.0/sjson-0.8-SNAPSHOT-2.8.0.jar differ
diff --git a/project/build/AkkaProject.scala b/project/build/AkkaProject.scala
index 393b531dd6..6a97dbccfd 100644
--- a/project/build/AkkaProject.scala
+++ b/project/build/AkkaProject.scala
@@ -49,6 +49,7 @@ class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) {
lazy val JavaNetRepo = MavenRepository("java.net Repo", "http://download.java.net/maven/2")
lazy val SonatypeSnapshotRepo = MavenRepository("Sonatype OSS Repo", "http://oss.sonatype.org/content/repositories/releases")
lazy val SunJDMKRepo = MavenRepository("Sun JDMK Repo", "http://wp5.e-taxonomy.eu/cdmlib/mavenrepo")
+ lazy val CasbahRepoReleases = MavenRepository("Casbah Release Repo", "http://repo.bumnetworks.com/releases")
}
// -------------------------------------------------------------------------------------------------------------------
@@ -75,6 +76,7 @@ class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) {
lazy val scalaTestModuleConfig = ModuleConfiguration("org.scalatest", ScalaToolsSnapshots)
lazy val logbackModuleConfig = ModuleConfiguration("ch.qos.logback",sbt.DefaultMavenRepository)
lazy val atomikosModuleConfig = ModuleConfiguration("com.atomikos",sbt.DefaultMavenRepository)
+ lazy val casbahRelease = ModuleConfiguration("com.novus",CasbahRepoReleases)
lazy val embeddedRepo = EmbeddedRepo // This is the only exception, because the embedded repo is fast!
// -------------------------------------------------------------------------------------------------------------------
@@ -166,6 +168,10 @@ class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) {
lazy val mongo = "org.mongodb" % "mongo-java-driver" % "2.0" % "compile"
+ lazy val casbah = "com.novus" % "casbah_2.8.0" % "1.0.8.5" % "compile"
+
+ lazy val time = "org.scala-tools" % "time" % "2.8.0-SNAPSHOT-0.2-SNAPSHOT" % "compile"
+
lazy val multiverse = "org.multiverse" % "multiverse-alpha" % MULTIVERSE_VERSION % "compile" intransitive
lazy val netty = "org.jboss.netty" % "netty" % "3.2.2.Final" % "compile"
@@ -180,7 +186,7 @@ class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) {
lazy val sbinary = "sbinary" % "sbinary" % "2.8.0-0.3.1" % "compile"
- lazy val sjson = "sjson.json" % "sjson" % "0.7-2.8.0" % "compile"
+ lazy val sjson = "sjson.json" % "sjson" % "0.8-SNAPSHOT-2.8.0" % "compile"
lazy val slf4j = "org.slf4j" % "slf4j-api" % SLF4J_VERSION % "compile"
@@ -483,6 +489,7 @@ class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) {
class AkkaMongoProject(info: ProjectInfo) extends AkkaDefaultProject(info, distPath) {
val mongo = Dependencies.mongo
+ val casbah = Dependencies.casbah
override def testOptions = TestFilter((name: String) => name.endsWith("Test")) :: Nil
}