CVE-2023-31442 Address DNS poisoning vulnerability (#385)

* CVE-2023-31442 Address DNS poisoning vulnerability (and DNS concurrency bug)

* Remove sequential dns id generator

* fix scalafmt

* Fix bug in isSameQuestion
And ensure that DnsClient only removes inflight messages when the
questions match

* fix up exception message to remove reference to 'sequence' generator

* Add tests to failed commands and drop requests

---------

Co-authored-by: PJ Fanning <pjfanning@users.noreply.github.com>
This commit is contained in:
Iain Hull 2023-06-14 19:38:20 +01:00 committed by GitHub
parent fbf923fc68
commit c56edca78f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 602 additions and 80 deletions

View file

@ -34,6 +34,7 @@ class DnsSettingsSpec extends PekkoSpec {
ndots = 1
positive-ttl = forever
negative-ttl = never
id-generator-policy = thread-local-random
""")
"DNS settings" must {
@ -137,6 +138,10 @@ class DnsSettingsSpec extends PekkoSpec {
dnsSettingsDuration.PositiveCachePolicy shouldEqual CachePolicy.Ttl.fromPositive(10.seconds)
dnsSettingsDuration.NegativeCachePolicy shouldEqual CachePolicy.Ttl.fromPositive(10.days)
}
}
"parse id-generator-policy" in {
val dnsSettings = new DnsSettings(eas, defaultConfig)
dnsSettings.IdGeneratorPolicy shouldEqual (IdGenerator.Policy.ThreadLocalRandom)
}
}
}

View file

@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* license agreements; and to You under the Apache License, version 2.0:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* This file is part of the Apache Pekko project, derived from Akka.
*/
package org.apache.pekko.io.dns
import org.apache.pekko.testkit.PekkoSpec
class IdGeneratorSpec extends PekkoSpec {
"IdGenerator" must {
"provide a thread-local-random" in {
val gen = IdGenerator(IdGenerator.Policy.ThreadLocalRandom)
gen.nextId() should be < Short.MaxValue
}
"provide a secure-random" in {
val gen = IdGenerator(IdGenerator.Policy.SecureRandom)
gen.nextId() should be < Short.MaxValue
}
}
}

View file

@ -19,16 +19,15 @@ import scala.collection.{ immutable => im }
import scala.concurrent.duration._
import com.typesafe.config.{ Config, ConfigFactory, ConfigValueFactory }
import org.apache.pekko
import pekko.actor.{ ActorRef, ExtendedActorSystem, Props }
import pekko.actor.Status.Failure
import pekko.io.SimpleDnsCache
import pekko.io.dns.{ AAAARecord, ARecord, DnsSettings, SRVRecord }
import pekko.io.dns.{ AAAARecord, ARecord, DnsSettings, IdGenerator, SRVRecord }
import pekko.io.dns.CachePolicy.Ttl
import pekko.io.dns.DnsProtocol._
import pekko.io.dns.internal.AsyncDnsResolver.ResolveFailedException
import pekko.io.dns.internal.DnsClient.{ Answer, Question4, Question6, SrvQuestion }
import pekko.io.dns.internal.DnsClient.{ Answer, DuplicateId, Question4, Question6, SrvQuestion }
import pekko.testkit.{ PekkoSpec, TestProbe, WithLogCapturing }
class AsyncDnsResolverSpec extends PekkoSpec("""
@ -56,8 +55,11 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
"Async DNS Resolver" must {
"use dns clients in order" in new Setup {
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
dnsClient1.expectMsg(Question4(1, "cats.com"))
dnsClient1.reply(Answer(1, im.Seq.empty))
val id = dnsClient1.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" =>
q4.id
}
dnsClient1.reply(Answer(id, im.Seq.empty))
dnsClient2.expectNoMessage()
senderProbe.expectMsg(Resolved("cats.com", im.Seq.empty))
}
@ -65,31 +67,65 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
"move to next client if first fails" in new Setup {
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
// first will get ask timeout
dnsClient1.expectMsg(Question4(1, "cats.com"))
val firstId = dnsClient1.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" =>
q4.id
}
dnsClient1.reply(Failure(new RuntimeException("Nope")))
dnsClient2.expectMsg(Question4(2, "cats.com"))
dnsClient2.reply(Answer(2, im.Seq.empty))
val secondId = dnsClient2.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" && q4.id != firstId =>
q4.id
}
dnsClient2.reply(Answer(secondId, im.Seq.empty))
senderProbe.expectMsg(Resolved("cats.com", im.Seq.empty))
}
"move to next client if first times out" in new Setup {
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
// first will get ask timeout
dnsClient1.expectMsg(Question4(1, "cats.com"))
dnsClient2.expectMsg(Question4(2, "cats.com"))
dnsClient2.reply(Answer(2, im.Seq.empty))
val firstId = dnsClient1.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" =>
q4.id
}
val secondId = dnsClient2.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" && q4.id != firstId =>
q4.id
}
dnsClient2.reply(Answer(secondId, im.Seq.empty))
senderProbe.expectMsg(Resolved("cats.com", im.Seq.empty))
}
"handle duplicate Ids in dnsClient" in new Setup {
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
val firstId = dnsClient1.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" =>
q4.id
}
dnsClient1.reply(DuplicateId(firstId))
val secondId = dnsClient1.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" && q4.id != firstId =>
q4.id
}
dnsClient1.reply(Answer(secondId, im.Seq.empty))
dnsClient2.expectNoMessage()
senderProbe.expectMsg(Resolved("cats.com", im.Seq.empty))
}
"gets both A and AAAA records if requested" in new Setup {
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = true))
dnsClient1.expectMsg(Question4(1, "cats.com"))
val firstId = dnsClient1.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" =>
q4.id
}
val ttl = Ttl.fromPositive(100.seconds)
val ipv4Record = ARecord("cats.com", ttl, InetAddress.getByName("127.0.0.1"))
dnsClient1.reply(Answer(1, im.Seq(ipv4Record)))
dnsClient1.expectMsg(Question6(2, "cats.com"))
dnsClient1.reply(Answer(firstId, im.Seq(ipv4Record)))
val secondId = dnsClient1.expectMsgPF() {
case q6: Question6 if q6.name == "cats.com" && q6.id != firstId =>
q6.id
}
val ipv6Record = AAAARecord("cats.com", ttl, InetAddress.getByName("::1").asInstanceOf[Inet6Address])
dnsClient1.reply(Answer(2, im.Seq(ipv6Record)))
dnsClient1.reply(Answer(secondId, im.Seq(ipv6Record)))
senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record, ipv6Record)))
}
@ -102,9 +138,15 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
"fails if all dns clients fail" in new Setup {
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
dnsClient1.expectMsg(Question4(1, "cats.com"))
val firstId = dnsClient1.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" =>
q4.id
}
dnsClient1.reply(Failure(new RuntimeException("Fail")))
dnsClient2.expectMsg(Question4(2, "cats.com"))
dnsClient2.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" && q4.id != firstId =>
q4.id
}
dnsClient2.reply(Failure(new RuntimeException("Yet another fail")))
senderProbe.expectMsgPF(remainingOrDefault) {
case Failure(ResolveFailedException(_)) =>
@ -113,8 +155,11 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
"gets SRV records if requested" in new Setup {
r ! Resolve("cats.com", Srv)
dnsClient1.expectMsg(SrvQuestion(1, "cats.com"))
dnsClient1.reply(Answer(1, im.Seq.empty))
val firstId = dnsClient1.expectMsgPF() {
case srvQuestion: SrvQuestion if srvQuestion.name == "cats.com" =>
srvQuestion.id
}
dnsClient1.reply(Answer(firstId, im.Seq.empty))
dnsClient2.expectNoMessage()
senderProbe.expectMsg(Resolved("cats.com", im.Seq.empty))
}
@ -145,10 +190,13 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
"return additional records for SRV requests" in new Setup {
r ! Resolve("cats.com", Srv)
dnsClient1.expectMsg(SrvQuestion(1, "cats.com"))
val firstId = dnsClient1.expectMsgPF() {
case srvQuestion: SrvQuestion if srvQuestion.name == "cats.com" =>
srvQuestion.id
}
val srvRecs = im.Seq(SRVRecord("cats.com", Ttl.fromPositive(5000.seconds), 1, 1, 1, "a.cats.com"))
val aRecs = im.Seq(ARecord("a.cats.com", Ttl.fromPositive(1.seconds), InetAddress.getByName("127.0.0.1")))
dnsClient1.reply(Answer(1, srvRecs, aRecs))
dnsClient1.reply(Answer(firstId, srvRecs, aRecs))
dnsClient2.expectNoMessage(50.millis)
senderProbe.expectMsg(Resolved("cats.com", srvRecs, aRecs))
@ -162,8 +210,11 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
override val r = resolver(List(dnsClient1.ref), configWithSmallTtl)
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
dnsClient1.expectMsg(Question4(1, "cats.com"))
dnsClient1.reply(Answer(1, im.Seq.empty))
val firstId = dnsClient1.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" =>
q4.id
}
dnsClient1.reply(Answer(firstId, im.Seq.empty))
senderProbe.expectMsg(Resolved("cats.com", im.Seq()))
@ -178,8 +229,11 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
val ipv4Record = ARecord("cats.com", recordTtl, InetAddress.getByName("127.0.0.1"))
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
dnsClient1.expectMsg(Question4(1, "cats.com"))
dnsClient1.reply(Answer(1, im.Seq(ipv4Record)))
val firstId = dnsClient1.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" =>
q4.id
}
dnsClient1.reply(Answer(firstId, im.Seq(ipv4Record)))
senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record)))
@ -197,14 +251,20 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
val ipv4Record = ARecord("cats.com", recordTtl, InetAddress.getByName("127.0.0.1"))
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
dnsClient1.expectMsg(Question4(1, "cats.com"))
dnsClient1.reply(Answer(1, im.Seq(ipv4Record)))
val firstId = dnsClient1.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" =>
q4.id
}
dnsClient1.reply(Answer(firstId, im.Seq(ipv4Record)))
senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record)))
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
dnsClient1.expectMsg(Question4(2, "cats.com"))
dnsClient1.reply(Answer(2, im.Seq(ipv4Record)))
val secondId = dnsClient1.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" =>
q4.id
}
dnsClient1.reply(Answer(secondId, im.Seq(ipv4Record)))
senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record)))
}
@ -217,8 +277,11 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
val ipv4Record = ARecord("cats.com", recordTtl, InetAddress.getByName("127.0.0.1"))
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
dnsClient1.expectMsg(Question4(1, "cats.com"))
dnsClient1.reply(Answer(1, im.Seq(ipv4Record)))
val firstId = dnsClient1.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" =>
q4.id
}
dnsClient1.reply(Answer(firstId, im.Seq(ipv4Record)))
senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record)))
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
@ -228,8 +291,11 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
Thread.sleep(200)
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
dnsClient1.expectMsg(Question4(2, "cats.com"))
dnsClient1.reply(Answer(2, im.Seq(ipv4Record)))
val secondId = dnsClient1.expectMsgPF() {
case q4: Question4 if q4.name == "cats.com" =>
q4.id
}
dnsClient1.reply(Answer(secondId, im.Seq(ipv4Record)))
senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record)))
}
@ -240,6 +306,6 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
system.actorOf(Props(new AsyncDnsResolver(settings, new SimpleDnsCache(),
(_, _) => {
clients
})))
}, IdGenerator())))
}
}

