diff --git a/akka-actor-tests/src/test/scala/akka/util/ByteStringSpec.scala b/akka-actor-tests/src/test/scala/akka/util/ByteStringSpec.scala index 0f10d7ae52..b1edd72d76 100644 --- a/akka-actor-tests/src/test/scala/akka/util/ByteStringSpec.scala +++ b/akka-actor-tests/src/test/scala/akka/util/ByteStringSpec.scala @@ -4,6 +4,8 @@ package akka.util +import java.io.{ ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, ByteArrayOutputStream } + import org.scalatest.WordSpec import org.scalatest.Matchers import org.scalatest.prop.Checkers @@ -11,6 +13,8 @@ import org.scalacheck.Arbitrary import org.scalacheck.Arbitrary.arbitrary import org.scalacheck.Gen +import org.apache.commons.codec.binary.Hex.{ encodeHex, decodeHex } + import scala.collection.mutable.Builder import java.nio.{ ByteBuffer } @@ -56,6 +60,23 @@ class ByteStringSpec extends WordSpec with Matchers with Checkers { } yield (xs, from, until) } + def testSer(obj: AnyRef) = { + val os = new ByteArrayOutputStream + val bos = new ObjectOutputStream(os) + bos.writeObject(obj) + val arr = os.toByteArray + val is = new ObjectInputStream(new ByteArrayInputStream(arr)) + + is.readObject == obj + } + + def hexFromSer(obj: AnyRef) = { + val os = new ByteArrayOutputStream + val bos = new ObjectOutputStream(os) + bos.writeObject(obj) + String valueOf encodeHex(os.toByteArray) + } + val arbitraryByteArray: Arbitrary[Array[Byte]] = Arbitrary { Gen.sized { n ⇒ Gen.containerOfN[Array, Byte](n, arbitrary[Byte]) } } implicit val arbitraryByteArraySlice: Arbitrary[ArraySlice[Byte]] = arbSlice(arbitraryByteArray) val arbitraryShortArray: Arbitrary[Array[Short]] = Arbitrary { Gen.sized { n ⇒ Gen.containerOfN[Array, Short](n, arbitrary[Short]) } } @@ -370,6 +391,22 @@ class ByteStringSpec extends WordSpec with Matchers with Checkers { } } } + + "serialize correctly" when { + "parsing regular ByteString1C as compat" in { + val oldSerd = "aced000573720021616b6b612e7574696c2e42797465537472696e672442797465537472696e67314336e9eed0afcfe4a40200015b000562797465737400025b427872001b616b6b612e7574696c2e436f6d7061637442797465537472696e67fa2925150f93468f0200007870757200025b42acf317f8060854e002000078700000000a74657374737472696e67" + val bs = ByteString("teststring", "UTF8") + val str = hexFromSer(bs) + + require(oldSerd == str) + } + + "given all types of ByteString" in { + check { bs: ByteString ⇒ + testSer(bs) + } + } + } } "A ByteStringIterator" must { diff --git a/akka-actor/src/main/scala/akka/util/ByteString.scala b/akka-actor/src/main/scala/akka/util/ByteString.scala index 9e72959db9..e4ed4fc505 100644 --- a/akka-actor/src/main/scala/akka/util/ByteString.scala +++ b/akka-actor/src/main/scala/akka/util/ByteString.scala @@ -4,6 +4,7 @@ package akka.util +import java.io.{ ObjectInputStream, ObjectOutputStream } import java.nio.{ ByteBuffer, ByteOrder } import java.lang.{ Iterable ⇒ JIterable } @@ -93,8 +94,16 @@ object ByteString { def apply(): ByteStringBuilder = newBuilder } - private[akka] object ByteString1C { + private[akka] object ByteString1C extends Companion { def apply(bytes: Array[Byte]): ByteString1C = new ByteString1C(bytes) + val SerializationIdentity = 1.toByte + + def readFromInputStream(is: ObjectInputStream): ByteString1C = { + val length = is.readInt() + val arr = new Array[Byte](length) + is.read(arr, 0, length) + ByteString1C(arr) + } } /** @@ -110,6 +119,8 @@ object ByteString { private[akka] def toByteString1: ByteString1 = ByteString1(bytes) + private[akka] def byteStringCompanion = ByteString1C + def asByteBuffer: ByteBuffer = toByteString1.asByteBuffer def asByteBuffers: scala.collection.immutable.Iterable[ByteBuffer] = List(asByteBuffer) @@ -125,19 +136,27 @@ object ByteString { override def slice(from: Int, until: Int): ByteString = if ((from != 0) || (until != length)) toByteString1.slice(from, until) else this + + private[akka] def writeToOutputStream(os: ObjectOutputStream): Unit = + toByteString1.writeToOutputStream(os) } - private[akka] object ByteString1 { + private[akka] object ByteString1 extends Companion { val empty: ByteString1 = new ByteString1(Array.empty[Byte]) def apply(bytes: Array[Byte]): ByteString1 = ByteString1(bytes, 0, bytes.length) def apply(bytes: Array[Byte], startIndex: Int, length: Int): ByteString1 = if (length == 0) empty else new ByteString1(bytes, startIndex, length) + + val SerializationIdentity = 0.toByte + + def readFromInputStream(is: ObjectInputStream): ByteString1 = + ByteString1C.readFromInputStream(is).toByteString1 } /** * An unfragmented ByteString. */ - final class ByteString1 private (private val bytes: Array[Byte], private val startIndex: Int, val length: Int) extends ByteString { + final class ByteString1 private (private val bytes: Array[Byte], private val startIndex: Int, val length: Int) extends ByteString with Serializable { private def this(bytes: Array[Byte]) = this(bytes, 0, bytes.length) @@ -153,8 +172,15 @@ object ByteString { throw new IndexOutOfBoundsException(index.toString) } + private[akka] def writeToOutputStream(os: ObjectOutputStream): Unit = { + os.writeInt(length) + os.write(bytes, startIndex, length) + } + def isCompact: Boolean = (length == bytes.length) + private[akka] def byteStringCompanion = ByteString1 + def compact: CompactByteString = if (isCompact) ByteString1C(bytes) else ByteString1C(toArray) @@ -181,9 +207,11 @@ object ByteString { case bs: ByteStrings ⇒ ByteStrings(this, bs) } } + + protected def writeReplace(): AnyRef = new SerializationProxy(this) } - private[akka] object ByteStrings { + private[akka] object ByteStrings extends Companion { def apply(bytestrings: Vector[ByteString1]): ByteString = new ByteStrings(bytestrings, (0 /: bytestrings)(_ + _.length)) def apply(bytestrings: Vector[ByteString1], length: Int): ByteString = new ByteStrings(bytestrings, length) @@ -222,12 +250,30 @@ object ByteString { if (b2.isEmpty) 0 else 2 else if (b2.isEmpty) 1 else 3 + val SerializationIdentity = 2.toByte + + def readFromInputStream(is: ObjectInputStream): ByteStrings = { + val nByteStrings = is.readInt() + + val builder = new VectorBuilder[ByteString1] + var length = 0 + + builder.sizeHint(nByteStrings) + + for (_ ← 0 until nByteStrings) { + val bs = ByteString1.readFromInputStream(is) + builder += bs + length += bs.length + } + + new ByteStrings(builder.result(), length) + } } /** * A ByteString with 2 or more fragments. */ - final class ByteStrings private (private[akka] val bytestrings: Vector[ByteString1], val length: Int) extends ByteString { + final class ByteStrings private (private[akka] val bytestrings: Vector[ByteString1], val length: Int) extends ByteString with Serializable { if (bytestrings.isEmpty) throw new IllegalArgumentException("bytestrings must not be empty") def apply(idx: Int): Byte = @@ -254,6 +300,8 @@ object ByteString { } } + private[akka] def byteStringCompanion = ByteStrings + def isCompact: Boolean = if (bytestrings.length == 1) bytestrings.head.isCompact else false def compact: CompactByteString = { @@ -274,8 +322,43 @@ object ByteString { def asByteBuffers: scala.collection.immutable.Iterable[ByteBuffer] = bytestrings map { _.asByteBuffer } def decodeString(charset: String): String = compact.decodeString(charset) + + private[akka] def writeToOutputStream(os: ObjectOutputStream): Unit = { + os.writeInt(bytestrings.length) + bytestrings.foreach(_.writeToOutputStream(os)) + } + + protected def writeReplace(): AnyRef = new SerializationProxy(this) } + @SerialVersionUID(1L) + private class SerializationProxy(@transient private var orig: ByteString) extends Serializable { + private def writeObject(out: ObjectOutputStream) { + out.writeByte(orig.byteStringCompanion.SerializationIdentity) + orig.writeToOutputStream(out) + } + + private def readObject(in: ObjectInputStream) { + val serializationId = in.readByte() + + orig = Companion(from = serializationId).readFromInputStream(in) + } + + private def readResolve(): AnyRef = orig + } + + private[akka] object Companion { + private val companionMap = Seq(ByteString1, ByteString1C, ByteStrings). + map(x ⇒ x.SerializationIdentity -> x).toMap. + withDefault(x ⇒ throw new IllegalArgumentException("Invalid serialization id " + x)) + + def apply(from: Byte): Companion = companionMap(from) + } + + private[akka] sealed trait Companion { + def SerializationIdentity: Byte + def readFromInputStream(is: ObjectInputStream): ByteString + } } /** @@ -288,6 +371,7 @@ object ByteString { */ sealed abstract class ByteString extends IndexedSeq[Byte] with IndexedSeqOptimized[Byte, ByteString] { def apply(idx: Int): Byte + private[akka] def byteStringCompanion: ByteString.Companion override protected[this] def newBuilder: ByteStringBuilder = ByteString.newBuilder @@ -333,6 +417,8 @@ sealed abstract class ByteString extends IndexedSeq[Byte] with IndexedSeqOptimiz override def foreach[@specialized U](f: Byte ⇒ U): Unit = iterator foreach f + private[akka] def writeToOutputStream(os: ObjectOutputStream): Unit + /** * Efficiently concatenate another ByteString. */