pekko/akka-zeromq/src/main/scala/akka/zeromq/ConcurrentSocketActor.scala

211 lines
8.2 KiB
Scala
Raw Normal View History

/**
2012-01-18 21:01:14 +01:00
* Copyright (C) 2009-2012 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.zeromq
import org.zeromq.ZMQ.{ Socket, Poller }
import org.zeromq.{ ZMQ JZMQ }
import akka.actor._
import akka.dispatch.{ Promise, Future }
import akka.event.Logging
import annotation.tailrec
2012-01-19 13:06:22 +01:00
import akka.util.Duration
import java.util.concurrent.TimeUnit
private[zeromq] object ConcurrentSocketActor {
private sealed trait PollMsg
private case object Poll extends PollMsg
private case object PollCareful extends PollMsg
private class NoSocketHandleException() extends Exception("Couldn't create a zeromq socket.")
private val DefaultContext = Context()
}
private[zeromq] class ConcurrentSocketActor(params: Seq[SocketOption]) extends Actor {
import ConcurrentSocketActor._
private val noBytes = Array[Byte]()
private val zmqContext = params collectFirst { case c: Context c } getOrElse DefaultContext
private val deserializer = deserializerFromParams
private val socketType = {
import SocketType.{ ZMQSocketType ST }
params.collectFirst { case t: ST t }.getOrElse(throw new IllegalArgumentException("A socket type is required"))
}
private val socket: Socket = zmqContext.socket(socketType)
private val poller: Poller = zmqContext.poller
private val log = Logging(context.system, this)
def receive = {
case m: PollMsg doPoll(m)
case ZMQMessage(frames) sendMessage(frames)
case r: Request handleRequest(r)
case Terminated(_) context stop self
}
private def handleRequest(msg: Request): Unit = msg match {
case Send(frames) sendMessage(frames)
case opt: SocketOption handleSocketOption(opt)
case q: SocketOptionQuery handleSocketOptionQuery(q)
}
private def handleConnectOption(msg: SocketConnectOption): Unit = msg match {
case Connect(endpoint) socket.connect(endpoint); notifyListener(Connecting)
case Bind(endpoint) socket.bind(endpoint)
}
private def handlePubSubOption(msg: PubSubOption): Unit = msg match {
case Subscribe(topic) socket.subscribe(topic.toArray)
case Unsubscribe(topic) socket.unsubscribe(topic.toArray)
}
private def handleSocketOption(msg: SocketOption): Unit = msg match {
case x: SocketMeta throw new IllegalStateException("SocketMeta " + x + " only allowed for setting up a socket")
case c: SocketConnectOption handleConnectOption(c)
case ps: PubSubOption handlePubSubOption(ps)
case Linger(value) socket.setLinger(value)
case ReconnectIVL(value) socket.setReconnectIVL(value)
case Backlog(value) socket.setBacklog(value)
case ReconnectIVLMax(value) socket.setReconnectIVLMax(value)
case MaxMsgSize(value) socket.setMaxMsgSize(value)
case SendHighWatermark(value) socket.setSndHWM(value)
case ReceiveHighWatermark(value) socket.setRcvHWM(value)
case HighWatermark(value) socket.setHWM(value)
case Swap(value) socket.setSwap(value)
case Affinity(value) socket.setAffinity(value)
case Identity(value) socket.setIdentity(value)
case Rate(value) socket.setRate(value)
case RecoveryInterval(value) socket.setRecoveryInterval(value)
case MulticastLoop(value) socket.setMulticastLoop(value)
case MulticastHops(value) socket.setMulticastHops(value)
case SendBufferSize(value) socket.setSendBufferSize(value)
case ReceiveBufferSize(value) socket.setReceiveBufferSize(value)
}
2012-04-02 12:25:30 +02:00
private def handleSocketOptionQuery(msg: SocketOptionQuery): Unit =
sender ! (msg match {
case Linger socket.getLinger
case ReconnectIVL socket.getReconnectIVL
case Backlog socket.getBacklog
case ReconnectIVLMax socket.getReconnectIVLMax
case MaxMsgSize socket.getMaxMsgSize
case SendHighWatermark socket.getSndHWM
case ReceiveHighWatermark socket.getRcvHWM
case Swap socket.getSwap
case Affinity socket.getAffinity
case Identity socket.getIdentity
case Rate socket.getRate
case RecoveryInterval socket.getRecoveryInterval
case MulticastLoop socket.hasMulticastLoop
case MulticastHops socket.getMulticastHops
case SendBufferSize socket.getSendBufferSize
case ReceiveBufferSize socket.getReceiveBufferSize
case FileDescriptor socket.getFD
})
override def preStart {
watchListener()
setupSocket()
poller.register(socket, Poller.POLLIN)
setupConnection()
import SocketType._
socketType match {
case Pub | Push // dont poll
case Sub | Pull | Pair | Dealer | Router self ! Poll
case Req | Rep self ! PollCareful
}
}
private def setupConnection() {
params filter (_.isInstanceOf[SocketConnectOption]) foreach { self ! _ }
params filter (_.isInstanceOf[PubSubOption]) foreach { self ! _ }
}
private def deserializerFromParams = {
params collectFirst { case d: Deserializer d } getOrElse new ZMQMessageDeserializer
}
private def setupSocket() = {
params foreach {
case _: SocketConnectOption | _: PubSubOption | _: SocketMeta // ignore, handled differently
case m self ! m
}
}
override def preRestart(reason: Throwable, message: Option[Any]) {
context.children foreach context.stop //Do not call postStop
}
override def postRestart(reason: Throwable) {} //Do nothing
override def postStop {
try {
if (socket != null) {
poller.unregister(socket)
socket.close
}
} finally {
notifyListener(Closed)
}
}
private def sendMessage(frames: Seq[Frame]) {
def sendBytes(bytes: Seq[Byte], flags: Int) = socket.send(bytes.toArray, flags)
val iter = frames.iterator
while (iter.hasNext) {
val payload = iter.next.payload
val flags = if (iter.hasNext) JZMQ.SNDMORE else 0
sendBytes(payload, flags)
}
}
// this is a “PollMsg=>Unit” which either polls or schedules Poll, depending on the sign of the timeout
private val doPollTimeout = {
val fromConfig = params collectFirst { case PollTimeoutDuration(duration) duration }
val duration = fromConfig getOrElse ZeroMQExtension(context.system).DefaultPollTimeout
if (duration > Duration.Zero) { (msg: PollMsg)
// for positive timeout values, do poll (i.e. block this thread)
poller.poll(duration.toMicros)
self ! msg
} else {
val d = -duration
{ (msg: PollMsg)
// for negative timeout values, schedule Poll token -duration into the future
context.system.scheduler.scheduleOnce(d, self, msg)
()
}
}
}
@tailrec private def doPoll(mode: PollMsg, togo: Int = 10): Unit =
receiveMessage(mode) match {
case null // receiveMessage has already done something special here
case Seq() doPollTimeout(mode)
case frames
notifyListener(deserializer(frames))
if (togo > 0) doPoll(mode, togo - 1)
else self ! mode
}
@tailrec private def receiveMessage(mode: PollMsg, currentFrames: Vector[Frame] = Vector.empty): Seq[Frame] = {
val result = mode match {
case Poll socket.recv(JZMQ.NOBLOCK)
case PollCareful if (poller.poll(0) > 0) socket.recv(0) else null
}
result match {
case null
if (currentFrames.isEmpty) currentFrames
else throw new IllegalStateException("no more frames available while socket.hasReceivedMore==true")
case bytes
val frames = currentFrames :+ Frame(if (bytes.length == 0) noBytes else bytes)
if (socket.hasReceiveMore) receiveMessage(mode, frames) else frames
}
}
private val listenerOpt = params collectFirst { case Listener(l) l }
2012-05-24 12:34:18 +02:00
private def watchListener(): Unit = listenerOpt foreach context.watch
private def notifyListener(message: Any): Unit = listenerOpt foreach { _ ! message }
}