diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/StreamRefsSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/StreamRefsSpec.scala index cf3620147e..003f487a4a 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/StreamRefsSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/StreamRefsSpec.scala @@ -433,6 +433,41 @@ class StreamRefsSpec extends AkkaSpec(StreamRefsSpec.config()) { } } + "not die to a slow and eager subscriber" in { + import akka.stream.impl.streamref.StreamRefsProtocol._ + + // GIVEN: remoteActor delivers 2 elements "hello", "world" + val remoteProbe = TestProbe()(remoteSystem) + remoteActor.tell("give", remoteProbe.ref) + val sourceRefImpl = remoteProbe.expectMsgType[SourceRefImpl[String]] + + val sourceRefStageProbe = TestProbe("sourceRefStageProbe") + + // WHEN: SourceRefStage sends a first CumulativeDemand with enough demand to consume the whole stream + sourceRefStageProbe.send(sourceRefImpl.initialPartnerRef, CumulativeDemand(10)) + + // THEN: stream established with OnSubscribeHandshake + val onSubscribeHandshake = sourceRefStageProbe.expectMsgType[OnSubscribeHandshake] + val sinkRefStageActorRef = watch(onSubscribeHandshake.targetRef) + + // THEN: all elements are streamed to SourceRefStage + sourceRefStageProbe.expectMsg(SequencedOnNext(0, "hello")) + sourceRefStageProbe.expectMsg(SequencedOnNext(1, "world")) + sourceRefStageProbe.expectMsg(RemoteStreamCompleted(2)) + + // WHEN: SinkRefStage receives another CumulativeDemand, due to latency in network or slowness of sourceRefStage + sourceRefStageProbe.send(sinkRefStageActorRef, CumulativeDemand(10)) + + // THEN: SinkRefStage should not terminate + expectNoMessage() + + // WHEN: SourceRefStage terminates + system.stop(sourceRefStageProbe.ref) + + // THEN: SinkRefStage should terminate + expectTerminated(sinkRefStageActorRef) + } + } "A SinkRef" must { diff --git a/akka-stream/src/main/scala/akka/stream/impl/streamref/SinkRefImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/streamref/SinkRefImpl.scala index 5e79888820..297ff5a1af 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/streamref/SinkRefImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/streamref/SinkRefImpl.scala @@ -196,7 +196,7 @@ private[stream] final class SinkRefStageImpl[In] private[akka] (val initialPartn } private def tryPull(): Unit = - if (remoteCumulativeDemandConsumed < remoteCumulativeDemandReceived && !hasBeenPulled(in)) { + if (remoteCumulativeDemandConsumed < remoteCumulativeDemandReceived && !hasBeenPulled(in) && !isClosed(in)) { pull(in) }