+act #15502: Pluggable DNS resolution in akka-io

This commit is contained in:
Ilya Epifanov 2014-06-21 16:20:20 +04:00
parent b88c964bd4
commit 5954302658
11 changed files with 446 additions and 24 deletions

View file

@ -0,0 +1,45 @@
package akka.io
import java.net.InetAddress
import java.util.concurrent.atomic.AtomicLong
import org.scalatest.{ ShouldMatchers, WordSpec }
class SimpleDnsCacheSpec extends WordSpec with ShouldMatchers {
"Cache" should {
"not reply with expired but not yet swept out entries" in {
val localClock = new AtomicLong(0)
val cache: SimpleDnsCache = new SimpleDnsCache() {
override protected def clock() = localClock.get
}
val cacheEntry = Dns.Resolved("test.local", Seq(InetAddress.getByName("127.0.0.1")))
cache.put(cacheEntry, 5000)
cache.cached("test.local") should equal(Some(cacheEntry))
localClock.set(4999)
cache.cached("test.local") should equal(Some(cacheEntry))
localClock.set(5000)
cache.cached("test.local") should equal(None)
}
"sweep out expired entries on cleanup()" in {
val localClock = new AtomicLong(0)
val cache: SimpleDnsCache = new SimpleDnsCache() {
override protected def clock() = localClock.get
}
val cacheEntry = Dns.Resolved("test.local", Seq(InetAddress.getByName("127.0.0.1")))
cache.put(cacheEntry, 5000)
cache.cached("test.local") should equal(Some(cacheEntry))
localClock.set(5000)
cache.cached("test.local") should equal(None)
localClock.set(0)
cache.cached("test.local") should equal(Some(cacheEntry))
localClock.set(5000)
cache.cleanup()
cache.cached("test.local") should equal(None)
localClock.set(0)
cache.cached("test.local") should equal(None)
}
}
}

View file

@ -226,6 +226,12 @@ akka {
messages-per-resize = 10
}
}
/IO-DNS/inet-address {
mailbox = "unbounded"
router = "consistent-hashing-pool"
nr-of-instances = 4
}
}
default-dispatcher {
@ -729,6 +735,28 @@ akka {
management-dispatcher = "akka.actor.default-dispatcher"
}
dns {
# Fully qualified config path which holds the dispatcher configuration
# for the manager and resolver router actors.
# For actual router configuration see akka.actor.deployment./IO-DNS/*
dispatcher = "akka.actor.default-dispatcher"
# Name of the subconfig at path akka.io.dns, see inet-address below
resolver = "inet-address"
inet-address {
# Must implement akka.io.DnsProvider
provider-object = "akka.io.InetAddressDnsProvider"
# These TTLs are set to default java 6 values
positive-ttl = 30s
negative-ttl = 10s
# How often to sweep out expired cache entries.
# Note that this interval has nothing to do with TTLs
cache-cleanup-interval = 120s
}
}
}

View file

