diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala index 806f31be51..4b3f8c4c8e 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala @@ -6,24 +6,29 @@ package akka.stream.io import java.io.IOException import java.lang.management.ManagementFactory import java.util.concurrent.TimeoutException + import akka.actor.ActorSystem -import akka.stream._ import akka.stream.Attributes.inputBuffer +import akka.stream._ +import akka.stream.impl.ActorMaterializerImpl +import akka.stream.impl.StreamSupervisor import akka.stream.impl.StreamSupervisor.Children import akka.stream.impl.io.OutputStreamSourceStage -import akka.stream.impl.{ ActorMaterializerImpl, StreamSupervisor } -import akka.stream.scaladsl.{ Keep, StreamConverters, Sink } +import akka.stream.scaladsl.{ Keep, Sink, Source, StreamConverters } import akka.stream.testkit.Utils._ import akka.stream.testkit._ import akka.stream.testkit.scaladsl.TestSink import akka.testkit.TestProbe import akka.util.ByteString + +import scala.concurrent.Await +import scala.concurrent.Future import scala.concurrent.duration.Duration.Zero import scala.concurrent.duration._ -import scala.concurrent.{ Await, Future } import scala.util.Random class OutputStreamSourceSpec extends StreamSpec(UnboundedMailboxConfig) { + import system.dispatcher val settings = ActorMaterializerSettings(system).withDispatcher("akka.actor.default-dispatcher") @@ -204,5 +209,27 @@ class OutputStreamSourceSpec extends StreamSpec(UnboundedMailboxConfig) { assertNoBlockedThreads() } + + "correctly complete the stage after close" in assertAllStagesStopped { + // actually this was a race, so it only happened in at least one of 20 runs + + val bufSize = 4 + val sourceProbe = TestProbe() + val (outputStream, probe) = StreamConverters.asOutputStream(timeout) + .addAttributes(Attributes.inputBuffer(bufSize, bufSize)) + .toMat(TestSink.probe[ByteString])(Keep.both).run + + // fill the buffer up + (1 to (bufSize - 1)).foreach(outputStream.write) + Future { + outputStream.close() + } + // here is the race, has the elements reached the stage buffer yet? + Thread.sleep(500) + probe.request(bufSize - 1) + probe.expectNextN(bufSize - 1) + probe.expectComplete() + } + } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala b/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala index 21832b4832..fbae436eff 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala @@ -32,9 +32,6 @@ private[stream] object OutputStreamSourceStage { case object Ok extends DownstreamStatus case object Canceled extends DownstreamStatus - sealed trait StageWithCallback { - def wakeUp(msg: AdapterToStageMessage): Future[Unit] - } } final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration) extends GraphStageWithMaterializedValue[SourceShape[ByteString], OutputStream] { @@ -52,7 +49,7 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration val dataQueue = new LinkedBlockingQueue[ByteString](maxBuffer) val downstreamStatus = new AtomicReference[DownstreamStatus](Ok) - val logic = new GraphStageLogic(shape) with StageWithCallback { + val logic = new GraphStageLogic(shape) with CallbackWrapper[(AdapterToStageMessage, Promise[Unit])] { var flush: Option[Promise[Unit]] = None var close: Option[Promise[Unit]] = None @@ -68,9 +65,9 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration private val upstreamCallback: AsyncCallback[(AdapterToStageMessage, Promise[Unit])] = getAsyncCallback(onAsyncMessage) - override def wakeUp(msg: AdapterToStageMessage): Future[Unit] = { + def wakeUp(msg: AdapterToStageMessage): Future[Unit] = { val p = Promise[Unit]() - upstreamCallback.invoke((msg, p)) + this.invoke((msg, p)) p.future } @@ -81,11 +78,7 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration sendResponseIfNeed() case Close ⇒ close = Some(event._2) - if (dataQueue.isEmpty) { - downstreamStatus.set(Canceled) - completeStage() - unblockUpstream() - } else sendResponseIfNeed() + sendResponseIfNeed() } private def unblockUpstream(): Boolean = @@ -96,8 +89,10 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration true case None ⇒ close match { case Some(p) ⇒ + downstreamStatus.set(Canceled) p.complete(Success(())) close = None + completeStage() true case None ⇒ false } @@ -115,6 +110,7 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration override def preStart(): Unit = { dispatcher = ActorMaterializerHelper.downcast(materializer).system.dispatchers.lookup(dispatcherId) super.preStart() + initCallback(upstreamCallback.invoke) } setHandler(out, new OutHandler { @@ -151,6 +147,7 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration super.postStop() } } + (logic, new OutputStreamAdapter(dataQueue, downstreamStatus, logic.wakeUp, writeTimeout)) } }