2012-01-14 03:16:39 +01:00
|
|
|
|
/**
|
2012-01-18 21:01:14 +01:00
|
|
|
|
* Copyright (C) 2009-2012 Typesafe Inc. <http://www.typesafe.com>
|
2012-01-14 03:16:39 +01:00
|
|
|
|
*/
|
|
|
|
|
|
package akka.zeromq
|
|
|
|
|
|
|
|
|
|
|
|
import org.zeromq.ZMQ.{ Socket, Poller }
|
|
|
|
|
|
import org.zeromq.{ ZMQ ⇒ JZMQ }
|
|
|
|
|
|
import akka.actor._
|
2012-01-14 12:13:46 +01:00
|
|
|
|
import akka.dispatch.{ Promise, Future }
|
2012-01-16 10:48:23 +01:00
|
|
|
|
import akka.event.Logging
|
2012-01-19 12:38:36 +01:00
|
|
|
|
import annotation.tailrec
|
2012-01-19 13:06:22 +01:00
|
|
|
|
import akka.util.Duration
|
|
|
|
|
|
import java.util.concurrent.TimeUnit
|
2012-01-14 03:16:39 +01:00
|
|
|
|
|
2012-01-19 12:04:35 +01:00
|
|
|
|
private[zeromq] object ConcurrentSocketActor {
|
2012-04-03 12:12:52 +02:00
|
|
|
|
private sealed trait PollMsg
|
2012-03-30 15:24:57 +02:00
|
|
|
|
private case object Poll extends PollMsg
|
2012-04-03 12:12:52 +02:00
|
|
|
|
private case object PollCareful extends PollMsg
|
2012-01-19 12:04:35 +01:00
|
|
|
|
|
2012-01-20 00:34:36 +01:00
|
|
|
|
private class NoSocketHandleException() extends Exception("Couldn't create a zeromq socket.")
|
|
|
|
|
|
|
|
|
|
|
|
private val DefaultContext = Context()
|
2012-01-19 12:04:35 +01:00
|
|
|
|
}
|
2012-01-19 00:26:52 +01:00
|
|
|
|
private[zeromq] class ConcurrentSocketActor(params: Seq[SocketOption]) extends Actor {
|
2012-01-14 03:16:39 +01:00
|
|
|
|
|
2012-01-19 12:04:35 +01:00
|
|
|
|
import ConcurrentSocketActor._
|
2012-01-14 03:16:39 +01:00
|
|
|
|
private val noBytes = Array[Byte]()
|
2012-03-30 15:24:57 +02:00
|
|
|
|
private val zmqContext = params collectFirst { case c: Context ⇒ c } getOrElse DefaultContext
|
2012-01-16 10:48:23 +01:00
|
|
|
|
|
2012-01-19 12:38:36 +01:00
|
|
|
|
private val deserializer = deserializerFromParams
|
2012-04-03 12:12:52 +02:00
|
|
|
|
private val socketType = {
|
|
|
|
|
|
import SocketType.{ ZMQSocketType ⇒ ST }
|
|
|
|
|
|
params.collectFirst { case t: ST ⇒ t }.getOrElse(throw new IllegalArgumentException("A socket type is required"))
|
|
|
|
|
|
}
|
|
|
|
|
|
private val socket: Socket = zmqContext.socket(socketType)
|
2012-01-19 12:38:36 +01:00
|
|
|
|
private val poller: Poller = zmqContext.poller
|
|
|
|
|
|
private val log = Logging(context.system, this)
|
2012-01-16 10:54:46 +01:00
|
|
|
|
|
2012-03-30 15:24:57 +02:00
|
|
|
|
def receive = {
|
2012-04-03 12:12:52 +02:00
|
|
|
|
case m: PollMsg ⇒ doPoll(m)
|
|
|
|
|
|
case ZMQMessage(frames) ⇒ sendMessage(frames)
|
|
|
|
|
|
case r: Request ⇒ handleRequest(r)
|
|
|
|
|
|
case Terminated(_) ⇒ context stop self
|
2012-03-30 15:24:57 +02:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private def handleRequest(msg: Request): Unit = msg match {
|
|
|
|
|
|
case Send(frames) ⇒ sendMessage(frames)
|
|
|
|
|
|
case opt: SocketOption ⇒ handleSocketOption(opt)
|
|
|
|
|
|
case q: SocketOptionQuery ⇒ handleSocketOptionQuery(q)
|
2012-01-16 00:44:09 +01:00
|
|
|
|
}
|
2012-01-16 10:54:46 +01:00
|
|
|
|
|
2012-03-30 15:24:57 +02:00
|
|
|
|
private def handleConnectOption(msg: SocketConnectOption): Unit = msg match {
|
|
|
|
|
|
case Connect(endpoint) ⇒ socket.connect(endpoint); notifyListener(Connecting)
|
|
|
|
|
|
case Bind(endpoint) ⇒ socket.bind(endpoint)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private def handlePubSubOption(msg: PubSubOption): Unit = msg match {
|
|
|
|
|
|
case Subscribe(topic) ⇒ socket.subscribe(topic.toArray)
|
|
|
|
|
|
case Unsubscribe(topic) ⇒ socket.unsubscribe(topic.toArray)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private def handleSocketOption(msg: SocketOption): Unit = msg match {
|
|
|
|
|
|
case x: SocketMeta ⇒ throw new IllegalStateException("SocketMeta " + x + " only allowed for setting up a socket")
|
|
|
|
|
|
case c: SocketConnectOption ⇒ handleConnectOption(c)
|
|
|
|
|
|
case ps: PubSubOption ⇒ handlePubSubOption(ps)
|
2012-01-20 01:16:24 +01:00
|
|
|
|
case Linger(value) ⇒ socket.setLinger(value)
|
|
|
|
|
|
case ReconnectIVL(value) ⇒ socket.setReconnectIVL(value)
|
|
|
|
|
|
case Backlog(value) ⇒ socket.setBacklog(value)
|
|
|
|
|
|
case ReconnectIVLMax(value) ⇒ socket.setReconnectIVLMax(value)
|
|
|
|
|
|
case MaxMsgSize(value) ⇒ socket.setMaxMsgSize(value)
|
|
|
|
|
|
case SendHighWatermark(value) ⇒ socket.setSndHWM(value)
|
|
|
|
|
|
case ReceiveHighWatermark(value) ⇒ socket.setRcvHWM(value)
|
|
|
|
|
|
case HighWatermark(value) ⇒ socket.setHWM(value)
|
|
|
|
|
|
case Swap(value) ⇒ socket.setSwap(value)
|
|
|
|
|
|
case Affinity(value) ⇒ socket.setAffinity(value)
|
|
|
|
|
|
case Identity(value) ⇒ socket.setIdentity(value)
|
|
|
|
|
|
case Rate(value) ⇒ socket.setRate(value)
|
|
|
|
|
|
case RecoveryInterval(value) ⇒ socket.setRecoveryInterval(value)
|
|
|
|
|
|
case MulticastLoop(value) ⇒ socket.setMulticastLoop(value)
|
|
|
|
|
|
case MulticastHops(value) ⇒ socket.setMulticastHops(value)
|
|
|
|
|
|
case SendBufferSize(value) ⇒ socket.setSendBufferSize(value)
|
|
|
|
|
|
case ReceiveBufferSize(value) ⇒ socket.setReceiveBufferSize(value)
|
2012-01-14 03:16:39 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
2012-04-02 12:25:30 +02:00
|
|
|
|
private def handleSocketOptionQuery(msg: SocketOptionQuery): Unit =
|
|
|
|
|
|
sender ! (msg match {
|
|
|
|
|
|
case Linger ⇒ socket.getLinger
|
|
|
|
|
|
case ReconnectIVL ⇒ socket.getReconnectIVL
|
|
|
|
|
|
case Backlog ⇒ socket.getBacklog
|
|
|
|
|
|
case ReconnectIVLMax ⇒ socket.getReconnectIVLMax
|
|
|
|
|
|
case MaxMsgSize ⇒ socket.getMaxMsgSize
|
|
|
|
|
|
case SendHighWatermark ⇒ socket.getSndHWM
|
|
|
|
|
|
case ReceiveHighWatermark ⇒ socket.getRcvHWM
|
|
|
|
|
|
case Swap ⇒ socket.getSwap
|
|
|
|
|
|
case Affinity ⇒ socket.getAffinity
|
|
|
|
|
|
case Identity ⇒ socket.getIdentity
|
|
|
|
|
|
case Rate ⇒ socket.getRate
|
|
|
|
|
|
case RecoveryInterval ⇒ socket.getRecoveryInterval
|
|
|
|
|
|
case MulticastLoop ⇒ socket.hasMulticastLoop
|
|
|
|
|
|
case MulticastHops ⇒ socket.getMulticastHops
|
|
|
|
|
|
case SendBufferSize ⇒ socket.getSendBufferSize
|
|
|
|
|
|
case ReceiveBufferSize ⇒ socket.getReceiveBufferSize
|
|
|
|
|
|
case FileDescriptor ⇒ socket.getFD
|
|
|
|
|
|
})
|
2012-01-16 00:44:09 +01:00
|
|
|
|
|
2012-01-14 03:16:39 +01:00
|
|
|
|
override def preStart {
|
2012-01-19 12:38:36 +01:00
|
|
|
|
watchListener()
|
2012-01-19 00:26:52 +01:00
|
|
|
|
setupSocket()
|
2012-01-14 03:16:39 +01:00
|
|
|
|
poller.register(socket, Poller.POLLIN)
|
2012-01-19 09:50:59 +01:00
|
|
|
|
setupConnection()
|
2012-04-03 12:12:52 +02:00
|
|
|
|
|
|
|
|
|
|
import SocketType._
|
|
|
|
|
|
socketType match {
|
|
|
|
|
|
case Pub | Push ⇒ // don’t poll
|
|
|
|
|
|
case Sub | Pull | Pair | Dealer | Router ⇒ self ! Poll
|
|
|
|
|
|
case Req | Rep ⇒ self ! PollCareful
|
|
|
|
|
|
}
|
2012-01-19 09:50:59 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private def setupConnection() {
|
|
|
|
|
|
params filter (_.isInstanceOf[SocketConnectOption]) foreach { self ! _ }
|
|
|
|
|
|
params filter (_.isInstanceOf[PubSubOption]) foreach { self ! _ }
|
2012-01-14 03:16:39 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
2012-01-19 00:26:52 +01:00
|
|
|
|
private def deserializerFromParams = {
|
2012-01-19 12:38:36 +01:00
|
|
|
|
params collectFirst { case d: Deserializer ⇒ d } getOrElse new ZMQMessageDeserializer
|
2012-01-19 00:26:52 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private def setupSocket() = {
|
|
|
|
|
|
params foreach {
|
2012-01-19 09:50:59 +01:00
|
|
|
|
case _: SocketConnectOption | _: PubSubOption | _: SocketMeta ⇒ // ignore, handled differently
|
|
|
|
|
|
case m ⇒ self ! m
|
2012-01-19 00:26:52 +01:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2012-04-04 11:08:28 +02:00
|
|
|
|
override def preRestart(reason: Throwable, message: Option[Any]) {
|
|
|
|
|
|
context.children foreach context.stop //Do not call postStop
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
override def postRestart(reason: Throwable) {} //Do nothing
|
|
|
|
|
|
|
2012-01-14 03:16:39 +01:00
|
|
|
|
override def postStop {
|
2012-01-16 00:44:09 +01:00
|
|
|
|
try {
|
2012-04-04 11:08:28 +02:00
|
|
|
|
if (socket != null) {
|
|
|
|
|
|
poller.unregister(socket)
|
|
|
|
|
|
socket.close
|
|
|
|
|
|
}
|
2012-01-16 00:44:09 +01:00
|
|
|
|
} finally {
|
|
|
|
|
|
notifyListener(Closed)
|
|
|
|
|
|
}
|
2012-01-14 03:16:39 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
2012-03-30 15:24:57 +02:00
|
|
|
|
private def sendMessage(frames: Seq[Frame]) {
|
2012-01-16 00:44:09 +01:00
|
|
|
|
def sendBytes(bytes: Seq[Byte], flags: Int) = socket.send(bytes.toArray, flags)
|
2012-01-14 03:16:39 +01:00
|
|
|
|
val iter = frames.iterator
|
|
|
|
|
|
while (iter.hasNext) {
|
|
|
|
|
|
val payload = iter.next.payload
|
|
|
|
|
|
val flags = if (iter.hasNext) JZMQ.SNDMORE else 0
|
|
|
|
|
|
sendBytes(payload, flags)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2012-03-30 15:24:57 +02:00
|
|
|
|
// this is a “PollMsg=>Unit” which either polls or schedules Poll, depending on the sign of the timeout
|
|
|
|
|
|
private val doPollTimeout = {
|
|
|
|
|
|
val fromConfig = params collectFirst { case PollTimeoutDuration(duration) ⇒ duration }
|
|
|
|
|
|
val duration = fromConfig getOrElse ZeroMQExtension(context.system).DefaultPollTimeout
|
|
|
|
|
|
if (duration > Duration.Zero) { (msg: PollMsg) ⇒
|
|
|
|
|
|
// for positive timeout values, do poll (i.e. block this thread)
|
|
|
|
|
|
poller.poll(duration.toMicros)
|
|
|
|
|
|
self ! msg
|
|
|
|
|
|
} else {
|
|
|
|
|
|
val d = -duration
|
|
|
|
|
|
|
|
|
|
|
|
{ (msg: PollMsg) ⇒
|
|
|
|
|
|
// for negative timeout values, schedule Poll token -duration into the future
|
|
|
|
|
|
context.system.scheduler.scheduleOnce(d, self, msg)
|
|
|
|
|
|
()
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2012-01-19 00:26:52 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
2012-04-03 12:12:52 +02:00
|
|
|
|
@tailrec private def doPoll(mode: PollMsg, togo: Int = 10): Unit =
|
|
|
|
|
|
receiveMessage(mode) match {
|
2012-03-30 15:24:57 +02:00
|
|
|
|
case null ⇒ // receiveMessage has already done something special here
|
2012-04-03 12:12:52 +02:00
|
|
|
|
case Seq() ⇒ doPollTimeout(mode)
|
2012-03-30 15:24:57 +02:00
|
|
|
|
case frames ⇒
|
|
|
|
|
|
notifyListener(deserializer(frames))
|
2012-04-03 12:12:52 +02:00
|
|
|
|
if (togo > 0) doPoll(mode, togo - 1)
|
|
|
|
|
|
else self ! mode
|
2012-01-14 03:16:39 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
2012-04-03 12:12:52 +02:00
|
|
|
|
@tailrec private def receiveMessage(mode: PollMsg, currentFrames: Vector[Frame] = Vector.empty): Seq[Frame] = {
|
|
|
|
|
|
val result = mode match {
|
|
|
|
|
|
case Poll ⇒ socket.recv(JZMQ.NOBLOCK)
|
|
|
|
|
|
case PollCareful ⇒ if (poller.poll(0) > 0) socket.recv(0) else null
|
|
|
|
|
|
}
|
|
|
|
|
|
result match {
|
2012-03-30 15:24:57 +02:00
|
|
|
|
case null ⇒
|
|
|
|
|
|
if (currentFrames.isEmpty) currentFrames
|
2012-04-03 12:12:52 +02:00
|
|
|
|
else throw new IllegalStateException("no more frames available while socket.hasReceivedMore==true")
|
2012-03-30 15:24:57 +02:00
|
|
|
|
case bytes ⇒
|
|
|
|
|
|
val frames = currentFrames :+ Frame(if (bytes.length == 0) noBytes else bytes)
|
2012-04-03 12:12:52 +02:00
|
|
|
|
if (socket.hasReceiveMore) receiveMessage(mode, frames) else frames
|
2012-03-30 15:24:57 +02:00
|
|
|
|
}
|
2012-01-14 03:16:39 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
2012-01-19 12:50:51 +01:00
|
|
|
|
private val listenerOpt = params collectFirst { case Listener(l) ⇒ l }
|
2012-05-24 12:34:18 +02:00
|
|
|
|
private def watchListener(): Unit = listenerOpt foreach context.watch
|
|
|
|
|
|
private def notifyListener(message: Any): Unit = listenerOpt foreach { _ ! message }
|
2012-01-14 03:16:39 +01:00
|
|
|
|
}
|