@ -0,0 +1,91 @@
package akka.io
import java.net.{ Inet4Address, Inet6Address, InetAddress, UnknownHostException }
import akka.actor._
import akka.routing.ConsistentHashingRouter.ConsistentHashable
import com.typesafe.config.Config
import scala.collection.{ breakOut, immutable }
abstract class Dns {
def cached(name: String): Option[Dns.Resolved] = None
def resolve(name: String)(system: ActorSystem, sender: ActorRef): Option[Dns.Resolved] = {
val ret = cached(name)
if (ret.isEmpty)
IO(Dns)(system).tell(Dns.Resolve(name), sender)
ret
}
}
object Dns extends ExtensionId[DnsExt] with ExtensionIdProvider {
sealed trait Command
case class Resolve(name: String) extends Command with ConsistentHashable {
override def consistentHashKey = name
}
case class Resolved(name: String, ipv4: immutable.Seq[Inet4Address], ipv6: immutable.Seq[Inet6Address]) extends Command {
val addrOption: Option[InetAddress] = ipv4.headOption orElse ipv6.headOption
@throws[UnknownHostException]
def addr: InetAddress = addrOption match {
case Some(addr) addr
case None throw new UnknownHostException(name)
}
}
object Resolved {
def apply(name: String, addresses: Iterable[InetAddress]): Resolved = {
val ipv4: immutable.Seq[Inet4Address] = addresses.collect({
case a: Inet4Address a
})(breakOut)
val ipv6: immutable.Seq[Inet6Address] = addresses.collect({
case a: Inet6Address a
})(breakOut)
Resolved(name, ipv4, ipv6)
}
}
def cached(name: String)(system: ActorSystem): Option[Resolved] = {
Dns(system).cache.cached(name)
}
def resolve(name: String)(system: ActorSystem, sender: ActorRef): Option[Resolved] = {
Dns(system).cache.resolve(name)(system, sender)
}
override def lookup() = Dns
override def createExtension(system: ExtendedActorSystem): DnsExt = new DnsExt(system)
/**
* Java API: retrieve the Udp extension for the given system.
*/
override def get(system: ActorSystem): DnsExt = super.get(system)
}
class DnsExt(system: ExtendedActorSystem) extends IO.Extension {
val Settings = new Settings(system.settings.config.getConfig("akka.io.dns"))
class Settings private[DnsExt] (_config: Config) {
import _config._
val Dispatcher: String = getString("dispatcher")
val Resolver: String = getString("resolver")
val ResolverConfig: Config = getConfig(Resolver)
val ProviderObjectName: String = ResolverConfig.getString("provider-object")
}
val provider: DnsProvider = system.dynamicAccess.getClassFor[DnsProvider](Settings.ProviderObjectName).get.newInstance()
val cache: Dns = provider.cache
val manager: ActorRef = {
system.systemActorOf(
props = Props(classOf[SimpleDnsManager], this).withDeploy(Deploy.local).withDispatcher(Settings.Dispatcher),
name = "IO-DNS")
}
def getResolver: ActorRef = manager
}

View file

@ -0,0 +1,9 @@
package akka.io
import akka.actor.Actor
trait DnsProvider {
def cache: Dns
def actorClass: Class[_ <: Actor]
def managerClass: Class[_ <: Actor]
}

View file

@ -0,0 +1,7 @@
package akka.io
class InetAddressDnsProvider extends DnsProvider {
override def cache: Dns = new SimpleDnsCache()
override def actorClass = classOf[InetAddressDnsResolver]
override def managerClass = classOf[SimpleDnsManager]
}

View file

@ -0,0 +1,33 @@
package akka.io
import java.net.{ UnknownHostException, InetAddress }
import java.util.concurrent.TimeUnit
import akka.actor.Actor
import com.typesafe.config.Config
import scala.collection.immutable
class InetAddressDnsResolver(cache: SimpleDnsCache, config: Config) extends Actor {
val positiveTtl = config.getDuration("positive-ttl", TimeUnit.MILLISECONDS)
val negativeTtl = config.getDuration("negative-ttl", TimeUnit.MILLISECONDS)
override def receive = {
case Dns.Resolve(name)
val answer = cache.cached(name) match {
case Some(a) a
case None
try {
val answer = Dns.Resolved(name, InetAddress.getAllByName(name))
cache.put(answer, positiveTtl)
answer
} catch {
case e: UnknownHostException
val answer = Dns.Resolved(name, immutable.Seq.empty, immutable.Seq.empty)
cache.put(answer, negativeTtl)
answer
}
}
sender() ! answer
}
}

View file

