diff --git a/akka-actor-tests/src/test/scala/akka/io/InetAddressDnsResolverSpec.scala b/akka-actor-tests/src/test/scala/akka/io/InetAddressDnsResolverSpec.scala index 4fa8c2eef4..1cdeb04141 100644 --- a/akka-actor-tests/src/test/scala/akka/io/InetAddressDnsResolverSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/io/InetAddressDnsResolverSpec.scala @@ -68,6 +68,21 @@ class InetAddressDnsResolverSpec extends AkkaSpec(""" } } } + + "use Forever when system Property (or the security property) value is lower than zero" in { + withNewSecurityProperty("networkaddress.cache.negative.ttl", "-1") { + withNewSystemProperty("sun.net.inetaddr.negative.ttl", "") { + dnsResolver.negativeTtl shouldBe Long.MaxValue + } + } + } + "use Never when system Property (or the security property) value is zero" in { + withNewSecurityProperty("networkaddress.cache.negative.ttl", "0") { + withNewSystemProperty("sun.net.inetaddr.negative.ttl", "") { + dnsResolver.negativeTtl shouldBe 0 + } + } + } } private def secondsToMillis(seconds: Int) = TimeUnit.SECONDS.toMillis(seconds) @@ -102,3 +117,31 @@ class InetAddressDnsResolverSpec extends AkkaSpec(""" } } +class InetAddressDnsResolverConfigSpec extends AkkaSpec( + """ + akka.io.dns.inet-address.positive-ttl = forever + akka.io.dns.inet-address.negative-ttl = never + akka.actor.serialize-creators = on + """) { + thisSpecs ⇒ + + "The DNS resolver parsed ttl's" must { + "use ttl=Long.MaxValue if user provides 'forever' " in { + dnsResolver.positiveTtl shouldBe Long.MaxValue + } + + "use ttl=0 if user provides 'never' " in { + dnsResolver.negativeTtl shouldBe 0 + } + + } + + private def dnsResolver = { + val actorRef = TestActorRef[InetAddressDnsResolver](Props( + classOf[InetAddressDnsResolver], + new SimpleDnsCache(), + system.settings.config.getConfig("akka.io.dns.inet-address") + )) + actorRef.underlyingActor + } +} diff --git a/akka-actor-tests/src/test/scala/akka/io/SimpleDnsCacheSpec.scala b/akka-actor-tests/src/test/scala/akka/io/SimpleDnsCacheSpec.scala index ad302dcffe..87eca27cd6 100644 --- a/akka-actor-tests/src/test/scala/akka/io/SimpleDnsCacheSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/io/SimpleDnsCacheSpec.scala @@ -7,8 +7,11 @@ package akka.io import java.net.InetAddress import java.util.concurrent.atomic.AtomicLong +import akka.io.dns.CachePolicy.Ttl import org.scalatest.{ Matchers, WordSpec } +import scala.concurrent.duration._ + class SimpleDnsCacheSpec extends WordSpec with Matchers { "Cache" should { "not reply with expired but not yet swept out entries" in { @@ -17,7 +20,7 @@ class SimpleDnsCacheSpec extends WordSpec with Matchers { override protected def clock() = localClock.get } val cacheEntry = Dns.Resolved("test.local", Seq(InetAddress.getByName("127.0.0.1"))) - cache.put(cacheEntry, 5000) + cache.put(cacheEntry, Ttl.fromPositive(5000.millis)) cache.cached("test.local") should ===(Some(cacheEntry)) localClock.set(4999) @@ -32,7 +35,7 @@ class SimpleDnsCacheSpec extends WordSpec with Matchers { override protected def clock() = localClock.get } val cacheEntry = Dns.Resolved("test.local", Seq(InetAddress.getByName("127.0.0.1"))) - cache.put(cacheEntry, 5000) + cache.put(cacheEntry, Ttl.fromPositive(5000.millis)) cache.cached("test.local") should ===(Some(cacheEntry)) localClock.set(5000) diff --git a/akka-actor-tests/src/test/scala/akka/io/dns/AsyncDnsResolverIntegrationSpec.scala b/akka-actor-tests/src/test/scala/akka/io/dns/AsyncDnsResolverIntegrationSpec.scala index d928ec2bd5..d3dd4a8084 100644 --- a/akka-actor-tests/src/test/scala/akka/io/dns/AsyncDnsResolverIntegrationSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/io/dns/AsyncDnsResolverIntegrationSpec.scala @@ -8,6 +8,7 @@ import java.net.InetAddress import akka.io.dns.DnsProtocol.{ Ip, RequestType, Srv } import akka.io.{ Dns, IO } +import CachePolicy.Ttl import akka.pattern.ask import akka.testkit.{ AkkaSpec, SocketUtil } import akka.util.Timeout @@ -124,8 +125,8 @@ class AsyncDnsResolverIntegrationSpec extends AkkaSpec( answer.name shouldEqual name answer.records.collect { case r: SRVRecord ⇒ r }.toSet shouldEqual Set( - SRVRecord("service.tcp.foo.test", 86400, 10, 65534, 5060, "a-single.foo.test"), - SRVRecord("service.tcp.foo.test", 86400, 65533, 40, 65535, "a-double.foo.test") + SRVRecord("service.tcp.foo.test", Ttl.fromPositive(86400.seconds), 10, 65534, 5060, "a-single.foo.test"), + SRVRecord("service.tcp.foo.test", Ttl.fromPositive(86400.seconds), 65533, 40, 65535, "a-double.foo.test") ) } diff --git a/akka-actor-tests/src/test/scala/akka/io/dns/internal/AsyncDnsManagerSpec.scala b/akka-actor-tests/src/test/scala/akka/io/dns/internal/AsyncDnsManagerSpec.scala index 7ac54d74b0..d52bdc733d 100644 --- a/akka-actor-tests/src/test/scala/akka/io/dns/internal/AsyncDnsManagerSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/io/dns/internal/AsyncDnsManagerSpec.scala @@ -4,13 +4,15 @@ package akka.io.dns.internal -import java.net.{ Inet6Address, InetAddress } +import java.net.InetAddress + +import akka.io.Dns +import akka.io.dns.AAAARecord +import akka.io.dns.DnsProtocol.{ Resolve, Resolved } +import akka.io.dns.CachePolicy.Ttl +import akka.testkit.{ AkkaSpec, ImplicitSender } import scala.collection.immutable.Seq -import akka.io.Dns -import akka.io.dns.{ AAAARecord, ResourceRecord } -import akka.io.dns.DnsProtocol.{ Resolve, Resolved } -import akka.testkit.{ AkkaSpec, ImplicitSender } class AsyncDnsManagerSpec extends AkkaSpec( """ @@ -30,7 +32,7 @@ class AsyncDnsManagerSpec extends AkkaSpec( "support ipv6" in { dns ! Resolve("::1") // ::1 will short circuit the resolution - val Resolved("::1", Seq(AAAARecord("::1", Int.MaxValue, _)), Nil) = expectMsgType[Resolved] + val Resolved("::1", Seq(AAAARecord("::1", Ttl.effectivelyForever, _)), Nil) = expectMsgType[Resolved] } "support ipv6 also using the old protocol" in { diff --git a/akka-actor-tests/src/test/scala/akka/io/dns/internal/AsyncDnsResolverSpec.scala b/akka-actor-tests/src/test/scala/akka/io/dns/internal/AsyncDnsResolverSpec.scala index 687e53f494..0be17cef44 100644 --- a/akka-actor-tests/src/test/scala/akka/io/dns/internal/AsyncDnsResolverSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/io/dns/internal/AsyncDnsResolverSpec.scala @@ -8,15 +8,16 @@ import java.net.{ Inet6Address, InetAddress } import akka.actor.Status.Failure import akka.actor.{ ActorRef, ExtendedActorSystem, Props } +import akka.io.dns.DnsProtocol._ +import akka.io.dns.internal.AsyncDnsResolver.ResolveFailedException +import akka.io.dns.CachePolicy.Ttl +import akka.io.dns.internal.DnsClient.{ Answer, Question4, Question6, SrvQuestion } import akka.io.dns.{ AAAARecord, ARecord, DnsSettings, SRVRecord } import akka.testkit.{ AkkaSpec, ImplicitSender, TestProbe } import com.typesafe.config.ConfigFactory -import akka.io.dns.DnsProtocol._ -import akka.io.dns.internal.AsyncDnsResolver.ResolveFailedException -import akka.io.dns.internal.DnsClient.{ Answer, Question4, Question6, SrvQuestion } -import scala.concurrent.duration._ import scala.collection.{ immutable ⇒ im } +import scala.concurrent.duration._ class AsyncDnsResolverSpec extends AkkaSpec( """ @@ -66,10 +67,11 @@ class AsyncDnsResolverSpec extends AkkaSpec( val r = resolver(List(dnsClient1.ref)) r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = true)) dnsClient1.expectMsg(Question4(1, "cats.com")) - val ipv4Record = ARecord("cats.com", 100, InetAddress.getByName("127.0.0.1")) + val ttl = Ttl.fromPositive(100.seconds) + val ipv4Record = ARecord("cats.com", ttl, InetAddress.getByName("127.0.0.1")) dnsClient1.reply(Answer(1, im.Seq(ipv4Record))) dnsClient1.expectMsg(Question6(2, "cats.com")) - val ipv6Record = AAAARecord("cats.com", 100, InetAddress.getByName("::1").asInstanceOf[Inet6Address]) + val ipv6Record = AAAARecord("cats.com", ttl, InetAddress.getByName("::1").asInstanceOf[Inet6Address]) dnsClient1.reply(Answer(2, im.Seq(ipv6Record))) expectMsg(Resolved("cats.com", im.Seq(ipv4Record, ipv6Record))) } @@ -103,7 +105,7 @@ class AsyncDnsResolverSpec extends AkkaSpec( dnsClient1.expectNoMessage(50.millis) val answer = expectMsgType[Resolved] answer.records.collect { case r: ARecord ⇒ r }.toSet shouldEqual Set( - ARecord("127.0.0.1", Int.MaxValue, InetAddress.getByName("127.0.0.1")) + ARecord("127.0.0.1", Ttl.effectivelyForever, InetAddress.getByName("127.0.0.1")) ) } @@ -114,7 +116,7 @@ class AsyncDnsResolverSpec extends AkkaSpec( r ! Resolve(name) dnsClient1.expectNoMessage(50.millis) val answer = expectMsgType[Resolved] - val Seq(AAAARecord("1:2:3:0:0:0:0:0", Int.MaxValue, _)) = answer.records.collect { case r: AAAARecord ⇒ r } + val Seq(AAAARecord("1:2:3:0:0:0:0:0", Ttl.effectivelyForever, _)) = answer.records.collect { case r: AAAARecord ⇒ r } } "return additional records for SRV requests" in { @@ -123,8 +125,8 @@ class AsyncDnsResolverSpec extends AkkaSpec( val r = resolver(List(dnsClient1.ref, dnsClient2.ref)) r ! Resolve("cats.com", Srv) dnsClient1.expectMsg(SrvQuestion(1, "cats.com")) - val srvRecs = im.Seq(SRVRecord("cats.com", 5000, 1, 1, 1, "a.cats.com")) - val aRecs = im.Seq(ARecord("a.cats.com", 1, InetAddress.getByName("127.0.0.1"))) + val srvRecs = im.Seq(SRVRecord("cats.com", Ttl.fromPositive(5000.seconds), 1, 1, 1, "a.cats.com")) + val aRecs = im.Seq(ARecord("a.cats.com", Ttl.fromPositive(1.seconds), InetAddress.getByName("127.0.0.1"))) dnsClient1.reply(Answer(1, srvRecs, aRecs)) dnsClient2.expectNoMessage(50.millis) expectMsg(Resolved("cats.com", srvRecs, aRecs)) diff --git a/akka-actor/src/main/mima-filters/2.5.17.backwards.excludes b/akka-actor/src/main/mima-filters/2.5.17.backwards.excludes index 858cad6f04..0a83e47279 100644 --- a/akka-actor/src/main/mima-filters/2.5.17.backwards.excludes +++ b/akka-actor/src/main/mima-filters/2.5.17.backwards.excludes @@ -14,4 +14,3 @@ ProblemFilters.exclude[DirectMissingMethodProblem]("akka.io.dns.CNameRecord.writ ProblemFilters.exclude[DirectMissingMethodProblem]("akka.io.dns.AAAARecord.write") ProblemFilters.exclude[DirectMissingMethodProblem]("akka.io.dns.ResourceRecord.write") ProblemFilters.exclude[DirectMissingMethodProblem]("akka.io.dns.SRVRecord.write") - diff --git a/akka-actor/src/main/mima-filters/2.5.18.backwards.excludes b/akka-actor/src/main/mima-filters/2.5.18.backwards.excludes new file mode 100644 index 0000000000..977f3e2703 --- /dev/null +++ b/akka-actor/src/main/mima-filters/2.5.18.backwards.excludes @@ -0,0 +1,49 @@ +# Replaces DNS TTL primitive types with Duration #25850 +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.SimpleDnsCache.put") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.SimpleDnsCache#Cache.put") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.internal.AsyncDnsCache.put") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.ResourceRecord.this") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.io.dns.ResourceRecord.ttlInSeconds") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.io.dns.ResourceRecord.ttl") +# +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.AAAARecord.apply") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.AAAARecord.copy") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.AAAARecord.parseBody") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.AAAARecord.this") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.io.dns.AAAARecord.copy$default$2") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.io.dns.AAAARecord.ttlInSeconds") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.io.dns.AAAARecord.ttl") +# +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.ARecord.apply") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.ARecord.copy") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.ARecord.parseBody") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.ARecord.this") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.io.dns.ARecord.copy$default$2") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.io.dns.ARecord.ttlInSeconds") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.io.dns.ARecord.ttl") +# +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.CNameRecord.apply") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.CNameRecord.copy") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.CNameRecord.parseBody") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.CNameRecord.this") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.io.dns.CNameRecord.copy$default$2") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.io.dns.CNameRecord.ttlInSeconds") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.io.dns.CNameRecord.ttl") +# +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.SRVRecord.apply") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.SRVRecord.copy") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.SRVRecord.parseBody") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.SRVRecord.this") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.io.dns.SRVRecord.copy$default$2") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.io.dns.SRVRecord.ttlInSeconds") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.io.dns.SRVRecord.ttl") +# +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.UnknownRecord.apply") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.UnknownRecord.copy") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.UnknownRecord.parseBody") +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.io.dns.UnknownRecord.this") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.io.dns.UnknownRecord.copy$default$2") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.io.dns.UnknownRecord.ttlInSeconds") +ProblemFilters.exclude[IncompatibleResultTypeProblem]("akka.io.dns.UnknownRecord.ttl") + + diff --git a/akka-actor/src/main/scala/akka/io/InetAddressDnsResolver.scala b/akka-actor/src/main/scala/akka/io/InetAddressDnsResolver.scala index 6e9ce373c8..8e3881f593 100644 --- a/akka-actor/src/main/scala/akka/io/InetAddressDnsResolver.scala +++ b/akka-actor/src/main/scala/akka/io/InetAddressDnsResolver.scala @@ -8,16 +8,17 @@ import java.net.{ InetAddress, UnknownHostException } import java.security.Security import java.util.concurrent.TimeUnit -import akka.actor.Actor +import akka.io.dns.CachePolicy._ +import akka.actor.{ Actor, ActorLogging } +import akka.util.Helpers.Requiring import com.typesafe.config.Config import scala.collection.immutable -import akka.util.Helpers.Requiring - +import scala.concurrent.duration._ import scala.util.Try /** Respects the settings that can be set on the Java runtime via parameters. */ -class InetAddressDnsResolver(cache: SimpleDnsCache, config: Config) extends Actor { +class InetAddressDnsResolver(cache: SimpleDnsCache, config: Config) extends Actor with ActorLogging { // Controls the cache policy for successful lookups only private final val CachePolicyProp = "networkaddress.cache.ttl" @@ -29,41 +30,61 @@ class InetAddressDnsResolver(cache: SimpleDnsCache, config: Config) extends Acto // Deprecated JVM property key, keeping for legacy compatibility; replaced by NegativeCachePolicyProp private final val NegativeCachePolicyPropFallback = "sun.net.inetaddr.negative.ttl" - // default values (-1 and 0 are magic numbers, trust them) - private final val Forever = -1 - private final val Never = 0 - private final val DefaultPositive = 30 + private final val DefaultPositive = Ttl.fromPositive(30.seconds) - private lazy val cachePolicy: Int = { - val n = Try(Security.getProperty(CachePolicyProp).toInt) + private lazy val defaultCachePolicy: CachePolicy = + Try(Security.getProperty(CachePolicyProp).toInt) .orElse(Try(System.getProperty(CachePolicyPropFallback).toInt)) - .getOrElse(DefaultPositive) // default - if (n < 0) Forever else n - } + .map(parsePolicy) + .getOrElse { + log.warning("No caching TTL defined. Using default value {}.", DefaultPositive) + DefaultPositive + } - private lazy val negativeCachePolicy = { - val n = Try(Security.getProperty(NegativeCachePolicyProp).toInt) + private lazy val defaultNegativeCachePolicy: CachePolicy = + Try(Security.getProperty(NegativeCachePolicyProp).toInt) .orElse(Try(System.getProperty(NegativeCachePolicyPropFallback).toInt)) - .getOrElse(0) // default - if (n < 0) Forever else n + .map(parsePolicy) + .getOrElse { + log.warning("No negative caching TTL defined. Using default value {}.", Never) + Never + } + + private def parsePolicy(n: Int): CachePolicy = { + n match { + case 0 ⇒ Never + case x if x < 0 ⇒ Forever + case x ⇒ Ttl.fromPositive(x.seconds) + } } - private def getTtl(path: String, positive: Boolean): Long = + private def getTtl(path: String, positive: Boolean): CachePolicy = config.getString(path) match { - case "default" ⇒ - (if (positive) cachePolicy else negativeCachePolicy) match { - case Never ⇒ Never - case n if n > 0 ⇒ TimeUnit.SECONDS.toMillis(n) - case _ ⇒ Long.MaxValue // forever if negative - } - case "forever" ⇒ Long.MaxValue + case "default" ⇒ if (positive) defaultCachePolicy else defaultNegativeCachePolicy + case "forever" ⇒ Forever case "never" ⇒ Never - case _ ⇒ config.getDuration(path, TimeUnit.MILLISECONDS) - .requiring(_ > 0, s"akka.io.dns.$path must be 'default', 'forever', 'never' or positive duration") + case _ ⇒ { + val finiteTtl = config + .getDuration(path, TimeUnit.SECONDS) + .requiring(_ > 0, s"akka.io.dns.$path must be 'default', 'forever', 'never' or positive duration") + Ttl.fromPositive(finiteTtl.seconds) + } } - val positiveTtl: Long = getTtl("positive-ttl", positive = true) - val negativeTtl: Long = getTtl("negative-ttl", positive = false) + val positiveCachePolicy: CachePolicy = getTtl("positive-ttl", positive = true) + val negativeCachePolicy: CachePolicy = getTtl("negative-ttl", positive = false) + @deprecated("Use positiveCacheDuration instead", "2.5.17") + val positiveTtl: Long = toLongTtl(positiveCachePolicy) + @deprecated("Use negativeCacheDuration instead", "2.5.17") + val negativeTtl: Long = toLongTtl(negativeCachePolicy) + + private def toLongTtl(cp: CachePolicy): Long = { + cp match { + case Forever ⇒ Long.MaxValue + case Never ⇒ 0 + case Ttl(ttl) ⇒ ttl.toMillis + } + } override def receive = { case Dns.Resolve(name) ⇒ @@ -72,12 +93,12 @@ class InetAddressDnsResolver(cache: SimpleDnsCache, config: Config) extends Acto case None ⇒ try { val answer = Dns.Resolved(name, InetAddress.getAllByName(name)) - if (positiveTtl != Never) cache.put(answer, positiveTtl) + if (positiveCachePolicy != Never) cache.put(answer, positiveCachePolicy) answer } catch { case e: UnknownHostException ⇒ val answer = Dns.Resolved(name, immutable.Seq.empty, immutable.Seq.empty) - if (negativeTtl != Never) cache.put(answer, negativeTtl) + if (negativeCachePolicy != Never) cache.put(answer, negativeCachePolicy) answer } } diff --git a/akka-actor/src/main/scala/akka/io/SimpleDnsCache.scala b/akka-actor/src/main/scala/akka/io/SimpleDnsCache.scala index 8cc6bdf340..bcdf0ba682 100644 --- a/akka-actor/src/main/scala/akka/io/SimpleDnsCache.scala +++ b/akka-actor/src/main/scala/akka/io/SimpleDnsCache.scala @@ -8,6 +8,7 @@ import java.util.concurrent.atomic.AtomicReference import akka.annotation.InternalApi import akka.io.Dns.Resolved +import akka.io.dns.CachePolicy._ import scala.annotation.tailrec import scala.collection.immutable @@ -36,10 +37,10 @@ class SimpleDnsCache extends Dns with PeriodicCacheCleanup { } @tailrec - private[io] final def put(r: Resolved, ttlMillis: Long): Unit = { + private[io] final def put(r: Resolved, ttl: CachePolicy): Unit = { val c = cache.get() - if (!cache.compareAndSet(c, c.put(r.name, r, ttlMillis))) - put(r, ttlMillis) + if (!cache.compareAndSet(c, c.put(r.name, r, ttl))) + put(r, ttl) } @tailrec @@ -64,9 +65,12 @@ object SimpleDnsCache { } yield e.answer } - def put(name: K, answer: V, ttlMillis: Long): Cache[K, V] = { - val until0 = clock() + ttlMillis - val until = if (until0 < 0) Long.MaxValue else until0 + def put(name: K, answer: V, ttl: CachePolicy): Cache[K, V] = { + val until = ttl match { + case Forever ⇒ Long.MaxValue + case Never ⇒ clock() - 1 + case Ttl(ttl) ⇒ clock() + ttl.toMillis + } new Cache[K, V]( queue + new ExpiryEntry[K](name, until), diff --git a/akka-actor/src/main/scala/akka/io/dns/CachePolicy.scala b/akka-actor/src/main/scala/akka/io/dns/CachePolicy.scala new file mode 100644 index 0000000000..27112410ee --- /dev/null +++ b/akka-actor/src/main/scala/akka/io/dns/CachePolicy.scala @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2018 Lightbend Inc. + */ + +package akka.io.dns + +import akka.annotation.ApiMayChange +import akka.util.JavaDurationConverters._ + +import scala.concurrent.duration.{ Duration, FiniteDuration, _ } + +object CachePolicy { + + @ApiMayChange + sealed trait CachePolicy + @ApiMayChange + case object Never extends CachePolicy + @ApiMayChange + case object Forever extends CachePolicy + @ApiMayChange + final class Ttl private (val value: FiniteDuration) extends CachePolicy { + if (value <= Duration.Zero) throw new IllegalArgumentException(s"TTL values must be a positive value.") + import akka.util.JavaDurationConverters._ + def getValue: java.time.Duration = value.asJava + + override def equals(other: Any): Boolean = other match { + case that: Ttl ⇒ value == that.value + case _ ⇒ false + } + + override def hashCode(): Int = value.hashCode() + + override def toString = s"Ttl($value)" + } + @ApiMayChange + object Ttl { + def unapply(ttl: Ttl): Option[FiniteDuration] = Some(ttl.value) + def fromPositive(value: FiniteDuration): Ttl = { + new Ttl(value) + } + def fromPositive(value: java.time.Duration): Ttl = fromPositive(value.asScala) + + // There's places where only a Ttl makes sense (DNS RFC says TTL is a positive 32 bit integer) + // but we know the value can be cached effectively forever (e.g. the Lookup name was the actual IP already) + val effectivelyForever: Ttl = fromPositive(Int.MaxValue.seconds) + + implicit object TtlIsOrdered extends Ordering[Ttl] { + def compare(a: Ttl, b: Ttl) = a.value.compare(b.value) + } + } + +} diff --git a/akka-actor/src/main/scala/akka/io/dns/DnsResourceRecords.scala b/akka-actor/src/main/scala/akka/io/dns/DnsResourceRecords.scala index 388a46b560..e20d317424 100644 --- a/akka-actor/src/main/scala/akka/io/dns/DnsResourceRecords.scala +++ b/akka-actor/src/main/scala/akka/io/dns/DnsResourceRecords.scala @@ -8,19 +8,21 @@ import java.net.{ Inet4Address, Inet6Address, InetAddress } import akka.actor.NoSerializationVerificationNeeded import akka.annotation.{ ApiMayChange, InternalApi } +import CachePolicy._ import akka.io.dns.internal.{ DomainName, _ } -import akka.util.{ ByteIterator, ByteString, ByteStringBuilder } +import akka.util.{ ByteIterator, ByteString } import scala.annotation.switch +import scala.concurrent.duration._ @ApiMayChange -sealed abstract class ResourceRecord(val name: String, val ttlInSeconds: Int, val recType: Short, val recClass: Short) +sealed abstract class ResourceRecord(val name: String, val ttl: Ttl, val recType: Short, val recClass: Short) extends NoSerializationVerificationNeeded { } @ApiMayChange -final case class ARecord(override val name: String, override val ttlInSeconds: Int, - ip: InetAddress) extends ResourceRecord(name, ttlInSeconds, RecordType.A.code, RecordClass.IN.code) { +final case class ARecord(override val name: String, override val ttl: Ttl, + ip: InetAddress) extends ResourceRecord(name, ttl, RecordType.A.code, RecordClass.IN.code) { } /** @@ -28,16 +30,16 @@ final case class ARecord(override val name: String, override val ttlInSeconds: I */ @InternalApi private[dns] object ARecord { - def parseBody(name: String, ttlInSeconds: Int, length: Short, it: ByteIterator): ARecord = { + def parseBody(name: String, ttl: Ttl, length: Short, it: ByteIterator): ARecord = { val addr = Array.ofDim[Byte](4) it.getBytes(addr) - ARecord(name, ttlInSeconds, InetAddress.getByAddress(addr).asInstanceOf[Inet4Address]) + ARecord(name, ttl, InetAddress.getByAddress(addr).asInstanceOf[Inet4Address]) } } @ApiMayChange -final case class AAAARecord(override val name: String, override val ttlInSeconds: Int, - ip: Inet6Address) extends ResourceRecord(name, ttlInSeconds, RecordType.AAAA.code, RecordClass.IN.code) { +final case class AAAARecord(override val name: String, override val ttl: Ttl, + ip: Inet6Address) extends ResourceRecord(name, ttl, RecordType.AAAA.code, RecordClass.IN.code) { } /** @@ -50,16 +52,16 @@ private[dns] object AAAARecord { * INTERNAL API */ @InternalApi - def parseBody(name: String, ttlInSeconds: Int, length: Short, it: ByteIterator): AAAARecord = { + def parseBody(name: String, ttl: Ttl, length: Short, it: ByteIterator): AAAARecord = { val addr = Array.ofDim[Byte](16) it.getBytes(addr) - AAAARecord(name, ttlInSeconds, InetAddress.getByAddress(addr).asInstanceOf[Inet6Address]) + AAAARecord(name, ttl, InetAddress.getByAddress(addr).asInstanceOf[Inet6Address]) } } @ApiMayChange -final case class CNameRecord(override val name: String, override val ttlInSeconds: Int, - canonicalName: String) extends ResourceRecord(name, ttlInSeconds, RecordType.CNAME.code, RecordClass.IN.code) { +final case class CNameRecord(override val name: String, override val ttl: Ttl, + canonicalName: String) extends ResourceRecord(name, ttl, RecordType.CNAME.code, RecordClass.IN.code) { } @InternalApi @@ -68,14 +70,14 @@ private[dns] object CNameRecord { * INTERNAL API */ @InternalApi - def parseBody(name: String, ttlInSeconds: Int, length: Short, it: ByteIterator, msg: ByteString): CNameRecord = { - CNameRecord(name, ttlInSeconds, DomainName.parse(it, msg)) + def parseBody(name: String, ttl: Ttl, length: Short, it: ByteIterator, msg: ByteString): CNameRecord = { + CNameRecord(name, ttl, DomainName.parse(it, msg)) } } @ApiMayChange -final case class SRVRecord(override val name: String, override val ttlInSeconds: Int, - priority: Int, weight: Int, port: Int, target: String) extends ResourceRecord(name, ttlInSeconds, RecordType.SRV.code, RecordClass.IN.code) { +final case class SRVRecord(override val name: String, override val ttl: Ttl, + priority: Int, weight: Int, port: Int, target: String) extends ResourceRecord(name, ttl, RecordType.SRV.code, RecordClass.IN.code) { } /** @@ -87,18 +89,18 @@ private[dns] object SRVRecord { * INTERNAL API */ @InternalApi - def parseBody(name: String, ttlInSeconds: Int, length: Short, it: ByteIterator, msg: ByteString): SRVRecord = { + def parseBody(name: String, ttl: Ttl, length: Short, it: ByteIterator, msg: ByteString): SRVRecord = { val priority = it.getShort.toInt & 0xFFFF val weight = it.getShort.toInt & 0xFFFF val port = it.getShort.toInt & 0xFFFF - SRVRecord(name, ttlInSeconds, priority, weight, port, DomainName.parse(it, msg)) + SRVRecord(name, ttl, priority, weight, port, DomainName.parse(it, msg)) } } @ApiMayChange -final case class UnknownRecord(override val name: String, override val ttlInSeconds: Int, +final case class UnknownRecord(override val name: String, override val ttl: Ttl, override val recType: Short, override val recClass: Short, - data: ByteString) extends ResourceRecord(name, ttlInSeconds, recType, recClass) { + data: ByteString) extends ResourceRecord(name, ttl, recType, recClass) { } /** @@ -110,8 +112,8 @@ private[dns] object UnknownRecord { * INTERNAL API */ @InternalApi - def parseBody(name: String, ttlInSeconds: Int, recType: Short, recClass: Short, length: Short, it: ByteIterator): UnknownRecord = - UnknownRecord(name, ttlInSeconds, recType, recClass, it.toByteString) + def parseBody(name: String, ttl: Ttl, recType: Short, recClass: Short, length: Short, it: ByteIterator): UnknownRecord = + UnknownRecord(name, ttl, recType, recClass, it.toByteString) } /** @@ -127,7 +129,8 @@ private[dns] object ResourceRecord { val name = DomainName.parse(it, msg) val recType = it.getShort val recClass = it.getShort - val ttl = it.getInt + // According to https://www.ietf.org/rfc/rfc1035.txt: "TTL: positive values of a signed 32 bit number." + val ttl = Ttl.fromPositive(it.getInt.seconds) val rdLength = it.getShort val data = it.clone().take(rdLength) it.drop(rdLength) diff --git a/akka-actor/src/main/scala/akka/io/dns/internal/AsyncDnsCache.scala b/akka-actor/src/main/scala/akka/io/dns/internal/AsyncDnsCache.scala index 0f6ac8b9d3..607eac7665 100644 --- a/akka-actor/src/main/scala/akka/io/dns/internal/AsyncDnsCache.scala +++ b/akka-actor/src/main/scala/akka/io/dns/internal/AsyncDnsCache.scala @@ -8,6 +8,7 @@ import java.util.concurrent.atomic.AtomicReference import akka.annotation.InternalApi import akka.io.{ Dns, PeriodicCacheCleanup } +import akka.io.dns.CachePolicy.CachePolicy import scala.collection.immutable import akka.io.SimpleDnsCache._ @@ -16,12 +17,13 @@ import akka.io.dns.internal.DnsClient.Answer import akka.io.dns.{ AAAARecord, ARecord } import scala.annotation.tailrec +import scala.concurrent.duration._ /** * Internal API */ @InternalApi class AsyncDnsCache extends Dns with PeriodicCacheCleanup { - private val cache = new AtomicReference(new Cache[(String, QueryType), Answer]( + private val cacheRef = new AtomicReference(new Cache[(String, QueryType), Answer]( immutable.SortedSet()(expiryEntryOrdering()), immutable.Map(), clock)) @@ -33,8 +35,8 @@ import scala.annotation.tailrec */ override def cached(name: String): Option[Dns.Resolved] = { for { - ipv4 ← cache.get().get((name, Ipv4Type)) - ipv6 ← cache.get().get((name, Ipv6Type)) + ipv4 ← cacheRef.get().get((name, Ipv4Type)) + ipv6 ← cacheRef.get().get((name, Ipv6Type)) } yield { Dns.Resolved(name, (ipv4.rrs ++ ipv6.rrs).collect { case r: ARecord ⇒ r.ip @@ -51,20 +53,20 @@ import scala.annotation.tailrec } private[io] final def get(key: (String, QueryType)): Option[Answer] = { - cache.get().get(key) + cacheRef.get().get(key) } @tailrec - private[io] final def put(key: (String, QueryType), records: Answer, ttlInMillis: Long): Unit = { - val c = cache.get() - if (!cache.compareAndSet(c, c.put(key, records, ttlInMillis))) - put(key, records, ttlInMillis) + private[io] final def put(key: (String, QueryType), records: Answer, ttl: CachePolicy): Unit = { + val cache: Cache[(String, QueryType), Answer] = cacheRef.get() + if (!cacheRef.compareAndSet(cache, cache.put(key, records, ttl))) + put(key, records, ttl) } @tailrec override final def cleanup(): Unit = { - val c = cache.get() - if (!cache.compareAndSet(c, c.cleanup())) + val c = cacheRef.get() + if (!cacheRef.compareAndSet(c, c.cleanup())) cleanup() } } diff --git a/akka-actor/src/main/scala/akka/io/dns/internal/AsyncDnsResolver.scala b/akka-actor/src/main/scala/akka/io/dns/internal/AsyncDnsResolver.scala index f7adc94b76..ff9efcf02f 100644 --- a/akka-actor/src/main/scala/akka/io/dns/internal/AsyncDnsResolver.scala +++ b/akka-actor/src/main/scala/akka/io/dns/internal/AsyncDnsResolver.scala @@ -5,18 +5,17 @@ package akka.io.dns.internal import java.net.{ Inet4Address, Inet6Address, InetAddress, InetSocketAddress } -import java.nio.charset.StandardCharsets -import akka.actor.{ Actor, ActorLogging, ActorRef, ActorRefFactory, Props } +import akka.actor.{ Actor, ActorLogging, ActorRef, ActorRefFactory } import akka.annotation.InternalApi +import akka.io.dns.CachePolicy.Ttl import akka.io.dns.DnsProtocol.{ Ip, RequestType, Srv } import akka.io.dns.internal.DnsClient._ -import akka.io.dns.{ AAAARecord, ARecord, DnsProtocol, DnsSettings, ResourceRecord } +import akka.io.dns._ import akka.pattern.{ ask, pipe } import akka.util.{ Helpers, Timeout } -import scala.collection.immutable.Seq -import scala.collection.{ breakOut, immutable } +import scala.collection.immutable import scala.concurrent.Future import scala.util.Try import scala.util.control.NonFatal @@ -60,8 +59,8 @@ private[io] final class AsyncDnsResolver( Try { val address = InetAddress.getByName(name) // only checks validity, since known to be IP address val record = address match { - case _: Inet4Address ⇒ ARecord(name, Int.MaxValue, address) - case ipv6address: Inet6Address ⇒ AAAARecord(name, Int.MaxValue, ipv6address) + case _: Inet4Address ⇒ ARecord(name, Ttl.effectivelyForever, address) + case ipv6address: Inet6Address ⇒ AAAARecord(name, Ttl.effectivelyForever, ipv6address) } DnsProtocol.Resolved(name, record :: Nil) } @@ -116,13 +115,13 @@ private[io] final class AsyncDnsResolver( ipv4Recs.flatMap(ipv4Records ⇒ { // TODO, do we want config to specify a max for this? if (ipv4Records.rrs.nonEmpty) { - val minTtl4 = ipv4Records.rrs.minBy(_.ttlInSeconds).ttlInSeconds - cache.put((name, Ipv4Type), ipv4Records, minTtl4 * 1000) + val minTtl4 = ipv4Records.rrs.map(_.ttl).min + cache.put((name, Ipv4Type), ipv4Records, minTtl4) } ipv6Recs.map(ipv6Records ⇒ { if (ipv6Records.rrs.nonEmpty) { - val minTtl6 = ipv6Records.rrs.minBy(_.ttlInSeconds).ttlInSeconds - cache.put((name, Ipv6Type), ipv6Records, minTtl6 * 1000) + val minTtl6 = ipv6Records.rrs.map(_.ttl).min + cache.put((name, Ipv6Type), ipv6Records, minTtl6) } ipv4Records.rrs ++ ipv6Records.rrs }).map(recs ⇒ DnsProtocol.Resolved(name, recs)) @@ -136,8 +135,8 @@ private[io] final class AsyncDnsResolver( sendQuestion(resolver, SrvQuestion(nextId(), caseFoldedName)) .map(answer ⇒ { if (answer.rrs.nonEmpty) { - val minttlInSeconds = answer.rrs.minBy(_.ttlInSeconds).ttlInSeconds - cache.put((name, SrvType), answer, minttlInSeconds * 1000) // cache uses ttl in millis + val minTtl = answer.rrs.map(_.ttl).min + cache.put((name, SrvType), answer, minTtl) } DnsProtocol.Resolved(name, answer.rrs, answer.additionalRecs) })