Additional records support for async dns (#25492)

* Additional records support for async dns

- Currently just used for SRV A/AAAA records
- Cache SRV records
- Rename Resolved.results to Resolved.records

* Review feedback
This commit is contained in:
Christopher Batey 2018-08-21 04:00:29 +01:00 committed by Konrad `ktoso` Malawski
parent 482eaea122
commit b72d3090af
7 changed files with 112 additions and 65 deletions

View file

@ -98,9 +98,9 @@ class AsyncDnsResolverIntegrationSpec extends AkkaSpec(
val answer = resolve(name, DnsProtocol.Ip(ipv6 = false))
withClue(answer) {
answer.name shouldEqual name
answer.results.size shouldEqual 1
answer.results.head.name shouldEqual name
answer.results.head.asInstanceOf[ARecord].ip shouldEqual InetAddress.getByName("192.168.1.20")
answer.records.size shouldEqual 1
answer.records.head.name shouldEqual name
answer.records.head.asInstanceOf[ARecord].ip shouldEqual InetAddress.getByName("192.168.1.20")
}
}
@ -108,7 +108,7 @@ class AsyncDnsResolverIntegrationSpec extends AkkaSpec(
val name = "a-double.akka.test"
val answer = resolve(name)
answer.name shouldEqual name
answer.results.map(_.asInstanceOf[ARecord].ip).toSet shouldEqual Set(
answer.records.map(_.asInstanceOf[ARecord].ip).toSet shouldEqual Set(
InetAddress.getByName("192.168.1.21"),
InetAddress.getByName("192.168.1.22")
)
@ -118,14 +118,14 @@ class AsyncDnsResolverIntegrationSpec extends AkkaSpec(
val name = "aaaa-single.akka.test"
val answer = resolve(name)
answer.name shouldEqual name
answer.results.map(_.asInstanceOf[AAAARecord].ip) shouldEqual Seq(InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:1"))
answer.records.map(_.asInstanceOf[AAAARecord].ip) shouldEqual Seq(InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:1"))
}
"resolve double AAAA records" in {
val name = "aaaa-double.akka.test"
val answer = resolve(name)
answer.name shouldEqual name
answer.results.map(_.asInstanceOf[AAAARecord].ip).toSet shouldEqual Set(
answer.records.map(_.asInstanceOf[AAAARecord].ip).toSet shouldEqual Set(
InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:2"),
InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:3")
)
@ -136,12 +136,12 @@ class AsyncDnsResolverIntegrationSpec extends AkkaSpec(
val answer = resolve(name)
answer.name shouldEqual name
answer.results.collect { case r: ARecord r.ip }.toSet shouldEqual Set(
answer.records.collect { case r: ARecord r.ip }.toSet shouldEqual Set(
InetAddress.getByName("192.168.1.23"),
InetAddress.getByName("192.168.1.24")
)
answer.results.collect { case r: AAAARecord r.ip }.toSet shouldEqual Set(
answer.records.collect { case r: AAAARecord r.ip }.toSet shouldEqual Set(
InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:4"),
InetAddress.getByName("fd4d:36b2:3eca:a2d8:0:0:0:5")
)
@ -151,10 +151,10 @@ class AsyncDnsResolverIntegrationSpec extends AkkaSpec(
val name = "cname-ext.akka.test"
val answer = (IO(Dns) ? DnsProtocol.Resolve(name)).mapTo[DnsProtocol.Resolved].futureValue
answer.name shouldEqual name
answer.results.collect { case r: CNameRecord r.canonicalName }.toSet shouldEqual Set(
answer.records.collect { case r: CNameRecord r.canonicalName }.toSet shouldEqual Set(
"a-single.akka.test2"
)
answer.results.collect { case r: ARecord r.ip }.toSet shouldEqual Set(
answer.records.collect { case r: ARecord r.ip }.toSet shouldEqual Set(
InetAddress.getByName("192.168.2.20")
)
}
@ -163,10 +163,10 @@ class AsyncDnsResolverIntegrationSpec extends AkkaSpec(
val name = "cname-in.akka.test"
val answer = resolve(name)
answer.name shouldEqual name
answer.results.collect { case r: CNameRecord r.canonicalName }.toSet shouldEqual Set(
answer.records.collect { case r: CNameRecord r.canonicalName }.toSet shouldEqual Set(
"a-double.akka.test"
)
answer.results.collect { case r: ARecord r.ip }.toSet shouldEqual Set(
answer.records.collect { case r: ARecord r.ip }.toSet shouldEqual Set(
InetAddress.getByName("192.168.1.21"),
InetAddress.getByName("192.168.1.22")
)
@ -177,23 +177,23 @@ class AsyncDnsResolverIntegrationSpec extends AkkaSpec(
val answer = resolve("service.tcp.akka.test", Srv)
answer.name shouldEqual name
answer.results.collect { case r: SRVRecord r }.toSet shouldEqual Set(
answer.records.collect { case r: SRVRecord r }.toSet shouldEqual Set(
SRVRecord("service.tcp.akka.test", 86400, 10, 60, 5060, "a-single.akka.test"),
SRVRecord("service.tcp.akka.test", 86400, 10, 40, 5070, "a-double.akka.test")
)
}
"resolve same address twice" in {
resolve("a-single.akka.test").results.map(_.asInstanceOf[ARecord].ip) shouldEqual Seq(InetAddress.getByName("192.168.1.20"))
resolve("a-single.akka.test").results.map(_.asInstanceOf[ARecord].ip) shouldEqual Seq(InetAddress.getByName("192.168.1.20"))
resolve("a-single.akka.test").records.map(_.asInstanceOf[ARecord].ip) shouldEqual Seq(InetAddress.getByName("192.168.1.20"))
resolve("a-single.akka.test").records.map(_.asInstanceOf[ARecord].ip) shouldEqual Seq(InetAddress.getByName("192.168.1.20"))
}
"handle nonexistent domains" in {
val answer = (IO(Dns) ? DnsProtocol.Resolve("nonexistent.akka.test")).mapTo[DnsProtocol.Resolved].futureValue
answer.results shouldEqual List.empty
answer.records shouldEqual List.empty
}
def resolve(name: String, requestType: RequestType = Ip()) = {
def resolve(name: String, requestType: RequestType = Ip()): DnsProtocol.Resolved = {
(IO(Dns) ? DnsProtocol.Resolve(name, requestType)).mapTo[DnsProtocol.Resolved].futureValue
}

View file

@ -8,14 +8,15 @@ import java.net.{ Inet6Address, InetAddress }
import akka.actor.Status.Failure
import akka.actor.{ ActorRef, ExtendedActorSystem, Props }
import akka.io.dns.{ AAAARecord, ARecord, DnsSettings }
import akka.io.dns.{ AAAARecord, ARecord, DnsSettings, SRVRecord }
import akka.testkit.{ AkkaSpec, ImplicitSender, TestProbe }
import com.typesafe.config.ConfigFactory
import akka.io.dns.DnsProtocol._
import akka.io.dns.internal.AsyncDnsResolver.ResolveFailedException
import akka.io.dns.internal.DnsClient.{ Answer, Question4, Question6, SrvQuestion }
import scala.concurrent.duration._
import scala.collection.immutable
import scala.collection.{ immutable im }
class AsyncDnsResolverSpec extends AkkaSpec(
"""
@ -30,9 +31,9 @@ class AsyncDnsResolverSpec extends AkkaSpec(
val r = resolver(List(dnsClient1.ref, dnsClient2.ref))
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
dnsClient1.expectMsg(Question4(1, "cats.com"))
dnsClient1.reply(Answer(1, immutable.Seq()))
dnsClient1.reply(Answer(1, im.Seq.empty))
dnsClient2.expectNoMessage()
expectMsg(Resolved("cats.com", immutable.Seq()))
expectMsg(Resolved("cats.com", im.Seq.empty))
}
"move to next client if first fails" in {
@ -44,8 +45,8 @@ class AsyncDnsResolverSpec extends AkkaSpec(
dnsClient1.expectMsg(Question4(1, "cats.com"))
dnsClient1.reply(Failure(new RuntimeException("Nope")))
dnsClient2.expectMsg(Question4(2, "cats.com"))
dnsClient2.reply(Answer(2, immutable.Seq()))
expectMsg(Resolved("cats.com", immutable.Seq()))
dnsClient2.reply(Answer(2, im.Seq.empty))
expectMsg(Resolved("cats.com", im.Seq.empty))
}
"move to next client if first times out" in {
@ -56,8 +57,8 @@ class AsyncDnsResolverSpec extends AkkaSpec(
// first will get ask timeout
dnsClient1.expectMsg(Question4(1, "cats.com"))
dnsClient2.expectMsg(Question4(2, "cats.com"))
dnsClient2.reply(Answer(2, immutable.Seq()))
expectMsg(Resolved("cats.com", immutable.Seq()))
dnsClient2.reply(Answer(2, im.Seq.empty))
expectMsg(Resolved("cats.com", im.Seq.empty))
}
"gets both A and AAAA records if requested" in {
@ -66,11 +67,11 @@ class AsyncDnsResolverSpec extends AkkaSpec(
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = true))
dnsClient1.expectMsg(Question4(1, "cats.com"))
val ipv4Record = ARecord("cats.com", 100, InetAddress.getByName("127.0.0.1"))
dnsClient1.reply(Answer(1, immutable.Seq(ipv4Record)))
dnsClient1.reply(Answer(1, im.Seq(ipv4Record)))
dnsClient1.expectMsg(Question6(2, "cats.com"))
val ipv6Record = AAAARecord("cats.com", 100, InetAddress.getByName("::1").asInstanceOf[Inet6Address])
dnsClient1.reply(Answer(2, immutable.Seq(ipv6Record)))
expectMsg(Resolved("cats.com", immutable.Seq(ipv4Record, ipv6Record)))
dnsClient1.reply(Answer(2, im.Seq(ipv6Record)))
expectMsg(Resolved("cats.com", im.Seq(ipv4Record, ipv6Record)))
}
"fails if all dns clients timeout" in {
@ -89,35 +90,51 @@ class AsyncDnsResolverSpec extends AkkaSpec(
val r = resolver(List(dnsClient1.ref, dnsClient2.ref))
r ! Resolve("cats.com", Srv)
dnsClient1.expectMsg(SrvQuestion(1, "cats.com"))
dnsClient1.reply(Answer(1, immutable.Seq()))
dnsClient1.reply(Answer(1, im.Seq.empty))
dnsClient2.expectNoMessage()
expectMsg(Resolved("cats.com", immutable.Seq()))
expectMsg(Resolved("cats.com", im.Seq.empty))
}
"not hang when resolving raw IP address" in {
import scala.concurrent.duration._
"response immediately IP address" in {
val name = "127.0.0.1"
val dnsClient1 = TestProbe()
val r = resolver(List(dnsClient1.ref))
r ! Resolve(name)
dnsClient1.expectNoMessage(50.millis)
val answer = expectMsgType[Resolved]
answer.results.collect { case r: ARecord r }.toSet shouldEqual Set(
answer.records.collect { case r: ARecord r }.toSet shouldEqual Set(
ARecord("127.0.0.1", Int.MaxValue, InetAddress.getByName("127.0.0.1"))
)
}
"not hang when resolving raw IPv6 address" in {
import scala.concurrent.duration._
"response immediately for IPv6 address" in {
val name = "1:2:3:0:0:0:0:0"
val dnsClient1 = TestProbe()
val r = resolver(List(dnsClient1.ref))
r ! Resolve(name)
dnsClient1.expectNoMessage(50.millis)
val answer = expectMsgType[Resolved]
answer.results.collect { case r: ARecord r }.toSet shouldEqual Set(
answer.records.collect { case r: ARecord r }.toSet shouldEqual Set(
ARecord("1:2:3:0:0:0:0:0", Int.MaxValue, InetAddress.getByName("1:2:3:0:0:0:0:0"))
)
}
"return additional records for SRV requests" in {
val dnsClient1 = TestProbe()
val dnsClient2 = TestProbe()
val r = resolver(List(dnsClient1.ref, dnsClient2.ref))
r ! Resolve("cats.com", Srv)
dnsClient1.expectMsg(SrvQuestion(1, "cats.com"))
val srvRecs = im.Seq(SRVRecord("cats.com", 5000, 1, 1, 1, "a.cats.com"))
val aRecs = im.Seq(ARecord("a.cats.com", 1, InetAddress.getByName("127.0.0.1")))
dnsClient1.reply(Answer(1, srvRecs, aRecs))
dnsClient2.expectNoMessage(50.millis)
expectMsg(Resolved("cats.com", srvRecs, aRecs))
// cached the second time, don't have the probe reply
r ! Resolve("cats.com", Srv)
expectMsg(Resolved("cats.com", srvRecs, aRecs))
}
}
def resolver(clients: List[ActorRef]): ActorRef = {

View file

@ -9,7 +9,7 @@ import java.util
import akka.actor.NoSerializationVerificationNeeded
import akka.annotation.ApiMayChange
import scala.collection.immutable
import scala.collection.{ immutable im }
import scala.collection.JavaConverters._
/**
@ -60,12 +60,34 @@ object DnsProtocol {
def create(name: String, requestType: RequestType): Resolve = Resolve(name, requestType)
}
/**
* @param name of the record
* @param records resource records for the query
* @param additionalRecords records that relate to the query but are not strictly answers
*/
@ApiMayChange
final case class Resolved(name: String, results: immutable.Seq[ResourceRecord]) extends NoSerializationVerificationNeeded {
final case class Resolved(name: String, records: im.Seq[ResourceRecord], additionalRecords: im.Seq[ResourceRecord]) extends NoSerializationVerificationNeeded {
/**
* Java API
*
* Records for the query
*/
def getResults(): util.List[ResourceRecord] = results.asJava
def getRecords(): util.List[ResourceRecord] = records.asJava
/**
* Java API
*
* Records that relate to the query but are not strickly answers e.g. A records for the records returned for an SRV query.
*
*/
def getAdditionalRecords(): util.List[ResourceRecord] = additionalRecords.asJava
}
@ApiMayChange
object Resolved {
@ApiMayChange
def apply(name: String, records: im.Seq[ResourceRecord]): Resolved =
new Resolved(name, records, Nil)
}
}

View file

@ -12,7 +12,8 @@ import akka.io.{ Dns, PeriodicCacheCleanup }
import scala.collection.immutable
import akka.io.SimpleDnsCache._
import akka.io.dns.internal.AsyncDnsResolver.{ Ipv4Type, Ipv6Type, QueryType }
import akka.io.dns.{ AAAARecord, ARecord, ResourceRecord }
import akka.io.dns.internal.DnsClient.Answer
import akka.io.dns.{ AAAARecord, ARecord }
import scala.annotation.tailrec
@ -20,7 +21,7 @@ import scala.annotation.tailrec
* Internal API
*/
@InternalApi class AsyncDnsCache extends Dns with PeriodicCacheCleanup {
private val cache = new AtomicReference(new Cache[(String, QueryType), immutable.Seq[ResourceRecord]](
private val cache = new AtomicReference(new Cache[(String, QueryType), Answer](
immutable.SortedSet()(expiryEntryOrdering()),
immutable.Map(), clock))
@ -35,7 +36,7 @@ import scala.annotation.tailrec
ipv4 cache.get().get((name, Ipv4Type))
ipv6 cache.get().get((name, Ipv6Type))
} yield {
Dns.Resolved(name, (ipv4 ++ ipv6).collect {
Dns.Resolved(name, (ipv4.rrs ++ ipv6.rrs).collect {
case r: ARecord r.ip
case r: AAAARecord r.ip
})
@ -49,12 +50,12 @@ import scala.annotation.tailrec
else (now - nanoBase) / 1000000
}
private[io] final def get(key: (String, QueryType)): Option[immutable.Seq[ResourceRecord]] = {
private[io] final def get(key: (String, QueryType)): Option[Answer] = {
cache.get().get(key)
}
@tailrec
private[io] final def put(key: (String, QueryType), records: immutable.Seq[ResourceRecord], ttlMillis: Long): Unit = {
private[io] final def put(key: (String, QueryType), records: Answer, ttlMillis: Long): Unit = {
val c = cache.get()
if (!cache.compareAndSet(c, c.put(key, records, ttlMillis)))
put(key, records, ttlMillis)

View file

@ -74,7 +74,7 @@ private[io] final class AsyncDnsManager(val ext: DnsExt) extends Actor
val adapted = DnsProtocol.Resolve(name)
val reply = (resolver ? adapted).mapTo[DnsProtocol.Resolved]
.map { asyncResolved
val ips = asyncResolved.results.collect { case a: ARecord a.ip }
val ips = asyncResolved.records.collect { case a: ARecord a.ip }
Dns.Resolved(asyncResolved.name, ips)
}
reply pipeTo sender

View file

@ -75,8 +75,8 @@ private[io] final class AsyncDnsResolver(
}
}
private def sendQuestion(resolver: ActorRef, message: DnsQuestion): Future[Seq[ResourceRecord]] = {
val result = (resolver ? message).mapTo[Answer].map(_.rrs)
private def sendQuestion(resolver: ActorRef, message: DnsQuestion): Future[Answer] = {
val result = (resolver ? message).mapTo[Answer]
result.onFailure {
case NonFatal(_) resolver ! DropRequest(message.id)
}
@ -86,9 +86,9 @@ private[io] final class AsyncDnsResolver(
private def resolve(name: String, requestType: RequestType, resolver: ActorRef): Future[DnsProtocol.Resolved] = {
log.debug("Attempting to resolve {} with {}", name, resolver)
val caseFoldedName = Helpers.toRootLowerCase(name)
val recs: Future[Seq[ResourceRecord]] = requestType match {
requestType match {
case Ip(ipv4, ipv6)
val ipv4Recs = if (ipv4)
val ipv4Recs: Future[Answer] = if (ipv4)
cache.get((name, Ipv4Type)) match {
case Some(r)
log.debug("Ipv4 cached {}", r)
@ -112,28 +112,35 @@ private[io] final class AsyncDnsResolver(
ipv4Recs.flatMap(ipv4Records {
// TODO, do we want config to specify a max for this?
if (ipv4Records.nonEmpty) {
val minTtl4 = ipv4Records.minBy(_.ttl).ttl
if (ipv4Records.rrs.nonEmpty) {
val minTtl4 = ipv4Records.rrs.minBy(_.ttl).ttl
cache.put((name, Ipv4Type), ipv4Records, minTtl4)
}
ipv6Recs.map(ipv6Records {
if (ipv6Records.nonEmpty) {
val minTtl6 = ipv6Records.minBy(_.ttl).ttl
if (ipv6Records.rrs.nonEmpty) {
val minTtl6 = ipv6Records.rrs.minBy(_.ttl).ttl
cache.put((name, Ipv6Type), ipv6Records, minTtl6)
}
ipv4Records ++ ipv6Records
})
ipv4Records.rrs ++ ipv6Records.rrs
}).map(recs DnsProtocol.Resolved(name, recs))
})
case Srv
cache.get((name, Ipv4Type)) match {
case Some(r) Future.successful(r)
cache.get((name, SrvType)) match {
case Some(r)
Future.successful(DnsProtocol.Resolved(name, r.rrs, r.additionalRecs))
case None
sendQuestion(resolver, SrvQuestion(nextId(), caseFoldedName))
.map(r {
if (r.rrs.nonEmpty) {
val minTtl = r.rrs.minBy(_.ttl).ttl
cache.put((name, SrvType), r, minTtl)
}
DnsProtocol.Resolved(name, r.rrs, r.additionalRecs)
})
}
}
recs.map(result DnsProtocol.Resolved(name, result))
}
}
@ -154,7 +161,7 @@ private[io] object AsyncDnsResolver {
ipv4Address.findAllMatchIn(name).nonEmpty ||
ipv6Address.findAllMatchIn(name).nonEmpty
private val Empty = Future.successful(immutable.Seq.empty[ResourceRecord])
private val Empty = Future.successful(Answer(-1, immutable.Seq.empty[ResourceRecord], immutable.Seq.empty[ResourceRecord]))
sealed trait QueryType
final case object Ipv4Type extends QueryType

View file

@ -12,7 +12,7 @@ import akka.annotation.InternalApi
import akka.io.dns.{ RecordClass, RecordType, ResourceRecord }
import akka.io.{ IO, Udp }
import scala.collection.immutable
import scala.collection.{ immutable im }
import scala.util.Try
/**
@ -25,7 +25,7 @@ import scala.util.Try
final case class SrvQuestion(id: Short, name: String) extends DnsQuestion
final case class Question4(id: Short, name: String) extends DnsQuestion
final case class Question6(id: Short, name: String) extends DnsQuestion
final case class Answer(id: Short, rrs: immutable.Seq[ResourceRecord]) extends NoSerializationVerificationNeeded
final case class Answer(id: Short, rrs: im.Seq[ResourceRecord], additionalRecs: im.Seq[ResourceRecord] = Nil) extends NoSerializationVerificationNeeded
final case class DropRequest(id: Short)
}
@ -56,7 +56,7 @@ import scala.util.Try
}
private def message(name: String, id: Short, recordType: RecordType): Message = {
Message(id, MessageFlags(), immutable.Seq(Question(name, recordType, RecordClass.IN)))
Message(id, MessageFlags(), im.Seq(Question(name, recordType, RecordClass.IN)))
}
def ready(socket: ActorRef): Receive = {
@ -103,8 +103,8 @@ import scala.util.Try
log.debug(s"Received message from [{}]: [{}]", remote, data)
val msg = Message.parse(data)
log.debug(s"Decoded: $msg")
val recs = if (msg.flags.responseCode == ResponseCode.SUCCESS) msg.answerRecs else immutable.Seq.empty
val response = Answer(msg.id, recs)
val (recs, additionalRecs) = if (msg.flags.responseCode == ResponseCode.SUCCESS) (msg.answerRecs, msg.additionalRecs) else (Nil, Nil)
val response = Answer(msg.id, recs, additionalRecs)
inflightRequests.get(response.id) match {
case Some(reply)
reply ! response