Merge branch 'wip-2069-DirectRoutedRemoteActorMultiJvmSpec-patriknw'

This commit is contained in:
Roland 2012-05-24 11:40:05 +02:00
commit fa5960372c
11 changed files with 461 additions and 242 deletions

View file

@ -624,7 +624,7 @@ object Logging {
import java.text.SimpleDateFormat
import java.util.Date
val dateFormat = new SimpleDateFormat("MM/dd/yyyy HH:mm:ss.S")
val dateFormat = new SimpleDateFormat("MM/dd/yyyy HH:mm:ss.SSS")
def timestamp = dateFormat.format(new Date)

View file

@ -27,12 +27,26 @@ import akka.actor.SupervisorStrategy
import java.util.concurrent.ConcurrentHashMap
import akka.actor.Status
sealed trait Direction
sealed trait Direction {
def includes(other: Direction): Boolean
}
object Direction {
case object Send extends Direction
case object Receive extends Direction
case object Both extends Direction
case object Send extends Direction {
override def includes(other: Direction): Boolean = other match {
case Send true
case _ false
}
}
case object Receive extends Direction {
override def includes(other: Direction): Boolean = other match {
case Receive true
case _ false
}
}
case object Both extends Direction {
override def includes(other: Direction): Boolean = true
}
}
/**
@ -205,14 +219,15 @@ trait Conductor { this: TestConductorExt ⇒
*
* INTERNAL API.
*/
private[akka] class ConductorHandler(system: ActorSystem, controller: ActorRef, log: LoggingAdapter) extends SimpleChannelUpstreamHandler {
private[akka] class ConductorHandler(_createTimeout: Timeout, controller: ActorRef, log: LoggingAdapter) extends SimpleChannelUpstreamHandler {
implicit val createTimeout = _createTimeout
val clients = new ConcurrentHashMap[Channel, ActorRef]()
override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
val channel = event.getChannel
log.debug("connection from {}", getAddrString(channel))
val fsm = system.actorOf(Props(new ServerFSM(controller, channel)))
val fsm: ActorRef = Await.result(controller ? Controller.CreateServerFSM(channel) mapTo, Duration.Inf)
clients.put(channel, fsm)
}
@ -332,6 +347,7 @@ private[akka] object Controller {
case class ClientDisconnected(name: RoleName)
case object GetNodes
case object GetSockAddr
case class CreateServerFSM(channel: Channel)
case class NodeInfo(name: RoleName, addr: Address, fsm: ActorRef)
}
@ -349,7 +365,7 @@ private[akka] class Controller(private var initialParticipants: Int, controllerP
val settings = TestConductor().Settings
val connection = RemoteConnection(Server, controllerPort,
new ConductorHandler(context.system, self, Logging(context.system, "ConductorHandler")))
new ConductorHandler(settings.QueryTimeout, self, Logging(context.system, "ConductorHandler")))
/*
* Supervision of the BarrierCoordinator means to catch all his bad emotions
@ -376,8 +392,15 @@ private[akka] class Controller(private var initialParticipants: Int, controllerP
// map keeping unanswered queries for node addresses (enqueued upon GetAddress, serviced upon NodeInfo)
var addrInterest = Map[RoleName, Set[ActorRef]]()
val generation = Iterator from 1
override def receive = LoggingReceive {
case CreateServerFSM(channel)
val (ip, port) = channel.getRemoteAddress match {
case s: InetSocketAddress (s.getHostString, s.getPort)
}
val name = ip + ":" + port + "-server" + generation.next
sender ! context.actorOf(Props(new ServerFSM(self, channel)), name)
case c @ NodeInfo(name, addr, fsm)
barrier forward c
if (nodes contains name) {

View file

@ -10,6 +10,8 @@ import java.util.concurrent.TimeUnit.MILLISECONDS
import akka.actor.ActorRef
import java.util.concurrent.ConcurrentHashMap
import akka.actor.Address
import akka.actor.ActorSystemImpl
import akka.actor.Props
/**
* Access to the [[akka.remote.testconductor.TestConductorExt]] extension:
@ -63,9 +65,9 @@ class TestConductorExt(val system: ExtendedActorSystem) extends Extension with C
/**
* INTERNAL API.
*
* [[akka.remote.testconductor.FailureInjector]]s register themselves here so that
* [[akka.remote.testconductor.NetworkFailureInjector]]s register themselves here so that
* failures can be injected.
*/
private[akka] val failureInjectors = new ConcurrentHashMap[Address, FailureInjector]
private[akka] val failureInjector = system.asInstanceOf[ActorSystemImpl].systemActorOf(Props[FailureInjector], "FailureInjector")
}
}

