From 5954302658024b33829331dfd692551918481ece Mon Sep 17 00:00:00 2001 From: Ilya Epifanov Date: Sat, 21 Jun 2014 16:20:20 +0400 Subject: [PATCH] +act #15502: Pluggable DNS resolution in akka-io --- .../scala/akka/io/SimpleDnsCacheSpec.scala | 45 +++++++++ akka-actor/src/main/resources/reference.conf | 28 ++++++ akka-actor/src/main/scala/akka/io/Dns.scala | 91 ++++++++++++++++++ .../src/main/scala/akka/io/DnsProvider.scala | 9 ++ .../akka/io/InetAddressDnsProvider.scala | 7 ++ .../akka/io/InetAddressDnsResolver.scala | 33 +++++++ .../main/scala/akka/io/SimpleDnsCache.scala | 94 +++++++++++++++++++ .../main/scala/akka/io/SimpleDnsManager.scala | 42 +++++++++ .../scala/akka/io/TcpOutgoingConnection.scala | 37 ++++++-- .../main/scala/akka/io/UdpConnection.scala | 60 ++++++++---- .../src/main/scala/akka/io/WithUdpSend.scala | 24 ++++- 11 files changed, 446 insertions(+), 24 deletions(-) create mode 100644 akka-actor-tests/src/test/scala/akka/io/SimpleDnsCacheSpec.scala create mode 100644 akka-actor/src/main/scala/akka/io/Dns.scala create mode 100644 akka-actor/src/main/scala/akka/io/DnsProvider.scala create mode 100644 akka-actor/src/main/scala/akka/io/InetAddressDnsProvider.scala create mode 100644 akka-actor/src/main/scala/akka/io/InetAddressDnsResolver.scala create mode 100644 akka-actor/src/main/scala/akka/io/SimpleDnsCache.scala create mode 100644 akka-actor/src/main/scala/akka/io/SimpleDnsManager.scala diff --git a/akka-actor-tests/src/test/scala/akka/io/SimpleDnsCacheSpec.scala b/akka-actor-tests/src/test/scala/akka/io/SimpleDnsCacheSpec.scala new file mode 100644 index 0000000000..af796c6e0b --- /dev/null +++ b/akka-actor-tests/src/test/scala/akka/io/SimpleDnsCacheSpec.scala @@ -0,0 +1,45 @@ +package akka.io + +import java.net.InetAddress +import java.util.concurrent.atomic.AtomicLong + +import org.scalatest.{ ShouldMatchers, WordSpec } + +class SimpleDnsCacheSpec extends WordSpec with ShouldMatchers { + "Cache" should { + "not reply with expired but not yet swept out entries" in { + val localClock = new AtomicLong(0) + val cache: SimpleDnsCache = new SimpleDnsCache() { + override protected def clock() = localClock.get + } + val cacheEntry = Dns.Resolved("test.local", Seq(InetAddress.getByName("127.0.0.1"))) + cache.put(cacheEntry, 5000) + + cache.cached("test.local") should equal(Some(cacheEntry)) + localClock.set(4999) + cache.cached("test.local") should equal(Some(cacheEntry)) + localClock.set(5000) + cache.cached("test.local") should equal(None) + } + + "sweep out expired entries on cleanup()" in { + val localClock = new AtomicLong(0) + val cache: SimpleDnsCache = new SimpleDnsCache() { + override protected def clock() = localClock.get + } + val cacheEntry = Dns.Resolved("test.local", Seq(InetAddress.getByName("127.0.0.1"))) + cache.put(cacheEntry, 5000) + + cache.cached("test.local") should equal(Some(cacheEntry)) + localClock.set(5000) + cache.cached("test.local") should equal(None) + localClock.set(0) + cache.cached("test.local") should equal(Some(cacheEntry)) + localClock.set(5000) + cache.cleanup() + cache.cached("test.local") should equal(None) + localClock.set(0) + cache.cached("test.local") should equal(None) + } + } +} diff --git a/akka-actor/src/main/resources/reference.conf b/akka-actor/src/main/resources/reference.conf index ed76edc77e..65fbf4facc 100644 --- a/akka-actor/src/main/resources/reference.conf +++ b/akka-actor/src/main/resources/reference.conf @@ -226,6 +226,12 @@ akka { messages-per-resize = 10 } } + + /IO-DNS/inet-address { + mailbox = "unbounded" + router = "consistent-hashing-pool" + nr-of-instances = 4 + } } default-dispatcher { @@ -729,6 +735,28 @@ akka { management-dispatcher = "akka.actor.default-dispatcher" } + dns { + # Fully qualified config path which holds the dispatcher configuration + # for the manager and resolver router actors. + # For actual router configuration see akka.actor.deployment./IO-DNS/* + dispatcher = "akka.actor.default-dispatcher" + + # Name of the subconfig at path akka.io.dns, see inet-address below + resolver = "inet-address" + + inet-address { + # Must implement akka.io.DnsProvider + provider-object = "akka.io.InetAddressDnsProvider" + + # These TTLs are set to default java 6 values + positive-ttl = 30s + negative-ttl = 10s + + # How often to sweep out expired cache entries. + # Note that this interval has nothing to do with TTLs + cache-cleanup-interval = 120s + } + } } diff --git a/akka-actor/src/main/scala/akka/io/Dns.scala b/akka-actor/src/main/scala/akka/io/Dns.scala new file mode 100644 index 0000000000..fb621334e4 --- /dev/null +++ b/akka-actor/src/main/scala/akka/io/Dns.scala @@ -0,0 +1,91 @@ +package akka.io + +import java.net.{ Inet4Address, Inet6Address, InetAddress, UnknownHostException } + +import akka.actor._ +import akka.routing.ConsistentHashingRouter.ConsistentHashable +import com.typesafe.config.Config + +import scala.collection.{ breakOut, immutable } + +abstract class Dns { + def cached(name: String): Option[Dns.Resolved] = None + def resolve(name: String)(system: ActorSystem, sender: ActorRef): Option[Dns.Resolved] = { + val ret = cached(name) + if (ret.isEmpty) + IO(Dns)(system).tell(Dns.Resolve(name), sender) + ret + } +} + +object Dns extends ExtensionId[DnsExt] with ExtensionIdProvider { + sealed trait Command + + case class Resolve(name: String) extends Command with ConsistentHashable { + override def consistentHashKey = name + } + + case class Resolved(name: String, ipv4: immutable.Seq[Inet4Address], ipv6: immutable.Seq[Inet6Address]) extends Command { + val addrOption: Option[InetAddress] = ipv4.headOption orElse ipv6.headOption + + @throws[UnknownHostException] + def addr: InetAddress = addrOption match { + case Some(addr) ⇒ addr + case None ⇒ throw new UnknownHostException(name) + } + } + + object Resolved { + def apply(name: String, addresses: Iterable[InetAddress]): Resolved = { + val ipv4: immutable.Seq[Inet4Address] = addresses.collect({ + case a: Inet4Address ⇒ a + })(breakOut) + val ipv6: immutable.Seq[Inet6Address] = addresses.collect({ + case a: Inet6Address ⇒ a + })(breakOut) + Resolved(name, ipv4, ipv6) + } + } + + def cached(name: String)(system: ActorSystem): Option[Resolved] = { + Dns(system).cache.cached(name) + } + + def resolve(name: String)(system: ActorSystem, sender: ActorRef): Option[Resolved] = { + Dns(system).cache.resolve(name)(system, sender) + } + + override def lookup() = Dns + + override def createExtension(system: ExtendedActorSystem): DnsExt = new DnsExt(system) + + /** + * Java API: retrieve the Udp extension for the given system. + */ + override def get(system: ActorSystem): DnsExt = super.get(system) +} + +class DnsExt(system: ExtendedActorSystem) extends IO.Extension { + val Settings = new Settings(system.settings.config.getConfig("akka.io.dns")) + + class Settings private[DnsExt] (_config: Config) { + + import _config._ + + val Dispatcher: String = getString("dispatcher") + val Resolver: String = getString("resolver") + val ResolverConfig: Config = getConfig(Resolver) + val ProviderObjectName: String = ResolverConfig.getString("provider-object") + } + + val provider: DnsProvider = system.dynamicAccess.getClassFor[DnsProvider](Settings.ProviderObjectName).get.newInstance() + val cache: Dns = provider.cache + + val manager: ActorRef = { + system.systemActorOf( + props = Props(classOf[SimpleDnsManager], this).withDeploy(Deploy.local).withDispatcher(Settings.Dispatcher), + name = "IO-DNS") + } + + def getResolver: ActorRef = manager +} diff --git a/akka-actor/src/main/scala/akka/io/DnsProvider.scala b/akka-actor/src/main/scala/akka/io/DnsProvider.scala new file mode 100644 index 0000000000..9fe094bce0 --- /dev/null +++ b/akka-actor/src/main/scala/akka/io/DnsProvider.scala @@ -0,0 +1,9 @@ +package akka.io + +import akka.actor.Actor + +trait DnsProvider { + def cache: Dns + def actorClass: Class[_ <: Actor] + def managerClass: Class[_ <: Actor] +} diff --git a/akka-actor/src/main/scala/akka/io/InetAddressDnsProvider.scala b/akka-actor/src/main/scala/akka/io/InetAddressDnsProvider.scala new file mode 100644 index 0000000000..c31fdf457f --- /dev/null +++ b/akka-actor/src/main/scala/akka/io/InetAddressDnsProvider.scala @@ -0,0 +1,7 @@ +package akka.io + +class InetAddressDnsProvider extends DnsProvider { + override def cache: Dns = new SimpleDnsCache() + override def actorClass = classOf[InetAddressDnsResolver] + override def managerClass = classOf[SimpleDnsManager] +} diff --git a/akka-actor/src/main/scala/akka/io/InetAddressDnsResolver.scala b/akka-actor/src/main/scala/akka/io/InetAddressDnsResolver.scala new file mode 100644 index 0000000000..a288b56b0a --- /dev/null +++ b/akka-actor/src/main/scala/akka/io/InetAddressDnsResolver.scala @@ -0,0 +1,33 @@ +package akka.io + +import java.net.{ UnknownHostException, InetAddress } +import java.util.concurrent.TimeUnit + +import akka.actor.Actor +import com.typesafe.config.Config + +import scala.collection.immutable + +class InetAddressDnsResolver(cache: SimpleDnsCache, config: Config) extends Actor { + val positiveTtl = config.getDuration("positive-ttl", TimeUnit.MILLISECONDS) + val negativeTtl = config.getDuration("negative-ttl", TimeUnit.MILLISECONDS) + + override def receive = { + case Dns.Resolve(name) ⇒ + val answer = cache.cached(name) match { + case Some(a) ⇒ a + case None ⇒ + try { + val answer = Dns.Resolved(name, InetAddress.getAllByName(name)) + cache.put(answer, positiveTtl) + answer + } catch { + case e: UnknownHostException ⇒ + val answer = Dns.Resolved(name, immutable.Seq.empty, immutable.Seq.empty) + cache.put(answer, negativeTtl) + answer + } + } + sender() ! answer + } +} diff --git a/akka-actor/src/main/scala/akka/io/SimpleDnsCache.scala b/akka-actor/src/main/scala/akka/io/SimpleDnsCache.scala new file mode 100644 index 0000000000..6ee9aab3ca --- /dev/null +++ b/akka-actor/src/main/scala/akka/io/SimpleDnsCache.scala @@ -0,0 +1,94 @@ +package akka.io + +import java.util.concurrent.atomic.AtomicReference +import akka.io.Dns.Resolved + +import scala.annotation.tailrec +import scala.collection.immutable + +private[io] sealed trait PeriodicCacheCleanup { + def cleanup(): Unit +} + +class SimpleDnsCache extends Dns with PeriodicCacheCleanup { + import akka.io.SimpleDnsCache._ + + private val cache = new AtomicReference(new Cache( + immutable.SortedSet()(ExpiryEntryOrdering), + immutable.Map(), clock)) + + private val nanoBase = System.nanoTime() + + override def cached(name: String): Option[Resolved] = { + cache.get().get(name) + } + + protected def clock(): Long = { + val now = System.nanoTime() + if (now - nanoBase < 0) 0 + else (now - nanoBase) / 1000000 + } + + @tailrec + private[io] final def put(r: Resolved, ttlMillis: Long): Unit = { + val c = cache.get() + if (!cache.compareAndSet(c, c.put(r, ttlMillis))) + put(r, ttlMillis) + } + + @tailrec + override final def cleanup(): Unit = { + val c = cache.get() + if (!cache.compareAndSet(c, c.cleanup())) + cleanup() + } +} + +object SimpleDnsCache { + private class Cache(queue: immutable.SortedSet[ExpiryEntry], cache: immutable.Map[String, CacheEntry], clock: () ⇒ Long) { + def get(name: String): Option[Resolved] = { + for { + e ← cache.get(name) + if e.isValid(clock()) + } yield e.answer + } + + def put(answer: Resolved, ttlMillis: Long): Cache = { + val until = clock() + ttlMillis + + new Cache( + queue + new ExpiryEntry(answer.name, until), + cache + (answer.name -> CacheEntry(answer, until)), + clock) + } + + def cleanup(): Cache = { + val now = clock() + var q = queue + var c = cache + while (q.nonEmpty && !q.head.isValid(now)) { + val minEntry = q.head + val name = minEntry.name + q -= minEntry + if (c.get(name).filterNot(_.isValid(now)).isDefined) + c -= name + } + new Cache(q, c, clock) + } + } + + private case class CacheEntry(answer: Dns.Resolved, until: Long) { + def isValid(clock: Long): Boolean = clock < until + } + + private class ExpiryEntry(val name: String, val until: Long) extends Ordered[ExpiryEntry] { + def isValid(clock: Long): Boolean = clock < until + override def compare(that: ExpiryEntry): Int = -until.compareTo(that.until) + } + + private object ExpiryEntryOrdering extends Ordering[ExpiryEntry] { + override def compare(x: ExpiryEntry, y: ExpiryEntry): Int = { + x.until.compareTo(y.until) + } + } +} diff --git a/akka-actor/src/main/scala/akka/io/SimpleDnsManager.scala b/akka-actor/src/main/scala/akka/io/SimpleDnsManager.scala new file mode 100644 index 0000000000..856fb99191 --- /dev/null +++ b/akka-actor/src/main/scala/akka/io/SimpleDnsManager.scala @@ -0,0 +1,42 @@ +package akka.io + +import java.util.concurrent.TimeUnit + +import akka.actor.{ ActorLogging, Actor, Deploy, Props } +import akka.dispatch.{ RequiresMessageQueue, UnboundedMessageQueueSemantics } +import akka.routing.FromConfig + +import scala.concurrent.duration.Duration + +class SimpleDnsManager(val ext: DnsExt) extends Actor with RequiresMessageQueue[UnboundedMessageQueueSemantics] with ActorLogging { + + import context._ + + private val resolver = actorOf(FromConfig.props(Props(ext.provider.actorClass, ext.cache, ext.Settings.ResolverConfig).withDeploy(Deploy.local).withDispatcher(ext.Settings.Dispatcher)), ext.Settings.Resolver) + private val cacheCleanup = ext.cache match { + case cleanup: PeriodicCacheCleanup ⇒ Some(cleanup) + case _ ⇒ None + } + + private val cleanupTimer = cacheCleanup map { _ ⇒ + val interval = Duration(ext.Settings.ResolverConfig.getDuration("cache-cleanup-interval", TimeUnit.MILLISECONDS), TimeUnit.MILLISECONDS) + system.scheduler.schedule(interval, interval, self, SimpleDnsManager.CacheCleanup) + } + + override def receive = { + case r @ Dns.Resolve(name) ⇒ + log.debug("Resolution request for {} from {}", name, sender()) + resolver.forward(r) + case SimpleDnsManager.CacheCleanup ⇒ + for (c ← cacheCleanup) + c.cleanup() + } + + override def postStop(): Unit = { + for (t ← cleanupTimer) t.cancel() + } +} + +object SimpleDnsManager { + private case object CacheCleanup +} diff --git a/akka-actor/src/main/scala/akka/io/TcpOutgoingConnection.scala b/akka-actor/src/main/scala/akka/io/TcpOutgoingConnection.scala index 8c87dee287..ab83b067be 100644 --- a/akka-actor/src/main/scala/akka/io/TcpOutgoingConnection.scala +++ b/akka-actor/src/main/scala/akka/io/TcpOutgoingConnection.scala @@ -4,6 +4,7 @@ package akka.io +import java.net.InetSocketAddress import java.nio.channels.{ SelectionKey, SocketChannel } import scala.util.control.NonFatal import scala.collection.immutable @@ -26,6 +27,7 @@ private[io] class TcpOutgoingConnection(_tcp: TcpExt, connect: Connect) extends TcpConnection(_tcp, SocketChannel.open().configureBlocking(false).asInstanceOf[SocketChannel], connect.pullMode) { + import context._ import connect._ context.watch(commander) // sign death pact @@ -49,17 +51,40 @@ private[io] class TcpOutgoingConnection(_tcp: TcpExt, def receive: Receive = { case registration: ChannelRegistration ⇒ - log.debug("Attempting connection to [{}]", remoteAddress) reportConnectFailure { - if (channel.connect(remoteAddress)) - completeConnect(registration, commander, options) - else { - registration.enableInterest(SelectionKey.OP_CONNECT) - context.become(connecting(registration, tcp.Settings.FinishConnectRetries)) + if (remoteAddress.isUnresolved) { + log.debug("Resolving {} before connecting", remoteAddress.getHostName) + Dns.resolve(remoteAddress.getHostName)(system, self) match { + case None ⇒ + context.become(resolving(registration)) + case Some(resolved) ⇒ + register(new InetSocketAddress(resolved.addr, remoteAddress.getPort), registration) + } + } else { + register(remoteAddress, registration) } } } + def resolving(registration: ChannelRegistration): Receive = { + case resolved: Dns.Resolved ⇒ + reportConnectFailure { + register(new InetSocketAddress(resolved.addr, remoteAddress.getPort), registration) + } + } + + def register(address: InetSocketAddress, registration: ChannelRegistration): Unit = { + reportConnectFailure { + log.debug("Attempting connection to [{}]", address) + if (channel.connect(address)) + completeConnect(registration, commander, options) + else { + registration.enableInterest(SelectionKey.OP_CONNECT) + context.become(connecting(registration, tcp.Settings.FinishConnectRetries)) + } + } + } + def connecting(registration: ChannelRegistration, remainingFinishConnectRetries: Int): Receive = { { case ChannelConnectable ⇒ diff --git a/akka-actor/src/main/scala/akka/io/UdpConnection.scala b/akka-actor/src/main/scala/akka/io/UdpConnection.scala index dc63992123..7111742971 100644 --- a/akka-actor/src/main/scala/akka/io/UdpConnection.scala +++ b/akka-actor/src/main/scala/akka/io/UdpConnection.scala @@ -3,6 +3,7 @@ */ package akka.io +import java.net.InetSocketAddress import java.nio.ByteBuffer import java.nio.channels.DatagramChannel import java.nio.channels.SelectionKey._ @@ -31,25 +32,38 @@ private[io] class UdpConnection(udpConn: UdpConnectedExt, def writePending = pendingSend ne null context.watch(handler) // sign death pact - val channel = { - val datagramChannel = DatagramChannel.open - datagramChannel.configureBlocking(false) - val socket = datagramChannel.socket - options.foreach(_.beforeBind(datagramChannel)) - try { - localAddress foreach socket.bind - datagramChannel.connect(remoteAddress) - } catch { - case NonFatal(e) ⇒ - log.debug("Failure while connecting UDP channel to remote address [{}] local address [{}]: {}", - remoteAddress, localAddress.getOrElse("undefined"), e) - commander ! CommandFailed(connect) - context.stop(self) + var channel: DatagramChannel = null + + if (remoteAddress.isUnresolved) { + Dns.resolve(remoteAddress.getHostName)(context.system, self) match { + case Some(r) ⇒ + doConnect(new InetSocketAddress(r.addr, remoteAddress.getPort)) + case None ⇒ + context.become(resolving(), discardOld = true) } - datagramChannel + } else { + doConnect(remoteAddress) + } + + def resolving(): Receive = { + case r: Dns.Resolved ⇒ + reportConnectFailure { + doConnect(new InetSocketAddress(r.addr, remoteAddress.getPort)) + } + } + + def doConnect(address: InetSocketAddress): Unit = { + reportConnectFailure { + channel = DatagramChannel.open + channel.configureBlocking(false) + val socket = channel.socket + options.foreach(_.beforeBind(channel)) + localAddress foreach socket.bind + channel.connect(remoteAddress) + channelRegistry.register(channel, OP_READ) + } + log.debug("Successfully connected to [{}]", remoteAddress) } - channelRegistry.register(channel, OP_READ) - log.debug("Successfully connected to [{}]", remoteAddress) def receive = { case registration: ChannelRegistration ⇒ @@ -130,4 +144,16 @@ private[io] class UdpConnection(udpConn: UdpConnectedExt, case NonFatal(e) ⇒ log.debug("Error closing DatagramChannel: {}", e) } } + + private def reportConnectFailure(thunk: ⇒ Unit): Unit = { + try { + thunk + } catch { + case NonFatal(e) ⇒ + log.debug("Failure while connecting UDP channel to remote address [{}] local address [{}]: {}", + remoteAddress, localAddress.getOrElse("undefined"), e) + commander ! CommandFailed(connect) + context.stop(self) + } + } } diff --git a/akka-actor/src/main/scala/akka/io/WithUdpSend.scala b/akka-actor/src/main/scala/akka/io/WithUdpSend.scala index 6d5c1b5bbd..1c0a2b9f26 100644 --- a/akka-actor/src/main/scala/akka/io/WithUdpSend.scala +++ b/akka-actor/src/main/scala/akka/io/WithUdpSend.scala @@ -3,11 +3,14 @@ */ package akka.io +import java.net.InetSocketAddress import java.nio.channels.{ SelectionKey, DatagramChannel } import akka.actor.{ ActorRef, ActorLogging, Actor } import akka.io.Udp.{ CommandFailed, Send } import akka.io.SelectionHandler._ +import scala.util.control.NonFatal + /** * INTERNAL API */ @@ -39,7 +42,26 @@ private[io] trait WithUdpSend { case send: Send ⇒ pendingSend = send pendingCommander = sender() - doSend(registration) + if (send.target.isUnresolved) { + Dns.resolve(send.target.getHostName)(context.system, self) match { + case Some(r) ⇒ + try { + pendingSend = pendingSend.copy(target = new InetSocketAddress(r.addr, pendingSend.target.getPort)) + doSend(registration) + } catch { + case NonFatal(e) ⇒ + sender() ! CommandFailed(send) + log.debug("Failure while sending UDP datagram to remote address [{}]: {}", + send.target, e) + retriedSend = false + pendingSend = null + pendingCommander = null + } + case None ⇒ + } + } else { + doSend(registration) + } case ChannelWritable ⇒ if (hasWritePending) doSend(registration) }