@ -0,0 +1,94 @@
package akka.io
import java.util.concurrent.atomic.AtomicReference
import akka.io.Dns.Resolved
import scala.annotation.tailrec
import scala.collection.immutable
private[io] sealed trait PeriodicCacheCleanup {
def cleanup(): Unit
}
class SimpleDnsCache extends Dns with PeriodicCacheCleanup {
import akka.io.SimpleDnsCache._
private val cache = new AtomicReference(new Cache(
immutable.SortedSet()(ExpiryEntryOrdering),
immutable.Map(), clock))
private val nanoBase = System.nanoTime()
override def cached(name: String): Option[Resolved] = {
cache.get().get(name)
}
protected def clock(): Long = {
val now = System.nanoTime()
if (now - nanoBase < 0) 0
else (now - nanoBase) / 1000000
}
@tailrec
private[io] final def put(r: Resolved, ttlMillis: Long): Unit = {
val c = cache.get()
if (!cache.compareAndSet(c, c.put(r, ttlMillis)))
put(r, ttlMillis)
}
@tailrec
override final def cleanup(): Unit = {
val c = cache.get()
if (!cache.compareAndSet(c, c.cleanup()))
cleanup()
}
}
object SimpleDnsCache {
private class Cache(queue: immutable.SortedSet[ExpiryEntry], cache: immutable.Map[String, CacheEntry], clock: () Long) {
def get(name: String): Option[Resolved] = {
for {
e cache.get(name)
if e.isValid(clock())
} yield e.answer
}
def put(answer: Resolved, ttlMillis: Long): Cache = {
val until = clock() + ttlMillis
new Cache(
queue + new ExpiryEntry(answer.name, until),
cache + (answer.name -> CacheEntry(answer, until)),
clock)
}
def cleanup(): Cache = {
val now = clock()
var q = queue
var c = cache
while (q.nonEmpty && !q.head.isValid(now)) {
val minEntry = q.head
val name = minEntry.name
q -= minEntry
if (c.get(name).filterNot(_.isValid(now)).isDefined)
c -= name
}
new Cache(q, c, clock)
}
}
private case class CacheEntry(answer: Dns.Resolved, until: Long) {
def isValid(clock: Long): Boolean = clock < until
}
private class ExpiryEntry(val name: String, val until: Long) extends Ordered[ExpiryEntry] {
def isValid(clock: Long): Boolean = clock < until
override def compare(that: ExpiryEntry): Int = -until.compareTo(that.until)
}
private object ExpiryEntryOrdering extends Ordering[ExpiryEntry] {
override def compare(x: ExpiryEntry, y: ExpiryEntry): Int = {
x.until.compareTo(y.until)
}
}
}

View file

@ -0,0 +1,42 @@
package akka.io
import java.util.concurrent.TimeUnit
import akka.actor.{ ActorLogging, Actor, Deploy, Props }
import akka.dispatch.{ RequiresMessageQueue, UnboundedMessageQueueSemantics }
import akka.routing.FromConfig
import scala.concurrent.duration.Duration
class SimpleDnsManager(val ext: DnsExt) extends Actor with RequiresMessageQueue[UnboundedMessageQueueSemantics] with ActorLogging {
import context._
private val resolver = actorOf(FromConfig.props(Props(ext.provider.actorClass, ext.cache, ext.Settings.ResolverConfig).withDeploy(Deploy.local).withDispatcher(ext.Settings.Dispatcher)), ext.Settings.Resolver)
private val cacheCleanup = ext.cache match {
case cleanup: PeriodicCacheCleanup Some(cleanup)
case _ None
}
private val cleanupTimer = cacheCleanup map { _
val interval = Duration(ext.Settings.ResolverConfig.getDuration("cache-cleanup-interval", TimeUnit.MILLISECONDS), TimeUnit.MILLISECONDS)
system.scheduler.schedule(interval, interval, self, SimpleDnsManager.CacheCleanup)
}
override def receive = {
case r @ Dns.Resolve(name)
log.debug("Resolution request for {} from {}", name, sender())
resolver.forward(r)
case SimpleDnsManager.CacheCleanup
for (c cacheCleanup)
c.cleanup()
}
override def postStop(): Unit = {
for (t cleanupTimer) t.cancel()
}
}
object SimpleDnsManager {
private case object CacheCleanup
}

View file

