diff --git a/akka-actor-tests/src/test/scala/akka/util/ByteStringSpec.scala b/akka-actor-tests/src/test/scala/akka/util/ByteStringSpec.scala index 1514635c8a..a6a614475c 100644 --- a/akka-actor-tests/src/test/scala/akka/util/ByteStringSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/util/ByteStringSpec.scala @@ -308,6 +308,19 @@ class ByteStringSpec extends WordSpec with Matchers with Checkers { ByteString1.fromString("ab").drop(2) should ===(ByteString("")) ByteString1.fromString("ab").drop(3) should ===(ByteString("")) } + "take" in { + ByteString1.empty.take(-1) should ===(ByteString("")) + ByteString1.empty.take(0) should ===(ByteString("")) + ByteString1.empty.take(1) should ===(ByteString("")) + ByteString1.fromString("a").take(1) should ===(ByteString("a")) + ByteString1.fromString("ab").take(-1) should ===(ByteString("")) + ByteString1.fromString("ab").take(0) should ===(ByteString("")) + ByteString1.fromString("ab").take(1) should ===(ByteString("a")) + ByteString1.fromString("ab").take(2) should ===(ByteString("ab")) + ByteString1.fromString("ab").take(3) should ===(ByteString("ab")) + ByteString1.fromString("0123456789").take(3).drop(1) should ===(ByteString("12")) + ByteString1.fromString("0123456789").take(10).take(8).drop(3).take(5) should ===(ByteString("34567")) + } } "ByteString1C" must { "drop(0)" in { @@ -415,6 +428,9 @@ class ByteStringSpec extends WordSpec with Matchers with Checkers { ByteStrings(ByteString1.fromString("a"), ByteString1.fromString("bc")).dropRight(3) should ===(ByteString("")) } "take" in { + ByteString.empty.take(-1) should ===(ByteString("")) + ByteString.empty.take(0) should ===(ByteString("")) + ByteString.empty.take(1) should ===(ByteString("")) ByteStrings(ByteString1.fromString("a"), ByteString1.fromString("bc")).drop(1).take(0) should ===(ByteString("")) ByteStrings(ByteString1.fromString("a"), ByteString1.fromString("bc")).drop(1).take(-1) should ===(ByteString("")) ByteStrings(ByteString1.fromString("a"), ByteString1.fromString("bc")).drop(1).take(-2) should ===(ByteString("")) diff --git a/akka-actor/src/main/scala/akka/util/ByteString.scala b/akka-actor/src/main/scala/akka/util/ByteString.scala index d2600320d8..8196495d7c 100644 --- a/akka-actor/src/main/scala/akka/util/ByteString.scala +++ b/akka-actor/src/main/scala/akka/util/ByteString.scala @@ -249,8 +249,11 @@ object ByteString { } override def take(n: Int): ByteString = - if (n <= 0) ByteString.empty - else ByteString1(bytes, startIndex, Math.min(n, length)) + if (n <= 0) ByteString.empty else take1(n) + + private[akka] def take1(n: Int): ByteString1 = + if (n >= length) this + else ByteString1(bytes, startIndex, n) override def slice(from: Int, until: Int): ByteString = drop(from).take(until - Math.max(0, from)) @@ -432,18 +435,23 @@ object ByteString { bytestrings.foreach(_.writeToOutputStream(os)) } - override def take(n: Int): ByteString = { - @tailrec def take0(n: Int, b: ByteStringBuilder, bs: Vector[ByteString1]): ByteString = - if (bs.isEmpty || n <= 0) b.result - else { - val head = bs.head - if (n <= head.length) b.append(head.take(n)).result - else take0(n - head.length, b.append(head), bs.tail) - } - + override def take(n: Int): ByteString = if (n <= 0) ByteString.empty else if (n >= length) this - else take0(n, ByteString.newBuilder, bytestrings) + else take0(n) + + private[akka] def take0(n: Int): ByteString = { + @tailrec def go(last: Int, restToTake: Int): (Int, Int) = { + val bs = bytestrings(last) + if (bs.length > restToTake) (last, restToTake) + else go(last + 1, restToTake - bs.length) + } + + val (last, restToTake) = go(0, n) + + if (last == 0) bytestrings(last).take(restToTake) + else if (restToTake == 0) new ByteStrings(bytestrings.take(last), n) + else new ByteStrings(bytestrings.take(last) :+ bytestrings(last).take1(restToTake), n) } override def dropRight(n: Int): ByteString = @@ -469,7 +477,7 @@ object ByteString { override def drop(n: Int): ByteString = if (n <= 0) this - else if (n > length) ByteString.empty + else if (n >= length) ByteString.empty else drop0(n) private def drop0(n: Int): ByteString = {