2012-01-14 03:16:39 +01:00
|
|
|
/**
|
2012-01-18 21:01:14 +01:00
|
|
|
* Copyright (C) 2009-2012 Typesafe Inc. <http://www.typesafe.com>
|
2012-01-14 03:16:39 +01:00
|
|
|
*/
|
|
|
|
|
package akka.zeromq
|
|
|
|
|
|
|
|
|
|
import org.zeromq.ZMQ.{ Socket, Poller }
|
|
|
|
|
import org.zeromq.{ ZMQ ⇒ JZMQ }
|
|
|
|
|
import akka.actor._
|
2012-01-14 12:13:46 +01:00
|
|
|
import akka.dispatch.{ Promise, Future }
|
2012-01-16 10:48:23 +01:00
|
|
|
import akka.event.Logging
|
2012-01-19 12:38:36 +01:00
|
|
|
import annotation.tailrec
|
2012-01-19 13:06:22 +01:00
|
|
|
import akka.util.Duration
|
|
|
|
|
import java.util.concurrent.TimeUnit
|
2012-01-14 03:16:39 +01:00
|
|
|
|
2012-01-19 12:04:35 +01:00
|
|
|
private[zeromq] object ConcurrentSocketActor {
|
2012-03-30 15:24:57 +02:00
|
|
|
private trait PollMsg
|
|
|
|
|
private case object Poll extends PollMsg
|
|
|
|
|
private case class ContinuePoll(frames: Vector[Frame]) extends PollMsg {
|
|
|
|
|
println("continue")
|
|
|
|
|
}
|
2012-01-19 12:04:35 +01:00
|
|
|
|
2012-01-20 00:34:36 +01:00
|
|
|
private class NoSocketHandleException() extends Exception("Couldn't create a zeromq socket.")
|
|
|
|
|
|
|
|
|
|
private val DefaultContext = Context()
|
2012-01-19 12:04:35 +01:00
|
|
|
}
|
2012-01-19 00:26:52 +01:00
|
|
|
private[zeromq] class ConcurrentSocketActor(params: Seq[SocketOption]) extends Actor {
|
2012-01-14 03:16:39 +01:00
|
|
|
|
2012-01-19 12:04:35 +01:00
|
|
|
import ConcurrentSocketActor._
|
2012-01-14 03:16:39 +01:00
|
|
|
private val noBytes = Array[Byte]()
|
2012-03-30 15:24:57 +02:00
|
|
|
private val zmqContext = params collectFirst { case c: Context ⇒ c } getOrElse DefaultContext
|
2012-01-16 10:48:23 +01:00
|
|
|
|
2012-01-19 12:38:36 +01:00
|
|
|
private val deserializer = deserializerFromParams
|
|
|
|
|
private val socket: Socket = socketFromParams
|
|
|
|
|
private val poller: Poller = zmqContext.poller
|
|
|
|
|
private val log = Logging(context.system, this)
|
2012-01-16 10:54:46 +01:00
|
|
|
|
2012-03-30 15:24:57 +02:00
|
|
|
def receive = {
|
|
|
|
|
case Poll ⇒ doPoll()
|
|
|
|
|
case ContinuePoll(frames) ⇒ doPoll(currentFrames = frames)
|
|
|
|
|
case ZMQMessage(frames) ⇒ sendMessage(frames)
|
|
|
|
|
case r: Request ⇒ handleRequest(r)
|
|
|
|
|
case Terminated(_) ⇒ context stop self
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private def handleRequest(msg: Request): Unit = msg match {
|
|
|
|
|
case Send(frames) ⇒ sendMessage(frames)
|
|
|
|
|
case opt: SocketOption ⇒ handleSocketOption(opt)
|
|
|
|
|
case q: SocketOptionQuery ⇒ handleSocketOptionQuery(q)
|
2012-01-16 00:44:09 +01:00
|
|
|
}
|
2012-01-16 10:54:46 +01:00
|
|
|
|
2012-03-30 15:24:57 +02:00
|
|
|
private def handleConnectOption(msg: SocketConnectOption): Unit = msg match {
|
|
|
|
|
case Connect(endpoint) ⇒ socket.connect(endpoint); notifyListener(Connecting)
|
|
|
|
|
case Bind(endpoint) ⇒ socket.bind(endpoint)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private def handlePubSubOption(msg: PubSubOption): Unit = msg match {
|
|
|
|
|
case Subscribe(topic) ⇒ socket.subscribe(topic.toArray)
|
|
|
|
|
case Unsubscribe(topic) ⇒ socket.unsubscribe(topic.toArray)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private def handleSocketOption(msg: SocketOption): Unit = msg match {
|
|
|
|
|
case x: SocketMeta ⇒ throw new IllegalStateException("SocketMeta " + x + " only allowed for setting up a socket")
|
|
|
|
|
case c: SocketConnectOption ⇒ handleConnectOption(c)
|
|
|
|
|
case ps: PubSubOption ⇒ handlePubSubOption(ps)
|
2012-01-20 01:16:24 +01:00
|
|
|
case Linger(value) ⇒ socket.setLinger(value)
|
|
|
|
|
case ReconnectIVL(value) ⇒ socket.setReconnectIVL(value)
|
|
|
|
|
case Backlog(value) ⇒ socket.setBacklog(value)
|
|
|
|
|
case ReconnectIVLMax(value) ⇒ socket.setReconnectIVLMax(value)
|
|
|
|
|
case MaxMsgSize(value) ⇒ socket.setMaxMsgSize(value)
|
|
|
|
|
case SendHighWatermark(value) ⇒ socket.setSndHWM(value)
|
|
|
|
|
case ReceiveHighWatermark(value) ⇒ socket.setRcvHWM(value)
|
|
|
|
|
case HighWatermark(value) ⇒ socket.setHWM(value)
|
|
|
|
|
case Swap(value) ⇒ socket.setSwap(value)
|
|
|
|
|
case Affinity(value) ⇒ socket.setAffinity(value)
|
|
|
|
|
case Identity(value) ⇒ socket.setIdentity(value)
|
|
|
|
|
case Rate(value) ⇒ socket.setRate(value)
|
|
|
|
|
case RecoveryInterval(value) ⇒ socket.setRecoveryInterval(value)
|
|
|
|
|
case MulticastLoop(value) ⇒ socket.setMulticastLoop(value)
|
|
|
|
|
case MulticastHops(value) ⇒ socket.setMulticastHops(value)
|
|
|
|
|
case SendBufferSize(value) ⇒ socket.setSendBufferSize(value)
|
|
|
|
|
case ReceiveBufferSize(value) ⇒ socket.setReceiveBufferSize(value)
|
2012-01-14 03:16:39 +01:00
|
|
|
}
|
|
|
|
|
|
2012-03-30 15:24:57 +02:00
|
|
|
private def handleSocketOptionQuery(msg: SocketOptionQuery): Unit = msg match {
|
|
|
|
|
case Linger ⇒ sender ! socket.getLinger
|
|
|
|
|
case ReconnectIVL ⇒ sender ! socket.getReconnectIVL
|
|
|
|
|
case Backlog ⇒ sender ! socket.getBacklog
|
|
|
|
|
case ReconnectIVLMax ⇒ sender ! socket.getReconnectIVLMax
|
|
|
|
|
case MaxMsgSize ⇒ sender ! socket.getMaxMsgSize
|
|
|
|
|
case SendHighWatermark ⇒ sender ! socket.getSndHWM
|
|
|
|
|
case ReceiveHighWatermark ⇒ sender ! socket.getRcvHWM
|
|
|
|
|
case Swap ⇒ sender ! socket.getSwap
|
|
|
|
|
case Affinity ⇒ sender ! socket.getAffinity
|
|
|
|
|
case Identity ⇒ sender ! socket.getIdentity
|
|
|
|
|
case Rate ⇒ sender ! socket.getRate
|
|
|
|
|
case RecoveryInterval ⇒ sender ! socket.getRecoveryInterval
|
|
|
|
|
case MulticastLoop ⇒ sender ! socket.hasMulticastLoop
|
|
|
|
|
case MulticastHops ⇒ sender ! socket.getMulticastHops
|
|
|
|
|
case SendBufferSize ⇒ sender ! socket.getSendBufferSize
|
|
|
|
|
case ReceiveBufferSize ⇒ sender ! socket.getReceiveBufferSize
|
|
|
|
|
case FileDescriptor ⇒ sender ! socket.getFD
|
|
|
|
|
}
|
2012-01-16 00:44:09 +01:00
|
|
|
|
2012-01-14 03:16:39 +01:00
|
|
|
override def preStart {
|
2012-01-19 12:38:36 +01:00
|
|
|
watchListener()
|
2012-01-19 00:26:52 +01:00
|
|
|
setupSocket()
|
2012-01-14 03:16:39 +01:00
|
|
|
poller.register(socket, Poller.POLLIN)
|
2012-01-19 09:50:59 +01:00
|
|
|
setupConnection()
|
2012-03-30 15:24:57 +02:00
|
|
|
self ! Poll
|
2012-01-19 09:50:59 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private def setupConnection() {
|
|
|
|
|
params filter (_.isInstanceOf[SocketConnectOption]) foreach { self ! _ }
|
|
|
|
|
params filter (_.isInstanceOf[PubSubOption]) foreach { self ! _ }
|
2012-01-14 03:16:39 +01:00
|
|
|
}
|
|
|
|
|
|
2012-01-19 00:26:52 +01:00
|
|
|
private def socketFromParams() = {
|
|
|
|
|
require(ZeroMQExtension.check[SocketType.ZMQSocketType](params), "A socket type is required")
|
2012-01-20 00:36:33 +01:00
|
|
|
(params
|
|
|
|
|
collectFirst { case t: SocketType.ZMQSocketType ⇒ zmqContext.socket(t) }
|
|
|
|
|
getOrElse (throw new NoSocketHandleException))
|
2012-01-19 00:26:52 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private def deserializerFromParams = {
|
2012-01-19 12:38:36 +01:00
|
|
|
params collectFirst { case d: Deserializer ⇒ d } getOrElse new ZMQMessageDeserializer
|
2012-01-19 00:26:52 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private def setupSocket() = {
|
|
|
|
|
params foreach {
|
2012-01-19 09:50:59 +01:00
|
|
|
case _: SocketConnectOption | _: PubSubOption | _: SocketMeta ⇒ // ignore, handled differently
|
|
|
|
|
case m ⇒ self ! m
|
2012-01-19 00:26:52 +01:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2012-01-14 03:16:39 +01:00
|
|
|
override def postStop {
|
2012-01-16 00:44:09 +01:00
|
|
|
try {
|
2012-01-19 12:50:51 +01:00
|
|
|
poller.unregister(socket)
|
2012-01-16 00:44:09 +01:00
|
|
|
if (socket != null) socket.close
|
|
|
|
|
} finally {
|
|
|
|
|
notifyListener(Closed)
|
|
|
|
|
}
|
2012-01-14 03:16:39 +01:00
|
|
|
}
|
|
|
|
|
|
2012-03-30 15:24:57 +02:00
|
|
|
private def sendMessage(frames: Seq[Frame]) {
|
2012-01-16 00:44:09 +01:00
|
|
|
def sendBytes(bytes: Seq[Byte], flags: Int) = socket.send(bytes.toArray, flags)
|
2012-01-14 03:16:39 +01:00
|
|
|
val iter = frames.iterator
|
|
|
|
|
while (iter.hasNext) {
|
|
|
|
|
val payload = iter.next.payload
|
|
|
|
|
val flags = if (iter.hasNext) JZMQ.SNDMORE else 0
|
|
|
|
|
sendBytes(payload, flags)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2012-03-30 15:24:57 +02:00
|
|
|
// this is a “PollMsg=>Unit” which either polls or schedules Poll, depending on the sign of the timeout
|
|
|
|
|
private val doPollTimeout = {
|
|
|
|
|
val fromConfig = params collectFirst { case PollTimeoutDuration(duration) ⇒ duration }
|
|
|
|
|
val duration = fromConfig getOrElse ZeroMQExtension(context.system).DefaultPollTimeout
|
|
|
|
|
if (duration > Duration.Zero) { (msg: PollMsg) ⇒
|
|
|
|
|
// for positive timeout values, do poll (i.e. block this thread)
|
|
|
|
|
poller.poll(duration.toMicros)
|
|
|
|
|
self ! msg
|
|
|
|
|
} else {
|
|
|
|
|
val d = -duration
|
|
|
|
|
|
|
|
|
|
{ (msg: PollMsg) ⇒
|
|
|
|
|
// for negative timeout values, schedule Poll token -duration into the future
|
|
|
|
|
context.system.scheduler.scheduleOnce(d, self, msg)
|
|
|
|
|
()
|
|
|
|
|
}
|
|
|
|
|
}
|
2012-01-19 00:26:52 +01:00
|
|
|
}
|
|
|
|
|
|
2012-03-30 15:24:57 +02:00
|
|
|
@tailrec private def doPoll(togo: Int = 10, currentFrames: Vector[Frame] = Vector.empty): Unit =
|
|
|
|
|
receiveMessage(currentFrames) match {
|
|
|
|
|
case null ⇒ // receiveMessage has already done something special here
|
|
|
|
|
case Seq() ⇒ doPollTimeout(Poll)
|
|
|
|
|
case frames ⇒
|
|
|
|
|
notifyListener(deserializer(frames))
|
|
|
|
|
if (togo > 0) doPoll(togo - 1)
|
|
|
|
|
else self ! Poll
|
2012-01-14 03:16:39 +01:00
|
|
|
}
|
|
|
|
|
|
2012-03-30 15:24:57 +02:00
|
|
|
@tailrec private def receiveMessage(currentFrames: Vector[Frame]): Seq[Frame] = {
|
|
|
|
|
socket.recv(JZMQ.NOBLOCK) match {
|
|
|
|
|
case null ⇒
|
|
|
|
|
if (currentFrames.isEmpty) currentFrames
|
|
|
|
|
else {
|
|
|
|
|
doPollTimeout(ContinuePoll(currentFrames))
|
|
|
|
|
null
|
|
|
|
|
}
|
|
|
|
|
case bytes ⇒
|
|
|
|
|
val frames = currentFrames :+ Frame(if (bytes.length == 0) noBytes else bytes)
|
|
|
|
|
if (socket.hasReceiveMore) receiveMessage(frames) else frames
|
|
|
|
|
}
|
2012-01-14 03:16:39 +01:00
|
|
|
}
|
|
|
|
|
|
2012-01-19 12:50:51 +01:00
|
|
|
private val listenerOpt = params collectFirst { case Listener(l) ⇒ l }
|
2012-01-19 12:38:36 +01:00
|
|
|
private def watchListener() {
|
|
|
|
|
listenerOpt foreach context.watch
|
|
|
|
|
}
|
|
|
|
|
|
2012-01-14 03:16:39 +01:00
|
|
|
private def notifyListener(message: Any) {
|
2012-01-19 12:38:36 +01:00
|
|
|
listenerOpt foreach { _ ! message }
|
2012-01-14 03:16:39 +01:00
|
|
|
}
|
|
|
|
|
}
|