From c56edca78f775934ca081f7967175960a18cb99c Mon Sep 17 00:00:00 2001 From: Iain Hull Date: Wed, 14 Jun 2023 19:38:20 +0100 Subject: [PATCH] CVE-2023-31442 Address DNS poisoning vulnerability (#385) * CVE-2023-31442 Address DNS poisoning vulnerability (and DNS concurrency bug) * Remove sequential dns id generator * fix scalafmt * Fix bug in isSameQuestion And ensure that DnsClient only removes inflight messages when the questions match * fix up exception message to remove reference to 'sequence' generator * Add tests to failed commands and drop requests --------- Co-authored-by: PJ Fanning --- .../apache/pekko/io/dns/DnsSettingsSpec.scala | 7 +- .../apache/pekko/io/dns/IdGeneratorSpec.scala | 27 +++ .../dns/internal/AsyncDnsResolverSpec.scala | 134 ++++++++--- .../pekko/io/dns/internal/DnsClientSpec.scala | 227 +++++++++++++++++- actor/src/main/resources/reference.conf | 6 + .../org/apache/pekko/io/dns/DnsSettings.scala | 5 + .../org/apache/pekko/io/dns/IdGenerator.scala | 81 +++++++ .../io/dns/internal/AsyncDnsResolver.scala | 40 +-- .../pekko/io/dns/internal/DnsClient.scala | 101 ++++++-- .../pekko/io/dns/internal/Question.scala | 10 + .../pekko/io/dns/IdGeneratorBanchmark.scala | 44 ++++ 11 files changed, 602 insertions(+), 80 deletions(-) create mode 100644 actor-tests/src/test/scala/org/apache/pekko/io/dns/IdGeneratorSpec.scala create mode 100644 actor/src/main/scala/org/apache/pekko/io/dns/IdGenerator.scala create mode 100644 bench-jmh/src/main/scala/org/apache/pekko/io/dns/IdGeneratorBanchmark.scala diff --git a/actor-tests/src/test/scala/org/apache/pekko/io/dns/DnsSettingsSpec.scala b/actor-tests/src/test/scala/org/apache/pekko/io/dns/DnsSettingsSpec.scala index b1f81edce1..f60fdcf86a 100644 --- a/actor-tests/src/test/scala/org/apache/pekko/io/dns/DnsSettingsSpec.scala +++ b/actor-tests/src/test/scala/org/apache/pekko/io/dns/DnsSettingsSpec.scala @@ -34,6 +34,7 @@ class DnsSettingsSpec extends PekkoSpec { ndots = 1 positive-ttl = forever negative-ttl = never + id-generator-policy = thread-local-random """) "DNS settings" must { @@ -137,6 +138,10 @@ class DnsSettingsSpec extends PekkoSpec { dnsSettingsDuration.PositiveCachePolicy shouldEqual CachePolicy.Ttl.fromPositive(10.seconds) dnsSettingsDuration.NegativeCachePolicy shouldEqual CachePolicy.Ttl.fromPositive(10.days) } - } + "parse id-generator-policy" in { + val dnsSettings = new DnsSettings(eas, defaultConfig) + dnsSettings.IdGeneratorPolicy shouldEqual (IdGenerator.Policy.ThreadLocalRandom) + } + } } diff --git a/actor-tests/src/test/scala/org/apache/pekko/io/dns/IdGeneratorSpec.scala b/actor-tests/src/test/scala/org/apache/pekko/io/dns/IdGeneratorSpec.scala new file mode 100644 index 0000000000..3ea4a495be --- /dev/null +++ b/actor-tests/src/test/scala/org/apache/pekko/io/dns/IdGeneratorSpec.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * license agreements; and to You under the Apache License, version 2.0: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * This file is part of the Apache Pekko project, derived from Akka. + */ + +package org.apache.pekko.io.dns + +import org.apache.pekko.testkit.PekkoSpec + +class IdGeneratorSpec extends PekkoSpec { + + "IdGenerator" must { + "provide a thread-local-random" in { + val gen = IdGenerator(IdGenerator.Policy.ThreadLocalRandom) + gen.nextId() should be < Short.MaxValue + } + + "provide a secure-random" in { + val gen = IdGenerator(IdGenerator.Policy.SecureRandom) + gen.nextId() should be < Short.MaxValue + } + } +} diff --git a/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolverSpec.scala b/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolverSpec.scala index a4cb4f9560..b3db4c629c 100644 --- a/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolverSpec.scala +++ b/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolverSpec.scala @@ -19,16 +19,15 @@ import scala.collection.{ immutable => im } import scala.concurrent.duration._ import com.typesafe.config.{ Config, ConfigFactory, ConfigValueFactory } - import org.apache.pekko import pekko.actor.{ ActorRef, ExtendedActorSystem, Props } import pekko.actor.Status.Failure import pekko.io.SimpleDnsCache -import pekko.io.dns.{ AAAARecord, ARecord, DnsSettings, SRVRecord } +import pekko.io.dns.{ AAAARecord, ARecord, DnsSettings, IdGenerator, SRVRecord } import pekko.io.dns.CachePolicy.Ttl import pekko.io.dns.DnsProtocol._ import pekko.io.dns.internal.AsyncDnsResolver.ResolveFailedException -import pekko.io.dns.internal.DnsClient.{ Answer, Question4, Question6, SrvQuestion } +import pekko.io.dns.internal.DnsClient.{ Answer, DuplicateId, Question4, Question6, SrvQuestion } import pekko.testkit.{ PekkoSpec, TestProbe, WithLogCapturing } class AsyncDnsResolverSpec extends PekkoSpec(""" @@ -56,8 +55,11 @@ class AsyncDnsResolverSpec extends PekkoSpec(""" "Async DNS Resolver" must { "use dns clients in order" in new Setup { r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false)) - dnsClient1.expectMsg(Question4(1, "cats.com")) - dnsClient1.reply(Answer(1, im.Seq.empty)) + val id = dnsClient1.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" => + q4.id + } + dnsClient1.reply(Answer(id, im.Seq.empty)) dnsClient2.expectNoMessage() senderProbe.expectMsg(Resolved("cats.com", im.Seq.empty)) } @@ -65,31 +67,65 @@ class AsyncDnsResolverSpec extends PekkoSpec(""" "move to next client if first fails" in new Setup { r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false)) // first will get ask timeout - dnsClient1.expectMsg(Question4(1, "cats.com")) + val firstId = dnsClient1.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" => + q4.id + } dnsClient1.reply(Failure(new RuntimeException("Nope"))) - dnsClient2.expectMsg(Question4(2, "cats.com")) - dnsClient2.reply(Answer(2, im.Seq.empty)) + val secondId = dnsClient2.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" && q4.id != firstId => + q4.id + } + dnsClient2.reply(Answer(secondId, im.Seq.empty)) senderProbe.expectMsg(Resolved("cats.com", im.Seq.empty)) } "move to next client if first times out" in new Setup { r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false)) // first will get ask timeout - dnsClient1.expectMsg(Question4(1, "cats.com")) - dnsClient2.expectMsg(Question4(2, "cats.com")) - dnsClient2.reply(Answer(2, im.Seq.empty)) + val firstId = dnsClient1.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" => + q4.id + } + val secondId = dnsClient2.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" && q4.id != firstId => + q4.id + } + dnsClient2.reply(Answer(secondId, im.Seq.empty)) + senderProbe.expectMsg(Resolved("cats.com", im.Seq.empty)) + } + + "handle duplicate Ids in dnsClient" in new Setup { + r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false)) + val firstId = dnsClient1.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" => + q4.id + } + dnsClient1.reply(DuplicateId(firstId)) + val secondId = dnsClient1.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" && q4.id != firstId => + q4.id + } + dnsClient1.reply(Answer(secondId, im.Seq.empty)) + dnsClient2.expectNoMessage() senderProbe.expectMsg(Resolved("cats.com", im.Seq.empty)) } "gets both A and AAAA records if requested" in new Setup { r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = true)) - dnsClient1.expectMsg(Question4(1, "cats.com")) + val firstId = dnsClient1.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" => + q4.id + } val ttl = Ttl.fromPositive(100.seconds) val ipv4Record = ARecord("cats.com", ttl, InetAddress.getByName("127.0.0.1")) - dnsClient1.reply(Answer(1, im.Seq(ipv4Record))) - dnsClient1.expectMsg(Question6(2, "cats.com")) + dnsClient1.reply(Answer(firstId, im.Seq(ipv4Record))) + val secondId = dnsClient1.expectMsgPF() { + case q6: Question6 if q6.name == "cats.com" && q6.id != firstId => + q6.id + } val ipv6Record = AAAARecord("cats.com", ttl, InetAddress.getByName("::1").asInstanceOf[Inet6Address]) - dnsClient1.reply(Answer(2, im.Seq(ipv6Record))) + dnsClient1.reply(Answer(secondId, im.Seq(ipv6Record))) senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record, ipv6Record))) } @@ -102,9 +138,15 @@ class AsyncDnsResolverSpec extends PekkoSpec(""" "fails if all dns clients fail" in new Setup { r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false)) - dnsClient1.expectMsg(Question4(1, "cats.com")) + val firstId = dnsClient1.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" => + q4.id + } dnsClient1.reply(Failure(new RuntimeException("Fail"))) - dnsClient2.expectMsg(Question4(2, "cats.com")) + dnsClient2.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" && q4.id != firstId => + q4.id + } dnsClient2.reply(Failure(new RuntimeException("Yet another fail"))) senderProbe.expectMsgPF(remainingOrDefault) { case Failure(ResolveFailedException(_)) => @@ -113,8 +155,11 @@ class AsyncDnsResolverSpec extends PekkoSpec(""" "gets SRV records if requested" in new Setup { r ! Resolve("cats.com", Srv) - dnsClient1.expectMsg(SrvQuestion(1, "cats.com")) - dnsClient1.reply(Answer(1, im.Seq.empty)) + val firstId = dnsClient1.expectMsgPF() { + case srvQuestion: SrvQuestion if srvQuestion.name == "cats.com" => + srvQuestion.id + } + dnsClient1.reply(Answer(firstId, im.Seq.empty)) dnsClient2.expectNoMessage() senderProbe.expectMsg(Resolved("cats.com", im.Seq.empty)) } @@ -145,10 +190,13 @@ class AsyncDnsResolverSpec extends PekkoSpec(""" "return additional records for SRV requests" in new Setup { r ! Resolve("cats.com", Srv) - dnsClient1.expectMsg(SrvQuestion(1, "cats.com")) + val firstId = dnsClient1.expectMsgPF() { + case srvQuestion: SrvQuestion if srvQuestion.name == "cats.com" => + srvQuestion.id + } val srvRecs = im.Seq(SRVRecord("cats.com", Ttl.fromPositive(5000.seconds), 1, 1, 1, "a.cats.com")) val aRecs = im.Seq(ARecord("a.cats.com", Ttl.fromPositive(1.seconds), InetAddress.getByName("127.0.0.1"))) - dnsClient1.reply(Answer(1, srvRecs, aRecs)) + dnsClient1.reply(Answer(firstId, srvRecs, aRecs)) dnsClient2.expectNoMessage(50.millis) senderProbe.expectMsg(Resolved("cats.com", srvRecs, aRecs)) @@ -162,8 +210,11 @@ class AsyncDnsResolverSpec extends PekkoSpec(""" override val r = resolver(List(dnsClient1.ref), configWithSmallTtl) r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false)) - dnsClient1.expectMsg(Question4(1, "cats.com")) - dnsClient1.reply(Answer(1, im.Seq.empty)) + val firstId = dnsClient1.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" => + q4.id + } + dnsClient1.reply(Answer(firstId, im.Seq.empty)) senderProbe.expectMsg(Resolved("cats.com", im.Seq())) @@ -178,8 +229,11 @@ class AsyncDnsResolverSpec extends PekkoSpec(""" val ipv4Record = ARecord("cats.com", recordTtl, InetAddress.getByName("127.0.0.1")) r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false)) - dnsClient1.expectMsg(Question4(1, "cats.com")) - dnsClient1.reply(Answer(1, im.Seq(ipv4Record))) + val firstId = dnsClient1.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" => + q4.id + } + dnsClient1.reply(Answer(firstId, im.Seq(ipv4Record))) senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record))) @@ -197,14 +251,20 @@ class AsyncDnsResolverSpec extends PekkoSpec(""" val ipv4Record = ARecord("cats.com", recordTtl, InetAddress.getByName("127.0.0.1")) r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false)) - dnsClient1.expectMsg(Question4(1, "cats.com")) - dnsClient1.reply(Answer(1, im.Seq(ipv4Record))) + val firstId = dnsClient1.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" => + q4.id + } + dnsClient1.reply(Answer(firstId, im.Seq(ipv4Record))) senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record))) r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false)) - dnsClient1.expectMsg(Question4(2, "cats.com")) - dnsClient1.reply(Answer(2, im.Seq(ipv4Record))) + val secondId = dnsClient1.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" => + q4.id + } + dnsClient1.reply(Answer(secondId, im.Seq(ipv4Record))) senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record))) } @@ -217,8 +277,11 @@ class AsyncDnsResolverSpec extends PekkoSpec(""" val ipv4Record = ARecord("cats.com", recordTtl, InetAddress.getByName("127.0.0.1")) r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false)) - dnsClient1.expectMsg(Question4(1, "cats.com")) - dnsClient1.reply(Answer(1, im.Seq(ipv4Record))) + val firstId = dnsClient1.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" => + q4.id + } + dnsClient1.reply(Answer(firstId, im.Seq(ipv4Record))) senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record))) r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false)) @@ -228,8 +291,11 @@ class AsyncDnsResolverSpec extends PekkoSpec(""" Thread.sleep(200) r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false)) - dnsClient1.expectMsg(Question4(2, "cats.com")) - dnsClient1.reply(Answer(2, im.Seq(ipv4Record))) + val secondId = dnsClient1.expectMsgPF() { + case q4: Question4 if q4.name == "cats.com" => + q4.id + } + dnsClient1.reply(Answer(secondId, im.Seq(ipv4Record))) senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record))) } @@ -240,6 +306,6 @@ class AsyncDnsResolverSpec extends PekkoSpec(""" system.actorOf(Props(new AsyncDnsResolver(settings, new SimpleDnsCache(), (_, _) => { clients - }))) + }, IdGenerator()))) } } diff --git a/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/DnsClientSpec.scala b/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/DnsClientSpec.scala index 3b89db65e0..d5636c48ca 100644 --- a/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/DnsClientSpec.scala +++ b/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/DnsClientSpec.scala @@ -13,24 +13,23 @@ package org.apache.pekko.io.dns.internal -import java.net.InetSocketAddress +import java.net.{ InetAddress, InetSocketAddress } import java.util.concurrent.atomic.AtomicBoolean - -import scala.collection.immutable.Seq - +import scala.collection.{ immutable => im } import org.apache.pekko +import org.apache.pekko.actor.Status.Failure import pekko.actor.Props import pekko.io.Udp -import pekko.io.dns.{ RecordClass, RecordType } -import pekko.io.dns.internal.DnsClient.{ Answer, Question4 } +import pekko.io.dns.{ ARecord, CachePolicy, RecordClass, RecordType } +import pekko.io.dns.internal.DnsClient.{ Answer, DropRequest, DuplicateId, Question4 } import pekko.testkit.{ ImplicitSender, PekkoSpec, TestProbe } class DnsClientSpec extends PekkoSpec with ImplicitSender { "The async DNS client" should { val exampleRequest = Question4(42, "pekko.io") - val exampleRequestMessage = - Message(42, MessageFlags(), questions = Seq(Question("pekko.io", RecordType.A, RecordClass.IN))) - val exampleResponseMessage = Message(42, MessageFlags(answer = true)) + val exampleQuestion = Question("pekko.io", RecordType.A, RecordClass.IN) + val exampleRequestMessage = Message(42, MessageFlags(), questions = im.Seq(exampleQuestion)) + val exampleResponseMessage = Message(42, MessageFlags(answer = true), questions = im.Seq(exampleQuestion)) val exampleResponse = Answer(42, Nil) val dnsServerAddress = InetSocketAddress.createUnresolved("foo", 53) @@ -83,5 +82,215 @@ class DnsClientSpec extends PekkoSpec with ImplicitSender { expectMsg(exampleResponse) } + + "Do not accept duplicate transaction ids" in { + val udpExtensionProbe = TestProbe() + val udpSocketProbe = TestProbe() + val tcpClientProbe = TestProbe() + val goodSenderProbe = TestProbe() + val badSenderProbe = TestProbe() + + val client = system.actorOf(Props(new DnsClient(dnsServerAddress) { + override val udp = udpExtensionProbe.ref + + override def createTcpClient() = tcpClientProbe.ref + })) + val badRequest = Question4(exampleRequest.id, "not." + exampleRequest.name) + + udpExtensionProbe.expectMsgType[Udp.Bind] + udpSocketProbe.send( + udpExtensionProbe.lastSender, + Udp.Bound(InetSocketAddress.createUnresolved("localhost", 41325))) + + goodSenderProbe.send(client, exampleRequest) + + val udpSend = udpSocketProbe.expectMsgType[Udp.Send] + udpSend.payload shouldBe exampleRequestMessage.write() + + badSenderProbe.send(client, badRequest) + badSenderProbe.expectMsg(DuplicateId(exampleRequest.id)) + + val answer = Answer(exampleRequest.id, im.Seq(), im.Seq()) + udpSocketProbe.reply(answer) + goodSenderProbe.expectMsg(answer) + + udpSocketProbe.expectNoMessage() + } + + "Verify original question when processing UDP replies (DNS poisoning)" in { + val udpExtensionProbe = TestProbe() + val udpSocketProbe = TestProbe() + val tcpClientProbe = TestProbe() + val goodSenderProbe = TestProbe() + + val socket = InetSocketAddress.createUnresolved("localhost", 41325) + val goodRecord = ARecord(exampleRequest.name, CachePolicy.Ttl.never, InetAddress.getLocalHost()) + + val client = system.actorOf(Props(new DnsClient(dnsServerAddress) { + override val udp = udpExtensionProbe.ref + + override def createTcpClient() = tcpClientProbe.ref + })) + + udpExtensionProbe.expectMsgType[Udp.Bind] + udpSocketProbe.send(udpExtensionProbe.lastSender, Udp.Bound(socket)) + + goodSenderProbe.send(client, exampleRequest) + + val udpSend = udpSocketProbe.expectMsgType[Udp.Send] + udpSend.payload shouldBe exampleRequestMessage.write() + + val flags = MessageFlags(true, authoritativeAnswer = true) + val badId = exampleRequestMessage.copy(id = 999, flags = flags, answerRecs = im.Seq(goodRecord)) + val badQuestion = exampleRequestMessage.copy( + flags = flags, + questions = im.Seq(exampleQuestion.copy(name = "not.com")), + answerRecs = im.Seq(goodRecord)) + val goodAnswer = exampleRequestMessage.copy(flags = flags, answerRecs = im.Seq(goodRecord)) + + udpSocketProbe.reply(Udp.Received(internal.ByteResponse(badId), socket)) + udpSocketProbe.reply(Udp.Received(internal.ByteResponse(badQuestion), socket)) + udpSocketProbe.reply(Udp.Received(internal.ByteResponse(goodAnswer), socket)) + + val answer = Answer(exampleRequest.id, im.Seq(goodRecord), im.Seq()) + goodSenderProbe.expectMsg(answer) + + udpSocketProbe.expectNoMessage() + } + + "Verify original question when processing DropRequest" in { + val udpExtensionProbe = TestProbe() + val udpSocketProbe = TestProbe() + val tcpClientProbe = TestProbe() + val goodSenderProbe = TestProbe() + + val socket = InetSocketAddress.createUnresolved("localhost", 41325) + val client = system.actorOf(Props(new DnsClient(dnsServerAddress) { + override val udp = udpExtensionProbe.ref + + override def createTcpClient() = tcpClientProbe.ref + })) + + udpExtensionProbe.expectMsgType[Udp.Bind] + udpSocketProbe.send(udpExtensionProbe.lastSender, Udp.Bound(socket)) + + goodSenderProbe.send(client, exampleRequest) + + val udpSend = udpSocketProbe.expectMsgType[Udp.Send] + udpSend.payload shouldBe exampleRequestMessage.write() + + goodSenderProbe.send(client, DropRequest(exampleRequest.copy(id = 999))) + + // duplicate shows inflight message not deleted + goodSenderProbe.send(client, exampleRequest) + goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id)) + + goodSenderProbe.send(client, DropRequest(exampleRequest.copy(name = "not.com"))) + + // duplicate shows inflight message not deleted + goodSenderProbe.send(client, exampleRequest) + goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id)) + + goodSenderProbe.send(client, DropRequest(exampleRequest)) + + // no duplicate shows inflight message was deleted + goodSenderProbe.send(client, exampleRequest) + goodSenderProbe.expectNoMessage() + } + + "Verify original question when processing UDP Failures" in { + val udpExtensionProbe = TestProbe() + val udpSocketProbe = TestProbe() + val tcpClientProbe = TestProbe() + val goodSenderProbe = TestProbe() + + val socket = InetSocketAddress.createUnresolved("localhost", 41325) + val client = system.actorOf(Props(new DnsClient(dnsServerAddress) { + override val udp = udpExtensionProbe.ref + + override def createTcpClient() = tcpClientProbe.ref + })) + + udpExtensionProbe.expectMsgType[Udp.Bind] + udpSocketProbe.send(udpExtensionProbe.lastSender, Udp.Bound(socket)) + + goodSenderProbe.send(client, exampleRequest) + + val udpSend = udpSocketProbe.expectMsgType[Udp.Send] + udpSend.payload shouldBe exampleRequestMessage.write() + + val badId = exampleRequestMessage.copy(id = 999) + val badQuestion = exampleRequestMessage.copy( + questions = im.Seq(exampleQuestion.copy(name = "not.com"))) + val goodQuestion = exampleRequestMessage + + udpSocketProbe.reply(Udp.CommandFailed(Udp.Send(internal.ByteResponse(badId), socket))) + + // duplicate shows inflight message not deleted + goodSenderProbe.send(client, exampleRequest) + goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id)) + + udpSocketProbe.reply(Udp.CommandFailed(Udp.Send(internal.ByteResponse(badQuestion), socket))) + + // duplicate shows inflight message not deleted + goodSenderProbe.send(client, exampleRequest) + goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id)) + + udpSocketProbe.reply(Udp.CommandFailed(Udp.Send(internal.ByteResponse(goodQuestion), socket))) + goodSenderProbe.expectMsgType[Failure] + + // no duplicate shows inflight message was deleted + goodSenderProbe.send(client, exampleRequest) + goodSenderProbe.expectNoMessage() + } + } +} + +/** + * The main code only knows how to write the questions not the responses, ByteResponse + * implements just enough of the writing logic that the main code can parse the answers + * messages used in tests. + * + * Message is only available to the internal package. + */ +package internal { + + import org.apache.pekko.io.dns.ResourceRecord + import org.apache.pekko.util.{ ByteString, ByteStringBuilder } + + object ByteResponse { + + def apply(msg: Message): ByteString = { + val ret = ByteString.newBuilder + write(msg, ret) + ret.result() + } + + def write(msg: Message, ret: ByteStringBuilder): Unit = { + ret + .putShort(msg.id) + .putShort(msg.flags.flags) + .putShort(msg.questions.size) + .putShort(msg.answerRecs.size) + .putShort(0) + .putShort(0) + + msg.questions.foreach(_.write(ret)) + msg.answerRecs.foreach(write(_, ret)) + } + + def write(msg: ResourceRecord, ret: ByteStringBuilder): Unit = { + msg match { + case ARecord(name, ttl, ip) => + DomainName.write(ret, name) + ret.putShort(RecordType.A.code) + ret.putShort(RecordClass.IN.code) + ret.putInt(ttl.value.toSeconds.toInt) + ret.putShort(4) + ret.putBytes(ip.getAddress, 0, 4) + case _ => + throw new IllegalStateException(s"Tests cannot write messages of type ${msg.getClass}") + } + } } } diff --git a/actor/src/main/resources/reference.conf b/actor/src/main/resources/reference.conf index 53212baf92..e266a7eee5 100644 --- a/actor/src/main/resources/reference.conf +++ b/actor/src/main/resources/reference.conf @@ -1144,6 +1144,12 @@ pekko { # Defaults to a system dependent lookup (on Unix like OSes, will attempt to parse /etc/resolv.conf, on # other platforms, will default to 1). ndots = default + + # The policy used to generate dns transaction ids. Options are thread-local-random or secure-random. + # Defaults to thread-local-random similar to Netty, secure-random produces FIPS compliant random numbers but + # could block looking for entropy (these are short integers so are easy to bruit-force, use thread-local-random + # unless you really require FIPS compliant random numbers). + id-generator-policy = thread-local-random } } } diff --git a/actor/src/main/scala/org/apache/pekko/io/dns/DnsSettings.scala b/actor/src/main/scala/org/apache/pekko/io/dns/DnsSettings.scala index 98bf8f870f..fc6f3fc56d 100644 --- a/actor/src/main/scala/org/apache/pekko/io/dns/DnsSettings.scala +++ b/actor/src/main/scala/org/apache/pekko/io/dns/DnsSettings.scala @@ -67,6 +67,11 @@ private[dns] final class DnsSettings(system: ExtendedActorSystem, c: Config) { val PositiveCachePolicy: CachePolicy = getTtl("positive-ttl") val NegativeCachePolicy: CachePolicy = getTtl("negative-ttl") + lazy val IdGeneratorPolicy: IdGenerator.Policy = IdGenerator + .Policy(c.getString("id-generator-policy")) + .getOrElse(throw new IllegalArgumentException("id-generator-policy must be 'thread-local-random' or " + + s"'secure-random' value was '${c.getString("id-generator-policy")}'")) + private def getTtl(path: String): CachePolicy = c.getString(path) match { case "forever" => Forever diff --git a/actor/src/main/scala/org/apache/pekko/io/dns/IdGenerator.scala b/actor/src/main/scala/org/apache/pekko/io/dns/IdGenerator.scala new file mode 100644 index 0000000000..5bb3191f14 --- /dev/null +++ b/actor/src/main/scala/org/apache/pekko/io/dns/IdGenerator.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * license agreements; and to You under the Apache License, version 2.0: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * This file is part of the Apache Pekko project, derived from Akka. + */ + +package org.apache.pekko.io.dns + +import org.apache.pekko.annotation.InternalApi + +import java.security.SecureRandom +import java.util.concurrent.ThreadLocalRandom +import java.util.concurrent.atomic.AtomicInteger +import scala.annotation.tailrec + +/** + * INTERNAL API + * + * These are called by an actor, however they are called inside composed futures so need to be + * nextId needs to be thread safe. + */ +@InternalApi +private[pekko] trait IdGenerator { + def nextId(): Short +} + +/** + * INTERNAL API + */ +@InternalApi +private[pekko] object IdGenerator { + sealed trait Policy + + object Policy { + case object ThreadLocalRandom extends Policy + case object SecureRandom extends Policy + val Default: Policy = ThreadLocalRandom + + def apply(name: String): Option[Policy] = name.toLowerCase match { + case "thread-local-random" => Some(ThreadLocalRandom) + case "secure-random" => Some(SecureRandom) + case _ => Some(ThreadLocalRandom) + } + } + + def apply(policy: Policy): IdGenerator = policy match { + case Policy.ThreadLocalRandom => random(ThreadLocalRandom.current()) + case Policy.SecureRandom => random(new SecureRandom()) + } + + def apply(): IdGenerator = random(ThreadLocalRandom.current()) + + /** + * @return a random sequence of ids for production + */ + def random(rand: java.util.Random): IdGenerator = new IdGenerator { + override def nextId(): Short = rand.nextInt(Short.MaxValue).toShort + } + + /** + * @return a predictable sequence of ids for tests + */ + def sequence(): IdGenerator = new IdGenerator { + val requestId: AtomicInteger = new AtomicInteger(0) + + @tailrec + override final def nextId(): Short = { + val oldId = requestId.get() + val newId = (oldId + 1) % Short.MaxValue + + if (requestId.compareAndSet(oldId, newId.intValue())) { + newId.toShort + } else { + nextId() + } + } + } +} diff --git a/actor/src/main/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolver.scala b/actor/src/main/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolver.scala index 6172de7b6e..7ac4179634 100644 --- a/actor/src/main/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolver.scala +++ b/actor/src/main/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolver.scala @@ -19,6 +19,8 @@ import scala.collection.immutable import scala.concurrent.ExecutionContextExecutor import scala.concurrent.Future import scala.util.Try +import scala.util.Success +import scala.util.Failure import scala.util.control.NonFatal import org.apache.pekko @@ -41,10 +43,17 @@ import pekko.util.PrettyDuration._ private[io] final class AsyncDnsResolver( settings: DnsSettings, cache: SimpleDnsCache, - clientFactory: (ActorRefFactory, List[InetSocketAddress]) => List[ActorRef]) + clientFactory: (ActorRefFactory, List[InetSocketAddress]) => List[ActorRef], + idGenerator: IdGenerator) extends Actor with ActorLogging { + def this( + settings: DnsSettings, + cache: SimpleDnsCache, + clientFactory: (ActorRefFactory, List[InetSocketAddress]) => List[ActorRef]) = + this(settings, cache, clientFactory, IdGenerator(settings.IdGeneratorPolicy)) + import AsyncDnsResolver._ implicit val ec: ExecutionContextExecutor = context.dispatcher @@ -85,13 +94,6 @@ private[io] final class AsyncDnsResolver( settings.SearchDomains, settings.NDots) - private var requestId: Short = 0 - - private def nextId(): Short = { - requestId = (requestId + 1).toShort - requestId - } - private val resolvers: List[ActorRef] = clientFactory(context, nameServers) // only supports DnsProtocol, not the deprecated Dns protocol @@ -151,11 +153,19 @@ private[io] final class AsyncDnsResolver( } private def sendQuestion(resolver: ActorRef, message: DnsQuestion): Future[Answer] = { - val result = (resolver ? message).mapTo[Answer] - result.failed.foreach { _ => - resolver ! DropRequest(message.id) + (resolver ? message).transformWith { + case Success(result: Answer) => + Future.successful(result) + case Success(DuplicateId(_)) => + sendQuestion(resolver, message.withId(idGenerator.nextId())) + case Failure(t) => + resolver ! DropRequest(message) + Future.failed(t) + case Success(a) => + resolver ! DropRequest(message) + Future.failed( + new IllegalArgumentException("Unexpected response " + a.toString + " of type " + a.getClass.toString)) } - result } private def resolveWithSearch( @@ -208,13 +218,13 @@ private[io] final class AsyncDnsResolver( case Ip(ipv4, ipv6) => val ipv4Recs: Future[Answer] = if (ipv4) - sendQuestion(resolver, Question4(nextId(), caseFoldedName)) + sendQuestion(resolver, Question4(idGenerator.nextId(), caseFoldedName)) else Empty val ipv6Recs = if (ipv6) - sendQuestion(resolver, Question6(nextId(), caseFoldedName)) + sendQuestion(resolver, Question6(idGenerator.nextId(), caseFoldedName)) else Empty @@ -224,7 +234,7 @@ private[io] final class AsyncDnsResolver( } yield DnsProtocol.Resolved(name, ipv4.rrs ++ ipv6.rrs, ipv4.additionalRecs ++ ipv6.additionalRecs) case Srv => - sendQuestion(resolver, SrvQuestion(nextId(), caseFoldedName)).map(answer => { + sendQuestion(resolver, SrvQuestion(idGenerator.nextId(), caseFoldedName)).map(answer => { DnsProtocol.Resolved(name, answer.rrs, answer.additionalRecs) }) } diff --git a/actor/src/main/scala/org/apache/pekko/io/dns/internal/DnsClient.scala b/actor/src/main/scala/org/apache/pekko/io/dns/internal/DnsClient.scala index 0b397020ee..47743d27e9 100644 --- a/actor/src/main/scala/org/apache/pekko/io/dns/internal/DnsClient.scala +++ b/actor/src/main/scala/org/apache/pekko/io/dns/internal/DnsClient.scala @@ -14,13 +14,10 @@ package org.apache.pekko.io.dns.internal import java.net.{ InetAddress, InetSocketAddress } - import scala.collection.{ immutable => im } import scala.concurrent.duration._ import scala.util.Try - -import scala.annotation.nowarn - +import scala.annotation.{ nowarn, tailrec } import org.apache.pekko import pekko.actor.{ Actor, ActorLogging, ActorRef, NoSerializationVerificationNeeded, Props, Stash } import pekko.actor.Status.Failure @@ -35,13 +32,23 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor } @InternalApi private[pekko] object DnsClient { sealed trait DnsQuestion { def id: Short + def name: String + def withId(newId: Short): DnsQuestion = { + this match { + case SrvQuestion(_, name) => SrvQuestion(newId, name) + case Question4(_, name) => Question4(newId, name) + case Question6(_, name) => Question6(newId, name) + } + } } final case class SrvQuestion(id: Short, name: String) extends DnsQuestion final case class Question4(id: Short, name: String) extends DnsQuestion final case class Question6(id: Short, name: String) extends DnsQuestion final case class Answer(id: Short, rrs: im.Seq[ResourceRecord], additionalRecs: im.Seq[ResourceRecord] = Nil) extends NoSerializationVerificationNeeded - final case class DropRequest(id: Short) + + final case class DuplicateId(id: Short) extends NoSerializationVerificationNeeded + final case class DropRequest(question: DnsQuestion) } /** @@ -85,30 +92,40 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor } */ @nowarn() def ready(socket: ActorRef): Receive = { - case DropRequest(id) => - log.debug("Dropping request [{}]", id) - inflightRequests -= id + case DropRequest(msg) => + inflightRequests.get(msg.id).foreach { + case (_, orig) if Seq(msg.name) == orig.questions.map(_.name) => + log.debug("Dropping request [{}]", msg.id) + inflightRequests -= msg.id + case (_, orig) => + log.warning("Cannot drop inflight DNS request the question [{}] does not match [{}]", + msg.name, + orig.questions.map(_.name).mkString(",")) + } case Question4(id, name) => log.debug("Resolving [{}] (A)", name) val msg = message(name, id, RecordType.A) - inflightRequests += (id -> (sender() -> msg)) - log.debug("Message [{}] to [{}]: [{}]", id, ns, msg) - socket ! Udp.Send(msg.write(), ns) + newInflightRequests(msg, sender()) { + log.debug("Message [{}] to [{}]: [{}]", id, ns, msg) + socket ! Udp.Send(msg.write(), ns) + } case Question6(id, name) => log.debug("Resolving [{}] (AAAA)", name) val msg = message(name, id, RecordType.AAAA) - inflightRequests += (id -> (sender() -> msg)) - log.debug("Message to [{}]: [{}]", ns, msg) - socket ! Udp.Send(msg.write(), ns) + newInflightRequests(msg, sender()) { + log.debug("Message [{}] to [{}]: [{}]", id, ns, msg) + socket ! Udp.Send(msg.write(), ns) + } case SrvQuestion(id, name) => log.debug("Resolving [{}] (SRV)", name) val msg = message(name, id, RecordType.SRV) - inflightRequests += (id -> (sender() -> msg)) - log.debug("Message to [{}]: [{}]", ns, msg) - socket ! Udp.Send(msg.write(), ns) + newInflightRequests(msg, sender()) { + log.debug("Message [{}] to [{}]: [{}]", id, ns, msg) + socket ! Udp.Send(msg.write(), ns) + } case Udp.CommandFailed(cmd) => log.debug("Command failed [{}]", cmd) @@ -118,9 +135,13 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor } Try { val msg = Message.parse(send.payload) inflightRequests.get(msg.id).foreach { - case (s, _) => + case (s, orig) if isSameQuestion(msg.questions, orig.questions) => s ! Failure(new RuntimeException("Send failed to nameserver")) inflightRequests -= msg.id + case (_, orig) => + log.warning("Cannot command failed question [{}] does not match [{}]", + msg.questions.mkString(","), + orig.questions.mkString(",")) } } case _ => @@ -140,9 +161,24 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor } log.debug("Client for id {} not found. Discarding unsuccessful response.", msg.id) } } else { - val (recs, additionalRecs) = - if (msg.flags.responseCode == ResponseCode.SUCCESS) (msg.answerRecs, msg.additionalRecs) else (Nil, Nil) - self ! Answer(msg.id, recs, additionalRecs) + inflightRequests.get(msg.id) match { + case Some((_, orig)) if !isSameQuestion(msg.questions, orig.questions) => + log.warning( + "Unexpected DNS response id {} question [{}] does not match question asked [{}]", + msg.id, + msg.questions.mkString(","), + orig.questions.mkString(",")) + case Some((_, orig)) => + log.warning("DNS response id {} question [{}] question asked [{}]", + msg.id, + msg.questions.mkString(","), + orig.questions.mkString(",")) + val (recs, additionalRecs) = + if (msg.flags.responseCode == ResponseCode.SUCCESS) (msg.answerRecs, msg.additionalRecs) else (Nil, Nil) + self ! Answer(msg.id, recs, additionalRecs) + case None => + log.warning("Unexpected DNS response invalid id {}", msg.id) + } } case response: Answer => inflightRequests.get(response.id) match { @@ -156,6 +192,29 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor } case Udp.Unbound => context.stop(self) } + private def newInflightRequests(msg: Message, theSender: ActorRef)(func: => Unit): Unit = { + if (!inflightRequests.contains(msg.id)) { + inflightRequests += (msg.id -> (theSender -> msg)) + func + } else { + log.warning("Received duplicate message [{}] with id [{}]", msg, msg.id) + theSender ! DuplicateId(msg.id) + } + } + + private def isSameQuestion(q1s: Seq[Question], q2s: Seq[Question]): Boolean = { + @tailrec + def impl(q1s: List[Question], q2s: List[Question]): Boolean = { + (q1s, q2s) match { + case (Nil, Nil) => true + case (h1 :: t1, h2 :: t2) => h1.isSame(h2) && impl(t1, t2) + case _ => false + } + } + + impl(q1s.sortBy(_.name).toList, q2s.sortBy(_.name).toList) + } + def createTcpClient() = { context.actorOf( BackoffSupervisor.props( diff --git a/actor/src/main/scala/org/apache/pekko/io/dns/internal/Question.scala b/actor/src/main/scala/org/apache/pekko/io/dns/internal/Question.scala index 534e4b5c50..e8079b50e1 100644 --- a/actor/src/main/scala/org/apache/pekko/io/dns/internal/Question.scala +++ b/actor/src/main/scala/org/apache/pekko/io/dns/internal/Question.scala @@ -28,6 +28,16 @@ private[pekko] final case class Question(name: String, qType: RecordType, qClass RecordTypeSerializer.write(out, qType) RecordClassSerializer.write(out, qClass) } + + /** + * DomainName.parse adds a '.' to the end of domain names we have to allow for this checking the domain names + * @return true if this the questions are the same (allowing for an trailing dot). + */ + def isSame(that: Question): Boolean = { + def addDot(name: String): String = if (name.nonEmpty && !name.endsWith(".")) name.concat(".") else name + + addDot(this.name) == addDot(that.name) && this.qType == that.qType && this.qClass == that.qClass + } } /** diff --git a/bench-jmh/src/main/scala/org/apache/pekko/io/dns/IdGeneratorBanchmark.scala b/bench-jmh/src/main/scala/org/apache/pekko/io/dns/IdGeneratorBanchmark.scala new file mode 100644 index 0000000000..40233bff31 --- /dev/null +++ b/bench-jmh/src/main/scala/org/apache/pekko/io/dns/IdGeneratorBanchmark.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * license agreements; and to You under the Apache License, version 2.0: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * This file is part of the Apache Pekko project, derived from Akka. + */ + +package org.apache.pekko.io.dns + +import org.openjdk.jmh.annotations.{ + Benchmark, + BenchmarkMode, + Fork, + Measurement, + Mode, + OutputTimeUnit, + Scope, + State, + Threads, + Warmup +} + +import java.util.concurrent.{ ThreadLocalRandom, TimeUnit } +import java.security.SecureRandom + +@BenchmarkMode(Array(Mode.Throughput)) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@Warmup(iterations = 3, time = 5, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 3, time = 5) +@Threads(8) +@Fork(1) +@State(Scope.Benchmark) +class IdGeneratorBanchmark { + val threadLocalRandom = IdGenerator.random(ThreadLocalRandom.current()) + val secureRandom = IdGenerator.random(new SecureRandom()) + + @Benchmark + def measureThreadLocalRandom(): Short = threadLocalRandom.nextId() + + @Benchmark + def measureSecureRandom(): Short = secureRandom.nextId() +}