Additional records support for async dns (#25492)

* Additional records support for async dns

- Currently just used for SRV A/AAAA records
- Cache SRV records
- Rename Resolved.results to Resolved.records

* Review feedback
This commit is contained in:
Christopher Batey 2018-08-21 04:00:29 +01:00 committed by Konrad `ktoso` Malawski
parent 482eaea122
commit b72d3090af
7 changed files with 112 additions and 65 deletions

View file

@ -98,9 +98,9 @@ class AsyncDnsResolverIntegrationSpec extends AkkaSpec(
val answer = resolve(name, DnsProtocol.Ip(ipv6 = false)) val answer = resolve(name, DnsProtocol.Ip(ipv6 = false))
withClue(answer) { withClue(answer) {
answer.name shouldEqual name answer.name shouldEqual name
answer.results.size shouldEqual 1 answer.records.size shouldEqual 1
answer.results.head.name shouldEqual name answer.records.head.name shouldEqual name
answer.results.head.asInstanceOf[ARecord].ip shouldEqual InetAddress.getByName("192.168.1.20") answer.records.head.asInstanceOf[ARecord].ip shouldEqual InetAddress.getByName("192.168.1.20")
} }
} }
@ -108,7 +108,7 @@ class AsyncDnsResolverIntegrationSpec extends AkkaSpec(
val name = "a-double.akka.test" val name = "a-double.akka.test"
val answer = resolve(name) val answer = resolve(name)
answer.name shouldEqual name answer.name shouldEqual name
answer.results.map(_.asInstanceOf[ARecord].ip).toSet shouldEqual Set( answer.records.map(_.asInstanceOf[ARecord].ip).toSet shouldEqual Set(
InetAddress.getByName("192.168.1.21"), InetAddress.getByName("192.168.1.21"),
InetAddress.getByName("192.168.1.22") InetAddress.getByName("192.168.1.22")
) )
@ -118,14 +118,14 @@ class AsyncDnsResolverIntegrationSpec extends AkkaSpec(
val name = "aaaa-single.akka.test" val name = "aaaa-single.akka.test"
val answer = resolve(name) val answer = resolve(name)
answer.name shouldEqual name answer.name shouldEqual name
answer.results.map(_.asInstanceOf[AAAARecord].ip) shouldEqual Seq(InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:1")) answer.records.map(_.asInstanceOf[AAAARecord].ip) shouldEqual Seq(InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:1"))
} }
"resolve double AAAA records" in { "resolve double AAAA records" in {
val name = "aaaa-double.akka.test" val name = "aaaa-double.akka.test"
val answer = resolve(name) val answer = resolve(name)
answer.name shouldEqual name answer.name shouldEqual name
answer.results.map(_.asInstanceOf[AAAARecord].ip).toSet shouldEqual Set( answer.records.map(_.asInstanceOf[AAAARecord].ip).toSet shouldEqual Set(
InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:2"), InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:2"),
InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:3") InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:3")
) )
@ -136,12 +136,12 @@ class AsyncDnsResolverIntegrationSpec extends AkkaSpec(
val answer = resolve(name) val answer = resolve(name)
answer.name shouldEqual name answer.name shouldEqual name
answer.results.collect { case r: ARecord r.ip }.toSet shouldEqual Set( answer.records.collect { case r: ARecord r.ip }.toSet shouldEqual Set(
InetAddress.getByName("192.168.1.23"), InetAddress.getByName("192.168.1.23"),
InetAddress.getByName("192.168.1.24") InetAddress.getByName("192.168.1.24")
) )
answer.results.collect { case r: AAAARecord r.ip }.toSet shouldEqual Set( answer.records.collect { case r: AAAARecord r.ip }.toSet shouldEqual Set(
InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:4"), InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:4"),
InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:5") InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:5")
) )
@ -151,10 +151,10 @@ class AsyncDnsResolverIntegrationSpec extends AkkaSpec(
val name = "cname-ext.akka.test" val name = "cname-ext.akka.test"
val answer = (IO(Dns) ? DnsProtocol.Resolve(name)).mapTo[DnsProtocol.Resolved].futureValue val answer = (IO(Dns) ? DnsProtocol.Resolve(name)).mapTo[DnsProtocol.Resolved].futureValue
answer.name shouldEqual name answer.name shouldEqual name
answer.results.collect { case r: CNameRecord r.canonicalName }.toSet shouldEqual Set( answer.records.collect { case r: CNameRecord r.canonicalName }.toSet shouldEqual Set(
"a-single.akka.test2" "a-single.akka.test2"
) )
answer.results.collect { case r: ARecord r.ip }.toSet shouldEqual Set( answer.records.collect { case r: ARecord r.ip }.toSet shouldEqual Set(
InetAddress.getByName("192.168.2.20") InetAddress.getByName("192.168.2.20")
) )
} }
@ -163,10 +163,10 @@ class AsyncDnsResolverIntegrationSpec extends AkkaSpec(
val name = "cname-in.akka.test" val name = "cname-in.akka.test"
val answer = resolve(name) val answer = resolve(name)
answer.name shouldEqual name answer.name shouldEqual name
answer.results.collect { case r: CNameRecord r.canonicalName }.toSet shouldEqual Set( answer.records.collect { case r: CNameRecord r.canonicalName }.toSet shouldEqual Set(
"a-double.akka.test" "a-double.akka.test"
) )
answer.results.collect { case r: ARecord r.ip }.toSet shouldEqual Set( answer.records.collect { case r: ARecord r.ip }.toSet shouldEqual Set(
InetAddress.getByName("192.168.1.21"), InetAddress.getByName("192.168.1.21"),
InetAddress.getByName("192.168.1.22") InetAddress.getByName("192.168.1.22")
) )
@ -177,23 +177,23 @@ class AsyncDnsResolverIntegrationSpec extends AkkaSpec(
val answer = resolve("service.tcp.akka.test", Srv) val answer = resolve("service.tcp.akka.test", Srv)
answer.name shouldEqual name answer.name shouldEqual name
answer.results.collect { case r: SRVRecord r }.toSet shouldEqual Set( answer.records.collect { case r: SRVRecord r }.toSet shouldEqual Set(
SRVRecord("service.tcp.akka.test", 86400, 10, 60, 5060, "a-single.akka.test"), SRVRecord("service.tcp.akka.test", 86400, 10, 60, 5060, "a-single.akka.test"),
SRVRecord("service.tcp.akka.test", 86400, 10, 40, 5070, "a-double.akka.test") SRVRecord("service.tcp.akka.test", 86400, 10, 40, 5070, "a-double.akka.test")
) )
} }
"resolve same address twice" in { "resolve same address twice" in {
resolve("a-single.akka.test").results.map(_.asInstanceOf[ARecord].ip) shouldEqual Seq(InetAddress.getByName("192.168.1.20")) resolve("a-single.akka.test").records.map(_.asInstanceOf[ARecord].ip) shouldEqual Seq(InetAddress.getByName("192.168.1.20"))
resolve("a-single.akka.test").results.map(_.asInstanceOf[ARecord].ip) shouldEqual Seq(InetAddress.getByName("192.168.1.20")) resolve("a-single.akka.test").records.map(_.asInstanceOf[ARecord].ip) shouldEqual Seq(InetAddress.getByName("192.168.1.20"))
} }
"handle nonexistent domains" in { "handle nonexistent domains" in {
val answer = (IO(Dns) ? DnsProtocol.Resolve("nonexistent.akka.test")).mapTo[DnsProtocol.Resolved].futureValue val answer = (IO(Dns) ? DnsProtocol.Resolve("nonexistent.akka.test")).mapTo[DnsProtocol.Resolved].futureValue
answer.results shouldEqual List.empty answer.records shouldEqual List.empty
} }
def resolve(name: String, requestType: RequestType = Ip()) = { def resolve(name: String, requestType: RequestType = Ip()): DnsProtocol.Resolved = {
(IO(Dns) ? DnsProtocol.Resolve(name, requestType)).mapTo[DnsProtocol.Resolved].futureValue (IO(Dns) ? DnsProtocol.Resolve(name, requestType)).mapTo[DnsProtocol.Resolved].futureValue
} }

