diff --git a/akka-actor/src/main/scala/akka/event/Logging.scala b/akka-actor/src/main/scala/akka/event/Logging.scala index bf4fc7996d..5eb483c873 100644 --- a/akka-actor/src/main/scala/akka/event/Logging.scala +++ b/akka-actor/src/main/scala/akka/event/Logging.scala @@ -624,7 +624,7 @@ object Logging { import java.text.SimpleDateFormat import java.util.Date - val dateFormat = new SimpleDateFormat("MM/dd/yyyy HH:mm:ss.S") + val dateFormat = new SimpleDateFormat("MM/dd/yyyy HH:mm:ss.SSS") def timestamp = dateFormat.format(new Date) diff --git a/akka-remote-tests/src/main/scala/akka/remote/testconductor/Conductor.scala b/akka-remote-tests/src/main/scala/akka/remote/testconductor/Conductor.scala index 1ec172e9ce..89fa807762 100644 --- a/akka-remote-tests/src/main/scala/akka/remote/testconductor/Conductor.scala +++ b/akka-remote-tests/src/main/scala/akka/remote/testconductor/Conductor.scala @@ -27,12 +27,26 @@ import akka.actor.SupervisorStrategy import java.util.concurrent.ConcurrentHashMap import akka.actor.Status -sealed trait Direction +sealed trait Direction { + def includes(other: Direction): Boolean +} object Direction { - case object Send extends Direction - case object Receive extends Direction - case object Both extends Direction + case object Send extends Direction { + override def includes(other: Direction): Boolean = other match { + case Send ⇒ true + case _ ⇒ false + } + } + case object Receive extends Direction { + override def includes(other: Direction): Boolean = other match { + case Receive ⇒ true + case _ ⇒ false + } + } + case object Both extends Direction { + override def includes(other: Direction): Boolean = true + } } /** @@ -205,14 +219,15 @@ trait Conductor { this: TestConductorExt ⇒ * * INTERNAL API. */ -private[akka] class ConductorHandler(system: ActorSystem, controller: ActorRef, log: LoggingAdapter) extends SimpleChannelUpstreamHandler { +private[akka] class ConductorHandler(_createTimeout: Timeout, controller: ActorRef, log: LoggingAdapter) extends SimpleChannelUpstreamHandler { + implicit val createTimeout = _createTimeout val clients = new ConcurrentHashMap[Channel, ActorRef]() override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { val channel = event.getChannel log.debug("connection from {}", getAddrString(channel)) - val fsm = system.actorOf(Props(new ServerFSM(controller, channel))) + val fsm: ActorRef = Await.result(controller ? Controller.CreateServerFSM(channel) mapTo, Duration.Inf) clients.put(channel, fsm) } @@ -332,6 +347,7 @@ private[akka] object Controller { case class ClientDisconnected(name: RoleName) case object GetNodes case object GetSockAddr + case class CreateServerFSM(channel: Channel) case class NodeInfo(name: RoleName, addr: Address, fsm: ActorRef) } @@ -349,7 +365,7 @@ private[akka] class Controller(private var initialParticipants: Int, controllerP val settings = TestConductor().Settings val connection = RemoteConnection(Server, controllerPort, - new ConductorHandler(context.system, self, Logging(context.system, "ConductorHandler"))) + new ConductorHandler(settings.QueryTimeout, self, Logging(context.system, "ConductorHandler"))) /* * Supervision of the BarrierCoordinator means to catch all his bad emotions @@ -376,8 +392,15 @@ private[akka] class Controller(private var initialParticipants: Int, controllerP // map keeping unanswered queries for node addresses (enqueued upon GetAddress, serviced upon NodeInfo) var addrInterest = Map[RoleName, Set[ActorRef]]() + val generation = Iterator from 1 override def receive = LoggingReceive { + case CreateServerFSM(channel) ⇒ + val (ip, port) = channel.getRemoteAddress match { + case s: InetSocketAddress ⇒ (s.getHostString, s.getPort) + } + val name = ip + ":" + port + "-server" + generation.next + sender ! context.actorOf(Props(new ServerFSM(self, channel)), name) case c @ NodeInfo(name, addr, fsm) ⇒ barrier forward c if (nodes contains name) { diff --git a/akka-remote-tests/src/main/scala/akka/remote/testconductor/Extension.scala b/akka-remote-tests/src/main/scala/akka/remote/testconductor/Extension.scala index 6800253ae0..48f3983a78 100644 --- a/akka-remote-tests/src/main/scala/akka/remote/testconductor/Extension.scala +++ b/akka-remote-tests/src/main/scala/akka/remote/testconductor/Extension.scala @@ -10,6 +10,8 @@ import java.util.concurrent.TimeUnit.MILLISECONDS import akka.actor.ActorRef import java.util.concurrent.ConcurrentHashMap import akka.actor.Address +import akka.actor.ActorSystemImpl +import akka.actor.Props /** * Access to the [[akka.remote.testconductor.TestConductorExt]] extension: @@ -63,9 +65,9 @@ class TestConductorExt(val system: ExtendedActorSystem) extends Extension with C /** * INTERNAL API. * - * [[akka.remote.testconductor.FailureInjector]]s register themselves here so that + * [[akka.remote.testconductor.NetworkFailureInjector]]s register themselves here so that * failures can be injected. */ - private[akka] val failureInjectors = new ConcurrentHashMap[Address, FailureInjector] + private[akka] val failureInjector = system.asInstanceOf[ActorSystemImpl].systemActorOf(Props[FailureInjector], "FailureInjector") -} \ No newline at end of file +} diff --git a/akka-remote-tests/src/main/scala/akka/remote/testconductor/NetworkFailureInjector.scala b/akka-remote-tests/src/main/scala/akka/remote/testconductor/NetworkFailureInjector.scala index 629a15d51f..ba8f8d1285 100644 --- a/akka-remote-tests/src/main/scala/akka/remote/testconductor/NetworkFailureInjector.scala +++ b/akka-remote-tests/src/main/scala/akka/remote/testconductor/NetworkFailureInjector.scala @@ -4,56 +4,141 @@ package akka.remote.testconductor import java.net.InetSocketAddress -import scala.collection.immutable.Queue -import org.jboss.netty.buffer.ChannelBuffer -import org.jboss.netty.channel.ChannelState.BOUND -import org.jboss.netty.channel.ChannelState.OPEN -import org.jboss.netty.channel.Channel -import org.jboss.netty.channel.ChannelEvent -import org.jboss.netty.channel.ChannelHandlerContext -import org.jboss.netty.channel.ChannelStateEvent -import org.jboss.netty.channel.MessageEvent -import akka.actor.FSM -import akka.actor.Actor -import akka.util.duration.doubleToDurationDouble -import akka.util.Index -import akka.actor.Address -import akka.actor.ActorSystem -import akka.actor.Props -import akka.actor.ActorRef -import akka.event.Logging -import org.jboss.netty.channel.SimpleChannelHandler + import scala.annotation.tailrec +import scala.collection.immutable.Queue + +import org.jboss.netty.buffer.ChannelBuffer +import org.jboss.netty.channel.{ SimpleChannelHandler, MessageEvent, Channels, ChannelStateEvent, ChannelHandlerContext, ChannelFutureListener, ChannelFuture } + +import akka.actor.{ Props, LoggingFSM, Address, ActorSystem, ActorRef, ActorLogging, Actor, FSM } +import akka.event.Logging +import akka.remote.netty.ChannelAddress import akka.util.Duration -import akka.actor.LoggingFSM -import org.jboss.netty.channel.Channels -import org.jboss.netty.channel.ChannelFuture -import org.jboss.netty.channel.ChannelFutureListener -import org.jboss.netty.channel.ChannelFuture +import akka.util.duration._ /** * INTERNAL API. */ -private[akka] case class FailureInjector(sender: ActorRef, receiver: ActorRef) { - def refs(dir: Direction) = dir match { - case Direction.Send ⇒ Seq(sender) - case Direction.Receive ⇒ Seq(receiver) - case Direction.Both ⇒ Seq(sender, receiver) +private[akka] class FailureInjector extends Actor with ActorLogging { + import ThrottleActor._ + import NetworkFailureInjector._ + + case class ChannelSettings( + ctx: Option[ChannelHandlerContext] = None, + throttleSend: Option[SetRate] = None, + throttleReceive: Option[SetRate] = None) + case class Injectors(sender: ActorRef, receiver: ActorRef) + + var channels = Map[ChannelHandlerContext, Injectors]() + var settings = Map[Address, ChannelSettings]() + var generation = Iterator from 1 + + /** + * Only for a NEW ctx, start ThrottleActors, prime them and update all maps. + */ + def ingestContextAddress(ctx: ChannelHandlerContext, addr: Address): Injectors = { + val gen = generation.next + val name = addr.host.get + ":" + addr.port.get + val thrSend = context.actorOf(Props(new ThrottleActor(ctx)), name + "-snd" + gen) + val thrRecv = context.actorOf(Props(new ThrottleActor(ctx)), name + "-rcv" + gen) + val injectors = Injectors(thrSend, thrRecv) + channels += ctx -> injectors + settings += addr -> (settings get addr map { + case c @ ChannelSettings(prevCtx, ts, tr) ⇒ + ts foreach (thrSend ! _) + tr foreach (thrRecv ! _) + prevCtx match { + case Some(p) ⇒ log.warning("installing context {} instead of {} for address {}", ctx, p, addr) + case None ⇒ // okay + } + c.copy(ctx = Some(ctx)) + } getOrElse ChannelSettings(Some(ctx))) + injectors + } + + /** + * Retrieve target settings, also if they were sketchy before (i.e. no system name) + */ + def retrieveTargetSettings(target: Address): Option[ChannelSettings] = { + settings get target orElse { + val host = target.host + val port = target.port + settings find { + case (Address("akka", "", `host`, `port`), s) ⇒ true + case _ ⇒ false + } map { + case (_, s) ⇒ settings += target -> s; s + } + } + } + + def receive = { + case RemoveContext(ctx) ⇒ + channels get ctx foreach { inj ⇒ + context stop inj.sender + context stop inj.receiver + } + channels -= ctx + settings ++= settings collect { case (addr, c @ ChannelSettings(Some(`ctx`), _, _)) ⇒ (addr, c.copy(ctx = None)) } + case ThrottleMsg(target, dir, rateMBit) ⇒ + val setting = retrieveTargetSettings(target) + settings += target -> ((setting getOrElse ChannelSettings() match { + case cs @ ChannelSettings(ctx, _, _) if dir includes Direction.Send ⇒ + ctx foreach (c ⇒ channels get c foreach (_.sender ! SetRate(rateMBit))) + cs.copy(throttleSend = Some(SetRate(rateMBit))) + case x ⇒ x + }) match { + case cs @ ChannelSettings(ctx, _, _) if dir includes Direction.Receive ⇒ + ctx foreach (c ⇒ channels get c foreach (_.receiver ! SetRate(rateMBit))) + cs.copy(throttleReceive = Some(SetRate(rateMBit))) + case x ⇒ x + }) + sender ! "ok" + case DisconnectMsg(target, abort) ⇒ + retrieveTargetSettings(target) foreach { + case ChannelSettings(Some(ctx), _, _) ⇒ + val ch = ctx.getChannel + if (abort) { + ch.getConfig.setOption("soLinger", 0) + log.info("aborting connection {}", ch) + } else log.info("closing connection {}", ch) + ch.close + case _ ⇒ log.debug("no connection to {} to close or abort", target) + } + sender ! "ok" + case s @ Send(ctx, direction, future, msg) ⇒ + channels get ctx match { + case Some(Injectors(snd, rcv)) ⇒ + if (direction includes Direction.Send) snd ! s + if (direction includes Direction.Receive) rcv ! s + case None ⇒ + val (ipaddr, ip, port) = ctx.getChannel.getRemoteAddress match { + case s: InetSocketAddress ⇒ (s.getAddress, s.getAddress.getHostAddress, s.getPort) + } + val addr = ChannelAddress.get(ctx.getChannel) orElse { + settings collect { case (a @ Address("akka", _, Some(`ip`), Some(`port`)), _) ⇒ a } headOption + } orElse { + val name = ipaddr.getHostName + if (name == ip) None + else settings collect { case (a @ Address("akka", _, Some(`name`), Some(`port`)), _) ⇒ a } headOption + } getOrElse Address("akka", "", ip, port) // this will not match later requests directly, but be picked up by retrieveTargetSettings + val inj = ingestContextAddress(ctx, addr) + if (direction includes Direction.Send) inj.sender ! s + if (direction includes Direction.Receive) inj.receiver ! s + } } } -/** - * INTERNAL API. - */ private[akka] object NetworkFailureInjector { - case class SetRate(rateMBit: Float) - case class Disconnect(abort: Boolean) + case class RemoveContext(ctx: ChannelHandlerContext) } /** - * Brief overview: all network traffic passes through the `sender`/`receiver` FSMs, which can + * Brief overview: all network traffic passes through the `sender`/`receiver` FSMs managed + * by the FailureInjector of the TestConductor extension. These can * pass through requests immediately, drop them or throttle to a desired rate. The FSMs are - * registered in the TestConductorExt.failureInjectors so that settings can be applied from + * registered in the TestConductorExt.failureInjector so that settings can be applied from * the ClientFSMs. * * I found that simply forwarding events using ctx.sendUpstream/sendDownstream does not work, @@ -63,195 +148,181 @@ private[akka] object NetworkFailureInjector { * INTERNAL API. */ private[akka] class NetworkFailureInjector(system: ActorSystem) extends SimpleChannelHandler { + import NetworkFailureInjector._ - val log = Logging(system, "FailureInjector") + private val log = Logging(system, "FailureInjector") - // everything goes via these Throttle actors to enable easy steering - private val sender = system.actorOf(Props(new Throttle(Direction.Send))) - private val receiver = system.actorOf(Props(new Throttle(Direction.Receive))) - - private val packetSplitThreshold = TestConductor(system).Settings.PacketSplitThreshold - - /* - * State, Data and Messages for the internal Throttle actor - */ - sealed private trait State - private case object PassThrough extends State - private case object Throttle extends State - private case object Blackhole extends State - - private case class Data(lastSent: Long, rateMBit: Float, queue: Queue[Send]) - - private case class Send(ctx: ChannelHandlerContext, future: Option[ChannelFuture], msg: AnyRef) - private case class SetContext(ctx: ChannelHandlerContext) - private case object Tick - - private class Throttle(dir: Direction) extends Actor with LoggingFSM[State, Data] { - import FSM._ - - var channelContext: ChannelHandlerContext = _ - - startWith(PassThrough, Data(0, -1, Queue())) - - when(PassThrough) { - case Event(s @ Send(_, _, msg), _) ⇒ - log.debug("sending msg (PassThrough): {}", msg) - send(s) - stay - } - - when(Throttle) { - case Event(s: Send, data @ Data(_, _, Queue())) ⇒ - stay using sendThrottled(data.copy(lastSent = System.nanoTime, queue = Queue(s))) - case Event(s: Send, data) ⇒ - stay using sendThrottled(data.copy(queue = data.queue.enqueue(s))) - case Event(Tick, data) ⇒ - stay using sendThrottled(data) - } - - onTransition { - case Throttle -> PassThrough ⇒ - for (s ← stateData.queue) { - log.debug("sending msg (Transition): {}", s.msg) - send(s) - } - cancelTimer("send") - case Throttle -> Blackhole ⇒ - cancelTimer("send") - } - - when(Blackhole) { - case Event(Send(_, _, msg), _) ⇒ - log.debug("dropping msg {}", msg) - stay - } - - whenUnhandled { - case Event(NetworkFailureInjector.SetRate(rate), d) ⇒ - sender ! "ok" - if (rate > 0) { - goto(Throttle) using d.copy(lastSent = System.nanoTime, rateMBit = rate, queue = Queue()) - } else if (rate == 0) { - goto(Blackhole) - } else { - goto(PassThrough) - } - case Event(SetContext(ctx), _) ⇒ channelContext = ctx; stay - case Event(NetworkFailureInjector.Disconnect(abort), Data(ctx, _, _)) ⇒ - sender ! "ok" - // TODO implement abort - channelContext.getChannel.disconnect() - stay - } - - initialize - - private def sendThrottled(d: Data): Data = { - val (data, toSend, toTick) = schedule(d) - for (s ← toSend) { - log.debug("sending msg (Tick): {}", s.msg) - send(s) - } - if (!timerActive_?("send")) - for (time ← toTick) { - log.debug("scheduling next Tick in {}", time) - setTimer("send", Tick, time, false) - } - data - } - - private def send(s: Send): Unit = dir match { - case Direction.Send ⇒ Channels.write(s.ctx, s.future getOrElse Channels.future(s.ctx.getChannel), s.msg) - case Direction.Receive ⇒ Channels.fireMessageReceived(s.ctx, s.msg) - case _ ⇒ - } - - private def schedule(d: Data): (Data, Seq[Send], Option[Duration]) = { - val now = System.nanoTime - @tailrec def rec(d: Data, toSend: Seq[Send]): (Data, Seq[Send], Option[Duration]) = { - if (d.queue.isEmpty) (d, toSend, None) - else { - val timeForPacket = d.lastSent + (1000 * size(d.queue.head.msg) / d.rateMBit).toLong - if (timeForPacket <= now) rec(Data(timeForPacket, d.rateMBit, d.queue.tail), toSend :+ d.queue.head) - else { - val splitThreshold = d.lastSent + packetSplitThreshold.toNanos - if (now < splitThreshold) (d, toSend, Some((timeForPacket - now).nanos min (splitThreshold - now).nanos)) - else { - val microsToSend = (now - d.lastSent) / 1000 - val (s1, s2) = split(d.queue.head, (microsToSend * d.rateMBit / 8).toInt) - (d.copy(queue = s2 +: d.queue.tail), toSend :+ s1, Some((timeForPacket - now).nanos min packetSplitThreshold)) - } - } - } - } - rec(d, Seq()) - } - - private def split(s: Send, bytes: Int): (Send, Send) = { - s.msg match { - case buf: ChannelBuffer ⇒ - val f = s.future map { f ⇒ - val newF = Channels.future(s.ctx.getChannel) - newF.addListener(new ChannelFutureListener { - def operationComplete(future: ChannelFuture) { - if (future.isCancelled) f.cancel() - else future.getCause match { - case null ⇒ - case thr ⇒ f.setFailure(thr) - } - } - }) - newF - } - val b = buf.slice() - b.writerIndex(b.readerIndex + bytes) - buf.readerIndex(buf.readerIndex + bytes) - (Send(s.ctx, f, b), Send(s.ctx, s.future, buf)) - } - } - - private def size(msg: AnyRef) = msg match { - case b: ChannelBuffer ⇒ b.readableBytes() * 8 - case _ ⇒ throw new UnsupportedOperationException("NetworkFailureInjector only supports ChannelBuffer messages") - } - } - - private var remote: Option[Address] = None - - override def messageReceived(ctx: ChannelHandlerContext, msg: MessageEvent) { - log.debug("upstream(queued): {}", msg) - receiver ! Send(ctx, Option(msg.getFuture), msg.getMessage) - } + private val conductor = TestConductor(system) + private var announced = false override def channelConnected(ctx: ChannelHandlerContext, state: ChannelStateEvent) { state.getValue match { case a: InetSocketAddress ⇒ val addr = Address("akka", "", a.getHostName, a.getPort) log.debug("connected to {}", addr) - TestConductor(system).failureInjectors.put(addr, FailureInjector(sender, receiver)) match { - case null ⇒ // okay - case fi ⇒ system.log.error("{} already registered for address {}", fi, addr) - } - remote = Some(addr) - sender ! SetContext(ctx) case x ⇒ throw new IllegalArgumentException("unknown address type: " + x) } } override def channelDisconnected(ctx: ChannelHandlerContext, state: ChannelStateEvent) { - log.debug("disconnected from {}", remote) - remote = remote flatMap { addr ⇒ - TestConductor(system).failureInjectors.remove(addr) - system.stop(sender) - system.stop(receiver) - None - } + log.debug("disconnected from {}", state.getChannel) + conductor.failureInjector ! RemoveContext(ctx) + } + + override def messageReceived(ctx: ChannelHandlerContext, msg: MessageEvent) { + log.debug("upstream(queued): {}", msg) + conductor.failureInjector ! ThrottleActor.Send(ctx, Direction.Receive, Option(msg.getFuture), msg.getMessage) } override def writeRequested(ctx: ChannelHandlerContext, msg: MessageEvent) { log.debug("downstream(queued): {}", msg) - sender ! Send(ctx, Option(msg.getFuture), msg.getMessage) + conductor.failureInjector ! ThrottleActor.Send(ctx, Direction.Send, Option(msg.getFuture), msg.getMessage) } } +/** + * INTERNAL API. + */ +private[akka] object ThrottleActor { + sealed trait State + case object PassThrough extends State + case object Throttle extends State + case object Blackhole extends State + + case class Data(lastSent: Long, rateMBit: Float, queue: Queue[Send]) + + case class Send(ctx: ChannelHandlerContext, direction: Direction, future: Option[ChannelFuture], msg: AnyRef) + case class SetRate(rateMBit: Float) + case object Tick +} + +/** + * INTERNAL API. + */ +private[akka] class ThrottleActor(channelContext: ChannelHandlerContext) + extends Actor with LoggingFSM[ThrottleActor.State, ThrottleActor.Data] { + + import ThrottleActor._ + import FSM._ + + private val packetSplitThreshold = TestConductor(context.system).Settings.PacketSplitThreshold + + startWith(PassThrough, Data(0, -1, Queue())) + + when(PassThrough) { + case Event(s @ Send(_, _, _, msg), _) ⇒ + log.debug("sending msg (PassThrough): {}", msg) + send(s) + stay + } + + when(Throttle) { + case Event(s: Send, data @ Data(_, _, Queue())) ⇒ + stay using sendThrottled(data.copy(lastSent = System.nanoTime, queue = Queue(s))) + case Event(s: Send, data) ⇒ + stay using sendThrottled(data.copy(queue = data.queue.enqueue(s))) + case Event(Tick, data) ⇒ + stay using sendThrottled(data) + } + + onTransition { + case Throttle -> PassThrough ⇒ + for (s ← stateData.queue) { + log.debug("sending msg (Transition): {}", s.msg) + send(s) + } + cancelTimer("send") + case Throttle -> Blackhole ⇒ + cancelTimer("send") + } + + when(Blackhole) { + case Event(Send(_, _, _, msg), _) ⇒ + log.debug("dropping msg {}", msg) + stay + } + + whenUnhandled { + case Event(SetRate(rate), d) ⇒ + if (rate > 0) { + goto(Throttle) using d.copy(lastSent = System.nanoTime, rateMBit = rate, queue = Queue()) + } else if (rate == 0) { + goto(Blackhole) + } else { + goto(PassThrough) + } + } + + initialize + + private def sendThrottled(d: Data): Data = { + val (data, toSend, toTick) = schedule(d) + for (s ← toSend) { + log.debug("sending msg (Tick): {}", s.msg) + send(s) + } + if (!timerActive_?("send")) + for (time ← toTick) { + log.debug("scheduling next Tick in {}", time) + setTimer("send", Tick, time, false) + } + data + } + + private def send(s: Send): Unit = s.direction match { + case Direction.Send ⇒ Channels.write(s.ctx, s.future getOrElse Channels.future(s.ctx.getChannel), s.msg) + case Direction.Receive ⇒ Channels.fireMessageReceived(s.ctx, s.msg) + case _ ⇒ + } + + private def schedule(d: Data): (Data, Seq[Send], Option[Duration]) = { + val now = System.nanoTime + @tailrec def rec(d: Data, toSend: Seq[Send]): (Data, Seq[Send], Option[Duration]) = { + if (d.queue.isEmpty) (d, toSend, None) + else { + val timeForPacket = d.lastSent + (1000 * size(d.queue.head.msg) / d.rateMBit).toLong + if (timeForPacket <= now) rec(Data(timeForPacket, d.rateMBit, d.queue.tail), toSend :+ d.queue.head) + else { + val splitThreshold = d.lastSent + packetSplitThreshold.toNanos + if (now < splitThreshold) (d, toSend, Some((timeForPacket - now).nanos min (splitThreshold - now).nanos)) + else { + val microsToSend = (now - d.lastSent) / 1000 + val (s1, s2) = split(d.queue.head, (microsToSend * d.rateMBit / 8).toInt) + (d.copy(queue = s2 +: d.queue.tail), toSend :+ s1, Some((timeForPacket - now).nanos min packetSplitThreshold)) + } + } + } + } + rec(d, Seq()) + } + + private def split(s: Send, bytes: Int): (Send, Send) = { + s.msg match { + case buf: ChannelBuffer ⇒ + val f = s.future map { f ⇒ + val newF = Channels.future(s.ctx.getChannel) + newF.addListener(new ChannelFutureListener { + def operationComplete(future: ChannelFuture) { + if (future.isCancelled) f.cancel() + else future.getCause match { + case null ⇒ + case thr ⇒ f.setFailure(thr) + } + } + }) + newF + } + val b = buf.slice() + b.writerIndex(b.readerIndex + bytes) + buf.readerIndex(buf.readerIndex + bytes) + (Send(s.ctx, s.direction, f, b), Send(s.ctx, s.direction, s.future, buf)) + } + } + + private def size(msg: AnyRef) = msg match { + case b: ChannelBuffer ⇒ b.readableBytes() * 8 + case _ ⇒ throw new UnsupportedOperationException("NetworkFailureInjector only supports ChannelBuffer messages") + } +} + diff --git a/akka-remote-tests/src/main/scala/akka/remote/testconductor/Player.scala b/akka-remote-tests/src/main/scala/akka/remote/testconductor/Player.scala index 2a4eeb6ad1..09699edc34 100644 --- a/akka-remote-tests/src/main/scala/akka/remote/testconductor/Player.scala +++ b/akka-remote-tests/src/main/scala/akka/remote/testconductor/Player.scala @@ -200,21 +200,13 @@ private[akka] class ClientFSM(name: RoleName, controllerAddr: InetSocketAddress) log.warning("did not expect {}", op) } stay using d.copy(runningOp = None) - case ThrottleMsg(target, dir, rate) ⇒ + case t: ThrottleMsg ⇒ import settings.QueryTimeout - import context.dispatcher - TestConductor().failureInjectors.get(target.copy(system = "")) match { - case null ⇒ log.warning("cannot throttle unknown address {}", target) - case inj ⇒ - Future.sequence(inj.refs(dir) map (_ ? NetworkFailureInjector.SetRate(rate))) map (_ ⇒ ToServer(Done)) pipeTo self - } + TestConductor().failureInjector ? t map (_ ⇒ ToServer(Done)) pipeTo self stay - case DisconnectMsg(target, abort) ⇒ + case d: DisconnectMsg ⇒ import settings.QueryTimeout - TestConductor().failureInjectors.get(target.copy(system = "")) match { - case null ⇒ log.warning("cannot disconnect unknown address {}", target) - case inj ⇒ inj.sender ? NetworkFailureInjector.Disconnect(abort) map (_ ⇒ ToServer(Done)) pipeTo self - } + TestConductor().failureInjector ? d map (_ ⇒ ToServer(Done)) pipeTo self stay case TerminateMsg(exit) ⇒ System.exit(exit) diff --git a/akka-remote-tests/src/multi-jvm/scala/akka/remote/SimpleRemoteSpec.scala b/akka-remote-tests/src/multi-jvm/scala/akka/remote/SimpleRemoteSpec.scala new file mode 100644 index 0000000000..dcc4b60526 --- /dev/null +++ b/akka-remote-tests/src/multi-jvm/scala/akka/remote/SimpleRemoteSpec.scala @@ -0,0 +1,55 @@ +/** + * Copyright (C) 2009-2012 Typesafe Inc. + */ +package akka.remote + +import akka.actor.Actor +import akka.actor.ActorRef +import akka.actor.Props +import akka.pattern.ask +import akka.remote.testkit.MultiNodeConfig +import akka.remote.testkit.MultiNodeSpec +import akka.testkit._ + +object SimpleRemoteMultiJvmSpec extends MultiNodeConfig { + + class SomeActor extends Actor with Serializable { + def receive = { + case "identify" ⇒ sender ! self + } + } + + commonConfig(debugConfig(on = false)) + + val master = role("master") + val slave = role("slave") + +} + +class SimpleRemoteMultiJvmNode1 extends SimpleRemoteSpec +class SimpleRemoteMultiJvmNode2 extends SimpleRemoteSpec + +class SimpleRemoteSpec extends MultiNodeSpec(SimpleRemoteMultiJvmSpec) + with ImplicitSender with DefaultTimeout { + import SimpleRemoteMultiJvmSpec._ + + def initialParticipants = 2 + + runOn(master) { + system.actorOf(Props[SomeActor], "service-hello") + } + + "Remoting" must { + "lookup remote actor" in { + runOn(slave) { + val hello = system.actorFor(node(master) / "user" / "service-hello") + hello.isInstanceOf[RemoteActorRef] must be(true) + val masterAddress = testConductor.getAddressFor(master).await + (hello ? "identify").await.asInstanceOf[ActorRef].path.address must equal(masterAddress) + } + testConductor.enter("done") + } + } + +} + diff --git a/akka-remote-tests/src/multi-jvm/scala/akka/remote/router/DirectRoutedRemoteActorMultiJvmSpec.scala b/akka-remote-tests/src/multi-jvm/scala/akka/remote/router/DirectRoutedRemoteActorMultiJvmSpec.scala new file mode 100644 index 0000000000..3f23f60b37 --- /dev/null +++ b/akka-remote-tests/src/multi-jvm/scala/akka/remote/router/DirectRoutedRemoteActorMultiJvmSpec.scala @@ -0,0 +1,73 @@ +/** + * Copyright (C) 2009-2012 Typesafe Inc. + */ +package akka.remote.router + +import com.typesafe.config.ConfigFactory + +import akka.actor.Actor +import akka.actor.ActorRef +import akka.actor.Props +import akka.pattern.ask +import akka.remote.RemoteActorRef +import akka.remote.testkit.MultiNodeConfig +import akka.remote.testkit.MultiNodeSpec +import akka.testkit._ + +object DirectRoutedRemoteActorMultiJvmSpec extends MultiNodeConfig { + + class SomeActor extends Actor with Serializable { + def receive = { + case "identify" ⇒ sender ! self + } + } + + commonConfig(debugConfig(on = false)) + + val master = role("master") + val slave = role("slave") + + nodeConfig(master, ConfigFactory.parseString(""" + akka.actor { + deployment { + /service-hello.remote = "akka://MultiNodeSpec@%s" + } + } + # FIXME When using NettyRemoteTransport instead of TestConductorTransport it works + # akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" + """.format("localhost:2553"))) // FIXME is there a way to avoid hardcoding the host:port here? + + nodeConfig(slave, ConfigFactory.parseString(""" + akka.remote.netty.port = 2553 + """)) + +} + +class DirectRoutedRemoteActorMultiJvmNode1 extends DirectRoutedRemoteActorSpec +class DirectRoutedRemoteActorMultiJvmNode2 extends DirectRoutedRemoteActorSpec + +class DirectRoutedRemoteActorSpec extends MultiNodeSpec(DirectRoutedRemoteActorMultiJvmSpec) + with ImplicitSender with DefaultTimeout { + import DirectRoutedRemoteActorMultiJvmSpec._ + + def initialParticipants = 2 + + "A new remote actor configured with a Direct router" must { + "be locally instantiated on a remote node and be able to communicate through its RemoteActorRef" in { + + runOn(master) { + val actor = system.actorOf(Props[SomeActor], "service-hello") + actor.isInstanceOf[RemoteActorRef] must be(true) + + val slaveAddress = testConductor.getAddressFor(slave).await + (actor ? "identify").await.asInstanceOf[ActorRef].path.address must equal(slaveAddress) + + // shut down the actor before we let the other node(s) shut down so we don't try to send + // "Terminate" to a shut down node + system.stop(actor) + } + + testConductor.enter("done") + } + } +} diff --git a/akka-remote-tests/src/multi-jvm/scala/akka/remote/testconductor/TestConductorSpec.scala b/akka-remote-tests/src/multi-jvm/scala/akka/remote/testconductor/TestConductorSpec.scala index 5ff19a806b..df6388d562 100644 --- a/akka-remote-tests/src/multi-jvm/scala/akka/remote/testconductor/TestConductorSpec.scala +++ b/akka-remote-tests/src/multi-jvm/scala/akka/remote/testconductor/TestConductorSpec.scala @@ -19,8 +19,8 @@ import akka.remote.testkit.MultiNodeSpec import akka.remote.testkit.MultiNodeConfig object TestConductorMultiJvmSpec extends MultiNodeConfig { - commonConfig(debugConfig(on = true)) - + commonConfig(debugConfig(on = false)) + val master = role("master") val slave = role("slave") } diff --git a/akka-remote/src/main/scala/akka/remote/netty/Client.scala b/akka-remote/src/main/scala/akka/remote/netty/Client.scala index 3dcb2cc5a8..14544a6c54 100644 --- a/akka-remote/src/main/scala/akka/remote/netty/Client.scala +++ b/akka-remote/src/main/scala/akka/remote/netty/Client.scala @@ -173,6 +173,7 @@ class ActiveRemoteClient private[akka] ( notifyListeners(RemoteClientError(connection.getCause, netty, remoteAddress)) false } else { + ChannelAddress.set(connection.getChannel, Some(remoteAddress)) sendSecureCookie(connection) notifyListeners(RemoteClientStarted(netty, remoteAddress)) true @@ -196,8 +197,10 @@ class ActiveRemoteClient private[akka] ( notifyListeners(RemoteClientShutdown(netty, remoteAddress)) try { - if ((connection ne null) && (connection.getChannel ne null)) + if ((connection ne null) && (connection.getChannel ne null)) { + ChannelAddress.remove(connection.getChannel) connection.getChannel.close() + } } finally { try { if (openChannels ne null) openChannels.close.awaitUninterruptibly() diff --git a/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala b/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala index d81b197f70..59670d2106 100644 --- a/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala +++ b/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala @@ -12,7 +12,7 @@ import java.util.concurrent.Executors import scala.collection.mutable.HashMap import org.jboss.netty.channel.group.{ DefaultChannelGroup, ChannelGroupFuture } import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory -import org.jboss.netty.channel.{ ChannelHandlerContext, Channel, StaticChannelPipeline, ChannelHandler, ChannelPipelineFactory } +import org.jboss.netty.channel.{ ChannelHandlerContext, Channel, StaticChannelPipeline, ChannelHandler, ChannelPipelineFactory, ChannelLocal } import org.jboss.netty.handler.codec.frame.{ LengthFieldPrepender, LengthFieldBasedFrameDecoder } import org.jboss.netty.handler.codec.protobuf.{ ProtobufEncoder, ProtobufDecoder } import org.jboss.netty.handler.execution.{ ExecutionHandler, OrderedMemoryAwareThreadPoolExecutor } @@ -25,6 +25,10 @@ import akka.remote.{ RemoteTransportException, RemoteTransport, RemoteSettings, import akka.util.NonFatal import akka.actor.{ ExtendedActorSystem, Address, ActorRef } +object ChannelAddress extends ChannelLocal[Option[Address]] { + override def initialValue(ch: Channel): Option[Address] = None +} + /** * Provides the implementation of the Netty remote support */ diff --git a/akka-remote/src/main/scala/akka/remote/netty/Server.scala b/akka-remote/src/main/scala/akka/remote/netty/Server.scala index 4ed1007bbd..22408a668d 100644 --- a/akka-remote/src/main/scala/akka/remote/netty/Server.scala +++ b/akka-remote/src/main/scala/akka/remote/netty/Server.scala @@ -112,10 +112,6 @@ class RemoteServerHandler( val openChannels: ChannelGroup, val netty: NettyRemoteTransport) extends SimpleChannelUpstreamHandler { - val channelAddress = new ChannelLocal[Option[Address]](false) { - override def initialValue(channel: Channel) = None - } - import netty.settings private var addressToSet = true @@ -139,16 +135,16 @@ class RemoteServerHandler( override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = () override def channelDisconnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { - netty.notifyListeners(RemoteServerClientDisconnected(netty, channelAddress.get(ctx.getChannel))) + netty.notifyListeners(RemoteServerClientDisconnected(netty, ChannelAddress.get(ctx.getChannel))) } override def channelClosed(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { - val address = channelAddress.get(ctx.getChannel) + val address = ChannelAddress.get(ctx.getChannel) if (address.isDefined && settings.UsePassiveConnections) netty.unbindClient(address.get) netty.notifyListeners(RemoteServerClientClosed(netty, address)) - channelAddress.remove(ctx.getChannel) + ChannelAddress.remove(ctx.getChannel) } override def messageReceived(ctx: ChannelHandlerContext, event: MessageEvent) = try { @@ -162,7 +158,7 @@ class RemoteServerHandler( case CommandType.CONNECT ⇒ val origin = instruction.getOrigin val inbound = Address("akka", origin.getSystem, origin.getHostname, origin.getPort) - channelAddress.set(event.getChannel, Option(inbound)) + ChannelAddress.set(event.getChannel, Option(inbound)) //If we want to reuse the inbound connections as outbound we need to get busy if (settings.UsePassiveConnections)