From 17f712a76beca1fac8fe054b8be07b3e6e5638f6 Mon Sep 17 00:00:00 2001 From: Patrik Nordwall Date: Sat, 11 Nov 2017 10:19:57 +0100 Subject: [PATCH] Pass HandshakeReq in all inbound lanes, #23527 (#23842) * Pass HandshakeReq in all inbound lanes, #23527 The HandshakeReq message must be passed in each inbound lane to ensure that it arrives before any application message. Otherwise there is a risk that an application message arrives in the InboundHandshake stage before the handshake is completed and then it would be dropped. * mima --- .../mima-filters/2.5.6.backwards.excludes | 5 + .../akka/remote/artery/ArteryTransport.scala | 5 +- .../scala/akka/remote/artery/BufferPool.scala | 15 ++- .../scala/akka/remote/artery/Codecs.scala | 66 ++++++++++- .../akka/remote/artery/InboundEnvelope.scala | 21 +++- .../artery/DuplicateHandshakeSpec.scala | 104 ++++++++++++++++++ 6 files changed, 211 insertions(+), 5 deletions(-) create mode 100644 akka-remote/src/main/mima-filters/2.5.6.backwards.excludes create mode 100644 akka-remote/src/test/scala/akka/remote/artery/DuplicateHandshakeSpec.scala diff --git a/akka-remote/src/main/mima-filters/2.5.6.backwards.excludes b/akka-remote/src/main/mima-filters/2.5.6.backwards.excludes new file mode 100644 index 0000000000..76d5479d2f --- /dev/null +++ b/akka-remote/src/main/mima-filters/2.5.6.backwards.excludes @@ -0,0 +1,5 @@ +# #23527 HandshakeReq in inbound lanes +ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.remote.artery.InboundEnvelope.lane") +ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.remote.artery.InboundEnvelope.copyForLane") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.remote.artery.ReusableInboundEnvelope.init") + diff --git a/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala b/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala index 51ad9e0419..9d5a7dfe84 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala @@ -792,6 +792,7 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R aeronSource(ordinaryStreamId, envelopeBufferPool) .via(hubKillSwitch.flow) .viaMat(inboundFlow(settings, _inboundCompressions))(Keep.both) + .via(Flow.fromGraph(new DuplicateHandshakeReq(inboundLanes, this, system, envelopeBufferPool))) // Select lane based on destination to preserve message order, // Also include the uid of the sending system in the hash to spread @@ -804,7 +805,9 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R val hashA = 23 + a val hash: Int = 23 * hashA + java.lang.Long.hashCode(b) math.abs(hash) % inboundLanes - case OptionVal.None ⇒ 0 + case OptionVal.None ⇒ + // the lane is set by the DuplicateHandshakeReq stage, otherwise 0 + env.lane } } diff --git a/akka-remote/src/main/scala/akka/remote/artery/BufferPool.scala b/akka-remote/src/main/scala/akka/remote/artery/BufferPool.scala index 4e3460cec5..7f91747efb 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/BufferPool.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/BufferPool.scala @@ -37,7 +37,8 @@ private[remote] class EnvelopeBufferPool(maximumPayload: Int, maximumBuffers: In } } - def release(buffer: EnvelopeBuffer) = if (!availableBuffers.offer(buffer)) buffer.tryCleanDirectByteBuffer() + def release(buffer: EnvelopeBuffer) = + if (buffer.byteBuffer.isDirect && !availableBuffers.offer(buffer)) buffer.tryCleanDirectByteBuffer() } @@ -499,4 +500,16 @@ private[remote] final class EnvelopeBuffer(val byteBuffer: ByteBuffer) { } def tryCleanDirectByteBuffer(): Unit = DirectByteBufferPool.tryCleanDirectByteBuffer(byteBuffer) + + def copy(): EnvelopeBuffer = { + val p = byteBuffer.position() + byteBuffer.rewind() + val bytes = new Array[Byte](byteBuffer.remaining) + byteBuffer.get(bytes) + val newByteBuffer = ByteBuffer.wrap(bytes) + newByteBuffer.position(p) + byteBuffer.position(p) + new EnvelopeBuffer(newByteBuffer) + } + } diff --git a/akka-remote/src/main/scala/akka/remote/artery/Codecs.scala b/akka-remote/src/main/scala/akka/remote/artery/Codecs.scala index 77062bce0d..e65aba091b 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/Codecs.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/Codecs.scala @@ -21,6 +21,8 @@ import akka.util.{ OptionVal, Unsafe } import scala.concurrent.duration._ import scala.concurrent.{ Future, Promise } import scala.util.control.NonFatal +import akka.remote.artery.OutboundHandshake.HandshakeReq +import akka.serialization.SerializerWithStringManifest /** * INTERNAL API @@ -482,7 +484,8 @@ private[remote] class Decoder( classManifest, headerBuilder.flags, envelope, - association) + association, + lane = 0) if (recipient.isEmpty && !headerBuilder.isNoRecipient) { @@ -649,3 +652,64 @@ private[remote] class Deserializer( setHandlers(in, out, this) } } + +/** + * INTERNAL API: The HandshakeReq message must be passed in each inbound lane to + * ensure that it arrives before any application message. Otherwise there is a risk + * that an application message arrives in the InboundHandshake stage before the + * handshake is completed and then it would be dropped. + */ +private[remote] class DuplicateHandshakeReq( + numberOfLanes: Int, + inboundContext: InboundContext, + system: ExtendedActorSystem, + bufferPool: EnvelopeBufferPool) extends GraphStage[FlowShape[InboundEnvelope, InboundEnvelope]] { + + val in: Inlet[InboundEnvelope] = Inlet("Artery.DuplicateHandshakeReq.in") + val out: Outlet[InboundEnvelope] = Outlet("Artery.DuplicateHandshakeReq.out") + val shape: FlowShape[InboundEnvelope, InboundEnvelope] = FlowShape(in, out) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with InHandler with OutHandler { + private val (serializerId, manifest) = { + val serialization = SerializationExtension(system) + val ser = serialization.serializerFor(classOf[HandshakeReq]) + val m = ser match { + case s: SerializerWithStringManifest ⇒ + s.manifest(HandshakeReq(inboundContext.localAddress, inboundContext.localAddress.address)) + case _ ⇒ "" + } + (ser.identifier, m) + } + var currentIterator: Iterator[InboundEnvelope] = Iterator.empty + + override def onPush(): Unit = { + val envelope = grab(in) + if (envelope.association.isEmpty && envelope.serializer == serializerId && envelope.classManifest == manifest) { + // only need to duplicate HandshakeReq before handshake is completed + try { + currentIterator = Vector.tabulate(numberOfLanes)(i ⇒ envelope.copyForLane(i)).iterator + push(out, currentIterator.next()) + } finally { + val buf = envelope.envelopeBuffer + if (buf != null) { + envelope.releaseEnvelopeBuffer() + bufferPool.release(buf) + } + } + } else + push(out, envelope) + } + + override def onPull(): Unit = { + if (currentIterator.isEmpty) + pull(in) + else { + push(out, currentIterator.next()) + if (currentIterator.isEmpty) currentIterator = Iterator.empty // GC friendly + } + } + + setHandlers(in, out, this) + } +} diff --git a/akka-remote/src/main/scala/akka/remote/artery/InboundEnvelope.scala b/akka-remote/src/main/scala/akka/remote/artery/InboundEnvelope.scala index 14f555619a..7c7d564b47 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/InboundEnvelope.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/InboundEnvelope.scala @@ -22,7 +22,7 @@ private[remote] object InboundEnvelope { originUid: Long, association: OptionVal[OutboundContext]): InboundEnvelope = { val env = new ReusableInboundEnvelope - env.init(recipient, sender, originUid, -1, "", 0, null, association) + env.init(recipient, sender, originUid, -1, "", 0, null, association, lane = 0) .withMessage(message) } @@ -50,6 +50,9 @@ private[remote] trait InboundEnvelope { def releaseEnvelopeBuffer(): InboundEnvelope def withRecipient(ref: InternalActorRef): InboundEnvelope + + def lane: Int + def copyForLane(lane: Int): InboundEnvelope } /** @@ -72,6 +75,7 @@ private[remote] final class ReusableInboundEnvelope extends InboundEnvelope { private var _serializer: Int = -1 private var _classManifest: String = null private var _flags: Byte = 0 + private var _lane: Int = 0 private var _message: AnyRef = null private var _envelopeBuffer: EnvelopeBuffer = null @@ -87,6 +91,8 @@ private[remote] final class ReusableInboundEnvelope extends InboundEnvelope { override def flags: Byte = _flags override def flag(byteFlag: ByteFlag): Boolean = byteFlag.isEnabled(_flags) + override def lane: Int = _lane + override def withMessage(message: AnyRef): InboundEnvelope = { _message = message this @@ -108,6 +114,7 @@ private[remote] final class ReusableInboundEnvelope extends InboundEnvelope { _sender = OptionVal.None _originUid = 0L _association = OptionVal.None + _lane = 0 } def init( @@ -118,7 +125,8 @@ private[remote] final class ReusableInboundEnvelope extends InboundEnvelope { classManifest: String, flags: Byte, envelopeBuffer: EnvelopeBuffer, - association: OptionVal[OutboundContext]): InboundEnvelope = { + association: OptionVal[OutboundContext], + lane: Int): InboundEnvelope = { _recipient = recipient _sender = sender _originUid = originUid @@ -127,6 +135,7 @@ private[remote] final class ReusableInboundEnvelope extends InboundEnvelope { _flags = flags _envelopeBuffer = envelopeBuffer _association = association + _lane = lane this } @@ -135,6 +144,14 @@ private[remote] final class ReusableInboundEnvelope extends InboundEnvelope { this } + override def copyForLane(lane: Int): InboundEnvelope = { + val buf = if (envelopeBuffer eq null) null else envelopeBuffer.copy() + val env = new ReusableInboundEnvelope + env.init(recipient, sender, originUid, serializer, classManifest, flags, buf, association, lane) + .withMessage(message) + + } + override def toString: String = s"InboundEnvelope($recipient, $message, $sender, $originUid, $association)" } diff --git a/akka-remote/src/test/scala/akka/remote/artery/DuplicateHandshakeSpec.scala b/akka-remote/src/test/scala/akka/remote/artery/DuplicateHandshakeSpec.scala new file mode 100644 index 0000000000..7d55c7216d --- /dev/null +++ b/akka-remote/src/test/scala/akka/remote/artery/DuplicateHandshakeSpec.scala @@ -0,0 +1,104 @@ +/** + * Copyright (C) 2017 Lightbend Inc. + */ +package akka.remote.artery + +import scala.concurrent.duration._ + +import akka.actor.Address +import akka.actor.ExtendedActorSystem +import akka.remote.UniqueAddress +import akka.remote.artery.OutboundHandshake.HandshakeReq +import akka.stream.ActorMaterializer +import akka.stream.ActorMaterializerSettings +import akka.stream.scaladsl.Keep +import akka.stream.testkit.TestPublisher +import akka.stream.testkit.TestSubscriber +import akka.stream.testkit.scaladsl.TestSink +import akka.stream.testkit.scaladsl.TestSource +import akka.testkit.AkkaSpec +import akka.testkit.ImplicitSender +import akka.util.OptionVal +import akka.serialization.SerializationExtension +import akka.serialization.SerializerWithStringManifest + +class DuplicateHandshakeSpec extends AkkaSpec with ImplicitSender { + + val matSettings = ActorMaterializerSettings(system).withFuzzing(true) + implicit val mat = ActorMaterializer(matSettings)(system) + val pool = new EnvelopeBufferPool(1034 * 1024, 128) + val serialization = SerializationExtension(system) + + val addressA = UniqueAddress(Address("akka", "sysA", "hostA", 1001), 1) + val addressB = UniqueAddress(Address("akka", "sysB", "hostB", 1002), 2) + + private def setupStream(inboundContext: InboundContext, timeout: FiniteDuration = 5.seconds): (TestPublisher.Probe[AnyRef], TestSubscriber.Probe[Any]) = { + TestSource.probe[AnyRef] + .map { msg ⇒ + val association = inboundContext.association(addressA.uid) + val ser = serialization.serializerFor(msg.getClass) + val serializerId = ser.identifier + val manifest = ser match { + case s: SerializerWithStringManifest ⇒ s.manifest(msg) + case _ ⇒ "" + } + + val env = new ReusableInboundEnvelope + env.init(recipient = OptionVal.None, sender = OptionVal.None, originUid = addressA.uid, + serializerId, manifest, flags = 0, envelopeBuffer = null, association, lane = 0) + .withMessage(msg) + env + } + .via(new DuplicateHandshakeReq(numberOfLanes = 3, inboundContext, system.asInstanceOf[ExtendedActorSystem], pool)) + .map { case env: InboundEnvelope ⇒ (env.message -> env.lane) } + .toMat(TestSink.probe[Any])(Keep.both) + .run() + } + + "DuplicateHandshake stage" must { + + "duplicate initial HandshakeReq" in { + val inboundContext = new TestInboundContext(addressB, controlProbe = None) + val (upstream, downstream) = setupStream(inboundContext) + + downstream.request(10) + val req = HandshakeReq(addressA, addressB.address) + upstream.sendNext(req) + upstream.sendNext("msg1") + downstream.expectNext((req, 0)) + downstream.expectNext((req, 1)) + downstream.expectNext((req, 2)) + downstream.expectNext(("msg1", 0)) + upstream.sendNext(req) + downstream.expectNext((req, 0)) + downstream.expectNext((req, 1)) + downstream.expectNext((req, 2)) + downstream.cancel() + } + + "not duplicate after handshake completed" in { + val inboundContext = new TestInboundContext(addressB, controlProbe = None) + val (upstream, downstream) = setupStream(inboundContext) + + downstream.request(10) + val req = HandshakeReq(addressA, addressB.address) + upstream.sendNext(req) + downstream.expectNext((req, 0)) + downstream.expectNext((req, 1)) + downstream.expectNext((req, 2)) + upstream.sendNext("msg1") + downstream.expectNext(("msg1", 0)) + + inboundContext.completeHandshake(addressA) + upstream.sendNext("msg2") + downstream.expectNext(("msg2", 0)) + upstream.sendNext(req) + downstream.expectNext((req, 0)) + upstream.sendNext("msg3") + downstream.expectNext(("msg3", 0)) + downstream.cancel() + } + + } + +}