View file

@ -8,14 +8,15 @@ import java.net.{ Inet6Address, InetAddress }
import akka.actor.Status.Failure import akka.actor.Status.Failure
import akka.actor.{ ActorRef, ExtendedActorSystem, Props } import akka.actor.{ ActorRef, ExtendedActorSystem, Props }
import akka.io.dns.{ AAAARecord, ARecord, DnsSettings } import akka.io.dns.{ AAAARecord, ARecord, DnsSettings, SRVRecord }
import akka.testkit.{ AkkaSpec, ImplicitSender, TestProbe } import akka.testkit.{ AkkaSpec, ImplicitSender, TestProbe }
import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigFactory
import akka.io.dns.DnsProtocol._ import akka.io.dns.DnsProtocol._
import akka.io.dns.internal.AsyncDnsResolver.ResolveFailedException import akka.io.dns.internal.AsyncDnsResolver.ResolveFailedException
import akka.io.dns.internal.DnsClient.{ Answer, Question4, Question6, SrvQuestion } import akka.io.dns.internal.DnsClient.{ Answer, Question4, Question6, SrvQuestion }
import scala.concurrent.duration._
import scala.collection.immutable import scala.collection.{ immutable im }
class AsyncDnsResolverSpec extends AkkaSpec( class AsyncDnsResolverSpec extends AkkaSpec(
""" """
@ -30,9 +31,9 @@ class AsyncDnsResolverSpec extends AkkaSpec(
val r = resolver(List(dnsClient1.ref, dnsClient2.ref)) val r = resolver(List(dnsClient1.ref, dnsClient2.ref))
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false)) r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
dnsClient1.expectMsg(Question4(1, "cats.com")) dnsClient1.expectMsg(Question4(1, "cats.com"))
dnsClient1.reply(Answer(1, immutable.Seq())) dnsClient1.reply(Answer(1, im.Seq.empty))
dnsClient2.expectNoMessage() dnsClient2.expectNoMessage()
expectMsg(Resolved("cats.com", immutable.Seq())) expectMsg(Resolved("cats.com", im.Seq.empty))
} }
"move to next client if first fails" in { "move to next client if first fails" in {
@ -44,8 +45,8 @@ class AsyncDnsResolverSpec extends AkkaSpec(
dnsClient1.expectMsg(Question4(1, "cats.com")) dnsClient1.expectMsg(Question4(1, "cats.com"))
dnsClient1.reply(Failure(new RuntimeException("Nope"))) dnsClient1.reply(Failure(new RuntimeException("Nope")))
dnsClient2.expectMsg(Question4(2, "cats.com")) dnsClient2.expectMsg(Question4(2, "cats.com"))
dnsClient2.reply(Answer(2, immutable.Seq())) dnsClient2.reply(Answer(2, im.Seq.empty))
expectMsg(Resolved("cats.com", immutable.Seq())) expectMsg(Resolved("cats.com", im.Seq.empty))
} }
"move to next client if first times out" in { "move to next client if first times out" in {
@ -56,8 +57,8 @@ class AsyncDnsResolverSpec extends AkkaSpec(
// first will get ask timeout // first will get ask timeout
dnsClient1.expectMsg(Question4(1, "cats.com")) dnsClient1.expectMsg(Question4(1, "cats.com"))
dnsClient2.expectMsg(Question4(2, "cats.com")) dnsClient2.expectMsg(Question4(2, "cats.com"))
dnsClient2.reply(Answer(2, immutable.Seq())) dnsClient2.reply(Answer(2, im.Seq.empty))
expectMsg(Resolved("cats.com", immutable.Seq())) expectMsg(Resolved("cats.com", im.Seq.empty))
} }
"gets both A and AAAA records if requested" in { "gets both A and AAAA records if requested" in {
@ -66,11 +67,11 @@ class AsyncDnsResolverSpec extends AkkaSpec(
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = true)) r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = true))
dnsClient1.expectMsg(Question4(1, "cats.com")) dnsClient1.expectMsg(Question4(1, "cats.com"))
val ipv4Record = ARecord("cats.com", 100, InetAddress.getByName("127.0.0.1")) val ipv4Record = ARecord("cats.com", 100, InetAddress.getByName("127.0.0.1"))
dnsClient1.reply(Answer(1, immutable.Seq(ipv4Record))) dnsClient1.reply(Answer(1, im.Seq(ipv4Record)))
dnsClient1.expectMsg(Question6(2, "cats.com")) dnsClient1.expectMsg(Question6(2, "cats.com"))
val ipv6Record = AAAARecord("cats.com", 100, InetAddress.getByName("::1").asInstanceOf[Inet6Address]) val ipv6Record = AAAARecord("cats.com", 100, InetAddress.getByName("::1").asInstanceOf[Inet6Address])
dnsClient1.reply(Answer(2, immutable.Seq(ipv6Record))) dnsClient1.reply(Answer(2, im.Seq(ipv6Record)))
expectMsg(Resolved("cats.com", immutable.Seq(ipv4Record, ipv6Record))) expectMsg(Resolved("cats.com", im.Seq(ipv4Record, ipv6Record)))
} }
"fails if all dns clients timeout" in { "fails if all dns clients timeout" in {
@ -89,35 +90,51 @@ class AsyncDnsResolverSpec extends AkkaSpec(
val r = resolver(List(dnsClient1.ref, dnsClient2.ref)) val r = resolver(List(dnsClient1.ref, dnsClient2.ref))
r ! Resolve("cats.com", Srv) r ! Resolve("cats.com", Srv)
dnsClient1.expectMsg(SrvQuestion(1, "cats.com")) dnsClient1.expectMsg(SrvQuestion(1, "cats.com"))
dnsClient1.reply(Answer(1, immutable.Seq())) dnsClient1.reply(Answer(1, im.Seq.empty))
dnsClient2.expectNoMessage() dnsClient2.expectNoMessage()
expectMsg(Resolved("cats.com", immutable.Seq())) expectMsg(Resolved("cats.com", im.Seq.empty))
} }
"not hang when resolving raw IP address" in { "response immediately IP address" in {
import scala.concurrent.duration._
val name = "127.0.0.1" val name = "127.0.0.1"
val dnsClient1 = TestProbe() val dnsClient1 = TestProbe()
val r = resolver(List(dnsClient1.ref)) val r = resolver(List(dnsClient1.ref))
r ! Resolve(name) r ! Resolve(name)
dnsClient1.expectNoMessage(50.millis) dnsClient1.expectNoMessage(50.millis)
val answer = expectMsgType[Resolved] val answer = expectMsgType[Resolved]
answer.results.collect { case r: ARecord r }.toSet shouldEqual Set( answer.records.collect { case r: ARecord r }.toSet shouldEqual Set(
ARecord("127.0.0.1", Int.MaxValue, InetAddress.getByName("127.0.0.1")) ARecord("127.0.0.1", Int.MaxValue, InetAddress.getByName("127.0.0.1"))
) )
} }
"not hang when resolving raw IPv6 address" in {
import scala.concurrent.duration._ "response immediately for IPv6 address" in {
val name = "1:2:3:0:0:0:0:0" val name = "1:2:3:0:0:0:0:0"
val dnsClient1 = TestProbe() val dnsClient1 = TestProbe()
val r = resolver(List(dnsClient1.ref)) val r = resolver(List(dnsClient1.ref))
r ! Resolve(name) r ! Resolve(name)
dnsClient1.expectNoMessage(50.millis) dnsClient1.expectNoMessage(50.millis)
val answer = expectMsgType[Resolved] val answer = expectMsgType[Resolved]
answer.results.collect { case r: ARecord r }.toSet shouldEqual Set( answer.records.collect { case r: ARecord r }.toSet shouldEqual Set(
ARecord("1:2:3:0:0:0:0:0", Int.MaxValue, InetAddress.getByName("1:2:3:0:0:0:0:0")) ARecord("1:2:3:0:0:0:0:0", Int.MaxValue, InetAddress.getByName("1:2:3:0:0:0:0:0"))
) )
} }
"return additional records for SRV requests" in {
val dnsClient1 = TestProbe()
val dnsClient2 = TestProbe()
val r = resolver(List(dnsClient1.ref, dnsClient2.ref))
r ! Resolve("cats.com", Srv)
dnsClient1.expectMsg(SrvQuestion(1, "cats.com"))
val srvRecs = im.Seq(SRVRecord("cats.com", 5000, 1, 1, 1, "a.cats.com"))
val aRecs = im.Seq(ARecord("a.cats.com", 1, InetAddress.getByName("127.0.0.1")))
dnsClient1.reply(Answer(1, srvRecs, aRecs))
dnsClient2.expectNoMessage(50.millis)
expectMsg(Resolved("cats.com", srvRecs, aRecs))
// cached the second time, don't have the probe reply
r ! Resolve("cats.com", Srv)
expectMsg(Resolved("cats.com", srvRecs, aRecs))
}
} }
def resolver(clients: List[ActorRef]): ActorRef = { def resolver(clients: List[ActorRef]): ActorRef = {

View file

@ -9,7 +9,7 @@ import java.util
import akka.actor.NoSerializationVerificationNeeded import akka.actor.NoSerializationVerificationNeeded
import akka.annotation.ApiMayChange import akka.annotation.ApiMayChange
import scala.collection.immutable import scala.collection.{ immutable im }
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
/** /**
@ -60,12 +60,34 @@ object DnsProtocol {
def create(name: String, requestType: RequestType): Resolve = Resolve(name, requestType) def create(name: String, requestType: RequestType): Resolve = Resolve(name, requestType)
} }
/**
* @param name of the record
* @param records resource records for the query
* @param additionalRecords records that relate to the query but are not strictly answers
*/
@ApiMayChange @ApiMayChange
final case class Resolved(name: String, results: immutable.Seq[ResourceRecord]) extends NoSerializationVerificationNeeded { final case class Resolved(name: String, records: im.Seq[ResourceRecord], additionalRecords: im.Seq[ResourceRecord]) extends NoSerializationVerificationNeeded {
/** /**
* Java API * Java API
*
* Records for the query
*/ */
def getResults(): util.List[ResourceRecord] = results.asJava def getRecords(): util.List[ResourceRecord] = records.asJava
/**
* Java API
*
* Records that relate to the query but are not strickly answers e.g. A records for the records returned for an SRV query.
*
*/
def getAdditionalRecords(): util.List[ResourceRecord] = additionalRecords.asJava
}
@ApiMayChange
object Resolved {
@ApiMayChange
def apply(name: String, records: im.Seq[ResourceRecord]): Resolved =
new Resolved(name, records, Nil)
} }
} }

