diff --git a/akka-remote/src/main/java/akka/remote/testconductor/TestConductorProtocol.java b/akka-remote/src/main/java/akka/remote/testconductor/TestConductorProtocol.java index 0b2950018f..f112a1b0c2 100644 --- a/akka-remote/src/main/java/akka/remote/testconductor/TestConductorProtocol.java +++ b/akka-remote/src/main/java/akka/remote/testconductor/TestConductorProtocol.java @@ -87,10 +87,12 @@ public final class TestConductorProtocol { implements com.google.protobuf.ProtocolMessageEnum { Send(0, 1), Receive(1, 2), + Both(2, 3), ; public static final int Send_VALUE = 1; public static final int Receive_VALUE = 2; + public static final int Both_VALUE = 3; public final int getNumber() { return value; } @@ -99,6 +101,7 @@ public final class TestConductorProtocol { switch (value) { case 1: return Send; case 2: return Receive; + case 3: return Both; default: return null; } } @@ -129,7 +132,7 @@ public final class TestConductorProtocol { } private static final Direction[] VALUES = { - Send, Receive, + Send, Receive, Both, }; public static Direction valueOf( @@ -169,6 +172,10 @@ public final class TestConductorProtocol { boolean hasFailure(); akka.remote.testconductor.TestConductorProtocol.InjectFailure getFailure(); akka.remote.testconductor.TestConductorProtocol.InjectFailureOrBuilder getFailureOrBuilder(); + + // optional string done = 4; + boolean hasDone(); + String getDone(); } public static final class Wrapper extends com.google.protobuf.GeneratedMessage @@ -238,10 +245,43 @@ public final class TestConductorProtocol { return failure_; } + // optional string done = 4; + public static final int DONE_FIELD_NUMBER = 4; + private java.lang.Object done_; + public boolean hasDone() { + return ((bitField0_ & 0x00000008) == 0x00000008); + } + public String getDone() { + java.lang.Object ref = done_; + if (ref instanceof String) { + return (String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + String s = bs.toStringUtf8(); + if (com.google.protobuf.Internal.isValidUtf8(bs)) { + done_ = s; + } + return s; + } + } + private com.google.protobuf.ByteString getDoneBytes() { + java.lang.Object ref = done_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8((String) ref); + done_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + private void initFields() { hello_ = akka.remote.testconductor.TestConductorProtocol.Hello.getDefaultInstance(); barrier_ = akka.remote.testconductor.TestConductorProtocol.EnterBarrier.getDefaultInstance(); failure_ = akka.remote.testconductor.TestConductorProtocol.InjectFailure.getDefaultInstance(); + done_ = ""; } private byte memoizedIsInitialized = -1; public final boolean isInitialized() { @@ -282,6 +322,9 @@ public final class TestConductorProtocol { if (((bitField0_ & 0x00000004) == 0x00000004)) { output.writeMessage(3, failure_); } + if (((bitField0_ & 0x00000008) == 0x00000008)) { + output.writeBytes(4, getDoneBytes()); + } getUnknownFields().writeTo(output); } @@ -303,6 +346,10 @@ public final class TestConductorProtocol { size += com.google.protobuf.CodedOutputStream .computeMessageSize(3, failure_); } + if (((bitField0_ & 0x00000008) == 0x00000008)) { + size += com.google.protobuf.CodedOutputStream + .computeBytesSize(4, getDoneBytes()); + } size += getUnknownFields().getSerializedSize(); memoizedSerializedSize = size; return size; @@ -448,6 +495,8 @@ public final class TestConductorProtocol { failureBuilder_.clear(); } bitField0_ = (bitField0_ & ~0x00000004); + done_ = ""; + bitField0_ = (bitField0_ & ~0x00000008); return this; } @@ -510,6 +559,10 @@ public final class TestConductorProtocol { } else { result.failure_ = failureBuilder_.build(); } + if (((from_bitField0_ & 0x00000008) == 0x00000008)) { + to_bitField0_ |= 0x00000008; + } + result.done_ = done_; result.bitField0_ = to_bitField0_; onBuilt(); return result; @@ -535,6 +588,9 @@ public final class TestConductorProtocol { if (other.hasFailure()) { mergeFailure(other.getFailure()); } + if (other.hasDone()) { + setDone(other.getDone()); + } this.mergeUnknownFields(other.getUnknownFields()); return this; } @@ -611,6 +667,11 @@ public final class TestConductorProtocol { setFailure(subBuilder.buildPartial()); break; } + case 34: { + bitField0_ |= 0x00000008; + done_ = input.readBytes(); + break; + } } } } @@ -887,6 +948,42 @@ public final class TestConductorProtocol { return failureBuilder_; } + // optional string done = 4; + private java.lang.Object done_ = ""; + public boolean hasDone() { + return ((bitField0_ & 0x00000008) == 0x00000008); + } + public String getDone() { + java.lang.Object ref = done_; + if (!(ref instanceof String)) { + String s = ((com.google.protobuf.ByteString) ref).toStringUtf8(); + done_ = s; + return s; + } else { + return (String) ref; + } + } + public Builder setDone(String value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000008; + done_ = value; + onChanged(); + return this; + } + public Builder clearDone() { + bitField0_ = (bitField0_ & ~0x00000008); + done_ = getDefaultInstance().getDone(); + onChanged(); + return this; + } + void setDone(com.google.protobuf.ByteString value) { + bitField0_ |= 0x00000008; + done_ = value; + onChanged(); + } + // @@protoc_insertion_point(builder_scope:Wrapper) } @@ -3199,20 +3296,21 @@ public final class TestConductorProtocol { descriptor; static { java.lang.String[] descriptorData = { - "\n\033TestConductorProtocol.proto\"a\n\007Wrapper" + + "\n\033TestConductorProtocol.proto\"o\n\007Wrapper" + "\022\025\n\005hello\030\001 \001(\0132\006.Hello\022\036\n\007barrier\030\002 \001(\013" + "2\r.EnterBarrier\022\037\n\007failure\030\003 \001(\0132\016.Injec" + - "tFailure\"0\n\005Hello\022\014\n\004name\030\001 \002(\t\022\031\n\007addre" + - "ss\030\002 \002(\0132\010.Address\"\034\n\014EnterBarrier\022\014\n\004na" + - "me\030\001 \002(\t\"G\n\007Address\022\020\n\010protocol\030\001 \002(\t\022\016\n" + - "\006system\030\002 \002(\t\022\014\n\004host\030\003 \002(\t\022\014\n\004port\030\004 \002(" + - "\005\"\212\001\n\rInjectFailure\022\032\n\007failure\030\001 \002(\0162\t.F" + - "ailType\022\035\n\tdirection\030\002 \001(\0162\n.Direction\022\031" + - "\n\007address\030\003 \001(\0132\010.Address\022\020\n\010rateMBit\030\006 ", - "\001(\002\022\021\n\texitValue\030\007 \001(\005*A\n\010FailType\022\014\n\010Th" + - "rottle\020\001\022\016\n\nDisconnect\020\002\022\t\n\005Abort\020\003\022\014\n\010S" + - "hutdown\020\004*\"\n\tDirection\022\010\n\004Send\020\001\022\013\n\007Rece" + - "ive\020\002B\035\n\031akka.remote.testconductorH\001" + "tFailure\022\014\n\004done\030\004 \001(\t\"0\n\005Hello\022\014\n\004name\030" + + "\001 \002(\t\022\031\n\007address\030\002 \002(\0132\010.Address\"\034\n\014Ente" + + "rBarrier\022\014\n\004name\030\001 \002(\t\"G\n\007Address\022\020\n\010pro" + + "tocol\030\001 \002(\t\022\016\n\006system\030\002 \002(\t\022\014\n\004host\030\003 \002(" + + "\t\022\014\n\004port\030\004 \002(\005\"\212\001\n\rInjectFailure\022\032\n\007fai" + + "lure\030\001 \002(\0162\t.FailType\022\035\n\tdirection\030\002 \001(\016" + + "2\n.Direction\022\031\n\007address\030\003 \001(\0132\010.Address\022", + "\020\n\010rateMBit\030\006 \001(\002\022\021\n\texitValue\030\007 \001(\005*A\n\010" + + "FailType\022\014\n\010Throttle\020\001\022\016\n\nDisconnect\020\002\022\t" + + "\n\005Abort\020\003\022\014\n\010Shutdown\020\004*,\n\tDirection\022\010\n\004" + + "Send\020\001\022\013\n\007Receive\020\002\022\010\n\004Both\020\003B\035\n\031akka.re" + + "mote.testconductorH\001" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner() { @@ -3224,7 +3322,7 @@ public final class TestConductorProtocol { internal_static_Wrapper_fieldAccessorTable = new com.google.protobuf.GeneratedMessage.FieldAccessorTable( internal_static_Wrapper_descriptor, - new java.lang.String[] { "Hello", "Barrier", "Failure", }, + new java.lang.String[] { "Hello", "Barrier", "Failure", "Done", }, akka.remote.testconductor.TestConductorProtocol.Wrapper.class, akka.remote.testconductor.TestConductorProtocol.Wrapper.Builder.class); internal_static_Hello_descriptor = diff --git a/akka-remote/src/main/protocol/TestConductorProtocol.proto b/akka-remote/src/main/protocol/TestConductorProtocol.proto index 213820e687..e483bf4f01 100644 --- a/akka-remote/src/main/protocol/TestConductorProtocol.proto +++ b/akka-remote/src/main/protocol/TestConductorProtocol.proto @@ -15,6 +15,7 @@ message Wrapper { optional Hello hello = 1; optional EnterBarrier barrier = 2; optional InjectFailure failure = 3; + optional string done = 4; } message Hello { @@ -42,6 +43,7 @@ enum FailType { enum Direction { Send = 1; Receive = 2; + Both = 3; } message InjectFailure { required FailType failure = 1; diff --git a/akka-remote/src/main/scala/akka/remote/netty/Client.scala b/akka-remote/src/main/scala/akka/remote/netty/Client.scala index a0e91398fc..cf143650bc 100644 --- a/akka-remote/src/main/scala/akka/remote/netty/Client.scala +++ b/akka-remote/src/main/scala/akka/remote/netty/Client.scala @@ -112,8 +112,6 @@ class ActiveRemoteClient private[akka] ( private var connection: ChannelFuture = _ @volatile private[remote] var openChannels: DefaultChannelGroup = _ - @volatile - private var executionHandler: ExecutionHandler = _ @volatile private var reconnectionTimeWindowStart = 0L @@ -156,9 +154,8 @@ class ActiveRemoteClient private[akka] ( runSwitch switchOn { openChannels = new DefaultDisposableChannelGroup(classOf[RemoteClient].getName) - executionHandler = new ExecutionHandler(netty.executor) val b = new ClientBootstrap(netty.clientChannelFactory) - b.setPipelineFactory(new ActiveRemoteClientPipelineFactory(name, b, executionHandler, remoteAddress, localAddress, this)) + b.setPipelineFactory(netty.mkPipeline(new ActiveRemoteClientHandler(name, b, remoteAddress, localAddress, netty.timer, this), true)) b.setOption("tcpNoDelay", true) b.setOption("keepAlive", true) b.setOption("connectTimeoutMillis", settings.ConnectionTimeout.toMillis) @@ -206,7 +203,6 @@ class ActiveRemoteClient private[akka] ( if (openChannels ne null) openChannels.close.awaitUninterruptibly() } finally { connection = null - executionHandler = null } } @@ -319,31 +315,6 @@ class ActiveRemoteClientHandler( } } -class ActiveRemoteClientPipelineFactory( - name: String, - bootstrap: ClientBootstrap, - executionHandler: ExecutionHandler, - remoteAddress: Address, - localAddress: Address, - client: ActiveRemoteClient) extends ChannelPipelineFactory { - - import client.netty.settings - - def getPipeline: ChannelPipeline = { - val timeout = new IdleStateHandler(client.netty.timer, - settings.ReadTimeout.toSeconds.toInt, - settings.WriteTimeout.toSeconds.toInt, - settings.AllTimeout.toSeconds.toInt) - val lenDec = new LengthFieldBasedFrameDecoder(settings.MessageFrameSize, 0, 4, 0, 4) - val lenPrep = new LengthFieldPrepender(4) - val messageDec = new RemoteMessageDecoder - val messageEnc = new RemoteMessageEncoder(client.netty) - val remoteClient = new ActiveRemoteClientHandler(name, bootstrap, remoteAddress, localAddress, client.netty.timer, client) - - new StaticChannelPipeline(timeout, lenDec, messageDec, lenPrep, messageEnc, executionHandler, remoteClient) - } -} - class PassiveRemoteClient(val currentChannel: Channel, netty: NettyRemoteTransport, remoteAddress: Address) diff --git a/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala b/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala index c3a41f8275..35ef3bf7fd 100644 --- a/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala +++ b/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala @@ -22,6 +22,13 @@ import akka.event.Logging import akka.remote.RemoteProtocol.AkkaRemoteProtocol import akka.remote.{ RemoteTransportException, RemoteTransport, RemoteSettings, RemoteMarshallingOps, RemoteActorRefProvider, RemoteActorRef, RemoteServerStarted } import akka.util.NonFatal +import org.jboss.netty.channel.StaticChannelPipeline +import org.jboss.netty.channel.ChannelHandler +import org.jboss.netty.handler.codec.frame.LengthFieldPrepender +import org.jboss.netty.handler.codec.frame.LengthFieldBasedFrameDecoder +import org.jboss.netty.handler.timeout.IdleStateHandler +import org.jboss.netty.channel.ChannelPipelineFactory +import org.jboss.netty.handler.execution.ExecutionHandler /** * Provides the implementation of the Netty remote support @@ -34,20 +41,54 @@ class NettyRemoteTransport(val remoteSettings: RemoteSettings, val system: Actor // TODO replace by system.scheduler val timer: HashedWheelTimer = new HashedWheelTimer(system.threadFactory) - // TODO make configurable - lazy val executor = new OrderedMemoryAwareThreadPoolExecutor( - settings.ExecutionPoolSize, - settings.MaxChannelMemorySize, - settings.MaxTotalMemorySize, - settings.ExecutionPoolKeepalive.length, - settings.ExecutionPoolKeepalive.unit, - system.threadFactory) - // TODO make configurable/shareable with server socket factory val clientChannelFactory = new NioClientSocketChannelFactory( Executors.newCachedThreadPool(system.threadFactory), Executors.newCachedThreadPool(system.threadFactory)) + object PipelineFactory { + def apply(handlers: Seq[ChannelHandler]): StaticChannelPipeline = new StaticChannelPipeline(handlers: _*) + def apply(endpoint: ⇒ Seq[ChannelHandler], withTimeout: Boolean): ChannelPipelineFactory = + new ChannelPipelineFactory { + def getPipeline = apply(defaultStack(withTimeout) ++ endpoint) + } + + def defaultStack(withTimeout: Boolean): Seq[ChannelHandler] = + (if (withTimeout) timeout :: Nil else Nil) ::: + msgFormat ::: + authenticator ::: + executionHandler :: + Nil + + def timeout = new IdleStateHandler(timer, + settings.ReadTimeout.toSeconds.toInt, + settings.WriteTimeout.toSeconds.toInt, + settings.AllTimeout.toSeconds.toInt) + + def msgFormat = new LengthFieldBasedFrameDecoder(settings.MessageFrameSize, 0, 4, 0, 4) :: + new LengthFieldPrepender(4) :: + new RemoteMessageDecoder :: + new RemoteMessageEncoder(NettyRemoteTransport.this) :: + Nil + + val executionHandler = new ExecutionHandler(new OrderedMemoryAwareThreadPoolExecutor( + settings.ExecutionPoolSize, + settings.MaxChannelMemorySize, + settings.MaxTotalMemorySize, + settings.ExecutionPoolKeepalive.length, + settings.ExecutionPoolKeepalive.unit, + system.threadFactory)) + + def authenticator = if (settings.RequireCookie) new RemoteServerAuthenticationHandler(settings.SecureCookie) :: Nil else Nil + } + + /** + * This method is factored out to provide an extension point in case the + * pipeline shall be changed. It is recommended to use + */ + def mkPipeline(endpoint: ⇒ ChannelHandler, withTimeout: Boolean): ChannelPipelineFactory = + PipelineFactory(Seq(endpoint), withTimeout) + private val remoteClients = new HashMap[Address, RemoteClient] private val clientsLock = new ReentrantReadWriteLock @@ -105,11 +146,7 @@ class NettyRemoteTransport(val remoteSettings: RemoteSettings, val system: Actor try { timer.stop() } finally { - try { - clientChannelFactory.releaseExternalResources() - } finally { - executor.shutdown() - } + clientChannelFactory.releaseExternalResources() } } } diff --git a/akka-remote/src/main/scala/akka/remote/netty/Server.scala b/akka-remote/src/main/scala/akka/remote/netty/Server.scala index ac4289e8ae..f9d4ede1d8 100644 --- a/akka-remote/src/main/scala/akka/remote/netty/Server.scala +++ b/akka-remote/src/main/scala/akka/remote/netty/Server.scala @@ -30,14 +30,12 @@ class NettyRemoteServer(val netty: NettyRemoteTransport) { Executors.newCachedThreadPool(netty.system.threadFactory), Executors.newCachedThreadPool(netty.system.threadFactory)) - private val executionHandler = new ExecutionHandler(netty.executor) - // group of open channels, used for clean-up private val openChannels: ChannelGroup = new DefaultDisposableChannelGroup("akka-remote-server") private val bootstrap = { val b = new ServerBootstrap(factory) - b.setPipelineFactory(makePipeline()) + b.setPipelineFactory(netty.mkPipeline(new RemoteServerHandler(openChannels, netty), false)) b.setOption("backlog", settings.Backlog) b.setOption("tcpNoDelay", true) b.setOption("child.keepAlive", true) @@ -45,8 +43,6 @@ class NettyRemoteServer(val netty: NettyRemoteTransport) { b } - protected def makePipeline(): ChannelPipelineFactory = new RemoteServerPipelineFactory(openChannels, executionHandler, netty) - @volatile private[akka] var channel: Channel = _ @@ -79,26 +75,6 @@ class NettyRemoteServer(val netty: NettyRemoteTransport) { } } -class RemoteServerPipelineFactory( - val openChannels: ChannelGroup, - val executionHandler: ExecutionHandler, - val netty: NettyRemoteTransport) extends ChannelPipelineFactory { - - import netty.settings - - def getPipeline: ChannelPipeline = { - val lenDec = new LengthFieldBasedFrameDecoder(settings.MessageFrameSize, 0, 4, 0, 4) - val lenPrep = new LengthFieldPrepender(4) - val messageDec = new RemoteMessageDecoder - val messageEnc = new RemoteMessageEncoder(netty) - - val authenticator = if (settings.RequireCookie) new RemoteServerAuthenticationHandler(settings.SecureCookie) :: Nil else Nil - val remoteServer = new RemoteServerHandler(openChannels, netty) - val stages: List[ChannelHandler] = lenDec :: messageDec :: lenPrep :: messageEnc :: executionHandler :: authenticator ::: remoteServer :: Nil - new StaticChannelPipeline(stages: _*) - } -} - @ChannelHandler.Sharable class RemoteServerAuthenticationHandler(secureCookie: Option[String]) extends SimpleChannelUpstreamHandler { val authenticated = new AnyRef diff --git a/akka-remote/src/main/scala/akka/remote/testconductor/Conductor.scala b/akka-remote/src/main/scala/akka/remote/testconductor/Conductor.scala index c46e22eb9f..c9cbeadf83 100644 --- a/akka-remote/src/main/scala/akka/remote/testconductor/Conductor.scala +++ b/akka-remote/src/main/scala/akka/remote/testconductor/Conductor.scala @@ -21,6 +21,7 @@ import scala.util.control.NoStackTrace import akka.event.LoggingReceive import akka.actor.Address import java.net.InetSocketAddress +import akka.dispatch.Future trait Conductor extends RunControl with FailureInject { this: TestConductorExt ⇒ @@ -32,55 +33,63 @@ trait Conductor extends RunControl with FailureInject { this: TestConductorExt case x ⇒ x } - override def startController() { + override def startController(): Future[Int] = { if (_controller ne null) throw new RuntimeException("TestConductorServer was already started") _controller = system.actorOf(Props[Controller], "controller") import Settings.BarrierTimeout - startClient(Await.result(controller ? GetPort mapTo, Duration.Inf)) + controller ? GetPort flatMap { case port: Int ⇒ startClient(port) map (_ ⇒ port) } } - override def port: Int = { + override def port: Future[Int] = { import Settings.QueryTimeout - Await.result(controller ? GetPort mapTo, Duration.Inf) + controller ? GetPort mapTo } - override def throttle(node: String, target: String, direction: Direction, rateMBit: Float) { - controller ! Throttle(node, target, direction, rateMBit) - } - - override def blackhole(node: String, target: String, direction: Direction) { - controller ! Throttle(node, target, direction, 0f) - } - - override def disconnect(node: String, target: String) { - controller ! Disconnect(node, target, false) - } - - override def abort(node: String, target: String) { - controller ! Disconnect(node, target, true) - } - - override def shutdown(node: String, exitValue: Int) { - controller ! Terminate(node, exitValue) - } - - override def kill(node: String) { - controller ! Terminate(node, -1) - } - - override def getNodes = { + override def throttle(node: String, target: String, direction: Direction, rateMBit: Double): Future[Done] = { import Settings.QueryTimeout - Await.result(controller ? GetNodes mapTo manifest[List[String]], Duration.Inf) + controller ? Throttle(node, target, direction, rateMBit.toFloat) mapTo } - override def removeNode(node: String) { - controller ! Remove(node) + override def blackhole(node: String, target: String, direction: Direction): Future[Done] = { + import Settings.QueryTimeout + controller ? Throttle(node, target, direction, 0f) mapTo + } + + override def disconnect(node: String, target: String): Future[Done] = { + import Settings.QueryTimeout + controller ? Disconnect(node, target, false) mapTo + } + + override def abort(node: String, target: String): Future[Done] = { + import Settings.QueryTimeout + controller ? Disconnect(node, target, true) mapTo + } + + override def shutdown(node: String, exitValue: Int): Future[Done] = { + import Settings.QueryTimeout + controller ? Terminate(node, exitValue) mapTo + } + + override def kill(node: String): Future[Done] = { + import Settings.QueryTimeout + controller ? Terminate(node, -1) mapTo + } + + override def getNodes: Future[List[String]] = { + import Settings.QueryTimeout + controller ? GetNodes mapTo + } + + override def removeNode(node: String): Future[Done] = { + import Settings.QueryTimeout + controller ? Remove(node) mapTo } } class ConductorHandler(system: ActorSystem, controller: ActorRef, log: LoggingAdapter) extends SimpleChannelUpstreamHandler { + @volatile var clients = Map[Channel, ActorRef]() override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { @@ -102,7 +111,7 @@ class ConductorHandler(system: ActorSystem, controller: ActorRef, log: LoggingAd val channel = event.getChannel log.debug("message from {}: {}", getAddrString(channel), event.getMessage) event.getMessage match { - case msg: Wrapper if msg.getAllFields.size == 1 ⇒ + case msg: NetworkOp ⇒ clients(channel) ! msg case msg ⇒ log.info("client {} sent garbage '{}', disconnecting", getAddrString(channel), msg) @@ -116,28 +125,26 @@ object ServerFSM { sealed trait State case object Initial extends State case object Ready extends State - - case class Send(msg: Wrapper) } -class ServerFSM(val controller: ActorRef, val channel: Channel) extends Actor with LoggingFSM[ServerFSM.State, Null] { +class ServerFSM(val controller: ActorRef, val channel: Channel) extends Actor with LoggingFSM[ServerFSM.State, Option[ActorRef]] { import ServerFSM._ import akka.actor.FSM._ import Controller._ - startWith(Initial, null) + startWith(Initial, None) when(Initial, stateTimeout = 10 seconds) { - case Event(msg: Wrapper, _) ⇒ - if (msg.hasHello) { - val hello = msg.getHello - controller ! ClientConnected(hello.getName, hello.getAddress) - goto(Ready) - } else { - log.warning("client {} sent no Hello in first message, disconnecting", getAddrString(channel)) - channel.close() - stop() - } + case Event(Hello(name, addr), _) ⇒ + controller ! ClientConnected(name, addr) + goto(Ready) + case Event(x: NetworkOp, _) ⇒ + log.warning("client {} sent no Hello in first message (instead {}), disconnecting", getAddrString(channel), x) + channel.close() + stop() + case Event(Send(msg), _) ⇒ + log.warning("cannot send {} in state Initial", msg) + stay case Event(StateTimeout, _) ⇒ log.info("closing channel to {} because of Hello timeout", getAddrString(channel)) channel.close() @@ -145,20 +152,24 @@ class ServerFSM(val controller: ActorRef, val channel: Channel) extends Actor wi } when(Ready) { - case Event(msg: Wrapper, _) ⇒ - if (msg.hasBarrier) { - val barrier = msg.getBarrier - controller ! EnterBarrier(barrier.getName) - } else { - log.warning("client {} sent unsupported message {}", getAddrString(channel), msg) - } + case Event(msg: EnterBarrier, _) ⇒ + controller ! msg stay - case Event(Send(msg), _) ⇒ + case Event(d: Done, Some(s)) ⇒ + s ! d + stay using None + case Event(msg: NetworkOp, _) ⇒ + log.warning("client {} sent unsupported message {}", getAddrString(channel), msg) + channel.close() + stop() + case Event(Send(msg: EnterBarrier), _) ⇒ channel.write(msg) stay - case Event(EnterBarrier(name), _) ⇒ - val barrier = TestConductorProtocol.EnterBarrier.newBuilder.setName(name).build - channel.write(Wrapper.newBuilder.setBarrier(barrier).build) + case Event(Send(msg), None) ⇒ + channel.write(msg) + stay using Some(sender) + case Event(Send(msg), _) ⇒ + log.warning("cannot send {} while waiting for previous ACK", msg) stay } @@ -185,7 +196,6 @@ class Controller extends Actor { var nodes = Map[String, NodeInfo]() override def receive = LoggingReceive { - case "ready?" ⇒ sender ! "yes" case ClientConnected(name, addr) ⇒ nodes += name -> NodeInfo(name, addr, sender) barrier forward ClientConnected @@ -198,28 +208,15 @@ class Controller extends Actor { barrier forward e case Throttle(node, target, direction, rateMBit) ⇒ val t = nodes(target) - val throttle = - InjectFailure.newBuilder - .setFailure(FailType.Throttle) - .setDirection(TestConductorProtocol.Direction.valueOf(direction.toString)) - .setAddress(t.addr) - .setRateMBit(rateMBit) - .build - nodes(node).fsm ! ServerFSM.Send(Wrapper.newBuilder.setFailure(throttle).build) + nodes(node).fsm forward Send(ThrottleMsg(t.addr, direction, rateMBit)) case Disconnect(node, target, abort) ⇒ val t = nodes(target) - val disconnect = - InjectFailure.newBuilder - .setFailure(if (abort) FailType.Abort else FailType.Disconnect) - .setAddress(t.addr) - .build - nodes(node).fsm ! ServerFSM.Send(Wrapper.newBuilder.setFailure(disconnect).build) + nodes(node).fsm forward Send(DisconnectMsg(t.addr, abort)) case Terminate(node, exitValueOrKill) ⇒ if (exitValueOrKill < 0) { // TODO: kill via SBT } else { - val shutdown = InjectFailure.newBuilder.setFailure(FailType.Shutdown).setExitValue(exitValueOrKill).build - nodes(node).fsm ! ServerFSM.Send(Wrapper.newBuilder.setFailure(shutdown).build) + nodes(node).fsm forward Send(TerminateMsg(exitValueOrKill)) } // TODO: properly remove node from BarrierCoordinator // case Remove(node) => @@ -269,7 +266,7 @@ class BarrierCoordinator extends Actor with LoggingFSM[BarrierCoordinator.State, if (name != barrier) throw new IllegalStateException("trying enter barrier '" + name + "' while barrier '" + barrier + "' is active") val together = sender :: arrived if (together.size == num) { - together foreach (_ ! e) + together foreach (_ ! Send(e)) goto(Idle) using Data(num, "", Nil) } else { stay using d.copy(arrived = together) @@ -280,7 +277,7 @@ class BarrierCoordinator extends Actor with LoggingFSM[BarrierCoordinator.State, val expected = num - 1 if (arrived.size == expected) { val e = EnterBarrier(barrier) - sender :: arrived foreach (_ ! e) + sender :: arrived foreach (_ ! Send(e)) goto(Idle) using Data(expected, "", Nil) } else { stay using d.copy(clients = expected) diff --git a/akka-remote/src/main/scala/akka/remote/testconductor/DataTypes.scala b/akka-remote/src/main/scala/akka/remote/testconductor/DataTypes.scala index 2b54ea1018..90d7eeccd5 100644 --- a/akka-remote/src/main/scala/akka/remote/testconductor/DataTypes.scala +++ b/akka-remote/src/main/scala/akka/remote/testconductor/DataTypes.scala @@ -3,11 +3,82 @@ */ package akka.remote.testconductor -sealed trait ClientOp -sealed trait ServerOp +import org.jboss.netty.handler.codec.oneone.OneToOneEncoder +import org.jboss.netty.channel.ChannelHandlerContext +import org.jboss.netty.channel.Channel +import akka.remote.testconductor.{ TestConductorProtocol ⇒ TCP } +import com.google.protobuf.Message +import akka.actor.Address +import org.jboss.netty.handler.codec.oneone.OneToOneDecoder -case class EnterBarrier(name: String) extends ClientOp with ServerOp +case class Send(msg: NetworkOp) + +sealed trait ClientOp // messages sent to Player FSM +sealed trait ServerOp // messages sent to Conductor FSM +sealed trait NetworkOp // messages sent over the wire + +case class Hello(name: String, addr: Address) extends NetworkOp +case class EnterBarrier(name: String) extends ClientOp with ServerOp with NetworkOp case class Throttle(node: String, target: String, direction: Direction, rateMBit: Float) extends ServerOp +case class ThrottleMsg(target: Address, direction: Direction, rateMBit: Float) extends NetworkOp case class Disconnect(node: String, target: String, abort: Boolean) extends ServerOp +case class DisconnectMsg(target: Address, abort: Boolean) extends NetworkOp case class Terminate(node: String, exitValueOrKill: Int) extends ServerOp +case class TerminateMsg(exitValue: Int) extends NetworkOp +abstract class Done extends NetworkOp +case object Done extends Done { + def getInstance: Done = this +} + case class Remove(node: String) extends ServerOp + +class MsgEncoder extends OneToOneEncoder { + def encode(ctx: ChannelHandlerContext, ch: Channel, msg: AnyRef): AnyRef = msg match { + case x: NetworkOp ⇒ + val w = TCP.Wrapper.newBuilder + x match { + case Hello(name, addr) ⇒ + w.setHello(TCP.Hello.newBuilder.setName(name).setAddress(addr)) + case EnterBarrier(name) ⇒ + w.setBarrier(TCP.EnterBarrier.newBuilder.setName(name)) + case ThrottleMsg(target, dir, rate) ⇒ + w.setFailure(TCP.InjectFailure.newBuilder.setAddress(target) + .setFailure(TCP.FailType.Throttle).setDirection(dir).setRateMBit(rate)) + case DisconnectMsg(target, abort) ⇒ + w.setFailure(TCP.InjectFailure.newBuilder.setAddress(target) + .setFailure(if (abort) TCP.FailType.Abort else TCP.FailType.Disconnect)) + case TerminateMsg(exitValue) ⇒ + w.setFailure(TCP.InjectFailure.newBuilder.setFailure(TCP.FailType.Shutdown).setExitValue(exitValue)) + case _: Done ⇒ + w.setDone("") + } + w.build + case _ ⇒ throw new IllegalArgumentException("wrong message " + msg) + } +} + +class MsgDecoder extends OneToOneDecoder { + def decode(ctx: ChannelHandlerContext, ch: Channel, msg: AnyRef): AnyRef = msg match { + case w: TCP.Wrapper if w.getAllFields.size == 1 ⇒ + if (w.hasHello) { + val h = w.getHello + Hello(h.getName, h.getAddress) + } else if (w.hasBarrier) { + EnterBarrier(w.getBarrier.getName) + } else if (w.hasFailure) { + val f = w.getFailure + import TCP.{ FailType ⇒ FT } + f.getFailure match { + case FT.Throttle ⇒ ThrottleMsg(f.getAddress, f.getDirection, f.getRateMBit) + case FT.Abort ⇒ DisconnectMsg(f.getAddress, true) + case FT.Disconnect ⇒ DisconnectMsg(f.getAddress, false) + case FT.Shutdown ⇒ TerminateMsg(f.getExitValue) + } + } else if (w.hasDone) { + Done + } else { + throw new IllegalArgumentException("unknown message " + msg) + } + case _ ⇒ throw new IllegalArgumentException("wrong message " + msg) + } +} diff --git a/akka-remote/src/main/scala/akka/remote/testconductor/Extension.scala b/akka-remote/src/main/scala/akka/remote/testconductor/Extension.scala index 94847664c9..bffa84847f 100644 --- a/akka-remote/src/main/scala/akka/remote/testconductor/Extension.scala +++ b/akka-remote/src/main/scala/akka/remote/testconductor/Extension.scala @@ -7,9 +7,14 @@ import akka.remote.RemoteActorRefProvider import akka.actor.ActorContext import akka.util.{ Duration, Timeout } import java.util.concurrent.TimeUnit.MILLISECONDS +import akka.actor.ActorRef +import java.util.concurrent.ConcurrentHashMap +import akka.actor.Address object TestConductor extends ExtensionKey[TestConductorExt] { + def apply()(implicit ctx: ActorContext): TestConductorExt = apply(ctx.system) + } class TestConductorExt(val system: ExtendedActorSystem) extends Extension with Conductor with Player { @@ -28,4 +33,6 @@ class TestConductorExt(val system: ExtendedActorSystem) extends Extension with C val transport = system.provider.asInstanceOf[RemoteActorRefProvider].transport val address = transport.address + val failureInjectors = new ConcurrentHashMap[Address, FailureInjector] + } \ No newline at end of file diff --git a/akka-remote/src/main/scala/akka/remote/testconductor/Features.scala b/akka-remote/src/main/scala/akka/remote/testconductor/Features.scala index 930be600c2..b94f205726 100644 --- a/akka-remote/src/main/scala/akka/remote/testconductor/Features.scala +++ b/akka-remote/src/main/scala/akka/remote/testconductor/Features.scala @@ -3,6 +3,8 @@ */ package akka.remote.testconductor +import akka.dispatch.Future + trait BarrierSync { /** * Enter all given barriers in the order in which they were given. @@ -11,9 +13,12 @@ trait BarrierSync { } sealed trait Direction -case object Send extends Direction -case object Receive extends Direction -case object Both extends Direction + +object Direction { + case object Send extends Direction + case object Receive extends Direction + case object Both extends Direction +} trait FailureInject { @@ -21,7 +26,7 @@ trait FailureInject { * Make the remoting pipeline on the node throttle data sent to or received * from the given remote peer. */ - def throttle(node: String, target: String, direction: Direction, rateMBit: Float): Unit + def throttle(node: String, target: String, direction: Direction, rateMBit: Double): Future[Done] /** * Switch the Netty pipeline of the remote support into blackhole mode for @@ -29,56 +34,56 @@ trait FailureInject { * submitting them to the Socket or right after receiving them from the * Socket. */ - def blackhole(node: String, target: String, direction: Direction): Unit + def blackhole(node: String, target: String, direction: Direction): Future[Done] /** * Tell the remote support to shutdown the connection to the given remote * peer. It works regardless of whether the recipient was initiator or * responder. */ - def disconnect(node: String, target: String): Unit + def disconnect(node: String, target: String): Future[Done] /** * Tell the remote support to TCP_RESET the connection to the given remote * peer. It works regardless of whether the recipient was initiator or * responder. */ - def abort(node: String, target: String): Unit + def abort(node: String, target: String): Future[Done] } trait RunControl { /** - * Start the server port. + * Start the server port, returns the port number. */ - def startController(): Unit + def startController(): Future[Int] /** * Get the actual port used by the server. */ - def port: Int + def port: Future[Int] /** * Tell the remote node to shut itself down using System.exit with the given * exitValue. */ - def shutdown(node: String, exitValue: Int): Unit + def shutdown(node: String, exitValue: Int): Future[Done] /** * Tell the SBT plugin to forcibly terminate the given remote node using Process.destroy. */ - def kill(node: String): Unit + def kill(node: String): Future[Done] /** * Obtain the list of remote host names currently registered. */ - def getNodes: List[String] + def getNodes: Future[List[String]] /** * Remove a remote host from the list, so that the remaining nodes may still * pass subsequent barriers. */ - def removeNode(node: String): Unit + def removeNode(node: String): Future[Done] } diff --git a/akka-remote/src/main/scala/akka/remote/testconductor/NetworkFailureInjector.scala b/akka-remote/src/main/scala/akka/remote/testconductor/NetworkFailureInjector.scala index 6569d81acc..30e5308979 100644 --- a/akka-remote/src/main/scala/akka/remote/testconductor/NetworkFailureInjector.scala +++ b/akka-remote/src/main/scala/akka/remote/testconductor/NetworkFailureInjector.scala @@ -1,5 +1,5 @@ /** - * Copyright (C) 2009-2011 Typesafe Inc. + * Copyright (C) 2009-2012 Typesafe Inc. */ package akka.remote.testconductor @@ -9,11 +9,9 @@ import org.jboss.netty.buffer.ChannelBuffer import org.jboss.netty.channel.ChannelState.BOUND import org.jboss.netty.channel.ChannelState.OPEN import org.jboss.netty.channel.Channel -import org.jboss.netty.channel.ChannelDownstreamHandler import org.jboss.netty.channel.ChannelEvent import org.jboss.netty.channel.ChannelHandlerContext import org.jboss.netty.channel.ChannelStateEvent -import org.jboss.netty.channel.ChannelUpstreamHandler import org.jboss.netty.channel.MessageEvent import akka.actor.FSM import akka.actor.Actor @@ -22,23 +20,26 @@ import akka.util.Index import akka.actor.Address import akka.actor.ActorSystem import akka.actor.Props +import akka.actor.ActorRef +import akka.event.Logging +import org.jboss.netty.channel.SimpleChannelHandler -object NetworkFailureInjector { - - val channels = new Index[Address, Channel](16, (c1, c2) ⇒ c1 compareTo c2) - - def close(remote: Address): Unit = { - // channels will be cleaned up by the handler - for (chs ← channels.remove(remote); c ← chs) c.close() +case class FailureInjector(sender: ActorRef, receiver: ActorRef) { + def refs(dir: Direction) = dir match { + case Direction.Send ⇒ Seq(sender) + case Direction.Receive ⇒ Seq(receiver) + case Direction.Both ⇒ Seq(sender, receiver) } } -class NetworkFailureInjector(system: ActorSystem) extends ChannelUpstreamHandler with ChannelDownstreamHandler { +object NetworkFailureInjector { + case class SetRate(rateMBit: Float) + case class Disconnect(abort: Boolean) +} - import NetworkFailureInjector._ +class NetworkFailureInjector(system: ActorSystem) extends SimpleChannelHandler { - // local cache of remote address - private var remote: Option[Address] = None + val log = Logging(system, "FailureInjector") // everything goes via these Throttle actors to enable easy steering private val sender = system.actorOf(Props(new Throttle(_.sendDownstream(_)))) @@ -54,8 +55,8 @@ class NetworkFailureInjector(system: ActorSystem) extends ChannelUpstreamHandler private case class Data(ctx: ChannelHandlerContext, rateMBit: Float, queue: Queue[MessageEvent]) - private case class SetRate(rateMBit: Float) private case class Send(ctx: ChannelHandlerContext, msg: MessageEvent) + private case class SetContext(ctx: ChannelHandlerContext) private case object Tick private class Throttle(send: (ChannelHandlerContext, MessageEvent) ⇒ Unit) extends Actor with FSM[State, Data] { @@ -65,6 +66,7 @@ class NetworkFailureInjector(system: ActorSystem) extends ChannelUpstreamHandler when(PassThrough) { case Event(Send(ctx, msg), d) ⇒ + log.debug("sending msg (PassThrough): {}", msg) send(ctx, msg) stay } @@ -77,26 +79,37 @@ class NetworkFailureInjector(system: ActorSystem) extends ChannelUpstreamHandler stay using d.copy(ctx = ctx, queue = d.queue.enqueue(msg)) case Event(Tick, d) ⇒ val (msg, queue) = d.queue.dequeue + log.debug("sending msg (Tick, {}/{} left): {}", d.queue.size, queue.size, msg) send(d.ctx, msg) - if (queue.nonEmpty) setTimer("send", Tick, (size(queue.head) / d.rateMBit) microseconds, false) + if (queue.nonEmpty) { + val time = (size(queue.head) / d.rateMBit).microseconds + log.debug("scheduling next Tick in {}", time) + setTimer("send", Tick, time, false) + } stay using d.copy(queue = queue) } onTransition { case Throttle -> PassThrough ⇒ - stateData.queue foreach (send(stateData.ctx, _)) + stateData.queue foreach { msg ⇒ + log.debug("sending msg (Transition): {}") + send(stateData.ctx, msg) + } cancelTimer("send") case Throttle -> Blackhole ⇒ cancelTimer("send") } when(Blackhole) { - case Event(Send(_, _), _) ⇒ + case Event(Send(_, msg), _) ⇒ + log.debug("dropping msg {}", msg) stay } whenUnhandled { - case Event(SetRate(rate), d) ⇒ + case Event(SetContext(ctx), d) ⇒ stay using d.copy(ctx = ctx) + case Event(NetworkFailureInjector.SetRate(rate), d) ⇒ + sender ! "ok" if (rate > 0) { goto(Throttle) using d.copy(rateMBit = rate, queue = Queue()) } else if (rate == 0) { @@ -104,6 +117,11 @@ class NetworkFailureInjector(system: ActorSystem) extends ChannelUpstreamHandler } else { goto(PassThrough) } + case Event(NetworkFailureInjector.Disconnect(abort), Data(ctx, _, _)) ⇒ + sender ! "ok" + // TODO implement abort + ctx.getChannel.disconnect() + stay } initialize @@ -114,46 +132,42 @@ class NetworkFailureInjector(system: ActorSystem) extends ChannelUpstreamHandler } } - def throttleSend(rateMBit: Float) { - sender ! SetRate(rateMBit) + private var remote: Option[Address] = None + + override def messageReceived(ctx: ChannelHandlerContext, msg: MessageEvent) { + log.debug("upstream(queued): {}", msg) + receiver ! Send(ctx, msg) } - def throttleReceive(rateMBit: Float) { - receiver ! SetRate(rateMBit) - } - - override def handleUpstream(ctx: ChannelHandlerContext, evt: ChannelEvent) { - evt match { - case msg: MessageEvent ⇒ - receiver ! Send(ctx, msg) - case state: ChannelStateEvent ⇒ - state.getState match { - case BOUND ⇒ - state.getValue match { - case null ⇒ - remote = remote flatMap { a ⇒ channels.remove(a, state.getChannel); None } - case a: InetSocketAddress ⇒ - val addr = Address("akka", "XXX", a.getHostName, a.getPort) - channels.put(addr, state.getChannel) - remote = Some(addr) - } - case OPEN if state.getValue == false ⇒ - remote = remote flatMap { a ⇒ channels.remove(a, state.getChannel); None } + override def channelConnected(ctx: ChannelHandlerContext, state: ChannelStateEvent) { + state.getValue match { + case a: InetSocketAddress ⇒ + val addr = Address("akka", "", a.getHostName, a.getPort) + log.debug("connected to {}", addr) + TestConductor(system).failureInjectors.put(addr, FailureInjector(sender, receiver)) match { + case null ⇒ // okay + case fi ⇒ system.log.error("{} already registered for address {}", fi, addr) } - ctx.sendUpstream(evt) - case _ ⇒ - ctx.sendUpstream(evt) + remote = Some(addr) + sender ! SetContext(ctx) + case x ⇒ throw new IllegalArgumentException("unknown address type: " + x) } } - override def handleDownstream(ctx: ChannelHandlerContext, evt: ChannelEvent) { - evt match { - case msg: MessageEvent ⇒ - sender ! Send(ctx, msg) - case _ ⇒ - ctx.sendUpstream(evt) + override def channelDisconnected(ctx: ChannelHandlerContext, state: ChannelStateEvent) { + log.debug("disconnected from {}", remote) + remote = remote flatMap { addr ⇒ + TestConductor(system).failureInjectors.remove(addr) + system.stop(sender) + system.stop(receiver) + None } } + override def writeRequested(ctx: ChannelHandlerContext, msg: MessageEvent) { + log.debug("downstream(queued): {}", msg) + sender ! Send(ctx, msg) + } + } diff --git a/akka-remote/src/main/scala/akka/remote/testconductor/Player.scala b/akka-remote/src/main/scala/akka/remote/testconductor/Player.scala index 93aa6bc33d..72b15922f3 100644 --- a/akka-remote/src/main/scala/akka/remote/testconductor/Player.scala +++ b/akka-remote/src/main/scala/akka/remote/testconductor/Player.scala @@ -6,20 +6,20 @@ package akka.remote.testconductor import akka.actor.{ Actor, ActorRef, ActorSystem, LoggingFSM, Props } import RemoteConnection.getAddrString import akka.util.duration._ -import TestConductorProtocol._ import org.jboss.netty.channel.{ Channel, SimpleChannelUpstreamHandler, ChannelHandlerContext, ChannelStateEvent, MessageEvent } import com.eaio.uuid.UUID import com.typesafe.config.ConfigFactory import akka.util.Timeout import akka.util.Duration import java.util.concurrent.TimeUnit.MILLISECONDS -import akka.pattern.ask +import akka.pattern.{ ask, pipe } import akka.dispatch.Await import scala.util.control.NoStackTrace import akka.actor.Status import akka.event.LoggingAdapter import akka.actor.PoisonPill import akka.event.Logging +import akka.dispatch.Future trait Player extends BarrierSync { this: TestConductorExt ⇒ @@ -29,7 +29,7 @@ trait Player extends BarrierSync { this: TestConductorExt ⇒ case x ⇒ x } - def startClient(port: Int) { + def startClient(port: Int): Future[Done] = { import ClientFSM._ import akka.actor.FSM._ import Settings.BarrierTimeout @@ -40,21 +40,21 @@ trait Player extends BarrierSync { this: TestConductorExt ⇒ var waiting: ActorRef = _ def receive = { case fsm: ActorRef ⇒ waiting = sender; fsm ! SubscribeTransitionCallBack(self) - case Transition(_, Connecting, Connected) ⇒ waiting ! "okay" + case Transition(_, Connecting, Connected) ⇒ waiting ! Done case t: Transition[_] ⇒ waiting ! Status.Failure(new RuntimeException("unexpected transition: " + t)) - case CurrentState(_, Connected) ⇒ waiting ! "okay" + case CurrentState(_, Connected) ⇒ waiting ! Done case _: CurrentState[_] ⇒ } })) - Await.result(a ? client, Duration.Inf) + a ? client mapTo } override def enter(name: String*) { system.log.debug("entering barriers " + name.mkString("(", ", ", ")")) name foreach { b ⇒ import Settings.BarrierTimeout - Await.result(client ? EnterBarrier(b), Duration.Inf) + Await.result(client ? Send(EnterBarrier(b)), Duration.Inf) system.log.debug("passed barrier {}", b) } } @@ -84,8 +84,7 @@ class ClientFSM(port: Int) extends Actor with LoggingFSM[ClientFSM.State, Client case Event(msg: ClientOp, _) ⇒ stay replying Status.Failure(new IllegalStateException("not connected yet")) case Event(Connected, d @ Data(channel, _)) ⇒ - val hello = Hello.newBuilder.setName(settings.name).setAddress(TestConductor().address).build - channel.write(Wrapper.newBuilder.setHello(hello).build) + channel.write(Hello(settings.name, TestConductor().address)) goto(Connected) case Event(_: ConnectionFailure, _) ⇒ // System.exit(1) @@ -100,19 +99,41 @@ class ClientFSM(port: Int) extends Actor with LoggingFSM[ClientFSM.State, Client case Event(Disconnected, _) ⇒ log.info("disconnected from TestConductor") throw new ConnectionFailure("disconnect") - case Event(msg: EnterBarrier, Data(channel, _)) ⇒ - sendMsg(channel)(msg) + case Event(Send(msg: EnterBarrier), Data(channel, None)) ⇒ + channel.write(msg) stay using Data(channel, Some(msg.name, sender)) - case Event(msg: Wrapper, Data(channel, Some((barrier, sender)))) if msg.getAllFields.size == 1 ⇒ - if (msg.hasBarrier) { - val b = msg.getBarrier.getName - if (b != barrier) { - sender ! Status.Failure(new RuntimeException("wrong barrier " + b + " received while waiting for " + barrier)) - } else { - sender ! b - } + case Event(Send(d: Done), Data(channel, _)) ⇒ + channel.write(d) + stay + case Event(Send(x), _) ⇒ + log.warning("cannot send message {}", x) + stay + case Event(EnterBarrier(b), Data(channel, Some((barrier, sender)))) ⇒ + if (b != barrier) { + sender ! Status.Failure(new RuntimeException("wrong barrier " + b + " received while waiting for " + barrier)) + } else { + sender ! b } stay using Data(channel, None) + case Event(ThrottleMsg(target, dir, rate), _) ⇒ + import settings.QueryTimeout + import context.dispatcher + TestConductor().failureInjectors.get(target.copy(system = "")) match { + case null ⇒ log.warning("cannot throttle unknown address {}", target) + case inj ⇒ + Future.sequence(inj.refs(dir) map (_ ? NetworkFailureInjector.SetRate(rate))) map (_ ⇒ Send(Done)) pipeTo self + } + stay + case Event(DisconnectMsg(target, abort), _) ⇒ + import settings.QueryTimeout + TestConductor().failureInjectors.get(target.copy(system = "")) match { + case null ⇒ log.warning("cannot disconnect unknown address {}", target) + case inj ⇒ inj.sender ? NetworkFailureInjector.Disconnect(abort) map (_ ⇒ Send(Done)) pipeTo self + } + stay + case Event(TerminateMsg(exit), _) ⇒ + System.exit(exit) + stay // needed because Java doesn’t have Nothing } onTermination { @@ -122,14 +143,6 @@ class ClientFSM(port: Int) extends Actor with LoggingFSM[ClientFSM.State, Client initialize - private def sendMsg(channel: Channel)(msg: ClientOp) { - msg match { - case EnterBarrier(name) ⇒ - val enter = TestConductorProtocol.EnterBarrier.newBuilder.setName(name).build - channel.write(Wrapper.newBuilder.setBarrier(enter).build) - } - } - } class PlayerHandler(fsm: ActorRef, log: LoggingAdapter) extends SimpleChannelUpstreamHandler { @@ -152,7 +165,7 @@ class PlayerHandler(fsm: ActorRef, log: LoggingAdapter) extends SimpleChannelUps val channel = event.getChannel log.debug("message from {}: {}", getAddrString(channel), event.getMessage) event.getMessage match { - case msg: Wrapper if msg.getAllFields.size == 1 ⇒ + case msg: NetworkOp ⇒ fsm ! msg case msg ⇒ log.info("server {} sent garbage '{}', disconnecting", getAddrString(channel), msg) diff --git a/akka-remote/src/main/scala/akka/remote/testconductor/RemoteConnection.scala b/akka-remote/src/main/scala/akka/remote/testconductor/RemoteConnection.scala index a92b6295e2..b2f4baebbb 100644 --- a/akka-remote/src/main/scala/akka/remote/testconductor/RemoteConnection.scala +++ b/akka-remote/src/main/scala/akka/remote/testconductor/RemoteConnection.scala @@ -17,7 +17,8 @@ class TestConductorPipelineFactory(handler: ChannelUpstreamHandler) extends Chan def getPipeline: ChannelPipeline = { val encap = List(new LengthFieldPrepender(4), new LengthFieldBasedFrameDecoder(10000, 0, 4, 0, 4)) val proto = List(new ProtobufEncoder, new ProtobufDecoder(TestConductorProtocol.Wrapper.getDefaultInstance)) - new StaticChannelPipeline(encap ::: proto ::: handler :: Nil: _*) + val msg = List(new MsgEncoder, new MsgDecoder) + new StaticChannelPipeline(encap ::: proto ::: msg ::: handler :: Nil: _*) } } diff --git a/akka-remote/src/main/scala/akka/remote/testconductor/package.scala b/akka-remote/src/main/scala/akka/remote/testconductor/package.scala index 8ebeea90a9..b24279dbf6 100644 --- a/akka-remote/src/main/scala/akka/remote/testconductor/package.scala +++ b/akka-remote/src/main/scala/akka/remote/testconductor/package.scala @@ -16,4 +16,16 @@ package object testconductor { implicit def address2scala(addr: TCP.Address): Address = Address(addr.getProtocol, addr.getSystem, addr.getHost, addr.getPort) + implicit def direction2proto(dir: Direction): TCP.Direction = dir match { + case Direction.Send ⇒ TCP.Direction.Send + case Direction.Receive ⇒ TCP.Direction.Receive + case Direction.Both ⇒ TCP.Direction.Both + } + + implicit def direction2scala(dir: TCP.Direction): Direction = dir match { + case TCP.Direction.Send ⇒ Direction.Send + case TCP.Direction.Receive ⇒ Direction.Receive + case TCP.Direction.Both ⇒ Direction.Both + } + } \ No newline at end of file diff --git a/akka-remote/src/multi-jvm/scala/akka/remote/AbstractRemoteActorMultiJvmSpec.scala b/akka-remote/src/multi-jvm/scala/akka/remote/AbstractRemoteActorMultiJvmSpec.scala index ab8bdadae6..ca4313b56b 100644 --- a/akka-remote/src/multi-jvm/scala/akka/remote/AbstractRemoteActorMultiJvmSpec.scala +++ b/akka-remote/src/multi-jvm/scala/akka/remote/AbstractRemoteActorMultiJvmSpec.scala @@ -1,6 +1,7 @@ package akka.remote import com.typesafe.config.{Config, ConfigFactory} +import akka.actor.Address trait AbstractRemoteActorMultiJvmSpec { def NrOfNodes: Int @@ -8,7 +9,6 @@ trait AbstractRemoteActorMultiJvmSpec { def PortRangeStart = 1990 def NodeRange = 1 to NrOfNodes - def PortRange = PortRangeStart to NrOfNodes private[this] val remotes: IndexedSeq[String] = { val nodesOpt = Option(AkkaRemoteSpec.testNodes).map(_.split(",").toIndexedSeq) diff --git a/akka-remote/src/multi-jvm/scala/akka/remote/testconductor/TestConductorSpec.scala b/akka-remote/src/multi-jvm/scala/akka/remote/testconductor/TestConductorSpec.scala index cae2917577..096d4c5a89 100644 --- a/akka-remote/src/multi-jvm/scala/akka/remote/testconductor/TestConductorSpec.scala +++ b/akka-remote/src/multi-jvm/scala/akka/remote/testconductor/TestConductorSpec.scala @@ -3,12 +3,24 @@ package akka.remote.testconductor import akka.remote.AkkaRemoteSpec import com.typesafe.config.ConfigFactory import akka.remote.AbstractRemoteActorMultiJvmSpec +import akka.actor.Props +import akka.actor.Actor +import akka.dispatch.Await +import akka.dispatch.Await.Awaitable +import akka.util.Duration +import akka.util.duration._ +import akka.testkit.ImplicitSender object TestConductorMultiJvmSpec extends AbstractRemoteActorMultiJvmSpec { override def NrOfNodes = 2 override def commonConfig = ConfigFactory.parseString(""" akka.loglevel = DEBUG akka.actor.provider = akka.remote.RemoteActorRefProvider + akka.remote { + transport = akka.remote.testconductor.TestConductorTransport + log-received-messages = on + log-sent-messages = on + } akka.actor.debug { receive = on fsm = on @@ -19,34 +31,87 @@ object TestConductorMultiJvmSpec extends AbstractRemoteActorMultiJvmSpec { } """) def nameConfig(n: Int) = ConfigFactory.parseString("akka.testconductor.name = node" + n).withFallback(nodeConfigs(n)) + + implicit def awaitHelper[T](w: Awaitable[T]) = new AwaitHelper(w) + class AwaitHelper[T](w: Awaitable[T]) { + def await: T = Await.result(w, Duration.Inf) + } } -import TestConductorMultiJvmSpec._ +class TestConductorMultiJvmNode1 extends AkkaRemoteSpec(TestConductorMultiJvmSpec.nameConfig(0)) { -class TestConductorMultiJvmNode1 extends AkkaRemoteSpec(nameConfig(0)) { + import TestConductorMultiJvmSpec._ - val nodes = TestConductorMultiJvmSpec.NrOfNodes + val nodes = NrOfNodes - "running a test" in { - val tc = TestConductor(system) - tc.startController() + val tc = TestConductor(system) + + val echo = system.actorOf(Props(new Actor { + def receive = { + case x ⇒ testActor ! x; sender ! x + } + }), "echo") + + "running a test with barrier" in { + tc.startController().await barrier("start") barrier("first") tc.enter("begin") barrier("end") } + + "throttling" in { + expectMsg("start") + tc.throttle("node1", "node0", Direction.Send, 0.016).await + tc.enter("throttled_send") + within(1 second, 2 seconds) { + receiveN(10) must be(0 to 9) + } + tc.enter("throttled_send2") + tc.throttle("node1", "node0", Direction.Send, -1).await + + tc.throttle("node1", "node0", Direction.Receive, 0.016).await + tc.enter("throttled_recv") + receiveN(10, 500 millis) must be(10 to 19) + tc.enter("throttled_recv2") + tc.throttle("node1", "node0", Direction.Receive, -1).await + } } -class TestConductorMultiJvmNode2 extends AkkaRemoteSpec(nameConfig(1)) { +class TestConductorMultiJvmNode2 extends AkkaRemoteSpec(TestConductorMultiJvmSpec.nameConfig(1)) with ImplicitSender { - val nodes = TestConductorMultiJvmSpec.NrOfNodes + import TestConductorMultiJvmSpec._ - "running a test" in { + val nodes = NrOfNodes + + val tc = TestConductor(system) + + val echo = system.actorFor("akka://" + akkaSpec(0) + "/user/echo") + + "running a test with barrier" in { barrier("start") - val tc = TestConductor(system) - tc.startClient(4712) + tc.startClient(4712).await barrier("first") tc.enter("begin") barrier("end") } + + "throttling" in { + echo ! "start" + expectMsg("start") + tc.enter("throttled_send") + for (i <- 0 to 9) echo ! i + expectMsg(500 millis, 0) + within(1 second, 2 seconds) { + receiveN(9) must be(1 to 9) + } + tc.enter("throttled_send2", "throttled_recv") + for (i <- 10 to 19) echo ! i + expectMsg(500 millis, 10) + within(1 second, 2 seconds) { + receiveN(9) must be(11 to 19) + } + tc.enter("throttled_recv2") + } + }