diff --git a/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/SurviveInboundStreamRestartWithCompressionInFlightSpec.scala b/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/SurviveInboundStreamRestartWithCompressionInFlightSpec.scala index d377b3f118..ae0da300d0 100644 --- a/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/SurviveInboundStreamRestartWithCompressionInFlightSpec.scala +++ b/akka-remote-tests/src/multi-jvm/scala/akka/remote/artery/SurviveInboundStreamRestartWithCompressionInFlightSpec.scala @@ -32,7 +32,8 @@ object SurviveInboundStreamRestartWithCompressionInFlightSpec extends MultiNodeC akka.loglevel = INFO akka.remote.artery { enabled = on - advanced { + advanced { + inbound-lanes = 4 give-up-system-message-after = 4s compression.actor-refs.advertisement-interval = 300ms compression.manifests.advertisement-interval = 1 minute @@ -118,15 +119,20 @@ abstract class SurviveInboundStreamRestartWithCompressionInFlightSpec extends Re } enterBarrier("inbound-failure-restart-first") - // we poke the remote system, awaiting its inbound stream recovery, when it should reply - awaitAssert( - { - sendToB ! "alive-again" - expectMsg(300.millis, s"${sendToB.path.name}-alive-again") - }, - max = 5.seconds, interval = 500.millis) - runOn(second) { + sendToB.tell("trigger", ActorRef.noSender) + // when using inbound-lanes > 1 we can't be sure when it's done, another message (e.g. HandshakeReq) + // might have triggered the restart + Thread.sleep(2000) + + // we poke the remote system, awaiting its inbound stream recovery, then it should reply + awaitAssert( + { + sendToB ! "alive-again" + expectMsg(300.millis, s"${sendToB.path.name}-alive-again") + }, + max = 5.seconds, interval = 500.millis) + // we continue sending messages using the "old table". // if a new table was being built, it would cause the b to be compressed as 1 causing a wrong reply to come back 1 to 100 foreach { i ⇒ pingPong(sendToB, s"b$i") } diff --git a/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala b/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala index 9d5a7dfe84..26f6550927 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/ArteryTransport.scala @@ -824,13 +824,14 @@ private[remote] class ArteryTransport(_system: ExtendedActorSystem, _provider: R }(collection.breakOut) import system.dispatcher - val completed = Future.sequence(completedValues).map(_ ⇒ Done) // tear down the upstream hub part if downstream lane fails // lanes are not completed with success by themselves so we don't have to care about onSuccess - completed.failed.foreach { reason ⇒ hubKillSwitch.abort(reason) } + Future.firstCompletedOf(completedValues).failed.foreach { reason ⇒ hubKillSwitch.abort(reason) } - (resourceLife, compressionAccess, completed) + val allCompleted = Future.sequence(completedValues).map(_ ⇒ Done) + + (resourceLife, compressionAccess, allCompleted) } _inboundCompressionAccess = OptionVal(inboundCompressionAccesses) diff --git a/akka-remote/src/main/scala/akka/remote/artery/Association.scala b/akka-remote/src/main/scala/akka/remote/artery/Association.scala index c3d39c41d8..a216ab5ec9 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/Association.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/Association.scala @@ -613,14 +613,15 @@ private[remote] class Association( val (queueValues, compressionAccessValues, laneCompletedValues) = values.unzip3 import transport.system.dispatcher - val completed = Future.sequence(laneCompletedValues).flatMap(_ ⇒ aeronSinkCompleted) // tear down all parts if one part fails or completes - completed.failed.foreach { - reason ⇒ streamKillSwitch.abort(reason) + Future.firstCompletedOf(laneCompletedValues).failed.foreach { reason ⇒ + streamKillSwitch.abort(reason) } (laneCompletedValues :+ aeronSinkCompleted).foreach(_.foreach { _ ⇒ streamKillSwitch.shutdown() }) + val allCompleted = Future.sequence(laneCompletedValues).flatMap(_ ⇒ aeronSinkCompleted) + queueValues.zip(wrappers).zipWithIndex.foreach { case ((q, w), i) ⇒ q.inject(w.queue) @@ -631,7 +632,7 @@ private[remote] class Association( outboundCompressionAccess = compressionAccessValues attachStreamRestart("Outbound message stream", OrdinaryQueueIndex, queueSize, - completed, () ⇒ runOutboundOrdinaryMessagesStream()) + allCompleted, () ⇒ runOutboundOrdinaryMessagesStream()) } } diff --git a/akka-remote/src/main/scala/akka/remote/artery/FixedSizePartitionHub.scala b/akka-remote/src/main/scala/akka/remote/artery/FixedSizePartitionHub.scala index 7426b58db4..b5cd20fe91 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/FixedSizePartitionHub.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/FixedSizePartitionHub.scala @@ -15,7 +15,11 @@ import org.agrona.concurrent.ManyToManyConcurrentArrayQueue @InternalApi private[akka] class FixedSizePartitionHub[T]( partitioner: T ⇒ Int, lanes: Int, - bufferSize: Int) extends PartitionHub[T](() ⇒ (info, elem) ⇒ info.consumerIdByIdx(partitioner(elem)), lanes, bufferSize - 1) { + bufferSize: Int) extends PartitionHub[T]( + // during tear down or restart it's possible that some streams have been removed + // and then we must drop elements (return -1) + () ⇒ (info, elem) ⇒ if (info.size < lanes) -1 else info.consumerIdByIdx(partitioner(elem)), + lanes, bufferSize - 1) { // -1 because of the Completed token override def createQueue(): PartitionHub.Internal.PartitionQueue = diff --git a/akka-remote/src/main/scala/akka/remote/artery/TestStage.scala b/akka-remote/src/main/scala/akka/remote/artery/TestStage.scala index a0889011a1..84e30ff01d 100644 --- a/akka-remote/src/main/scala/akka/remote/artery/TestStage.scala +++ b/akka-remote/src/main/scala/akka/remote/artery/TestStage.scala @@ -161,7 +161,9 @@ private[remote] class InboundTestStage(inboundContext: InboundContext, state: Sh // InHandler override def onPush(): Unit = { state.getInboundFailureOnce match { - case Some(shouldFailEx) ⇒ failStage(shouldFailEx) + case Some(shouldFailEx) ⇒ + log.info("Fail inbound stream from [{}]: {}", classOf[InboundTestStage].getName, shouldFailEx.getMessage) + failStage(shouldFailEx) case _ ⇒ val env = grab(in) env.association match { diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/HubSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/HubSpec.scala index 9897e3a2a4..0df497012e 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/HubSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/HubSpec.scala @@ -602,6 +602,15 @@ class HubSpec extends StreamSpec { } } + "drop elements with negative index" in assertAllStagesStopped { + val source = Source(0 until 10).runWith(PartitionHub.sink( + (size, elem) ⇒ if (elem == 3 || elem == 4) -1 else elem % size, startAfterNrOfConsumers = 2, bufferSize = 8)) + val result1 = source.runWith(Sink.seq) + val result2 = source.runWith(Sink.seq) + result1.futureValue should ===((0 to 8 by 2).filterNot(_ == 4)) + result2.futureValue should ===((1 to 9 by 2).filterNot(_ == 3)) + } + } } diff --git a/akka-stream/src/main/scala/akka/stream/javadsl/Hub.scala b/akka-stream/src/main/scala/akka/stream/javadsl/Hub.scala index c183381156..eda0e6034a 100644 --- a/akka-stream/src/main/scala/akka/stream/javadsl/Hub.scala +++ b/akka-stream/src/main/scala/akka/stream/javadsl/Hub.scala @@ -173,7 +173,8 @@ object PartitionHub { * @param partitioner Function that decides where to route an element. The function takes two parameters; * the first is the number of active consumers and the second is the stream element. The function should * return the index of the selected consumer for the given element, i.e. int greater than or equal to 0 - * and less than number of consumers. E.g. `(size, elem) -> Math.abs(elem.hashCode()) % size`. + * and less than number of consumers. E.g. `(size, elem) -> Math.abs(elem.hashCode()) % size`. It's also + * possible to use `-1` to drop the element. * @param startAfterNrOfConsumers Elements are buffered until this number of consumers have been connected. * This is only used initially when the stage is starting up, i.e. it is not honored when consumers have * been removed (canceled). diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Hub.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Hub.scala index 0d0f13a480..96edf32dc2 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Hub.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Hub.scala @@ -776,7 +776,8 @@ object PartitionHub { * @param partitioner Function that decides where to route an element. The function takes two parameters; * the first is the number of active consumers and the second is the stream element. The function should * return the index of the selected consumer for the given element, i.e. int greater than or equal to 0 - * and less than number of consumers. E.g. `(size, elem) => math.abs(elem.hashCode) % size`. + * and less than number of consumers. E.g. `(size, elem) => math.abs(elem.hashCode) % size`. It's also + * possible to use `-1` to drop the element. * @param startAfterNrOfConsumers Elements are buffered until this number of consumers have been connected. * This is only used initially when the stage is starting up, i.e. it is not honored when consumers have * been removed (canceled). @@ -785,8 +786,14 @@ object PartitionHub { */ @ApiMayChange def sink[T](partitioner: (Int, T) ⇒ Int, startAfterNrOfConsumers: Int, - bufferSize: Int = defaultBufferSize): Sink[T, Source[T, NotUsed]] = - statefulSink(() ⇒ (info, elem) ⇒ info.consumerIdByIdx(partitioner(info.size, elem)), startAfterNrOfConsumers, bufferSize) + bufferSize: Int = defaultBufferSize): Sink[T, Source[T, NotUsed]] = { + val fun: (ConsumerInfo, T) ⇒ Long = { (info, elem) ⇒ + val idx = partitioner(info.size, elem) + if (idx < 0) -1L + else info.consumerIdByIdx(idx) + } + statefulSink(() ⇒ fun, startAfterNrOfConsumers, bufferSize) + } @DoNotInherit @ApiMayChange trait ConsumerInfo extends akka.stream.javadsl.PartitionHub.ConsumerInfo { @@ -1051,8 +1058,10 @@ object PartitionHub { pending :+= elem } else { val id = materializedPartitioner(consumerInfo, elem) - queue.offer(id, elem) - wakeup(id) + if (id >= 0) { // negative id is a way to drop the element + queue.offer(id, elem) + wakeup(id) + } } }