diff --git a/akka-actor-tests/src/test/scala/akka/io/dns/internal/MessageSpec.scala b/akka-actor-tests/src/test/scala/akka/io/dns/internal/MessageSpec.scala new file mode 100644 index 0000000000..0acb7fdee6 --- /dev/null +++ b/akka-actor-tests/src/test/scala/akka/io/dns/internal/MessageSpec.scala @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2018 Lightbend Inc. + */ + +package akka.io.dns.internal + +import akka.io.dns.{ RecordClass, RecordType } +import akka.util.ByteString +import org.scalatest.{ Matchers, WordSpec } + +class MessageSpec extends WordSpec with Matchers { + "The Message" should { + "parse a response that is truncated mid-message" in { + val bytes = ByteString(0, 4, -125, -128, 0, 1, 0, 48, 0, 0, 0, 0, 4, 109, 97, 110, 121, 4, 98, 122, 122, 116, 3, 110, 101, 116, 0, 0, 28, 0, 1) + val msg = Message.parse(bytes) + msg.id should be(4) + msg.flags.isTruncated should be(true) + msg.questions.length should be(1) + msg.questions.head should be(Question("many.bzzt.net", RecordType.AAAA, RecordClass.IN)) + msg.answerRecs.length should be(0) + msg.authorityRecs.length should be(0) + msg.additionalRecs.length should be(0) + } + } + +} diff --git a/akka-actor/src/main/scala/akka/io/dns/internal/DnsMessage.scala b/akka-actor/src/main/scala/akka/io/dns/internal/DnsMessage.scala index 197d16d516..186596f0be 100644 --- a/akka-actor/src/main/scala/akka/io/dns/internal/DnsMessage.scala +++ b/akka-actor/src/main/scala/akka/io/dns/internal/DnsMessage.scala @@ -8,7 +8,9 @@ import akka.annotation.InternalApi import akka.io.dns.ResourceRecord import akka.util.{ ByteString, ByteStringBuilder } +import scala.collection.GenTraversableOnce import scala.collection.immutable.Seq +import scala.util.{ Failure, Success, Try } /** * INTERNAL API @@ -128,17 +130,26 @@ private[internal] object Message { def parse(msg: ByteString): Message = { val it = msg.iterator val id = it.getShort - val flags = it.getShort + val flags = new MessageFlags(it.getShort) + val qdCount = it.getShort val anCount = it.getShort val nsCount = it.getShort val arCount = it.getShort - val qs = (0 until qdCount) map { i ⇒ Question.parse(it, msg) } - val ans = (0 until anCount) map { i ⇒ ResourceRecord.parse(it, msg) } - val nss = (0 until nsCount) map { i ⇒ ResourceRecord.parse(it, msg) } - val ars = (0 until arCount) map { i ⇒ ResourceRecord.parse(it, msg) } + val qs = (0 until qdCount) map { i ⇒ Try(Question.parse(it, msg)) } + val ans = (0 until anCount) map { i ⇒ Try(ResourceRecord.parse(it, msg)) } + val nss = (0 until nsCount) map { i ⇒ Try(ResourceRecord.parse(it, msg)) } + val ars = (0 until arCount) map { i ⇒ Try(ResourceRecord.parse(it, msg)) } - new Message(id, new MessageFlags(flags), qs, ans, nss, ars) + import scala.language.implicitConversions + implicit def flattener[T](tried: Try[T]): GenTraversableOnce[T] = + if (flags.isTruncated) tried.toOption + else tried match { + case Success(value) ⇒ Some(value) + case Failure(reason) ⇒ throw reason + } + + new Message(id, flags, qs.flatten, ans.flatten, nss.flatten, ars.flatten) } }