CVE-2023-31442 Address DNS poisoning vulnerability (#385)
* CVE-2023-31442 Address DNS poisoning vulnerability (and DNS concurrency bug) * Remove sequential dns id generator * fix scalafmt * Fix bug in isSameQuestion And ensure that DnsClient only removes inflight messages when the questions match * fix up exception message to remove reference to 'sequence' generator * Add tests to failed commands and drop requests --------- Co-authored-by: PJ Fanning <pjfanning@users.noreply.github.com>
This commit is contained in:
parent
fbf923fc68
commit
c56edca78f
11 changed files with 602 additions and 80 deletions
|
|
@ -34,6 +34,7 @@ class DnsSettingsSpec extends PekkoSpec {
|
|||
ndots = 1
|
||||
positive-ttl = forever
|
||||
negative-ttl = never
|
||||
id-generator-policy = thread-local-random
|
||||
""")
|
||||
|
||||
"DNS settings" must {
|
||||
|
|
@ -137,6 +138,10 @@ class DnsSettingsSpec extends PekkoSpec {
|
|||
dnsSettingsDuration.PositiveCachePolicy shouldEqual CachePolicy.Ttl.fromPositive(10.seconds)
|
||||
dnsSettingsDuration.NegativeCachePolicy shouldEqual CachePolicy.Ttl.fromPositive(10.days)
|
||||
}
|
||||
}
|
||||
|
||||
"parse id-generator-policy" in {
|
||||
val dnsSettings = new DnsSettings(eas, defaultConfig)
|
||||
dnsSettings.IdGeneratorPolicy shouldEqual (IdGenerator.Policy.ThreadLocalRandom)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* license agreements; and to You under the Apache License, version 2.0:
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* This file is part of the Apache Pekko project, derived from Akka.
|
||||
*/
|
||||
|
||||
package org.apache.pekko.io.dns
|
||||
|
||||
import org.apache.pekko.testkit.PekkoSpec
|
||||
|
||||
class IdGeneratorSpec extends PekkoSpec {
|
||||
|
||||
"IdGenerator" must {
|
||||
"provide a thread-local-random" in {
|
||||
val gen = IdGenerator(IdGenerator.Policy.ThreadLocalRandom)
|
||||
gen.nextId() should be < Short.MaxValue
|
||||
}
|
||||
|
||||
"provide a secure-random" in {
|
||||
val gen = IdGenerator(IdGenerator.Policy.SecureRandom)
|
||||
gen.nextId() should be < Short.MaxValue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -19,16 +19,15 @@ import scala.collection.{ immutable => im }
|
|||
import scala.concurrent.duration._
|
||||
|
||||
import com.typesafe.config.{ Config, ConfigFactory, ConfigValueFactory }
|
||||
|
||||
import org.apache.pekko
|
||||
import pekko.actor.{ ActorRef, ExtendedActorSystem, Props }
|
||||
import pekko.actor.Status.Failure
|
||||
import pekko.io.SimpleDnsCache
|
||||
import pekko.io.dns.{ AAAARecord, ARecord, DnsSettings, SRVRecord }
|
||||
import pekko.io.dns.{ AAAARecord, ARecord, DnsSettings, IdGenerator, SRVRecord }
|
||||
import pekko.io.dns.CachePolicy.Ttl
|
||||
import pekko.io.dns.DnsProtocol._
|
||||
import pekko.io.dns.internal.AsyncDnsResolver.ResolveFailedException
|
||||
import pekko.io.dns.internal.DnsClient.{ Answer, Question4, Question6, SrvQuestion }
|
||||
import pekko.io.dns.internal.DnsClient.{ Answer, DuplicateId, Question4, Question6, SrvQuestion }
|
||||
import pekko.testkit.{ PekkoSpec, TestProbe, WithLogCapturing }
|
||||
|
||||
class AsyncDnsResolverSpec extends PekkoSpec("""
|
||||
|
|
@ -56,8 +55,11 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
|
|||
"Async DNS Resolver" must {
|
||||
"use dns clients in order" in new Setup {
|
||||
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
|
||||
dnsClient1.expectMsg(Question4(1, "cats.com"))
|
||||
dnsClient1.reply(Answer(1, im.Seq.empty))
|
||||
val id = dnsClient1.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" =>
|
||||
q4.id
|
||||
}
|
||||
dnsClient1.reply(Answer(id, im.Seq.empty))
|
||||
dnsClient2.expectNoMessage()
|
||||
senderProbe.expectMsg(Resolved("cats.com", im.Seq.empty))
|
||||
}
|
||||
|
|
@ -65,31 +67,65 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
|
|||
"move to next client if first fails" in new Setup {
|
||||
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
|
||||
// first will get ask timeout
|
||||
dnsClient1.expectMsg(Question4(1, "cats.com"))
|
||||
val firstId = dnsClient1.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" =>
|
||||
q4.id
|
||||
}
|
||||
dnsClient1.reply(Failure(new RuntimeException("Nope")))
|
||||
dnsClient2.expectMsg(Question4(2, "cats.com"))
|
||||
dnsClient2.reply(Answer(2, im.Seq.empty))
|
||||
val secondId = dnsClient2.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" && q4.id != firstId =>
|
||||
q4.id
|
||||
}
|
||||
dnsClient2.reply(Answer(secondId, im.Seq.empty))
|
||||
senderProbe.expectMsg(Resolved("cats.com", im.Seq.empty))
|
||||
}
|
||||
|
||||
"move to next client if first times out" in new Setup {
|
||||
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
|
||||
// first will get ask timeout
|
||||
dnsClient1.expectMsg(Question4(1, "cats.com"))
|
||||
dnsClient2.expectMsg(Question4(2, "cats.com"))
|
||||
dnsClient2.reply(Answer(2, im.Seq.empty))
|
||||
val firstId = dnsClient1.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" =>
|
||||
q4.id
|
||||
}
|
||||
val secondId = dnsClient2.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" && q4.id != firstId =>
|
||||
q4.id
|
||||
}
|
||||
dnsClient2.reply(Answer(secondId, im.Seq.empty))
|
||||
senderProbe.expectMsg(Resolved("cats.com", im.Seq.empty))
|
||||
}
|
||||
|
||||
"handle duplicate Ids in dnsClient" in new Setup {
|
||||
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
|
||||
val firstId = dnsClient1.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" =>
|
||||
q4.id
|
||||
}
|
||||
dnsClient1.reply(DuplicateId(firstId))
|
||||
val secondId = dnsClient1.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" && q4.id != firstId =>
|
||||
q4.id
|
||||
}
|
||||
dnsClient1.reply(Answer(secondId, im.Seq.empty))
|
||||
dnsClient2.expectNoMessage()
|
||||
senderProbe.expectMsg(Resolved("cats.com", im.Seq.empty))
|
||||
}
|
||||
|
||||
"gets both A and AAAA records if requested" in new Setup {
|
||||
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = true))
|
||||
dnsClient1.expectMsg(Question4(1, "cats.com"))
|
||||
val firstId = dnsClient1.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" =>
|
||||
q4.id
|
||||
}
|
||||
val ttl = Ttl.fromPositive(100.seconds)
|
||||
val ipv4Record = ARecord("cats.com", ttl, InetAddress.getByName("127.0.0.1"))
|
||||
dnsClient1.reply(Answer(1, im.Seq(ipv4Record)))
|
||||
dnsClient1.expectMsg(Question6(2, "cats.com"))
|
||||
dnsClient1.reply(Answer(firstId, im.Seq(ipv4Record)))
|
||||
val secondId = dnsClient1.expectMsgPF() {
|
||||
case q6: Question6 if q6.name == "cats.com" && q6.id != firstId =>
|
||||
q6.id
|
||||
}
|
||||
val ipv6Record = AAAARecord("cats.com", ttl, InetAddress.getByName("::1").asInstanceOf[Inet6Address])
|
||||
dnsClient1.reply(Answer(2, im.Seq(ipv6Record)))
|
||||
dnsClient1.reply(Answer(secondId, im.Seq(ipv6Record)))
|
||||
senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record, ipv6Record)))
|
||||
}
|
||||
|
||||
|
|
@ -102,9 +138,15 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
|
|||
|
||||
"fails if all dns clients fail" in new Setup {
|
||||
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
|
||||
dnsClient1.expectMsg(Question4(1, "cats.com"))
|
||||
val firstId = dnsClient1.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" =>
|
||||
q4.id
|
||||
}
|
||||
dnsClient1.reply(Failure(new RuntimeException("Fail")))
|
||||
dnsClient2.expectMsg(Question4(2, "cats.com"))
|
||||
dnsClient2.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" && q4.id != firstId =>
|
||||
q4.id
|
||||
}
|
||||
dnsClient2.reply(Failure(new RuntimeException("Yet another fail")))
|
||||
senderProbe.expectMsgPF(remainingOrDefault) {
|
||||
case Failure(ResolveFailedException(_)) =>
|
||||
|
|
@ -113,8 +155,11 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
|
|||
|
||||
"gets SRV records if requested" in new Setup {
|
||||
r ! Resolve("cats.com", Srv)
|
||||
dnsClient1.expectMsg(SrvQuestion(1, "cats.com"))
|
||||
dnsClient1.reply(Answer(1, im.Seq.empty))
|
||||
val firstId = dnsClient1.expectMsgPF() {
|
||||
case srvQuestion: SrvQuestion if srvQuestion.name == "cats.com" =>
|
||||
srvQuestion.id
|
||||
}
|
||||
dnsClient1.reply(Answer(firstId, im.Seq.empty))
|
||||
dnsClient2.expectNoMessage()
|
||||
senderProbe.expectMsg(Resolved("cats.com", im.Seq.empty))
|
||||
}
|
||||
|
|
@ -145,10 +190,13 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
|
|||
|
||||
"return additional records for SRV requests" in new Setup {
|
||||
r ! Resolve("cats.com", Srv)
|
||||
dnsClient1.expectMsg(SrvQuestion(1, "cats.com"))
|
||||
val firstId = dnsClient1.expectMsgPF() {
|
||||
case srvQuestion: SrvQuestion if srvQuestion.name == "cats.com" =>
|
||||
srvQuestion.id
|
||||
}
|
||||
val srvRecs = im.Seq(SRVRecord("cats.com", Ttl.fromPositive(5000.seconds), 1, 1, 1, "a.cats.com"))
|
||||
val aRecs = im.Seq(ARecord("a.cats.com", Ttl.fromPositive(1.seconds), InetAddress.getByName("127.0.0.1")))
|
||||
dnsClient1.reply(Answer(1, srvRecs, aRecs))
|
||||
dnsClient1.reply(Answer(firstId, srvRecs, aRecs))
|
||||
dnsClient2.expectNoMessage(50.millis)
|
||||
senderProbe.expectMsg(Resolved("cats.com", srvRecs, aRecs))
|
||||
|
||||
|
|
@ -162,8 +210,11 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
|
|||
override val r = resolver(List(dnsClient1.ref), configWithSmallTtl)
|
||||
|
||||
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
|
||||
dnsClient1.expectMsg(Question4(1, "cats.com"))
|
||||
dnsClient1.reply(Answer(1, im.Seq.empty))
|
||||
val firstId = dnsClient1.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" =>
|
||||
q4.id
|
||||
}
|
||||
dnsClient1.reply(Answer(firstId, im.Seq.empty))
|
||||
|
||||
senderProbe.expectMsg(Resolved("cats.com", im.Seq()))
|
||||
|
||||
|
|
@ -178,8 +229,11 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
|
|||
val ipv4Record = ARecord("cats.com", recordTtl, InetAddress.getByName("127.0.0.1"))
|
||||
|
||||
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
|
||||
dnsClient1.expectMsg(Question4(1, "cats.com"))
|
||||
dnsClient1.reply(Answer(1, im.Seq(ipv4Record)))
|
||||
val firstId = dnsClient1.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" =>
|
||||
q4.id
|
||||
}
|
||||
dnsClient1.reply(Answer(firstId, im.Seq(ipv4Record)))
|
||||
|
||||
senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record)))
|
||||
|
||||
|
|
@ -197,14 +251,20 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
|
|||
val ipv4Record = ARecord("cats.com", recordTtl, InetAddress.getByName("127.0.0.1"))
|
||||
|
||||
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
|
||||
dnsClient1.expectMsg(Question4(1, "cats.com"))
|
||||
dnsClient1.reply(Answer(1, im.Seq(ipv4Record)))
|
||||
val firstId = dnsClient1.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" =>
|
||||
q4.id
|
||||
}
|
||||
dnsClient1.reply(Answer(firstId, im.Seq(ipv4Record)))
|
||||
|
||||
senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record)))
|
||||
|
||||
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
|
||||
dnsClient1.expectMsg(Question4(2, "cats.com"))
|
||||
dnsClient1.reply(Answer(2, im.Seq(ipv4Record)))
|
||||
val secondId = dnsClient1.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" =>
|
||||
q4.id
|
||||
}
|
||||
dnsClient1.reply(Answer(secondId, im.Seq(ipv4Record)))
|
||||
|
||||
senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record)))
|
||||
}
|
||||
|
|
@ -217,8 +277,11 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
|
|||
val ipv4Record = ARecord("cats.com", recordTtl, InetAddress.getByName("127.0.0.1"))
|
||||
|
||||
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
|
||||
dnsClient1.expectMsg(Question4(1, "cats.com"))
|
||||
dnsClient1.reply(Answer(1, im.Seq(ipv4Record)))
|
||||
val firstId = dnsClient1.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" =>
|
||||
q4.id
|
||||
}
|
||||
dnsClient1.reply(Answer(firstId, im.Seq(ipv4Record)))
|
||||
senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record)))
|
||||
|
||||
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
|
||||
|
|
@ -228,8 +291,11 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
|
|||
|
||||
Thread.sleep(200)
|
||||
r ! Resolve("cats.com", Ip(ipv4 = true, ipv6 = false))
|
||||
dnsClient1.expectMsg(Question4(2, "cats.com"))
|
||||
dnsClient1.reply(Answer(2, im.Seq(ipv4Record)))
|
||||
val secondId = dnsClient1.expectMsgPF() {
|
||||
case q4: Question4 if q4.name == "cats.com" =>
|
||||
q4.id
|
||||
}
|
||||
dnsClient1.reply(Answer(secondId, im.Seq(ipv4Record)))
|
||||
|
||||
senderProbe.expectMsg(Resolved("cats.com", im.Seq(ipv4Record)))
|
||||
}
|
||||
|
|
@ -240,6 +306,6 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
|
|||
system.actorOf(Props(new AsyncDnsResolver(settings, new SimpleDnsCache(),
|
||||
(_, _) => {
|
||||
clients
|
||||
})))
|
||||
}, IdGenerator())))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,24 +13,23 @@
|
|||
|
||||
package org.apache.pekko.io.dns.internal
|
||||
|
||||
import java.net.InetSocketAddress
|
||||
import java.net.{ InetAddress, InetSocketAddress }
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
import scala.collection.immutable.Seq
|
||||
|
||||
import scala.collection.{ immutable => im }
|
||||
import org.apache.pekko
|
||||
import org.apache.pekko.actor.Status.Failure
|
||||
import pekko.actor.Props
|
||||
import pekko.io.Udp
|
||||
import pekko.io.dns.{ RecordClass, RecordType }
|
||||
import pekko.io.dns.internal.DnsClient.{ Answer, Question4 }
|
||||
import pekko.io.dns.{ ARecord, CachePolicy, RecordClass, RecordType }
|
||||
import pekko.io.dns.internal.DnsClient.{ Answer, DropRequest, DuplicateId, Question4 }
|
||||
import pekko.testkit.{ ImplicitSender, PekkoSpec, TestProbe }
|
||||
|
||||
class DnsClientSpec extends PekkoSpec with ImplicitSender {
|
||||
"The async DNS client" should {
|
||||
val exampleRequest = Question4(42, "pekko.io")
|
||||
val exampleRequestMessage =
|
||||
Message(42, MessageFlags(), questions = Seq(Question("pekko.io", RecordType.A, RecordClass.IN)))
|
||||
val exampleResponseMessage = Message(42, MessageFlags(answer = true))
|
||||
val exampleQuestion = Question("pekko.io", RecordType.A, RecordClass.IN)
|
||||
val exampleRequestMessage = Message(42, MessageFlags(), questions = im.Seq(exampleQuestion))
|
||||
val exampleResponseMessage = Message(42, MessageFlags(answer = true), questions = im.Seq(exampleQuestion))
|
||||
val exampleResponse = Answer(42, Nil)
|
||||
val dnsServerAddress = InetSocketAddress.createUnresolved("foo", 53)
|
||||
|
||||
|
|
@ -83,5 +82,215 @@ class DnsClientSpec extends PekkoSpec with ImplicitSender {
|
|||
|
||||
expectMsg(exampleResponse)
|
||||
}
|
||||
|
||||
"Do not accept duplicate transaction ids" in {
|
||||
val udpExtensionProbe = TestProbe()
|
||||
val udpSocketProbe = TestProbe()
|
||||
val tcpClientProbe = TestProbe()
|
||||
val goodSenderProbe = TestProbe()
|
||||
val badSenderProbe = TestProbe()
|
||||
|
||||
val client = system.actorOf(Props(new DnsClient(dnsServerAddress) {
|
||||
override val udp = udpExtensionProbe.ref
|
||||
|
||||
override def createTcpClient() = tcpClientProbe.ref
|
||||
}))
|
||||
val badRequest = Question4(exampleRequest.id, "not." + exampleRequest.name)
|
||||
|
||||
udpExtensionProbe.expectMsgType[Udp.Bind]
|
||||
udpSocketProbe.send(
|
||||
udpExtensionProbe.lastSender,
|
||||
Udp.Bound(InetSocketAddress.createUnresolved("localhost", 41325)))
|
||||
|
||||
goodSenderProbe.send(client, exampleRequest)
|
||||
|
||||
val udpSend = udpSocketProbe.expectMsgType[Udp.Send]
|
||||
udpSend.payload shouldBe exampleRequestMessage.write()
|
||||
|
||||
badSenderProbe.send(client, badRequest)
|
||||
badSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
|
||||
|
||||
val answer = Answer(exampleRequest.id, im.Seq(), im.Seq())
|
||||
udpSocketProbe.reply(answer)
|
||||
goodSenderProbe.expectMsg(answer)
|
||||
|
||||
udpSocketProbe.expectNoMessage()
|
||||
}
|
||||
|
||||
"Verify original question when processing UDP replies (DNS poisoning)" in {
|
||||
val udpExtensionProbe = TestProbe()
|
||||
val udpSocketProbe = TestProbe()
|
||||
val tcpClientProbe = TestProbe()
|
||||
val goodSenderProbe = TestProbe()
|
||||
|
||||
val socket = InetSocketAddress.createUnresolved("localhost", 41325)
|
||||
val goodRecord = ARecord(exampleRequest.name, CachePolicy.Ttl.never, InetAddress.getLocalHost())
|
||||
|
||||
val client = system.actorOf(Props(new DnsClient(dnsServerAddress) {
|
||||
override val udp = udpExtensionProbe.ref
|
||||
|
||||
override def createTcpClient() = tcpClientProbe.ref
|
||||
}))
|
||||
|
||||
udpExtensionProbe.expectMsgType[Udp.Bind]
|
||||
udpSocketProbe.send(udpExtensionProbe.lastSender, Udp.Bound(socket))
|
||||
|
||||
goodSenderProbe.send(client, exampleRequest)
|
||||
|
||||
val udpSend = udpSocketProbe.expectMsgType[Udp.Send]
|
||||
udpSend.payload shouldBe exampleRequestMessage.write()
|
||||
|
||||
val flags = MessageFlags(true, authoritativeAnswer = true)
|
||||
val badId = exampleRequestMessage.copy(id = 999, flags = flags, answerRecs = im.Seq(goodRecord))
|
||||
val badQuestion = exampleRequestMessage.copy(
|
||||
flags = flags,
|
||||
questions = im.Seq(exampleQuestion.copy(name = "not.com")),
|
||||
answerRecs = im.Seq(goodRecord))
|
||||
val goodAnswer = exampleRequestMessage.copy(flags = flags, answerRecs = im.Seq(goodRecord))
|
||||
|
||||
udpSocketProbe.reply(Udp.Received(internal.ByteResponse(badId), socket))
|
||||
udpSocketProbe.reply(Udp.Received(internal.ByteResponse(badQuestion), socket))
|
||||
udpSocketProbe.reply(Udp.Received(internal.ByteResponse(goodAnswer), socket))
|
||||
|
||||
val answer = Answer(exampleRequest.id, im.Seq(goodRecord), im.Seq())
|
||||
goodSenderProbe.expectMsg(answer)
|
||||
|
||||
udpSocketProbe.expectNoMessage()
|
||||
}
|
||||
|
||||
"Verify original question when processing DropRequest" in {
|
||||
val udpExtensionProbe = TestProbe()
|
||||
val udpSocketProbe = TestProbe()
|
||||
val tcpClientProbe = TestProbe()
|
||||
val goodSenderProbe = TestProbe()
|
||||
|
||||
val socket = InetSocketAddress.createUnresolved("localhost", 41325)
|
||||
val client = system.actorOf(Props(new DnsClient(dnsServerAddress) {
|
||||
override val udp = udpExtensionProbe.ref
|
||||
|
||||
override def createTcpClient() = tcpClientProbe.ref
|
||||
}))
|
||||
|
||||
udpExtensionProbe.expectMsgType[Udp.Bind]
|
||||
udpSocketProbe.send(udpExtensionProbe.lastSender, Udp.Bound(socket))
|
||||
|
||||
goodSenderProbe.send(client, exampleRequest)
|
||||
|
||||
val udpSend = udpSocketProbe.expectMsgType[Udp.Send]
|
||||
udpSend.payload shouldBe exampleRequestMessage.write()
|
||||
|
||||
goodSenderProbe.send(client, DropRequest(exampleRequest.copy(id = 999)))
|
||||
|
||||
// duplicate shows inflight message not deleted
|
||||
goodSenderProbe.send(client, exampleRequest)
|
||||
goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
|
||||
|
||||
goodSenderProbe.send(client, DropRequest(exampleRequest.copy(name = "not.com")))
|
||||
|
||||
// duplicate shows inflight message not deleted
|
||||
goodSenderProbe.send(client, exampleRequest)
|
||||
goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
|
||||
|
||||
goodSenderProbe.send(client, DropRequest(exampleRequest))
|
||||
|
||||
// no duplicate shows inflight message was deleted
|
||||
goodSenderProbe.send(client, exampleRequest)
|
||||
goodSenderProbe.expectNoMessage()
|
||||
}
|
||||
|
||||
"Verify original question when processing UDP Failures" in {
|
||||
val udpExtensionProbe = TestProbe()
|
||||
val udpSocketProbe = TestProbe()
|
||||
val tcpClientProbe = TestProbe()
|
||||
val goodSenderProbe = TestProbe()
|
||||
|
||||
val socket = InetSocketAddress.createUnresolved("localhost", 41325)
|
||||
val client = system.actorOf(Props(new DnsClient(dnsServerAddress) {
|
||||
override val udp = udpExtensionProbe.ref
|
||||
|
||||
override def createTcpClient() = tcpClientProbe.ref
|
||||
}))
|
||||
|
||||
udpExtensionProbe.expectMsgType[Udp.Bind]
|
||||
udpSocketProbe.send(udpExtensionProbe.lastSender, Udp.Bound(socket))
|
||||
|
||||
goodSenderProbe.send(client, exampleRequest)
|
||||
|
||||
val udpSend = udpSocketProbe.expectMsgType[Udp.Send]
|
||||
udpSend.payload shouldBe exampleRequestMessage.write()
|
||||
|
||||
val badId = exampleRequestMessage.copy(id = 999)
|
||||
val badQuestion = exampleRequestMessage.copy(
|
||||
questions = im.Seq(exampleQuestion.copy(name = "not.com")))
|
||||
val goodQuestion = exampleRequestMessage
|
||||
|
||||
udpSocketProbe.reply(Udp.CommandFailed(Udp.Send(internal.ByteResponse(badId), socket)))
|
||||
|
||||
// duplicate shows inflight message not deleted
|
||||
goodSenderProbe.send(client, exampleRequest)
|
||||
goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
|
||||
|
||||
udpSocketProbe.reply(Udp.CommandFailed(Udp.Send(internal.ByteResponse(badQuestion), socket)))
|
||||
|
||||
// duplicate shows inflight message not deleted
|
||||
goodSenderProbe.send(client, exampleRequest)
|
||||
goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
|
||||
|
||||
udpSocketProbe.reply(Udp.CommandFailed(Udp.Send(internal.ByteResponse(goodQuestion), socket)))
|
||||
goodSenderProbe.expectMsgType[Failure]
|
||||
|
||||
// no duplicate shows inflight message was deleted
|
||||
goodSenderProbe.send(client, exampleRequest)
|
||||
goodSenderProbe.expectNoMessage()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The main code only knows how to write the questions not the responses, ByteResponse
|
||||
* implements just enough of the writing logic that the main code can parse the answers
|
||||
* messages used in tests.
|
||||
*
|
||||
* Message is only available to the internal package.
|
||||
*/
|
||||
package internal {
|
||||
|
||||
import org.apache.pekko.io.dns.ResourceRecord
|
||||
import org.apache.pekko.util.{ ByteString, ByteStringBuilder }
|
||||
|
||||
object ByteResponse {
|
||||
|
||||
def apply(msg: Message): ByteString = {
|
||||
val ret = ByteString.newBuilder
|
||||
write(msg, ret)
|
||||
ret.result()
|
||||
}
|
||||
|
||||
def write(msg: Message, ret: ByteStringBuilder): Unit = {
|
||||
ret
|
||||
.putShort(msg.id)
|
||||
.putShort(msg.flags.flags)
|
||||
.putShort(msg.questions.size)
|
||||
.putShort(msg.answerRecs.size)
|
||||
.putShort(0)
|
||||
.putShort(0)
|
||||
|
||||
msg.questions.foreach(_.write(ret))
|
||||
msg.answerRecs.foreach(write(_, ret))
|
||||
}
|
||||
|
||||
def write(msg: ResourceRecord, ret: ByteStringBuilder): Unit = {
|
||||
msg match {
|
||||
case ARecord(name, ttl, ip) =>
|
||||
DomainName.write(ret, name)
|
||||
ret.putShort(RecordType.A.code)
|
||||
ret.putShort(RecordClass.IN.code)
|
||||
ret.putInt(ttl.value.toSeconds.toInt)
|
||||
ret.putShort(4)
|
||||
ret.putBytes(ip.getAddress, 0, 4)
|
||||
case _ =>
|
||||
throw new IllegalStateException(s"Tests cannot write messages of type ${msg.getClass}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1144,6 +1144,12 @@ pekko {
|
|||
# Defaults to a system dependent lookup (on Unix like OSes, will attempt to parse /etc/resolv.conf, on
|
||||
# other platforms, will default to 1).
|
||||
ndots = default
|
||||
|
||||
# The policy used to generate dns transaction ids. Options are thread-local-random or secure-random.
|
||||
# Defaults to thread-local-random similar to Netty, secure-random produces FIPS compliant random numbers but
|
||||
# could block looking for entropy (these are short integers so are easy to bruit-force, use thread-local-random
|
||||
# unless you really require FIPS compliant random numbers).
|
||||
id-generator-policy = thread-local-random
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,6 +67,11 @@ private[dns] final class DnsSettings(system: ExtendedActorSystem, c: Config) {
|
|||
val PositiveCachePolicy: CachePolicy = getTtl("positive-ttl")
|
||||
val NegativeCachePolicy: CachePolicy = getTtl("negative-ttl")
|
||||
|
||||
lazy val IdGeneratorPolicy: IdGenerator.Policy = IdGenerator
|
||||
.Policy(c.getString("id-generator-policy"))
|
||||
.getOrElse(throw new IllegalArgumentException("id-generator-policy must be 'thread-local-random' or " +
|
||||
s"'secure-random' value was '${c.getString("id-generator-policy")}'"))
|
||||
|
||||
private def getTtl(path: String): CachePolicy =
|
||||
c.getString(path) match {
|
||||
case "forever" => Forever
|
||||
|
|
|
|||
|
|
@ -0,0 +1,81 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* license agreements; and to You under the Apache License, version 2.0:
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* This file is part of the Apache Pekko project, derived from Akka.
|
||||
*/
|
||||
|
||||
package org.apache.pekko.io.dns
|
||||
|
||||
import org.apache.pekko.annotation.InternalApi
|
||||
|
||||
import java.security.SecureRandom
|
||||
import java.util.concurrent.ThreadLocalRandom
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import scala.annotation.tailrec
|
||||
|
||||
/**
|
||||
* INTERNAL API
|
||||
*
|
||||
* These are called by an actor, however they are called inside composed futures so need to be
|
||||
* nextId needs to be thread safe.
|
||||
*/
|
||||
@InternalApi
|
||||
private[pekko] trait IdGenerator {
|
||||
def nextId(): Short
|
||||
}
|
||||
|
||||
/**
|
||||
* INTERNAL API
|
||||
*/
|
||||
@InternalApi
|
||||
private[pekko] object IdGenerator {
|
||||
sealed trait Policy
|
||||
|
||||
object Policy {
|
||||
case object ThreadLocalRandom extends Policy
|
||||
case object SecureRandom extends Policy
|
||||
val Default: Policy = ThreadLocalRandom
|
||||
|
||||
def apply(name: String): Option[Policy] = name.toLowerCase match {
|
||||
case "thread-local-random" => Some(ThreadLocalRandom)
|
||||
case "secure-random" => Some(SecureRandom)
|
||||
case _ => Some(ThreadLocalRandom)
|
||||
}
|
||||
}
|
||||
|
||||
def apply(policy: Policy): IdGenerator = policy match {
|
||||
case Policy.ThreadLocalRandom => random(ThreadLocalRandom.current())
|
||||
case Policy.SecureRandom => random(new SecureRandom())
|
||||
}
|
||||
|
||||
def apply(): IdGenerator = random(ThreadLocalRandom.current())
|
||||
|
||||
/**
|
||||
* @return a random sequence of ids for production
|
||||
*/
|
||||
def random(rand: java.util.Random): IdGenerator = new IdGenerator {
|
||||
override def nextId(): Short = rand.nextInt(Short.MaxValue).toShort
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a predictable sequence of ids for tests
|
||||
*/
|
||||
def sequence(): IdGenerator = new IdGenerator {
|
||||
val requestId: AtomicInteger = new AtomicInteger(0)
|
||||
|
||||
@tailrec
|
||||
override final def nextId(): Short = {
|
||||
val oldId = requestId.get()
|
||||
val newId = (oldId + 1) % Short.MaxValue
|
||||
|
||||
if (requestId.compareAndSet(oldId, newId.intValue())) {
|
||||
newId.toShort
|
||||
} else {
|
||||
nextId()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -19,6 +19,8 @@ import scala.collection.immutable
|
|||
import scala.concurrent.ExecutionContextExecutor
|
||||
import scala.concurrent.Future
|
||||
import scala.util.Try
|
||||
import scala.util.Success
|
||||
import scala.util.Failure
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
import org.apache.pekko
|
||||
|
|
@ -41,10 +43,17 @@ import pekko.util.PrettyDuration._
|
|||
private[io] final class AsyncDnsResolver(
|
||||
settings: DnsSettings,
|
||||
cache: SimpleDnsCache,
|
||||
clientFactory: (ActorRefFactory, List[InetSocketAddress]) => List[ActorRef])
|
||||
clientFactory: (ActorRefFactory, List[InetSocketAddress]) => List[ActorRef],
|
||||
idGenerator: IdGenerator)
|
||||
extends Actor
|
||||
with ActorLogging {
|
||||
|
||||
def this(
|
||||
settings: DnsSettings,
|
||||
cache: SimpleDnsCache,
|
||||
clientFactory: (ActorRefFactory, List[InetSocketAddress]) => List[ActorRef]) =
|
||||
this(settings, cache, clientFactory, IdGenerator(settings.IdGeneratorPolicy))
|
||||
|
||||
import AsyncDnsResolver._
|
||||
|
||||
implicit val ec: ExecutionContextExecutor = context.dispatcher
|
||||
|
|
@ -85,13 +94,6 @@ private[io] final class AsyncDnsResolver(
|
|||
settings.SearchDomains,
|
||||
settings.NDots)
|
||||
|
||||
private var requestId: Short = 0
|
||||
|
||||
private def nextId(): Short = {
|
||||
requestId = (requestId + 1).toShort
|
||||
requestId
|
||||
}
|
||||
|
||||
private val resolvers: List[ActorRef] = clientFactory(context, nameServers)
|
||||
|
||||
// only supports DnsProtocol, not the deprecated Dns protocol
|
||||
|
|
@ -151,11 +153,19 @@ private[io] final class AsyncDnsResolver(
|
|||
}
|
||||
|
||||
private def sendQuestion(resolver: ActorRef, message: DnsQuestion): Future[Answer] = {
|
||||
val result = (resolver ? message).mapTo[Answer]
|
||||
result.failed.foreach { _ =>
|
||||
resolver ! DropRequest(message.id)
|
||||
(resolver ? message).transformWith {
|
||||
case Success(result: Answer) =>
|
||||
Future.successful(result)
|
||||
case Success(DuplicateId(_)) =>
|
||||
sendQuestion(resolver, message.withId(idGenerator.nextId()))
|
||||
case Failure(t) =>
|
||||
resolver ! DropRequest(message)
|
||||
Future.failed(t)
|
||||
case Success(a) =>
|
||||
resolver ! DropRequest(message)
|
||||
Future.failed(
|
||||
new IllegalArgumentException("Unexpected response " + a.toString + " of type " + a.getClass.toString))
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
private def resolveWithSearch(
|
||||
|
|
@ -208,13 +218,13 @@ private[io] final class AsyncDnsResolver(
|
|||
case Ip(ipv4, ipv6) =>
|
||||
val ipv4Recs: Future[Answer] =
|
||||
if (ipv4)
|
||||
sendQuestion(resolver, Question4(nextId(), caseFoldedName))
|
||||
sendQuestion(resolver, Question4(idGenerator.nextId(), caseFoldedName))
|
||||
else
|
||||
Empty
|
||||
|
||||
val ipv6Recs =
|
||||
if (ipv6)
|
||||
sendQuestion(resolver, Question6(nextId(), caseFoldedName))
|
||||
sendQuestion(resolver, Question6(idGenerator.nextId(), caseFoldedName))
|
||||
else
|
||||
Empty
|
||||
|
||||
|
|
@ -224,7 +234,7 @@ private[io] final class AsyncDnsResolver(
|
|||
} yield DnsProtocol.Resolved(name, ipv4.rrs ++ ipv6.rrs, ipv4.additionalRecs ++ ipv6.additionalRecs)
|
||||
|
||||
case Srv =>
|
||||
sendQuestion(resolver, SrvQuestion(nextId(), caseFoldedName)).map(answer => {
|
||||
sendQuestion(resolver, SrvQuestion(idGenerator.nextId(), caseFoldedName)).map(answer => {
|
||||
DnsProtocol.Resolved(name, answer.rrs, answer.additionalRecs)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,13 +14,10 @@
|
|||
package org.apache.pekko.io.dns.internal
|
||||
|
||||
import java.net.{ InetAddress, InetSocketAddress }
|
||||
|
||||
import scala.collection.{ immutable => im }
|
||||
import scala.concurrent.duration._
|
||||
import scala.util.Try
|
||||
|
||||
import scala.annotation.nowarn
|
||||
|
||||
import scala.annotation.{ nowarn, tailrec }
|
||||
import org.apache.pekko
|
||||
import pekko.actor.{ Actor, ActorLogging, ActorRef, NoSerializationVerificationNeeded, Props, Stash }
|
||||
import pekko.actor.Status.Failure
|
||||
|
|
@ -35,13 +32,23 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
|
|||
@InternalApi private[pekko] object DnsClient {
|
||||
sealed trait DnsQuestion {
|
||||
def id: Short
|
||||
def name: String
|
||||
def withId(newId: Short): DnsQuestion = {
|
||||
this match {
|
||||
case SrvQuestion(_, name) => SrvQuestion(newId, name)
|
||||
case Question4(_, name) => Question4(newId, name)
|
||||
case Question6(_, name) => Question6(newId, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
final case class SrvQuestion(id: Short, name: String) extends DnsQuestion
|
||||
final case class Question4(id: Short, name: String) extends DnsQuestion
|
||||
final case class Question6(id: Short, name: String) extends DnsQuestion
|
||||
final case class Answer(id: Short, rrs: im.Seq[ResourceRecord], additionalRecs: im.Seq[ResourceRecord] = Nil)
|
||||
extends NoSerializationVerificationNeeded
|
||||
final case class DropRequest(id: Short)
|
||||
|
||||
final case class DuplicateId(id: Short) extends NoSerializationVerificationNeeded
|
||||
final case class DropRequest(question: DnsQuestion)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -85,30 +92,40 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
|
|||
*/
|
||||
@nowarn()
|
||||
def ready(socket: ActorRef): Receive = {
|
||||
case DropRequest(id) =>
|
||||
log.debug("Dropping request [{}]", id)
|
||||
inflightRequests -= id
|
||||
case DropRequest(msg) =>
|
||||
inflightRequests.get(msg.id).foreach {
|
||||
case (_, orig) if Seq(msg.name) == orig.questions.map(_.name) =>
|
||||
log.debug("Dropping request [{}]", msg.id)
|
||||
inflightRequests -= msg.id
|
||||
case (_, orig) =>
|
||||
log.warning("Cannot drop inflight DNS request the question [{}] does not match [{}]",
|
||||
msg.name,
|
||||
orig.questions.map(_.name).mkString(","))
|
||||
}
|
||||
|
||||
case Question4(id, name) =>
|
||||
log.debug("Resolving [{}] (A)", name)
|
||||
val msg = message(name, id, RecordType.A)
|
||||
inflightRequests += (id -> (sender() -> msg))
|
||||
log.debug("Message [{}] to [{}]: [{}]", id, ns, msg)
|
||||
socket ! Udp.Send(msg.write(), ns)
|
||||
newInflightRequests(msg, sender()) {
|
||||
log.debug("Message [{}] to [{}]: [{}]", id, ns, msg)
|
||||
socket ! Udp.Send(msg.write(), ns)
|
||||
}
|
||||
|
||||
case Question6(id, name) =>
|
||||
log.debug("Resolving [{}] (AAAA)", name)
|
||||
val msg = message(name, id, RecordType.AAAA)
|
||||
inflightRequests += (id -> (sender() -> msg))
|
||||
log.debug("Message to [{}]: [{}]", ns, msg)
|
||||
socket ! Udp.Send(msg.write(), ns)
|
||||
newInflightRequests(msg, sender()) {
|
||||
log.debug("Message [{}] to [{}]: [{}]", id, ns, msg)
|
||||
socket ! Udp.Send(msg.write(), ns)
|
||||
}
|
||||
|
||||
case SrvQuestion(id, name) =>
|
||||
log.debug("Resolving [{}] (SRV)", name)
|
||||
val msg = message(name, id, RecordType.SRV)
|
||||
inflightRequests += (id -> (sender() -> msg))
|
||||
log.debug("Message to [{}]: [{}]", ns, msg)
|
||||
socket ! Udp.Send(msg.write(), ns)
|
||||
newInflightRequests(msg, sender()) {
|
||||
log.debug("Message [{}] to [{}]: [{}]", id, ns, msg)
|
||||
socket ! Udp.Send(msg.write(), ns)
|
||||
}
|
||||
|
||||
case Udp.CommandFailed(cmd) =>
|
||||
log.debug("Command failed [{}]", cmd)
|
||||
|
|
@ -118,9 +135,13 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
|
|||
Try {
|
||||
val msg = Message.parse(send.payload)
|
||||
inflightRequests.get(msg.id).foreach {
|
||||
case (s, _) =>
|
||||
case (s, orig) if isSameQuestion(msg.questions, orig.questions) =>
|
||||
s ! Failure(new RuntimeException("Send failed to nameserver"))
|
||||
inflightRequests -= msg.id
|
||||
case (_, orig) =>
|
||||
log.warning("Cannot command failed question [{}] does not match [{}]",
|
||||
msg.questions.mkString(","),
|
||||
orig.questions.mkString(","))
|
||||
}
|
||||
}
|
||||
case _ =>
|
||||
|
|
@ -140,9 +161,24 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
|
|||
log.debug("Client for id {} not found. Discarding unsuccessful response.", msg.id)
|
||||
}
|
||||
} else {
|
||||
val (recs, additionalRecs) =
|
||||
if (msg.flags.responseCode == ResponseCode.SUCCESS) (msg.answerRecs, msg.additionalRecs) else (Nil, Nil)
|
||||
self ! Answer(msg.id, recs, additionalRecs)
|
||||
inflightRequests.get(msg.id) match {
|
||||
case Some((_, orig)) if !isSameQuestion(msg.questions, orig.questions) =>
|
||||
log.warning(
|
||||
"Unexpected DNS response id {} question [{}] does not match question asked [{}]",
|
||||
msg.id,
|
||||
msg.questions.mkString(","),
|
||||
orig.questions.mkString(","))
|
||||
case Some((_, orig)) =>
|
||||
log.warning("DNS response id {} question [{}] question asked [{}]",
|
||||
msg.id,
|
||||
msg.questions.mkString(","),
|
||||
orig.questions.mkString(","))
|
||||
val (recs, additionalRecs) =
|
||||
if (msg.flags.responseCode == ResponseCode.SUCCESS) (msg.answerRecs, msg.additionalRecs) else (Nil, Nil)
|
||||
self ! Answer(msg.id, recs, additionalRecs)
|
||||
case None =>
|
||||
log.warning("Unexpected DNS response invalid id {}", msg.id)
|
||||
}
|
||||
}
|
||||
case response: Answer =>
|
||||
inflightRequests.get(response.id) match {
|
||||
|
|
@ -156,6 +192,29 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
|
|||
case Udp.Unbound => context.stop(self)
|
||||
}
|
||||
|
||||
private def newInflightRequests(msg: Message, theSender: ActorRef)(func: => Unit): Unit = {
|
||||
if (!inflightRequests.contains(msg.id)) {
|
||||
inflightRequests += (msg.id -> (theSender -> msg))
|
||||
func
|
||||
} else {
|
||||
log.warning("Received duplicate message [{}] with id [{}]", msg, msg.id)
|
||||
theSender ! DuplicateId(msg.id)
|
||||
}
|
||||
}
|
||||
|
||||
private def isSameQuestion(q1s: Seq[Question], q2s: Seq[Question]): Boolean = {
|
||||
@tailrec
|
||||
def impl(q1s: List[Question], q2s: List[Question]): Boolean = {
|
||||
(q1s, q2s) match {
|
||||
case (Nil, Nil) => true
|
||||
case (h1 :: t1, h2 :: t2) => h1.isSame(h2) && impl(t1, t2)
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
|
||||
impl(q1s.sortBy(_.name).toList, q2s.sortBy(_.name).toList)
|
||||
}
|
||||
|
||||
def createTcpClient() = {
|
||||
context.actorOf(
|
||||
BackoffSupervisor.props(
|
||||
|
|
|
|||
|
|
@ -28,6 +28,16 @@ private[pekko] final case class Question(name: String, qType: RecordType, qClass
|
|||
RecordTypeSerializer.write(out, qType)
|
||||
RecordClassSerializer.write(out, qClass)
|
||||
}
|
||||
|
||||
/**
|
||||
* DomainName.parse adds a '.' to the end of domain names we have to allow for this checking the domain names
|
||||
* @return true if this the questions are the same (allowing for an trailing dot).
|
||||
*/
|
||||
def isSame(that: Question): Boolean = {
|
||||
def addDot(name: String): String = if (name.nonEmpty && !name.endsWith(".")) name.concat(".") else name
|
||||
|
||||
addDot(this.name) == addDot(that.name) && this.qType == that.qType && this.qClass == that.qClass
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* license agreements; and to You under the Apache License, version 2.0:
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* This file is part of the Apache Pekko project, derived from Akka.
|
||||
*/
|
||||
|
||||
package org.apache.pekko.io.dns
|
||||
|
||||
import org.openjdk.jmh.annotations.{
|
||||
Benchmark,
|
||||
BenchmarkMode,
|
||||
Fork,
|
||||
Measurement,
|
||||
Mode,
|
||||
OutputTimeUnit,
|
||||
Scope,
|
||||
State,
|
||||
Threads,
|
||||
Warmup
|
||||
}
|
||||
|
||||
import java.util.concurrent.{ ThreadLocalRandom, TimeUnit }
|
||||
import java.security.SecureRandom
|
||||
|
||||
@BenchmarkMode(Array(Mode.Throughput))
|
||||
@OutputTimeUnit(TimeUnit.NANOSECONDS)
|
||||
@Warmup(iterations = 3, time = 5, timeUnit = TimeUnit.SECONDS)
|
||||
@Measurement(iterations = 3, time = 5)
|
||||
@Threads(8)
|
||||
@Fork(1)
|
||||
@State(Scope.Benchmark)
|
||||
class IdGeneratorBanchmark {
|
||||
val threadLocalRandom = IdGenerator.random(ThreadLocalRandom.current())
|
||||
val secureRandom = IdGenerator.random(new SecureRandom())
|
||||
|
||||
@Benchmark
|
||||
def measureThreadLocalRandom(): Short = threadLocalRandom.nextId()
|
||||
|
||||
@Benchmark
|
||||
def measureSecureRandom(): Short = secureRandom.nextId()
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue