diff --git a/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala b/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala index a51022c95f..f3f14ca5e1 100644 --- a/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala +++ b/akka-remote/src/main/scala/akka/remote/netty/NettyRemoteSupport.scala @@ -181,9 +181,15 @@ abstract class RemoteClient private[akka] ( /** * Returns an array with the current pending messages not yet delivered. */ - def pendingMessages: Array[Any] = pendingRequests - .toArray.asInstanceOf[Array[(Boolean, Uuid, RemoteMessageProtocol)]] - .map(req => MessageSerializer.deserialize(req._3.getMessage)) + def pendingMessages: Array[Any] = { + var messages = Vector[Any]() + val iter = pendingRequests.iterator + while (iter.hasNext) { + val (_, _, message) = iter.next + messages = messages :+ MessageSerializer.deserialize(message.getMessage) + } + messages.toArray + } /** * Converts the message to the wireprotocol and sends the message across the wire @@ -222,16 +228,43 @@ abstract class RemoteClient private[akka] ( senderFuture: Option[CompletableFuture[T]]): Option[CompletableFuture[T]] = { if (isRunning) { if (request.getOneWay) { - pendingRequests.add((true, null, request)) - sendPendingMessages() + try { + val future = currentChannel.write(RemoteEncoder.encode(request)) + future.awaitUninterruptibly() + if (!future.isCancelled && !future.isSuccess) { + notifyListeners(RemoteClientWriteFailed(request, future.getCause, module, remoteAddress)) + throw future.getCause + } + } catch { + case e: Throwable => + pendingRequests.add((true, null, request)) // add the request to the tx log after a failing send + notifyListeners(RemoteClientError(e, module, remoteAddress)) + throw e + } None } else { val futureResult = if (senderFuture.isDefined) senderFuture.get else new DefaultCompletableFuture[T](request.getActorInfo.getTimeout) val futureUuid = uuidFrom(request.getUuid.getHigh, request.getUuid.getLow) futures.put(futureUuid, futureResult) // Add future prematurely, remove it if write fails - pendingRequests.add((false, futureUuid, request)) - sendPendingMessages() + + def handleRequestReplyError(future: ChannelFuture) = { + pendingRequests.add((false, futureUuid, request)) // Add the request to the tx log after a failing send + val f = futures.remove(futureUuid) // Clean up future + if (f ne null) f.completeWithException(future.getCause) + notifyListeners(RemoteClientWriteFailed(request, future.getCause, module, remoteAddress)) + } + + var future: ChannelFuture = null + try { + // try to send the original one + future = currentChannel.write(RemoteEncoder.encode(request)) + future.awaitUninterruptibly() + if (future.isCancelled) futures.remove(futureUuid) // Clean up future + else if (!future.isSuccess) handleRequestReplyError(future) + } catch { + case e: Exception => handleRequestReplyError(future) + } Some(futureResult) } } else { @@ -241,10 +274,12 @@ abstract class RemoteClient private[akka] ( } } - private[remote] def sendPendingMessages() = { - var pendingMessage = pendingRequests.peek // try to grab first message - while (pendingMessage ne null) { - val (isOneWay, futureUuid, message) = pendingMessage + private[remote] def sendPendingRequests() = pendingRequests synchronized { // ensure only one thread at a time can flush the log + val nrOfMessages = pendingRequests.size + if (nrOfMessages > 0) EventHandler.info(this, "Resending [%s] previously failed messages after remote client reconnect" format nrOfMessages) + var pendingRequest = pendingRequests.peek + while (pendingRequest ne null) { + val (isOneWay, futureUuid, message) = pendingRequest if (isOneWay) { // sendOneWay val future = currentChannel.write(RemoteEncoder.encode(message)) future.awaitUninterruptibly() @@ -255,16 +290,15 @@ abstract class RemoteClient private[akka] ( } else { // sendRequestReply val future = currentChannel.write(RemoteEncoder.encode(message)) future.awaitUninterruptibly() - if (future.isCancelled) { - futures.remove(futureUuid) // Clean up future - } else if (!future.isSuccess) { + if (future.isCancelled) futures.remove(futureUuid) // Clean up future + else if (!future.isSuccess) { val f = futures.remove(futureUuid) // Clean up future if (f ne null) f.completeWithException(future.getCause) notifyListeners(RemoteClientWriteFailed(message, future.getCause, module, remoteAddress)) } } - pendingRequests.remove(pendingMessage) // message delivered; remove from tx log - pendingMessage = pendingRequests.peek // try to grab next message + pendingRequests.remove(pendingRequest) + pendingRequest = pendingRequests.peek // try to grab next message } } @@ -474,9 +508,16 @@ class ActiveRemoteClientHandler( } override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { - client.sendPendingMessages() // try to send pending message (still there after client/server crash ard reconnect - client.notifyListeners(RemoteClientConnected(client.module, client.remoteAddress)) - client.resetReconnectionTimeWindow + try { + client.sendPendingRequests() // try to send pending requests (still there after client/server crash ard reconnect + client.notifyListeners(RemoteClientConnected(client.module, client.remoteAddress)) + client.resetReconnectionTimeWindow + } catch { + case e: Throwable => + EventHandler.error(e, this, e.getMessage) + client.notifyListeners(RemoteClientError(e, client.module, client.remoteAddress)) + throw e + } } override def channelDisconnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = { diff --git a/akka-remote/src/test/scala/remote/AkkaRemoteTest.scala b/akka-remote/src/test/scala/remote/AkkaRemoteTest.scala index 0c7421df0a..19f3313536 100644 --- a/akka-remote/src/test/scala/remote/AkkaRemoteTest.scala +++ b/akka-remote/src/test/scala/remote/AkkaRemoteTest.scala @@ -1,13 +1,13 @@ package akka.actor.remote -import org.scalatest.WordSpec import org.scalatest.matchers.MustMatchers -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.junit.JUnitRunner import org.junit.runner.RunWith import akka.remote.netty.NettyRemoteSupport import akka.actor. {Actor, ActorRegistry} import java.util.concurrent. {TimeUnit, CountDownLatch} +import org.scalatest.{Spec, WordSpec, BeforeAndAfterAll, BeforeAndAfterEach} +import java.util.concurrent.atomic.AtomicBoolean object AkkaRemoteTest { class ReplyHandlerActor(latch: CountDownLatch, expect: String) extends Actor { @@ -59,4 +59,89 @@ class AkkaRemoteTest extends /* Utilities */ def replyHandler(latch: CountDownLatch, expect: String) = Some(Actor.actorOf(new ReplyHandlerActor(latch, expect)).start) +} + +trait NetworkFailureTest { self: WordSpec => + import akka.actor.Actor._ + import akka.util.Duration + + val BYTES_PER_SECOND = "60KByte/s" + val DELAY_MILLIS = "350ms" + val PORT_RANGE = "1024-65535" + + def replyWithTcpResetFor(duration: Duration, dead: AtomicBoolean) = { + spawn { + try { + enableTcpReset() + println("===>>> Reply with [TCP RST] for [" + duration + "]") + Thread.sleep(duration.toMillis) + restoreIP + } catch { + case e => + dead.set(true) + e.printStackTrace + } + } + } + + def throttleNetworkFor(duration: Duration, dead: AtomicBoolean) = { + spawn { + try { + enableNetworkThrottling() + println("===>>> Throttling network with [" + BYTES_PER_SECOND + ", " + DELAY_MILLIS + "] for [" + duration + "]") + Thread.sleep(duration.toMillis) + restoreIP + } catch { + case e => + dead.set(true) + e.printStackTrace + } + } + } + + def dropNetworkFor(duration: Duration, dead: AtomicBoolean) = { + spawn { + try { + enableNetworkDrop() + println("===>>> Blocking network [TCP DENY] for [" + duration + "]") + Thread.sleep(duration.toMillis) + restoreIP + } catch { + case e => + dead.set(true) + e.printStackTrace + } + } + } + + def sleepFor(duration: Duration) = { + println("===>>> Sleeping for [" + duration + "]") + Thread sleep (duration.toMillis) + } + + def enableNetworkThrottling() = { + restoreIP() + assert(new ProcessBuilder("ipfw", "add", "pipe", "1", "ip", "from", "any", "to", "any").start.waitFor == 0) + assert(new ProcessBuilder("ipfw", "add", "pipe", "2", "ip", "from", "any", "to", "any").start.waitFor == 0) + assert(new ProcessBuilder("ipfw", "pipe", "1", "config", "bw", BYTES_PER_SECOND, "delay", DELAY_MILLIS).start.waitFor == 0) + assert(new ProcessBuilder("ipfw", "pipe", "2", "config", "bw", BYTES_PER_SECOND, "delay", DELAY_MILLIS).start.waitFor == 0) + } + + def enableNetworkDrop() = { + restoreIP() + assert(new ProcessBuilder("ipfw", "add", "1", "deny", "tcp", "from", "any", "to", "any", PORT_RANGE).start.waitFor == 0) + } + + def enableTcpReset() = { + restoreIP() + assert(new ProcessBuilder("ipfw", "add", "1", "reset", "tcp", "from", "any", "to", "any", PORT_RANGE).start.waitFor == 0) + } + + def restoreIP() = { + println("===>>> Restoring network") + assert(new ProcessBuilder("ipfw", "del", "pipe", "1").start.waitFor == 0) + assert(new ProcessBuilder("ipfw", "del", "pipe", "2").start.waitFor == 0) + assert(new ProcessBuilder("ipfw", "flush").start.waitFor == 0) + assert(new ProcessBuilder("ipfw", "pipe", "flush").start.waitFor == 0) + } } \ No newline at end of file diff --git a/akka-remote/src/test/scala/remote/RemoteErrorHandlingTest.scala b/akka-remote/src/test/scala/remote/RemoteErrorHandlingTest.scala new file mode 100644 index 0000000000..f3b89172e3 --- /dev/null +++ b/akka-remote/src/test/scala/remote/RemoteErrorHandlingTest.scala @@ -0,0 +1,117 @@ +package akka.actor.remote + +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import akka.actor.Actor._ +import akka.actor.{ActorRef, Actor} +import akka.util.duration._ +import java.util.concurrent.atomic.AtomicBoolean + +object RemoteErrorHandlingTest { + case class Send(actor: ActorRef) + + class RemoteActorSpecActorUnidirectional extends Actor { + self.id = "network-drop:unidirectional" + def receive = { + case "Ping" => self.reply_?("Pong") + } + } + + class Decrementer extends Actor { + def receive = { + case "done" => self.reply_?(false) + case i: Int if i > 0 => + self.reply_?(i - 1) + case i: Int => + self.reply_?(0) + this become { + case "done" => self.reply_?(true) + case _ => //Do Nothing + } + } + } + + class RemoteActorSpecActorBidirectional extends Actor { + + def receive = { + case "Hello" => + self.reply("World") + case "Failure" => + throw new RuntimeException("Expected exception; to test fault-tolerance") + } + } + + class RemoteActorSpecActorAsyncSender(latch: CountDownLatch) extends Actor { + def receive = { + case Send(actor: ActorRef) => + actor ! "Hello" + case "World" => latch.countDown + } + } +} + +class RemoteErrorHandlingTest extends AkkaRemoteTest with NetworkFailureTest { + import RemoteErrorHandlingTest._ + + "Remote actors" should { + + "be able to recover from network drop without loosing any messages" in { + val latch = new CountDownLatch(10) + implicit val sender = replyHandler(latch, "Pong") + val service = actorOf[RemoteActorSpecActorUnidirectional] + remote.register(service.id, service) + val actor = remote.actorFor(service.id, 5000L, host, port) + actor ! "Ping" + actor ! "Ping" + actor ! "Ping" + actor ! "Ping" + actor ! "Ping" + val dead = new AtomicBoolean(false) + dropNetworkFor (10 seconds, dead) // drops the network - in another thread - so async + sleepFor (2 seconds) // wait until network drop is done before sending the other messages + try { actor ! "Ping" } catch { case e => () } // queue up messages + try { actor ! "Ping" } catch { case e => () } // ... + try { actor ! "Ping" } catch { case e => () } // ... + try { actor ! "Ping" } catch { case e => () } // ... + try { actor ! "Ping" } catch { case e => () } // ... + latch.await(15, TimeUnit.SECONDS) must be (true) // network should be restored and the messages delivered + dead.get must be (false) + } + + "be able to recover from TCP RESET without loosing any messages" in { + val latch = new CountDownLatch(10) + implicit val sender = replyHandler(latch, "Pong") + val service = actorOf[RemoteActorSpecActorUnidirectional] + remote.register(service.id, service) + val actor = remote.actorFor(service.id, 5000L, host, port) + actor ! "Ping" + actor ! "Ping" + actor ! "Ping" + actor ! "Ping" + actor ! "Ping" + val dead = new AtomicBoolean(false) + replyWithTcpResetFor (10 seconds, dead) + sleepFor (2 seconds) + try { actor ! "Ping" } catch { case e => () } // queue up messages + try { actor ! "Ping" } catch { case e => () } // ... + try { actor ! "Ping" } catch { case e => () } // ... + try { actor ! "Ping" } catch { case e => () } // ... + try { actor ! "Ping" } catch { case e => () } // ... + latch.await(15, TimeUnit.SECONDS) must be (true) + dead.get must be (false) + } +/* + "sendWithBangAndGetReplyThroughSenderRef" in { + remote.register(actorOf[RemoteActorSpecActorBidirectional]) + implicit val timeout = 500000000L + val actor = remote.actorFor( + "akka.actor.remote.ServerInitiatedRemoteActorSpec$RemoteActorSpecActorBidirectional", timeout, host, port) + val latch = new CountDownLatch(1) + val sender = actorOf( new RemoteActorSpecActorAsyncSender(latch) ).start + sender ! Send(actor) + latch.await(1, TimeUnit.SECONDS) must be (true) + } + */ + } +} + diff --git a/project/build/AkkaProject.scala b/project/build/AkkaProject.scala index 5a04b66627..4b59c401cb 100644 --- a/project/build/AkkaProject.scala +++ b/project/build/AkkaProject.scala @@ -346,6 +346,7 @@ class AkkaParentProject(info: ProjectInfo) extends DefaultProject(info) { val junit = Dependencies.junit val scalatest = Dependencies.scalatest + override def testOptions = createTestFilter( _.endsWith("Spec")) override def bndImportPackage = "javax.transaction;version=1.1" :: super.bndImportPackage.toList }