View file

@ -12,7 +12,8 @@ import akka.io.{ Dns, PeriodicCacheCleanup }
import scala.collection.immutable import scala.collection.immutable
import akka.io.SimpleDnsCache._ import akka.io.SimpleDnsCache._
import akka.io.dns.internal.AsyncDnsResolver.{ Ipv4Type, Ipv6Type, QueryType } import akka.io.dns.internal.AsyncDnsResolver.{ Ipv4Type, Ipv6Type, QueryType }
import akka.io.dns.{ AAAARecord, ARecord, ResourceRecord } import akka.io.dns.internal.DnsClient.Answer
import akka.io.dns.{ AAAARecord, ARecord }
import scala.annotation.tailrec import scala.annotation.tailrec
@ -20,7 +21,7 @@ import scala.annotation.tailrec
* Internal API * Internal API
*/ */
@InternalApi class AsyncDnsCache extends Dns with PeriodicCacheCleanup { @InternalApi class AsyncDnsCache extends Dns with PeriodicCacheCleanup {
private val cache = new AtomicReference(new Cache[(String, QueryType), immutable.Seq[ResourceRecord]]( private val cache = new AtomicReference(new Cache[(String, QueryType), Answer](
immutable.SortedSet()(expiryEntryOrdering()), immutable.SortedSet()(expiryEntryOrdering()),
immutable.Map(), clock)) immutable.Map(), clock))
@ -35,7 +36,7 @@ import scala.annotation.tailrec
ipv4 cache.get().get((name, Ipv4Type)) ipv4 cache.get().get((name, Ipv4Type))
ipv6 cache.get().get((name, Ipv6Type)) ipv6 cache.get().get((name, Ipv6Type))
} yield { } yield {
Dns.Resolved(name, (ipv4 ++ ipv6).collect { Dns.Resolved(name, (ipv4.rrs ++ ipv6.rrs).collect {
case r: ARecord r.ip case r: ARecord r.ip
case r: AAAARecord r.ip case r: AAAARecord r.ip
}) })
@ -49,12 +50,12 @@ import scala.annotation.tailrec
else (now - nanoBase) / 1000000 else (now - nanoBase) / 1000000
} }
private[io] final def get(key: (String, QueryType)): Option[immutable.Seq[ResourceRecord]] = { private[io] final def get(key: (String, QueryType)): Option[Answer] = {
cache.get().get(key) cache.get().get(key)
} }
@tailrec @tailrec
private[io] final def put(key: (String, QueryType), records: immutable.Seq[ResourceRecord], ttlMillis: Long): Unit = { private[io] final def put(key: (String, QueryType), records: Answer, ttlMillis: Long): Unit = {
val c = cache.get() val c = cache.get()
if (!cache.compareAndSet(c, c.put(key, records, ttlMillis))) if (!cache.compareAndSet(c, c.put(key, records, ttlMillis)))
put(key, records, ttlMillis) put(key, records, ttlMillis)

View file

@ -74,7 +74,7 @@ private[io] final class AsyncDnsManager(val ext: DnsExt) extends Actor
val adapted = DnsProtocol.Resolve(name) val adapted = DnsProtocol.Resolve(name)
val reply = (resolver ? adapted).mapTo[DnsProtocol.Resolved] val reply = (resolver ? adapted).mapTo[DnsProtocol.Resolved]
.map { asyncResolved .map { asyncResolved
val ips = asyncResolved.results.collect { case a: ARecord a.ip } val ips = asyncResolved.records.collect { case a: ARecord a.ip }
Dns.Resolved(asyncResolved.name, ips) Dns.Resolved(asyncResolved.name, ips)
} }
reply pipeTo sender reply pipeTo sender

View file

@ -75,8 +75,8 @@ private[io] final class AsyncDnsResolver(
} }
} }
private def sendQuestion(resolver: ActorRef, message: DnsQuestion): Future[Seq[ResourceRecord]] = { private def sendQuestion(resolver: ActorRef, message: DnsQuestion): Future[Answer] = {
val result = (resolver ? message).mapTo[Answer].map(_.rrs) val result = (resolver ? message).mapTo[Answer]
result.onFailure { result.onFailure {
case NonFatal(_) resolver ! DropRequest(message.id) case NonFatal(_) resolver ! DropRequest(message.id)
} }
@ -86,9 +86,9 @@ private[io] final class AsyncDnsResolver(
private def resolve(name: String, requestType: RequestType, resolver: ActorRef): Future[DnsProtocol.Resolved] = { private def resolve(name: String, requestType: RequestType, resolver: ActorRef): Future[DnsProtocol.Resolved] = {
log.debug("Attempting to resolve {} with {}", name, resolver) log.debug("Attempting to resolve {} with {}", name, resolver)
val caseFoldedName = Helpers.toRootLowerCase(name) val caseFoldedName = Helpers.toRootLowerCase(name)
val recs: Future[Seq[ResourceRecord]] = requestType match { requestType match {
case Ip(ipv4, ipv6) case Ip(ipv4, ipv6)
val ipv4Recs = if (ipv4) val ipv4Recs: Future[Answer] = if (ipv4)
cache.get((name, Ipv4Type)) match { cache.get((name, Ipv4Type)) match {
case Some(r) case Some(r)
log.debug("Ipv4 cached {}", r) log.debug("Ipv4 cached {}", r)
@ -112,28 +112,35 @@ private[io] final class AsyncDnsResolver(
ipv4Recs.flatMap(ipv4Records { ipv4Recs.flatMap(ipv4Records {
// TODO, do we want config to specify a max for this? // TODO, do we want config to specify a max for this?
if (ipv4Records.nonEmpty) { if (ipv4Records.rrs.nonEmpty) {
val minTtl4 = ipv4Records.minBy(_.ttl).ttl val minTtl4 = ipv4Records.rrs.minBy(_.ttl).ttl
cache.put((name, Ipv4Type), ipv4Records, minTtl4) cache.put((name, Ipv4Type), ipv4Records, minTtl4)
} }
ipv6Recs.map(ipv6Records { ipv6Recs.map(ipv6Records {
if (ipv6Records.nonEmpty) { if (ipv6Records.rrs.nonEmpty) {
val minTtl6 = ipv6Records.minBy(_.ttl).ttl val minTtl6 = ipv6Records.rrs.minBy(_.ttl).ttl
cache.put((name, Ipv6Type), ipv6Records, minTtl6) cache.put((name, Ipv6Type), ipv6Records, minTtl6)
} }
ipv4Records ++ ipv6Records ipv4Records.rrs ++ ipv6Records.rrs
}) }).map(recs DnsProtocol.Resolved(name, recs))
}) })
case Srv case Srv
cache.get((name, Ipv4Type)) match { cache.get((name, SrvType)) match {
case Some(r) Future.successful(r) case Some(r)
Future.successful(DnsProtocol.Resolved(name, r.rrs, r.additionalRecs))
case None case None
sendQuestion(resolver, SrvQuestion(nextId(), caseFoldedName)) sendQuestion(resolver, SrvQuestion(nextId(), caseFoldedName))
.map(r {
if (r.rrs.nonEmpty) {
val minTtl = r.rrs.minBy(_.ttl).ttl
cache.put((name, SrvType), r, minTtl)
}
DnsProtocol.Resolved(name, r.rrs, r.additionalRecs)
})
} }
} }
recs.map(result DnsProtocol.Resolved(name, result))
} }
} }
@ -154,7 +161,7 @@ private[io] object AsyncDnsResolver {
ipv4Address.findAllMatchIn(name).nonEmpty || ipv4Address.findAllMatchIn(name).nonEmpty ||
ipv6Address.findAllMatchIn(name).nonEmpty ipv6Address.findAllMatchIn(name).nonEmpty
private val Empty = Future.successful(immutable.Seq.empty[ResourceRecord]) private val Empty = Future.successful(Answer(-1, immutable.Seq.empty[ResourceRecord], immutable.Seq.empty[ResourceRecord]))
sealed trait QueryType sealed trait QueryType
final case object Ipv4Type extends QueryType final case object Ipv4Type extends QueryType

View file

@ -12,7 +12,7 @@ import akka.annotation.InternalApi
import akka.io.dns.{ RecordClass, RecordType, ResourceRecord } import akka.io.dns.{ RecordClass, RecordType, ResourceRecord }
import akka.io.{ IO, Udp } import akka.io.{ IO, Udp }
import scala.collection.immutable import scala.collection.{ immutable im }
import scala.util.Try import scala.util.Try
/** /**
@ -25,7 +25,7 @@ import scala.util.Try
final case class SrvQuestion(id: Short, name: String) extends DnsQuestion final case class SrvQuestion(id: Short, name: String) extends DnsQuestion
final case class Question4(id: Short, name: String) extends DnsQuestion final case class Question4(id: Short, name: String) extends DnsQuestion
final case class Question6(id: Short, name: String) extends DnsQuestion final case class Question6(id: Short, name: String) extends DnsQuestion
final case class Answer(id: Short, rrs: immutable.Seq[ResourceRecord]) extends NoSerializationVerificationNeeded final case class Answer(id: Short, rrs: im.Seq[ResourceRecord], additionalRecs: im.Seq[ResourceRecord] = Nil) extends NoSerializationVerificationNeeded
final case class DropRequest(id: Short) final case class DropRequest(id: Short)
} }
@ -56,7 +56,7 @@ import scala.util.Try
} }
private def message(name: String, id: Short, recordType: RecordType): Message = { private def message(name: String, id: Short, recordType: RecordType): Message = {
Message(id, MessageFlags(), immutable.Seq(Question(name, recordType, RecordClass.IN))) Message(id, MessageFlags(), im.Seq(Question(name, recordType, RecordClass.IN)))
} }
def ready(socket: ActorRef): Receive = { def ready(socket: ActorRef): Receive = {
@ -103,8 +103,8 @@ import scala.util.Try
log.debug(s"Received message from [{}]: [{}]", remote, data) log.debug(s"Received message from [{}]: [{}]", remote, data)
val msg = Message.parse(data) val msg = Message.parse(data)
log.debug(s"Decoded: $msg") log.debug(s"Decoded: $msg")
val recs = if (msg.flags.responseCode == ResponseCode.SUCCESS) msg.answerRecs else immutable.Seq.empty val (recs, additionalRecs) = if (msg.flags.responseCode == ResponseCode.SUCCESS) (msg.answerRecs, msg.additionalRecs) else (Nil, Nil)
val response = Answer(msg.id, recs) val response = Answer(msg.id, recs, additionalRecs)
inflightRequests.get(response.id) match { inflightRequests.get(response.id) match {
case Some(reply) case Some(reply)
reply ! response reply ! response