From d9a8c4e7e68476a99d3438ff031cf8f1c8935241 Mon Sep 17 00:00:00 2001 From: Ivan Porto Carrero Date: Thu, 19 Jan 2012 00:26:52 +0100 Subject: [PATCH] Changes the intialization logic for the actor so that zeromq options are preserved --- .../additional/third-party-integrations.rst | 2 +- .../akka/zeromq/ConcurrentSocketActor.scala | 125 +++++++++++------- .../src/main/scala/akka/zeromq/Context.scala | 20 --- .../main/scala/akka/zeromq/Deserializer.scala | 2 +- .../src/main/scala/akka/zeromq/Requests.scala | 35 ++++- .../main/scala/akka/zeromq/SocketType.scala | 14 -- .../scala/akka/zeromq/ZeroMQExtension.scala | 34 ++--- .../zeromq/ConcurrentSocketActorSpec.scala | 12 +- 8 files changed, 132 insertions(+), 112 deletions(-) delete mode 100644 akka-zeromq/src/main/scala/akka/zeromq/Context.scala delete mode 100644 akka-zeromq/src/main/scala/akka/zeromq/SocketType.scala diff --git a/akka-docs/additional/third-party-integrations.rst b/akka-docs/additional/third-party-integrations.rst index 446c8436b3..93b3b36d2e 100644 --- a/akka-docs/additional/third-party-integrations.rst +++ b/akka-docs/additional/third-party-integrations.rst @@ -13,5 +13,5 @@ Scalatra Scalatra has Akka integration. -Read more here: ``_ +Read more here: ``_ diff --git a/akka-zeromq/src/main/scala/akka/zeromq/ConcurrentSocketActor.scala b/akka-zeromq/src/main/scala/akka/zeromq/ConcurrentSocketActor.scala index ca7cc08e89..07a7b5f888 100644 --- a/akka-zeromq/src/main/scala/akka/zeromq/ConcurrentSocketActor.scala +++ b/akka-zeromq/src/main/scala/akka/zeromq/ConcurrentSocketActor.scala @@ -8,17 +8,22 @@ import org.zeromq.{ ZMQ ⇒ JZMQ } import akka.actor._ import akka.dispatch.{ Promise, Future } import akka.event.Logging +import akka.util.duration._ private[zeromq] sealed trait PollLifeCycle private[zeromq] case object NoResults extends PollLifeCycle private[zeromq] case object Results extends PollLifeCycle private[zeromq] case object Closing extends PollLifeCycle -private[zeromq] class ConcurrentSocketActor(params: SocketParameters) extends Actor { +private[zeromq] class ConcurrentSocketActor(params: Seq[SocketOption]) extends Actor { private val noBytes = Array[Byte]() - private val socket: Socket = params.context.socket(params.socketType) - private val poller: Poller = params.context.poller + private val zmqContext = { + params find (_.isInstanceOf[Context]) map (_.asInstanceOf[Context]) getOrElse new Context(1) + } + private lazy val deserializer = deserializerFromParams + private lazy val socket: Socket = socketFromParams + private lazy val poller: Poller = zmqContext.poller private val log = Logging(context.system, this) private case object Poll @@ -55,45 +60,26 @@ private[zeromq] class ConcurrentSocketActor(params: SocketParameters) extends Ac } private def handleSocketOption: Receive = { - case Linger(value) ⇒ socket.setLinger(value) - case ReconnectIVL(value) ⇒ socket.setReconnectIVL(value) - case Backlog(value) ⇒ socket.setBacklog(value) - case ReconnectIVLMax(value) ⇒ socket.setReconnectIVLMax(value) - case MaxMsgSize(value) ⇒ socket.setMaxMsgSize(value) - case SndHWM(value) ⇒ socket.setSndHWM(value) - case RcvHWM(value) ⇒ socket.setRcvHWM(value) - case HWM(value) ⇒ socket.setHWM(value) - case Swap(value) ⇒ socket.setSwap(value) - case Affinity(value) ⇒ socket.setAffinity(value) - case Identity(value) ⇒ socket.setIdentity(value) - case Rate(value) ⇒ socket.setRate(value) - case RecoveryInterval(value) ⇒ socket.setRecoveryInterval(value) - case MulticastLoop(value) ⇒ socket.setMulticastLoop(value) - case MulticastHops(value) ⇒ socket.setMulticastHops(value) - case ReceiveTimeOut(value) ⇒ socket.setReceiveTimeOut(value) - case SendTimeOut(value) ⇒ socket.setSendTimeOut(value) - case SendBufferSize(value) ⇒ socket.setSendBufferSize(value) - case ReceiveBufferSize(value) ⇒ socket.setReceiveBufferSize(value) - case Linger ⇒ sender ! socket.getLinger - case ReconnectIVL ⇒ sender ! socket.getReconnectIVL - case Backlog ⇒ sender ! socket.getBacklog - case ReconnectIVLMax ⇒ sender ! socket.getReconnectIVLMax - case MaxMsgSize ⇒ sender ! socket.getMaxMsgSize - case SndHWM ⇒ sender ! socket.getSndHWM - case RcvHWM ⇒ sender ! socket.getRcvHWM - case Swap ⇒ sender ! socket.getSwap - case Affinity ⇒ sender ! socket.getAffinity - case Identity ⇒ sender ! socket.getIdentity - case Rate ⇒ sender ! socket.getRate - case RecoveryInterval ⇒ sender ! socket.getRecoveryInterval - case MulticastLoop ⇒ sender ! socket.hasMulticastLoop - case MulticastHops ⇒ sender ! socket.getMulticastHops - case ReceiveTimeOut ⇒ sender ! socket.getReceiveTimeOut - case SendTimeOut ⇒ sender ! socket.getSendTimeOut - case SendBufferSize ⇒ sender ! socket.getSendBufferSize - case ReceiveBufferSize ⇒ sender ! socket.getReceiveBufferSize - case ReceiveMore ⇒ sender ! socket.hasReceiveMore - case FileDescriptor ⇒ sender ! socket.getFD + case Linger ⇒ sender ! socket.getLinger + case ReconnectIVL ⇒ sender ! socket.getReconnectIVL + case Backlog ⇒ sender ! socket.getBacklog + case ReconnectIVLMax ⇒ sender ! socket.getReconnectIVLMax + case MaxMsgSize ⇒ sender ! socket.getMaxMsgSize + case SndHWM ⇒ sender ! socket.getSndHWM + case RcvHWM ⇒ sender ! socket.getRcvHWM + case Swap ⇒ sender ! socket.getSwap + case Affinity ⇒ sender ! socket.getAffinity + case Identity ⇒ sender ! socket.getIdentity + case Rate ⇒ sender ! socket.getRate + case RecoveryInterval ⇒ sender ! socket.getRecoveryInterval + case MulticastLoop ⇒ sender ! socket.hasMulticastLoop + case MulticastHops ⇒ sender ! socket.getMulticastHops + case ReceiveTimeOut ⇒ sender ! socket.getReceiveTimeOut + case SendTimeOut ⇒ sender ! socket.getSendTimeOut + case SendBufferSize ⇒ sender ! socket.getSendBufferSize + case ReceiveBufferSize ⇒ sender ! socket.getReceiveBufferSize + case ReceiveMore ⇒ sender ! socket.hasReceiveMore + case FileDescriptor ⇒ sender ! socket.getFD } private def internalMessage: Receive = { @@ -104,7 +90,7 @@ private[zeromq] class ConcurrentSocketActor(params: SocketParameters) extends Ac case ReceiveFrames ⇒ { receiveFrames() match { case Seq() ⇒ - case frames ⇒ notifyListener(params.deserializer(frames)) + case frames ⇒ notifyListener(deserializer(frames)) } self ! Poll } @@ -118,9 +104,46 @@ private[zeromq] class ConcurrentSocketActor(params: SocketParameters) extends Ac override def receive: Receive = handleConnectionMessages orElse handleSocketOption orElse internalMessage override def preStart { + setupSocket() poller.register(socket, Poller.POLLIN) } + private def socketFromParams() = { + require(ZeroMQExtension.check[SocketType.ZMQSocketType](params), "A socket type is required") + (params + find (_.isInstanceOf[SocketType.ZMQSocketType]) + map (t ⇒ zmqContext.socket(t.asInstanceOf[SocketType.ZMQSocketType])) get) + } + + private def deserializerFromParams = { + params find (_.isInstanceOf[Deserializer]) map (_.asInstanceOf[Deserializer]) getOrElse new ZMQMessageDeserializer + } + + private def setupSocket() = { + params foreach { + case Linger(value) ⇒ socket.setLinger(value) + case ReconnectIVL(value) ⇒ socket.setReconnectIVL(value) + case Backlog(value) ⇒ socket.setBacklog(value) + case ReconnectIVLMax(value) ⇒ socket.setReconnectIVLMax(value) + case MaxMsgSize(value) ⇒ socket.setMaxMsgSize(value) + case SndHWM(value) ⇒ socket.setSndHWM(value) + case RcvHWM(value) ⇒ socket.setRcvHWM(value) + case HWM(value) ⇒ socket.setHWM(value) + case Swap(value) ⇒ socket.setSwap(value) + case Affinity(value) ⇒ socket.setAffinity(value) + case Identity(value) ⇒ socket.setIdentity(value) + case Rate(value) ⇒ socket.setRate(value) + case RecoveryInterval(value) ⇒ socket.setRecoveryInterval(value) + case MulticastLoop(value) ⇒ socket.setMulticastLoop(value) + case MulticastHops(value) ⇒ socket.setMulticastHops(value) + case ReceiveTimeOut(value) ⇒ socket.setReceiveTimeOut(value) + case SendTimeOut(value) ⇒ socket.setSendTimeOut(value) + case SendBufferSize(value) ⇒ socket.setSendBufferSize(value) + case ReceiveBufferSize(value) ⇒ socket.setReceiveBufferSize(value) + case _ ⇒ + } + } + override def postStop { try { poller.unregister(socket) @@ -146,14 +169,22 @@ private[zeromq] class ConcurrentSocketActor(params: SocketParameters) extends Ac if (currentPoll.isEmpty) currentPoll = newEventLoop } - private val eventLoopDispatcher = { - params.pollDispatcher.map(d ⇒ context.system.dispatchers.lookup(d)) getOrElse context.system.dispatcher + private lazy val eventLoopDispatcher = { + val fromConfig = params.find(_.isInstanceOf[PollDispatcher]) map { + option ⇒ context.system.dispatchers.lookup(option.asInstanceOf[PollDispatcher].name) + } + fromConfig getOrElse context.system.dispatcher + } + + private lazy val pollTimeout = { + val fromConfig = params find (_.isInstanceOf[PollTimeoutDuration]) map (_.asInstanceOf[PollTimeoutDuration].duration) + fromConfig getOrElse 100.millis } private def newEventLoop: Option[Promise[PollLifeCycle]] = { implicit val executor = eventLoopDispatcher Some((Future { - if (poller.poll(params.pollTimeoutDuration.toMicros) > 0 && poller.pollin(0)) Results else NoResults + if (poller.poll(pollTimeout.toMicros) > 0 && poller.pollin(0)) Results else NoResults }).asInstanceOf[Promise[PollLifeCycle]] onSuccess { case Results ⇒ self ! ReceiveFrames case NoResults ⇒ self ! Poll @@ -184,7 +215,7 @@ private[zeromq] class ConcurrentSocketActor(params: SocketParameters) extends Ac } private def notifyListener(message: Any) { - params.listener.foreach { listener ⇒ + params find (_.isInstanceOf[Listener]) map (_.asInstanceOf[Listener].listener) foreach { listener ⇒ if (listener.isTerminated) context stop self else diff --git a/akka-zeromq/src/main/scala/akka/zeromq/Context.scala b/akka-zeromq/src/main/scala/akka/zeromq/Context.scala deleted file mode 100644 index 7d768e1492..0000000000 --- a/akka-zeromq/src/main/scala/akka/zeromq/Context.scala +++ /dev/null @@ -1,20 +0,0 @@ -/** - * Copyright (C) 2009-2011 Typesafe Inc. - */ -package akka.zeromq - -import org.zeromq.{ ZMQ ⇒ JZMQ } -import akka.zeromq.SocketType._ - -class Context(numIoThreads: Int) { - private val context = JZMQ.context(numIoThreads) - def socket(socketType: SocketType) = { - context.socket(socketType.id) - } - def poller = { - context.poller - } - def term = { - context.term - } -} diff --git a/akka-zeromq/src/main/scala/akka/zeromq/Deserializer.scala b/akka-zeromq/src/main/scala/akka/zeromq/Deserializer.scala index 6430a5a9c6..daf75f7c4b 100644 --- a/akka-zeromq/src/main/scala/akka/zeromq/Deserializer.scala +++ b/akka-zeromq/src/main/scala/akka/zeromq/Deserializer.scala @@ -6,7 +6,7 @@ package akka.zeromq case class Frame(payload: Seq[Byte]) object Frame { def apply(s: String): Frame = Frame(s.getBytes) } -trait Deserializer { +trait Deserializer extends SocketOption { def apply(frames: Seq[Frame]): Any } diff --git a/akka-zeromq/src/main/scala/akka/zeromq/Requests.scala b/akka-zeromq/src/main/scala/akka/zeromq/Requests.scala index 3cb5388069..0b1a2dad90 100644 --- a/akka-zeromq/src/main/scala/akka/zeromq/Requests.scala +++ b/akka-zeromq/src/main/scala/akka/zeromq/Requests.scala @@ -4,12 +4,45 @@ package akka.zeromq import com.google.protobuf.Message +import org.zeromq.{ ZMQ ⇒ JZMQ } +import akka.actor.ActorRef +import akka.util.duration._ +import akka.util.Duration sealed trait Request -sealed trait SocketOption extends Request +trait SocketOption extends Request sealed trait SocketOptionQuery extends Request case class Connect(endpoint: String) extends Request + +object Context { + def apply(numIoThreads: Int = 1) = new Context(numIoThreads) +} +class Context(numIoThreads: Int) extends SocketOption { + private val context = JZMQ.context(numIoThreads) + def socket(socketType: SocketType.ZMQSocketType) = { + context.socket(socketType.id) + } + def poller = { + context.poller + } + def term = { + context.term + } +} + +object SocketType { + abstract class ZMQSocketType(val id: Int) extends SocketOption + object Pub extends ZMQSocketType(JZMQ.PUB) + object Sub extends ZMQSocketType(JZMQ.SUB) + object Dealer extends ZMQSocketType(JZMQ.DEALER) + object Router extends ZMQSocketType(JZMQ.ROUTER) +} + +case class Listener(listener: ActorRef) extends SocketOption +case class PollDispatcher(name: String) extends SocketOption +case class PollTimeoutDuration(duration: Duration = 100 millis) extends SocketOption + case class Bind(endpoint: String) extends Request private[zeromq] case object Close extends Request diff --git a/akka-zeromq/src/main/scala/akka/zeromq/SocketType.scala b/akka-zeromq/src/main/scala/akka/zeromq/SocketType.scala deleted file mode 100644 index aba0c1608a..0000000000 --- a/akka-zeromq/src/main/scala/akka/zeromq/SocketType.scala +++ /dev/null @@ -1,14 +0,0 @@ -/** - * Copyright (C) 2009-2011 Typesafe Inc. - */ -package akka.zeromq - -import org.zeromq.{ ZMQ ⇒ JZMQ } - -object SocketType extends Enumeration { - type SocketType = Value - val Pub = Value(JZMQ.PUB) - val Sub = Value(JZMQ.SUB) - val Dealer = Value(JZMQ.DEALER) - val Router = Value(JZMQ.ROUTER) -} diff --git a/akka-zeromq/src/main/scala/akka/zeromq/ZeroMQExtension.scala b/akka-zeromq/src/main/scala/akka/zeromq/ZeroMQExtension.scala index e44bc00d06..685e3106bd 100644 --- a/akka-zeromq/src/main/scala/akka/zeromq/ZeroMQExtension.scala +++ b/akka-zeromq/src/main/scala/akka/zeromq/ZeroMQExtension.scala @@ -5,18 +5,10 @@ package akka.zeromq import akka.util.Duration import akka.util.duration._ -import akka.zeromq.SocketType._ import org.zeromq.{ ZMQ ⇒ JZMQ } import akka.actor._ import akka.dispatch.{ Dispatcher, Await } - -case class SocketParameters( - socketType: SocketType, - context: Context, - listener: Option[ActorRef] = None, - pollDispatcher: Option[String] = None, - deserializer: Deserializer = new ZMQMessageDeserializer, - pollTimeoutDuration: Duration = 100 millis) +import collection.mutable.ListBuffer case class ZeroMQVersion(major: Int, minor: Int, patch: Int) { override def toString = "%d.%d.%d".format(major, minor, patch) @@ -28,6 +20,12 @@ object ZeroMQExtension extends ExtensionId[ZeroMQExtension] with ExtensionIdProv private val minVersionString = "2.1.0" private val minVersion = JZMQ.makeVersion(2, 1, 0) + + private[zeromq] def check[TOption <: SocketOption: Manifest](parameters: Seq[SocketOption]) = { + parameters exists { p ⇒ + ClassManifest.singleType(p) <:< manifest[TOption] + } + } } class ZeroMQExtension(system: ActorSystem) extends Extension { @@ -35,23 +33,15 @@ class ZeroMQExtension(system: ActorSystem) extends Extension { ZeroMQVersion(JZMQ.getMajorVersion, JZMQ.getMinorVersion, JZMQ.getPatchVersion) } - lazy val DefaultContext = newContext() - - def newContext(numIoThreads: Int = 1) = { + def newSocketProps(socketParameters: SocketOption*): Props = { verifyZeroMQVersion - new Context(numIoThreads) + require(ZeroMQExtension.check[SocketType.ZMQSocketType](socketParameters), "A socket type is required") + Props(new ConcurrentSocketActor(socketParameters)).withDispatcher("akka.zeromq.socket-dispatcher") } - def newSocket(socketType: SocketType, - listener: Option[ActorRef] = None, - context: Context = DefaultContext, // For most applications you want to use the default context - deserializer: Deserializer = new ZMQMessageDeserializer, - pollDispatcher: Option[String] = None, - pollTimeoutDuration: Duration = 500 millis): ActorRef = { - verifyZeroMQVersion - val params = SocketParameters(socketType, context, listener, pollDispatcher, deserializer, pollTimeoutDuration) + def newSocket(socketParameters: SocketOption*): ActorRef = { implicit val timeout = system.settings.ActorTimeout - val req = (zeromq ? Props(new ConcurrentSocketActor(params)).withDispatcher("akka.zeromq.socket-dispatcher")).mapTo[ActorRef] + val req = (zeromq ? newSocketProps(socketParameters: _*)).mapTo[ActorRef] Await.result(req, timeout.duration) } diff --git a/akka-zeromq/src/test/scala/akka/zeromq/ConcurrentSocketActorSpec.scala b/akka-zeromq/src/test/scala/akka/zeromq/ConcurrentSocketActorSpec.scala index 6c256dd58e..d29579d8cb 100644 --- a/akka-zeromq/src/test/scala/akka/zeromq/ConcurrentSocketActorSpec.scala +++ b/akka-zeromq/src/test/scala/akka/zeromq/ConcurrentSocketActorSpec.scala @@ -37,7 +37,7 @@ class ConcurrentSocketActorSpec "support pub-sub connections" in { checkZeroMQInstallation val (publisherProbe, subscriberProbe) = (TestProbe(), TestProbe()) - val context = zmq.newContext() + val context = Context() val publisher = newPublisher(context, publisherProbe.ref) val subscriber = newSubscriber(context, subscriberProbe.ref) val msgGenerator = newMessageGenerator(publisher) @@ -63,7 +63,7 @@ class ConcurrentSocketActorSpec "support zero-length message frames" in { checkZeroMQInstallation val publisherProbe = TestProbe() - val context = zmq.newContext() + val context = Context() val publisher = newPublisher(context, publisherProbe.ref) try { @@ -77,18 +77,18 @@ class ConcurrentSocketActorSpec } } def newPublisher(context: Context, listener: ActorRef) = { - val publisher = zmq.newSocket(SocketType.Pub, context = context, listener = Some(listener)) + val publisher = zmq.newSocket(SocketType.Pub, context, Listener(listener)) publisher ! Bind(endpoint) publisher } def newSubscriber(context: Context, listener: ActorRef) = { - val subscriber = zmq.newSocket(SocketType.Sub, context = context, listener = Some(listener)) + val subscriber = zmq.newSocket(SocketType.Sub, context, Listener(listener)) subscriber ! Connect(endpoint) subscriber ! Subscribe(Seq()) subscriber } def newMessageGenerator(actorRef: ActorRef) = { - system.actorOf(Props(new MessageGeneratorActor(actorRef)).withTimeout(Timeout(10 millis))) + system.actorOf(Props(new MessageGeneratorActor(actorRef))) } def checkZeroMQInstallation = try { @@ -114,7 +114,7 @@ class ConcurrentSocketActorSpec private var genMessages: Cancellable = null override def preStart() = { - genMessages = system.scheduler.schedule(100 millis, 10 millis, self, 'm) + genMessages = system.scheduler.schedule(100 millis, 10 millis, self, "genMessage") } override def postStop() = {