Merge branch 'master' into ticket_419

This commit is contained in:
Viktor Klang 2010-09-14 13:48:52 +02:00
commit e43d9b2b61
57 changed files with 2776 additions and 1576 deletions

View file

@ -196,19 +196,10 @@ trait ActorRef extends
*/
@volatile private[akka] var _transactionFactory: Option[TransactionFactory] = None
/**
* This lock ensures thread safety in the dispatching: only one message can
* be dispatched at once on the actor.
*/
protected[akka] val dispatcherLock = new ReentrantLock
/**
* This is a reference to the message currently being processed by the actor
*/
protected[akka] var _currentMessage: Option[MessageInvocation] = None
protected[akka] def currentMessage_=(msg: Option[MessageInvocation]) = guard.withGuard { _currentMessage = msg }
protected[akka] def currentMessage = guard.withGuard { _currentMessage }
@volatile protected[akka] var currentMessage: MessageInvocation = null
/**
* Comparison only takes uuid into account.
@ -978,12 +969,12 @@ class LocalActorRef private[akka](
protected[akka] def postMessageToMailbox(message: Any, senderOption: Option[ActorRef]): Unit = {
joinTransaction(message)
if (isRemotingEnabled && remoteAddress.isDefined) {
if (remoteAddress.isDefined && isRemotingEnabled) {
RemoteClientModule.send[Any](
message, senderOption, None, remoteAddress.get, timeout, true, this, None, ActorType.ScalaActor)
} else {
val invocation = new MessageInvocation(this, message, senderOption, None, transactionSet.get)
invocation.send
dispatcher dispatch invocation
}
}
@ -994,7 +985,7 @@ class LocalActorRef private[akka](
senderFuture: Option[CompletableFuture[T]]): CompletableFuture[T] = {
joinTransaction(message)
if (isRemotingEnabled && remoteAddress.isDefined) {
if (remoteAddress.isDefined && isRemotingEnabled) {
val future = RemoteClientModule.send[T](
message, senderOption, senderFuture, remoteAddress.get, timeout, false, this, None, ActorType.ScalaActor)
if (future.isDefined) future.get
@ -1004,7 +995,7 @@ class LocalActorRef private[akka](
else new DefaultCompletableFuture[T](timeout)
val invocation = new MessageInvocation(
this, message, senderOption, Some(future.asInstanceOf[CompletableFuture[Any]]), transactionSet.get)
invocation.send
dispatcher dispatch invocation
future
}
}
@ -1016,7 +1007,7 @@ class LocalActorRef private[akka](
if (isShutdown)
Actor.log.warning("Actor [%s] is shut down,\n\tignoring message [%s]", toString, messageHandle)
else {
currentMessage = Option(messageHandle)
currentMessage = messageHandle
try {
dispatch(messageHandle)
} catch {
@ -1024,7 +1015,7 @@ class LocalActorRef private[akka](
Actor.log.error(e, "Could not invoke actor [%s]", this)
throw e
} finally {
currentMessage = None //TODO: Don't reset this, we might want to resend the message
currentMessage = null //TODO: Don't reset this, we might want to resend the message
}
}
}
@ -1187,7 +1178,7 @@ class LocalActorRef private[akka](
}
private def dispatch[T](messageHandle: MessageInvocation) = {
Actor.log.trace("Invoking actor with message:\n" + messageHandle)
Actor.log.trace("Invoking actor with message: %s\n",messageHandle)
val message = messageHandle.message //serializeMessage(messageHandle.message)
var topLevelTransaction = false
val txSet: Option[CountDownCommitBarrier] =
@ -1351,17 +1342,18 @@ object RemoteActorSystemMessage {
* @author <a href="http://jonasboner.com">Jonas Bon&#233;r</a>
*/
private[akka] case class RemoteActorRef private[akka] (
uuuid: String,
classOrServiceName: String,
val className: String,
val hostname: String,
val port: Int,
_timeout: Long,
loader: Option[ClassLoader])
loader: Option[ClassLoader],
val actorType: ActorType = ActorType.ScalaActor)
extends ActorRef with ScalaActorRef {
ensureRemotingEnabled
_uuid = uuuid
id = classOrServiceName
timeout = _timeout
start
@ -1369,7 +1361,7 @@ private[akka] case class RemoteActorRef private[akka] (
def postMessageToMailbox(message: Any, senderOption: Option[ActorRef]): Unit =
RemoteClientModule.send[Any](
message, senderOption, None, remoteAddress.get, timeout, true, this, None, ActorType.ScalaActor)
message, senderOption, None, remoteAddress.get, timeout, true, this, None, actorType)
def postMessageToMailboxAndCreateFutureResultWithTimeout[T](
message: Any,
@ -1377,7 +1369,7 @@ private[akka] case class RemoteActorRef private[akka] (
senderOption: Option[ActorRef],
senderFuture: Option[CompletableFuture[T]]): CompletableFuture[T] = {
val future = RemoteClientModule.send[T](
message, senderOption, senderFuture, remoteAddress.get, timeout, false, this, None, ActorType.ScalaActor)
message, senderOption, senderFuture, remoteAddress.get, timeout, false, this, None, actorType)
if (future.isDefined) future.get
else throw new IllegalActorStateException("Expected a future from remote call to actor " + toString)
}
@ -1533,10 +1525,9 @@ trait ScalaActorRef extends ActorRefShared { ref: ActorRef =>
* Is defined if the message was sent from another Actor, else None.
*/
def sender: Option[ActorRef] = {
// Five lines of map-performance-avoidance, could be just: currentMessage map { _.sender }
val msg = currentMessage
if (msg.isEmpty) None
else msg.get.sender
if (msg eq null) None
else msg.sender
}
/**
@ -1544,10 +1535,9 @@ trait ScalaActorRef extends ActorRefShared { ref: ActorRef =>
* Is defined if the message was sent with sent with '!!' or '!!!', else None.
*/
def senderFuture(): Option[CompletableFuture[Any]] = {
// Five lines of map-performance-avoidance, could be just: currentMessage map { _.senderFuture }
val msg = currentMessage
if (msg.isEmpty) None
else msg.get.senderFuture
if (msg eq null) None
else msg.senderFuture
}

View file

@ -125,7 +125,7 @@ object ActorRegistry extends ListenerManagement {
actorsByUUID.put(actor.uuid, actor)
// notify listeners
foreachListener(_ ! ActorRegistered(actor))
notifyListeners(ActorRegistered(actor))
}
/**
@ -137,7 +137,7 @@ object ActorRegistry extends ListenerManagement {
actorsById.remove(actor.id,actor)
// notify listeners
foreachListener(_ ! ActorUnregistered(actor))
notifyListeners(ActorUnregistered(actor))
}
/**

View file

@ -7,7 +7,7 @@ package se.scalablesolutions.akka.dispatch
import se.scalablesolutions.akka.actor.{ActorRef, IllegalActorStateException}
import java.util.Queue
import java.util.concurrent.{ConcurrentLinkedQueue, LinkedBlockingQueue}
import java.util.concurrent.{RejectedExecutionException, ConcurrentLinkedQueue, LinkedBlockingQueue}
/**
* Default settings are:
@ -64,7 +64,7 @@ import java.util.concurrent.{ConcurrentLinkedQueue, LinkedBlockingQueue}
*/
class ExecutorBasedEventDrivenDispatcher(
_name: String,
throughput: Int = Dispatchers.THROUGHPUT,
val throughput: Int = Dispatchers.THROUGHPUT,
mailboxConfig: MailboxConfig = Dispatchers.MAILBOX_CONFIG,
config: (ThreadPoolBuilder) => Unit = _ => ()) extends MessageDispatcher with ThreadPoolBuilder {
@ -80,71 +80,84 @@ class ExecutorBasedEventDrivenDispatcher(
val name = "akka:event-driven:dispatcher:" + _name
init
/**
* This is the behavior of an ExecutorBasedEventDrivenDispatchers mailbox
*/
trait ExecutableMailbox extends Runnable { self: MessageQueue =>
final def run = {
val reschedule = try {
processMailbox()
} finally {
dispatcherLock.unlock()
}
if (reschedule || !self.isEmpty)
registerForExecution(self)
}
/**
* Process the messages in the mailbox
*
* @return true if the processing finished before the mailbox was empty, due to the throughput constraint
*/
final def processMailbox(): Boolean = {
val throttle = throughput > 0
var processedMessages = 0
var nextMessage = self.dequeue
if (nextMessage ne null) {
do {
nextMessage.invoke
if(throttle) { //Will be elided when false
processedMessages += 1
if (processedMessages >= throughput) //If we're throttled, break out
return !self.isEmpty
}
nextMessage = self.dequeue
}
while (nextMessage ne null)
}
false
}
}
def dispatch(invocation: MessageInvocation) = {
getMailbox(invocation.receiver) enqueue invocation
dispatch(invocation.receiver)
val mbox = getMailbox(invocation.receiver)
mbox enqueue invocation
registerForExecution(mbox)
}
protected def registerForExecution(mailbox: MessageQueue with ExecutableMailbox): Unit = if (active) {
if (mailbox.dispatcherLock.tryLock()) {
try {
executor execute mailbox
} catch {
case e: RejectedExecutionException =>
mailbox.dispatcherLock.unlock()
throw e
}
}
} else {
log.warning("%s is shut down,\n\tignoring the rest of the messages in the mailbox of\n\t%s", toString, mailbox)
}
/**
* @return the mailbox associated with the actor
*/
private def getMailbox(receiver: ActorRef) = receiver.mailbox.asInstanceOf[MessageQueue]
private def getMailbox(receiver: ActorRef) = receiver.mailbox.asInstanceOf[MessageQueue with ExecutableMailbox]
override def mailboxSize(actorRef: ActorRef) = getMailbox(actorRef).size
override def createMailbox(actorRef: ActorRef): AnyRef = mailboxConfig.newMailbox(bounds = mailboxCapacity, blockDequeue = false)
def dispatch(receiver: ActorRef): Unit = if (active) {
executor.execute(new Runnable() {
def run = {
var lockAcquiredOnce = false
var finishedBeforeMailboxEmpty = false
val lock = receiver.dispatcherLock
val mailbox = getMailbox(receiver)
// this do-while loop is required to prevent missing new messages between the end of the inner while
// loop and releasing the lock
do {
if (lock.tryLock) {
// Only dispatch if we got the lock. Otherwise another thread is already dispatching.
lockAcquiredOnce = true
try {
finishedBeforeMailboxEmpty = processMailbox(receiver)
} finally {
lock.unlock
if (finishedBeforeMailboxEmpty) dispatch(receiver)
}
}
} while ((lockAcquiredOnce && !finishedBeforeMailboxEmpty && !mailbox.isEmpty))
}
})
} else {
log.warning("%s is shut down,\n\tignoring the rest of the messages in the mailbox of\n\t%s", toString, receiver)
override def createMailbox(actorRef: ActorRef): AnyRef = {
if (mailboxCapacity > 0)
new DefaultBoundedMessageQueue(mailboxCapacity,mailboxConfig.pushTimeOut,blockDequeue = false) with ExecutableMailbox
else
new DefaultUnboundedMessageQueue(blockDequeue = false) with ExecutableMailbox
}
/**
* Process the messages in the mailbox of the given actor.
*
* @return true if the processing finished before the mailbox was empty, due to the throughput constraint
*/
def processMailbox(receiver: ActorRef): Boolean = {
var processedMessages = 0
val mailbox = getMailbox(receiver)
var messageInvocation = mailbox.dequeue
while (messageInvocation != null) {
messageInvocation.invoke
processedMessages += 1
// check if we simply continue with other messages, or reached the throughput limit
if (throughput <= 0 || processedMessages < throughput) messageInvocation = mailbox.dequeue
else {
messageInvocation = null
return !mailbox.isEmpty
}
}
false
}
def start = if (!active) {
log.debug("Starting up %s\n\twith throughput [%d]", toString, throughput)
active = true
@ -157,8 +170,10 @@ class ExecutorBasedEventDrivenDispatcher(
uuids.clear
}
def ensureNotActive(): Unit = if (active) throw new IllegalActorStateException(
def ensureNotActive(): Unit = if (active) {
throw new IllegalActorStateException(
"Can't build a new thread pool for a dispatcher that is already up and running")
}
override def toString = "ExecutorBasedEventDrivenDispatcher[" + name + "]"

View file

@ -56,21 +56,14 @@ class ExecutorBasedEventDrivenWorkStealingDispatcher(
/**
* @return the mailbox associated with the actor
*/
private def getMailbox(receiver: ActorRef) = receiver.mailbox.asInstanceOf[Deque[MessageInvocation]]
private def getMailbox(receiver: ActorRef) = receiver.mailbox.asInstanceOf[Deque[MessageInvocation] with MessageQueue with Runnable]
override def mailboxSize(actorRef: ActorRef) = getMailbox(actorRef).size
def dispatch(invocation: MessageInvocation) = if (active) {
getMailbox(invocation.receiver).add(invocation)
executor.execute(new Runnable() {
def run = {
if (!tryProcessMailbox(invocation.receiver)) {
// we are not able to process our mailbox (another thread is busy with it), so lets donate some of our mailbox
// to another actor and then process his mailbox in stead.
findThief(invocation.receiver).foreach( tryDonateAndProcessMessages(invocation.receiver,_) )
}
}
})
val mbox = getMailbox(invocation.receiver)
mbox enqueue invocation
executor execute mbox
} else throw new IllegalActorStateException("Can't submit invocations to dispatcher since it's not started")
/**
@ -79,22 +72,21 @@ class ExecutorBasedEventDrivenWorkStealingDispatcher(
*
* @return true if the mailbox was processed, false otherwise
*/
private def tryProcessMailbox(receiver: ActorRef): Boolean = {
private def tryProcessMailbox(mailbox: MessageQueue): Boolean = {
var lockAcquiredOnce = false
val lock = receiver.dispatcherLock
// this do-wile loop is required to prevent missing new messages between the end of processing
// the mailbox and releasing the lock
do {
if (lock.tryLock) {
if (mailbox.dispatcherLock.tryLock) {
lockAcquiredOnce = true
try {
processMailbox(receiver)
processMailbox(mailbox)
} finally {
lock.unlock
mailbox.dispatcherLock.unlock
}
}
} while ((lockAcquiredOnce && !getMailbox(receiver).isEmpty))
} while ((lockAcquiredOnce && !mailbox.isEmpty))
lockAcquiredOnce
}
@ -102,12 +94,11 @@ class ExecutorBasedEventDrivenWorkStealingDispatcher(
/**
* Process the messages in the mailbox of the given actor.
*/
private def processMailbox(receiver: ActorRef) = {
val mailbox = getMailbox(receiver)
var messageInvocation = mailbox.poll
while (messageInvocation != null) {
private def processMailbox(mailbox: MessageQueue) = {
var messageInvocation = mailbox.dequeue
while (messageInvocation ne null) {
messageInvocation.invoke
messageInvocation = mailbox.poll
messageInvocation = mailbox.dequeue
}
}
@ -145,11 +136,12 @@ class ExecutorBasedEventDrivenWorkStealingDispatcher(
* the thiefs dispatching lock, because in that case another thread is already processing the thiefs mailbox.
*/
private def tryDonateAndProcessMessages(receiver: ActorRef, thief: ActorRef) = {
if (thief.dispatcherLock.tryLock) {
val mailbox = getMailbox(thief)
if (mailbox.dispatcherLock.tryLock) {
try {
while(donateMessage(receiver, thief)) processMailbox(thief)
while(donateMessage(receiver, thief)) processMailbox(mailbox)
} finally {
thief.dispatcherLock.unlock
mailbox.dispatcherLock.unlock
}
}
}
@ -191,18 +183,44 @@ class ExecutorBasedEventDrivenWorkStealingDispatcher(
}
protected override def createMailbox(actorRef: ActorRef): AnyRef = {
if (mailboxCapacity <= 0) new ConcurrentLinkedDeque[MessageInvocation]
else new LinkedBlockingDeque[MessageInvocation](mailboxCapacity)
if (mailboxCapacity <= 0) {
new ConcurrentLinkedDeque[MessageInvocation] with MessageQueue with Runnable {
def enqueue(handle: MessageInvocation): Unit = this.add(handle)
def dequeue: MessageInvocation = this.poll()
def run = {
if (!tryProcessMailbox(this)) {
// we are not able to process our mailbox (another thread is busy with it), so lets donate some of our mailbox
// to another actor and then process his mailbox in stead.
findThief(actorRef).foreach( tryDonateAndProcessMessages(actorRef,_) )
}
}
}
}
else {
new LinkedBlockingDeque[MessageInvocation](mailboxCapacity) with MessageQueue with Runnable {
def enqueue(handle: MessageInvocation): Unit = this.add(handle)
def dequeue: MessageInvocation = this.poll()
def run = {
if (!tryProcessMailbox(this)) {
// we are not able to process our mailbox (another thread is busy with it), so lets donate some of our mailbox
// to another actor and then process his mailbox in stead.
findThief(actorRef).foreach( tryDonateAndProcessMessages(actorRef,_) )
}
}
}
}
}
override def register(actorRef: ActorRef) = {
verifyActorsAreOfSameType(actorRef)
pooledActors.add(actorRef)
pooledActors add actorRef
super.register(actorRef)
}
override def unregister(actorRef: ActorRef) = {
pooledActors.remove(actorRef)
pooledActors remove actorRef
super.unregister(actorRef)
}

View file

@ -8,10 +8,10 @@ import se.scalablesolutions.akka.actor.{Actor, ActorRef, ActorInitializationExce
import org.multiverse.commitbarriers.CountDownCommitBarrier
import se.scalablesolutions.akka.AkkaException
import se.scalablesolutions.akka.util.{Duration, HashCode, Logging}
import java.util.{Queue, List}
import java.util.concurrent._
import concurrent.forkjoin.LinkedTransferQueue
import se.scalablesolutions.akka.util.{SimpleLock, Duration, HashCode, Logging}
/**
* @author <a href="http://jonasboner.com">Jonas Bon&#233;r</a>
@ -30,8 +30,6 @@ final class MessageInvocation(val receiver: ActorRef,
"Don't call 'self ! message' in the Actor's constructor (e.g. body of the class).")
}
def send = receiver.dispatcher.dispatch(this)
override def hashCode(): Int = synchronized {
var result = HashCode.SEED
result = HashCode.hash(result, receiver.actor)
@ -63,6 +61,7 @@ class MessageQueueAppendFailedException(message: String) extends AkkaException(m
* @author <a href="http://jonasboner.com">Jonas Bon&#233;r</a>
*/
trait MessageQueue {
val dispatcherLock = new SimpleLock
def enqueue(handle: MessageInvocation)
def dequeue(): MessageInvocation
def size: Int
@ -84,40 +83,36 @@ case class MailboxConfig(capacity: Int, pushTimeOut: Option[Duration], blockingD
*/
def newMailbox(bounds: Int = capacity,
pushTime: Option[Duration] = pushTimeOut,
blockDequeue: Boolean = blockingDequeue) : MessageQueue = {
if (bounds <= 0) { //UNBOUNDED: Will never block enqueue and optionally blocking dequeue
new LinkedTransferQueue[MessageInvocation] with MessageQueue {
def enqueue(handle: MessageInvocation): Unit = this add handle
def dequeue(): MessageInvocation = {
if(blockDequeue) this.take()
else this.poll()
}
}
}
else if (pushTime.isDefined) { //BOUNDED: Timeouted enqueue with MessageQueueAppendFailedException and optionally blocking dequeue
val time = pushTime.get
new BoundedTransferQueue[MessageInvocation](bounds) with MessageQueue {
def enqueue(handle: MessageInvocation) {
if (!this.offer(handle,time.length,time.unit))
throw new MessageQueueAppendFailedException("Couldn't enqueue message " + handle + " to " + this.toString)
}
blockDequeue: Boolean = blockingDequeue) : MessageQueue =
if (capacity > 0) new DefaultBoundedMessageQueue(bounds,pushTime,blockDequeue)
else new DefaultUnboundedMessageQueue(blockDequeue)
}
def dequeue(): MessageInvocation = {
if (blockDequeue) this.take()
else this.poll()
}
}
class DefaultUnboundedMessageQueue(blockDequeue: Boolean) extends LinkedBlockingQueue[MessageInvocation] with MessageQueue {
final def enqueue(handle: MessageInvocation) {
this add handle
}
final def dequeue(): MessageInvocation =
if (blockDequeue) this.take()
else this.poll()
}
class DefaultBoundedMessageQueue(capacity: Int, pushTimeOut: Option[Duration], blockDequeue: Boolean) extends LinkedBlockingQueue[MessageInvocation](capacity) with MessageQueue {
final def enqueue(handle: MessageInvocation) {
if (pushTimeOut.isDefined) {
if(!this.offer(handle,pushTimeOut.get.length,pushTimeOut.get.unit))
throw new MessageQueueAppendFailedException("Couldn't enqueue message " + handle + " to " + toString)
}
else { //BOUNDED: Blocking enqueue and optionally blocking dequeue
new LinkedBlockingQueue[MessageInvocation](bounds) with MessageQueue {
def enqueue(handle: MessageInvocation): Unit = this put handle
def dequeue(): MessageInvocation = {
if(blockDequeue) this.take()
else this.poll()
}
}
else {
this put handle
}
}
final def dequeue(): MessageInvocation =
if (blockDequeue) this.take()
else this.poll()
}
/**
@ -139,7 +134,7 @@ trait MessageDispatcher extends Logging {
}
def unregister(actorRef: ActorRef) = {
uuids remove actorRef.uuid
//actorRef.mailbox = null //FIXME should we null out the mailbox here?
actorRef.mailbox = null
if (canBeShutDown) shutdown // shut down in the dispatcher's references is zero
}
@ -156,14 +151,4 @@ trait MessageDispatcher extends Logging {
* Creates and returns a mailbox for the given actor
*/
protected def createMailbox(actorRef: ActorRef): AnyRef = null
}
/**
* @author <a href="http://jonasboner.com">Jonas Bon&#233;r</a>
*/
trait MessageDemultiplexer {
def select
def wakeUp
def acquireSelectedInvocations: List[MessageInvocation]
def releaseSelectedInvocations
}
}

View file

@ -1,141 +0,0 @@
/**
* Copyright (C) 2009-2010 Scalable Solutions AB <http://scalablesolutions.se>
*/
package se.scalablesolutions.akka.dispatch
import concurrent.forkjoin.LinkedTransferQueue
import java.util.concurrent.{TimeUnit, Semaphore}
import java.util.Iterator
import se.scalablesolutions.akka.util.Logger
class BoundedTransferQueue[E <: AnyRef](val capacity: Int) extends LinkedTransferQueue[E] {
require(capacity > 0)
protected val guard = new Semaphore(capacity)
override def take(): E = {
val e = super.take
if (e ne null) guard.release
e
}
override def poll(): E = {
val e = super.poll
if (e ne null) guard.release
e
}
override def poll(timeout: Long, unit: TimeUnit): E = {
val e = super.poll(timeout,unit)
if (e ne null) guard.release
e
}
override def remainingCapacity = guard.availablePermits
override def remove(o: AnyRef): Boolean = {
if (super.remove(o)) {
guard.release
true
} else {
false
}
}
override def offer(e: E): Boolean = {
if (guard.tryAcquire) {
val result = try {
super.offer(e)
} catch {
case e => guard.release; throw e
}
if (!result) guard.release
result
} else
false
}
override def offer(e: E, timeout: Long, unit: TimeUnit): Boolean = {
if (guard.tryAcquire(timeout,unit)) {
val result = try {
super.offer(e)
} catch {
case e => guard.release; throw e
}
if (!result) guard.release
result
} else
false
}
override def add(e: E): Boolean = {
if (guard.tryAcquire) {
val result = try {
super.add(e)
} catch {
case e => guard.release; throw e
}
if (!result) guard.release
result
} else
false
}
override def put(e :E): Unit = {
guard.acquire
try {
super.put(e)
} catch {
case e => guard.release; throw e
}
}
override def tryTransfer(e: E): Boolean = {
if (guard.tryAcquire) {
val result = try {
super.tryTransfer(e)
} catch {
case e => guard.release; throw e
}
if (!result) guard.release
result
} else
false
}
override def tryTransfer(e: E, timeout: Long, unit: TimeUnit): Boolean = {
if (guard.tryAcquire(timeout,unit)) {
val result = try {
super.tryTransfer(e)
} catch {
case e => guard.release; throw e
}
if (!result) guard.release
result
} else
false
}
override def transfer(e: E): Unit = {
if (guard.tryAcquire) {
try {
super.transfer(e)
} catch {
case e => guard.release; throw e
}
}
}
override def iterator: Iterator[E] = {
val it = super.iterator
new Iterator[E] {
def hasNext = it.hasNext
def next = it.next
def remove {
it.remove
guard.release //Assume remove worked if no exception was thrown
}
}
}
}

View file

@ -40,6 +40,23 @@ trait ListenerManagement extends Logging {
if (manageLifeCycleOfListeners) listener.stop
}
/*
* Returns whether there are any listeners currently
*/
def hasListeners: Boolean = !listeners.isEmpty
protected def notifyListeners(message: => Any) {
if (hasListeners) {
val msg = message
val iterator = listeners.iterator
while (iterator.hasNext) {
val listener = iterator.next
if (listener.isRunning) listener ! msg
else log.warning("Can't notify [%s] since it is not running.", listener)
}
}
}
/**
* Execute <code>f</code> with each listener as argument.
*/

View file

@ -5,6 +5,7 @@
package se.scalablesolutions.akka.util
import java.util.concurrent.locks.{ReentrantReadWriteLock, ReentrantLock}
import java.util.concurrent.atomic.AtomicBoolean
/**
* @author <a href="http://jonasboner.com">Jonas Bon&#233;r</a>
@ -58,3 +59,56 @@ class ReadWriteGuard {
}
}
/**
* A very simple lock that uses CCAS (Compare Compare-And-Swap)
* Does not keep track of the owner and isn't Reentrant, so don't nest and try to stick to the if*-methods
*/
class SimpleLock {
val acquired = new AtomicBoolean(false)
def ifPossible(perform: () => Unit): Boolean = {
if (tryLock()) {
try {
perform
} finally {
unlock()
}
true
} else false
}
def ifPossibleYield[T](perform: () => T): Option[T] = {
if (tryLock()) {
try {
Some(perform())
} finally {
unlock()
}
} else None
}
def ifPossibleApply[T,R](value: T)(function: (T) => R): Option[R] = {
if (tryLock()) {
try {
Some(function(value))
} finally {
unlock()
}
} else None
}
def tryLock() = {
if (acquired.get) false
else acquired.compareAndSet(false,true)
}
def tryUnlock() = {
acquired.compareAndSet(true,false)
}
def locked = acquired.get
def unlock() {
acquired.set(false)
}
}

View file

@ -5,11 +5,10 @@ import org.scalatest.junit.JUnitSuite
import org.junit.Test
import se.scalablesolutions.akka.dispatch.Dispatchers
import java.util.concurrent.{TimeUnit, CountDownLatch}
import se.scalablesolutions.akka.actor.{IllegalActorStateException, Actor}
import Actor._
import se.scalablesolutions.akka.dispatch.{MessageQueue, Dispatchers}
object ExecutorBasedEventDrivenWorkStealingDispatcherSpec {
val delayableActorDispatcher = Dispatchers.newExecutorBasedEventDrivenWorkStealingDispatcher("pooled-dispatcher")
@ -18,7 +17,7 @@ object ExecutorBasedEventDrivenWorkStealingDispatcherSpec {
class DelayableActor(name: String, delay: Int, finishedCounter: CountDownLatch) extends Actor {
self.dispatcher = delayableActorDispatcher
var invocationCount = 0
@volatile var invocationCount = 0
self.id = name
def receive = {
@ -61,10 +60,14 @@ class ExecutorBasedEventDrivenWorkStealingDispatcherSpec extends JUnitSuite with
val slow = actorOf(new DelayableActor("slow", 50, finishedCounter)).start
val fast = actorOf(new DelayableActor("fast", 10, finishedCounter)).start
var sentToFast = 0
for (i <- 1 to 100) {
// send most work to slow actor
if (i % 20 == 0)
if (i % 20 == 0) {
fast ! i
sentToFast += 1
}
else
slow ! i
}
@ -72,13 +75,18 @@ class ExecutorBasedEventDrivenWorkStealingDispatcherSpec extends JUnitSuite with
// now send some messages to actors to keep the dispatcher dispatching messages
for (i <- 1 to 10) {
Thread.sleep(150)
if (i % 2 == 0)
if (i % 2 == 0) {
fast ! i
sentToFast += 1
}
else
slow ! i
}
finishedCounter.await(5, TimeUnit.SECONDS)
fast.mailbox.asInstanceOf[MessageQueue].isEmpty must be(true)
slow.mailbox.asInstanceOf[MessageQueue].isEmpty must be(true)
fast.actor.asInstanceOf[DelayableActor].invocationCount must be > sentToFast
fast.actor.asInstanceOf[DelayableActor].invocationCount must be >
(slow.actor.asInstanceOf[DelayableActor].invocationCount)
slow.stop

View file

@ -29,7 +29,7 @@ object CassandraStorage extends Storage {
*
* @author <a href="http://jonasboner.com">Jonas Bon&#233;r</a>
*/
class CassandraPersistentMap(id: String) extends PersistentMap[Array[Byte], Array[Byte]] {
class CassandraPersistentMap(id: String) extends PersistentMapBinary {
val uuid = id
val storage = CassandraStorageBackend
}

View file

@ -7,9 +7,11 @@ package se.scalablesolutions.akka.persistence.common
import se.scalablesolutions.akka.stm._
import se.scalablesolutions.akka.stm.TransactionManagement.transaction
import se.scalablesolutions.akka.util.Logging
import se.scalablesolutions.akka.AkkaException
class StorageException(message: String) extends AkkaException(message)
// FIXME move to 'stm' package + add message with more info
class NoTransactionInScopeException extends RuntimeException
class StorageException(message: String) extends RuntimeException(message)
/**
* Example Scala usage.
@ -80,24 +82,90 @@ trait Storage {
*/
trait PersistentMap[K, V] extends scala.collection.mutable.Map[K, V]
with Transactional with Committable with Abortable with Logging {
protected val newAndUpdatedEntries = TransactionalMap[K, V]()
protected val removedEntries = TransactionalVector[K]()
protected val shouldClearOnCommit = Ref[Boolean]()
// operations on the Map
trait Op
case object GET extends Op
case object PUT extends Op
case object REM extends Op
case object UPD extends Op
// append only log: records all mutating operations
protected val appendOnlyTxLog = TransactionalVector[LogEntry]()
case class LogEntry(key: K, value: Option[V], op: Op)
// need to override in subclasses e.g. "sameElements" for Array[Byte]
def equal(k1: K, k2: K): Boolean = k1 == k2
// Seqable type that's required for maintaining the log of distinct keys affected in current transaction
type T <: Equals
// converts key K to the Seqable type Equals
def toEquals(k: K): T
// keys affected in the current transaction
protected val keysInCurrentTx = TransactionalMap[T, K]()
protected def addToListOfKeysInTx(key: K): Unit =
keysInCurrentTx += (toEquals(key), key)
protected def clearDistinctKeys = keysInCurrentTx.clear
protected def filterTxLogByKey(key: K): IndexedSeq[LogEntry] =
appendOnlyTxLog filter(e => equal(e.key, key))
// need to get current value considering the underlying storage as well as the transaction log
protected def getCurrentValue(key: K): Option[V] = {
// get all mutating entries for this key for this tx
val txEntries = filterTxLogByKey(key)
// get the snapshot from the underlying store for this key
val underlying = try {
storage.getMapStorageEntryFor(uuid, key)
} catch { case e: Exception => None }
if (txEntries.isEmpty) underlying
else replay(txEntries, key, underlying)
}
// replay all tx entries for key k with seed = initial
private def replay(txEntries: IndexedSeq[LogEntry], key: K, initial: Option[V]): Option[V] = {
import scala.collection.mutable._
val m = initial match {
case None => Map.empty[K, V]
case Some(v) => Map((key, v))
}
txEntries.foreach {case LogEntry(k, v, o) => o match {
case PUT => m.put(k, v.get)
case REM => m -= k
case UPD => m.update(k, v.get)
}}
m get key
}
// to be concretized in subclasses
val storage: MapStorageBackend[K, V]
def commit = {
if (shouldClearOnCommit.isDefined && shouldClearOnCommit.get) storage.removeMapStorageFor(uuid)
removedEntries.toList.foreach(key => storage.removeMapStorageFor(uuid, key))
storage.insertMapStorageEntriesFor(uuid, newAndUpdatedEntries.toList)
newAndUpdatedEntries.clear
removedEntries.clear
// if (shouldClearOnCommit.isDefined && shouldClearOnCommit.get) storage.removeMapStorageFor(uuid)
appendOnlyTxLog.foreach { case LogEntry(k, v, o) => o match {
case PUT => storage.insertMapStorageEntryFor(uuid, k, v.get)
case UPD => storage.insertMapStorageEntryFor(uuid, k, v.get)
case REM => storage.removeMapStorageFor(uuid, k)
}}
appendOnlyTxLog.clear
clearDistinctKeys
}
def abort = {
newAndUpdatedEntries.clear
removedEntries.clear
appendOnlyTxLog.clear
clearDistinctKeys
shouldClearOnCommit.swap(false)
}
@ -118,68 +186,84 @@ trait PersistentMap[K, V] extends scala.collection.mutable.Map[K, V]
override def put(key: K, value: V): Option[V] = {
register
newAndUpdatedEntries.put(key, value)
val curr = getCurrentValue(key)
appendOnlyTxLog add LogEntry(key, Some(value), PUT)
addToListOfKeysInTx(key)
curr
}
override def update(key: K, value: V) = {
register
newAndUpdatedEntries.update(key, value)
val curr = getCurrentValue(key)
appendOnlyTxLog add LogEntry(key, Some(value), UPD)
addToListOfKeysInTx(key)
curr
}
override def remove(key: K) = {
register
removedEntries.add(key)
newAndUpdatedEntries.get(key)
val curr = getCurrentValue(key)
appendOnlyTxLog add LogEntry(key, None, REM)
addToListOfKeysInTx(key)
curr
}
def slice(start: Option[K], count: Int): List[Tuple2[K, V]] =
def slice(start: Option[K], count: Int): List[(K, V)] =
slice(start, None, count)
def slice(start: Option[K], finish: Option[K], count: Int): List[Tuple2[K, V]] = try {
storage.getMapStorageRangeFor(uuid, start, finish, count)
} catch { case e: Exception => Nil }
def slice(start: Option[K], finish: Option[K], count: Int): List[(K, V)]
override def clear = {
register
appendOnlyTxLog.clear
clearDistinctKeys
shouldClearOnCommit.swap(true)
}
override def contains(key: K): Boolean = try {
newAndUpdatedEntries.contains(key) ||
storage.getMapStorageEntryFor(uuid, key).isDefined
filterTxLogByKey(key) match {
case Seq() => // current tx doesn't use this
storage.getMapStorageEntryFor(uuid, key).isDefined // check storage
case txs => // present in log
txs.last.op != REM // last entry cannot be a REM
}
} catch { case e: Exception => false }
protected def existsInStorage(key: K): Option[V] = try {
storage.getMapStorageEntryFor(uuid, key)
} catch {
case e: Exception => None
}
override def size: Int = try {
storage.getMapStorageSizeFor(uuid)
} catch { case e: Exception => 0 }
// partition key set affected in current tx into those which r added & which r deleted
val (keysAdded, keysRemoved) = keysInCurrentTx.map {
case (kseq, k) => ((kseq, k), getCurrentValue(k))
}.partition(_._2.isDefined)
override def get(key: K): Option[V] = {
if (newAndUpdatedEntries.contains(key)) {
newAndUpdatedEntries.get(key)
}
else try {
storage.getMapStorageEntryFor(uuid, key)
} catch { case e: Exception => None }
// keys which existed in storage but removed in current tx
val inStorageRemovedInTx =
keysRemoved.keySet
.map(_._2)
.filter(k => existsInStorage(k).isDefined)
.size
// all keys in storage
val keysInStorage =
storage.getMapStorageFor(uuid)
.map { case (k, v) => toEquals(k) }
.toSet
// (keys that existed UNION keys added ) - (keys removed)
(keysInStorage union keysAdded.keySet.map(_._1)).size - inStorageRemovedInTx
} catch {
case e: Exception => 0
}
def iterator = elements
// get must consider underlying storage & current uncommitted tx log
override def get(key: K): Option[V] = getCurrentValue(key)
override def elements: Iterator[Tuple2[K, V]] = {
new Iterator[Tuple2[K, V]] {
private val originalList: List[Tuple2[K, V]] = try {
storage.getMapStorageFor(uuid)
} catch {
case e: Throwable => Nil
}
private var elements = newAndUpdatedEntries.toList union originalList.reverse
override def next: Tuple2[K, V]= synchronized {
val element = elements.head
elements = elements.tail
element
}
override def hasNext: Boolean = synchronized { !elements.isEmpty }
}
}
def iterator: Iterator[Tuple2[K, V]]
private def register = {
if (transaction.get.isEmpty) throw new NoTransactionInScopeException
@ -187,6 +271,95 @@ trait PersistentMap[K, V] extends scala.collection.mutable.Map[K, V]
}
}
trait PersistentMapBinary extends PersistentMap[Array[Byte], Array[Byte]] {
import scala.collection.mutable.ArraySeq
type T = ArraySeq[Byte]
def toEquals(k: Array[Byte]) = ArraySeq(k: _*)
override def equal(k1: Array[Byte], k2: Array[Byte]): Boolean = k1 sameElements k2
object COrdering {
implicit object ArraySeqOrdering extends Ordering[ArraySeq[Byte]] {
def compare(o1: ArraySeq[Byte], o2: ArraySeq[Byte]) =
new String(o1.toArray) compare new String(o2.toArray)
}
}
import scala.collection.immutable.{TreeMap, SortedMap}
private def replayAllKeys: SortedMap[ArraySeq[Byte], Array[Byte]] = {
import COrdering._
// need ArraySeq for ordering
val fromStorage =
TreeMap(storage.getMapStorageFor(uuid).map { case (k, v) => (ArraySeq(k: _*), v) }: _*)
val (keysAdded, keysRemoved) = keysInCurrentTx.map {
case (_, k) => (k, getCurrentValue(k))
}.partition(_._2.isDefined)
val inStorageRemovedInTx =
keysRemoved.keySet
.filter(k => existsInStorage(k).isDefined)
.map(k => ArraySeq(k: _*))
(fromStorage -- inStorageRemovedInTx) ++ keysAdded.map { case (k, Some(v)) => (ArraySeq(k: _*), v) }
}
override def slice(start: Option[Array[Byte]], finish: Option[Array[Byte]], count: Int): List[(Array[Byte], Array[Byte])] = try {
val newMap = replayAllKeys
if (newMap isEmpty) List[(Array[Byte], Array[Byte])]()
val startKey =
start match {
case Some(bytes) => Some(ArraySeq(bytes: _*))
case None => None
}
val endKey =
finish match {
case Some(bytes) => Some(ArraySeq(bytes: _*))
case None => None
}
((startKey, endKey, count): @unchecked) match {
case ((Some(s), Some(e), _)) =>
newMap.range(s, e)
.toList
.map(e => (e._1.toArray, e._2))
.toList
case ((Some(s), None, c)) if c > 0 =>
newMap.from(s)
.iterator
.take(count)
.map(e => (e._1.toArray, e._2))
.toList
case ((Some(s), None, _)) =>
newMap.from(s)
.toList
.map(e => (e._1.toArray, e._2))
.toList
case ((None, Some(e), _)) =>
newMap.until(e)
.toList
.map(e => (e._1.toArray, e._2))
.toList
}
} catch { case e: Exception => Nil }
override def iterator: Iterator[(Array[Byte], Array[Byte])] = {
new Iterator[(Array[Byte], Array[Byte])] {
private var elements = replayAllKeys
override def next: (Array[Byte], Array[Byte]) = synchronized {
val (k, v) = elements.head
elements = elements.tail
(k.toArray, v)
}
override def hasNext: Boolean = synchronized { !elements.isEmpty }
}
}
}
/**
* Implements a template for a concrete persistent transactional vector based storage.
*
@ -198,42 +371,83 @@ trait PersistentVector[T] extends IndexedSeq[T] with Transactional with Committa
protected val removedElems = TransactionalVector[T]()
protected val shouldClearOnCommit = Ref[Boolean]()
// operations on the Vector
trait Op
case object ADD extends Op
case object UPD extends Op
case object POP extends Op
// append only log: records all mutating operations
protected val appendOnlyTxLog = TransactionalVector[LogEntry]()
case class LogEntry(index: Option[Int], value: Option[T], op: Op)
// need to override in subclasses e.g. "sameElements" for Array[Byte]
def equal(v1: T, v2: T): Boolean = v1 == v2
val storage: VectorStorageBackend[T]
def commit = {
for (element <- newElems) storage.insertVectorStorageEntryFor(uuid, element)
for (entry <- updatedElems) storage.updateVectorStorageEntryFor(uuid, entry._1, entry._2)
newElems.clear
updatedElems.clear
for(entry <- appendOnlyTxLog) {
entry match {
case LogEntry(_, Some(v), ADD) => storage.insertVectorStorageEntryFor(uuid, v)
case LogEntry(Some(i), Some(v), UPD) => storage.updateVectorStorageEntryFor(uuid, i, v)
case LogEntry(_, _, POP) => //..
}
}
appendOnlyTxLog.clear
}
def abort = {
newElems.clear
updatedElems.clear
removedElems.clear
appendOnlyTxLog.clear
shouldClearOnCommit.swap(false)
}
private def replay: List[T] = {
import scala.collection.mutable.ArrayBuffer
var elemsStorage = ArrayBuffer(storage.getVectorStorageRangeFor(uuid, None, None, storage.getVectorStorageSizeFor(uuid)).reverse: _*)
for(entry <- appendOnlyTxLog) {
entry match {
case LogEntry(_, Some(v), ADD) => elemsStorage += v
case LogEntry(Some(i), Some(v), UPD) => elemsStorage.update(i, v)
case LogEntry(_, _, POP) => elemsStorage = elemsStorage.drop(1)
}
}
elemsStorage.toList.reverse
}
def +(elem: T) = add(elem)
def add(elem: T) = {
register
newElems + elem
appendOnlyTxLog + LogEntry(None, Some(elem), ADD)
}
def apply(index: Int): T = get(index)
def get(index: Int): T = {
if (newElems.size > index) newElems(index)
else storage.getVectorStorageEntryFor(uuid, index)
if (appendOnlyTxLog.isEmpty) {
storage.getVectorStorageEntryFor(uuid, index)
} else {
val curr = replay
curr(index)
}
}
override def slice(start: Int, finish: Int): IndexedSeq[T] = slice(Some(start), Some(finish))
def slice(start: Option[Int], finish: Option[Int], count: Int = 0): IndexedSeq[T] = {
val buffer = new scala.collection.mutable.ArrayBuffer[T]
storage.getVectorStorageRangeFor(uuid, start, finish, count).foreach(buffer.append(_))
buffer
val curr = replay
val s = if (start.isDefined) start.get else 0
val cnt =
if (finish.isDefined) {
val f = finish.get
if (f >= s) (f - s) else count
}
else count
if (s == 0 && cnt == 0) List().toIndexedSeq
else curr.slice(s, s + cnt).toIndexedSeq
}
/**
@ -241,12 +455,13 @@ trait PersistentVector[T] extends IndexedSeq[T] with Transactional with Committa
*/
def pop: T = {
register
appendOnlyTxLog + LogEntry(None, None, POP)
throw new UnsupportedOperationException("PersistentVector::pop is not implemented")
}
def update(index: Int, newElem: T) = {
register
storage.updateVectorStorageEntryFor(uuid, index, newElem)
appendOnlyTxLog + LogEntry(Some(index), Some(newElem), UPD)
}
override def first: T = get(0)
@ -260,7 +475,7 @@ trait PersistentVector[T] extends IndexedSeq[T] with Transactional with Committa
}
}
def length: Int = storage.getVectorStorageSizeFor(uuid) + newElems.length
def length: Int = replay.length
private def register = {
if (transaction.get.isEmpty) throw new NoTransactionInScopeException

View file

@ -9,7 +9,7 @@ import se.scalablesolutions.akka.persistence.common._
import se.scalablesolutions.akka.util.UUID
object MongoStorage extends Storage {
type ElementType = AnyRef
type ElementType = Array[Byte]
def newMap: PersistentMap[ElementType, ElementType] = newMap(UUID.newUuid.toString)
def newVector: PersistentVector[ElementType] = newVector(UUID.newUuid.toString)
@ -29,7 +29,7 @@ object MongoStorage extends Storage {
*
* @author <a href="http://debasishg.blogspot.com">Debasish Ghosh</a>
*/
class MongoPersistentMap(id: String) extends PersistentMap[AnyRef, AnyRef] {
class MongoPersistentMap(id: String) extends PersistentMapBinary {
val uuid = id
val storage = MongoStorageBackend
}
@ -40,12 +40,12 @@ class MongoPersistentMap(id: String) extends PersistentMap[AnyRef, AnyRef] {
*
* @author <a href="http://debasishg.blogspot.com">Debaissh Ghosh</a>
*/
class MongoPersistentVector(id: String) extends PersistentVector[AnyRef] {
class MongoPersistentVector(id: String) extends PersistentVector[Array[Byte]] {
val uuid = id
val storage = MongoStorageBackend
}
class MongoPersistentRef(id: String) extends PersistentRef[AnyRef] {
class MongoPersistentRef(id: String) extends PersistentRef[Array[Byte]] {
val uuid = id
val storage = MongoStorageBackend
}

View file

@ -9,13 +9,8 @@ import se.scalablesolutions.akka.persistence.common._
import se.scalablesolutions.akka.util.Logging
import se.scalablesolutions.akka.config.Config.config
import sjson.json.Serializer._
import java.util.NoSuchElementException
import com.mongodb._
import java.util.{Map=>JMap, List=>JList, ArrayList=>JArrayList}
import com.novus.casbah.mongodb.Imports._
/**
* A module for supporting MongoDB based persistence.
@ -28,294 +23,208 @@ import java.util.{Map=>JMap, List=>JList, ArrayList=>JArrayList}
* @author <a href="http://debasishg.blogspot.com">Debasish Ghosh</a>
*/
private[akka] object MongoStorageBackend extends
MapStorageBackend[AnyRef, AnyRef] with
VectorStorageBackend[AnyRef] with
RefStorageBackend[AnyRef] with
MapStorageBackend[Array[Byte], Array[Byte]] with
VectorStorageBackend[Array[Byte]] with
RefStorageBackend[Array[Byte]] with
Logging {
// enrich with null safe findOne
class RichDBCollection(value: DBCollection) {
def findOneNS(o: DBObject): Option[DBObject] = {
value.findOne(o) match {
case null => None
case x => Some(x)
}
}
}
implicit def enrichDBCollection(c: DBCollection) = new RichDBCollection(c)
val KEY = "key"
val VALUE = "value"
val KEY = "__key"
val REF = "__ref"
val COLLECTION = "akka_coll"
val MONGODB_SERVER_HOSTNAME = config.getString("akka.storage.mongodb.hostname", "127.0.0.1")
val MONGODB_SERVER_DBNAME = config.getString("akka.storage.mongodb.dbname", "testdb")
val MONGODB_SERVER_PORT = config.getInt("akka.storage.mongodb.port", 27017)
val HOSTNAME = config.getString("akka.storage.mongodb.hostname", "127.0.0.1")
val DBNAME = config.getString("akka.storage.mongodb.dbname", "testdb")
val PORT = config.getInt("akka.storage.mongodb.port", 27017)
val db = new Mongo(MONGODB_SERVER_HOSTNAME, MONGODB_SERVER_PORT)
val coll = db.getDB(MONGODB_SERVER_DBNAME).getCollection(COLLECTION)
val db: MongoDB = MongoConnection(HOSTNAME, PORT)(DBNAME)
val coll: MongoCollection = db(COLLECTION)
private[this] val serializer = SJSON
def drop() { db.dropDatabase() }
def insertMapStorageEntryFor(name: String, key: AnyRef, value: AnyRef) {
def insertMapStorageEntryFor(name: String, key: Array[Byte], value: Array[Byte]) {
insertMapStorageEntriesFor(name, List((key, value)))
}
def insertMapStorageEntriesFor(name: String, entries: List[Tuple2[AnyRef, AnyRef]]) {
import java.util.{Map, HashMap}
val m: Map[AnyRef, AnyRef] = new HashMap
for ((k, v) <- entries) {
m.put(k, serializer.out(v))
}
nullSafeFindOne(name) match {
case None =>
coll.insert(new BasicDBObject().append(KEY, name).append(VALUE, m))
case Some(dbo) => {
// collate the maps
val o = dbo.get(VALUE).asInstanceOf[Map[AnyRef, AnyRef]]
o.putAll(m)
val newdbo = new BasicDBObject().append(KEY, name).append(VALUE, o)
coll.update(new BasicDBObject().append(KEY, name), newdbo, true, false)
def insertMapStorageEntriesFor(name: String, entries: List[(Array[Byte], Array[Byte])]) {
db.safely { db =>
val q: DBObject = MongoDBObject(KEY -> name)
coll.findOne(q) match {
case Some(dbo) =>
entries.foreach { case (k, v) => dbo += new String(k) -> v }
db.safely { db => coll.update(q, dbo, true, false) }
case None =>
val builder = MongoDBObject.newBuilder
builder += KEY -> name
entries.foreach { case (k, v) => builder += new String(k) -> v }
coll += builder.result.asDBObject
}
}
}
def removeMapStorageFor(name: String): Unit = {
val q = new BasicDBObject
q.put(KEY, name)
coll.remove(q)
val q: DBObject = MongoDBObject(KEY -> name)
db.safely { db => coll.remove(q) }
}
def removeMapStorageFor(name: String, key: AnyRef): Unit = {
nullSafeFindOne(name) match {
case None =>
case Some(dbo) => {
val orig = dbo.get(VALUE).asInstanceOf[DBObject].toMap
if (key.isInstanceOf[List[_]]) {
val keys = key.asInstanceOf[List[_]]
keys.foreach(k => orig.remove(k.asInstanceOf[String]))
} else {
orig.remove(key.asInstanceOf[String])
}
// remove existing reference
removeMapStorageFor(name)
// and insert
coll.insert(new BasicDBObject().append(KEY, name).append(VALUE, orig))
}
private def queryFor[T](name: String)(body: (MongoDBObject, Option[DBObject]) => T): T = {
val q = MongoDBObject(KEY -> name)
body(q, coll.findOne(q))
}
def removeMapStorageFor(name: String, key: Array[Byte]): Unit = queryFor(name) { (q, dbo) =>
dbo.foreach { d =>
d -= new String(key)
db.safely { db => coll.update(q, d, true, false) }
}
}
def getMapStorageEntryFor(name: String, key: AnyRef): Option[AnyRef] =
getValueForKey(name, key.asInstanceOf[String])
def getMapStorageSizeFor(name: String): Int = {
nullSafeFindOne(name) match {
case None => 0
case Some(dbo) =>
dbo.get(VALUE).asInstanceOf[JMap[String, AnyRef]].keySet.size
}
def getMapStorageEntryFor(name: String, key: Array[Byte]): Option[Array[Byte]] = queryFor(name) { (q, dbo) =>
dbo.map { d =>
d.getAs[Array[Byte]](new String(key))
}.getOrElse(None)
}
def getMapStorageFor(name: String): List[Tuple2[AnyRef, AnyRef]] = {
val m =
nullSafeFindOne(name) match {
case None =>
throw new NoSuchElementException(name + " not present")
case Some(dbo) =>
dbo.get(VALUE).asInstanceOf[JMap[String, AnyRef]]
}
val n =
List(m.keySet.toArray: _*).asInstanceOf[List[String]]
val vals =
for(s <- n)
yield (s, serializer.in[AnyRef](m.get(s).asInstanceOf[Array[Byte]]))
vals.asInstanceOf[List[Tuple2[String, AnyRef]]]
def getMapStorageSizeFor(name: String): Int = queryFor(name) { (q, dbo) =>
dbo.map { d =>
d.size - 2 // need to exclude object id and our KEY
}.getOrElse(0)
}
def getMapStorageRangeFor(name: String, start: Option[AnyRef],
finish: Option[AnyRef],
count: Int): List[Tuple2[AnyRef, AnyRef]] = {
val m =
nullSafeFindOne(name) match {
case None =>
throw new NoSuchElementException(name + " not present")
case Some(dbo) =>
dbo.get(VALUE).asInstanceOf[JMap[String, AnyRef]]
}
/**
* <tt>count</tt> is the max number of results to return. Start with
* <tt>start</tt> or 0 (if <tt>start</tt> is not defined) and go until
* you hit <tt>finish</tt> or <tt>count</tt>.
*/
val s = if (start.isDefined) start.get.asInstanceOf[Int] else 0
val cnt =
if (finish.isDefined) {
val f = finish.get.asInstanceOf[Int]
if (f >= s) math.min(count, (f - s)) else count
}
else count
val n =
List(m.keySet.toArray: _*).asInstanceOf[List[String]].sortWith((e1, e2) => (e1 compareTo e2) < 0).slice(s, s + cnt)
val vals =
for(s <- n)
yield (s, serializer.in[AnyRef](m.get(s).asInstanceOf[Array[Byte]]))
vals.asInstanceOf[List[Tuple2[String, AnyRef]]]
def getMapStorageFor(name: String): List[(Array[Byte], Array[Byte])] = queryFor(name) { (q, dbo) =>
dbo.map { d =>
for {
(k, v) <- d.toList
if k != "_id" && k != KEY
} yield (k.getBytes, v.asInstanceOf[Array[Byte]])
}.getOrElse(List.empty[(Array[Byte], Array[Byte])])
}
private def getValueForKey(name: String, key: String): Option[AnyRef] = {
try {
nullSafeFindOne(name) match {
case None => None
case Some(dbo) =>
Some(serializer.in[AnyRef](
dbo.get(VALUE)
.asInstanceOf[JMap[String, AnyRef]]
.get(key).asInstanceOf[Array[Byte]]))
}
} catch {
case e =>
throw new NoSuchElementException(e.getMessage)
}
def getMapStorageRangeFor(name: String, start: Option[Array[Byte]],
finish: Option[Array[Byte]],
count: Int): List[(Array[Byte], Array[Byte])] = queryFor(name) { (q, dbo) =>
dbo.map { d =>
// get all keys except the special ones
val keys = d.keys
.filter(k => k != "_id" && k != KEY)
.toList
.sortWith(_ < _)
// if the supplied start is not defined, get the head of keys
val s = start.map(new String(_)).getOrElse(keys.head)
// if the supplied finish is not defined, get the last element of keys
val f = finish.map(new String(_)).getOrElse(keys.last)
// slice from keys: both ends inclusive
val ks = keys.slice(keys.indexOf(s), scala.math.min(count, keys.indexOf(f) + 1))
ks.map(k => (k.getBytes, d.getAs[Array[Byte]](k).get))
}.getOrElse(List.empty[(Array[Byte], Array[Byte])])
}
def insertVectorStorageEntriesFor(name: String, elements: List[AnyRef]) = {
val q = new BasicDBObject
q.put(KEY, name)
val currentList =
coll.findOneNS(q) match {
case None =>
new JArrayList[AnyRef]
case Some(dbo) =>
dbo.get(VALUE).asInstanceOf[JArrayList[AnyRef]]
}
if (!currentList.isEmpty) {
// record exists
// remove before adding
coll.remove(q)
}
// add to the current list
elements.map(serializer.out(_)).foreach(currentList.add(_))
coll.insert(
new BasicDBObject()
.append(KEY, name)
.append(VALUE, currentList)
)
}
def insertVectorStorageEntryFor(name: String, element: AnyRef) = {
def insertVectorStorageEntryFor(name: String, element: Array[Byte]) = {
insertVectorStorageEntriesFor(name, List(element))
}
def getVectorStorageEntryFor(name: String, index: Int): AnyRef = {
try {
val o =
nullSafeFindOne(name) match {
case None =>
throw new NoSuchElementException(name + " not present")
def insertVectorStorageEntriesFor(name: String, elements: List[Array[Byte]]) = {
// lookup with name
val q: DBObject = MongoDBObject(KEY -> name)
case Some(dbo) =>
dbo.get(VALUE).asInstanceOf[JList[AnyRef]]
db.safely { db =>
coll.findOne(q) match {
// exists : need to update
case Some(dbo) =>
dbo -= KEY
dbo -= "_id"
val listBuilder = MongoDBList.newBuilder
// expensive!
listBuilder ++= (elements ++ dbo.toSeq.sortWith((e1, e2) => (e1._1.toInt < e2._1.toInt)).map(_._2))
val builder = MongoDBObject.newBuilder
builder += KEY -> name
builder ++= listBuilder.result
coll.update(q, builder.result.asDBObject, true, false)
// new : just add
case None =>
val listBuilder = MongoDBList.newBuilder
listBuilder ++= elements
val builder = MongoDBObject.newBuilder
builder += KEY -> name
builder ++= listBuilder.result
coll += builder.result.asDBObject
}
serializer.in[AnyRef](
o.get(index).asInstanceOf[Array[Byte]])
} catch {
case e =>
throw new NoSuchElementException(e.getMessage)
}
}
def getVectorStorageRangeFor(name: String,
start: Option[Int], finish: Option[Int], count: Int): List[AnyRef] = {
try {
val o =
nullSafeFindOne(name) match {
case None =>
throw new NoSuchElementException(name + " not present")
def updateVectorStorageEntryFor(name: String, index: Int, elem: Array[Byte]) = queryFor(name) { (q, dbo) =>
dbo.foreach { d =>
d += ((index.toString, elem))
db.safely { db => coll.update(q, d, true, false) }
}
}
case Some(dbo) =>
dbo.get(VALUE).asInstanceOf[JList[AnyRef]]
}
def getVectorStorageEntryFor(name: String, index: Int): Array[Byte] = queryFor(name) { (q, dbo) =>
dbo.map { d =>
d(index.toString).asInstanceOf[Array[Byte]]
}.getOrElse(Array.empty[Byte])
}
val s = if (start.isDefined) start.get else 0
/**
* if <tt>start</tt> and <tt>finish</tt> both are defined, ignore <tt>count</tt> and
* report the range [start, finish)
* if <tt>start</tt> is not defined, assume <tt>start</tt> = 0
* if <tt>start</tt> == 0 and <tt>finish</tt> == 0, return an empty collection
*/
def getVectorStorageRangeFor(name: String, start: Option[Int], finish: Option[Int], count: Int): List[Array[Byte]] = queryFor(name) { (q, dbo) =>
dbo.map { d =>
val ls = d.filter { case (k, v) => k != KEY && k != "_id" }
.toSeq
.sortWith((e1, e2) => (e1._1.toInt < e2._1.toInt))
.map(_._2)
val st = start.getOrElse(0)
val cnt =
if (finish.isDefined) {
val f = finish.get
if (f >= s) (f - s) else count
if (f >= st) (f - st) else count
}
else count
// pick the subrange and make a Scala list
val l =
List(o.subList(s, s + cnt).toArray: _*)
for(e <- l)
yield serializer.in[AnyRef](e.asInstanceOf[Array[Byte]])
} catch {
case e =>
throw new NoSuchElementException(e.getMessage)
}
if (st == 0 && cnt == 0) List()
ls.slice(st, st + cnt).asInstanceOf[List[Array[Byte]]]
}.getOrElse(List.empty[Array[Byte]])
}
def updateVectorStorageEntryFor(name: String, index: Int, elem: AnyRef) = {
val q = new BasicDBObject
q.put(KEY, name)
val dbobj =
coll.findOneNS(q) match {
case None =>
throw new NoSuchElementException(name + " not present")
case Some(dbo) => dbo
}
val currentList = dbobj.get(VALUE).asInstanceOf[JArrayList[AnyRef]]
currentList.set(index, serializer.out(elem))
coll.update(q,
new BasicDBObject().append(KEY, name).append(VALUE, currentList))
def getVectorStorageSizeFor(name: String): Int = queryFor(name) { (q, dbo) =>
dbo.map { d => d.size - 2 }.getOrElse(0)
}
def getVectorStorageSizeFor(name: String): Int = {
nullSafeFindOne(name) match {
case None => 0
case Some(dbo) =>
dbo.get(VALUE).asInstanceOf[JList[AnyRef]].size
}
}
def insertRefStorageFor(name: String, element: Array[Byte]) = {
// lookup with name
val q: DBObject = MongoDBObject(KEY -> name)
private def nullSafeFindOne(name: String): Option[DBObject] = {
val o = new BasicDBObject
o.put(KEY, name)
coll.findOneNS(o)
}
db.safely { db =>
coll.findOne(q) match {
// exists : need to update
case Some(dbo) =>
dbo += ((REF, element))
coll.update(q, dbo, true, false)
def insertRefStorageFor(name: String, element: AnyRef) = {
nullSafeFindOne(name) match {
case None =>
case Some(dbo) => {
val q = new BasicDBObject
q.put(KEY, name)
coll.remove(q)
// not found : make one
case None =>
val builder = MongoDBObject.newBuilder
builder += KEY -> name
builder += REF -> element
coll += builder.result.asDBObject
}
}
coll.insert(
new BasicDBObject()
.append(KEY, name)
.append(VALUE, serializer.out(element)))
}
def getRefStorageFor(name: String): Option[AnyRef] = {
nullSafeFindOne(name) match {
case None => None
case Some(dbo) =>
Some(serializer.in[AnyRef](dbo.get(VALUE).asInstanceOf[Array[Byte]]))
}
def getRefStorageFor(name: String): Option[Array[Byte]] = queryFor(name) { (q, dbo) =>
dbo.map { d =>
d.getAs[Array[Byte]](REF)
}.getOrElse(None)
}
}

View file

@ -1,32 +1,19 @@
package se.scalablesolutions.akka.persistence.mongo
import org.junit.{Test, Before}
import org.junit.Assert._
import org.scalatest.junit.JUnitSuite
import _root_.dispatch.json.{JsNumber, JsValue}
import _root_.dispatch.json.Js._
import org.scalatest.Spec
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.BeforeAndAfterEach
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
import se.scalablesolutions.akka.actor.{Transactor, Actor, ActorRef}
import Actor._
/**
* A persistent actor based on MongoDB storage.
* <p/>
* Demonstrates a bank account operation consisting of messages that:
* <li>checks balance <tt>Balance</tt></li>
* <li>debits amount<tt>Debit</tt></li>
* <li>debits multiple amounts<tt>MultiDebit</tt></li>
* <li>credits amount<tt>Credit</tt></li>
* <p/>
* Needs a running Mongo server.
* @author <a href="http://debasishg.blogspot.com">Debasish Ghosh</a>
*/
case class Balance(accountNo: String)
case class Debit(accountNo: String, amount: BigInt, failer: ActorRef)
case class MultiDebit(accountNo: String, amounts: List[BigInt], failer: ActorRef)
case class Credit(accountNo: String, amount: BigInt)
case class Debit(accountNo: String, amount: Int, failer: ActorRef)
case class MultiDebit(accountNo: String, amounts: List[Int], failer: ActorRef)
case class Credit(accountNo: String, amount: Int)
case class Log(start: Int, finish: Int)
case object LogSize
@ -35,63 +22,65 @@ class BankAccountActor extends Transactor {
private lazy val accountState = MongoStorage.newMap
private lazy val txnLog = MongoStorage.newVector
import sjson.json.DefaultProtocol._
import sjson.json.JsonSerialization._
def receive: Receive = {
// check balance
case Balance(accountNo) =>
txnLog.add("Balance:" + accountNo)
self.reply(accountState.get(accountNo).get)
txnLog.add(("Balance:" + accountNo).getBytes)
self.reply(
accountState.get(accountNo.getBytes)
.map(frombinary[Int](_))
.getOrElse(0))
// debit amount: can fail
case Debit(accountNo, amount, failer) =>
txnLog.add("Debit:" + accountNo + " " + amount)
txnLog.add(("Debit:" + accountNo + " " + amount).getBytes)
val m = accountState.get(accountNo.getBytes)
.map(frombinary[Int](_))
.getOrElse(0)
accountState.put(accountNo.getBytes, tobinary(m - amount))
if (amount > m) failer !! "Failure"
val m: BigInt =
accountState.get(accountNo) match {
case Some(JsNumber(n)) =>
BigInt(n.asInstanceOf[BigDecimal].intValue)
case None => 0
}
accountState.put(accountNo, (m - amount))
if (amount > m)
failer !! "Failure"
self.reply(m - amount)
// many debits: can fail
// demonstrates true rollback even if multiple puts have been done
case MultiDebit(accountNo, amounts, failer) =>
txnLog.add("MultiDebit:" + accountNo + " " + amounts.map(_.intValue).foldLeft(0)(_ + _))
val sum = amounts.foldRight(0)(_ + _)
txnLog.add(("MultiDebit:" + accountNo + " " + sum).getBytes)
val m: BigInt =
accountState.get(accountNo) match {
case Some(JsNumber(n)) => BigInt(n.toString)
case None => 0
val m = accountState.get(accountNo.getBytes)
.map(frombinary[Int](_))
.getOrElse(0)
var cbal = m
amounts.foreach { amount =>
accountState.put(accountNo.getBytes, tobinary(m - amount))
cbal = cbal - amount
if (cbal < 0) failer !! "Failure"
}
var bal: BigInt = 0
amounts.foreach {amount =>
bal = bal + amount
accountState.put(accountNo, (m - bal))
}
if (bal > m) failer !! "Failure"
self.reply(m - bal)
self.reply(m - sum)
// credit amount
case Credit(accountNo, amount) =>
txnLog.add("Credit:" + accountNo + " " + amount)
txnLog.add(("Credit:" + accountNo + " " + amount).getBytes)
val m = accountState.get(accountNo.getBytes)
.map(frombinary[Int](_))
.getOrElse(0)
accountState.put(accountNo.getBytes, tobinary(m + amount))
val m: BigInt =
accountState.get(accountNo) match {
case Some(JsNumber(n)) =>
BigInt(n.asInstanceOf[BigDecimal].intValue)
case None => 0
}
accountState.put(accountNo, (m + amount))
self.reply(m + amount)
case LogSize =>
self.reply(txnLog.length.asInstanceOf[AnyRef])
self.reply(txnLog.length)
case Log(start, finish) =>
self.reply(txnLog.slice(start, finish))
self.reply(txnLog.slice(start, finish).map(new String(_)))
}
}
@ -102,82 +91,71 @@ class BankAccountActor extends Transactor {
}
}
class MongoPersistentActorSpec extends JUnitSuite {
@Test
def testSuccessfulDebit = {
val bactor = actorOf[BankAccountActor]
bactor.start
val failer = actorOf[PersistentFailerActor]
failer.start
bactor !! Credit("a-123", 5000)
bactor !! Debit("a-123", 3000, failer)
@RunWith(classOf[JUnitRunner])
class MongoPersistentActorSpec extends
Spec with
ShouldMatchers with
BeforeAndAfterEach {
val JsNumber(b) = (bactor !! Balance("a-123")).get.asInstanceOf[JsValue]
assertEquals(BigInt(2000), BigInt(b.intValue))
bactor !! Credit("a-123", 7000)
val JsNumber(b1) = (bactor !! Balance("a-123")).get.asInstanceOf[JsValue]
assertEquals(BigInt(9000), BigInt(b1.intValue))
bactor !! Debit("a-123", 8000, failer)
val JsNumber(b2) = (bactor !! Balance("a-123")).get.asInstanceOf[JsValue]
assertEquals(BigInt(1000), BigInt(b2.intValue))
assert(7 == (bactor !! LogSize).get.asInstanceOf[Int])
import scala.collection.mutable.ArrayBuffer
assert((bactor !! Log(0, 7)).get.asInstanceOf[ArrayBuffer[String]].size == 7)
assert((bactor !! Log(0, 0)).get.asInstanceOf[ArrayBuffer[String]].size == 0)
assert((bactor !! Log(1, 2)).get.asInstanceOf[ArrayBuffer[String]].size == 1)
assert((bactor !! Log(6, 7)).get.asInstanceOf[ArrayBuffer[String]].size == 1)
assert((bactor !! Log(0, 1)).get.asInstanceOf[ArrayBuffer[String]].size == 1)
override def beforeEach {
MongoStorageBackend.drop
}
@Test
def testUnsuccessfulDebit = {
val bactor = actorOf[BankAccountActor]
bactor.start
bactor !! Credit("a-123", 5000)
val JsNumber(b) = (bactor !! Balance("a-123")).get.asInstanceOf[JsValue]
assertEquals(BigInt(5000), BigInt(b.intValue))
val failer = actorOf[PersistentFailerActor]
failer.start
try {
bactor !! Debit("a-123", 7000, failer)
fail("should throw exception")
} catch { case e: RuntimeException => {}}
val JsNumber(b1) = (bactor !! Balance("a-123")).get.asInstanceOf[JsValue]
assertEquals(BigInt(5000), BigInt(b1.intValue))
// should not count the failed one
assert(3 == (bactor !! LogSize).get.asInstanceOf[Int])
override def afterEach {
MongoStorageBackend.drop
}
@Test
def testUnsuccessfulMultiDebit = {
val bactor = actorOf[BankAccountActor]
bactor.start
bactor !! Credit("a-123", 5000)
describe("successful debit") {
it("should debit successfully") {
val bactor = actorOf[BankAccountActor]
bactor.start
val failer = actorOf[PersistentFailerActor]
failer.start
bactor !! Credit("a-123", 5000)
bactor !! Debit("a-123", 3000, failer)
val JsNumber(b) = (bactor !! Balance("a-123")).get.asInstanceOf[JsValue]
assertEquals(BigInt(5000), BigInt(b.intValue))
(bactor !! Balance("a-123")).get.asInstanceOf[Int] should equal(2000)
val failer = actorOf[PersistentFailerActor]
failer.start
try {
bactor !! MultiDebit("a-123", List(500, 2000, 1000, 3000), failer)
fail("should throw exception")
} catch { case e: RuntimeException => {}}
bactor !! Credit("a-123", 7000)
(bactor !! Balance("a-123")).get.asInstanceOf[Int] should equal(9000)
val JsNumber(b1) = (bactor !! Balance("a-123")).get.asInstanceOf[JsValue]
assertEquals(BigInt(5000), BigInt(b1.intValue))
bactor !! Debit("a-123", 8000, failer)
(bactor !! Balance("a-123")).get.asInstanceOf[Int] should equal(1000)
// should not count the failed one
assert(3 == (bactor !! LogSize).get.asInstanceOf[Int])
(bactor !! LogSize).get.asInstanceOf[Int] should equal(7)
(bactor !! Log(0, 7)).get.asInstanceOf[Iterable[String]].size should equal(7)
}
}
describe("unsuccessful debit") {
it("debit should fail") {
val bactor = actorOf[BankAccountActor]
bactor.start
val failer = actorOf[PersistentFailerActor]
failer.start
bactor !! Credit("a-123", 5000)
(bactor !! Balance("a-123")).get.asInstanceOf[Int] should equal(5000)
evaluating {
bactor !! Debit("a-123", 7000, failer)
} should produce [Exception]
(bactor !! Balance("a-123")).get.asInstanceOf[Int] should equal(5000)
(bactor !! LogSize).get.asInstanceOf[Int] should equal(3)
}
}
describe("unsuccessful multidebit") {
it("multidebit should fail") {
val bactor = actorOf[BankAccountActor]
bactor.start
val failer = actorOf[PersistentFailerActor]
failer.start
bactor !! Credit("a-123", 5000)
(bactor !! Balance("a-123")).get.asInstanceOf[Int] should equal(5000)
evaluating {
bactor !! MultiDebit("a-123", List(1000, 2000, 4000), failer)
} should produce [Exception]
(bactor !! Balance("a-123")).get.asInstanceOf[Int] should equal(5000)
(bactor !! LogSize).get.asInstanceOf[Int] should equal(3)
}
}
}

View file

@ -1,364 +1,158 @@
package se.scalablesolutions.akka.persistence.mongo
import org.junit.{Test, Before}
import org.junit.Assert._
import org.scalatest.junit.JUnitSuite
import _root_.dispatch.json._
import _root_.dispatch.json.Js._
import org.scalatest.Spec
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.BeforeAndAfterEach
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
import java.util.NoSuchElementException
@scala.reflect.BeanInfo case class Foo(no: Int, name: String)
class MongoStorageSpec extends JUnitSuite {
@RunWith(classOf[JUnitRunner])
class MongoStorageSpec extends
Spec with
ShouldMatchers with
BeforeAndAfterEach {
val changeSetV = new scala.collection.mutable.ArrayBuffer[AnyRef]
val changeSetM = new scala.collection.mutable.HashMap[AnyRef, AnyRef]
@Before def initialize() = {
MongoStorageBackend.coll.drop
override def beforeEach {
MongoStorageBackend.drop
}
@Test
def testVectorInsertForTransactionId = {
changeSetV += "debasish" // string
changeSetV += List(1, 2, 3) // Scala List
changeSetV += List(100, 200)
MongoStorageBackend.insertVectorStorageEntriesFor("U-A1", changeSetV.toList)
assertEquals(
3,
MongoStorageBackend.getVectorStorageSizeFor("U-A1"))
changeSetV.clear
// changeSetV should be reinitialized
changeSetV += List(12, 23, 45)
changeSetV += "maulindu"
MongoStorageBackend.insertVectorStorageEntriesFor("U-A1", changeSetV.toList)
assertEquals(
5,
MongoStorageBackend.getVectorStorageSizeFor("U-A1"))
// add more to the same changeSetV
changeSetV += "ramanendu"
changeSetV += Map(1 -> "dg", 2 -> "mc")
// add for a diff transaction
MongoStorageBackend.insertVectorStorageEntriesFor("U-A2", changeSetV.toList)
assertEquals(
4,
MongoStorageBackend.getVectorStorageSizeFor("U-A2"))
// previous transaction change set should remain same
assertEquals(
5,
MongoStorageBackend.getVectorStorageSizeFor("U-A1"))
// test single element entry
MongoStorageBackend.insertVectorStorageEntryFor("U-A1", Map(1->1, 2->4, 3->9))
assertEquals(
6,
MongoStorageBackend.getVectorStorageSizeFor("U-A1"))
override def afterEach {
MongoStorageBackend.drop
}
@Test
def testVectorFetchForKeys = {
describe("persistent maps") {
it("should insert with single key and value") {
import MongoStorageBackend._
// initially everything 0
assertEquals(
0,
MongoStorageBackend.getVectorStorageSizeFor("U-A2"))
assertEquals(
0,
MongoStorageBackend.getVectorStorageSizeFor("U-A1"))
// get some stuff
changeSetV += "debasish"
changeSetV += List(BigDecimal(12), BigDecimal(13), BigDecimal(14))
MongoStorageBackend.insertVectorStorageEntriesFor("U-A1", changeSetV.toList)
assertEquals(
2,
MongoStorageBackend.getVectorStorageSizeFor("U-A1"))
val JsString(str) = MongoStorageBackend.getVectorStorageEntryFor("U-A1", 0).asInstanceOf[JsString]
assertEquals("debasish", str)
val l = MongoStorageBackend.getVectorStorageEntryFor("U-A1", 1).asInstanceOf[JsValue]
val num_list = list ! num
val num_list(l0) = l
assertEquals(List(12, 13, 14), l0)
changeSetV.clear
changeSetV += Map(1->1, 2->4, 3->9)
changeSetV += BigInt(2310)
changeSetV += List(100, 200, 300)
MongoStorageBackend.insertVectorStorageEntriesFor("U-A1", changeSetV.toList)
assertEquals(
5,
MongoStorageBackend.getVectorStorageSizeFor("U-A1"))
val r =
MongoStorageBackend.getVectorStorageRangeFor("U-A1", Some(1), None, 3)
assertEquals(3, r.size)
val lr = r(0).asInstanceOf[JsValue]
val num_list(l1) = lr
assertEquals(List(12, 13, 14), l1)
}
@Test
def testVectorFetchForNonExistentKeys = {
try {
MongoStorageBackend.getVectorStorageEntryFor("U-A1", 1)
fail("should throw an exception")
} catch {case e: NoSuchElementException => {}}
try {
MongoStorageBackend.getVectorStorageRangeFor("U-A1", Some(2), None, 12)
fail("should throw an exception")
} catch {case e: NoSuchElementException => {}}
}
@Test
def testVectorUpdateForTransactionId = {
import MongoStorageBackend._
changeSetV += "debasish" // string
changeSetV += List(1, 2, 3) // Scala List
changeSetV += List(100, 200)
insertVectorStorageEntriesFor("U-A1", changeSetV.toList)
assertEquals(3, getVectorStorageSizeFor("U-A1"))
updateVectorStorageEntryFor("U-A1", 0, "maulindu")
val JsString(str) = getVectorStorageEntryFor("U-A1", 0).asInstanceOf[JsString]
assertEquals("maulindu", str)
updateVectorStorageEntryFor("U-A1", 1, Map("1"->"dg", "2"->"mc"))
val JsObject(m) = getVectorStorageEntryFor("U-A1", 1).asInstanceOf[JsObject]
assertEquals(m.keySet.size, 2)
}
@Test
def testMapInsertForTransactionId = {
fillMap
// add some more to changeSet
changeSetM += "5" -> Foo(12, "dg")
changeSetM += "6" -> java.util.Calendar.getInstance.getTime
// insert all into Mongo
MongoStorageBackend.insertMapStorageEntriesFor("U-M1", changeSetM.toList)
assertEquals(
6,
MongoStorageBackend.getMapStorageSizeFor("U-M1"))
// individual insert api
MongoStorageBackend.insertMapStorageEntryFor("U-M1", "7", "akka")
MongoStorageBackend.insertMapStorageEntryFor("U-M1", "8", List(23, 25))
assertEquals(
8,
MongoStorageBackend.getMapStorageSizeFor("U-M1"))
// add the same changeSet for another transaction
MongoStorageBackend.insertMapStorageEntriesFor("U-M2", changeSetM.toList)
assertEquals(
6,
MongoStorageBackend.getMapStorageSizeFor("U-M2"))
// the first transaction should remain the same
assertEquals(
8,
MongoStorageBackend.getMapStorageSizeFor("U-M1"))
changeSetM.clear
}
@Test
def testMapContents = {
fillMap
MongoStorageBackend.insertMapStorageEntriesFor("U-M1", changeSetM.toList)
MongoStorageBackend.getMapStorageEntryFor("U-M1", "2") match {
case Some(x) => {
val JsString(str) = x.asInstanceOf[JsValue]
assertEquals("peter", str)
}
case None => fail("should fetch peter")
}
MongoStorageBackend.getMapStorageEntryFor("U-M1", "4") match {
case Some(x) => {
val num_list = list ! num
val num_list(l0) = x.asInstanceOf[JsValue]
assertEquals(3, l0.size)
}
case None => fail("should fetch list")
}
MongoStorageBackend.getMapStorageEntryFor("U-M1", "3") match {
case Some(x) => {
val num_list = list ! num
val num_list(l0) = x.asInstanceOf[JsValue]
assertEquals(2, l0.size)
}
case None => fail("should fetch list")
insertMapStorageEntryFor("t1", "odersky".getBytes, "scala".getBytes)
insertMapStorageEntryFor("t1", "gosling".getBytes, "java".getBytes)
insertMapStorageEntryFor("t1", "stroustrup".getBytes, "c++".getBytes)
getMapStorageSizeFor("t1") should equal(3)
new String(getMapStorageEntryFor("t1", "odersky".getBytes).get) should equal("scala")
new String(getMapStorageEntryFor("t1", "gosling".getBytes).get) should equal("java")
new String(getMapStorageEntryFor("t1", "stroustrup".getBytes).get) should equal("c++")
getMapStorageEntryFor("t1", "torvalds".getBytes) should equal(None)
}
// get the entire map
val l: List[Tuple2[AnyRef, AnyRef]] =
MongoStorageBackend.getMapStorageFor("U-M1")
it("should insert with multiple keys and values") {
import MongoStorageBackend._
assertEquals(4, l.size)
assertTrue(l.map(_._1).contains("1"))
assertTrue(l.map(_._1).contains("2"))
assertTrue(l.map(_._1).contains("3"))
assertTrue(l.map(_._1).contains("4"))
val l = List(("stroustrup", "c++"), ("odersky", "scala"), ("gosling", "java"))
insertMapStorageEntriesFor("t1", l.map { case (k, v) => (k.getBytes, v.getBytes) })
getMapStorageSizeFor("t1") should equal(3)
new String(getMapStorageEntryFor("t1", "stroustrup".getBytes).get) should equal("c++")
new String(getMapStorageEntryFor("t1", "gosling".getBytes).get) should equal("java")
new String(getMapStorageEntryFor("t1", "odersky".getBytes).get) should equal("scala")
getMapStorageEntryFor("t1", "torvalds".getBytes) should equal(None)
val JsString(str) = l.filter(_._1 == "2").head._2
assertEquals(str, "peter")
getMapStorageEntryFor("t2", "torvalds".getBytes) should equal(None)
// trying to fetch for a non-existent transaction will throw
try {
MongoStorageBackend.getMapStorageFor("U-M2")
fail("should throw an exception")
} catch {case e: NoSuchElementException => {}}
getMapStorageFor("t1").map { case (k, v) => (new String(k), new String(v)) } should equal (l)
changeSetM.clear
}
removeMapStorageFor("t1", "gosling".getBytes)
getMapStorageSizeFor("t1") should equal(2)
@Test
def testMapContentsByRange = {
fillMap
changeSetM += "5" -> Map(1 -> "dg", 2 -> "mc")
MongoStorageBackend.insertMapStorageEntriesFor("U-M1", changeSetM.toList)
// specify start and count
val l: List[Tuple2[AnyRef, AnyRef]] =
MongoStorageBackend.getMapStorageRangeFor(
"U-M1", Some(Integer.valueOf(2)), None, 3)
assertEquals(3, l.size)
assertEquals("3", l(0)._1.asInstanceOf[String])
val lst = l(0)._2.asInstanceOf[JsValue]
val num_list = list ! num
val num_list(l0) = lst
assertEquals(List(100, 200), l0)
assertEquals("4", l(1)._1.asInstanceOf[String])
val ls = l(1)._2.asInstanceOf[JsValue]
val num_list(l1) = ls
assertEquals(List(10, 20, 30), l1)
// specify start, finish and count where finish - start == count
assertEquals(3,
MongoStorageBackend.getMapStorageRangeFor(
"U-M1", Some(Integer.valueOf(2)), Some(Integer.valueOf(5)), 3).size)
// specify start, finish and count where finish - start > count
assertEquals(3,
MongoStorageBackend.getMapStorageRangeFor(
"U-M1", Some(Integer.valueOf(2)), Some(Integer.valueOf(9)), 3).size)
// do not specify start or finish
assertEquals(3,
MongoStorageBackend.getMapStorageRangeFor(
"U-M1", None, None, 3).size)
// specify finish and count
assertEquals(3,
MongoStorageBackend.getMapStorageRangeFor(
"U-M1", None, Some(Integer.valueOf(3)), 3).size)
// specify start, finish and count where finish < start
assertEquals(3,
MongoStorageBackend.getMapStorageRangeFor(
"U-M1", Some(Integer.valueOf(2)), Some(Integer.valueOf(1)), 3).size)
changeSetM.clear
}
@Test
def testMapStorageRemove = {
fillMap
changeSetM += "5" -> Map(1 -> "dg", 2 -> "mc")
MongoStorageBackend.insertMapStorageEntriesFor("U-M1", changeSetM.toList)
assertEquals(5,
MongoStorageBackend.getMapStorageSizeFor("U-M1"))
// remove key "3"
MongoStorageBackend.removeMapStorageFor("U-M1", "3")
assertEquals(4,
MongoStorageBackend.getMapStorageSizeFor("U-M1"))
try {
MongoStorageBackend.getMapStorageEntryFor("U-M1", "3")
fail("should throw exception")
} catch { case e => {}}
// remove key "4"
MongoStorageBackend.removeMapStorageFor("U-M1", "4")
assertEquals(3,
MongoStorageBackend.getMapStorageSizeFor("U-M1"))
// remove key "2"
MongoStorageBackend.removeMapStorageFor("U-M1", "2")
assertEquals(2,
MongoStorageBackend.getMapStorageSizeFor("U-M1"))
// remove the whole stuff
MongoStorageBackend.removeMapStorageFor("U-M1")
try {
MongoStorageBackend.getMapStorageFor("U-M1")
fail("should throw exception")
} catch { case e: NoSuchElementException => {}}
changeSetM.clear
}
private def fillMap = {
changeSetM += "1" -> "john"
changeSetM += "2" -> "peter"
changeSetM += "3" -> List(100, 200)
changeSetM += "4" -> List(10, 20, 30)
changeSetM
}
@Test
def testRefStorage = {
MongoStorageBackend.getRefStorageFor("U-R1") match {
case None =>
case Some(o) => fail("should be None")
removeMapStorageFor("t1")
getMapStorageSizeFor("t1") should equal(0)
}
val m = Map("1"->1, "2"->4, "3"->9)
MongoStorageBackend.insertRefStorageFor("U-R1", m)
MongoStorageBackend.getRefStorageFor("U-R1") match {
case None => fail("should not be empty")
case Some(r) => {
val a = r.asInstanceOf[JsValue]
val m1 = Symbol("1") ? num
val m2 = Symbol("2") ? num
val m3 = Symbol("3") ? num
it("should do proper range queries") {
import MongoStorageBackend._
val l = List(
("bjarne stroustrup", "c++"),
("martin odersky", "scala"),
("james gosling", "java"),
("yukihiro matsumoto", "ruby"),
("slava pestov", "factor"),
("rich hickey", "clojure"),
("ola bini", "ioke"),
("dennis ritchie", "c"),
("larry wall", "perl"),
("guido van rossum", "python"),
("james strachan", "groovy"))
insertMapStorageEntriesFor("t1", l.map { case (k, v) => (k.getBytes, v.getBytes) })
getMapStorageSizeFor("t1") should equal(l.size)
getMapStorageRangeFor("t1", None, None, 100).map { case (k, v) => (new String(k), new String(v)) } should equal(l.sortWith(_._1 < _._1))
getMapStorageRangeFor("t1", None, None, 5).map { case (k, v) => (new String(k), new String(v)) }.size should equal(5)
}
}
val m1(n1) = a
val m2(n2) = a
val m3(n3) = a
describe("persistent vectors") {
it("should insert a single value") {
import MongoStorageBackend._
assertEquals(n1, 1)
assertEquals(n2, 4)
assertEquals(n3, 9)
}
insertVectorStorageEntryFor("t1", "martin odersky".getBytes)
insertVectorStorageEntryFor("t1", "james gosling".getBytes)
new String(getVectorStorageEntryFor("t1", 0)) should equal("james gosling")
new String(getVectorStorageEntryFor("t1", 1)) should equal("martin odersky")
}
// insert another one
// the previous one should be replaced
val b = List("100", "jonas")
MongoStorageBackend.insertRefStorageFor("U-R1", b)
MongoStorageBackend.getRefStorageFor("U-R1") match {
case None => fail("should not be empty")
case Some(r) => {
val a = r.asInstanceOf[JsValue]
val str_lst = list ! str
val str_lst(l) = a
assertEquals(b, l)
}
it("should insert multiple values") {
import MongoStorageBackend._
insertVectorStorageEntryFor("t1", "martin odersky".getBytes)
insertVectorStorageEntryFor("t1", "james gosling".getBytes)
insertVectorStorageEntriesFor("t1", List("ola bini".getBytes, "james strachan".getBytes, "dennis ritchie".getBytes))
new String(getVectorStorageEntryFor("t1", 0)) should equal("ola bini")
new String(getVectorStorageEntryFor("t1", 1)) should equal("james strachan")
new String(getVectorStorageEntryFor("t1", 2)) should equal("dennis ritchie")
new String(getVectorStorageEntryFor("t1", 3)) should equal("james gosling")
new String(getVectorStorageEntryFor("t1", 4)) should equal("martin odersky")
}
it("should fetch a range of values") {
import MongoStorageBackend._
insertVectorStorageEntryFor("t1", "martin odersky".getBytes)
insertVectorStorageEntryFor("t1", "james gosling".getBytes)
getVectorStorageSizeFor("t1") should equal(2)
insertVectorStorageEntriesFor("t1", List("ola bini".getBytes, "james strachan".getBytes, "dennis ritchie".getBytes))
getVectorStorageRangeFor("t1", None, None, 100).map(new String(_)) should equal(List("ola bini", "james strachan", "dennis ritchie", "james gosling", "martin odersky"))
getVectorStorageRangeFor("t1", Some(0), Some(5), 100).map(new String(_)) should equal(List("ola bini", "james strachan", "dennis ritchie", "james gosling", "martin odersky"))
getVectorStorageRangeFor("t1", Some(2), Some(5), 100).map(new String(_)) should equal(List("dennis ritchie", "james gosling", "martin odersky"))
getVectorStorageSizeFor("t1") should equal(5)
}
it("should insert and query complex structures") {
import MongoStorageBackend._
import sjson.json.DefaultProtocol._
import sjson.json.JsonSerialization._
// a list[AnyRef] should be added successfully
val l = List("ola bini".getBytes, tobinary(List(100, 200, 300)), tobinary(List(1, 2, 3)))
// for id = t1
insertVectorStorageEntriesFor("t1", l)
new String(getVectorStorageEntryFor("t1", 0)) should equal("ola bini")
frombinary[List[Int]](getVectorStorageEntryFor("t1", 1)) should equal(List(100, 200, 300))
frombinary[List[Int]](getVectorStorageEntryFor("t1", 2)) should equal(List(1, 2, 3))
getVectorStorageSizeFor("t1") should equal(3)
// some more for id = t1
val m = List(tobinary(Map(1 -> "dg", 2 -> "mc", 3 -> "nd")), tobinary(List("martin odersky", "james gosling")))
insertVectorStorageEntriesFor("t1", m)
// size should add up
getVectorStorageSizeFor("t1") should equal(5)
// now for a diff id
insertVectorStorageEntriesFor("t2", l)
getVectorStorageSizeFor("t2") should equal(3)
}
}
describe("persistent refs") {
it("should insert a ref") {
import MongoStorageBackend._
insertRefStorageFor("t1", "martin odersky".getBytes)
new String(getRefStorageFor("t1").get) should equal("martin odersky")
insertRefStorageFor("t1", "james gosling".getBytes)
new String(getRefStorageFor("t1").get) should equal("james gosling")
getRefStorageFor("t2") should equal(None)
}
}
}

View file

@ -0,0 +1,347 @@
package se.scalablesolutions.akka.persistence.mongo
import org.scalatest.Spec
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
import se.scalablesolutions.akka.actor.{Actor, ActorRef}
import se.scalablesolutions.akka.config.OneForOneStrategy
import Actor._
import se.scalablesolutions.akka.stm.global._
import se.scalablesolutions.akka.config.ScalaConfig._
import se.scalablesolutions.akka.util.Logging
import MongoStorageBackend._
case class GET(k: String)
case class SET(k: String, v: String)
case class REM(k: String)
case class CONTAINS(k: String)
case object MAP_SIZE
case class MSET(kvs: List[(String, String)])
case class REMOVE_AFTER_PUT(kvsToAdd: List[(String, String)], ksToRem: List[String])
case class CLEAR_AFTER_PUT(kvsToAdd: List[(String, String)])
case class PUT_WITH_SLICE(kvsToAdd: List[(String, String)], start: String, cnt: Int)
case class PUT_REM_WITH_SLICE(kvsToAdd: List[(String, String)], ksToRem: List[String], start: String, cnt: Int)
case class VADD(v: String)
case class VUPD(i: Int, v: String)
case class VUPD_AND_ABORT(i: Int, v: String)
case class VGET(i: Int)
case object VSIZE
case class VGET_AFTER_VADD(vsToAdd: List[String], isToFetch: List[Int])
case class VADD_WITH_SLICE(vsToAdd: List[String], start: Int, cnt: Int)
object Storage {
class MongoSampleMapStorage extends Actor {
self.lifeCycle = Some(LifeCycle(Permanent))
val FOO_MAP = "akka.sample.map"
private var fooMap = atomic { MongoStorage.getMap(FOO_MAP) }
def receive = {
case SET(k, v) =>
atomic {
fooMap += (k.getBytes, v.getBytes)
}
self.reply((k, v))
case GET(k) =>
val v = atomic {
fooMap.get(k.getBytes).map(new String(_)).getOrElse(k + " Not found")
}
self.reply(v)
case REM(k) =>
val v = atomic {
fooMap -= k.getBytes
}
self.reply(k)
case CONTAINS(k) =>
val v = atomic {
fooMap contains k.getBytes
}
self.reply(v)
case MAP_SIZE =>
val v = atomic {
fooMap.size
}
self.reply(v)
case MSET(kvs) => atomic {
kvs.foreach {kv => fooMap += (kv._1.getBytes, kv._2.getBytes) }
}
self.reply(kvs.size)
case REMOVE_AFTER_PUT(kvs2add, ks2rem) => atomic {
kvs2add.foreach {kv =>
fooMap += (kv._1.getBytes, kv._2.getBytes)
}
ks2rem.foreach {k =>
fooMap -= k.getBytes
}}
self.reply(fooMap.size)
case CLEAR_AFTER_PUT(kvs2add) => atomic {
kvs2add.foreach {kv =>
fooMap += (kv._1.getBytes, kv._2.getBytes)
}
fooMap.clear
}
self.reply(true)
case PUT_WITH_SLICE(kvs2add, from, cnt) =>
val v = atomic {
kvs2add.foreach {kv =>
fooMap += (kv._1.getBytes, kv._2.getBytes)
}
fooMap.slice(Some(from.getBytes), cnt)
}
self.reply(v: List[(Array[Byte], Array[Byte])])
case PUT_REM_WITH_SLICE(kvs2add, ks2rem, from, cnt) =>
val v = atomic {
kvs2add.foreach {kv =>
fooMap += (kv._1.getBytes, kv._2.getBytes)
}
ks2rem.foreach {k =>
fooMap -= k.getBytes
}
fooMap.slice(Some(from.getBytes), cnt)
}
self.reply(v: List[(Array[Byte], Array[Byte])])
}
}
class MongoSampleVectorStorage extends Actor {
self.lifeCycle = Some(LifeCycle(Permanent))
val FOO_VECTOR = "akka.sample.vector"
private var fooVector = atomic { MongoStorage.getVector(FOO_VECTOR) }
def receive = {
case VADD(v) =>
val size =
atomic {
fooVector + v.getBytes
fooVector length
}
self.reply(size)
case VGET(index) =>
val ind =
atomic {
fooVector get index
}
self.reply(ind)
case VGET_AFTER_VADD(vs, is) =>
val els =
atomic {
vs.foreach(fooVector + _.getBytes)
(is.foldRight(List[Array[Byte]]())(fooVector.get(_) :: _)).map(new String(_))
}
self.reply(els)
case VUPD_AND_ABORT(index, value) =>
val l =
atomic {
fooVector.update(index, value.getBytes)
// force fail
fooVector get 100
}
self.reply(index)
case VADD_WITH_SLICE(vs, s, c) =>
val l =
atomic {
vs.foreach(fooVector + _.getBytes)
fooVector.slice(Some(s), None, c)
}
self.reply(l.map(new String(_)))
}
}
}
import Storage._
@RunWith(classOf[JUnitRunner])
class MongoTicket343Spec extends
Spec with
ShouldMatchers with
BeforeAndAfterAll with
BeforeAndAfterEach {
override def beforeAll {
MongoStorageBackend.drop
println("** destroyed database")
}
override def beforeEach {
MongoStorageBackend.drop
println("** destroyed database")
}
override def afterEach {
MongoStorageBackend.drop
println("** destroyed database")
}
describe("Ticket 343 Issue #1") {
it("remove after put should work within the same transaction") {
val proc = actorOf[MongoSampleMapStorage]
proc.start
(proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
(proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
(proc !! MSET(List(("dg", "1"), ("mc", "2"), ("nd", "3")))).getOrElse("Mset failed") should equal(3)
(proc !! GET("dg")).getOrElse("Get failed") should equal("1")
(proc !! GET("mc")).getOrElse("Get failed") should equal("2")
(proc !! GET("nd")).getOrElse("Get failed") should equal("3")
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(4)
val add = List(("a", "1"), ("b", "2"), ("c", "3"))
val rem = List("a", "debasish")
(proc !! REMOVE_AFTER_PUT(add, rem)).getOrElse("REMOVE_AFTER_PUT failed") should equal(5)
(proc !! GET("debasish")).getOrElse("debasish not found") should equal("debasish Not found")
(proc !! GET("a")).getOrElse("a not found") should equal("a Not found")
(proc !! GET("b")).getOrElse("b not found") should equal("2")
(proc !! CONTAINS("b")).getOrElse("b not found") should equal(true)
(proc !! CONTAINS("debasish")).getOrElse("debasish not found") should equal(false)
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(5)
proc.stop
}
}
describe("Ticket 343 Issue #2") {
it("clear after put should work within the same transaction") {
val proc = actorOf[MongoSampleMapStorage]
proc.start
(proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
(proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
val add = List(("a", "1"), ("b", "2"), ("c", "3"))
(proc !! CLEAR_AFTER_PUT(add)).getOrElse("CLEAR_AFTER_PUT failed") should equal(true)
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
proc.stop
}
}
describe("Ticket 343 Issue #3") {
it("map size should change after the transaction") {
val proc = actorOf[MongoSampleMapStorage]
proc.start
(proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
(proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
(proc !! MSET(List(("dg", "1"), ("mc", "2"), ("nd", "3")))).getOrElse("Mset failed") should equal(3)
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(4)
(proc !! GET("dg")).getOrElse("Get failed") should equal("1")
(proc !! GET("mc")).getOrElse("Get failed") should equal("2")
(proc !! GET("nd")).getOrElse("Get failed") should equal("3")
proc.stop
}
}
describe("slice test") {
it("should pass") {
val proc = actorOf[MongoSampleMapStorage]
proc.start
(proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
(proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
// (proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
(proc !! MSET(List(("dg", "1"), ("mc", "2"), ("nd", "3")))).getOrElse("Mset failed") should equal(3)
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(4)
(proc !! PUT_WITH_SLICE(List(("ec", "1"), ("tb", "2"), ("mc", "10")), "dg", 3)).get.asInstanceOf[List[(Array[Byte], Array[Byte])]].map { case (k, v) => (new String(k), new String(v)) } should equal(List(("dg", "1"), ("ec", "1"), ("mc", "10")))
(proc !! PUT_REM_WITH_SLICE(List(("fc", "1"), ("gb", "2"), ("xy", "10")), List("tb", "fc"), "dg", 5)).get.asInstanceOf[List[(Array[Byte], Array[Byte])]].map { case (k, v) => (new String(k), new String(v)) } should equal(List(("dg", "1"), ("ec", "1"), ("gb", "2"), ("mc", "10"), ("nd", "3")))
proc.stop
}
}
describe("Ticket 343 Issue #4") {
it("vector get should not ignore elements that were in vector before transaction") {
val proc = actorOf[MongoSampleVectorStorage]
proc.start
// add 4 elements in separate transactions
(proc !! VADD("debasish")).getOrElse("VADD failed") should equal(1)
(proc !! VADD("maulindu")).getOrElse("VADD failed") should equal(2)
(proc !! VADD("ramanendu")).getOrElse("VADD failed") should equal(3)
(proc !! VADD("nilanjan")).getOrElse("VADD failed") should equal(4)
new String((proc !! VGET(0)).get.asInstanceOf[Array[Byte]] ) should equal("nilanjan")
new String((proc !! VGET(1)).get.asInstanceOf[Array[Byte]] ) should equal("ramanendu")
new String((proc !! VGET(2)).get.asInstanceOf[Array[Byte]] ) should equal("maulindu")
new String((proc !! VGET(3)).get.asInstanceOf[Array[Byte]] ) should equal("debasish")
// now add 3 more and do gets in the same transaction
(proc !! VGET_AFTER_VADD(List("a", "b", "c"), List(0, 2, 4))).get.asInstanceOf[List[String]] should equal(List("c", "a", "ramanendu"))
proc.stop
}
}
describe("Ticket 343 Issue #6") {
it("vector update should not ignore transaction") {
val proc = actorOf[MongoSampleVectorStorage]
proc.start
// add 4 elements in separate transactions
(proc !! VADD("debasish")).getOrElse("VADD failed") should equal(1)
(proc !! VADD("maulindu")).getOrElse("VADD failed") should equal(2)
(proc !! VADD("ramanendu")).getOrElse("VADD failed") should equal(3)
(proc !! VADD("nilanjan")).getOrElse("VADD failed") should equal(4)
evaluating {
(proc !! VUPD_AND_ABORT(0, "virat")).getOrElse("VUPD_AND_ABORT failed")
} should produce [Exception]
// update aborts and hence values will remain unchanged
new String((proc !! VGET(0)).get.asInstanceOf[Array[Byte]] ) should equal("nilanjan")
proc.stop
}
}
describe("Ticket 343 Issue #5") {
it("vector slice() should not ignore elements added in current transaction") {
val proc = actorOf[MongoSampleVectorStorage]
proc.start
// add 4 elements in separate transactions
(proc !! VADD("debasish")).getOrElse("VADD failed") should equal(1)
(proc !! VADD("maulindu")).getOrElse("VADD failed") should equal(2)
(proc !! VADD("ramanendu")).getOrElse("VADD failed") should equal(3)
(proc !! VADD("nilanjan")).getOrElse("VADD failed") should equal(4)
// slice with no new elements added in current transaction
(proc !! VADD_WITH_SLICE(List(), 2, 2)).getOrElse("VADD_WITH_SLICE failed") should equal(Vector("maulindu", "debasish"))
// slice with new elements added in current transaction
(proc !! VADD_WITH_SLICE(List("a", "b", "c", "d"), 2, 2)).getOrElse("VADD_WITH_SLICE failed") should equal(Vector("b", "a"))
proc.stop
}
}
}

View file

@ -36,7 +36,7 @@ object RedisStorage extends Storage {
*
* @author <a href="http://debasishg.blogspot.com">Debasish Ghosh</a>
*/
class RedisPersistentMap(id: String) extends PersistentMap[Array[Byte], Array[Byte]] {
class RedisPersistentMap(id: String) extends PersistentMapBinary {
val uuid = id
val storage = RedisStorageBackend
}

View file

@ -96,12 +96,12 @@ private [akka] object RedisStorageBackend extends
* <li>both parts of the key need to be based64 encoded since there can be spaces within each of them</li>
*/
private [this] def makeRedisKey(name: String, key: Array[Byte]): String = withErrorHandling {
"%s:%s".format(name, byteArrayToString(key))
"%s:%s".format(name, new String(key))
}
private [this] def makeKeyFromRedisKey(redisKey: String) = withErrorHandling {
val nk = redisKey.split(':')
(nk(0), stringToByteArray(nk(1)))
(nk(0), nk(1).getBytes)
}
private [this] def mset(entries: List[(String, String)]): Unit = withErrorHandling {
@ -124,27 +124,22 @@ private [akka] object RedisStorageBackend extends
}
def getMapStorageEntryFor(name: String, key: Array[Byte]): Option[Array[Byte]] = withErrorHandling {
db.get(makeRedisKey(name, key)) match {
case None =>
throw new NoSuchElementException(new String(key) + " not present")
case Some(s) => Some(stringToByteArray(s))
db.get(makeRedisKey(name, key))
.map(stringToByteArray(_))
.orElse(throw new NoSuchElementException(new String(key) + " not present"))
}
}
def getMapStorageSizeFor(name: String): Int = withErrorHandling {
db.keys("%s:*".format(name)) match {
case None => 0
case Some(keys) => keys.length
}
db.keys("%s:*".format(name)).map(_.length).getOrElse(0)
}
def getMapStorageFor(name: String): List[(Array[Byte], Array[Byte])] = withErrorHandling {
db.keys("%s:*".format(name)) match {
case None =>
throw new NoSuchElementException(name + " not present")
case Some(keys) =>
db.keys("%s:*".format(name))
.map { keys =>
keys.map(key => (makeKeyFromRedisKey(key.get)._2, stringToByteArray(db.get(key.get).get))).toList
}
}.getOrElse {
throw new NoSuchElementException(name + " not present")
}
}
def getMapStorageRangeFor(name: String, start: Option[Array[Byte]],
@ -207,12 +202,11 @@ private [akka] object RedisStorageBackend extends
}
def getVectorStorageEntryFor(name: String, index: Int): Array[Byte] = withErrorHandling {
db.lindex(name, index) match {
case None =>
db.lindex(name, index)
.map(stringToByteArray(_))
.getOrElse {
throw new NoSuchElementException(name + " does not have element at " + index)
case Some(e) =>
stringToByteArray(e)
}
}
}
/**
@ -252,11 +246,11 @@ private [akka] object RedisStorageBackend extends
}
def getRefStorageFor(name: String): Option[Array[Byte]] = withErrorHandling {
db.get(name) match {
case None =>
db.get(name)
.map(stringToByteArray(_))
.orElse {
throw new NoSuchElementException(name + " not present")
case Some(s) => Some(stringToByteArray(s))
}
}
}
// add to the end of the queue
@ -266,11 +260,11 @@ private [akka] object RedisStorageBackend extends
// pop from the front of the queue
def dequeue(name: String): Option[Array[Byte]] = withErrorHandling {
db.lpop(name) match {
case None =>
db.lpop(name)
.map(stringToByteArray(_))
.orElse {
throw new NoSuchElementException(name + " not present")
case Some(s) => Some(stringToByteArray(s))
}
}
}
// get the size of the queue
@ -302,26 +296,19 @@ private [akka] object RedisStorageBackend extends
// completely delete the queue
def remove(name: String): Boolean = withErrorHandling {
db.del(name) match {
case Some(1) => true
case _ => false
}
db.del(name).map { case 1 => true }.getOrElse(false)
}
// add item to sorted set identified by name
def zadd(name: String, zscore: String, item: Array[Byte]): Boolean = withErrorHandling {
db.zadd(name, zscore, byteArrayToString(item)) match {
case Some(1) => true
case _ => false
}
db.zadd(name, zscore, byteArrayToString(item))
.map { case 1 => true }.getOrElse(false)
}
// remove item from sorted set identified by name
def zrem(name: String, item: Array[Byte]): Boolean = withErrorHandling {
db.zrem(name, byteArrayToString(item)) match {
case Some(1) => true
case _ => false
}
db.zrem(name, byteArrayToString(item))
.map { case 1 => true }.getOrElse(false)
}
// cardinality of the set identified by name
@ -330,29 +317,23 @@ private [akka] object RedisStorageBackend extends
}
def zscore(name: String, item: Array[Byte]): Option[Float] = withErrorHandling {
db.zscore(name, byteArrayToString(item)) match {
case Some(s) => Some(s.toFloat)
case None => None
}
db.zscore(name, byteArrayToString(item)).map(_.toFloat)
}
def zrange(name: String, start: Int, end: Int): List[Array[Byte]] = withErrorHandling {
db.zrange(name, start.toString, end.toString, RedisClient.ASC, false) match {
case None =>
db.zrange(name, start.toString, end.toString, RedisClient.ASC, false)
.map(_.map(e => stringToByteArray(e.get)))
.getOrElse {
throw new NoSuchElementException(name + " not present")
case Some(s) =>
s.map(e => stringToByteArray(e.get))
}
}
}
def zrangeWithScore(name: String, start: Int, end: Int): List[(Array[Byte], Float)] = withErrorHandling {
db.zrangeWithScore(
name, start.toString, end.toString, RedisClient.ASC) match {
case None =>
throw new NoSuchElementException(name + " not present")
case Some(l) =>
l.map{ case (elem, score) => (stringToByteArray(elem.get), score.get.toFloat) }
}
db.zrangeWithScore(name, start.toString, end.toString, RedisClient.ASC)
.map(_.map { case (elem, score) => (stringToByteArray(elem.get), score.get.toFloat) })
.getOrElse {
throw new NoSuchElementException(name + " not present")
}
}
def flushDB = withErrorHandling(db.flushdb)

View file

@ -0,0 +1,351 @@
package se.scalablesolutions.akka.persistence.redis
import org.scalatest.Spec
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
import se.scalablesolutions.akka.actor.{Actor}
import se.scalablesolutions.akka.config.OneForOneStrategy
import Actor._
import se.scalablesolutions.akka.persistence.common.PersistentVector
import se.scalablesolutions.akka.stm.global._
import se.scalablesolutions.akka.config.ScalaConfig._
import se.scalablesolutions.akka.util.Logging
import RedisStorageBackend._
case class GET(k: String)
case class SET(k: String, v: String)
case class REM(k: String)
case class CONTAINS(k: String)
case object MAP_SIZE
case class MSET(kvs: List[(String, String)])
case class REMOVE_AFTER_PUT(kvsToAdd: List[(String, String)], ksToRem: List[String])
case class CLEAR_AFTER_PUT(kvsToAdd: List[(String, String)])
case class PUT_WITH_SLICE(kvsToAdd: List[(String, String)], start: String, cnt: Int)
case class PUT_REM_WITH_SLICE(kvsToAdd: List[(String, String)], ksToRem: List[String], start: String, cnt: Int)
case class VADD(v: String)
case class VUPD(i: Int, v: String)
case class VUPD_AND_ABORT(i: Int, v: String)
case class VGET(i: Int)
case object VSIZE
case class VGET_AFTER_VADD(vsToAdd: List[String], isToFetch: List[Int])
case class VADD_WITH_SLICE(vsToAdd: List[String], start: Int, cnt: Int)
object Storage {
class RedisSampleMapStorage extends Actor {
self.lifeCycle = Some(LifeCycle(Permanent))
val FOO_MAP = "akka.sample.map"
private var fooMap = atomic { RedisStorage.getMap(FOO_MAP) }
def receive = {
case SET(k, v) =>
atomic {
fooMap += (k.getBytes, v.getBytes)
}
self.reply((k, v))
case GET(k) =>
val v = atomic {
fooMap.get(k.getBytes)
}
self.reply(v.collect {case byte => new String(byte)}.getOrElse(k + " Not found"))
case REM(k) =>
val v = atomic {
fooMap -= k.getBytes
}
self.reply(k)
case CONTAINS(k) =>
val v = atomic {
fooMap contains k.getBytes
}
self.reply(v)
case MAP_SIZE =>
val v = atomic {
fooMap.size
}
self.reply(v)
case MSET(kvs) =>
atomic {
kvs.foreach {kv =>
fooMap += (kv._1.getBytes, kv._2.getBytes)
}
}
self.reply(kvs.size)
case REMOVE_AFTER_PUT(kvs2add, ks2rem) =>
val v =
atomic {
kvs2add.foreach {kv =>
fooMap += (kv._1.getBytes, kv._2.getBytes)
}
ks2rem.foreach {k =>
fooMap -= k.getBytes
}
fooMap.size
}
self.reply(v)
case CLEAR_AFTER_PUT(kvs2add) =>
atomic {
kvs2add.foreach {kv =>
fooMap += (kv._1.getBytes, kv._2.getBytes)
}
fooMap.clear
}
self.reply(true)
case PUT_WITH_SLICE(kvs2add, from, cnt) =>
val v =
atomic {
kvs2add.foreach {kv =>
fooMap += (kv._1.getBytes, kv._2.getBytes)
}
fooMap.slice(Some(from.getBytes), cnt)
}
self.reply(v: List[(Array[Byte], Array[Byte])])
case PUT_REM_WITH_SLICE(kvs2add, ks2rem, from, cnt) =>
val v =
atomic {
kvs2add.foreach {kv =>
fooMap += (kv._1.getBytes, kv._2.getBytes)
}
ks2rem.foreach {k =>
fooMap -= k.getBytes
}
fooMap.slice(Some(from.getBytes), cnt)
}
self.reply(v: List[(Array[Byte], Array[Byte])])
}
}
class RedisSampleVectorStorage extends Actor {
self.lifeCycle = Some(LifeCycle(Permanent))
val FOO_VECTOR = "akka.sample.vector"
private var fooVector = atomic { RedisStorage.getVector(FOO_VECTOR) }
def receive = {
case VADD(v) =>
val size =
atomic {
fooVector + v.getBytes
fooVector length
}
self.reply(size)
case VGET(index) =>
val ind =
atomic {
fooVector get index
}
self.reply(ind)
case VGET_AFTER_VADD(vs, is) =>
val els =
atomic {
vs.foreach(fooVector + _.getBytes)
(is.foldRight(List[Array[Byte]]())(fooVector.get(_) :: _)).map(new String(_))
}
self.reply(els)
case VUPD_AND_ABORT(index, value) =>
val l =
atomic {
fooVector.update(index, value.getBytes)
// force fail
fooVector get 100
}
self.reply(index)
case VADD_WITH_SLICE(vs, s, c) =>
val l =
atomic {
vs.foreach(fooVector + _.getBytes)
fooVector.slice(Some(s), None, c)
}
self.reply(l.map(new String(_)))
}
}
}
import Storage._
@RunWith(classOf[JUnitRunner])
class RedisTicket343Spec extends
Spec with
ShouldMatchers with
BeforeAndAfterAll with
BeforeAndAfterEach {
override def beforeAll {
flushDB
println("** destroyed database")
}
override def afterEach {
flushDB
println("** destroyed database")
}
describe("Ticket 343 Issue #1") {
it("remove after put should work within the same transaction") {
val proc = actorOf[RedisSampleMapStorage]
proc.start
(proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
(proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
(proc !! MSET(List(("dg", "1"), ("mc", "2"), ("nd", "3")))).getOrElse("Mset failed") should equal(3)
(proc !! GET("dg")).getOrElse("Get failed") should equal("1")
(proc !! GET("mc")).getOrElse("Get failed") should equal("2")
(proc !! GET("nd")).getOrElse("Get failed") should equal("3")
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(4)
val add = List(("a", "1"), ("b", "2"), ("c", "3"))
val rem = List("a", "debasish")
(proc !! REMOVE_AFTER_PUT(add, rem)).getOrElse("REMOVE_AFTER_PUT failed") should equal(5)
(proc !! GET("debasish")).getOrElse("debasish not found") should equal("debasish Not found")
(proc !! GET("a")).getOrElse("a not found") should equal("a Not found")
(proc !! GET("b")).getOrElse("b not found") should equal("2")
(proc !! CONTAINS("b")).getOrElse("b not found") should equal(true)
(proc !! CONTAINS("debasish")).getOrElse("debasish not found") should equal(false)
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(5)
proc.stop
}
}
describe("Ticket 343 Issue #2") {
it("clear after put should work within the same transaction") {
val proc = actorOf[RedisSampleMapStorage]
proc.start
(proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
(proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
val add = List(("a", "1"), ("b", "2"), ("c", "3"))
(proc !! CLEAR_AFTER_PUT(add)).getOrElse("CLEAR_AFTER_PUT failed") should equal(true)
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
proc.stop
}
}
describe("Ticket 343 Issue #3") {
it("map size should change after the transaction") {
val proc = actorOf[RedisSampleMapStorage]
proc.start
(proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
(proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
(proc !! MSET(List(("dg", "1"), ("mc", "2"), ("nd", "3")))).getOrElse("Mset failed") should equal(3)
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(4)
(proc !! GET("dg")).getOrElse("Get failed") should equal("1")
(proc !! GET("mc")).getOrElse("Get failed") should equal("2")
(proc !! GET("nd")).getOrElse("Get failed") should equal("3")
proc.stop
}
}
describe("slice test") {
it("should pass") {
val proc = actorOf[RedisSampleMapStorage]
proc.start
(proc !! SET("debasish", "anshinsoft")).getOrElse("Set failed") should equal(("debasish", "anshinsoft"))
(proc !! GET("debasish")).getOrElse("Get failed") should equal("anshinsoft")
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(1)
(proc !! MSET(List(("dg", "1"), ("mc", "2"), ("nd", "3")))).getOrElse("Mset failed") should equal(3)
(proc !! MAP_SIZE).getOrElse("Size failed") should equal(4)
(proc !! PUT_WITH_SLICE(List(("ec", "1"), ("tb", "2"), ("mc", "10")), "dg", 3)).get.asInstanceOf[List[(Array[Byte], Array[Byte])]].map { case (k, v) => (new String(k), new String(v)) } should equal(List(("dg", "1"), ("ec", "1"), ("mc", "10")))
(proc !! PUT_REM_WITH_SLICE(List(("fc", "1"), ("gb", "2"), ("xy", "10")), List("tb", "fc"), "dg", 5)).get.asInstanceOf[List[(Array[Byte], Array[Byte])]].map { case (k, v) => (new String(k), new String(v)) } should equal(List(("dg", "1"), ("ec", "1"), ("gb", "2"), ("mc", "10"), ("nd", "3")))
proc.stop
}
}
describe("Ticket 343 Issue #4") {
it("vector get should not ignore elements that were in vector before transaction") {
val proc = actorOf[RedisSampleVectorStorage]
proc.start
// add 4 elements in separate transactions
(proc !! VADD("debasish")).getOrElse("VADD failed") should equal(1)
(proc !! VADD("maulindu")).getOrElse("VADD failed") should equal(2)
(proc !! VADD("ramanendu")).getOrElse("VADD failed") should equal(3)
(proc !! VADD("nilanjan")).getOrElse("VADD failed") should equal(4)
new String((proc !! VGET(0)).get.asInstanceOf[Array[Byte]] ) should equal("nilanjan")
new String((proc !! VGET(1)).get.asInstanceOf[Array[Byte]] ) should equal("ramanendu")
new String((proc !! VGET(2)).get.asInstanceOf[Array[Byte]] ) should equal("maulindu")
new String((proc !! VGET(3)).get.asInstanceOf[Array[Byte]] ) should equal("debasish")
// now add 3 more and do gets in the same transaction
(proc !! VGET_AFTER_VADD(List("a", "b", "c"), List(0, 2, 4))).get.asInstanceOf[List[String]] should equal(List("c", "a", "ramanendu"))
proc.stop
}
}
describe("Ticket 343 Issue #6") {
it("vector update should not ignore transaction") {
val proc = actorOf[RedisSampleVectorStorage]
proc.start
// add 4 elements in separate transactions
(proc !! VADD("debasish")).getOrElse("VADD failed") should equal(1)
(proc !! VADD("maulindu")).getOrElse("VADD failed") should equal(2)
(proc !! VADD("ramanendu")).getOrElse("VADD failed") should equal(3)
(proc !! VADD("nilanjan")).getOrElse("VADD failed") should equal(4)
evaluating {
(proc !! VUPD_AND_ABORT(0, "virat")).getOrElse("VUPD_AND_ABORT failed")
} should produce [Exception]
// update aborts and hence values will remain unchanged
new String((proc !! VGET(0)).get.asInstanceOf[Array[Byte]] ) should equal("nilanjan")
proc.stop
}
}
describe("Ticket 343 Issue #5") {
it("vector slice() should not ignore elements added in current transaction") {
val proc = actorOf[RedisSampleVectorStorage]
proc.start
// add 4 elements in separate transactions
(proc !! VADD("debasish")).getOrElse("VADD failed") should equal(1)
(proc !! VADD("maulindu")).getOrElse("VADD failed") should equal(2)
(proc !! VADD("ramanendu")).getOrElse("VADD failed") should equal(3)
(proc !! VADD("nilanjan")).getOrElse("VADD failed") should equal(4)
// slice with no new elements added in current transaction
(proc !! VADD_WITH_SLICE(List(), 2, 2)).getOrElse("VADD_WITH_SLICE failed") should equal(Vector("maulindu", "debasish"))
// slice with new elements added in current transaction
(proc !! VADD_WITH_SLICE(List("a", "b", "c", "d"), 2, 2)).getOrElse("VADD_WITH_SLICE failed") should equal(Vector("b", "a"))
proc.stop
}
}
}

View file

@ -30,6 +30,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.{HashSet, HashMap}
import scala.reflect.BeanProperty
import se.scalablesolutions.akka.actor._
/**
* Atomic remote request/reply message id generator.
@ -76,8 +77,6 @@ object RemoteClient extends Logging {
private val remoteClients = new HashMap[String, RemoteClient]
private val remoteActors = new HashMap[RemoteServer.Address, HashSet[String]]
// FIXME: simplify overloaded methods when we have Scala 2.8
def actorFor(classNameOrServiceId: String, hostname: String, port: Int): ActorRef =
actorFor(classNameOrServiceId, classNameOrServiceId, 5000L, hostname, port, None)
@ -99,6 +98,27 @@ object RemoteClient extends Logging {
def actorFor(serviceId: String, className: String, timeout: Long, hostname: String, port: Int): ActorRef =
RemoteActorRef(serviceId, className, hostname, port, timeout, None)
def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, hostname: String, port: Int) : T = {
typedActorFor(intfClass, serviceIdOrClassName, serviceIdOrClassName, 5000L, hostname, port, None)
}
def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, timeout: Long, hostname: String, port: Int) : T = {
typedActorFor(intfClass, serviceIdOrClassName, serviceIdOrClassName, timeout, hostname, port, None)
}
def typedActorFor[T](intfClass: Class[T], serviceIdOrClassName: String, timeout: Long, hostname: String, port: Int, loader: ClassLoader) : T = {
typedActorFor(intfClass, serviceIdOrClassName, serviceIdOrClassName, timeout, hostname, port, Some(loader))
}
def typedActorFor[T](intfClass: Class[T], serviceId: String, implClassName: String, timeout: Long, hostname: String, port: Int, loader: ClassLoader) : T = {
typedActorFor(intfClass, serviceId, implClassName, timeout, hostname, port, Some(loader))
}
private[akka] def typedActorFor[T](intfClass: Class[T], serviceId: String, implClassName: String, timeout: Long, hostname: String, port: Int, loader: Option[ClassLoader]) : T = {
val actorRef = RemoteActorRef(serviceId, implClassName, hostname, port, timeout, loader, ActorType.TypedActor)
TypedActor.createProxyForRemoteActorRef(intfClass, actorRef)
}
private[akka] def actorFor(serviceId: String, className: String, timeout: Long, hostname: String, port: Int, loader: ClassLoader): ActorRef =
RemoteActorRef(serviceId, className, hostname, port, timeout, Some(loader))
@ -220,10 +240,10 @@ class RemoteClient private[akka] (
val channel = connection.awaitUninterruptibly.getChannel
openChannels.add(channel)
if (!connection.isSuccess) {
foreachListener(_ ! RemoteClientError(connection.getCause, this))
notifyListeners(RemoteClientError(connection.getCause, this))
log.error(connection.getCause, "Remote client connection to [%s:%s] has failed", hostname, port)
}
foreachListener(_ ! RemoteClientStarted(this))
notifyListeners(RemoteClientStarted(this))
isRunning = true
}
}
@ -232,7 +252,7 @@ class RemoteClient private[akka] (
log.info("Shutting down %s", name)
if (isRunning) {
isRunning = false
foreachListener(_ ! RemoteClientShutdown(this))
notifyListeners(RemoteClientShutdown(this))
timer.stop
timer = null
openChannels.close.awaitUninterruptibly
@ -250,7 +270,7 @@ class RemoteClient private[akka] (
@deprecated("Use removeListener instead")
def deregisterListener(actorRef: ActorRef) = removeListener(actorRef)
override def foreachListener(f: (ActorRef) => Unit): Unit = super.foreachListener(f)
override def notifyListeners(message: => Any): Unit = super.notifyListeners(message)
protected override def manageLifeCycleOfListeners = false
@ -287,7 +307,7 @@ class RemoteClient private[akka] (
} else {
val exception = new RemoteClientException(
"Remote client is not running, make sure you have invoked 'RemoteClient.connect' before using it.", this)
foreachListener(l => l ! RemoteClientError(exception, this))
notifyListeners(RemoteClientError(exception, this))
throw exception
}
@ -403,12 +423,12 @@ class RemoteClientHandler(
futures.remove(reply.getId)
} else {
val exception = new RemoteClientException("Unknown message received in remote client handler: " + result, client)
client.foreachListener(_ ! RemoteClientError(exception, client))
client.notifyListeners(RemoteClientError(exception, client))
throw exception
}
} catch {
case e: Exception =>
client.foreachListener(_ ! RemoteClientError(e, client))
client.notifyListeners(RemoteClientError(e, client))
log.error("Unexpected exception in remote client handler: %s", e)
throw e
}
@ -423,7 +443,7 @@ class RemoteClientHandler(
client.connection = bootstrap.connect(remoteAddress)
client.connection.awaitUninterruptibly // Wait until the connection attempt succeeds or fails.
if (!client.connection.isSuccess) {
client.foreachListener(_ ! RemoteClientError(client.connection.getCause, client))
client.notifyListeners(RemoteClientError(client.connection.getCause, client))
log.error(client.connection.getCause, "Reconnection to [%s] has failed", remoteAddress)
}
}
@ -433,7 +453,7 @@ class RemoteClientHandler(
override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
def connect = {
client.foreachListener(_ ! RemoteClientConnected(client))
client.notifyListeners(RemoteClientConnected(client))
log.debug("Remote client connected to [%s]", ctx.getChannel.getRemoteAddress)
client.resetReconnectionTimeWindow
}
@ -450,12 +470,12 @@ class RemoteClientHandler(
}
override def channelDisconnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
client.foreachListener(_ ! RemoteClientDisconnected(client))
client.notifyListeners(RemoteClientDisconnected(client))
log.debug("Remote client disconnected from [%s]", ctx.getChannel.getRemoteAddress)
}
override def exceptionCaught(ctx: ChannelHandlerContext, event: ExceptionEvent) = {
client.foreachListener(_ ! RemoteClientError(event.getCause, client))
client.notifyListeners(RemoteClientError(event.getCause, client))
log.error(event.getCause, "Unexpected exception from downstream in remote client")
event.getChannel.close
}

View file

@ -10,7 +10,7 @@ import java.util.concurrent.{ConcurrentHashMap, Executors}
import java.util.{Map => JMap}
import se.scalablesolutions.akka.actor.{
Actor, TypedActor, ActorRef, LocalActorRef, RemoteActorRef, IllegalActorStateException, RemoteActorSystemMessage}
Actor, TypedActor, ActorRef, IllegalActorStateException, RemoteActorSystemMessage}
import se.scalablesolutions.akka.actor.Actor._
import se.scalablesolutions.akka.util._
import se.scalablesolutions.akka.remote.protocol.RemoteProtocol._
@ -133,8 +133,8 @@ object RemoteServer {
actorsFor(RemoteServer.Address(address.getHostName, address.getPort)).actors.put(uuid, actor)
}
private[akka] def registerTypedActor(address: InetSocketAddress, name: String, typedActor: AnyRef) = guard.withWriteGuard {
actorsFor(RemoteServer.Address(address.getHostName, address.getPort)).typedActors.put(name, typedActor)
private[akka] def registerTypedActor(address: InetSocketAddress, uuid: String, typedActor: AnyRef) = guard.withWriteGuard {
actorsFor(RemoteServer.Address(address.getHostName, address.getPort)).typedActors.put(uuid, typedActor)
}
private[akka] def getOrCreateServer(address: InetSocketAddress): RemoteServer = guard.withWriteGuard {
@ -245,12 +245,12 @@ class RemoteServer extends Logging with ListenerManagement {
openChannels.add(bootstrap.bind(new InetSocketAddress(hostname, port)))
_isRunning = true
Cluster.registerLocalNode(hostname, port)
foreachListener(_ ! RemoteServerStarted(this))
notifyListeners(RemoteServerStarted(this))
}
} catch {
case e =>
log.error(e, "Could not start up remote server")
foreachListener(_ ! RemoteServerError(e, this))
notifyListeners(RemoteServerError(e, this))
}
this
}
@ -263,7 +263,7 @@ class RemoteServer extends Logging with ListenerManagement {
openChannels.close.awaitUninterruptibly
bootstrap.releaseExternalResources
Cluster.deregisterLocalNode(hostname, port)
foreachListener(_ ! RemoteServerShutdown(this))
notifyListeners(RemoteServerShutdown(this))
} catch {
case e: java.nio.channels.ClosedChannelException => {}
case e => log.warning("Could not close remote server channel in a graceful way")
@ -271,12 +271,28 @@ class RemoteServer extends Logging with ListenerManagement {
}
}
// TODO: register typed actor in RemoteServer as well
/**
* Register typed actor by interface name.
*/
def registerTypedActor(intfClass: Class[_], typedActor: AnyRef) : Unit = registerTypedActor(intfClass.getName, typedActor)
/**
* Register Remote Actor by the Actor's 'uuid' field. It starts the Actor if it is not started already.
* Register remote typed actor by a specific id.
* @param id custom actor id
* @param typedActor typed actor to register
*/
def register(actorRef: ActorRef): Unit = register(actorRef.id,actorRef)
def registerTypedActor(id: String, typedActor: AnyRef): Unit = synchronized {
val typedActors = RemoteServer.actorsFor(RemoteServer.Address(hostname, port)).typedActors
if (!typedActors.contains(id)) {
log.debug("Registering server side remote actor [%s] with id [%s] on [%s:%d]", typedActor.getClass.getName, id, hostname, port)
typedActors.put(id, typedActor)
}
}
/**
* Register Remote Actor by the Actor's 'id' field. It starts the Actor if it is not started already.
*/
def register(actorRef: ActorRef): Unit = register(actorRef.id, actorRef)
/**
* Register Remote Actor by a specific 'id' passed as argument.
@ -321,16 +337,25 @@ class RemoteServer extends Logging with ListenerManagement {
}
}
/**
* Unregister Remote Typed Actor by specific 'id'.
* <p/>
* NOTE: You need to call this method if you have registered an actor by a custom ID.
*/
def unregisterTypedActor(id: String):Unit = synchronized {
if (_isRunning) {
log.info("Unregistering server side remote typed actor with id [%s]", id)
val registeredTypedActors = typedActors()
registeredTypedActors.remove(id)
}
}
protected override def manageLifeCycleOfListeners = false
protected[akka] override def foreachListener(f: (ActorRef) => Unit): Unit = super.foreachListener(f)
protected[akka] override def notifyListeners(message: => Any): Unit = super.notifyListeners(message)
private[akka] def actors() : ConcurrentHashMap[String, ActorRef] = {
RemoteServer.actorsFor(address).actors
}
private[akka] def typedActors() : ConcurrentHashMap[String, AnyRef] = {
RemoteServer.actorsFor(address).typedActors
}
private[akka] def actors() = RemoteServer.actorsFor(address).actors
private[akka] def typedActors() = RemoteServer.actorsFor(address).typedActors
}
object RemoteServerSslContext {
@ -413,18 +438,18 @@ class RemoteServerHandler(
def operationComplete(future: ChannelFuture): Unit = {
if (future.isSuccess) {
openChannels.add(future.getChannel)
server.foreachListener(_ ! RemoteServerClientConnected(server))
server.notifyListeners(RemoteServerClientConnected(server))
} else future.getChannel.close
}
})
} else {
server.foreachListener(_ ! RemoteServerClientConnected(server))
server.notifyListeners(RemoteServerClientConnected(server))
}
}
override def channelClosed(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
log.debug("Remote client disconnected from [%s]", server.name)
server.foreachListener(_ ! RemoteServerClientDisconnected(server))
server.notifyListeners(RemoteServerClientDisconnected(server))
}
override def handleUpstream(ctx: ChannelHandlerContext, event: ChannelEvent) = {
@ -446,7 +471,7 @@ class RemoteServerHandler(
override def exceptionCaught(ctx: ChannelHandlerContext, event: ExceptionEvent) = {
log.error(event.getCause, "Unexpected exception from remote downstream")
event.getChannel.close
server.foreachListener(_ ! RemoteServerError(event.getCause, server))
server.notifyListeners(RemoteServerError(event.getCause, server))
}
private def handleRemoteRequestProtocol(request: RemoteRequestProtocol, channel: Channel) = {
@ -491,7 +516,7 @@ class RemoteServerHandler(
} catch {
case e: Throwable =>
channel.write(createErrorReplyMessage(e, request, true))
server.foreachListener(_ ! RemoteServerError(e, server))
server.notifyListeners(RemoteServerError(e, server))
}
}
}
@ -523,13 +548,39 @@ class RemoteServerHandler(
} catch {
case e: InvocationTargetException =>
channel.write(createErrorReplyMessage(e.getCause, request, false))
server.foreachListener(_ ! RemoteServerError(e, server))
server.notifyListeners(RemoteServerError(e, server))
case e: Throwable =>
channel.write(createErrorReplyMessage(e, request, false))
server.foreachListener(_ ! RemoteServerError(e, server))
server.notifyListeners(RemoteServerError(e, server))
}
}
/**
* Find a registered actor by ID (default) or UUID.
* Actors are registered by id apart from registering during serialization see SerializationProtocol.
*/
private def findActorByIdOrUuid(id: String, uuid: String) : ActorRef = {
val registeredActors = server.actors()
var actorRefOrNull = registeredActors get id
if (actorRefOrNull eq null) {
actorRefOrNull = registeredActors get uuid
}
actorRefOrNull
}
/**
* Find a registered typed actor by ID (default) or UUID.
* Actors are registered by id apart from registering during serialization see SerializationProtocol.
*/
private def findTypedActorByIdOrUUid(id: String, uuid: String) : AnyRef = {
val registeredActors = server.typedActors()
var actorRefOrNull = registeredActors get id
if (actorRefOrNull eq null) {
actorRefOrNull = registeredActors get uuid
}
actorRefOrNull
}
/**
* Creates a new instance of the actor with name, uuid and timeout specified as arguments.
*
@ -538,12 +589,14 @@ class RemoteServerHandler(
* Does not start the actor.
*/
private def createActor(actorInfo: ActorInfoProtocol): ActorRef = {
val uuid = actorInfo.getUuid
val ids = actorInfo.getUuid.split(':')
val uuid = ids(0)
val id = ids(1)
val name = actorInfo.getTarget
val timeout = actorInfo.getTimeout
val registeredActors = server.actors()
val actorRefOrNull = registeredActors get uuid
val actorRefOrNull = findActorByIdOrUuid(id, uuid)
if (actorRefOrNull eq null) {
try {
@ -552,23 +605,26 @@ class RemoteServerHandler(
else Class.forName(name)
val actorRef = Actor.actorOf(clazz.newInstance.asInstanceOf[Actor])
actorRef.uuid = uuid
actorRef.id = id
actorRef.timeout = timeout
actorRef.remoteAddress = None
registeredActors.put(uuid, actorRef)
server.actors.put(id, actorRef) // register by id
actorRef
} catch {
case e =>
log.error(e, "Could not create remote actor instance")
server.foreachListener(_ ! RemoteServerError(e, server))
server.notifyListeners(RemoteServerError(e, server))
throw e
}
} else actorRefOrNull
}
private def createTypedActor(actorInfo: ActorInfoProtocol): AnyRef = {
val uuid = actorInfo.getUuid
val registeredTypedActors = server.typedActors()
val typedActorOrNull = registeredTypedActors get uuid
val ids = actorInfo.getUuid.split(':')
val uuid = ids(0)
val id = ids(1)
val typedActorOrNull = findTypedActorByIdOrUUid(id, uuid)
if (typedActorOrNull eq null) {
val typedActorInfo = actorInfo.getTypedActorInfo
@ -585,12 +641,12 @@ class RemoteServerHandler(
val newInstance = TypedActor.newInstance(
interfaceClass, targetClass.asInstanceOf[Class[_ <: TypedActor]], actorInfo.getTimeout).asInstanceOf[AnyRef]
registeredTypedActors.put(uuid, newInstance)
server.typedActors.put(id, newInstance) // register by id
newInstance
} catch {
case e =>
log.error(e, "Could not create remote typed actor instance")
server.foreachListener(_ ! RemoteServerError(e, server))
server.notifyListeners(RemoteServerError(e, server))
throw e
}
} else typedActorOrNull

View file

@ -230,7 +230,7 @@ object RemoteActorSerialization {
}
RemoteActorRefProtocol.newBuilder
.setUuid(uuid)
.setUuid(uuid + ":" + id)
.setActorClassname(actorClass.getName)
.setHomeAddress(AddressProtocol.newBuilder.setHostname(host).setPort(port).build)
.setTimeout(timeout)
@ -248,7 +248,7 @@ object RemoteActorSerialization {
import actorRef._
val actorInfoBuilder = ActorInfoProtocol.newBuilder
.setUuid(uuid)
.setUuid(uuid + ":" + actorRef.id)
.setTarget(actorClassName)
.setTimeout(timeout)

View file

@ -2,6 +2,7 @@
<system id="akka">
<package name="se.scalablesolutions.akka.actor">
<aspect class="TypedActorAspect" />
<aspect class="ServerManagedTypedActorAspect" />
</package>
</system>
</aspectwerkz>

View file

@ -93,6 +93,7 @@ class ClientInitiatedRemoteActorSpec extends JUnitSuite {
actor.stop
}
@Test
def shouldSendOneWayAndReceiveReply = {
val actor = actorOf[SendOneWayAndReplyReceiverActor]
@ -103,7 +104,7 @@ class ClientInitiatedRemoteActorSpec extends JUnitSuite {
sender.actor.asInstanceOf[SendOneWayAndReplySenderActor].sendTo = actor
sender.start
sender.actor.asInstanceOf[SendOneWayAndReplySenderActor].sendOff
assert(SendOneWayAndReplySenderActor.latch.await(1, TimeUnit.SECONDS))
assert(SendOneWayAndReplySenderActor.latch.await(3, TimeUnit.SECONDS))
assert(sender.actor.asInstanceOf[SendOneWayAndReplySenderActor].state.isDefined === true)
assert("World" === sender.actor.asInstanceOf[SendOneWayAndReplySenderActor].state.get.asInstanceOf[String])
actor.stop
@ -134,6 +135,6 @@ class ClientInitiatedRemoteActorSpec extends JUnitSuite {
assert("Expected exception; to test fault-tolerance" === e.getMessage())
}
actor.stop
}
}
}

View file

@ -4,10 +4,7 @@
package se.scalablesolutions.akka.actor.remote
import org.scalatest.Spec
import org.scalatest.Assertions
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.BeforeAndAfterAll
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
@ -19,6 +16,7 @@ import se.scalablesolutions.akka.actor._
import se.scalablesolutions.akka.remote.{RemoteServer, RemoteClient}
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit, BlockingQueue}
import org.scalatest.{BeforeAndAfterEach, Spec, Assertions, BeforeAndAfterAll}
object RemoteTypedActorSpec {
val HOSTNAME = "localhost"
@ -40,7 +38,7 @@ object RemoteTypedActorLog {
class RemoteTypedActorSpec extends
Spec with
ShouldMatchers with
BeforeAndAfterAll {
BeforeAndAfterEach with BeforeAndAfterAll {
import RemoteTypedActorLog._
import RemoteTypedActorSpec._
@ -82,6 +80,10 @@ class RemoteTypedActorSpec extends
ActorRegistry.shutdownAll
}
override def afterEach() {
server.typedActors.clear
}
describe("Remote Typed Actor ") {
it("should receive one-way message") {

View file

@ -79,6 +79,7 @@ class ServerInitiatedRemoteActorSpec extends JUnitSuite {
}
}
@Test
def shouldSendWithBang {
val actor = RemoteClient.actorFor(
@ -153,10 +154,29 @@ class ServerInitiatedRemoteActorSpec extends JUnitSuite {
server.register(actorOf[RemoteActorSpecActorUnidirectional])
val actor = RemoteClient.actorFor("se.scalablesolutions.akka.actor.remote.ServerInitiatedRemoteActorSpec$RemoteActorSpecActorUnidirectional", HOSTNAME, PORT)
val numberOfActorsInRegistry = ActorRegistry.actors.length
val result = actor ! "OneWay"
actor ! "OneWay"
assert(RemoteActorSpecActorUnidirectional.latch.await(1, TimeUnit.SECONDS))
assert(numberOfActorsInRegistry === ActorRegistry.actors.length)
actor.stop
}
@Test
def shouldUseServiceNameAsIdForRemoteActorRef {
server.register(actorOf[RemoteActorSpecActorUnidirectional])
server.register("my-service", actorOf[RemoteActorSpecActorUnidirectional])
val actor1 = RemoteClient.actorFor("se.scalablesolutions.akka.actor.remote.ServerInitiatedRemoteActorSpec$RemoteActorSpecActorUnidirectional", HOSTNAME, PORT)
val actor2 = RemoteClient.actorFor("my-service", HOSTNAME, PORT)
val actor3 = RemoteClient.actorFor("my-service", HOSTNAME, PORT)
actor1 ! "OneWay"
actor2 ! "OneWay"
actor3 ! "OneWay"
assert(actor1.uuid != actor2.uuid)
assert(actor1.uuid != actor3.uuid)
assert(actor1.id != actor2.id)
assert(actor2.id == actor3.id)
}
}

View file

@ -0,0 +1,112 @@
/**
* Copyright (C) 2009-2010 Scalable Solutions AB <http://scalablesolutions.se>
*/
package se.scalablesolutions.akka.actor.remote
import org.scalatest.Spec
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.BeforeAndAfterAll
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
import java.util.concurrent.TimeUnit
import se.scalablesolutions.akka.remote.{RemoteServer, RemoteClient}
import se.scalablesolutions.akka.actor._
import RemoteTypedActorLog._
object ServerInitiatedRemoteTypedActorSpec {
val HOSTNAME = "localhost"
val PORT = 9990
var server: RemoteServer = null
}
@RunWith(classOf[JUnitRunner])
class ServerInitiatedRemoteTypedActorSpec extends
Spec with
ShouldMatchers with
BeforeAndAfterAll {
import ServerInitiatedRemoteTypedActorSpec._
private val unit = TimeUnit.MILLISECONDS
override def beforeAll = {
server = new RemoteServer()
server.start(HOSTNAME, PORT)
val typedActor = TypedActor.newInstance(classOf[RemoteTypedActorOne], classOf[RemoteTypedActorOneImpl], 1000)
server.registerTypedActor("typed-actor-service", typedActor)
Thread.sleep(1000)
}
// make sure the servers shutdown cleanly after the test has finished
override def afterAll = {
try {
server.shutdown
RemoteClient.shutdownAll
Thread.sleep(1000)
} catch {
case e => ()
}
}
describe("Server managed remote typed Actor ") {
it("should receive one-way message") {
clearMessageLogs
val actor = RemoteClient.typedActorFor(classOf[RemoteTypedActorOne], "typed-actor-service", 5000L, HOSTNAME, PORT)
expect("oneway") {
actor.oneWay
oneWayLog.poll(5, TimeUnit.SECONDS)
}
}
it("should respond to request-reply message") {
clearMessageLogs
val actor = RemoteClient.typedActorFor(classOf[RemoteTypedActorOne], "typed-actor-service", 5000L, HOSTNAME, PORT)
expect("pong") {
actor.requestReply("ping")
}
}
it("should not recreate registered actors") {
val actor = RemoteClient.typedActorFor(classOf[RemoteTypedActorOne], "typed-actor-service", 5000L, HOSTNAME, PORT)
val numberOfActorsInRegistry = ActorRegistry.actors.length
expect("oneway") {
actor.oneWay
oneWayLog.poll(5, TimeUnit.SECONDS)
}
assert(numberOfActorsInRegistry === ActorRegistry.actors.length)
}
it("should support multiple variants to get the actor from client side") {
var actor = RemoteClient.typedActorFor(classOf[RemoteTypedActorOne], "typed-actor-service", 5000L, HOSTNAME, PORT)
expect("oneway") {
actor.oneWay
oneWayLog.poll(5, TimeUnit.SECONDS)
}
actor = RemoteClient.typedActorFor(classOf[RemoteTypedActorOne], "typed-actor-service", HOSTNAME, PORT)
expect("oneway") {
actor.oneWay
oneWayLog.poll(5, TimeUnit.SECONDS)
}
actor = RemoteClient.typedActorFor(classOf[RemoteTypedActorOne], "typed-actor-service", 5000L, HOSTNAME, PORT, this.getClass().getClassLoader)
expect("oneway") {
actor.oneWay
oneWayLog.poll(5, TimeUnit.SECONDS)
}
}
it("should register and unregister typed actors") {
val typedActor = TypedActor.newInstance(classOf[RemoteTypedActorOne], classOf[RemoteTypedActorOneImpl], 1000)
server.registerTypedActor("my-test-service", typedActor)
assert(server.typedActors().get("my-test-service") != null)
server.unregisterTypedActor("my-test-service")
assert(server.typedActors().get("my-test-service") == null)
}
}
}

View file

@ -64,6 +64,14 @@
</xsd:restriction>
</xsd:simpleType>
<!-- management type for remote actors: client managed / server managed -->
<xsd:simpleType name="managed-by-type">
<xsd:restriction base="xsd:token">
<xsd:enumeration value="client"/>
<xsd:enumeration value="server"/>
</xsd:restriction>
</xsd:simpleType>
<!-- dispatcher type -->
<xsd:complexType name="dispatcher-type">
@ -105,6 +113,20 @@
</xsd:documentation>
</xsd:annotation>
</xsd:attribute>
<xsd:attribute name="managed-by" type="managed-by-type">
<xsd:annotation>
<xsd:documentation>
Management type for remote actors: client managed or server managed.
</xsd:documentation>
</xsd:annotation>
</xsd:attribute>
<xsd:attribute name="service-name" type="xsd:string">
<xsd:annotation>
<xsd:documentation>
Custom service name for server managed actor.
</xsd:documentation>
</xsd:annotation>
</xsd:attribute>
</xsd:complexType>
<!-- typed actor -->
@ -133,7 +155,7 @@
<xsd:attribute name="timeout" type="xsd:long" use="required">
<xsd:annotation>
<xsd:documentation>
Theh default timeout for '!!' invocations.
The default timeout for '!!' invocations.
</xsd:documentation>
</xsd:annotation>
</xsd:attribute>
@ -227,6 +249,41 @@
</xsd:choice>
</xsd:complexType>
<!-- actor-for -->
<!-- typed actor -->
<xsd:complexType name="actor-for-type">
<xsd:attribute name="id" type="xsd:ID"/>
<xsd:attribute name="host" type="xsd:string" use="required">
<xsd:annotation>
<xsd:documentation>
Name of the remote host.
</xsd:documentation>
</xsd:annotation>
</xsd:attribute>
<xsd:attribute name="port" type="xsd:integer" use="required">
<xsd:annotation>
<xsd:documentation>
Port of the remote host.
</xsd:documentation>
</xsd:annotation>
</xsd:attribute>
<xsd:attribute name="service-name" type="xsd:string" use="required">
<xsd:annotation>
<xsd:documentation>
Custom service name or class name for the server managed actor.
</xsd:documentation>
</xsd:annotation>
</xsd:attribute>
<xsd:attribute name="interface" type="xsd:string">
<xsd:annotation>
<xsd:documentation>
Name of the interface the typed actor implements.
</xsd:documentation>
</xsd:annotation>
</xsd:attribute>
</xsd:complexType>
<!-- Supervisor strategy -->
<xsd:complexType name="strategy-type">
<xsd:sequence>
@ -292,4 +349,7 @@
<!-- CamelService -->
<xsd:element name="camel-service" type="camel-service-type"/>
<!-- ActorFor -->
<xsd:element name="actor-for" type="actor-for-type"/>
</xsd:schema>

View file

@ -0,0 +1,72 @@
/**
* Copyright (C) 2009-2010 Scalable Solutions AB <http://scalablesolutions.se>
*/
package se.scalablesolutions.akka.spring
import org.springframework.beans.factory.support.BeanDefinitionBuilder
import org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser
import org.springframework.beans.factory.xml.ParserContext
import AkkaSpringConfigurationTags._
import org.w3c.dom.Element
/**
* Parser for custom namespace configuration.
* @author michaelkober
*/
class TypedActorBeanDefinitionParser extends AbstractSingleBeanDefinitionParser with ActorParser {
/*
* @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#doParse(org.w3c.dom.Element, org.springframework.beans.factory.xml.ParserContext, org.springframework.beans.factory.support.BeanDefinitionBuilder)
*/
override def doParse(element: Element, parserContext: ParserContext, builder: BeanDefinitionBuilder) {
val typedActorConf = parseActor(element)
typedActorConf.typed = TYPED_ACTOR_TAG
typedActorConf.setAsProperties(builder)
}
/*
* @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#getBeanClass(org.w3c.dom.Element)
*/
override def getBeanClass(element: Element): Class[_] = classOf[ActorFactoryBean]
}
/**
* Parser for custom namespace configuration.
* @author michaelkober
*/
class UntypedActorBeanDefinitionParser extends AbstractSingleBeanDefinitionParser with ActorParser {
/*
* @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#doParse(org.w3c.dom.Element, org.springframework.beans.factory.xml.ParserContext, org.springframework.beans.factory.support.BeanDefinitionBuilder)
*/
override def doParse(element: Element, parserContext: ParserContext, builder: BeanDefinitionBuilder) {
val untypedActorConf = parseActor(element)
untypedActorConf.typed = UNTYPED_ACTOR_TAG
untypedActorConf.setAsProperties(builder)
}
/*
* @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#getBeanClass(org.w3c.dom.Element)
*/
override def getBeanClass(element: Element): Class[_] = classOf[ActorFactoryBean]
}
/**
* Parser for custom namespace configuration.
* @author michaelkober
*/
class ActorForBeanDefinitionParser extends AbstractSingleBeanDefinitionParser with ActorForParser {
/*
* @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#doParse(org.w3c.dom.Element, org.springframework.beans.factory.xml.ParserContext, org.springframework.beans.factory.support.BeanDefinitionBuilder)
*/
override def doParse(element: Element, parserContext: ParserContext, builder: BeanDefinitionBuilder) {
val actorForConf = parseActorFor(element)
actorForConf.setAsProperties(builder)
}
/*
* @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#getBeanClass(org.w3c.dom.Element)
*/
override def getBeanClass(element: Element): Class[_] = classOf[ActorForFactoryBean]
}

View file

@ -4,22 +4,19 @@
package se.scalablesolutions.akka.spring
import java.beans.PropertyDescriptor
import java.lang.reflect.Method
import javax.annotation.PreDestroy
import javax.annotation.PostConstruct
import org.springframework.beans.{BeanUtils,BeansException,BeanWrapper,BeanWrapperImpl}
import org.springframework.beans.factory.BeanFactory
import se.scalablesolutions.akka.remote.{RemoteClient, RemoteServer}
//import org.springframework.beans.factory.BeanFactory
import org.springframework.beans.factory.config.AbstractFactoryBean
import org.springframework.context.{ApplicationContext,ApplicationContextAware}
import org.springframework.util.ReflectionUtils
//import org.springframework.util.ReflectionUtils
import org.springframework.util.StringUtils
import se.scalablesolutions.akka.actor.{ActorRef, AspectInitRegistry, TypedActorConfiguration, TypedActor,Actor}
import se.scalablesolutions.akka.dispatch.MessageDispatcher
import se.scalablesolutions.akka.util.{Logging, Duration}
import scala.reflect.BeanProperty
import java.net.InetSocketAddress
/**
* Exception to use when something goes wrong during bean creation.
@ -49,6 +46,8 @@ class ActorFactoryBean extends AbstractFactoryBean[AnyRef] with Logging with App
@BeanProperty var transactional: Boolean = false
@BeanProperty var host: String = ""
@BeanProperty var port: Int = _
@BeanProperty var serverManaged: Boolean = false
@BeanProperty var serviceName: String = ""
@BeanProperty var lifecycle: String = ""
@BeanProperty var dispatcher: DispatcherProperties = _
@BeanProperty var scope: String = VAL_SCOPE_SINGLETON
@ -94,7 +93,16 @@ class ActorFactoryBean extends AbstractFactoryBean[AnyRef] with Logging with App
if (implementation == null || implementation == "") throw new AkkaBeansException(
"The 'implementation' part of the 'akka:typed-actor' element in the Spring config file can't be null or empty string")
TypedActor.newInstance(interface.toClass, implementation.toClass, createConfig)
val typedActor: AnyRef = TypedActor.newInstance(interface.toClass, implementation.toClass, createConfig)
if (isRemote && serverManaged) {
val server = RemoteServer.getOrCreateServer(new InetSocketAddress(host, port))
if (serviceName.isEmpty) {
server.registerTypedActor(interface, typedActor)
} else {
server.registerTypedActor(serviceName, typedActor)
}
}
typedActor
}
/**
@ -111,7 +119,16 @@ class ActorFactoryBean extends AbstractFactoryBean[AnyRef] with Logging with App
actorRef.makeTransactionRequired
}
if (isRemote) {
actorRef.makeRemote(host, port)
if (serverManaged) {
val server = RemoteServer.getOrCreateServer(new InetSocketAddress(host, port))
if (serviceName.isEmpty) {
server.register(actorRef)
} else {
server.register(serviceName, actorRef)
}
} else {
actorRef.makeRemote(host, port)
}
}
if (hasDispatcher) {
if (dispatcher.dispatcherType != THREAD_BASED){
@ -159,7 +176,7 @@ class ActorFactoryBean extends AbstractFactoryBean[AnyRef] with Logging with App
private[akka] def createConfig: TypedActorConfiguration = {
val config = new TypedActorConfiguration().timeout(Duration(timeout, "millis"))
if (transactional) config.makeTransactionRequired
if (isRemote) config.makeRemote(host, port)
if (isRemote && !serverManaged) config.makeRemote(host, port)
if (hasDispatcher) {
if (dispatcher.dispatcherType != THREAD_BASED) {
config.dispatcher(dispatcherInstance())
@ -191,3 +208,39 @@ class ActorFactoryBean extends AbstractFactoryBean[AnyRef] with Logging with App
}
}
}
/**
* Factory bean for remote client actor-for.
*
* @author michaelkober
*/
class ActorForFactoryBean extends AbstractFactoryBean[AnyRef] with Logging with ApplicationContextAware {
import StringReflect._
import AkkaSpringConfigurationTags._
@BeanProperty var interface: String = ""
@BeanProperty var host: String = ""
@BeanProperty var port: Int = _
@BeanProperty var serviceName: String = ""
//@BeanProperty var scope: String = VAL_SCOPE_SINGLETON
@BeanProperty var applicationContext: ApplicationContext = _
override def isSingleton = false
/*
* @see org.springframework.beans.factory.FactoryBean#getObjectType()
*/
def getObjectType: Class[AnyRef] = classOf[AnyRef]
/*
* @see org.springframework.beans.factory.config.AbstractFactoryBean#createInstance()
*/
def createInstance: AnyRef = {
if (interface.isEmpty) {
RemoteClient.actorFor(serviceName, host, port)
} else {
RemoteClient.typedActorFor(interface.toClass, serviceName, host, port)
}
}
}

View file

@ -6,6 +6,7 @@ package se.scalablesolutions.akka.spring
import org.springframework.util.xml.DomUtils
import org.w3c.dom.Element
import scala.collection.JavaConversions._
import se.scalablesolutions.akka.util.Logging
import se.scalablesolutions.akka.actor.IllegalActorStateException
@ -27,11 +28,17 @@ trait ActorParser extends BeanParser with DispatcherParser {
val objectProperties = new ActorProperties()
val remoteElement = DomUtils.getChildElementByTagName(element, REMOTE_TAG);
val dispatcherElement = DomUtils.getChildElementByTagName(element, DISPATCHER_TAG)
val propertyEntries = DomUtils.getChildElementsByTagName(element,PROPERTYENTRY_TAG)
val propertyEntries = DomUtils.getChildElementsByTagName(element, PROPERTYENTRY_TAG)
if (remoteElement != null) {
objectProperties.host = mandatory(remoteElement, HOST)
objectProperties.port = mandatory(remoteElement, PORT).toInt
objectProperties.serverManaged = (remoteElement.getAttribute(MANAGED_BY) != null) && (remoteElement.getAttribute(MANAGED_BY).equals(SERVER_MANAGED))
val serviceName = remoteElement.getAttribute(SERVICE_NAME)
if ((serviceName != null) && (!serviceName.isEmpty)) {
objectProperties.serviceName = serviceName
objectProperties.serverManaged = true
}
}
if (dispatcherElement != null) {
@ -43,7 +50,7 @@ trait ActorParser extends BeanParser with DispatcherParser {
val entry = new PropertyEntry
entry.name = element.getAttribute("name");
entry.value = element.getAttribute("value")
entry.ref = element.getAttribute("ref")
entry.ref = element.getAttribute("ref")
objectProperties.propertyEntries.add(entry)
}
@ -59,15 +66,13 @@ trait ActorParser extends BeanParser with DispatcherParser {
objectProperties.target = mandatory(element, IMPLEMENTATION)
objectProperties.transactional = if (element.getAttribute(TRANSACTIONAL).isEmpty) false else element.getAttribute(TRANSACTIONAL).toBoolean
if (!element.getAttribute(INTERFACE).isEmpty) {
if (element.hasAttribute(INTERFACE)) {
objectProperties.interface = element.getAttribute(INTERFACE)
}
if (!element.getAttribute(LIFECYCLE).isEmpty) {
if (element.hasAttribute(LIFECYCLE)) {
objectProperties.lifecycle = element.getAttribute(LIFECYCLE)
}
if (!element.getAttribute(SCOPE).isEmpty) {
if (element.hasAttribute(SCOPE)) {
objectProperties.scope = element.getAttribute(SCOPE)
}
@ -75,3 +80,158 @@ trait ActorParser extends BeanParser with DispatcherParser {
}
}
/**
* Parser trait for custom namespace configuration for RemoteClient actor-for.
* @author michaelkober
*/
trait ActorForParser extends BeanParser {
import AkkaSpringConfigurationTags._
/**
* Parses the given element and returns a ActorForProperties.
* @param element dom element to parse
* @return configuration for the typed actor
*/
def parseActorFor(element: Element): ActorForProperties = {
val objectProperties = new ActorForProperties()
objectProperties.host = mandatory(element, HOST)
objectProperties.port = mandatory(element, PORT).toInt
objectProperties.serviceName = mandatory(element, SERVICE_NAME)
if (element.hasAttribute(INTERFACE)) {
objectProperties.interface = element.getAttribute(INTERFACE)
}
objectProperties
}
}
/**
* Base trait with utility methods for bean parsing.
*/
trait BeanParser extends Logging {
/**
* Get a mandatory element attribute.
* @param element the element with the mandatory attribute
* @param attribute name of the mandatory attribute
*/
def mandatory(element: Element, attribute: String): String = {
if ((element.getAttribute(attribute) == null) || (element.getAttribute(attribute).isEmpty)) {
throw new IllegalArgumentException("Mandatory attribute missing: " + attribute)
} else {
element.getAttribute(attribute)
}
}
/**
* Get a mandatory child element.
* @param element the parent element
* @param childName name of the mandatory child element
*/
def mandatoryElement(element: Element, childName: String): Element = {
val childElement = DomUtils.getChildElementByTagName(element, childName);
if (childElement == null) {
throw new IllegalArgumentException("Mandatory element missing: '<akka:" + childName + ">'")
} else {
childElement
}
}
}
/**
* Parser trait for custom namespace for Akka dispatcher configuration.
* @author michaelkober
*/
trait DispatcherParser extends BeanParser {
import AkkaSpringConfigurationTags._
/**
* Parses the given element and returns a DispatcherProperties.
* @param element dom element to parse
* @return configuration for the dispatcher
*/
def parseDispatcher(element: Element): DispatcherProperties = {
val properties = new DispatcherProperties()
var dispatcherElement = element
if (hasRef(element)) {
val ref = element.getAttribute(REF)
dispatcherElement = element.getOwnerDocument.getElementById(ref)
if (dispatcherElement == null) {
throw new IllegalArgumentException("Referenced dispatcher not found: '" + ref + "'")
}
}
properties.dispatcherType = mandatory(dispatcherElement, TYPE)
if (properties.dispatcherType == THREAD_BASED) {
val allowedParentNodes = "akka:typed-actor" :: "akka:untyped-actor" :: "typed-actor" :: "untyped-actor" :: Nil
if (!allowedParentNodes.contains(dispatcherElement.getParentNode.getNodeName)) {
throw new IllegalArgumentException("Thread based dispatcher must be nested in 'typed-actor' or 'untyped-actor' element!")
}
}
if (properties.dispatcherType == HAWT) { // no name for HawtDispatcher
properties.name = dispatcherElement.getAttribute(NAME)
if (dispatcherElement.hasAttribute(AGGREGATE)) {
properties.aggregate = dispatcherElement.getAttribute(AGGREGATE).toBoolean
}
} else {
properties.name = mandatory(dispatcherElement, NAME)
}
val threadPoolElement = DomUtils.getChildElementByTagName(dispatcherElement, THREAD_POOL_TAG);
if (threadPoolElement != null) {
if (properties.dispatcherType == THREAD_BASED) {
throw new IllegalArgumentException("Element 'thread-pool' not allowed for this dispatcher type.")
}
val threadPoolProperties = parseThreadPool(threadPoolElement)
properties.threadPool = threadPoolProperties
}
properties
}
/**
* Parses the given element and returns a ThreadPoolProperties.
* @param element dom element to parse
* @return configuration for the thread pool
*/
def parseThreadPool(element: Element): ThreadPoolProperties = {
val properties = new ThreadPoolProperties()
properties.queue = element.getAttribute(QUEUE)
if (element.hasAttribute(CAPACITY)) {
properties.capacity = element.getAttribute(CAPACITY).toInt
}
if (element.hasAttribute(BOUND)) {
properties.bound = element.getAttribute(BOUND).toInt
}
if (element.hasAttribute(FAIRNESS)) {
properties.fairness = element.getAttribute(FAIRNESS).toBoolean
}
if (element.hasAttribute(CORE_POOL_SIZE)) {
properties.corePoolSize = element.getAttribute(CORE_POOL_SIZE).toInt
}
if (element.hasAttribute(MAX_POOL_SIZE)) {
properties.maxPoolSize = element.getAttribute(MAX_POOL_SIZE).toInt
}
if (element.hasAttribute(KEEP_ALIVE)) {
properties.keepAlive = element.getAttribute(KEEP_ALIVE).toLong
}
if (element.hasAttribute(REJECTION_POLICY)) {
properties.rejectionPolicy = element.getAttribute(REJECTION_POLICY)
}
if (element.hasAttribute(MAILBOX_CAPACITY)) {
properties.mailboxCapacity = element.getAttribute(MAILBOX_CAPACITY).toInt
}
properties
}
def hasRef(element: Element): Boolean = {
val ref = element.getAttribute(REF)
(ref != null) && !ref.isEmpty
}
}

View file

@ -8,7 +8,7 @@ import org.springframework.beans.factory.support.BeanDefinitionBuilder
import AkkaSpringConfigurationTags._
/**
* Data container for typed actor configuration data.
* Data container for actor configuration data.
* @author michaelkober
* @author Martin Krasser
*/
@ -20,6 +20,8 @@ class ActorProperties {
var transactional: Boolean = false
var host: String = ""
var port: Int = _
var serverManaged: Boolean = false
var serviceName: String = ""
var lifecycle: String = ""
var scope:String = VAL_SCOPE_SINGLETON
var dispatcher: DispatcherProperties = _
@ -34,6 +36,8 @@ class ActorProperties {
builder.addPropertyValue("typed", typed)
builder.addPropertyValue(HOST, host)
builder.addPropertyValue(PORT, port)
builder.addPropertyValue("serverManaged", serverManaged)
builder.addPropertyValue("serviceName", serviceName)
builder.addPropertyValue(TIMEOUT, timeout)
builder.addPropertyValue(IMPLEMENTATION, target)
builder.addPropertyValue(INTERFACE, interface)
@ -45,3 +49,26 @@ class ActorProperties {
}
}
/**
* Data container for actor configuration data.
* @author michaelkober
*/
class ActorForProperties {
var interface: String = ""
var host: String = ""
var port: Int = _
var serviceName: String = ""
/**
* Sets the properties to the given builder.
* @param builder bean definition builder
*/
def setAsProperties(builder: BeanDefinitionBuilder) {
builder.addPropertyValue(HOST, host)
builder.addPropertyValue(PORT, port)
builder.addPropertyValue("serviceName", serviceName)
builder.addPropertyValue(INTERFACE, interface)
}
}

View file

@ -12,10 +12,11 @@ import AkkaSpringConfigurationTags._
*/
class AkkaNamespaceHandler extends NamespaceHandlerSupport {
def init = {
registerBeanDefinitionParser(TYPED_ACTOR_TAG, new TypedActorBeanDefinitionParser());
registerBeanDefinitionParser(UNTYPED_ACTOR_TAG, new UntypedActorBeanDefinitionParser());
registerBeanDefinitionParser(SUPERVISION_TAG, new SupervisionBeanDefinitionParser());
registerBeanDefinitionParser(DISPATCHER_TAG, new DispatcherBeanDefinitionParser());
registerBeanDefinitionParser(CAMEL_SERVICE_TAG, new CamelServiceBeanDefinitionParser);
registerBeanDefinitionParser(TYPED_ACTOR_TAG, new TypedActorBeanDefinitionParser())
registerBeanDefinitionParser(UNTYPED_ACTOR_TAG, new UntypedActorBeanDefinitionParser())
registerBeanDefinitionParser(SUPERVISION_TAG, new SupervisionBeanDefinitionParser())
registerBeanDefinitionParser(DISPATCHER_TAG, new DispatcherBeanDefinitionParser())
registerBeanDefinitionParser(CAMEL_SERVICE_TAG, new CamelServiceBeanDefinitionParser)
registerBeanDefinitionParser(ACTOR_FOR_TAG, new ActorForBeanDefinitionParser());
}
}

View file

@ -19,6 +19,7 @@ object AkkaSpringConfigurationTags {
val DISPATCHER_TAG = "dispatcher"
val PROPERTYENTRY_TAG = "property"
val CAMEL_SERVICE_TAG = "camel-service"
val ACTOR_FOR_TAG = "actor-for"
// actor sub tags
val REMOTE_TAG = "remote"
@ -45,6 +46,8 @@ object AkkaSpringConfigurationTags {
val TRANSACTIONAL = "transactional"
val HOST = "host"
val PORT = "port"
val MANAGED_BY = "managed-by"
val SERVICE_NAME = "service-name"
val LIFECYCLE = "lifecycle"
val SCOPE = "scope"
@ -101,4 +104,8 @@ object AkkaSpringConfigurationTags {
val THREAD_BASED = "thread-based"
val HAWT = "hawt"
// managed by types
val SERVER_MANAGED = "server"
val CLIENT_MANAGED = "client"
}

View file

@ -1,42 +0,0 @@
/**
* Copyright (C) 2009-2010 Scalable Solutions AB <http://scalablesolutions.se>
*/
package se.scalablesolutions.akka.spring
import se.scalablesolutions.akka.util.Logging
import org.w3c.dom.Element
import org.springframework.util.xml.DomUtils
/**
* Base trait with utility methods for bean parsing.
*/
trait BeanParser extends Logging {
/**
* Get a mandatory element attribute.
* @param element the element with the mandatory attribute
* @param attribute name of the mandatory attribute
*/
def mandatory(element: Element, attribute: String): String = {
if ((element.getAttribute(attribute) == null) || (element.getAttribute(attribute).isEmpty)) {
throw new IllegalArgumentException("Mandatory attribute missing: " + attribute)
} else {
element.getAttribute(attribute)
}
}
/**
* Get a mandatory child element.
* @param element the parent element
* @param childName name of the mandatory child element
*/
def mandatoryElement(element: Element, childName: String): Element = {
val childElement = DomUtils.getChildElementByTagName(element, childName);
if (childElement == null) {
throw new IllegalArgumentException("Mandatory element missing: '<akka:" + childName + ">'")
} else {
childElement
}
}
}

View file

@ -1,100 +0,0 @@
/**
* Copyright (C) 2009-2010 Scalable Solutions AB <http://scalablesolutions.se>
*/
package se.scalablesolutions.akka.spring
import org.w3c.dom.Element
import org.springframework.util.xml.DomUtils
/**
* Parser trait for custom namespace for Akka dispatcher configuration.
* @author michaelkober
*/
trait DispatcherParser extends BeanParser {
import AkkaSpringConfigurationTags._
/**
* Parses the given element and returns a DispatcherProperties.
* @param element dom element to parse
* @return configuration for the dispatcher
*/
def parseDispatcher(element: Element): DispatcherProperties = {
val properties = new DispatcherProperties()
var dispatcherElement = element
if (hasRef(element)) {
val ref = element.getAttribute(REF)
dispatcherElement = element.getOwnerDocument.getElementById(ref)
if (dispatcherElement == null) {
throw new IllegalArgumentException("Referenced dispatcher not found: '" + ref + "'")
}
}
properties.dispatcherType = mandatory(dispatcherElement, TYPE)
if (properties.dispatcherType == THREAD_BASED) {
val allowedParentNodes = "akka:typed-actor" :: "akka:untyped-actor" :: "typed-actor" :: "untyped-actor" :: Nil
if (!allowedParentNodes.contains(dispatcherElement.getParentNode.getNodeName)) {
throw new IllegalArgumentException("Thread based dispatcher must be nested in 'typed-actor' or 'untyped-actor' element!")
}
}
if (properties.dispatcherType == HAWT) { // no name for HawtDispatcher
properties.name = dispatcherElement.getAttribute(NAME)
if (dispatcherElement.hasAttribute(AGGREGATE)) {
properties.aggregate = dispatcherElement.getAttribute(AGGREGATE).toBoolean
}
} else {
properties.name = mandatory(dispatcherElement, NAME)
}
val threadPoolElement = DomUtils.getChildElementByTagName(dispatcherElement, THREAD_POOL_TAG);
if (threadPoolElement != null) {
if (properties.dispatcherType == THREAD_BASED) {
throw new IllegalArgumentException("Element 'thread-pool' not allowed for this dispatcher type.")
}
val threadPoolProperties = parseThreadPool(threadPoolElement)
properties.threadPool = threadPoolProperties
}
properties
}
/**
* Parses the given element and returns a ThreadPoolProperties.
* @param element dom element to parse
* @return configuration for the thread pool
*/
def parseThreadPool(element: Element): ThreadPoolProperties = {
val properties = new ThreadPoolProperties()
properties.queue = element.getAttribute(QUEUE)
if (element.hasAttribute(CAPACITY)) {
properties.capacity = element.getAttribute(CAPACITY).toInt
}
if (element.hasAttribute(BOUND)) {
properties.bound = element.getAttribute(BOUND).toInt
}
if (element.hasAttribute(FAIRNESS)) {
properties.fairness = element.getAttribute(FAIRNESS).toBoolean
}
if (element.hasAttribute(CORE_POOL_SIZE)) {
properties.corePoolSize = element.getAttribute(CORE_POOL_SIZE).toInt
}
if (element.hasAttribute(MAX_POOL_SIZE)) {
properties.maxPoolSize = element.getAttribute(MAX_POOL_SIZE).toInt
}
if (element.hasAttribute(KEEP_ALIVE)) {
properties.keepAlive = element.getAttribute(KEEP_ALIVE).toLong
}
if (element.hasAttribute(REJECTION_POLICY)) {
properties.rejectionPolicy = element.getAttribute(REJECTION_POLICY)
}
if (element.hasAttribute(MAILBOX_CAPACITY)) {
properties.mailboxCapacity = element.getAttribute(MAILBOX_CAPACITY).toInt
}
properties
}
def hasRef(element: Element): Boolean = {
val ref = element.getAttribute(REF)
(ref != null) && !ref.isEmpty
}
}

View file

@ -18,3 +18,19 @@ class PropertyEntries {
entryList.append(entry)
}
}
/**
* Represents a property element
* @author <a href="johan.rask@jayway.com">Johan Rask</a>
*/
class PropertyEntry {
var name: String = _
var value: String = null
var ref: String = null
override def toString(): String = {
format("name = %s,value = %s, ref = %s", name, value, ref)
}
}

View file

@ -1,19 +0,0 @@
/**
* Copyright (C) 2009-2010 Scalable Solutions AB <http://scalablesolutions.se>
*/
package se.scalablesolutions.akka.spring
/**
* Represents a property element
* @author <a href="johan.rask@jayway.com">Johan Rask</a>
*/
class PropertyEntry {
var name: String = _
var value: String = null
var ref: String = null
override def toString(): String = {
format("name = %s,value = %s, ref = %s", name, value, ref)
}
}

View file

@ -1,31 +0,0 @@
/**
* Copyright (C) 2009-2010 Scalable Solutions AB <http://scalablesolutions.se>
*/
package se.scalablesolutions.akka.spring
import org.springframework.beans.factory.support.BeanDefinitionBuilder
import org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser
import org.springframework.beans.factory.xml.ParserContext
import AkkaSpringConfigurationTags._
import org.w3c.dom.Element
/**
* Parser for custom namespace configuration.
* @author michaelkober
*/
class TypedActorBeanDefinitionParser extends AbstractSingleBeanDefinitionParser with ActorParser {
/*
* @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#doParse(org.w3c.dom.Element, org.springframework.beans.factory.xml.ParserContext, org.springframework.beans.factory.support.BeanDefinitionBuilder)
*/
override def doParse(element: Element, parserContext: ParserContext, builder: BeanDefinitionBuilder) {
val typedActorConf = parseActor(element)
typedActorConf.typed = TYPED_ACTOR_TAG
typedActorConf.setAsProperties(builder)
}
/*
* @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#getBeanClass(org.w3c.dom.Element)
*/
override def getBeanClass(element: Element): Class[_] = classOf[ActorFactoryBean]
}

View file

@ -1,31 +0,0 @@
/**
* Copyright (C) 2009-2010 Scalable Solutions AB <http://scalablesolutions.se>
*/
package se.scalablesolutions.akka.spring
import org.springframework.beans.factory.support.BeanDefinitionBuilder
import org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser
import org.springframework.beans.factory.xml.ParserContext
import AkkaSpringConfigurationTags._
import org.w3c.dom.Element
/**
* Parser for custom namespace configuration.
* @author michaelkober
*/
class UntypedActorBeanDefinitionParser extends AbstractSingleBeanDefinitionParser with ActorParser {
/*
* @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#doParse(org.w3c.dom.Element, org.springframework.beans.factory.xml.ParserContext, org.springframework.beans.factory.support.BeanDefinitionBuilder)
*/
override def doParse(element: Element, parserContext: ParserContext, builder: BeanDefinitionBuilder) {
val untypedActorConf = parseActor(element)
untypedActorConf.typed = UNTYPED_ACTOR_TAG
untypedActorConf.setAsProperties(builder)
}
/*
* @see org.springframework.beans.factory.xml.AbstractSingleBeanDefinitionParser#getBeanClass(org.w3c.dom.Element)
*/
override def getBeanClass(element: Element): Class[_] = classOf[ActorFactoryBean]
}

View file

@ -8,14 +8,12 @@ package se.scalablesolutions.akka.spring.foo;
* To change this template use File | Settings | File Templates.
*/
public interface IMyPojo {
public void oneWay(String message);
public String getFoo();
public String getBar();
public void preRestart();
public void postRestart();
public String longRunning();
}

View file

@ -1,42 +1,34 @@
package se.scalablesolutions.akka.spring.foo;
import se.scalablesolutions.akka.actor.*;
import se.scalablesolutions.akka.actor.TypedActor;
public class MyPojo extends TypedActor implements IMyPojo{
import java.util.concurrent.CountDownLatch;
private String foo;
private String bar;
public class MyPojo extends TypedActor implements IMyPojo {
public static CountDownLatch latch = new CountDownLatch(1);
public static String lastOneWayMessage = null;
private String foo = "foo";
public MyPojo() {
this.foo = "foo";
this.bar = "bar";
}
public MyPojo() {
}
public String getFoo() {
return foo;
}
public String getFoo() {
return foo;
}
public void oneWay(String message) {
lastOneWayMessage = message;
latch.countDown();
}
public String getBar() {
return bar;
}
public void preRestart() {
System.out.println("pre restart");
}
public void postRestart() {
System.out.println("post restart");
}
public String longRunning() {
try {
Thread.sleep(6000);
} catch (InterruptedException e) {
}
return "this took long";
public String longRunning() {
try {
Thread.sleep(6000);
} catch (InterruptedException e) {
}
return "this took long";
}
}

View file

@ -6,6 +6,8 @@ import se.scalablesolutions.akka.actor.ActorRef;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import java.util.concurrent.CountDownLatch;
/**
* test class
@ -14,6 +16,9 @@ public class PingActor extends UntypedActor implements ApplicationContextAware {
private String stringFromVal;
private String stringFromRef;
public static String lastMessage = null;
public static CountDownLatch latch = new CountDownLatch(1);
private boolean gotApplicationContext = false;
@ -42,7 +47,6 @@ public class PingActor extends UntypedActor implements ApplicationContextAware {
stringFromRef = s;
}
private String longRunning() {
try {
Thread.sleep(6000);
@ -53,12 +57,12 @@ public class PingActor extends UntypedActor implements ApplicationContextAware {
public void onReceive(Object message) throws Exception {
if (message instanceof String) {
System.out.println("Ping received String message: " + message);
lastMessage = (String) message;
if (message.equals("longRunning")) {
System.out.println("### starting pong");
ActorRef pongActor = UntypedActor.actorOf(PongActor.class).start();
pongActor.sendRequestReply("longRunning", getContext());
}
latch.countDown();
} else {
throw new IllegalArgumentException("Unknown message: " + message);
}

View file

@ -0,0 +1,57 @@
<?xml version="1.0" encoding="UTF-8"?>
<beans xmlns="http://www.springframework.org/schema/beans"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns:akka="http://www.akkasource.org/schema/akka"
xmlns:beans="http://www.springframework.org/schema/lang"
xsi:schemaLocation="
http://www.springframework.org/schema/beans
http://www.springframework.org/schema/beans/spring-beans-2.0.xsd
http://www.akkasource.org/schema/akka
http://scalablesolutions.se/akka/akka-1.0-SNAPSHOT.xsd">
<akka:untyped-actor id="client-managed-remote-untyped-actor"
implementation="se.scalablesolutions.akka.spring.foo.PingActor">
<akka:remote host="localhost" port="9990" managed-by="client"/>
</akka:untyped-actor>
<akka:untyped-actor id="server-managed-remote-untyped-actor"
implementation="se.scalablesolutions.akka.spring.foo.PingActor">
<akka:remote host="localhost" port="9990" managed-by="server"/>
</akka:untyped-actor>
<akka:untyped-actor id="server-managed-remote-untyped-actor-custom-id"
implementation="se.scalablesolutions.akka.spring.foo.PingActor">
<akka:remote host="localhost" port="9990" service-name="ping-service"/>
</akka:untyped-actor>
<akka:typed-actor id="client-managed-remote-typed-actor"
interface="se.scalablesolutions.akka.spring.foo.IMyPojo"
implementation="se.scalablesolutions.akka.spring.foo.MyPojo"
timeout="2000">
<akka:remote host="localhost" port="9990" managed-by="client"/>
</akka:typed-actor>
<akka:typed-actor id="server-managed-remote-typed-actor"
interface="se.scalablesolutions.akka.spring.foo.IMyPojo"
implementation="se.scalablesolutions.akka.spring.foo.MyPojo"
timeout="2000">
<akka:remote host="localhost" port="9990" managed-by="server"/>
</akka:typed-actor>
<akka:typed-actor id="server-managed-remote-typed-actor-custom-id"
interface="se.scalablesolutions.akka.spring.foo.IMyPojo"
implementation="se.scalablesolutions.akka.spring.foo.MyPojo"
timeout="2000">
<akka:remote host="localhost" port="9990" service-name="mypojo-service"/>
</akka:typed-actor>
<akka:actor-for id="client-1" host="localhost" port="9990" service-name="ping-service"/>
<akka:actor-for id="typed-client-1"
interface="se.scalablesolutions.akka.spring.foo.IMyPojo"
host="localhost"
port="9990"
service-name="mypojo-service"/>
</beans>

View file

@ -37,7 +37,7 @@ http://scalablesolutions.se/akka/akka-1.0-SNAPSHOT.xsd">
implementation="se.scalablesolutions.akka.spring.foo.MyPojo"
timeout="2000"
transactional="true">
<akka:remote host="localhost" port="9999"/>
<akka:remote host="localhost" port="9990"/>
</akka:typed-actor>
<akka:typed-actor id="remote-service1"

View file

@ -24,7 +24,7 @@ http://scalablesolutions.se/akka/akka-1.0-SNAPSHOT.xsd">
<akka:untyped-actor id="remote-untyped-actor"
implementation="se.scalablesolutions.akka.spring.foo.PingActor"
timeout="2000">
<akka:remote host="localhost" port="9999"/>
<akka:remote host="localhost" port="9992"/>
</akka:untyped-actor>
<akka:untyped-actor id="untyped-actor-with-dispatcher"

View file

@ -19,7 +19,7 @@ import org.w3c.dom.Element
class TypedActorBeanDefinitionParserTest extends Spec with ShouldMatchers {
private class Parser extends ActorParser
describe("An TypedActorParser") {
describe("A TypedActorParser") {
val parser = new Parser()
it("should parse the typed actor configuration") {
val xml = <akka:typed-actor id="typed-actor1"
@ -66,6 +66,20 @@ class TypedActorBeanDefinitionParserTest extends Spec with ShouldMatchers {
assert(props != null)
assert(props.host === "com.some.host")
assert(props.port === 9999)
assert(!props.serverManaged)
}
it("should parse remote server managed TypedActors configuration") {
val xml = <akka:typed-actor id="remote typed-actor" implementation="se.scalablesolutions.akka.spring.foo.MyPojo"
timeout="1000">
<akka:remote host="com.some.host" port="9999" service-name="my-service"/>
</akka:typed-actor>
val props = parser.parseActor(dom(xml).getDocumentElement);
assert(props != null)
assert(props.host === "com.some.host")
assert(props.port === 9999)
assert(props.serviceName === "my-service")
assert(props.serverManaged)
}
}
}

View file

@ -4,10 +4,8 @@
package se.scalablesolutions.akka.spring
import foo.{IMyPojo, MyPojo}
import foo.{PingActor, IMyPojo, MyPojo}
import se.scalablesolutions.akka.dispatch.FutureTimeoutException
import se.scalablesolutions.akka.remote.RemoteNode
import org.scalatest.FeatureSpec
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
@ -16,13 +14,52 @@ import org.springframework.beans.factory.xml.XmlBeanDefinitionReader
import org.springframework.context.ApplicationContext
import org.springframework.context.support.ClassPathXmlApplicationContext
import org.springframework.core.io.{ClassPathResource, Resource}
import org.scalatest.{BeforeAndAfterAll, FeatureSpec}
import se.scalablesolutions.akka.remote.{RemoteClient, RemoteServer, RemoteNode}
import java.util.concurrent.CountDownLatch
import se.scalablesolutions.akka.actor.{TypedActor, RemoteTypedActorOne, Actor}
import se.scalablesolutions.akka.actor.remote.RemoteTypedActorOneImpl
/**
* Tests for spring configuration of typed actors.
* @author michaelkober
*/
@RunWith(classOf[JUnitRunner])
class TypedActorSpringFeatureTest extends FeatureSpec with ShouldMatchers {
class TypedActorSpringFeatureTest extends FeatureSpec with ShouldMatchers with BeforeAndAfterAll {
var server1: RemoteServer = null
var server2: RemoteServer = null
override def beforeAll = {
val actor = Actor.actorOf[PingActor] // FIXME: remove this line when ticket 425 is fixed
server1 = new RemoteServer()
server1.start("localhost", 9990)
server2 = new RemoteServer()
server2.start("localhost", 9992)
val typedActor = TypedActor.newInstance(classOf[RemoteTypedActorOne], classOf[RemoteTypedActorOneImpl], 1000)
server1.registerTypedActor("typed-actor-service", typedActor)
}
// make sure the servers shutdown cleanly after the test has finished
override def afterAll = {
try {
server1.shutdown
server2.shutdown
RemoteClient.shutdownAll
Thread.sleep(1000)
} catch {
case e => ()
}
}
def getTypedActorFromContext(config: String, id: String) : IMyPojo = {
MyPojo.latch = new CountDownLatch(1)
val context = new ClassPathXmlApplicationContext(config)
val myPojo: IMyPojo = context.getBean(id).asInstanceOf[IMyPojo]
myPojo
}
feature("parse Spring application context") {
scenario("akka:typed-actor and akka:supervision and akka:dispatcher can be used as top level elements") {
@ -37,41 +74,79 @@ class TypedActorSpringFeatureTest extends FeatureSpec with ShouldMatchers {
}
scenario("get a typed actor") {
val context = new ClassPathXmlApplicationContext("/typed-actor-config.xml")
val myPojo = context.getBean("simple-typed-actor").asInstanceOf[IMyPojo]
var msg = myPojo.getFoo()
msg += myPojo.getBar()
assert(msg === "foobar")
val myPojo = getTypedActorFromContext("/typed-actor-config.xml", "simple-typed-actor")
assert(myPojo.getFoo() === "foo")
myPojo.oneWay("hello 1")
MyPojo.latch.await
assert(MyPojo.lastOneWayMessage === "hello 1")
}
scenario("FutureTimeoutException when timed out") {
val context = new ClassPathXmlApplicationContext("/typed-actor-config.xml")
val myPojo = context.getBean("simple-typed-actor").asInstanceOf[IMyPojo]
val myPojo = getTypedActorFromContext("/typed-actor-config.xml", "simple-typed-actor")
evaluating {myPojo.longRunning()} should produce[FutureTimeoutException]
}
scenario("typed-actor with timeout") {
val context = new ClassPathXmlApplicationContext("/typed-actor-config.xml")
val myPojo = context.getBean("simple-typed-actor-long-timeout").asInstanceOf[IMyPojo]
val myPojo = getTypedActorFromContext("/typed-actor-config.xml", "simple-typed-actor-long-timeout")
assert(myPojo.longRunning() === "this took long");
}
scenario("transactional typed-actor") {
val context = new ClassPathXmlApplicationContext("/typed-actor-config.xml")
val myPojo = context.getBean("transactional-typed-actor").asInstanceOf[IMyPojo]
var msg = myPojo.getFoo()
msg += myPojo.getBar()
assert(msg === "foobar")
val myPojo = getTypedActorFromContext("/typed-actor-config.xml", "transactional-typed-actor")
assert(myPojo.getFoo() === "foo")
myPojo.oneWay("hello 2")
MyPojo.latch.await
assert(MyPojo.lastOneWayMessage === "hello 2")
}
scenario("get a remote typed-actor") {
RemoteNode.start
Thread.sleep(1000)
val context = new ClassPathXmlApplicationContext("/typed-actor-config.xml")
val myPojo = context.getBean("remote-typed-actor").asInstanceOf[IMyPojo]
assert(myPojo.getFoo === "foo")
val myPojo = getTypedActorFromContext("/typed-actor-config.xml", "remote-typed-actor")
assert(myPojo.getFoo() === "foo")
myPojo.oneWay("hello 3")
MyPojo.latch.await
assert(MyPojo.lastOneWayMessage === "hello 3")
}
scenario("get a client-managed-remote-typed-actor") {
val myPojo = getTypedActorFromContext("/server-managed-config.xml", "client-managed-remote-typed-actor")
assert(myPojo.getFoo() === "foo")
myPojo.oneWay("hello client-managed-remote-typed-actor")
MyPojo.latch.await
assert(MyPojo.lastOneWayMessage === "hello client-managed-remote-typed-actor")
}
scenario("get a server-managed-remote-typed-actor") {
val serverPojo = getTypedActorFromContext("/server-managed-config.xml", "server-managed-remote-typed-actor")
//
val myPojoProxy = RemoteClient.typedActorFor(classOf[IMyPojo], classOf[IMyPojo].getName, 5000L, "localhost", 9990)
assert(myPojoProxy.getFoo() === "foo")
myPojoProxy.oneWay("hello server-managed-remote-typed-actor")
MyPojo.latch.await
assert(MyPojo.lastOneWayMessage === "hello server-managed-remote-typed-actor")
}
scenario("get a server-managed-remote-typed-actor-custom-id") {
val serverPojo = getTypedActorFromContext("/server-managed-config.xml", "server-managed-remote-typed-actor-custom-id")
//
val myPojoProxy = RemoteClient.typedActorFor(classOf[IMyPojo], "mypojo-service", 5000L, "localhost", 9990)
assert(myPojoProxy.getFoo() === "foo")
myPojoProxy.oneWay("hello server-managed-remote-typed-actor 2")
MyPojo.latch.await
assert(MyPojo.lastOneWayMessage === "hello server-managed-remote-typed-actor 2")
}
scenario("get a client proxy for server-managed-remote-typed-actor") {
MyPojo.latch = new CountDownLatch(1)
val context = new ClassPathXmlApplicationContext("/server-managed-config.xml")
val myPojo: IMyPojo = context.getBean("server-managed-remote-typed-actor-custom-id").asInstanceOf[IMyPojo]
// get client proxy from spring context
val myPojoProxy = context.getBean("typed-client-1").asInstanceOf[IMyPojo]
assert(myPojoProxy.getFoo() === "foo")
myPojoProxy.oneWay("hello")
MyPojo.latch.await
}
}
}

View file

@ -6,74 +6,146 @@ package se.scalablesolutions.akka.spring
import foo.PingActor
import se.scalablesolutions.akka.dispatch.ExecutorBasedEventDrivenWorkStealingDispatcher
import se.scalablesolutions.akka.remote.RemoteNode
import se.scalablesolutions.akka.actor.ActorRef
import org.scalatest.FeatureSpec
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
import org.springframework.context.ApplicationContext
import org.springframework.context.support.ClassPathXmlApplicationContext
import se.scalablesolutions.akka.remote.{RemoteClient, RemoteServer}
import org.scalatest.{BeforeAndAfterAll, FeatureSpec}
import java.util.concurrent.CountDownLatch
import se.scalablesolutions.akka.actor.{RemoteActorRef, ActorRegistry, Actor, ActorRef}
/**
* Tests for spring configuration of typed actors.
* @author michaelkober
*/
@RunWith(classOf[JUnitRunner])
class UntypedActorSpringFeatureTest extends FeatureSpec with ShouldMatchers {
class UntypedActorSpringFeatureTest extends FeatureSpec with ShouldMatchers with BeforeAndAfterAll {
var server1: RemoteServer = null
var server2: RemoteServer = null
override def beforeAll = {
val actor = Actor.actorOf[PingActor] // FIXME: remove this line when ticket 425 is fixed
server1 = new RemoteServer()
server1.start("localhost", 9990)
server2 = new RemoteServer()
server2.start("localhost", 9992)
}
// make sure the servers shutdown cleanly after the test has finished
override def afterAll = {
try {
server1.shutdown
server2.shutdown
RemoteClient.shutdownAll
Thread.sleep(1000)
} catch {
case e => ()
}
}
def getPingActorFromContext(config: String, id: String) : ActorRef = {
PingActor.latch = new CountDownLatch(1)
val context = new ClassPathXmlApplicationContext(config)
val pingActor = context.getBean(id).asInstanceOf[ActorRef]
assert(pingActor.getActorClassName() === "se.scalablesolutions.akka.spring.foo.PingActor")
pingActor.start()
}
feature("parse Spring application context") {
scenario("get an untyped actor") {
val context = new ClassPathXmlApplicationContext("/untyped-actor-config.xml")
val myactor = context.getBean("simple-untyped-actor").asInstanceOf[ActorRef]
assert(myactor.getActorClassName() === "se.scalablesolutions.akka.spring.foo.PingActor")
myactor.start()
scenario("get a untyped actor") {
val myactor = getPingActorFromContext("/untyped-actor-config.xml", "simple-untyped-actor")
myactor.sendOneWay("Hello")
PingActor.latch.await
assert(PingActor.lastMessage === "Hello")
assert(myactor.isDefinedAt("some string message"))
}
scenario("untyped-actor with timeout") {
val context = new ClassPathXmlApplicationContext("/untyped-actor-config.xml")
val myactor = context.getBean("simple-untyped-actor-long-timeout").asInstanceOf[ActorRef]
assert(myactor.getActorClassName() === "se.scalablesolutions.akka.spring.foo.PingActor")
myactor.start()
myactor.sendOneWay("Hello")
val myactor = getPingActorFromContext("/untyped-actor-config.xml", "simple-untyped-actor-long-timeout")
assert(myactor.getTimeout() === 10000)
myactor.sendOneWay("Hello 2")
PingActor.latch.await
assert(PingActor.lastMessage === "Hello 2")
}
scenario("transactional untyped-actor") {
val context = new ClassPathXmlApplicationContext("/untyped-actor-config.xml")
val myactor = context.getBean("transactional-untyped-actor").asInstanceOf[ActorRef]
assert(myactor.getActorClassName() === "se.scalablesolutions.akka.spring.foo.PingActor")
myactor.start()
myactor.sendOneWay("Hello")
assert(myactor.isDefinedAt("some string message"))
val myactor = getPingActorFromContext("/untyped-actor-config.xml", "transactional-untyped-actor")
myactor.sendOneWay("Hello 3")
PingActor.latch.await
assert(PingActor.lastMessage === "Hello 3")
}
scenario("get a remote typed-actor") {
RemoteNode.start
Thread.sleep(1000)
val context = new ClassPathXmlApplicationContext("/untyped-actor-config.xml")
val myactor = context.getBean("remote-untyped-actor").asInstanceOf[ActorRef]
assert(myactor.getActorClassName() === "se.scalablesolutions.akka.spring.foo.PingActor")
myactor.start()
myactor.sendOneWay("Hello")
assert(myactor.isDefinedAt("some string message"))
val myactor = getPingActorFromContext("/untyped-actor-config.xml", "remote-untyped-actor")
myactor.sendOneWay("Hello 4")
assert(myactor.getRemoteAddress().isDefined)
assert(myactor.getRemoteAddress().get.getHostName() === "localhost")
assert(myactor.getRemoteAddress().get.getPort() === 9999)
assert(myactor.getRemoteAddress().get.getPort() === 9992)
PingActor.latch.await
assert(PingActor.lastMessage === "Hello 4")
}
scenario("untyped-actor with custom dispatcher") {
val context = new ClassPathXmlApplicationContext("/untyped-actor-config.xml")
val myactor = context.getBean("untyped-actor-with-dispatcher").asInstanceOf[ActorRef]
assert(myactor.getActorClassName() === "se.scalablesolutions.akka.spring.foo.PingActor")
myactor.start()
myactor.sendOneWay("Hello")
val myactor = getPingActorFromContext("/untyped-actor-config.xml", "untyped-actor-with-dispatcher")
assert(myactor.getTimeout() === 1000)
assert(myactor.getDispatcher.isInstanceOf[ExecutorBasedEventDrivenWorkStealingDispatcher])
myactor.sendOneWay("Hello 5")
PingActor.latch.await
assert(PingActor.lastMessage === "Hello 5")
}
scenario("create client managed remote untyped-actor") {
val myactor = getPingActorFromContext("/server-managed-config.xml", "client-managed-remote-untyped-actor")
myactor.sendOneWay("Hello client managed remote untyped-actor")
PingActor.latch.await
assert(PingActor.lastMessage === "Hello client managed remote untyped-actor")
assert(myactor.getRemoteAddress().isDefined)
assert(myactor.getRemoteAddress().get.getHostName() === "localhost")
assert(myactor.getRemoteAddress().get.getPort() === 9990)
}
scenario("create server managed remote untyped-actor") {
val myactor = getPingActorFromContext("/server-managed-config.xml", "server-managed-remote-untyped-actor")
val nrOfActors = ActorRegistry.actors.length
val actorRef = RemoteClient.actorFor("se.scalablesolutions.akka.spring.foo.PingActor", "localhost", 9990)
actorRef.sendOneWay("Hello server managed remote untyped-actor")
PingActor.latch.await
assert(PingActor.lastMessage === "Hello server managed remote untyped-actor")
assert(ActorRegistry.actors.length === nrOfActors)
}
scenario("create server managed remote untyped-actor with custom service id") {
val myactor = getPingActorFromContext("/server-managed-config.xml", "server-managed-remote-untyped-actor-custom-id")
val nrOfActors = ActorRegistry.actors.length
val actorRef = RemoteClient.actorFor("ping-service", "localhost", 9990)
actorRef.sendOneWay("Hello server managed remote untyped-actor")
PingActor.latch.await
assert(PingActor.lastMessage === "Hello server managed remote untyped-actor")
assert(ActorRegistry.actors.length === nrOfActors)
}
scenario("get client actor for server managed remote untyped-actor") {
PingActor.latch = new CountDownLatch(1)
val context = new ClassPathXmlApplicationContext("/server-managed-config.xml")
val pingActor = context.getBean("server-managed-remote-untyped-actor-custom-id").asInstanceOf[ActorRef]
assert(pingActor.getActorClassName() === "se.scalablesolutions.akka.spring.foo.PingActor")
pingActor.start()
val nrOfActors = ActorRegistry.actors.length
// get client actor ref from spring context
val actorRef = context.getBean("client-1").asInstanceOf[ActorRef]
assert(actorRef.isInstanceOf[RemoteActorRef])
actorRef.sendOneWay("Hello")
PingActor.latch.await
assert(ActorRegistry.actors.length === nrOfActors)
}
}
}

View file

@ -16,9 +16,8 @@ import org.codehaus.aspectwerkz.proxy.Proxy
import org.codehaus.aspectwerkz.annotation.{Aspect, Around}
import java.net.InetSocketAddress
import java.lang.reflect.{InvocationTargetException, Method, Field}
import scala.reflect.BeanProperty
import java.lang.reflect.{Method, Field, InvocationHandler, Proxy => JProxy}
/**
* TypedActor is a type-safe actor made out of a POJO with interface.
@ -390,7 +389,8 @@ object TypedActor extends Logging {
typedActor.initialize(proxy)
if (config._messageDispatcher.isDefined) actorRef.dispatcher = config._messageDispatcher.get
if (config._threadBasedDispatcher.isDefined) actorRef.dispatcher = Dispatchers.newThreadBasedDispatcher(actorRef)
AspectInitRegistry.register(proxy, AspectInit(intfClass, typedActor, actorRef, None, config.timeout))
if (config._host.isDefined) actorRef.makeRemote(config._host.get)
AspectInitRegistry.register(proxy, AspectInit(intfClass, typedActor, actorRef, config._host, config.timeout))
actorRef.start
proxy.asInstanceOf[T]
}
@ -408,24 +408,47 @@ object TypedActor extends Logging {
proxy.asInstanceOf[T]
}
/*
// NOTE: currently not used - but keep it around
private[akka] def newInstance[T <: TypedActor](targetClass: Class[T],
remoteAddress: Option[InetSocketAddress], timeout: Long): T = {
val proxy = {
val instance = Proxy.newInstance(targetClass, true, false)
if (instance.isInstanceOf[TypedActor]) instance.asInstanceOf[TypedActor]
else throw new IllegalActorStateException("Actor [" + targetClass.getName + "] is not a sub class of 'TypedActor'")
/**
* Create a proxy for a RemoteActorRef representing a server managed remote typed actor.
*
*/
private[akka] def createProxyForRemoteActorRef[T](intfClass: Class[T], actorRef: ActorRef): T = {
class MyInvocationHandler extends InvocationHandler {
def invoke(proxy: AnyRef, method: Method, args: Array[AnyRef]): AnyRef = {
// do nothing, this is just a dummy
null
}
}
val context = injectTypedActorContext(proxy)
actorRef.actor.asInstanceOf[Dispatcher].initialize(targetClass, proxy, proxy, context)
actorRef.timeout = timeout
if (remoteAddress.isDefined) actorRef.makeRemote(remoteAddress.get)
AspectInitRegistry.register(proxy, AspectInit(targetClass, proxy, actorRef, remoteAddress, timeout))
actorRef.start
proxy.asInstanceOf[T]
val handler = new MyInvocationHandler()
val interfaces = Array(intfClass, classOf[ServerManagedTypedActor]).asInstanceOf[Array[java.lang.Class[_]]]
val jProxy = JProxy.newProxyInstance(intfClass.getClassLoader(), interfaces, handler)
val awProxy = Proxy.newInstance(interfaces, Array(jProxy, jProxy), true, false)
AspectInitRegistry.register(awProxy, AspectInit(intfClass, null, actorRef, None, 5000L))
awProxy.asInstanceOf[T]
}
*/
/*
// NOTE: currently not used - but keep it around
private[akka] def newInstance[T <: TypedActor](targetClass: Class[T],
remoteAddress: Option[InetSocketAddress], timeout: Long): T = {
val proxy = {
val instance = Proxy.newInstance(targetClass, true, false)
if (instance.isInstanceOf[TypedActor]) instance.asInstanceOf[TypedActor]
else throw new IllegalActorStateException("Actor [" + targetClass.getName + "] is not a sub class of 'TypedActor'")
}
val context = injectTypedActorContext(proxy)
actorRef.actor.asInstanceOf[Dispatcher].initialize(targetClass, proxy, proxy, context)
actorRef.timeout = timeout
if (remoteAddress.isDefined) actorRef.makeRemote(remoteAddress.get)
AspectInitRegistry.register(proxy, AspectInit(targetClass, proxy, actorRef, remoteAddress, timeout))
actorRef.start
proxy.asInstanceOf[T]
}
*/
/**
* Stops the current Typed Actor.
@ -542,6 +565,30 @@ object TypedActor extends Logging {
private[akka] def isJoinPoint(message: Any): Boolean = message.isInstanceOf[JoinPoint]
}
/**
* AspectWerkz Aspect that is turning POJO into proxy to a server managed remote TypedActor.
* <p/>
* Is deployed on a 'perInstance' basis with the pointcut 'execution(* *.*(..))',
* e.g. all methods on the instance.
*
* @author <a href="http://jonasboner.com">Jonas Bon&#233;r</a>
*/
@Aspect("perInstance")
private[akka] sealed class ServerManagedTypedActorAspect extends ActorAspect {
@Around("execution(* *.*(..)) && this(se.scalablesolutions.akka.actor.ServerManagedTypedActor)")
def invoke(joinPoint: JoinPoint): AnyRef = {
if (!isInitialized) initialize(joinPoint)
remoteDispatch(joinPoint)
}
override def initialize(joinPoint: JoinPoint): Unit = {
super.initialize(joinPoint)
remoteAddress = actorRef.remoteAddress
}
}
/**
* AspectWerkz Aspect that is turning POJO into TypedActor.
* <p/>
@ -551,18 +598,9 @@ object TypedActor extends Logging {
* @author <a href="http://jonasboner.com">Jonas Bon&#233;r</a>
*/
@Aspect("perInstance")
private[akka] sealed class TypedActorAspect {
@volatile private var isInitialized = false
@volatile private var isStopped = false
private var interfaceClass: Class[_] = _
private var typedActor: TypedActor = _
private var actorRef: ActorRef = _
private var remoteAddress: Option[InetSocketAddress] = _
private var timeout: Long = _
private var uuid: String = _
@volatile private var instance: TypedActor = _
private[akka] sealed class TypedActorAspect extends ActorAspect {
@Around("execution(* *.*(..))")
@Around("execution(* *.*(..)) && !this(se.scalablesolutions.akka.actor.ServerManagedTypedActor)")
def invoke(joinPoint: JoinPoint): AnyRef = {
if (!isInitialized) initialize(joinPoint)
dispatch(joinPoint)
@ -572,12 +610,26 @@ private[akka] sealed class TypedActorAspect {
if (remoteAddress.isDefined) remoteDispatch(joinPoint)
else localDispatch(joinPoint)
}
}
private def localDispatch(joinPoint: JoinPoint): AnyRef = {
val methodRtti = joinPoint.getRtti.asInstanceOf[MethodRtti]
val isOneWay = TypedActor.isOneWay(methodRtti)
/**
* Base class for TypedActorAspect and ServerManagedTypedActorAspect to reduce code duplication.
*/
private[akka] abstract class ActorAspect {
@volatile protected var isInitialized = false
@volatile protected var isStopped = false
protected var interfaceClass: Class[_] = _
protected var typedActor: TypedActor = _
protected var actorRef: ActorRef = _
protected var timeout: Long = _
protected var uuid: String = _
protected var remoteAddress: Option[InetSocketAddress] = _
protected def localDispatch(joinPoint: JoinPoint): AnyRef = {
val methodRtti = joinPoint.getRtti.asInstanceOf[MethodRtti]
val isOneWay = TypedActor.isOneWay(methodRtti)
val senderActorRef = Some(SenderContextInfo.senderActorRef.value)
val senderProxy = Some(SenderContextInfo.senderProxy.value)
val senderProxy = Some(SenderContextInfo.senderProxy.value)
typedActor.context._sender = senderProxy
if (!actorRef.isRunning && !isStopped) {
@ -598,7 +650,7 @@ private[akka] sealed class TypedActorAspect {
}
}
private def remoteDispatch(joinPoint: JoinPoint): AnyRef = {
protected def remoteDispatch(joinPoint: JoinPoint): AnyRef = {
val methodRtti = joinPoint.getRtti.asInstanceOf[MethodRtti]
val isOneWay = TypedActor.isOneWay(methodRtti)
@ -637,7 +689,7 @@ private[akka] sealed class TypedActorAspect {
(escapedArgs, isEscaped)
}
private def initialize(joinPoint: JoinPoint): Unit = {
protected def initialize(joinPoint: JoinPoint): Unit = {
val init = AspectInitRegistry.initFor(joinPoint.getThis)
interfaceClass = init.interfaceClass
typedActor = init.targetInstance
@ -649,6 +701,7 @@ private[akka] sealed class TypedActorAspect {
}
}
/**
* Internal helper class to help pass the contextual information between threads.
*
@ -670,7 +723,7 @@ private[akka] object AspectInitRegistry extends ListenerManagement {
def register(proxy: AnyRef, init: AspectInit) = {
val res = initializations.put(proxy, init)
foreachListener(_ ! AspectInitRegistered(proxy, init))
notifyListeners(AspectInitRegistered(proxy, init))
res
}
@ -679,7 +732,7 @@ private[akka] object AspectInitRegistry extends ListenerManagement {
*/
def unregister(proxy: AnyRef): AspectInit = {
val init = initializations.remove(proxy)
foreachListener(_ ! AspectInitUnregistered(proxy, init))
notifyListeners(AspectInitUnregistered(proxy, init))
init.actorRef.stop
init
}
@ -700,5 +753,11 @@ private[akka] sealed case class AspectInit(
val timeout: Long) {
def this(interfaceClass: Class[_], targetInstance: TypedActor, actorRef: ActorRef, timeout: Long) =
this(interfaceClass, targetInstance, actorRef, None, timeout)
}
/**
* Marker interface for server manager typed actors.
*/
private[akka] sealed trait ServerManagedTypedActor extends TypedActor

View file

@ -122,7 +122,6 @@ private[akka] class TypedActorGuiceConfigurator extends TypedActorConfiguratorBa
remoteAddress.foreach { address =>
actorRef.makeRemote(remoteAddress.get)
RemoteServerModule.registerTypedActor(address, implementationClass.getName, proxy)
}
AspectInitRegistry.register(

View file

@ -2,6 +2,7 @@
<system id="akka">
<package name="se.scalablesolutions.akka.actor">
<aspect class="TypedActorAspect" />
<aspect class="ServerManagedTypedActorAspect" />
</package>
</system>
</aspectwerkz>

View file

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.redis</groupId>
<artifactId>redisclient</artifactId>
<version>2.8.0-2.0</version>
<packaging>jar</packaging>
</project>

View file

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.scala-tools</groupId>
<artifactId>time</artifactId>
<version>2.8.0-0.2-SNAPSHOT</version>
<packaging>jar</packaging>
</project>

View file

@ -49,6 +49,7 @@ class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) {
lazy val JavaNetRepo = MavenRepository("java.net Repo", "http://download.java.net/maven/2")
lazy val SonatypeSnapshotRepo = MavenRepository("Sonatype OSS Repo", "http://oss.sonatype.org/content/repositories/releases")
lazy val SunJDMKRepo = MavenRepository("Sun JDMK Repo", "http://wp5.e-taxonomy.eu/cdmlib/mavenrepo")
lazy val CasbahRepoReleases = MavenRepository("Casbah Release Repo", "http://repo.bumnetworks.com/releases")
}
// -------------------------------------------------------------------------------------------------------------------
@ -75,6 +76,7 @@ class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) {
lazy val scalaTestModuleConfig = ModuleConfiguration("org.scalatest", ScalaToolsSnapshots)
lazy val logbackModuleConfig = ModuleConfiguration("ch.qos.logback",sbt.DefaultMavenRepository)
lazy val atomikosModuleConfig = ModuleConfiguration("com.atomikos",sbt.DefaultMavenRepository)
lazy val casbahRelease = ModuleConfiguration("com.novus",CasbahRepoReleases)
lazy val embeddedRepo = EmbeddedRepo // This is the only exception, because the embedded repo is fast!
// -------------------------------------------------------------------------------------------------------------------
@ -166,6 +168,10 @@ class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) {
lazy val mongo = "org.mongodb" % "mongo-java-driver" % "2.0" % "compile"
lazy val casbah = "com.novus" % "casbah_2.8.0" % "1.0.8.5" % "compile"
lazy val time = "org.scala-tools" % "time" % "2.8.0-SNAPSHOT-0.2-SNAPSHOT" % "compile"
lazy val multiverse = "org.multiverse" % "multiverse-alpha" % MULTIVERSE_VERSION % "compile" intransitive
lazy val netty = "org.jboss.netty" % "netty" % "3.2.2.Final" % "compile"
@ -180,7 +186,7 @@ class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) {
lazy val sbinary = "sbinary" % "sbinary" % "2.8.0-0.3.1" % "compile"
lazy val sjson = "sjson.json" % "sjson" % "0.7-2.8.0" % "compile"
lazy val sjson = "sjson.json" % "sjson" % "0.8-SNAPSHOT-2.8.0" % "compile"
lazy val slf4j = "org.slf4j" % "slf4j-api" % SLF4J_VERSION % "compile"
@ -483,6 +489,7 @@ class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) {
class AkkaMongoProject(info: ProjectInfo) extends AkkaDefaultProject(info, distPath) {
val mongo = Dependencies.mongo
val casbah = Dependencies.casbah
override def testOptions = TestFilter((name: String) => name.endsWith("Test")) :: Nil
}