View file

@ -13,24 +13,23 @@
package org.apache.pekko.io.dns.internal
import java.net.InetSocketAddress
import java.net.{ InetAddress, InetSocketAddress }
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.immutable.Seq
import scala.collection.{ immutable => im }
import org.apache.pekko
import org.apache.pekko.actor.Status.Failure
import pekko.actor.Props
import pekko.io.Udp
import pekko.io.dns.{ RecordClass, RecordType }
import pekko.io.dns.internal.DnsClient.{ Answer, Question4 }
import pekko.io.dns.{ ARecord, CachePolicy, RecordClass, RecordType }
import pekko.io.dns.internal.DnsClient.{ Answer, DropRequest, DuplicateId, Question4 }
import pekko.testkit.{ ImplicitSender, PekkoSpec, TestProbe }
class DnsClientSpec extends PekkoSpec with ImplicitSender {
"The async DNS client" should {
val exampleRequest = Question4(42, "pekko.io")
val exampleRequestMessage =
Message(42, MessageFlags(), questions = Seq(Question("pekko.io", RecordType.A, RecordClass.IN)))
val exampleResponseMessage = Message(42, MessageFlags(answer = true))
val exampleQuestion = Question("pekko.io", RecordType.A, RecordClass.IN)
val exampleRequestMessage = Message(42, MessageFlags(), questions = im.Seq(exampleQuestion))
val exampleResponseMessage = Message(42, MessageFlags(answer = true), questions = im.Seq(exampleQuestion))
val exampleResponse = Answer(42, Nil)
val dnsServerAddress = InetSocketAddress.createUnresolved("foo", 53)
@ -83,5 +82,215 @@ class DnsClientSpec extends PekkoSpec with ImplicitSender {
expectMsg(exampleResponse)
}
"Do not accept duplicate transaction ids" in {
val udpExtensionProbe = TestProbe()
val udpSocketProbe = TestProbe()
val tcpClientProbe = TestProbe()
val goodSenderProbe = TestProbe()
val badSenderProbe = TestProbe()
val client = system.actorOf(Props(new DnsClient(dnsServerAddress) {
override val udp = udpExtensionProbe.ref
override def createTcpClient() = tcpClientProbe.ref
}))
val badRequest = Question4(exampleRequest.id, "not." + exampleRequest.name)
udpExtensionProbe.expectMsgType[Udp.Bind]
udpSocketProbe.send(
udpExtensionProbe.lastSender,
Udp.Bound(InetSocketAddress.createUnresolved("localhost", 41325)))
goodSenderProbe.send(client, exampleRequest)
val udpSend = udpSocketProbe.expectMsgType[Udp.Send]
udpSend.payload shouldBe exampleRequestMessage.write()
badSenderProbe.send(client, badRequest)
badSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
val answer = Answer(exampleRequest.id, im.Seq(), im.Seq())
udpSocketProbe.reply(answer)
goodSenderProbe.expectMsg(answer)
udpSocketProbe.expectNoMessage()
}
"Verify original question when processing UDP replies (DNS poisoning)" in {
val udpExtensionProbe = TestProbe()
val udpSocketProbe = TestProbe()
val tcpClientProbe = TestProbe()
val goodSenderProbe = TestProbe()
val socket = InetSocketAddress.createUnresolved("localhost", 41325)
val goodRecord = ARecord(exampleRequest.name, CachePolicy.Ttl.never, InetAddress.getLocalHost())
val client = system.actorOf(Props(new DnsClient(dnsServerAddress) {
override val udp = udpExtensionProbe.ref
override def createTcpClient() = tcpClientProbe.ref
}))
udpExtensionProbe.expectMsgType[Udp.Bind]
udpSocketProbe.send(udpExtensionProbe.lastSender, Udp.Bound(socket))
goodSenderProbe.send(client, exampleRequest)
val udpSend = udpSocketProbe.expectMsgType[Udp.Send]
udpSend.payload shouldBe exampleRequestMessage.write()
val flags = MessageFlags(true, authoritativeAnswer = true)
val badId = exampleRequestMessage.copy(id = 999, flags = flags, answerRecs = im.Seq(goodRecord))
val badQuestion = exampleRequestMessage.copy(
flags = flags,
questions = im.Seq(exampleQuestion.copy(name = "not.com")),
answerRecs = im.Seq(goodRecord))
val goodAnswer = exampleRequestMessage.copy(flags = flags, answerRecs = im.Seq(goodRecord))
udpSocketProbe.reply(Udp.Received(internal.ByteResponse(badId), socket))
udpSocketProbe.reply(Udp.Received(internal.ByteResponse(badQuestion), socket))
udpSocketProbe.reply(Udp.Received(internal.ByteResponse(goodAnswer), socket))
val answer = Answer(exampleRequest.id, im.Seq(goodRecord), im.Seq())
goodSenderProbe.expectMsg(answer)
udpSocketProbe.expectNoMessage()
}
"Verify original question when processing DropRequest" in {
val udpExtensionProbe = TestProbe()
val udpSocketProbe = TestProbe()
val tcpClientProbe = TestProbe()
val goodSenderProbe = TestProbe()
val socket = InetSocketAddress.createUnresolved("localhost", 41325)
val client = system.actorOf(Props(new DnsClient(dnsServerAddress) {
override val udp = udpExtensionProbe.ref
override def createTcpClient() = tcpClientProbe.ref
}))
udpExtensionProbe.expectMsgType[Udp.Bind]
udpSocketProbe.send(udpExtensionProbe.lastSender, Udp.Bound(socket))
goodSenderProbe.send(client, exampleRequest)
val udpSend = udpSocketProbe.expectMsgType[Udp.Send]
udpSend.payload shouldBe exampleRequestMessage.write()
goodSenderProbe.send(client, DropRequest(exampleRequest.copy(id = 999)))
// duplicate shows inflight message not deleted
goodSenderProbe.send(client, exampleRequest)
goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
goodSenderProbe.send(client, DropRequest(exampleRequest.copy(name = "not.com")))
// duplicate shows inflight message not deleted
goodSenderProbe.send(client, exampleRequest)
goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
goodSenderProbe.send(client, DropRequest(exampleRequest))
// no duplicate shows inflight message was deleted
goodSenderProbe.send(client, exampleRequest)
goodSenderProbe.expectNoMessage()
}
"Verify original question when processing UDP Failures" in {
val udpExtensionProbe = TestProbe()
val udpSocketProbe = TestProbe()
val tcpClientProbe = TestProbe()
val goodSenderProbe = TestProbe()
val socket = InetSocketAddress.createUnresolved("localhost", 41325)
val client = system.actorOf(Props(new DnsClient(dnsServerAddress) {
override val udp = udpExtensionProbe.ref
override def createTcpClient() = tcpClientProbe.ref
}))
udpExtensionProbe.expectMsgType[Udp.Bind]
udpSocketProbe.send(udpExtensionProbe.lastSender, Udp.Bound(socket))
goodSenderProbe.send(client, exampleRequest)
val udpSend = udpSocketProbe.expectMsgType[Udp.Send]
udpSend.payload shouldBe exampleRequestMessage.write()
val badId = exampleRequestMessage.copy(id = 999)
val badQuestion = exampleRequestMessage.copy(
questions = im.Seq(exampleQuestion.copy(name = "not.com")))
val goodQuestion = exampleRequestMessage
udpSocketProbe.reply(Udp.CommandFailed(Udp.Send(internal.ByteResponse(badId), socket)))
// duplicate shows inflight message not deleted
goodSenderProbe.send(client, exampleRequest)
goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
udpSocketProbe.reply(Udp.CommandFailed(Udp.Send(internal.ByteResponse(badQuestion), socket)))
// duplicate shows inflight message not deleted
goodSenderProbe.send(client, exampleRequest)
goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
udpSocketProbe.reply(Udp.CommandFailed(Udp.Send(internal.ByteResponse(goodQuestion), socket)))
goodSenderProbe.expectMsgType[Failure]
// no duplicate shows inflight message was deleted
goodSenderProbe.send(client, exampleRequest)
goodSenderProbe.expectNoMessage()
}
}
}
/**
* The main code only knows how to write the questions not the responses, ByteResponse
* implements just enough of the writing logic that the main code can parse the answers
* messages used in tests.
*
* Message is only available to the internal package.
*/
package internal {
import org.apache.pekko.io.dns.ResourceRecord
import org.apache.pekko.util.{ ByteString, ByteStringBuilder }
object ByteResponse {
def apply(msg: Message): ByteString = {
val ret = ByteString.newBuilder
write(msg, ret)
ret.result()
}
def write(msg: Message, ret: ByteStringBuilder): Unit = {
ret
.putShort(msg.id)
.putShort(msg.flags.flags)
.putShort(msg.questions.size)
.putShort(msg.answerRecs.size)
.putShort(0)
.putShort(0)
msg.questions.foreach(_.write(ret))
msg.answerRecs.foreach(write(_, ret))
}
def write(msg: ResourceRecord, ret: ByteStringBuilder): Unit = {
msg match {
case ARecord(name, ttl, ip) =>
DomainName.write(ret, name)
ret.putShort(RecordType.A.code)
ret.putShort(RecordClass.IN.code)
ret.putInt(ttl.value.toSeconds.toInt)
ret.putShort(4)
ret.putBytes(ip.getAddress, 0, 4)
case _ =>
throw new IllegalStateException(s"Tests cannot write messages of type ${msg.getClass}")
}
}
}
}

