diff --git a/akka-testkit/src/main/scala/akka/testkit/TestKit.scala b/akka-testkit/src/main/scala/akka/testkit/TestKit.scala index bcac5c24cf..cbcfc2a77d 100644 --- a/akka-testkit/src/main/scala/akka/testkit/TestKit.scala +++ b/akka-testkit/src/main/scala/akka/testkit/TestKit.scala @@ -69,7 +69,7 @@ class TestActor(queue: BlockingDeque[TestActor.Message]) extends Actor { *
  * class Test extends TestKit(ActorSystem()) {
  *     try {
- *     
+ *
  *       val test = system.actorOf(Props[SomeActor]
  *
  *       within (1 second) {
@@ -77,7 +77,7 @@ class TestActor(queue: BlockingDeque[TestActor.Message]) extends Actor {
  *         expectMsg(Result1) // bounded to 1 second
  *         expectMsg(Result2) // bounded to the remainder of the 1 second
  *       }
- *     
+ *
  *     } finally {
  *       system.shutdown()
  *     }
@@ -86,7 +86,7 @@ class TestActor(queue: BlockingDeque[TestActor.Message]) extends Actor {
  *
  * Beware of two points:
  *
- *  - the ActorSystem passed into the constructor needs to be shutdown, 
+ *  - the ActorSystem passed into the constructor needs to be shutdown,
  *    otherwise thread pools and memory will be leaked
  *  - this trait is not thread-safe (only one actor with one queue, one stack
  *    of `within` blocks); it is expected that the code is executed from a
diff --git a/akka-zeromq/src/main/scala/akka/zeromq/ConcurrentSocketActor.scala b/akka-zeromq/src/main/scala/akka/zeromq/ConcurrentSocketActor.scala
index 82a07d7aa3..6626576089 100644
--- a/akka-zeromq/src/main/scala/akka/zeromq/ConcurrentSocketActor.scala
+++ b/akka-zeromq/src/main/scala/akka/zeromq/ConcurrentSocketActor.scala
@@ -12,16 +12,12 @@ import annotation.tailrec
 import akka.util.Duration
 import java.util.concurrent.TimeUnit
 
-private[zeromq] sealed trait PollLifeCycle
-private[zeromq] case object NoResults extends PollLifeCycle
-private[zeromq] case object Results extends PollLifeCycle
-private[zeromq] case object Closing extends PollLifeCycle
-
 private[zeromq] object ConcurrentSocketActor {
-  private case object Poll
-  private case object ReceiveFrames
-  private case object ClearPoll
-  private case class PollError(ex: Throwable)
+  private trait PollMsg
+  private case object Poll extends PollMsg
+  private case class ContinuePoll(frames: Vector[Frame]) extends PollMsg {
+    println("continue")
+  }
 
   private class NoSocketHandleException() extends Exception("Couldn't create a zeromq socket.")
 
@@ -31,45 +27,41 @@ private[zeromq] class ConcurrentSocketActor(params: Seq[SocketOption]) extends A
 
   import ConcurrentSocketActor._
   private val noBytes = Array[Byte]()
-  private val zmqContext = {
-    params collectFirst { case c: Context ⇒ c } getOrElse DefaultContext
-  }
+  private val zmqContext = params collectFirst { case c: Context ⇒ c } getOrElse DefaultContext
 
   private val deserializer = deserializerFromParams
   private val socket: Socket = socketFromParams
   private val poller: Poller = zmqContext.poller
   private val log = Logging(context.system, this)
 
-  private def handleConnectionMessages: Receive = {
-    case Send(frames) ⇒ {
-      sendFrames(frames)
-      pollAndReceiveFrames()
-    }
-    case ZMQMessage(frames) ⇒ {
-      sendFrames(frames)
-      pollAndReceiveFrames()
-    }
-    case Connect(endpoint) ⇒ {
-      socket.connect(endpoint)
-      notifyListener(Connecting)
-      pollAndReceiveFrames()
-    }
-    case Bind(endpoint) ⇒ {
-      socket.bind(endpoint)
-      pollAndReceiveFrames()
-    }
-    case Subscribe(topic) ⇒ {
-      socket.subscribe(topic.toArray)
-      pollAndReceiveFrames()
-    }
-    case Unsubscribe(topic) ⇒ {
-      socket.unsubscribe(topic.toArray)
-      pollAndReceiveFrames()
-    }
-    case Terminated(_) ⇒ context stop self
+  def receive = {
+    case Poll                 ⇒ doPoll()
+    case ContinuePoll(frames) ⇒ doPoll(currentFrames = frames)
+    case ZMQMessage(frames)   ⇒ sendMessage(frames)
+    case r: Request           ⇒ handleRequest(r)
+    case Terminated(_)        ⇒ context stop self
   }
 
-  private def handleSocketOption: Receive = {
+  private def handleRequest(msg: Request): Unit = msg match {
+    case Send(frames)         ⇒ sendMessage(frames)
+    case opt: SocketOption    ⇒ handleSocketOption(opt)
+    case q: SocketOptionQuery ⇒ handleSocketOptionQuery(q)
+  }
+
+  private def handleConnectOption(msg: SocketConnectOption): Unit = msg match {
+    case Connect(endpoint) ⇒ socket.connect(endpoint); notifyListener(Connecting)
+    case Bind(endpoint)    ⇒ socket.bind(endpoint)
+  }
+
+  private def handlePubSubOption(msg: PubSubOption): Unit = msg match {
+    case Subscribe(topic)   ⇒ socket.subscribe(topic.toArray)
+    case Unsubscribe(topic) ⇒ socket.unsubscribe(topic.toArray)
+  }
+
+  private def handleSocketOption(msg: SocketOption): Unit = msg match {
+    case x: SocketMeta               ⇒ throw new IllegalStateException("SocketMeta " + x + " only allowed for setting up a socket")
+    case c: SocketConnectOption      ⇒ handleConnectOption(c)
+    case ps: PubSubOption            ⇒ handlePubSubOption(ps)
     case Linger(value)               ⇒ socket.setLinger(value)
     case ReconnectIVL(value)         ⇒ socket.setReconnectIVL(value)
     case Backlog(value)              ⇒ socket.setBacklog(value)
@@ -87,51 +79,34 @@ private[zeromq] class ConcurrentSocketActor(params: Seq[SocketOption]) extends A
     case MulticastHops(value)        ⇒ socket.setMulticastHops(value)
     case SendBufferSize(value)       ⇒ socket.setSendBufferSize(value)
     case ReceiveBufferSize(value)    ⇒ socket.setReceiveBufferSize(value)
-    case Linger                      ⇒ sender ! socket.getLinger
-    case ReconnectIVL                ⇒ sender ! socket.getReconnectIVL
-    case Backlog                     ⇒ sender ! socket.getBacklog
-    case ReconnectIVLMax             ⇒ sender ! socket.getReconnectIVLMax
-    case MaxMsgSize                  ⇒ sender ! socket.getMaxMsgSize
-    case SendHighWatermark           ⇒ sender ! socket.getSndHWM
-    case ReceiveHighWatermark        ⇒ sender ! socket.getRcvHWM
-    case Swap                        ⇒ sender ! socket.getSwap
-    case Affinity                    ⇒ sender ! socket.getAffinity
-    case Identity                    ⇒ sender ! socket.getIdentity
-    case Rate                        ⇒ sender ! socket.getRate
-    case RecoveryInterval            ⇒ sender ! socket.getRecoveryInterval
-    case MulticastLoop               ⇒ sender ! socket.hasMulticastLoop
-    case MulticastHops               ⇒ sender ! socket.getMulticastHops
-    case SendBufferSize              ⇒ sender ! socket.getSendBufferSize
-    case ReceiveBufferSize           ⇒ sender ! socket.getReceiveBufferSize
-    case FileDescriptor              ⇒ sender ! socket.getFD
   }
 
-  private def internalMessage: Receive = {
-    case Poll ⇒ {
-      currentPoll = None
-      pollAndReceiveFrames()
-    }
-    case ReceiveFrames ⇒ {
-      receiveFrames() match {
-        case Seq()  ⇒
-        case frames ⇒ notifyListener(deserializer(frames))
-      }
-      self ! Poll
-    }
-    case ClearPoll ⇒ currentPoll = None
-    case PollError(ex) ⇒ {
-      log.error(ex, "There was a problem polling the zeromq socket")
-      self ! Poll
-    }
+  private def handleSocketOptionQuery(msg: SocketOptionQuery): Unit = msg match {
+    case Linger               ⇒ sender ! socket.getLinger
+    case ReconnectIVL         ⇒ sender ! socket.getReconnectIVL
+    case Backlog              ⇒ sender ! socket.getBacklog
+    case ReconnectIVLMax      ⇒ sender ! socket.getReconnectIVLMax
+    case MaxMsgSize           ⇒ sender ! socket.getMaxMsgSize
+    case SendHighWatermark    ⇒ sender ! socket.getSndHWM
+    case ReceiveHighWatermark ⇒ sender ! socket.getRcvHWM
+    case Swap                 ⇒ sender ! socket.getSwap
+    case Affinity             ⇒ sender ! socket.getAffinity
+    case Identity             ⇒ sender ! socket.getIdentity
+    case Rate                 ⇒ sender ! socket.getRate
+    case RecoveryInterval     ⇒ sender ! socket.getRecoveryInterval
+    case MulticastLoop        ⇒ sender ! socket.hasMulticastLoop
+    case MulticastHops        ⇒ sender ! socket.getMulticastHops
+    case SendBufferSize       ⇒ sender ! socket.getSendBufferSize
+    case ReceiveBufferSize    ⇒ sender ! socket.getReceiveBufferSize
+    case FileDescriptor       ⇒ sender ! socket.getFD
   }
 
-  override def receive: Receive = handleConnectionMessages orElse handleSocketOption orElse internalMessage
-
   override def preStart {
     watchListener()
     setupSocket()
     poller.register(socket, Poller.POLLIN)
     setupConnection()
+    self ! Poll
   }
 
   private def setupConnection() {
@@ -159,7 +134,6 @@ private[zeromq] class ConcurrentSocketActor(params: Seq[SocketOption]) extends A
 
   override def postStop {
     try {
-      currentPoll foreach { _ complete Right(Closing) }
       poller.unregister(socket)
       if (socket != null) socket.close
     } finally {
@@ -167,52 +141,58 @@ private[zeromq] class ConcurrentSocketActor(params: Seq[SocketOption]) extends A
     }
   }
 
-  private def sendFrames(frames: Seq[Frame]) {
+  private def sendMessage(frames: Seq[Frame]) {
     def sendBytes(bytes: Seq[Byte], flags: Int) = socket.send(bytes.toArray, flags)
     val iter = frames.iterator
     while (iter.hasNext) {
       val payload = iter.next.payload
       val flags = if (iter.hasNext) JZMQ.SNDMORE else 0
       sendBytes(payload, flags)
+      Thread.sleep(5)
     }
   }
 
-  private var currentPoll: Option[Promise[PollLifeCycle]] = None
-  private def pollAndReceiveFrames() {
-    if (currentPoll.isEmpty) currentPoll = newEventLoop
-  }
-
-  private val eventLoopDispatcher = {
-    val fromConfig = params collectFirst { case PollDispatcher(name) ⇒ context.system.dispatchers.lookup(name) }
-    fromConfig getOrElse context.system.dispatcher
-  }
-
-  private val pollTimeout = {
+  // this is a “PollMsg=>Unit” which either polls or schedules Poll, depending on the sign of the timeout
+  private val doPollTimeout = {
     val fromConfig = params collectFirst { case PollTimeoutDuration(duration) ⇒ duration }
-    fromConfig getOrElse ZeroMQExtension(context.system).DefaultPollTimeout
+    val duration = fromConfig getOrElse ZeroMQExtension(context.system).DefaultPollTimeout
+    if (duration > Duration.Zero) { (msg: PollMsg) ⇒
+      // for positive timeout values, do poll (i.e. block this thread)
+      poller.poll(duration.toMicros)
+      self ! msg
+    } else {
+      val d = -duration
+
+      { (msg: PollMsg) ⇒
+        // for negative timeout values, schedule Poll token -duration into the future
+        context.system.scheduler.scheduleOnce(d, self, msg)
+        ()
+      }
+    }
   }
 
-  private def newEventLoop: Option[Promise[PollLifeCycle]] = {
-    implicit val executor = eventLoopDispatcher
-    Some((Future {
-      if (poller.poll(pollTimeout.toMicros) > 0 && poller.pollin(0)) Results else NoResults
-    }).asInstanceOf[Promise[PollLifeCycle]] onSuccess {
-      case Results   ⇒ self ! ReceiveFrames
-      case NoResults ⇒ self ! Poll
-      case _         ⇒ self ! ClearPoll
-    } onFailure {
-      case ex ⇒ self ! PollError(ex)
-    })
-  }
-
-  private def receiveFrames(): Seq[Frame] = {
-    @tailrec def receiveBytes(next: Array[Byte], currentFrames: Vector[Frame] = Vector.empty): Seq[Frame] = {
-      val nwBytes = if (next != null && next.nonEmpty) next else noBytes
-      val frames = currentFrames :+ Frame(nwBytes)
-      if (socket.hasReceiveMore) receiveBytes(socket.recv(0), frames) else frames
+  @tailrec private def doPoll(togo: Int = 10, currentFrames: Vector[Frame] = Vector.empty): Unit =
+    receiveMessage(currentFrames) match {
+      case null  ⇒ // receiveMessage has already done something special here
+      case Seq() ⇒ doPollTimeout(Poll)
+      case frames ⇒
+        notifyListener(deserializer(frames))
+        if (togo > 0) doPoll(togo - 1)
+        else self ! Poll
     }
 
-    receiveBytes(socket.recv(0))
+  @tailrec private def receiveMessage(currentFrames: Vector[Frame]): Seq[Frame] = {
+    socket.recv(JZMQ.NOBLOCK) match {
+      case null ⇒
+        if (currentFrames.isEmpty) currentFrames
+        else {
+          doPollTimeout(ContinuePoll(currentFrames))
+          null
+        }
+      case bytes ⇒
+        val frames = currentFrames :+ Frame(if (bytes.length == 0) noBytes else bytes)
+        if (socket.hasReceiveMore) receiveMessage(frames) else frames
+    }
   }
 
   private val listenerOpt = params collectFirst { case Listener(l) ⇒ l }
diff --git a/akka-zeromq/src/main/scala/akka/zeromq/SocketOption.scala b/akka-zeromq/src/main/scala/akka/zeromq/SocketOption.scala
index d3b824bca1..1e4c83bcef 100644
--- a/akka-zeromq/src/main/scala/akka/zeromq/SocketOption.scala
+++ b/akka-zeromq/src/main/scala/akka/zeromq/SocketOption.scala
@@ -160,7 +160,6 @@ case class PollTimeoutDuration(duration: Duration = 100 millis) extends SocketMe
  * @param endpoint
  */
 case class Bind(endpoint: String) extends SocketConnectOption
-private[zeromq] case object Close extends Request
 
 /**
  * The [[akka.zeromq.Subscribe]] option shall establish a new message filter on a [[akka.zeromq.SocketType.Pub]] socket.