diff --git a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/HubSpec.scala b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/HubSpec.scala index 0a334c3729..e1a8ef77f3 100644 --- a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/HubSpec.scala +++ b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/HubSpec.scala @@ -561,6 +561,48 @@ class HubSpec extends StreamSpec { out.expectComplete() } + "handle unregistration concurrent with registration" in { + + var sinkProbe1: TestSubscriber.Probe[Int] = null + + def registerConsumerCallback(id: Long): Unit = { + if (id == 1) { + sinkProbe1.cancel() + } + } + + val in = TestPublisher.probe[Int]() + val hubSource = Source + .fromPublisher(in) + .runWith(Sink.fromGraph(new BroadcastHub[Int](0, 2, registerConsumerCallback))) + + // Put one element into the buffer + in.sendNext(15) + + // add a consumer to receive the first element + val sinkProbe0 = hubSource.runWith(TestSink.probe[Int]) + sinkProbe0.request(1) + sinkProbe0.expectNext(15) + sinkProbe0.cancel() + + // put more elements into the buffer + in.sendNext(16) + in.sendNext(17) + in.sendNext(18) + + // Add another consumer and kill it during registration + + sinkProbe1 = hubSource.runWith(TestSink.probe[Int]) + Thread.sleep(100) + + // Make sure that the element 16 isn't lost by reading it with a third consumer + val sinkProbe2 = hubSource.runWith(TestSink.probe[Int]) + sinkProbe2.request(1) + sinkProbe2.expectNext(16) + + in.sendComplete() + sinkProbe2.cancel() + } } "PartitionHub" must { diff --git a/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Hub.scala b/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Hub.scala index 23be24d7f4..a1043040ef 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Hub.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Hub.scala @@ -18,19 +18,18 @@ import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.{ AtomicLong, AtomicReference } import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicReferenceArray - import scala.annotation.tailrec import scala.collection.immutable import scala.collection.immutable.Queue import scala.collection.mutable.LongMap import scala.concurrent.{ Future, Promise } import scala.util.{ Failure, Success, Try } - import org.apache.pekko import pekko.NotUsed import pekko.annotation.DoNotInherit import pekko.annotation.InternalApi import pekko.dispatch.AbstractNodeQueue +import pekko.util.ConstantFun import pekko.stream._ import pekko.stream.Attributes.LogLevels import pekko.stream.impl.ActorPublisher @@ -475,13 +474,19 @@ object BroadcastHub { /** * INTERNAL API + * + * @param registrationPendingCallback Called during the `RegistrationPending` event of a consumer with the consumer's internal ID. + * This is useful for controlling the interleaving in tests. */ -private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: Int) +private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: Int, + registrationPendingCallback: Long => Unit) extends GraphStageWithMaterializedValue[SinkShape[T], Source[T, NotUsed]] { require(startAfterNrOfConsumers >= 0, "startAfterNrOfConsumers must >= 0") require(bufferSize > 0, "Buffer size must be positive") require(bufferSize < 4096, "Buffer size larger then 4095 is not allowed") require((bufferSize & bufferSize - 1) == 0, "Buffer size must be a power of two") + def this(startAfterNrOfConsumers: Int, bufferSize: Int) = + this(startAfterNrOfConsumers, bufferSize, ConstantFun.scalaAnyToUnit) def this(bufferSize: Int) = this(0, bufferSize) private val Mask = bufferSize - 1 @@ -585,6 +590,8 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I val startFrom = head activeConsumers += 1 addConsumer(consumer, startFrom) + // add a callback hook so that we can control the interleaving in tests + registrationPendingCallback(consumer.id) // in case the consumer is already stopped we need to undo registration implicit val ec = materializer.executionContext consumer.callback.invokeWithFeedback(Initialize(startFrom)).failed.foreach { @@ -854,7 +861,12 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I } override def postStop(): Unit = { - if (hubCallback ne null) + // If `postStop` is called before the consumer has processed the `RegistrationPending`'s `Initialize` event, + // then the `Initialize` message will fail with a `StreamDetachedException`, + // upon which the `RegistrationPending` logic itself unregisters this consumer. + // In particular, this client must not send the `Unregister` event itself because the values in + // `previousPublishedOffset` and `offset` are wrong. + if ((hubCallback ne null) && offsetInitialized) hubCallback.invoke(UnRegister(id, previousPublishedOffset, offset)) }