View file

@ -1144,6 +1144,12 @@ pekko {
# Defaults to a system dependent lookup (on Unix like OSes, will attempt to parse /etc/resolv.conf, on
# other platforms, will default to 1).
ndots = default
# The policy used to generate dns transaction ids. Options are thread-local-random or secure-random.
# Defaults to thread-local-random similar to Netty, secure-random produces FIPS compliant random numbers but
# could block looking for entropy (these are short integers so are easy to bruit-force, use thread-local-random
# unless you really require FIPS compliant random numbers).
id-generator-policy = thread-local-random
}
}
}

View file

@ -67,6 +67,11 @@ private[dns] final class DnsSettings(system: ExtendedActorSystem, c: Config) {
val PositiveCachePolicy: CachePolicy = getTtl("positive-ttl")
val NegativeCachePolicy: CachePolicy = getTtl("negative-ttl")
lazy val IdGeneratorPolicy: IdGenerator.Policy = IdGenerator
.Policy(c.getString("id-generator-policy"))
.getOrElse(throw new IllegalArgumentException("id-generator-policy must be 'thread-local-random' or " +
s"'secure-random' value was '${c.getString("id-generator-policy")}'"))
private def getTtl(path: String): CachePolicy =
c.getString(path) match {
case "forever" => Forever

View file

@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* license agreements; and to You under the Apache License, version 2.0:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* This file is part of the Apache Pekko project, derived from Akka.
*/
package org.apache.pekko.io.dns
import org.apache.pekko.annotation.InternalApi
import java.security.SecureRandom
import java.util.concurrent.ThreadLocalRandom
import java.util.concurrent.atomic.AtomicInteger
import scala.annotation.tailrec
/**
* INTERNAL API
*
* These are called by an actor, however they are called inside composed futures so need to be
* nextId needs to be thread safe.
*/
@InternalApi
private[pekko] trait IdGenerator {
def nextId(): Short
}
/**
* INTERNAL API
*/
@InternalApi
private[pekko] object IdGenerator {
sealed trait Policy
object Policy {
case object ThreadLocalRandom extends Policy
case object SecureRandom extends Policy
val Default: Policy = ThreadLocalRandom
def apply(name: String): Option[Policy] = name.toLowerCase match {
case "thread-local-random" => Some(ThreadLocalRandom)
case "secure-random" => Some(SecureRandom)
case _ => Some(ThreadLocalRandom)
}
}
def apply(policy: Policy): IdGenerator = policy match {
case Policy.ThreadLocalRandom => random(ThreadLocalRandom.current())
case Policy.SecureRandom => random(new SecureRandom())
}
def apply(): IdGenerator = random(ThreadLocalRandom.current())
/**
* @return a random sequence of ids for production
*/
def random(rand: java.util.Random): IdGenerator = new IdGenerator {
override def nextId(): Short = rand.nextInt(Short.MaxValue).toShort
}
/**
* @return a predictable sequence of ids for tests
*/
def sequence(): IdGenerator = new IdGenerator {
val requestId: AtomicInteger = new AtomicInteger(0)
@tailrec
override final def nextId(): Short = {
val oldId = requestId.get()
val newId = (oldId + 1) % Short.MaxValue
if (requestId.compareAndSet(oldId, newId.intValue())) {
newId.toShort
} else {
nextId()
}
}
}
}

View file

@ -19,6 +19,8 @@ import scala.collection.immutable
import scala.concurrent.ExecutionContextExecutor
import scala.concurrent.Future
import scala.util.Try
import scala.util.Success
import scala.util.Failure
import scala.util.control.NonFatal
import org.apache.pekko
@ -41,10 +43,17 @@ import pekko.util.PrettyDuration._
private[io] final class AsyncDnsResolver(
settings: DnsSettings,
cache: SimpleDnsCache,
clientFactory: (ActorRefFactory, List[InetSocketAddress]) => List[ActorRef])
clientFactory: (ActorRefFactory, List[InetSocketAddress]) => List[ActorRef],
idGenerator: IdGenerator)
extends Actor
with ActorLogging {
def this(
settings: DnsSettings,
cache: SimpleDnsCache,
clientFactory: (ActorRefFactory, List[InetSocketAddress]) => List[ActorRef]) =
this(settings, cache, clientFactory, IdGenerator(settings.IdGeneratorPolicy))
import AsyncDnsResolver._
implicit val ec: ExecutionContextExecutor = context.dispatcher
@ -85,13 +94,6 @@ private[io] final class AsyncDnsResolver(
settings.SearchDomains,
settings.NDots)
private var requestId: Short = 0
private def nextId(): Short = {
requestId = (requestId + 1).toShort
requestId
}
private val resolvers: List[ActorRef] = clientFactory(context, nameServers)
// only supports DnsProtocol, not the deprecated Dns protocol
@ -151,11 +153,19 @@ private[io] final class AsyncDnsResolver(
}
private def sendQuestion(resolver: ActorRef, message: DnsQuestion): Future[Answer] = {
val result = (resolver ? message).mapTo[Answer]
result.failed.foreach { _ =>
resolver ! DropRequest(message.id)
(resolver ? message).transformWith {
case Success(result: Answer) =>
Future.successful(result)
case Success(DuplicateId(_)) =>
sendQuestion(resolver, message.withId(idGenerator.nextId()))
case Failure(t) =>
resolver ! DropRequest(message)
Future.failed(t)
case Success(a) =>
resolver ! DropRequest(message)
Future.failed(
new IllegalArgumentException("Unexpected response " + a.toString + " of type " + a.getClass.toString))
}
result
}
private def resolveWithSearch(
@ -208,13 +218,13 @@ private[io] final class AsyncDnsResolver(
case Ip(ipv4, ipv6) =>
val ipv4Recs: Future[Answer] =
if (ipv4)
sendQuestion(resolver, Question4(nextId(), caseFoldedName))
sendQuestion(resolver, Question4(idGenerator.nextId(), caseFoldedName))
else
Empty
val ipv6Recs =
if (ipv6)
sendQuestion(resolver, Question6(nextId(), caseFoldedName))
sendQuestion(resolver, Question6(idGenerator.nextId(), caseFoldedName))
else
Empty
@ -224,7 +234,7 @@ private[io] final class AsyncDnsResolver(
} yield DnsProtocol.Resolved(name, ipv4.rrs ++ ipv6.rrs, ipv4.additionalRecs ++ ipv6.additionalRecs)
case Srv =>
sendQuestion(resolver, SrvQuestion(nextId(), caseFoldedName)).map(answer => {
sendQuestion(resolver, SrvQuestion(idGenerator.nextId(), caseFoldedName)).map(answer => {
DnsProtocol.Resolved(name, answer.rrs, answer.additionalRecs)
})
}

View file

@ -14,13 +14,10 @@
package org.apache.pekko.io.dns.internal
import java.net.{ InetAddress, InetSocketAddress }
import scala.collection.{ immutable => im }
import scala.concurrent.duration._
import scala.util.Try
import scala.annotation.nowarn
import scala.annotation.{ nowarn, tailrec }
import org.apache.pekko
import pekko.actor.{ Actor, ActorLogging, ActorRef, NoSerializationVerificationNeeded, Props, Stash }
import pekko.actor.Status.Failure
@ -35,13 +32,23 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
@InternalApi private[pekko] object DnsClient {
sealed trait DnsQuestion {
def id: Short
def name: String
def withId(newId: Short): DnsQuestion = {
this match {
case SrvQuestion(_, name) => SrvQuestion(newId, name)
case Question4(_, name) => Question4(newId, name)
case Question6(_, name) => Question6(newId, name)
}
}
}
final case class SrvQuestion(id: Short, name: String) extends DnsQuestion
final case class Question4(id: Short, name: String) extends DnsQuestion
final case class Question6(id: Short, name: String) extends DnsQuestion
final case class Answer(id: Short, rrs: im.Seq[ResourceRecord], additionalRecs: im.Seq[ResourceRecord] = Nil)
extends NoSerializationVerificationNeeded
final case class DropRequest(id: Short)
final case class DuplicateId(id: Short) extends NoSerializationVerificationNeeded
final case class DropRequest(question: DnsQuestion)
}
/**
@ -85,30 +92,40 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
*/
@nowarn()
def ready(socket: ActorRef): Receive = {
case DropRequest(id) =>
log.debug("Dropping request [{}]", id)
inflightRequests -= id
case DropRequest(msg) =>
inflightRequests.get(msg.id).foreach {
case (_, orig) if Seq(msg.name) == orig.questions.map(_.name) =>
log.debug("Dropping request [{}]", msg.id)
inflightRequests -= msg.id
case (_, orig) =>
log.warning("Cannot drop inflight DNS request the question [{}] does not match [{}]",
msg.name,
orig.questions.map(_.name).mkString(","))
}
case Question4(id, name) =>
log.debug("Resolving [{}] (A)", name)
val msg = message(name, id, RecordType.A)
inflightRequests += (id -> (sender() -> msg))
log.debug("Message [{}] to [{}]: [{}]", id, ns, msg)
socket ! Udp.Send(msg.write(), ns)
newInflightRequests(msg, sender()) {
log.debug("Message [{}] to [{}]: [{}]", id, ns, msg)
socket ! Udp.Send(msg.write(), ns)
}
case Question6(id, name) =>
log.debug("Resolving [{}] (AAAA)", name)
val msg = message(name, id, RecordType.AAAA)
inflightRequests += (id -> (sender() -> msg))
log.debug("Message to [{}]: [{}]", ns, msg)
socket ! Udp.Send(msg.write(), ns)
newInflightRequests(msg, sender()) {
log.debug("Message [{}] to [{}]: [{}]", id, ns, msg)
socket ! Udp.Send(msg.write(), ns)
}
case SrvQuestion(id, name) =>
log.debug("Resolving [{}] (SRV)", name)
val msg = message(name, id, RecordType.SRV)
inflightRequests += (id -> (sender() -> msg))
log.debug("Message to [{}]: [{}]", ns, msg)
socket ! Udp.Send(msg.write(), ns)
newInflightRequests(msg, sender()) {
log.debug("Message [{}] to [{}]: [{}]", id, ns, msg)
socket ! Udp.Send(msg.write(), ns)
}
case Udp.CommandFailed(cmd) =>
log.debug("Command failed [{}]", cmd)
@ -118,9 +135,13 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
Try {
val msg = Message.parse(send.payload)
inflightRequests.get(msg.id).foreach {
case (s, _) =>
case (s, orig) if isSameQuestion(msg.questions, orig.questions) =>
s ! Failure(new RuntimeException("Send failed to nameserver"))
inflightRequests -= msg.id
case (_, orig) =>
log.warning("Cannot command failed question [{}] does not match [{}]",
msg.questions.mkString(","),
orig.questions.mkString(","))
}
}
case _ =>
@ -140,9 +161,24 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
log.debug("Client for id {} not found. Discarding unsuccessful response.", msg.id)
}
} else {
val (recs, additionalRecs) =
if (msg.flags.responseCode == ResponseCode.SUCCESS) (msg.answerRecs, msg.additionalRecs) else (Nil, Nil)
self ! Answer(msg.id, recs, additionalRecs)
inflightRequests.get(msg.id) match {
case Some((_, orig)) if !isSameQuestion(msg.questions, orig.questions) =>
log.warning(
"Unexpected DNS response id {} question [{}] does not match question asked [{}]",
msg.id,
msg.questions.mkString(","),
orig.questions.mkString(","))
case Some((_, orig)) =>
log.warning("DNS response id {} question [{}] question asked [{}]",
msg.id,
msg.questions.mkString(","),
orig.questions.mkString(","))
val (recs, additionalRecs) =
if (msg.flags.responseCode == ResponseCode.SUCCESS) (msg.answerRecs, msg.additionalRecs) else (Nil, Nil)
self ! Answer(msg.id, recs, additionalRecs)
case None =>
log.warning("Unexpected DNS response invalid id {}", msg.id)
}
}
case response: Answer =>
inflightRequests.get(response.id) match {
@ -156,6 +192,29 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
case Udp.Unbound => context.stop(self)
}
private def newInflightRequests(msg: Message, theSender: ActorRef)(func: => Unit): Unit = {
if (!inflightRequests.contains(msg.id)) {
inflightRequests += (msg.id -> (theSender -> msg))
func
} else {
log.warning("Received duplicate message [{}] with id [{}]", msg, msg.id)
theSender ! DuplicateId(msg.id)
}
}
private def isSameQuestion(q1s: Seq[Question], q2s: Seq[Question]): Boolean = {
@tailrec
def impl(q1s: List[Question], q2s: List[Question]): Boolean = {
(q1s, q2s) match {
case (Nil, Nil) => true
case (h1 :: t1, h2 :: t2) => h1.isSame(h2) && impl(t1, t2)
case _ => false
}
}
impl(q1s.sortBy(_.name).toList, q2s.sortBy(_.name).toList)
}
def createTcpClient() = {
context.actorOf(
BackoffSupervisor.props(

View file

@ -28,6 +28,16 @@ private[pekko] final case class Question(name: String, qType: RecordType, qClass
RecordTypeSerializer.write(out, qType)
RecordClassSerializer.write(out, qClass)
}
/**
* DomainName.parse adds a '.' to the end of domain names we have to allow for this checking the domain names
* @return true if this the questions are the same (allowing for an trailing dot).
*/
def isSame(that: Question): Boolean = {
def addDot(name: String): String = if (name.nonEmpty && !name.endsWith(".")) name.concat(".") else name
addDot(this.name) == addDot(that.name) && this.qType == that.qType && this.qClass == that.qClass
}
}
/**

View file

@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* license agreements; and to You under the Apache License, version 2.0:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* This file is part of the Apache Pekko project, derived from Akka.
*/
package org.apache.pekko.io.dns
import org.openjdk.jmh.annotations.{
Benchmark,
BenchmarkMode,
Fork,
Measurement,
Mode,
OutputTimeUnit,
Scope,
State,
Threads,
Warmup
}
import java.util.concurrent.{ ThreadLocalRandom, TimeUnit }
import java.security.SecureRandom
@BenchmarkMode(Array(Mode.Throughput))
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Warmup(iterations = 3, time = 5, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 3, time = 5)
@Threads(8)
@Fork(1)
@State(Scope.Benchmark)
class IdGeneratorBanchmark {
val threadLocalRandom = IdGenerator.random(ThreadLocalRandom.current())
val secureRandom = IdGenerator.random(new SecureRandom())
@Benchmark
def measureThreadLocalRandom(): Short = threadLocalRandom.nextId()
@Benchmark
def measureSecureRandom(): Short = secureRandom.nextId()
}