@ -4,6 +4,7 @@
package akka.io
import java.net.InetSocketAddress
import java.nio.channels.{ SelectionKey, SocketChannel }
import scala.util.control.NonFatal
import scala.collection.immutable
@ -26,6 +27,7 @@ private[io] class TcpOutgoingConnection(_tcp: TcpExt,
connect: Connect)
extends TcpConnection(_tcp, SocketChannel.open().configureBlocking(false).asInstanceOf[SocketChannel], connect.pullMode) {
import context._
import connect._
context.watch(commander) // sign death pact
@ -49,17 +51,40 @@ private[io] class TcpOutgoingConnection(_tcp: TcpExt,
def receive: Receive = {
case registration: ChannelRegistration
log.debug("Attempting connection to [{}]", remoteAddress)
reportConnectFailure {
if (channel.connect(remoteAddress))
completeConnect(registration, commander, options)
else {
registration.enableInterest(SelectionKey.OP_CONNECT)
context.become(connecting(registration, tcp.Settings.FinishConnectRetries))
if (remoteAddress.isUnresolved) {
log.debug("Resolving {} before connecting", remoteAddress.getHostName)
Dns.resolve(remoteAddress.getHostName)(system, self) match {
case None
context.become(resolving(registration))
case Some(resolved)
register(new InetSocketAddress(resolved.addr, remoteAddress.getPort), registration)
}
} else {
register(remoteAddress, registration)
}
}
}
def resolving(registration: ChannelRegistration): Receive = {
case resolved: Dns.Resolved
reportConnectFailure {
register(new InetSocketAddress(resolved.addr, remoteAddress.getPort), registration)
}
}
def register(address: InetSocketAddress, registration: ChannelRegistration): Unit = {
reportConnectFailure {
log.debug("Attempting connection to [{}]", address)
if (channel.connect(address))
completeConnect(registration, commander, options)
else {
registration.enableInterest(SelectionKey.OP_CONNECT)
context.become(connecting(registration, tcp.Settings.FinishConnectRetries))
}
}
}
def connecting(registration: ChannelRegistration, remainingFinishConnectRetries: Int): Receive = {
{
case ChannelConnectable

View file

@ -3,6 +3,7 @@
*/
package akka.io
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.DatagramChannel
import java.nio.channels.SelectionKey._
@ -31,25 +32,38 @@ private[io] class UdpConnection(udpConn: UdpConnectedExt,
def writePending = pendingSend ne null
context.watch(handler) // sign death pact
val channel = {
val datagramChannel = DatagramChannel.open
datagramChannel.configureBlocking(false)
val socket = datagramChannel.socket
options.foreach(_.beforeBind(datagramChannel))
try {
localAddress foreach socket.bind
datagramChannel.connect(remoteAddress)
} catch {
case NonFatal(e)
log.debug("Failure while connecting UDP channel to remote address [{}] local address [{}]: {}",
remoteAddress, localAddress.getOrElse("undefined"), e)
commander ! CommandFailed(connect)
context.stop(self)
var channel: DatagramChannel = null
if (remoteAddress.isUnresolved) {
Dns.resolve(remoteAddress.getHostName)(context.system, self) match {
case Some(r)
doConnect(new InetSocketAddress(r.addr, remoteAddress.getPort))
case None
context.become(resolving(), discardOld = true)
}
datagramChannel
} else {
doConnect(remoteAddress)
}
def resolving(): Receive = {
case r: Dns.Resolved
reportConnectFailure {
doConnect(new InetSocketAddress(r.addr, remoteAddress.getPort))
}
}
def doConnect(address: InetSocketAddress): Unit = {
reportConnectFailure {
channel = DatagramChannel.open
channel.configureBlocking(false)
val socket = channel.socket
options.foreach(_.beforeBind(channel))
localAddress foreach socket.bind
channel.connect(remoteAddress)
channelRegistry.register(channel, OP_READ)
}
log.debug("Successfully connected to [{}]", remoteAddress)
}
channelRegistry.register(channel, OP_READ)
log.debug("Successfully connected to [{}]", remoteAddress)
def receive = {
case registration: ChannelRegistration
@ -130,4 +144,16 @@ private[io] class UdpConnection(udpConn: UdpConnectedExt,
case NonFatal(e) log.debug("Error closing DatagramChannel: {}", e)
}
}
private def reportConnectFailure(thunk: Unit): Unit = {
try {
thunk
} catch {
case NonFatal(e)
log.debug("Failure while connecting UDP channel to remote address [{}] local address [{}]: {}",
remoteAddress, localAddress.getOrElse("undefined"), e)
commander ! CommandFailed(connect)
context.stop(self)
}
}
}

View file

@ -3,11 +3,14 @@
*/
package akka.io
import java.net.InetSocketAddress
import java.nio.channels.{ SelectionKey, DatagramChannel }
import akka.actor.{ ActorRef, ActorLogging, Actor }
import akka.io.Udp.{ CommandFailed, Send }
import akka.io.SelectionHandler._
import scala.util.control.NonFatal
/**
* INTERNAL API
*/
@ -39,7 +42,26 @@ private[io] trait WithUdpSend {
case send: Send
pendingSend = send
pendingCommander = sender()
doSend(registration)
if (send.target.isUnresolved) {
Dns.resolve(send.target.getHostName)(context.system, self) match {
case Some(r)
try {
pendingSend = pendingSend.copy(target = new InetSocketAddress(r.addr, pendingSend.target.getPort))
doSend(registration)
} catch {
case NonFatal(e)
sender() ! CommandFailed(send)
log.debug("Failure while sending UDP datagram to remote address [{}]: {}",
send.target, e)
retriedSend = false
pendingSend = null
pendingCommander = null
}
case None
}
} else {
doSend(registration)
}
case ChannelWritable if (hasWritePending) doSend(registration)
}