make failure injection idempotent

- instead of creating local top-level actors per pipeline, just create
  one system actor through which everything is sent
- this enables storing settings (like what to throttle how) within this
  actor and applying settings when connections come up later
- it also gets rid of the blocking actor creation from
  NetworkFailureInjector, fixing the dead-lock
- moved also the ServerFSMs to be children of the Controller
- all actors have proper names now for easier debugging
This commit is contained in:
Roland 2012-05-24 10:56:32 +02:00
parent 12ff07f025
commit e054816047
11 changed files with 325 additions and 241 deletions

View file

@ -601,7 +601,7 @@ object Logging {
import java.text.SimpleDateFormat
import java.util.Date
val dateFormat = new SimpleDateFormat("MM/dd/yyyy HH:mm:ss.S")
val dateFormat = new SimpleDateFormat("MM/dd/yyyy HH:mm:ss.SSS")
def timestamp = dateFormat.format(new Date)

View file

@ -27,12 +27,26 @@ import akka.actor.SupervisorStrategy
import java.util.concurrent.ConcurrentHashMap
import akka.actor.Status
sealed trait Direction
sealed trait Direction {
def includes(other: Direction): Boolean
}
object Direction {
case object Send extends Direction
case object Receive extends Direction
case object Both extends Direction
case object Send extends Direction {
override def includes(other: Direction): Boolean = other match {
case Send true
case _ false
}
}
case object Receive extends Direction {
override def includes(other: Direction): Boolean = other match {
case Receive true
case _ false
}
}
case object Both extends Direction {
override def includes(other: Direction): Boolean = true
}
}
/**
@ -202,14 +216,15 @@ trait Conductor { this: TestConductorExt ⇒
* purpose is to dispatch incoming messages to the right ServerFSM actor. There is
* one shared instance of this class for all connections accepted by one Controller.
*/
class ConductorHandler(system: ActorSystem, controller: ActorRef, log: LoggingAdapter) extends SimpleChannelUpstreamHandler {
class ConductorHandler(_createTimeout: Timeout, controller: ActorRef, log: LoggingAdapter) extends SimpleChannelUpstreamHandler {
implicit val createTimeout = _createTimeout
val clients = new ConcurrentHashMap[Channel, ActorRef]()
override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
val channel = event.getChannel
log.debug("connection from {}", getAddrString(channel))
val fsm = system.actorOf(Props(new ServerFSM(controller, channel)))
val fsm: ActorRef = Await.result(controller ? Controller.CreateServerFSM(channel) mapTo, Duration.Inf)
clients.put(channel, fsm)
}
@ -321,6 +336,7 @@ object Controller {
case class ClientDisconnected(name: RoleName)
case object GetNodes
case object GetSockAddr
case class CreateServerFSM(channel: Channel)
case class NodeInfo(name: RoleName, addr: Address, fsm: ActorRef)
}
@ -336,7 +352,7 @@ class Controller(private var initialParticipants: Int, controllerPort: InetSocke
val settings = TestConductor().Settings
val connection = RemoteConnection(Server, controllerPort,
new ConductorHandler(context.system, self, Logging(context.system, "ConductorHandler")))
new ConductorHandler(settings.QueryTimeout, self, Logging(context.system, "ConductorHandler")))
/*
* Supervision of the BarrierCoordinator means to catch all his bad emotions
@ -363,8 +379,15 @@ class Controller(private var initialParticipants: Int, controllerPort: InetSocke
// map keeping unanswered queries for node addresses (enqueued upon GetAddress, serviced upon NodeInfo)
var addrInterest = Map[RoleName, Set[ActorRef]]()
val generation = Iterator from 1
override def receive = LoggingReceive {
case CreateServerFSM(channel)
val (ip, port) = channel.getRemoteAddress match {
case s: InetSocketAddress (s.getHostString, s.getPort)
}
val name = ip + ":" + port + "-server" + generation.next
sender ! context.actorOf(Props(new ServerFSM(self, channel)), name)
case c @ NodeInfo(name, addr, fsm)
barrier forward c
if (nodes contains name) {

View file

@ -10,6 +10,8 @@ import java.util.concurrent.TimeUnit.MILLISECONDS
import akka.actor.ActorRef
import java.util.concurrent.ConcurrentHashMap
import akka.actor.Address
import akka.actor.ActorSystemImpl
import akka.actor.Props
/**
* Access to the [[akka.remote.testconductor.TestConductorExt]] extension:
@ -50,6 +52,6 @@ class TestConductorExt(val system: ExtendedActorSystem) extends Extension with C
val transport = system.provider.asInstanceOf[RemoteActorRefProvider].transport
val address = transport.address
val failureInjectors = new ConcurrentHashMap[Address, FailureInjector]
val failureInjector = system.asInstanceOf[ActorSystemImpl].systemActorOf(Props[FailureInjector], "FailureInjector")
}

View file

@ -4,236 +4,303 @@
package akka.remote.testconductor
import java.net.InetSocketAddress
import scala.collection.immutable.Queue
import org.jboss.netty.buffer.ChannelBuffer
import org.jboss.netty.channel.ChannelState.BOUND
import org.jboss.netty.channel.ChannelState.OPEN
import org.jboss.netty.channel.Channel
import org.jboss.netty.channel.ChannelEvent
import org.jboss.netty.channel.ChannelHandlerContext
import org.jboss.netty.channel.ChannelStateEvent
import org.jboss.netty.channel.MessageEvent
import akka.actor.FSM
import akka.actor.Actor
import akka.util.duration.doubleToDurationDouble
import akka.util.Index
import akka.actor.Address
import akka.actor.ActorSystem
import akka.actor.Props
import akka.actor.ActorRef
import akka.event.Logging
import org.jboss.netty.channel.SimpleChannelHandler
import scala.annotation.tailrec
import akka.util.Duration
import akka.actor.LoggingFSM
import org.jboss.netty.channel.Channels
import org.jboss.netty.channel.ChannelFuture
import org.jboss.netty.channel.ChannelFutureListener
import org.jboss.netty.channel.ChannelFuture
case class FailureInjector(sender: ActorRef, receiver: ActorRef) {
def refs(dir: Direction) = dir match {
case Direction.Send Seq(sender)
case Direction.Receive Seq(receiver)
case Direction.Both Seq(sender, receiver)
import scala.annotation.tailrec
import scala.collection.immutable.Queue
import org.jboss.netty.buffer.ChannelBuffer
import org.jboss.netty.channel.{ SimpleChannelHandler, MessageEvent, Channels, ChannelStateEvent, ChannelHandlerContext, ChannelFutureListener, ChannelFuture }
import akka.actor.{ Props, LoggingFSM, Address, ActorSystem, ActorRef, ActorLogging, Actor, FSM }
import akka.event.Logging
import akka.remote.netty.ChannelAddress
import akka.util.Duration
import akka.util.duration._
class FailureInjector extends Actor with ActorLogging {
import ThrottleActor._
import NetworkFailureInjector._
case class ChannelSettings(
ctx: Option[ChannelHandlerContext] = None,
throttleSend: Option[SetRate] = None,
throttleReceive: Option[SetRate] = None)
case class Injectors(sender: ActorRef, receiver: ActorRef)
var channels = Map[ChannelHandlerContext, Injectors]()
var settings = Map[Address, ChannelSettings]()
var generation = Iterator from 1
/**
* Only for a NEW ctx, start ThrottleActors, prime them and update all maps.
*/
def ingestContextAddress(ctx: ChannelHandlerContext, addr: Address): Injectors = {
val gen = generation.next
val name = addr.host.get + ":" + addr.port.get
val thrSend = context.actorOf(Props(new ThrottleActor(ctx)), name + "-snd" + gen)
val thrRecv = context.actorOf(Props(new ThrottleActor(ctx)), name + "-rcv" + gen)
val injectors = Injectors(thrSend, thrRecv)
channels += ctx -> injectors
settings += addr -> (settings get addr map {
case c @ ChannelSettings(prevCtx, ts, tr)
ts foreach (thrSend ! _)
tr foreach (thrRecv ! _)
prevCtx match {
case Some(p) log.warning("installing context {} instead of {} for address {}", ctx, p, addr)
case None // okay
}
c.copy(ctx = Some(ctx))
} getOrElse ChannelSettings(Some(ctx)))
injectors
}
/**
* Retrieve target settings, also if they were sketchy before (i.e. no system name)
*/
def retrieveTargetSettings(target: Address): Option[ChannelSettings] = {
settings get target orElse {
val host = target.host
val port = target.port
settings find {
case (Address("akka", "", `host`, `port`), s) true
case _ false
} map {
case (_, s) settings += target -> s; s
}
}
}
def receive = {
case RemoveContext(ctx)
channels get ctx foreach { inj
context stop inj.sender
context stop inj.receiver
}
channels -= ctx
settings ++= settings collect { case (addr, c @ ChannelSettings(Some(`ctx`), _, _)) (addr, c.copy(ctx = None)) }
case ThrottleMsg(target, dir, rateMBit)
val setting = retrieveTargetSettings(target)
settings += target -> ((setting getOrElse ChannelSettings() match {
case cs @ ChannelSettings(ctx, _, _) if dir includes Direction.Send
ctx foreach (c channels get c foreach (_.sender ! SetRate(rateMBit)))
cs.copy(throttleSend = Some(SetRate(rateMBit)))
case x x
}) match {
case cs @ ChannelSettings(ctx, _, _) if dir includes Direction.Receive
ctx foreach (c channels get c foreach (_.receiver ! SetRate(rateMBit)))
cs.copy(throttleReceive = Some(SetRate(rateMBit)))
case x x
})
sender ! "ok"
case DisconnectMsg(target, abort)
retrieveTargetSettings(target) foreach {
case ChannelSettings(Some(ctx), _, _)
val ch = ctx.getChannel
if (abort) {
ch.getConfig.setOption("soLinger", 0)
log.info("aborting connection {}", ch)
} else log.info("closing connection {}", ch)
ch.close
case _ log.debug("no connection to {} to close or abort", target)
}
sender ! "ok"
case s @ Send(ctx, direction, future, msg)
channels get ctx match {
case Some(Injectors(snd, rcv))
if (direction includes Direction.Send) snd ! s
if (direction includes Direction.Receive) rcv ! s
case None
val (ipaddr, ip, port) = ctx.getChannel.getRemoteAddress match {
case s: InetSocketAddress (s.getAddress, s.getAddress.getHostAddress, s.getPort)
}
val addr = ChannelAddress.get(ctx.getChannel) orElse {
settings collect { case (a @ Address("akka", _, Some(`ip`), Some(`port`)), _) a } headOption
} orElse {
val name = ipaddr.getHostName
if (name == ip) None
else settings collect { case (a @ Address("akka", _, Some(`name`), Some(`port`)), _) a } headOption
} getOrElse Address("akka", "", ip, port) // this will not match later requests directly, but be picked up by retrieveTargetSettings
val inj = ingestContextAddress(ctx, addr)
if (direction includes Direction.Send) inj.sender ! s
if (direction includes Direction.Receive) inj.receiver ! s
}
}
}
object NetworkFailureInjector {
case class SetRate(rateMBit: Float)
case class Disconnect(abort: Boolean)
case class RemoveContext(ctx: ChannelHandlerContext)
}
class NetworkFailureInjector(system: ActorSystem) extends SimpleChannelHandler {
import NetworkFailureInjector._
val log = Logging(system, "FailureInjector")
private val log = Logging(system, "FailureInjector")
// everything goes via these Throttle actors to enable easy steering
private val sender = system.actorOf(Props(new Throttle(Direction.Send)))
private val receiver = system.actorOf(Props(new Throttle(Direction.Receive)))
private val packetSplitThreshold = TestConductor(system).Settings.PacketSplitThreshold
/*
* State, Data and Messages for the internal Throttle actor
*/
sealed private trait State
private case object PassThrough extends State
private case object Throttle extends State
private case object Blackhole extends State
private case class Data(lastSent: Long, rateMBit: Float, queue: Queue[Send])
private case class Send(ctx: ChannelHandlerContext, future: Option[ChannelFuture], msg: AnyRef)
private case class SetContext(ctx: ChannelHandlerContext)
private case object Tick
private class Throttle(dir: Direction) extends Actor with LoggingFSM[State, Data] {
import FSM._
var channelContext: ChannelHandlerContext = _
startWith(PassThrough, Data(0, -1, Queue()))
when(PassThrough) {
case Event(s @ Send(_, _, msg), _)
log.debug("sending msg (PassThrough): {}", msg)
send(s)
stay
}
when(Throttle) {
case Event(s: Send, data @ Data(_, _, Queue()))
stay using sendThrottled(data.copy(lastSent = System.nanoTime, queue = Queue(s)))
case Event(s: Send, data)
stay using sendThrottled(data.copy(queue = data.queue.enqueue(s)))
case Event(Tick, data)
stay using sendThrottled(data)
}
onTransition {
case Throttle -> PassThrough
for (s stateData.queue) {
log.debug("sending msg (Transition): {}", s.msg)
send(s)
}
cancelTimer("send")
case Throttle -> Blackhole
cancelTimer("send")
}
when(Blackhole) {
case Event(Send(_, _, msg), _)
log.debug("dropping msg {}", msg)
stay
}
whenUnhandled {
case Event(NetworkFailureInjector.SetRate(rate), d)
sender ! "ok"
if (rate > 0) {
goto(Throttle) using d.copy(lastSent = System.nanoTime, rateMBit = rate, queue = Queue())
} else if (rate == 0) {
goto(Blackhole)
} else {
goto(PassThrough)
}
case Event(SetContext(ctx), _) channelContext = ctx; stay
case Event(NetworkFailureInjector.Disconnect(abort), Data(ctx, _, _))
sender ! "ok"
// TODO implement abort
channelContext.getChannel.disconnect()
stay
}
initialize
private def sendThrottled(d: Data): Data = {
val (data, toSend, toTick) = schedule(d)
for (s toSend) {
log.debug("sending msg (Tick): {}", s.msg)
send(s)
}
if (!timerActive_?("send"))
for (time toTick) {
log.debug("scheduling next Tick in {}", time)
setTimer("send", Tick, time, false)
}
data
}
private def send(s: Send): Unit = dir match {
case Direction.Send Channels.write(s.ctx, s.future getOrElse Channels.future(s.ctx.getChannel), s.msg)
case Direction.Receive Channels.fireMessageReceived(s.ctx, s.msg)
case _
}
private def schedule(d: Data): (Data, Seq[Send], Option[Duration]) = {
val now = System.nanoTime
@tailrec def rec(d: Data, toSend: Seq[Send]): (Data, Seq[Send], Option[Duration]) = {
if (d.queue.isEmpty) (d, toSend, None)
else {
val timeForPacket = d.lastSent + (1000 * size(d.queue.head.msg) / d.rateMBit).toLong
if (timeForPacket <= now) rec(Data(timeForPacket, d.rateMBit, d.queue.tail), toSend :+ d.queue.head)
else {
val splitThreshold = d.lastSent + packetSplitThreshold.toNanos
if (now < splitThreshold) (d, toSend, Some((timeForPacket - now).nanos min (splitThreshold - now).nanos))
else {
val microsToSend = (now - d.lastSent) / 1000
val (s1, s2) = split(d.queue.head, (microsToSend * d.rateMBit / 8).toInt)
(d.copy(queue = s2 +: d.queue.tail), toSend :+ s1, Some((timeForPacket - now).nanos min packetSplitThreshold))
}
}
}
}
rec(d, Seq())
}
private def split(s: Send, bytes: Int): (Send, Send) = {
s.msg match {
case buf: ChannelBuffer
val f = s.future map { f
val newF = Channels.future(s.ctx.getChannel)
newF.addListener(new ChannelFutureListener {
def operationComplete(future: ChannelFuture) {
if (future.isCancelled) f.cancel()
else future.getCause match {
case null
case thr f.setFailure(thr)
}
}
})
newF
}
val b = buf.slice()
b.writerIndex(b.readerIndex + bytes)
buf.readerIndex(buf.readerIndex + bytes)
(Send(s.ctx, f, b), Send(s.ctx, s.future, buf))
}
}
private def size(msg: AnyRef) = msg match {
case b: ChannelBuffer b.readableBytes() * 8
case _ throw new UnsupportedOperationException("NetworkFailureInjector only supports ChannelBuffer messages")
}
}
private var remote: Option[Address] = None
override def messageReceived(ctx: ChannelHandlerContext, msg: MessageEvent) {
log.debug("upstream(queued): {}", msg)
receiver ! Send(ctx, Option(msg.getFuture), msg.getMessage)
}
private val conductor = TestConductor(system)
private var announced = false
override def channelConnected(ctx: ChannelHandlerContext, state: ChannelStateEvent) {
state.getValue match {
case a: InetSocketAddress
val addr = Address("akka", "", a.getHostName, a.getPort)
log.debug("connected to {}", addr)
TestConductor(system).failureInjectors.put(addr, FailureInjector(sender, receiver)) match {
case null // okay
case fi system.log.error("{} already registered for address {}", fi, addr)
}
remote = Some(addr)
sender ! SetContext(ctx)
case x throw new IllegalArgumentException("unknown address type: " + x)
}
}
override def channelDisconnected(ctx: ChannelHandlerContext, state: ChannelStateEvent) {
log.debug("disconnected from {}", remote)
remote = remote flatMap { addr
TestConductor(system).failureInjectors.remove(addr)
system.stop(sender)
system.stop(receiver)
None
}
log.debug("disconnected from {}", state.getChannel)
conductor.failureInjector ! RemoveContext(ctx)
}
override def messageReceived(ctx: ChannelHandlerContext, msg: MessageEvent) {
log.debug("upstream(queued): {}", msg)
conductor.failureInjector ! ThrottleActor.Send(ctx, Direction.Receive, Option(msg.getFuture), msg.getMessage)
}
override def writeRequested(ctx: ChannelHandlerContext, msg: MessageEvent) {
log.debug("downstream(queued): {}", msg)
sender ! Send(ctx, Option(msg.getFuture), msg.getMessage)
conductor.failureInjector ! ThrottleActor.Send(ctx, Direction.Send, Option(msg.getFuture), msg.getMessage)
}
}
private[akka] object ThrottleActor {
sealed trait State
case object PassThrough extends State
case object Throttle extends State
case object Blackhole extends State
case class Data(lastSent: Long, rateMBit: Float, queue: Queue[Send])
case class Send(ctx: ChannelHandlerContext, direction: Direction, future: Option[ChannelFuture], msg: AnyRef)
case class SetRate(rateMBit: Float)
case object Tick
}
private[akka] class ThrottleActor(channelContext: ChannelHandlerContext)
extends Actor with LoggingFSM[ThrottleActor.State, ThrottleActor.Data] {
import ThrottleActor._
import FSM._
private val packetSplitThreshold = TestConductor(context.system).Settings.PacketSplitThreshold
startWith(PassThrough, Data(0, -1, Queue()))
when(PassThrough) {
case Event(s @ Send(_, _, _, msg), _)
log.debug("sending msg (PassThrough): {}", msg)
send(s)
stay
}
when(Throttle) {
case Event(s: Send, data @ Data(_, _, Queue()))
stay using sendThrottled(data.copy(lastSent = System.nanoTime, queue = Queue(s)))
case Event(s: Send, data)
stay using sendThrottled(data.copy(queue = data.queue.enqueue(s)))
case Event(Tick, data)
stay using sendThrottled(data)
}
onTransition {
case Throttle -> PassThrough
for (s stateData.queue) {
log.debug("sending msg (Transition): {}", s.msg)
send(s)
}
cancelTimer("send")
case Throttle -> Blackhole
cancelTimer("send")
}
when(Blackhole) {
case Event(Send(_, _, _, msg), _)
log.debug("dropping msg {}", msg)
stay
}
whenUnhandled {
case Event(SetRate(rate), d)
if (rate > 0) {
goto(Throttle) using d.copy(lastSent = System.nanoTime, rateMBit = rate, queue = Queue())
} else if (rate == 0) {
goto(Blackhole)
} else {
goto(PassThrough)
}
}
initialize
private def sendThrottled(d: Data): Data = {
val (data, toSend, toTick) = schedule(d)
for (s toSend) {
log.debug("sending msg (Tick): {}", s.msg)
send(s)
}
if (!timerActive_?("send"))
for (time toTick) {
log.debug("scheduling next Tick in {}", time)
setTimer("send", Tick, time, false)
}
data
}
private def send(s: Send): Unit = s.direction match {
case Direction.Send Channels.write(s.ctx, s.future getOrElse Channels.future(s.ctx.getChannel), s.msg)
case Direction.Receive Channels.fireMessageReceived(s.ctx, s.msg)
case _
}
private def schedule(d: Data): (Data, Seq[Send], Option[Duration]) = {
val now = System.nanoTime
@tailrec def rec(d: Data, toSend: Seq[Send]): (Data, Seq[Send], Option[Duration]) = {
if (d.queue.isEmpty) (d, toSend, None)
else {
val timeForPacket = d.lastSent + (1000 * size(d.queue.head.msg) / d.rateMBit).toLong
if (timeForPacket <= now) rec(Data(timeForPacket, d.rateMBit, d.queue.tail), toSend :+ d.queue.head)
else {
val splitThreshold = d.lastSent + packetSplitThreshold.toNanos
if (now < splitThreshold) (d, toSend, Some((timeForPacket - now).nanos min (splitThreshold - now).nanos))
else {
val microsToSend = (now - d.lastSent) / 1000
val (s1, s2) = split(d.queue.head, (microsToSend * d.rateMBit / 8).toInt)
(d.copy(queue = s2 +: d.queue.tail), toSend :+ s1, Some((timeForPacket - now).nanos min packetSplitThreshold))
}
}
}
}
rec(d, Seq())
}
private def split(s: Send, bytes: Int): (Send, Send) = {
s.msg match {
case buf: ChannelBuffer
val f = s.future map { f
val newF = Channels.future(s.ctx.getChannel)
newF.addListener(new ChannelFutureListener {
def operationComplete(future: ChannelFuture) {
if (future.isCancelled) f.cancel()
else future.getCause match {
case null
case thr f.setFailure(thr)
}
}
})
newF
}
val b = buf.slice()
b.writerIndex(b.readerIndex + bytes)
buf.readerIndex(buf.readerIndex + bytes)
(Send(s.ctx, s.direction, f, b), Send(s.ctx, s.direction, s.future, buf))
}
}
private def size(msg: AnyRef) = msg match {
case b: ChannelBuffer b.readableBytes() * 8
case _ throw new UnsupportedOperationException("NetworkFailureInjector only supports ChannelBuffer messages")
}
}

View file

@ -195,21 +195,13 @@ class ClientFSM(name: RoleName, controllerAddr: InetSocketAddress) extends Actor
log.warning("did not expect {}", op)
}
stay using d.copy(runningOp = None)
case ThrottleMsg(target, dir, rate)
case t: ThrottleMsg
import settings.QueryTimeout
import context.dispatcher
TestConductor().failureInjectors.get(target.copy(system = "")) match {
case null log.warning("cannot throttle unknown address {}", target)
case inj
Future.sequence(inj.refs(dir) map (_ ? NetworkFailureInjector.SetRate(rate))) map (_ ToServer(Done)) pipeTo self
}
TestConductor().failureInjector ? t map (_ ToServer(Done)) pipeTo self
stay
case DisconnectMsg(target, abort)
case d: DisconnectMsg
import settings.QueryTimeout
TestConductor().failureInjectors.get(target.copy(system = "")) match {
case null log.warning("cannot disconnect unknown address {}", target)
case inj inj.sender ? NetworkFailureInjector.Disconnect(abort) map (_ ToServer(Done)) pipeTo self
}
TestConductor().failureInjector ? d map (_ ToServer(Done)) pipeTo self
stay
case TerminateMsg(exit)
System.exit(exit)

