diff --git a/akka-remote/src/main/scala/akka/remote/Endpoint.scala b/akka-remote/src/main/scala/akka/remote/Endpoint.scala index f133453162..9b870e0f46 100644 --- a/akka-remote/src/main/scala/akka/remote/Endpoint.scala +++ b/akka-remote/src/main/scala/akka/remote/Endpoint.scala @@ -15,8 +15,9 @@ import akka.remote.transport.AssociationHandle._ import akka.remote.transport.{ AkkaPduCodec, Transport, AssociationHandle } import akka.serialization.Serialization import akka.util.ByteString -import scala.util.control.{ NoStackTrace, NonFatal } import akka.remote.transport.Transport.InvalidAssociationException +import java.io.NotSerializableException +import scala.util.control.{ NoStackTrace, NonFatal } /** * INTERNAL API @@ -143,19 +144,42 @@ private[remote] class EndpointAssociationException(msg: String, cause: Throwable /** * INTERNAL API */ -private[remote] class EndpointWriter( - handleOrActive: Option[AssociationHandle], +@SerialVersionUID(1L) +private[remote] class OversizedPayloadException(msg: String) extends EndpointException(msg) + +private[remote] abstract class EndpointActor( val localAddress: Address, val remoteAddress: Address, val transport: Transport, val settings: RemoteSettings, - val codec: AkkaPduCodec) extends Actor with Stash with FSM[EndpointWriter.State, Unit] { + val codec: AkkaPduCodec) extends Actor with ActorLogging { + + def inbound: Boolean + + val eventPublisher = new EventPublisher(context.system, log, settings.LogRemoteLifecycleEvents) + + def publishError(reason: Throwable): Unit = { + try + eventPublisher.notifyListeners(AssociationErrorEvent(reason, localAddress, remoteAddress, inbound)) + catch { case NonFatal(e) ⇒ log.error(e, "Unable to publish error event to EventStream.") } + } +} + +/** + * INTERNAL API + */ +private[remote] class EndpointWriter( + handleOrActive: Option[AssociationHandle], + localAddress: Address, + remoteAddress: Address, + transport: Transport, + settings: RemoteSettings, + codec: AkkaPduCodec) extends EndpointActor(localAddress, remoteAddress, transport, settings, codec) with Stash with FSM[EndpointWriter.State, Unit] { import EndpointWriter._ import context.dispatcher val extendedSystem: ExtendedActorSystem = context.system.asInstanceOf[ExtendedActorSystem] - val eventPublisher = new EventPublisher(context.system, log, settings.LogRemoteLifecycleEvents) var reader: Option[ActorRef] = None var handle: Option[AssociationHandle] = handleOrActive // FIXME: refactor into state data @@ -168,12 +192,15 @@ private[remote] class EndpointWriter( var inbound = handle.isDefined private def publishAndThrow(reason: Throwable): Nothing = { - try - eventPublisher.notifyListeners(AssociationErrorEvent(reason, localAddress, remoteAddress, inbound)) - catch { case NonFatal(e) ⇒ log.error(e, "Unable to publish error event to EventStream.") } + publishError(reason) throw reason } + private def publishAndStay(reason: Throwable): State = { + publishError(reason) + stay() + } + override def postRestart(reason: Throwable): Unit = { handle = None // Wipe out the possibly injected handle inbound = false @@ -226,7 +253,11 @@ private[remote] class EndpointWriter( handle match { case Some(h) ⇒ val pdu = codec.constructMessage(recipient.localAddressToUse, recipient, serializeMessage(msg), senderOption) - if (h.write(pdu)) stay() else { + if (pdu.size > transport.maximumPayloadBytes) { + publishAndStay(new OversizedPayloadException(s"Discarding oversized payload sent to ${recipient}: max allowed size ${transport.maximumPayloadBytes} bytes, actual size of encoded ${msg.getClass} was ${pdu.size} bytes.")) + } else if (h.write(pdu)) { + stay() + } else { stash() goto(Buffering) } @@ -234,11 +265,12 @@ private[remote] class EndpointWriter( throw new EndpointException("Internal error: Endpoint is in state Writing, but no association handle is present.") } } catch { - case NonFatal(e: EndpointException) ⇒ publishAndThrow(e) - case NonFatal(e) ⇒ publishAndThrow(new EndpointException("Failed to write message to the transport", e)) + case e: NotSerializableException ⇒ publishAndStay(e) + case e: EndpointException ⇒ publishAndThrow(e) + case NonFatal(e) ⇒ publishAndThrow(new EndpointException("Failed to write message to the transport", e)) } - // We are in Writing state, so stash is emtpy, safe to stop here + // We are in Writing state, so stash is empty, safe to stop here case Event(FlushAndStop, _) ⇒ stop() } @@ -290,19 +322,15 @@ private[remote] class EndpointWriter( } private def startReadEndpoint(handle: AssociationHandle): Some[ActorRef] = { - val readerLocalAddress = handle.localAddress - val readerCodec = codec - val readerDispatcher = msgDispatch val newReader = - context.watch(context.actorOf(Props(new EndpointReader(readerCodec, readerLocalAddress, readerDispatcher)), + context.watch(context.actorOf( + Props(new EndpointReader(localAddress, remoteAddress, transport, settings, codec, msgDispatch, inbound)), "endpointReader-" + AddressUrlEncoder(remoteAddress) + "-" + readerId.next())) handle.readHandlerPromise.success(ActorHandleEventListener(newReader)) Some(newReader) } private def serializeMessage(msg: Any): MessageProtocol = handle match { - // FIXME: Unserializable messages should be dropped without closing the association. Should be logged, - // but without flooding the log. case Some(h) ⇒ Serialization.currentTransportAddress.withValue(h.localAddress) { (MessageSerializer.serialize(extendedSystem, msg.asInstanceOf[AnyRef])) @@ -317,9 +345,13 @@ private[remote] class EndpointWriter( * INTERNAL API */ private[remote] class EndpointReader( - val codec: AkkaPduCodec, - val localAddress: Address, - val msgDispatch: InboundMessageDispatcher) extends Actor { + localAddress: Address, + remoteAddress: Address, + transport: Transport, + settings: RemoteSettings, + codec: AkkaPduCodec, + msgDispatch: InboundMessageDispatcher, + val inbound: Boolean) extends EndpointActor(localAddress, remoteAddress, transport, settings, codec) { val provider = RARP(context.system).provider @@ -327,8 +359,12 @@ private[remote] class EndpointReader( case Disassociated ⇒ context.stop(self) case InboundPayload(p) ⇒ - val msg = decodePdu(p) - msgDispatch.dispatch(msg.recipient, msg.recipientAddress, msg.serializedMessage, msg.senderOption) + if (p.size > transport.maximumPayloadBytes) { + publishError(new OversizedPayloadException(s"Discarding oversized payload received: max allowed size ${transport.maximumPayloadBytes} bytes, actual size ${p.size} bytes.")) + } else { + val msg = decodePdu(p) + msgDispatch.dispatch(msg.recipient, msg.recipientAddress, msg.serializedMessage, msg.senderOption) + } } private def decodePdu(pdu: ByteString): Message = try { diff --git a/akka-remote/src/test/scala/akka/remote/RemotingSpec.scala b/akka-remote/src/test/scala/akka/remote/RemotingSpec.scala index 98b4cd455c..37be9b5553 100644 --- a/akka-remote/src/test/scala/akka/remote/RemotingSpec.scala +++ b/akka-remote/src/test/scala/akka/remote/RemotingSpec.scala @@ -5,12 +5,14 @@ package akka.remote import akka.actor._ import akka.pattern.ask +import akka.remote.transport.AssociationRegistry import akka.testkit._ +import akka.util.ByteString import com.typesafe.config._ +import java.io.NotSerializableException import scala.concurrent.Await import scala.concurrent.Future import scala.concurrent.duration._ -import akka.remote.transport.AssociationRegistry object RemotingSpec { class Echo1 extends Actor { @@ -115,8 +117,9 @@ class RemotingSpec extends AkkaSpec(RemotingSpec.cfg) with ImplicitSender with D val conf = ConfigFactory.parseString( """ - akka.remote { - test.local-address = "test://remote-sys@localhost:12346" + akka.remote.test { + local-address = "test://remote-sys@localhost:12346" + maximum-payload-bytes = 48000 bytes } """).withFallback(system.settings.config).resolve() val otherSystem = ActorSystem("remote-sys", conf) @@ -139,6 +142,38 @@ class RemotingSpec extends AkkaSpec(RemotingSpec.cfg) with ImplicitSender with D val here = system.actorFor("akka.test://remote-sys@localhost:12346/user/echo") + private def verifySend(msg: Any)(afterSend: ⇒ Unit) { + val bigBounceOther = otherSystem.actorOf(Props(new Actor { + def receive = { + case x: Int ⇒ sender ! byteStringOfSize(x) + case x ⇒ sender ! x + } + }), "bigBounce") + val bigBounceHere = system.actorFor("akka.test://remote-sys@localhost:12346/user/bigBounce") + + val eventForwarder = system.actorOf(Props(new Actor { + def receive = { + case x ⇒ testActor ! x + } + })) + system.eventStream.subscribe(eventForwarder, classOf[AssociationErrorEvent]) + system.eventStream.subscribe(eventForwarder, classOf[DisassociatedEvent]) + try { + bigBounceHere ! msg + afterSend + expectNoMsg(500.millis.dilated) + } finally { + system.eventStream.unsubscribe(eventForwarder, classOf[AssociationErrorEvent]) + system.eventStream.unsubscribe(eventForwarder, classOf[DisassociatedEvent]) + system.stop(eventForwarder) + otherSystem.stop(bigBounceOther) + } + } + + private def byteStringOfSize(size: Int) = ByteString.fromArray(Array.fill(size)(42: Byte)) + + val maxPayloadBytes = system.settings.config.getBytes("akka.remote.test.maximum-payload-bytes").toInt + override def afterTermination() { otherSystem.shutdown() AssociationRegistry.clear() @@ -345,6 +380,41 @@ class RemotingSpec extends AkkaSpec(RemotingSpec.cfg) with ImplicitSender with D expectMsg("postStop") } + "drop unserializable messages" in { + object Unserializable + verifySend(Unserializable) { + expectMsgPF(1.second) { + case AssociationErrorEvent(_: NotSerializableException, _, _, _) ⇒ () + } + } + } + + "allow messages up to payload size" in { + val maxProtocolOverhead = 500 // Make sure we're still under size after the message is serialized, etc + val big = byteStringOfSize(maxPayloadBytes - maxProtocolOverhead) + verifySend(big) { + expectMsg(1.second, big) + } + } + + "drop sent messages over payload size" in { + val oversized = byteStringOfSize(maxPayloadBytes + 1) + verifySend(oversized) { + expectMsgPF(1.second) { + case AssociationErrorEvent(e: OversizedPayloadException, _, _, _) if e.getMessage.startsWith("Discarding oversized payload sent") ⇒ () + } + } + } + + "drop received messages over payload size" in { + // Receiver should reply with a message of size maxPayload + 1, which will be dropped and an error logged + verifySend(maxPayloadBytes + 1) { + expectMsgPF(1.second) { + case AssociationErrorEvent(e: OversizedPayloadException, _, _, _) if e.getMessage.startsWith("Discarding oversized payload received") ⇒ () + } + } + } + } override def beforeTermination() {