View file

@ -4,56 +4,141 @@
package akka.remote.testconductor
import java.net.InetSocketAddress
import scala.collection.immutable.Queue
import org.jboss.netty.buffer.ChannelBuffer
import org.jboss.netty.channel.ChannelState.BOUND
import org.jboss.netty.channel.ChannelState.OPEN
import org.jboss.netty.channel.Channel
import org.jboss.netty.channel.ChannelEvent
import org.jboss.netty.channel.ChannelHandlerContext
import org.jboss.netty.channel.ChannelStateEvent
import org.jboss.netty.channel.MessageEvent
import akka.actor.FSM
import akka.actor.Actor
import akka.util.duration.doubleToDurationDouble
import akka.util.Index
import akka.actor.Address
import akka.actor.ActorSystem
import akka.actor.Props
import akka.actor.ActorRef
import akka.event.Logging
import org.jboss.netty.channel.SimpleChannelHandler
import scala.annotation.tailrec
import scala.collection.immutable.Queue
import org.jboss.netty.buffer.ChannelBuffer
import org.jboss.netty.channel.{ SimpleChannelHandler, MessageEvent, Channels, ChannelStateEvent, ChannelHandlerContext, ChannelFutureListener, ChannelFuture }
import akka.actor.{ Props, LoggingFSM, Address, ActorSystem, ActorRef, ActorLogging, Actor, FSM }
import akka.event.Logging
import akka.remote.netty.ChannelAddress
import akka.util.Duration
import akka.actor.LoggingFSM
import org.jboss.netty.channel.Channels
import org.jboss.netty.channel.ChannelFuture
import org.jboss.netty.channel.ChannelFutureListener
import org.jboss.netty.channel.ChannelFuture
import akka.util.duration._
/**
* INTERNAL API.
*/
private[akka] case class FailureInjector(sender: ActorRef, receiver: ActorRef) {
def refs(dir: Direction) = dir match {
case Direction.Send Seq(sender)
case Direction.Receive Seq(receiver)
case Direction.Both Seq(sender, receiver)
private[akka] class FailureInjector extends Actor with ActorLogging {
import ThrottleActor._
import NetworkFailureInjector._
case class ChannelSettings(
ctx: Option[ChannelHandlerContext] = None,
throttleSend: Option[SetRate] = None,
throttleReceive: Option[SetRate] = None)
case class Injectors(sender: ActorRef, receiver: ActorRef)
var channels = Map[ChannelHandlerContext, Injectors]()
var settings = Map[Address, ChannelSettings]()
var generation = Iterator from 1
/**
* Only for a NEW ctx, start ThrottleActors, prime them and update all maps.
*/
def ingestContextAddress(ctx: ChannelHandlerContext, addr: Address): Injectors = {
val gen = generation.next
val name = addr.host.get + ":" + addr.port.get
val thrSend = context.actorOf(Props(new ThrottleActor(ctx)), name + "-snd" + gen)
val thrRecv = context.actorOf(Props(new ThrottleActor(ctx)), name + "-rcv" + gen)
val injectors = Injectors(thrSend, thrRecv)
channels += ctx -> injectors
settings += addr -> (settings get addr map {
case c @ ChannelSettings(prevCtx, ts, tr)
ts foreach (thrSend ! _)
tr foreach (thrRecv ! _)
prevCtx match {
case Some(p) log.warning("installing context {} instead of {} for address {}", ctx, p, addr)
case None // okay
}
c.copy(ctx = Some(ctx))
} getOrElse ChannelSettings(Some(ctx)))
injectors
}
/**
* Retrieve target settings, also if they were sketchy before (i.e. no system name)
*/
def retrieveTargetSettings(target: Address): Option[ChannelSettings] = {
settings get target orElse {
val host = target.host
val port = target.port
settings find {
case (Address("akka", "", `host`, `port`), s) true
case _ false
} map {
case (_, s) settings += target -> s; s
}
}
}
def receive = {
case RemoveContext(ctx)
channels get ctx foreach { inj
context stop inj.sender
context stop inj.receiver
}
channels -= ctx
settings ++= settings collect { case (addr, c @ ChannelSettings(Some(`ctx`), _, _)) (addr, c.copy(ctx = None)) }
case ThrottleMsg(target, dir, rateMBit)
val setting = retrieveTargetSettings(target)
settings += target -> ((setting getOrElse ChannelSettings() match {
case cs @ ChannelSettings(ctx, _, _) if dir includes Direction.Send
ctx foreach (c channels get c foreach (_.sender ! SetRate(rateMBit)))
cs.copy(throttleSend = Some(SetRate(rateMBit)))
case x x
}) match {
case cs @ ChannelSettings(ctx, _, _) if dir includes Direction.Receive
ctx foreach (c channels get c foreach (_.receiver ! SetRate(rateMBit)))
cs.copy(throttleReceive = Some(SetRate(rateMBit)))
case x x
})
sender ! "ok"
case DisconnectMsg(target, abort)
retrieveTargetSettings(target) foreach {
case ChannelSettings(Some(ctx), _, _)
val ch = ctx.getChannel
if (abort) {
ch.getConfig.setOption("soLinger", 0)
log.info("aborting connection {}", ch)
} else log.info("closing connection {}", ch)
ch.close
case _ log.debug("no connection to {} to close or abort", target)
}
sender ! "ok"
case s @ Send(ctx, direction, future, msg)
channels get ctx match {
case Some(Injectors(snd, rcv))
if (direction includes Direction.Send) snd ! s
if (direction includes Direction.Receive) rcv ! s
case None
val (ipaddr, ip, port) = ctx.getChannel.getRemoteAddress match {
case s: InetSocketAddress (s.getAddress, s.getAddress.getHostAddress, s.getPort)
}
val addr = ChannelAddress.get(ctx.getChannel) orElse {
settings collect { case (a @ Address("akka", _, Some(`ip`), Some(`port`)), _) a } headOption
} orElse {
val name = ipaddr.getHostName
if (name == ip) None
else settings collect { case (a @ Address("akka", _, Some(`name`), Some(`port`)), _) a } headOption
} getOrElse Address("akka", "", ip, port) // this will not match later requests directly, but be picked up by retrieveTargetSettings
val inj = ingestContextAddress(ctx, addr)
if (direction includes Direction.Send) inj.sender ! s
if (direction includes Direction.Receive) inj.receiver ! s
}
}
}
/**
* INTERNAL API.
*/
private[akka] object NetworkFailureInjector {
case class SetRate(rateMBit: Float)
case class Disconnect(abort: Boolean)
case class RemoveContext(ctx: ChannelHandlerContext)
}
/**
* Brief overview: all network traffic passes through the `sender`/`receiver` FSMs, which can
* Brief overview: all network traffic passes through the `sender`/`receiver` FSMs managed
* by the FailureInjector of the TestConductor extension. These can
* pass through requests immediately, drop them or throttle to a desired rate. The FSMs are
* registered in the TestConductorExt.failureInjectors so that settings can be applied from
* registered in the TestConductorExt.failureInjector so that settings can be applied from
* the ClientFSMs.
*
* I found that simply forwarding events using ctx.sendUpstream/sendDownstream does not work,
@ -63,195 +148,181 @@ private[akka] object NetworkFailureInjector {
* INTERNAL API.
*/
private[akka] class NetworkFailureInjector(system: ActorSystem) extends SimpleChannelHandler {
import NetworkFailureInjector._
val log = Logging(system, "FailureInjector")
private val log = Logging(system, "FailureInjector")
// everything goes via these Throttle actors to enable easy steering
private val sender = system.actorOf(Props(new Throttle(Direction.Send)))
private val receiver = system.actorOf(Props(new Throttle(Direction.Receive)))
private val packetSplitThreshold = TestConductor(system).Settings.PacketSplitThreshold
/*
* State, Data and Messages for the internal Throttle actor
*/
sealed private trait State
private case object PassThrough extends State
private case object Throttle extends State
private case object Blackhole extends State
private case class Data(lastSent: Long, rateMBit: Float, queue: Queue[Send])
private case class Send(ctx: ChannelHandlerContext, future: Option[ChannelFuture], msg: AnyRef)
private case class SetContext(ctx: ChannelHandlerContext)
private case object Tick
private class Throttle(dir: Direction) extends Actor with LoggingFSM[State, Data] {
import FSM._
var channelContext: ChannelHandlerContext = _
startWith(PassThrough, Data(0, -1, Queue()))
when(PassThrough) {
case Event(s @ Send(_, _, msg), _)
log.debug("sending msg (PassThrough): {}", msg)
send(s)
stay
}
when(Throttle) {
case Event(s: Send, data @ Data(_, _, Queue()))
stay using sendThrottled(data.copy(lastSent = System.nanoTime, queue = Queue(s)))
case Event(s: Send, data)
stay using sendThrottled(data.copy(queue = data.queue.enqueue(s)))
case Event(Tick, data)
stay using sendThrottled(data)
}
onTransition {
case Throttle -> PassThrough
for (s stateData.queue) {
log.debug("sending msg (Transition): {}", s.msg)
send(s)
}
cancelTimer("send")
case Throttle -> Blackhole
cancelTimer("send")
}
when(Blackhole) {
case Event(Send(_, _, msg), _)
log.debug("dropping msg {}", msg)
stay
}
whenUnhandled {
case Event(NetworkFailureInjector.SetRate(rate), d)
sender ! "ok"
if (rate > 0) {
goto(Throttle) using d.copy(lastSent = System.nanoTime, rateMBit = rate, queue = Queue())
} else if (rate == 0) {
goto(Blackhole)
} else {
goto(PassThrough)
}
case Event(SetContext(ctx), _) channelContext = ctx; stay
case Event(NetworkFailureInjector.Disconnect(abort), Data(ctx, _, _))
sender ! "ok"
// TODO implement abort
channelContext.getChannel.disconnect()
stay
}
initialize
private def sendThrottled(d: Data): Data = {
val (data, toSend, toTick) = schedule(d)
for (s toSend) {
log.debug("sending msg (Tick): {}", s.msg)
send(s)
}
if (!timerActive_?("send"))
for (time toTick) {
log.debug("scheduling next Tick in {}", time)
setTimer("send", Tick, time, false)
}
data
}
private def send(s: Send): Unit = dir match {
case Direction.Send Channels.write(s.ctx, s.future getOrElse Channels.future(s.ctx.getChannel), s.msg)
case Direction.Receive Channels.fireMessageReceived(s.ctx, s.msg)
case _
}
private def schedule(d: Data): (Data, Seq[Send], Option[Duration]) = {
val now = System.nanoTime
@tailrec def rec(d: Data, toSend: Seq[Send]): (Data, Seq[Send], Option[Duration]) = {
if (d.queue.isEmpty) (d, toSend, None)
else {
val timeForPacket = d.lastSent + (1000 * size(d.queue.head.msg) / d.rateMBit).toLong
if (timeForPacket <= now) rec(Data(timeForPacket, d.rateMBit, d.queue.tail), toSend :+ d.queue.head)
else {
val splitThreshold = d.lastSent + packetSplitThreshold.toNanos
if (now < splitThreshold) (d, toSend, Some((timeForPacket - now).nanos min (splitThreshold - now).nanos))
else {
val microsToSend = (now - d.lastSent) / 1000
val (s1, s2) = split(d.queue.head, (microsToSend * d.rateMBit / 8).toInt)
(d.copy(queue = s2 +: d.queue.tail), toSend :+ s1, Some((timeForPacket - now).nanos min packetSplitThreshold))
}
}
}
}
rec(d, Seq())
}
private def split(s: Send, bytes: Int): (Send, Send) = {
s.msg match {
case buf: ChannelBuffer
val f = s.future map { f
val newF = Channels.future(s.ctx.getChannel)
newF.addListener(new ChannelFutureListener {
def operationComplete(future: ChannelFuture) {
if (future.isCancelled) f.cancel()
else future.getCause match {
case null
case thr f.setFailure(thr)
}
}
})
newF
}
val b = buf.slice()
b.writerIndex(b.readerIndex + bytes)
buf.readerIndex(buf.readerIndex + bytes)
(Send(s.ctx, f, b), Send(s.ctx, s.future, buf))
}
}
private def size(msg: AnyRef) = msg match {
case b: ChannelBuffer b.readableBytes() * 8
case _ throw new UnsupportedOperationException("NetworkFailureInjector only supports ChannelBuffer messages")
}
}
private var remote: Option[Address] = None
override def messageReceived(ctx: ChannelHandlerContext, msg: MessageEvent) {
log.debug("upstream(queued): {}", msg)
receiver ! Send(ctx, Option(msg.getFuture), msg.getMessage)
}
private val conductor = TestConductor(system)
private var announced = false
override def channelConnected(ctx: ChannelHandlerContext, state: ChannelStateEvent) {
state.getValue match {
case a: InetSocketAddress
val addr = Address("akka", "", a.getHostName, a.getPort)
log.debug("connected to {}", addr)
TestConductor(system).failureInjectors.put(addr, FailureInjector(sender, receiver)) match {
case null // okay
case fi system.log.error("{} already registered for address {}", fi, addr)
}
remote = Some(addr)
sender ! SetContext(ctx)
case x throw new IllegalArgumentException("unknown address type: " + x)
}
}
override def channelDisconnected(ctx: ChannelHandlerContext, state: ChannelStateEvent) {
log.debug("disconnected from {}", remote)
remote = remote flatMap { addr
TestConductor(system).failureInjectors.remove(addr)
system.stop(sender)
system.stop(receiver)
None
}
log.debug("disconnected from {}", state.getChannel)
conductor.failureInjector ! RemoveContext(ctx)
}
override def messageReceived(ctx: ChannelHandlerContext, msg: MessageEvent) {
log.debug("upstream(queued): {}", msg)
conductor.failureInjector ! ThrottleActor.Send(ctx, Direction.Receive, Option(msg.getFuture), msg.getMessage)
}
override def writeRequested(ctx: ChannelHandlerContext, msg: MessageEvent) {
log.debug("downstream(queued): {}", msg)
sender ! Send(ctx, Option(msg.getFuture), msg.getMessage)
conductor.failureInjector ! ThrottleActor.Send(ctx, Direction.Send, Option(msg.getFuture), msg.getMessage)
}
}
/**
* INTERNAL API.
*/
private[akka] object ThrottleActor {
sealed trait State
case object PassThrough extends State
case object Throttle extends State
case object Blackhole extends State
case class Data(lastSent: Long, rateMBit: Float, queue: Queue[Send])
case class Send(ctx: ChannelHandlerContext, direction: Direction, future: Option[ChannelFuture], msg: AnyRef)
case class SetRate(rateMBit: Float)
case object Tick
}
/**
* INTERNAL API.
*/
private[akka] class ThrottleActor(channelContext: ChannelHandlerContext)
extends Actor with LoggingFSM[ThrottleActor.State, ThrottleActor.Data] {
import ThrottleActor._
import FSM._
private val packetSplitThreshold = TestConductor(context.system).Settings.PacketSplitThreshold
startWith(PassThrough, Data(0, -1, Queue()))
when(PassThrough) {
case Event(s @ Send(_, _, _, msg), _)
log.debug("sending msg (PassThrough): {}", msg)
send(s)
stay
}
when(Throttle) {
case Event(s: Send, data @ Data(_, _, Queue()))
stay using sendThrottled(data.copy(lastSent = System.nanoTime, queue = Queue(s)))
case Event(s: Send, data)
stay using sendThrottled(data.copy(queue = data.queue.enqueue(s)))
case Event(Tick, data)
stay using sendThrottled(data)
}
onTransition {
case Throttle -> PassThrough
for (s stateData.queue) {
log.debug("sending msg (Transition): {}", s.msg)
send(s)
}
cancelTimer("send")
case Throttle -> Blackhole
cancelTimer("send")
}
when(Blackhole) {
case Event(Send(_, _, _, msg), _)
log.debug("dropping msg {}", msg)
stay
}
whenUnhandled {
case Event(SetRate(rate), d)
if (rate > 0) {
goto(Throttle) using d.copy(lastSent = System.nanoTime, rateMBit = rate, queue = Queue())
} else if (rate == 0) {
goto(Blackhole)
} else {
goto(PassThrough)
}
}
initialize
private def sendThrottled(d: Data): Data = {
val (data, toSend, toTick) = schedule(d)
for (s toSend) {
log.debug("sending msg (Tick): {}", s.msg)
send(s)
}
if (!timerActive_?("send"))
for (time toTick) {
log.debug("scheduling next Tick in {}", time)
setTimer("send", Tick, time, false)
}
data
}
private def send(s: Send): Unit = s.direction match {
case Direction.Send Channels.write(s.ctx, s.future getOrElse Channels.future(s.ctx.getChannel), s.msg)
case Direction.Receive Channels.fireMessageReceived(s.ctx, s.msg)
case _
}
private def schedule(d: Data): (Data, Seq[Send], Option[Duration]) = {
val now = System.nanoTime
@tailrec def rec(d: Data, toSend: Seq[Send]): (Data, Seq[Send], Option[Duration]) = {
if (d.queue.isEmpty) (d, toSend, None)
else {
val timeForPacket = d.lastSent + (1000 * size(d.queue.head.msg) / d.rateMBit).toLong
if (timeForPacket <= now) rec(Data(timeForPacket, d.rateMBit, d.queue.tail), toSend :+ d.queue.head)
else {
val splitThreshold = d.lastSent + packetSplitThreshold.toNanos
if (now < splitThreshold) (d, toSend, Some((timeForPacket - now).nanos min (splitThreshold - now).nanos))
else {
val microsToSend = (now - d.lastSent) / 1000
val (s1, s2) = split(d.queue.head, (microsToSend * d.rateMBit / 8).toInt)
(d.copy(queue = s2 +: d.queue.tail), toSend :+ s1, Some((timeForPacket - now).nanos min packetSplitThreshold))
}
}
}
}
rec(d, Seq())
}
private def split(s: Send, bytes: Int): (Send, Send) = {
s.msg match {
case buf: ChannelBuffer
val f = s.future map { f
val newF = Channels.future(s.ctx.getChannel)
newF.addListener(new ChannelFutureListener {
def operationComplete(future: ChannelFuture) {
if (future.isCancelled) f.cancel()
else future.getCause match {
case null
case thr f.setFailure(thr)
}
}
})
newF
}
val b = buf.slice()
b.writerIndex(b.readerIndex + bytes)
buf.readerIndex(buf.readerIndex + bytes)
(Send(s.ctx, s.direction, f, b), Send(s.ctx, s.direction, s.future, buf))
}
}
private def size(msg: AnyRef) = msg match {
case b: ChannelBuffer b.readableBytes() * 8
case _ throw new UnsupportedOperationException("NetworkFailureInjector only supports ChannelBuffer messages")
}
}

View file

@ -200,21 +200,13 @@ private[akka] class ClientFSM(name: RoleName, controllerAddr: InetSocketAddress)
log.warning("did not expect {}", op)
}
stay using d.copy(runningOp = None)
case ThrottleMsg(target, dir, rate)
case t: ThrottleMsg
import settings.QueryTimeout
import context.dispatcher
TestConductor().failureInjectors.get(target.copy(system = "")) match {
case null log.warning("cannot throttle unknown address {}", target)
case inj
Future.sequence(inj.refs(dir) map (_ ? NetworkFailureInjector.SetRate(rate))) map (_ ToServer(Done)) pipeTo self
}
TestConductor().failureInjector ? t map (_ ToServer(Done)) pipeTo self
stay
case DisconnectMsg(target, abort)
case d: DisconnectMsg
import settings.QueryTimeout
TestConductor().failureInjectors.get(target.copy(system = "")) match {
case null log.warning("cannot disconnect unknown address {}", target)
case inj inj.sender ? NetworkFailureInjector.Disconnect(abort) map (_ ToServer(Done)) pipeTo self
}
TestConductor().failureInjector ? d map (_ ToServer(Done)) pipeTo self
stay
case TerminateMsg(exit)
System.exit(exit)

View file

@ -0,0 +1,55 @@
/**
* Copyright (C) 2009-2012 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.remote
import akka.actor.Actor
import akka.actor.ActorRef
import akka.actor.Props
import akka.pattern.ask
import akka.remote.testkit.MultiNodeConfig
import akka.remote.testkit.MultiNodeSpec
import akka.testkit._
object SimpleRemoteMultiJvmSpec extends MultiNodeConfig {
class SomeActor extends Actor with Serializable {
def receive = {
case "identify" sender ! self
}
}
commonConfig(debugConfig(on = false))
val master = role("master")
val slave = role("slave")
}
class SimpleRemoteMultiJvmNode1 extends SimpleRemoteSpec
class SimpleRemoteMultiJvmNode2 extends SimpleRemoteSpec
class SimpleRemoteSpec extends MultiNodeSpec(SimpleRemoteMultiJvmSpec)
with ImplicitSender with DefaultTimeout {
import SimpleRemoteMultiJvmSpec._
def initialParticipants = 2
runOn(master) {
system.actorOf(Props[SomeActor], "service-hello")
}
"Remoting" must {
"lookup remote actor" in {
runOn(slave) {
val hello = system.actorFor(node(master) / "user" / "service-hello")
hello.isInstanceOf[RemoteActorRef] must be(true)
val masterAddress = testConductor.getAddressFor(master).await
(hello ? "identify").await.asInstanceOf[ActorRef].path.address must equal(masterAddress)
}
testConductor.enter("done")
}
}
}

View file

@ -0,0 +1,73 @@
/**
* Copyright (C) 2009-2012 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.remote.router
import com.typesafe.config.ConfigFactory
import akka.actor.Actor
import akka.actor.ActorRef
import akka.actor.Props
import akka.pattern.ask
import akka.remote.RemoteActorRef
import akka.remote.testkit.MultiNodeConfig
import akka.remote.testkit.MultiNodeSpec
import akka.testkit._
object DirectRoutedRemoteActorMultiJvmSpec extends MultiNodeConfig {
class SomeActor extends Actor with Serializable {
def receive = {
case "identify" sender ! self
}
}
commonConfig(debugConfig(on = false))
val master = role("master")
val slave = role("slave")
nodeConfig(master, ConfigFactory.parseString("""
akka.actor {
deployment {
/service-hello.remote = "akka://MultiNodeSpec@%s"
}
}
# FIXME When using NettyRemoteTransport instead of TestConductorTransport it works
# akka.remote.transport = "akka.remote.netty.NettyRemoteTransport"
""".format("localhost:2553"))) // FIXME is there a way to avoid hardcoding the host:port here?
nodeConfig(slave, ConfigFactory.parseString("""
akka.remote.netty.port = 2553
"""))
}
class DirectRoutedRemoteActorMultiJvmNode1 extends DirectRoutedRemoteActorSpec
class DirectRoutedRemoteActorMultiJvmNode2 extends DirectRoutedRemoteActorSpec
class DirectRoutedRemoteActorSpec extends MultiNodeSpec(DirectRoutedRemoteActorMultiJvmSpec)
with ImplicitSender with DefaultTimeout {
import DirectRoutedRemoteActorMultiJvmSpec._
def initialParticipants = 2
"A new remote actor configured with a Direct router" must {
"be locally instantiated on a remote node and be able to communicate through its RemoteActorRef" in {
runOn(master) {
val actor = system.actorOf(Props[SomeActor], "service-hello")
actor.isInstanceOf[RemoteActorRef] must be(true)
val slaveAddress = testConductor.getAddressFor(slave).await
(actor ? "identify").await.asInstanceOf[ActorRef].path.address must equal(slaveAddress)
// shut down the actor before we let the other node(s) shut down so we don't try to send
// "Terminate" to a shut down node
system.stop(actor)
}
testConductor.enter("done")
}
}
}

View file

@ -19,8 +19,8 @@ import akka.remote.testkit.MultiNodeSpec
import akka.remote.testkit.MultiNodeConfig
object TestConductorMultiJvmSpec extends MultiNodeConfig {
commonConfig(debugConfig(on = true))
commonConfig(debugConfig(on = false))
val master = role("master")
val slave = role("slave")
}

View file

@ -173,6 +173,7 @@ class ActiveRemoteClient private[akka] (
notifyListeners(RemoteClientError(connection.getCause, netty, remoteAddress))
false
} else {
ChannelAddress.set(connection.getChannel, Some(remoteAddress))
sendSecureCookie(connection)
notifyListeners(RemoteClientStarted(netty, remoteAddress))
true
@ -196,8 +197,10 @@ class ActiveRemoteClient private[akka] (
notifyListeners(RemoteClientShutdown(netty, remoteAddress))
try {
if ((connection ne null) && (connection.getChannel ne null))
if ((connection ne null) && (connection.getChannel ne null)) {
ChannelAddress.remove(connection.getChannel)
connection.getChannel.close()
}
} finally {
try {
if (openChannels ne null) openChannels.close.awaitUninterruptibly()

View file

@ -12,7 +12,7 @@ import java.util.concurrent.Executors
import scala.collection.mutable.HashMap
import org.jboss.netty.channel.group.{ DefaultChannelGroup, ChannelGroupFuture }
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.channel.{ ChannelHandlerContext, Channel, StaticChannelPipeline, ChannelHandler, ChannelPipelineFactory }
import org.jboss.netty.channel.{ ChannelHandlerContext, Channel, StaticChannelPipeline, ChannelHandler, ChannelPipelineFactory, ChannelLocal }
import org.jboss.netty.handler.codec.frame.{ LengthFieldPrepender, LengthFieldBasedFrameDecoder }
import org.jboss.netty.handler.codec.protobuf.{ ProtobufEncoder, ProtobufDecoder }
import org.jboss.netty.handler.execution.{ ExecutionHandler, OrderedMemoryAwareThreadPoolExecutor }
@ -25,6 +25,10 @@ import akka.remote.{ RemoteTransportException, RemoteTransport, RemoteSettings,
import akka.util.NonFatal
import akka.actor.{ ExtendedActorSystem, Address, ActorRef }
object ChannelAddress extends ChannelLocal[Option[Address]] {
override def initialValue(ch: Channel): Option[Address] = None
}
/**
* Provides the implementation of the Netty remote support
*/

View file

@ -112,10 +112,6 @@ class RemoteServerHandler(
val openChannels: ChannelGroup,
val netty: NettyRemoteTransport) extends SimpleChannelUpstreamHandler {
val channelAddress = new ChannelLocal[Option[Address]](false) {
override def initialValue(channel: Channel) = None
}
import netty.settings
private var addressToSet = true
@ -139,16 +135,16 @@ class RemoteServerHandler(
override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = ()
override def channelDisconnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
netty.notifyListeners(RemoteServerClientDisconnected(netty, channelAddress.get(ctx.getChannel)))
netty.notifyListeners(RemoteServerClientDisconnected(netty, ChannelAddress.get(ctx.getChannel)))
}
override def channelClosed(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
val address = channelAddress.get(ctx.getChannel)
val address = ChannelAddress.get(ctx.getChannel)
if (address.isDefined && settings.UsePassiveConnections)
netty.unbindClient(address.get)
netty.notifyListeners(RemoteServerClientClosed(netty, address))
channelAddress.remove(ctx.getChannel)
ChannelAddress.remove(ctx.getChannel)
}
override def messageReceived(ctx: ChannelHandlerContext, event: MessageEvent) = try {
@ -162,7 +158,7 @@ class RemoteServerHandler(
case CommandType.CONNECT
val origin = instruction.getOrigin
val inbound = Address("akka", origin.getSystem, origin.getHostname, origin.getPort)
channelAddress.set(event.getChannel, Option(inbound))
ChannelAddress.set(event.getChannel, Option(inbound))
//If we want to reuse the inbound connections as outbound we need to get busy
if (settings.UsePassiveConnections)