View file

@ -22,7 +22,7 @@ object SimpleRemoteMultiJvmSpec extends MultiNodeConfig {
}
commonConfig(ConfigFactory.parseString("""
akka.loglevel = DEBUG
# akka.loglevel = DEBUG
akka.remote {
log-received-messages = on
log-sent-messages = on

View file

@ -24,7 +24,7 @@ object DirectRoutedRemoteActorMultiJvmSpec extends MultiNodeConfig {
import com.typesafe.config.ConfigFactory
commonConfig(ConfigFactory.parseString("""
akka.loglevel = DEBUG
# akka.loglevel = DEBUG
akka.remote {
log-received-messages = on
log-sent-messages = on

View file

@ -17,7 +17,7 @@ import akka.remote.testkit.MultiNodeConfig
object TestConductorMultiJvmSpec extends MultiNodeConfig {
commonConfig(ConfigFactory.parseString("""
akka.loglevel = DEBUG
# akka.loglevel = DEBUG
akka.remote {
log-received-messages = on
log-sent-messages = on

View file

@ -173,6 +173,7 @@ class ActiveRemoteClient private[akka] (
notifyListeners(RemoteClientError(connection.getCause, netty, remoteAddress))
false
} else {
ChannelAddress.set(connection.getChannel, Some(remoteAddress))
sendSecureCookie(connection)
notifyListeners(RemoteClientStarted(netty, remoteAddress))
true
@ -196,8 +197,10 @@ class ActiveRemoteClient private[akka] (
notifyListeners(RemoteClientShutdown(netty, remoteAddress))
try {
if ((connection ne null) && (connection.getChannel ne null))
if ((connection ne null) && (connection.getChannel ne null)) {
ChannelAddress.remove(connection.getChannel)
connection.getChannel.close()
}
} finally {
try {
if (openChannels ne null) openChannels.close.awaitUninterruptibly()

View file

@ -29,6 +29,11 @@ import org.jboss.netty.handler.codec.frame.LengthFieldBasedFrameDecoder
import org.jboss.netty.handler.timeout.IdleStateHandler
import org.jboss.netty.channel.ChannelPipelineFactory
import org.jboss.netty.handler.execution.ExecutionHandler
import org.jboss.netty.channel.ChannelLocal
object ChannelAddress extends ChannelLocal[Option[Address]] {
override def initialValue(ch: Channel): Option[Address] = None
}
/**
* Provides the implementation of the Netty remote support

View file

@ -102,19 +102,11 @@ class RemoteServerAuthenticationHandler(secureCookie: Option[String]) extends Si
}
}
object ChannelLocalSystem extends ChannelLocal[ActorSystemImpl] {
override def initialValue(ch: Channel): ActorSystemImpl = null
}
@ChannelHandler.Sharable
class RemoteServerHandler(
val openChannels: ChannelGroup,
val netty: NettyRemoteTransport) extends SimpleChannelUpstreamHandler {
val channelAddress = new ChannelLocal[Option[Address]](false) {
override def initialValue(channel: Channel) = None
}
import netty.settings
private var addressToSet = true
@ -138,16 +130,16 @@ class RemoteServerHandler(
override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = ()
override def channelDisconnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
netty.notifyListeners(RemoteServerClientDisconnected(netty, channelAddress.get(ctx.getChannel)))
netty.notifyListeners(RemoteServerClientDisconnected(netty, ChannelAddress.get(ctx.getChannel)))
}
override def channelClosed(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
val address = channelAddress.get(ctx.getChannel)
val address = ChannelAddress.get(ctx.getChannel)
if (address.isDefined && settings.UsePassiveConnections)
netty.unbindClient(address.get)
netty.notifyListeners(RemoteServerClientClosed(netty, address))
channelAddress.remove(ctx.getChannel)
ChannelAddress.remove(ctx.getChannel)
}
override def messageReceived(ctx: ChannelHandlerContext, event: MessageEvent) = try {
@ -161,7 +153,7 @@ class RemoteServerHandler(
case CommandType.CONNECT
val origin = instruction.getOrigin
val inbound = Address("akka", origin.getSystem, origin.getHostname, origin.getPort)
channelAddress.set(event.getChannel, Option(inbound))
ChannelAddress.set(event.getChannel, Option(inbound))
//If we want to reuse the inbound connections as outbound we need to get busy
if (settings.UsePassiveConnections)