diff --git a/akka-actor/src/main/resources/reference.conf b/akka-actor/src/main/resources/reference.conf index 2e097c80a1..4a1aabf6b5 100644 --- a/akka-actor/src/main/resources/reference.conf +++ b/akka-actor/src/main/resources/reference.conf @@ -773,6 +773,22 @@ akka { primitive-bytestring = 21 primitive-boolean = 35 } + + } + + serialization.protobuf { + + # Additional classes that are allowed even if they are not defined in `serialization-bindings`. + # It can be exact class name or name of super class or interfaces (one level). + # This is useful when a class is not used for serialization any more and therefore removed + # from `serialization-bindings`, but should still be possible to deserialize. + whitelist-class = [ + "com.google.protobuf.GeneratedMessage", + "com.google.protobuf.GeneratedMessageV3", + "scalapb.GeneratedMessageCompanion", + "akka.protobuf.GeneratedMessage", + "akka.protobufv3.internal.GeneratedMessageV3" + ] } # Used to set the behavior of the scheduler. diff --git a/akka-remote/src/main/scala/akka/remote/serialization/ProtobufSerializer.scala b/akka-remote/src/main/scala/akka/remote/serialization/ProtobufSerializer.scala index 6b87968206..b03e33606b 100644 --- a/akka-remote/src/main/scala/akka/remote/serialization/ProtobufSerializer.scala +++ b/akka-remote/src/main/scala/akka/remote/serialization/ProtobufSerializer.scala @@ -10,8 +10,12 @@ import java.util.concurrent.atomic.AtomicReference import akka.actor.{ ActorRef, ExtendedActorSystem } import akka.remote.WireFormats.ActorRefData import akka.serialization.{ BaseSerializer, Serialization } - import scala.annotation.tailrec +import scala.util.control.NonFatal + +import akka.event.LogMarker +import akka.event.Logging +import akka.serialization.SerializationExtension object ProtobufSerializer { private val ARRAY_OF_BYTE_ARRAY = Array[Class[_]](classOf[Array[Byte]]) @@ -43,6 +47,16 @@ class ProtobufSerializer(val system: ExtendedActorSystem) extends BaseSerializer private val parsingMethodBindingRef = new AtomicReference[Map[Class[_], Method]](Map.empty) private val toByteArrayMethodBindingRef = new AtomicReference[Map[Class[_], Method]](Map.empty) + private val whitelistClassNames: Set[String] = { + import akka.util.ccompat.JavaConverters._ + system.settings.config.getStringList("akka.serialization.protobuf.whitelist-class").asScala.toSet + } + + // This must lazy otherwise it will deadlock the ActorSystem creation + private lazy val serialization = SerializationExtension(system) + + private val log = Logging.withMarker(system, getClass) + override def includeManifest: Boolean = true override def fromBinary(bytes: Array[Byte], manifest: Option[Class[_]]): AnyRef = { @@ -54,6 +68,7 @@ class ProtobufSerializer(val system: ExtendedActorSystem) extends BaseSerializer parsingMethodBinding.get(clazz) match { case Some(cachedParsingMethod) => cachedParsingMethod case None => + checkAllowedClass(clazz) val unCachedParsingMethod = if (method eq null) clazz.getDeclaredMethod("parseFrom", ProtobufSerializer.ARRAY_OF_BYTE_ARRAY: _*) else method @@ -93,4 +108,49 @@ class ProtobufSerializer(val system: ExtendedActorSystem) extends BaseSerializer } toByteArrayMethod().invoke(obj).asInstanceOf[Array[Byte]] } + + private def checkAllowedClass(clazz: Class[_]): Unit = { + if (!isInWhitelist(clazz)) { + val warnMsg = s"Can't deserialize object of type [${clazz.getName}] in [${getClass.getName}]. " + + "Only classes that are whitelisted are allowed for security reasons. " + + "Configure whitelist with akka.actor.serialization-bindings or " + + "akka.serialization.protobuf.whitelist-class" + log.warning(LogMarker.Security, warnMsg) + throw new IllegalArgumentException(warnMsg) + } + } + + /** + * Using the `serialization-bindings` as source for the whitelist. + * Note that the intended usage of serialization-bindings is for lookup of + * serializer when serializing (`toBinary`). For deserialization (`fromBinary`) the serializer-id is + * used for selecting serializer. + * Here we use `serialization-bindings` also when deserializing (fromBinary) + * to check that the manifest class is of a known (registered) type. + * + * If an old class is removed from `serialization-bindings` when it's not used for serialization + * but still used for deserialization (e.g. rolling update with serialization changes) it can + * be allowed by specifying in `akka.protobuf.whitelist-class`. + * + * That is also possible when changing a binding from a ProtobufSerializer to another serializer (e.g. Jackson) + * and still bind with the same class (interface). + */ + private def isInWhitelist(clazz: Class[_]): Boolean = { + isBoundToProtobufSerializer(clazz) || isInWhitelistClassName(clazz) + } + + private def isBoundToProtobufSerializer(clazz: Class[_]): Boolean = { + try { + val boundSerializer = serialization.serializerFor(clazz) + boundSerializer.isInstanceOf[ProtobufSerializer] + } catch { + case NonFatal(_) => false // not bound + } + } + + private def isInWhitelistClassName(clazz: Class[_]): Boolean = { + whitelistClassNames(clazz.getName) || + whitelistClassNames(clazz.getSuperclass.getName) || + clazz.getInterfaces.exists(c => whitelistClassNames(c.getName)) + } } diff --git a/akka-remote/src/test/scala/akka/remote/serialization/ProtobufSerializerSpec.scala b/akka-remote/src/test/scala/akka/remote/serialization/ProtobufSerializerSpec.scala index ed125e7b22..6be42f3da0 100644 --- a/akka-remote/src/test/scala/akka/remote/serialization/ProtobufSerializerSpec.scala +++ b/akka-remote/src/test/scala/akka/remote/serialization/ProtobufSerializerSpec.scala @@ -11,8 +11,51 @@ import akka.remote.ProtobufProtocol.MyMessage import akka.remote.MessageSerializer import akka.actor.ExtendedActorSystem import akka.remote.protobuf.v3.ProtobufProtocolV3.MyMessageV3 +import akka.util.unused -class ProtobufSerializerSpec extends AkkaSpec { +// those must be defined as top level classes, to have static parseFrom +case class MaliciousMessage() {} + +object ProtobufSerializerSpec { + trait AnotherInterface + abstract class AnotherBase +} + +object AnotherMessage { + def parseFrom(@unused bytes: Array[Byte]): AnotherMessage = + new AnotherMessage +} +case class AnotherMessage() {} + +object AnotherMessage2 { + def parseFrom(@unused bytes: Array[Byte]): AnotherMessage2 = + new AnotherMessage2 +} +case class AnotherMessage2() extends ProtobufSerializerSpec.AnotherInterface {} + +object AnotherMessage3 { + def parseFrom(@unused bytes: Array[Byte]): AnotherMessage3 = + new AnotherMessage3 +} +case class AnotherMessage3() extends ProtobufSerializerSpec.AnotherBase {} + +object MaliciousMessage { + def parseFrom(@unused bytes: Array[Byte]): MaliciousMessage = + new MaliciousMessage +} + +class ProtobufSerializerSpec extends AkkaSpec(s""" + akka.serialization.protobuf.whitelist-class = [ + "com.google.protobuf.GeneratedMessage", + "com.google.protobuf.GeneratedMessageV3", + "scalapb.GeneratedMessageCompanion", + "akka.protobuf.GeneratedMessage", + "akka.protobufv3.internal.GeneratedMessageV3", + "${classOf[AnotherMessage].getName}", + "${classOf[ProtobufSerializerSpec.AnotherInterface].getName}", + "${classOf[ProtobufSerializerSpec.AnotherBase].getName}" + ] + """) { val ser = SerializationExtension(system) @@ -44,5 +87,37 @@ class ProtobufSerializerSpec extends AkkaSpec { protobufV3Message should ===(deserialized) } + "disallow deserialization of classes that are not in bindings and not in configured whitelist-class" in { + val originalSerializer = ser.serializerFor(classOf[MyMessage]) + + intercept[IllegalArgumentException] { + ser.deserialize(Array[Byte](), originalSerializer.identifier, classOf[MaliciousMessage].getName).get + } + } + + "allow deserialization of classes in configured whitelist-class" in { + val originalSerializer = ser.serializerFor(classOf[MyMessage]) + + val deserialized = + ser.deserialize(Array[Byte](), originalSerializer.identifier, classOf[AnotherMessage].getName).get + deserialized.getClass should ===(classOf[AnotherMessage]) + } + + "allow deserialization of interfaces in configured whitelist-class" in { + val originalSerializer = ser.serializerFor(classOf[MyMessage]) + + val deserialized = + ser.deserialize(Array[Byte](), originalSerializer.identifier, classOf[AnotherMessage2].getName).get + deserialized.getClass should ===(classOf[AnotherMessage2]) + } + + "allow deserialization of super classes in configured whitelist-class" in { + val originalSerializer = ser.serializerFor(classOf[MyMessage]) + + val deserialized = + ser.deserialize(Array[Byte](), originalSerializer.identifier, classOf[AnotherMessage3].getName).get + deserialized.getClass should ===(classOf[AnotherMessage3]) + } + } } diff --git a/akka-serialization-jackson/src/main/scala/akka/serialization/jackson/JacksonSerializer.scala b/akka-serialization-jackson/src/main/scala/akka/serialization/jackson/JacksonSerializer.scala index e227eedc27..5ffadc2012 100644 --- a/akka-serialization-jackson/src/main/scala/akka/serialization/jackson/JacksonSerializer.scala +++ b/akka-serialization-jackson/src/main/scala/akka/serialization/jackson/JacksonSerializer.scala @@ -320,7 +320,7 @@ import com.fasterxml.jackson.dataformat.cbor.CBORFactory * Here we use `serialization-bindings` also and more importantly when deserializing (fromBinary) * to check that the manifest class is of a known (registered) type. * - * If and old class is removed from `serialization-bindings` when it's not used for serialization + * If an old class is removed from `serialization-bindings` when it's not used for serialization * but still used for deserialization (e.g. rolling update with serialization changes) it can * be allowed by specifying in `whitelist-class-prefix`. *