diff --git a/akka-testkit/src/main/scala/akka/testkit/TestKit.scala b/akka-testkit/src/main/scala/akka/testkit/TestKit.scala index bcac5c24cf..cbcfc2a77d 100644 --- a/akka-testkit/src/main/scala/akka/testkit/TestKit.scala +++ b/akka-testkit/src/main/scala/akka/testkit/TestKit.scala @@ -69,7 +69,7 @@ class TestActor(queue: BlockingDeque[TestActor.Message]) extends Actor { *
* class Test extends TestKit(ActorSystem()) {
* try {
- *
+ *
* val test = system.actorOf(Props[SomeActor]
*
* within (1 second) {
@@ -77,7 +77,7 @@ class TestActor(queue: BlockingDeque[TestActor.Message]) extends Actor {
* expectMsg(Result1) // bounded to 1 second
* expectMsg(Result2) // bounded to the remainder of the 1 second
* }
- *
+ *
* } finally {
* system.shutdown()
* }
@@ -86,7 +86,7 @@ class TestActor(queue: BlockingDeque[TestActor.Message]) extends Actor {
*
* Beware of two points:
*
- * - the ActorSystem passed into the constructor needs to be shutdown,
+ * - the ActorSystem passed into the constructor needs to be shutdown,
* otherwise thread pools and memory will be leaked
* - this trait is not thread-safe (only one actor with one queue, one stack
* of `within` blocks); it is expected that the code is executed from a
diff --git a/akka-zeromq/src/main/scala/akka/zeromq/ConcurrentSocketActor.scala b/akka-zeromq/src/main/scala/akka/zeromq/ConcurrentSocketActor.scala
index 82a07d7aa3..6626576089 100644
--- a/akka-zeromq/src/main/scala/akka/zeromq/ConcurrentSocketActor.scala
+++ b/akka-zeromq/src/main/scala/akka/zeromq/ConcurrentSocketActor.scala
@@ -12,16 +12,12 @@ import annotation.tailrec
import akka.util.Duration
import java.util.concurrent.TimeUnit
-private[zeromq] sealed trait PollLifeCycle
-private[zeromq] case object NoResults extends PollLifeCycle
-private[zeromq] case object Results extends PollLifeCycle
-private[zeromq] case object Closing extends PollLifeCycle
-
private[zeromq] object ConcurrentSocketActor {
- private case object Poll
- private case object ReceiveFrames
- private case object ClearPoll
- private case class PollError(ex: Throwable)
+ private trait PollMsg
+ private case object Poll extends PollMsg
+ private case class ContinuePoll(frames: Vector[Frame]) extends PollMsg {
+ println("continue")
+ }
private class NoSocketHandleException() extends Exception("Couldn't create a zeromq socket.")
@@ -31,45 +27,41 @@ private[zeromq] class ConcurrentSocketActor(params: Seq[SocketOption]) extends A
import ConcurrentSocketActor._
private val noBytes = Array[Byte]()
- private val zmqContext = {
- params collectFirst { case c: Context ⇒ c } getOrElse DefaultContext
- }
+ private val zmqContext = params collectFirst { case c: Context ⇒ c } getOrElse DefaultContext
private val deserializer = deserializerFromParams
private val socket: Socket = socketFromParams
private val poller: Poller = zmqContext.poller
private val log = Logging(context.system, this)
- private def handleConnectionMessages: Receive = {
- case Send(frames) ⇒ {
- sendFrames(frames)
- pollAndReceiveFrames()
- }
- case ZMQMessage(frames) ⇒ {
- sendFrames(frames)
- pollAndReceiveFrames()
- }
- case Connect(endpoint) ⇒ {
- socket.connect(endpoint)
- notifyListener(Connecting)
- pollAndReceiveFrames()
- }
- case Bind(endpoint) ⇒ {
- socket.bind(endpoint)
- pollAndReceiveFrames()
- }
- case Subscribe(topic) ⇒ {
- socket.subscribe(topic.toArray)
- pollAndReceiveFrames()
- }
- case Unsubscribe(topic) ⇒ {
- socket.unsubscribe(topic.toArray)
- pollAndReceiveFrames()
- }
- case Terminated(_) ⇒ context stop self
+ def receive = {
+ case Poll ⇒ doPoll()
+ case ContinuePoll(frames) ⇒ doPoll(currentFrames = frames)
+ case ZMQMessage(frames) ⇒ sendMessage(frames)
+ case r: Request ⇒ handleRequest(r)
+ case Terminated(_) ⇒ context stop self
}
- private def handleSocketOption: Receive = {
+ private def handleRequest(msg: Request): Unit = msg match {
+ case Send(frames) ⇒ sendMessage(frames)
+ case opt: SocketOption ⇒ handleSocketOption(opt)
+ case q: SocketOptionQuery ⇒ handleSocketOptionQuery(q)
+ }
+
+ private def handleConnectOption(msg: SocketConnectOption): Unit = msg match {
+ case Connect(endpoint) ⇒ socket.connect(endpoint); notifyListener(Connecting)
+ case Bind(endpoint) ⇒ socket.bind(endpoint)
+ }
+
+ private def handlePubSubOption(msg: PubSubOption): Unit = msg match {
+ case Subscribe(topic) ⇒ socket.subscribe(topic.toArray)
+ case Unsubscribe(topic) ⇒ socket.unsubscribe(topic.toArray)
+ }
+
+ private def handleSocketOption(msg: SocketOption): Unit = msg match {
+ case x: SocketMeta ⇒ throw new IllegalStateException("SocketMeta " + x + " only allowed for setting up a socket")
+ case c: SocketConnectOption ⇒ handleConnectOption(c)
+ case ps: PubSubOption ⇒ handlePubSubOption(ps)
case Linger(value) ⇒ socket.setLinger(value)
case ReconnectIVL(value) ⇒ socket.setReconnectIVL(value)
case Backlog(value) ⇒ socket.setBacklog(value)
@@ -87,51 +79,34 @@ private[zeromq] class ConcurrentSocketActor(params: Seq[SocketOption]) extends A
case MulticastHops(value) ⇒ socket.setMulticastHops(value)
case SendBufferSize(value) ⇒ socket.setSendBufferSize(value)
case ReceiveBufferSize(value) ⇒ socket.setReceiveBufferSize(value)
- case Linger ⇒ sender ! socket.getLinger
- case ReconnectIVL ⇒ sender ! socket.getReconnectIVL
- case Backlog ⇒ sender ! socket.getBacklog
- case ReconnectIVLMax ⇒ sender ! socket.getReconnectIVLMax
- case MaxMsgSize ⇒ sender ! socket.getMaxMsgSize
- case SendHighWatermark ⇒ sender ! socket.getSndHWM
- case ReceiveHighWatermark ⇒ sender ! socket.getRcvHWM
- case Swap ⇒ sender ! socket.getSwap
- case Affinity ⇒ sender ! socket.getAffinity
- case Identity ⇒ sender ! socket.getIdentity
- case Rate ⇒ sender ! socket.getRate
- case RecoveryInterval ⇒ sender ! socket.getRecoveryInterval
- case MulticastLoop ⇒ sender ! socket.hasMulticastLoop
- case MulticastHops ⇒ sender ! socket.getMulticastHops
- case SendBufferSize ⇒ sender ! socket.getSendBufferSize
- case ReceiveBufferSize ⇒ sender ! socket.getReceiveBufferSize
- case FileDescriptor ⇒ sender ! socket.getFD
}
- private def internalMessage: Receive = {
- case Poll ⇒ {
- currentPoll = None
- pollAndReceiveFrames()
- }
- case ReceiveFrames ⇒ {
- receiveFrames() match {
- case Seq() ⇒
- case frames ⇒ notifyListener(deserializer(frames))
- }
- self ! Poll
- }
- case ClearPoll ⇒ currentPoll = None
- case PollError(ex) ⇒ {
- log.error(ex, "There was a problem polling the zeromq socket")
- self ! Poll
- }
+ private def handleSocketOptionQuery(msg: SocketOptionQuery): Unit = msg match {
+ case Linger ⇒ sender ! socket.getLinger
+ case ReconnectIVL ⇒ sender ! socket.getReconnectIVL
+ case Backlog ⇒ sender ! socket.getBacklog
+ case ReconnectIVLMax ⇒ sender ! socket.getReconnectIVLMax
+ case MaxMsgSize ⇒ sender ! socket.getMaxMsgSize
+ case SendHighWatermark ⇒ sender ! socket.getSndHWM
+ case ReceiveHighWatermark ⇒ sender ! socket.getRcvHWM
+ case Swap ⇒ sender ! socket.getSwap
+ case Affinity ⇒ sender ! socket.getAffinity
+ case Identity ⇒ sender ! socket.getIdentity
+ case Rate ⇒ sender ! socket.getRate
+ case RecoveryInterval ⇒ sender ! socket.getRecoveryInterval
+ case MulticastLoop ⇒ sender ! socket.hasMulticastLoop
+ case MulticastHops ⇒ sender ! socket.getMulticastHops
+ case SendBufferSize ⇒ sender ! socket.getSendBufferSize
+ case ReceiveBufferSize ⇒ sender ! socket.getReceiveBufferSize
+ case FileDescriptor ⇒ sender ! socket.getFD
}
- override def receive: Receive = handleConnectionMessages orElse handleSocketOption orElse internalMessage
-
override def preStart {
watchListener()
setupSocket()
poller.register(socket, Poller.POLLIN)
setupConnection()
+ self ! Poll
}
private def setupConnection() {
@@ -159,7 +134,6 @@ private[zeromq] class ConcurrentSocketActor(params: Seq[SocketOption]) extends A
override def postStop {
try {
- currentPoll foreach { _ complete Right(Closing) }
poller.unregister(socket)
if (socket != null) socket.close
} finally {
@@ -167,52 +141,58 @@ private[zeromq] class ConcurrentSocketActor(params: Seq[SocketOption]) extends A
}
}
- private def sendFrames(frames: Seq[Frame]) {
+ private def sendMessage(frames: Seq[Frame]) {
def sendBytes(bytes: Seq[Byte], flags: Int) = socket.send(bytes.toArray, flags)
val iter = frames.iterator
while (iter.hasNext) {
val payload = iter.next.payload
val flags = if (iter.hasNext) JZMQ.SNDMORE else 0
sendBytes(payload, flags)
+ Thread.sleep(5)
}
}
- private var currentPoll: Option[Promise[PollLifeCycle]] = None
- private def pollAndReceiveFrames() {
- if (currentPoll.isEmpty) currentPoll = newEventLoop
- }
-
- private val eventLoopDispatcher = {
- val fromConfig = params collectFirst { case PollDispatcher(name) ⇒ context.system.dispatchers.lookup(name) }
- fromConfig getOrElse context.system.dispatcher
- }
-
- private val pollTimeout = {
+ // this is a “PollMsg=>Unit” which either polls or schedules Poll, depending on the sign of the timeout
+ private val doPollTimeout = {
val fromConfig = params collectFirst { case PollTimeoutDuration(duration) ⇒ duration }
- fromConfig getOrElse ZeroMQExtension(context.system).DefaultPollTimeout
+ val duration = fromConfig getOrElse ZeroMQExtension(context.system).DefaultPollTimeout
+ if (duration > Duration.Zero) { (msg: PollMsg) ⇒
+ // for positive timeout values, do poll (i.e. block this thread)
+ poller.poll(duration.toMicros)
+ self ! msg
+ } else {
+ val d = -duration
+
+ { (msg: PollMsg) ⇒
+ // for negative timeout values, schedule Poll token -duration into the future
+ context.system.scheduler.scheduleOnce(d, self, msg)
+ ()
+ }
+ }
}
- private def newEventLoop: Option[Promise[PollLifeCycle]] = {
- implicit val executor = eventLoopDispatcher
- Some((Future {
- if (poller.poll(pollTimeout.toMicros) > 0 && poller.pollin(0)) Results else NoResults
- }).asInstanceOf[Promise[PollLifeCycle]] onSuccess {
- case Results ⇒ self ! ReceiveFrames
- case NoResults ⇒ self ! Poll
- case _ ⇒ self ! ClearPoll
- } onFailure {
- case ex ⇒ self ! PollError(ex)
- })
- }
-
- private def receiveFrames(): Seq[Frame] = {
- @tailrec def receiveBytes(next: Array[Byte], currentFrames: Vector[Frame] = Vector.empty): Seq[Frame] = {
- val nwBytes = if (next != null && next.nonEmpty) next else noBytes
- val frames = currentFrames :+ Frame(nwBytes)
- if (socket.hasReceiveMore) receiveBytes(socket.recv(0), frames) else frames
+ @tailrec private def doPoll(togo: Int = 10, currentFrames: Vector[Frame] = Vector.empty): Unit =
+ receiveMessage(currentFrames) match {
+ case null ⇒ // receiveMessage has already done something special here
+ case Seq() ⇒ doPollTimeout(Poll)
+ case frames ⇒
+ notifyListener(deserializer(frames))
+ if (togo > 0) doPoll(togo - 1)
+ else self ! Poll
}
- receiveBytes(socket.recv(0))
+ @tailrec private def receiveMessage(currentFrames: Vector[Frame]): Seq[Frame] = {
+ socket.recv(JZMQ.NOBLOCK) match {
+ case null ⇒
+ if (currentFrames.isEmpty) currentFrames
+ else {
+ doPollTimeout(ContinuePoll(currentFrames))
+ null
+ }
+ case bytes ⇒
+ val frames = currentFrames :+ Frame(if (bytes.length == 0) noBytes else bytes)
+ if (socket.hasReceiveMore) receiveMessage(frames) else frames
+ }
}
private val listenerOpt = params collectFirst { case Listener(l) ⇒ l }
diff --git a/akka-zeromq/src/main/scala/akka/zeromq/SocketOption.scala b/akka-zeromq/src/main/scala/akka/zeromq/SocketOption.scala
index d3b824bca1..1e4c83bcef 100644
--- a/akka-zeromq/src/main/scala/akka/zeromq/SocketOption.scala
+++ b/akka-zeromq/src/main/scala/akka/zeromq/SocketOption.scala
@@ -160,7 +160,6 @@ case class PollTimeoutDuration(duration: Duration = 100 millis) extends SocketMe
* @param endpoint
*/
case class Bind(endpoint: String) extends SocketConnectOption
-private[zeromq] case object Close extends Request
/**
* The [[akka.zeromq.Subscribe]] option shall establish a new message filter on a [[akka.zeromq.SocketType.Pub]] socket.