Fix parsing truncated DNS messages (#25691)

* Fix parsing truncated DNS messages

* Additional validations
This commit is contained in:
Arnout Engelen 2018-10-02 16:52:28 +02:00 committed by Patrik Nordwall
parent 68074c3deb
commit 0101740825
2 changed files with 43 additions and 6 deletions

View file

@ -0,0 +1,26 @@
/*
* Copyright (C) 2018 Lightbend Inc. <https://www.lightbend.com>
*/
package akka.io.dns.internal
import akka.io.dns.{ RecordClass, RecordType }
import akka.util.ByteString
import org.scalatest.{ Matchers, WordSpec }
class MessageSpec extends WordSpec with Matchers {
"The Message" should {
"parse a response that is truncated mid-message" in {
val bytes = ByteString(0, 4, -125, -128, 0, 1, 0, 48, 0, 0, 0, 0, 4, 109, 97, 110, 121, 4, 98, 122, 122, 116, 3, 110, 101, 116, 0, 0, 28, 0, 1)
val msg = Message.parse(bytes)
msg.id should be(4)
msg.flags.isTruncated should be(true)
msg.questions.length should be(1)
msg.questions.head should be(Question("many.bzzt.net", RecordType.AAAA, RecordClass.IN))
msg.answerRecs.length should be(0)
msg.authorityRecs.length should be(0)
msg.additionalRecs.length should be(0)
}
}
}

View file

@ -8,7 +8,9 @@ import akka.annotation.InternalApi
import akka.io.dns.ResourceRecord
import akka.util.{ ByteString, ByteStringBuilder }
import scala.collection.GenTraversableOnce
import scala.collection.immutable.Seq
import scala.util.{ Failure, Success, Try }
/**
* INTERNAL API
@ -128,17 +130,26 @@ private[internal] object Message {
def parse(msg: ByteString): Message = {
val it = msg.iterator
val id = it.getShort
val flags = it.getShort
val flags = new MessageFlags(it.getShort)
val qdCount = it.getShort
val anCount = it.getShort
val nsCount = it.getShort
val arCount = it.getShort
val qs = (0 until qdCount) map { i Question.parse(it, msg) }
val ans = (0 until anCount) map { i ResourceRecord.parse(it, msg) }
val nss = (0 until nsCount) map { i ResourceRecord.parse(it, msg) }
val ars = (0 until arCount) map { i ResourceRecord.parse(it, msg) }
val qs = (0 until qdCount) map { i Try(Question.parse(it, msg)) }
val ans = (0 until anCount) map { i Try(ResourceRecord.parse(it, msg)) }
val nss = (0 until nsCount) map { i Try(ResourceRecord.parse(it, msg)) }
val ars = (0 until arCount) map { i Try(ResourceRecord.parse(it, msg)) }
new Message(id, new MessageFlags(flags), qs, ans, nss, ars)
import scala.language.implicitConversions
implicit def flattener[T](tried: Try[T]): GenTraversableOnce[T] =
if (flags.isTruncated) tried.toOption
else tried match {
case Success(value) Some(value)
case Failure(reason) throw reason
}
new Message(id, flags, qs.flatten, ans.flatten, nss.flatten, ars.flatten)
}
}