pekko/akka-zeromq/src/main/scala/akka/zeromq/ConcurrentSocketActor.scala

226 lines
7.9 KiB
Scala
Raw Normal View History

/**
2012-01-18 21:01:14 +01:00
* Copyright (C) 2009-2012 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.zeromq
import org.zeromq.ZMQ.{ Socket, Poller }
import org.zeromq.{ ZMQ JZMQ }
import akka.actor._
import akka.dispatch.{ Promise, Future }
import akka.event.Logging
import akka.util.duration._
private[zeromq] sealed trait PollLifeCycle
private[zeromq] case object NoResults extends PollLifeCycle
private[zeromq] case object Results extends PollLifeCycle
private[zeromq] case object Closing extends PollLifeCycle
private[zeromq] class ConcurrentSocketActor(params: Seq[SocketOption]) extends Actor {
private val noBytes = Array[Byte]()
private val zmqContext = {
params find (_.isInstanceOf[Context]) map (_.asInstanceOf[Context]) getOrElse new Context(1)
}
private lazy val deserializer = deserializerFromParams
private lazy val socket: Socket = socketFromParams
private lazy val poller: Poller = zmqContext.poller
private val log = Logging(context.system, this)
private case object Poll
private case object ReceiveFrames
private case object ClearPoll
private case class PollError(ex: Throwable)
private def handleConnectionMessages: Receive = {
case Send(frames) {
sendFrames(frames)
pollAndReceiveFrames()
}
case ZMQMessage(frames) {
sendFrames(frames)
pollAndReceiveFrames()
}
case Connect(endpoint) {
socket.connect(endpoint)
notifyListener(Connecting)
pollAndReceiveFrames()
}
case Bind(endpoint) {
socket.bind(endpoint)
pollAndReceiveFrames()
}
case Subscribe(topic) {
socket.subscribe(topic.toArray)
pollAndReceiveFrames()
}
case Unsubscribe(topic) {
socket.unsubscribe(topic.toArray)
pollAndReceiveFrames()
}
}
private def handleSocketOption: Receive = {
case Linger sender ! socket.getLinger
case ReconnectIVL sender ! socket.getReconnectIVL
case Backlog sender ! socket.getBacklog
case ReconnectIVLMax sender ! socket.getReconnectIVLMax
case MaxMsgSize sender ! socket.getMaxMsgSize
case SndHWM sender ! socket.getSndHWM
case RcvHWM sender ! socket.getRcvHWM
case Swap sender ! socket.getSwap
case Affinity sender ! socket.getAffinity
case Identity sender ! socket.getIdentity
case Rate sender ! socket.getRate
case RecoveryInterval sender ! socket.getRecoveryInterval
case MulticastLoop sender ! socket.hasMulticastLoop
case MulticastHops sender ! socket.getMulticastHops
case ReceiveTimeOut sender ! socket.getReceiveTimeOut
case SendTimeOut sender ! socket.getSendTimeOut
case SendBufferSize sender ! socket.getSendBufferSize
case ReceiveBufferSize sender ! socket.getReceiveBufferSize
case ReceiveMore sender ! socket.hasReceiveMore
case FileDescriptor sender ! socket.getFD
}
private def internalMessage: Receive = {
case Poll {
currentPoll = None
pollAndReceiveFrames()
}
case ReceiveFrames {
receiveFrames() match {
case Seq()
case frames notifyListener(deserializer(frames))
}
self ! Poll
}
case ClearPoll currentPoll = None
case PollError(ex) {
log.error(ex, "There was a problem polling the zeromq socket")
self ! Poll
}
}
override def receive: Receive = handleConnectionMessages orElse handleSocketOption orElse internalMessage
override def preStart {
setupSocket()
poller.register(socket, Poller.POLLIN)
}
private def socketFromParams() = {
require(ZeroMQExtension.check[SocketType.ZMQSocketType](params), "A socket type is required")
(params
find (_.isInstanceOf[SocketType.ZMQSocketType])
map (t zmqContext.socket(t.asInstanceOf[SocketType.ZMQSocketType])) get)
}
private def deserializerFromParams = {
params find (_.isInstanceOf[Deserializer]) map (_.asInstanceOf[Deserializer]) getOrElse new ZMQMessageDeserializer
}
private def setupSocket() = {
params foreach {
case Linger(value) socket.setLinger(value)
case ReconnectIVL(value) socket.setReconnectIVL(value)
case Backlog(value) socket.setBacklog(value)
case ReconnectIVLMax(value) socket.setReconnectIVLMax(value)
case MaxMsgSize(value) socket.setMaxMsgSize(value)
case SndHWM(value) socket.setSndHWM(value)
case RcvHWM(value) socket.setRcvHWM(value)
case HWM(value) socket.setHWM(value)
case Swap(value) socket.setSwap(value)
case Affinity(value) socket.setAffinity(value)
case Identity(value) socket.setIdentity(value)
case Rate(value) socket.setRate(value)
case RecoveryInterval(value) socket.setRecoveryInterval(value)
case MulticastLoop(value) socket.setMulticastLoop(value)
case MulticastHops(value) socket.setMulticastHops(value)
case ReceiveTimeOut(value) socket.setReceiveTimeOut(value)
case SendTimeOut(value) socket.setSendTimeOut(value)
case SendBufferSize(value) socket.setSendBufferSize(value)
case ReceiveBufferSize(value) socket.setReceiveBufferSize(value)
case _
}
}
override def postStop {
try {
poller.unregister(socket)
currentPoll foreach { _ complete Right(Closing) }
if (socket != null) socket.close
} finally {
notifyListener(Closed)
}
}
private def sendFrames(frames: Seq[Frame]) {
def sendBytes(bytes: Seq[Byte], flags: Int) = socket.send(bytes.toArray, flags)
val iter = frames.iterator
while (iter.hasNext) {
val payload = iter.next.payload
val flags = if (iter.hasNext) JZMQ.SNDMORE else 0
sendBytes(payload, flags)
}
}
private var currentPoll: Option[Promise[PollLifeCycle]] = None
private def pollAndReceiveFrames() {
if (currentPoll.isEmpty) currentPoll = newEventLoop
}
private lazy val eventLoopDispatcher = {
val fromConfig = params.find(_.isInstanceOf[PollDispatcher]) map {
option context.system.dispatchers.lookup(option.asInstanceOf[PollDispatcher].name)
}
fromConfig getOrElse context.system.dispatcher
}
private lazy val pollTimeout = {
val fromConfig = params find (_.isInstanceOf[PollTimeoutDuration]) map (_.asInstanceOf[PollTimeoutDuration].duration)
fromConfig getOrElse 100.millis
2012-01-18 21:01:14 +01:00
}
private def newEventLoop: Option[Promise[PollLifeCycle]] = {
2012-01-18 21:01:14 +01:00
implicit val executor = eventLoopDispatcher
2012-01-14 11:46:51 +01:00
Some((Future {
if (poller.poll(pollTimeout.toMicros) > 0 && poller.pollin(0)) Results else NoResults
}).asInstanceOf[Promise[PollLifeCycle]] onSuccess {
case Results self ! ReceiveFrames
case NoResults self ! Poll
case _ self ! ClearPoll
} onFailure {
case ex self ! PollError(ex)
2012-01-14 11:46:51 +01:00
})
}
private def receiveFrames(): Seq[Frame] = {
@inline def receiveBytes(): Array[Byte] = socket.recv(0) match {
2012-01-14 11:46:51 +01:00
case null noBytes
case bytes: Array[_] if bytes.length > 0 bytes
case _ noBytes
}
receiveBytes() match {
case `noBytes` Vector.empty
case someBytes
var frames = Vector(Frame(someBytes))
while (socket.hasReceiveMore) receiveBytes() match {
case `noBytes`
case someBytes frames :+= Frame(someBytes)
}
frames
}
}
private def notifyListener(message: Any) {
params find (_.isInstanceOf[Listener]) map (_.asInstanceOf[Listener].listener) foreach { listener
if (listener.isTerminated)
context stop self
else
listener ! message
}
}
}