diff --git a/akka-bench-jmh/src/main/scala/akka/stream/impl/OutputStreamSourceStageBenchmark.scala b/akka-bench-jmh/src/main/scala/akka/stream/impl/OutputStreamSourceStageBenchmark.scala new file mode 100644 index 0000000000..68f140d0bd --- /dev/null +++ b/akka-bench-jmh/src/main/scala/akka/stream/impl/OutputStreamSourceStageBenchmark.scala @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2019 Lightbend Inc. + */ + +package akka.stream.impl + +import java.io.OutputStream +import java.util.concurrent.TimeUnit + +import akka.Done +import akka.actor.ActorSystem +import akka.stream.ActorMaterializer +import akka.stream.scaladsl.{ Keep, Sink, StreamConverters } +import org.openjdk.jmh.annotations.TearDown + +import scala.concurrent.{ Await, Future } +import scala.concurrent.duration._ +import org.openjdk.jmh.annotations._ + +object OutputStreamSourceStageBenchmark { + final val WritesPerBench = 10000 +} +@State(Scope.Benchmark) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@BenchmarkMode(Array(Mode.Throughput)) +class OutputStreamSourceStageBenchmark { + import OutputStreamSourceStageBenchmark.WritesPerBench + implicit val system = ActorSystem("OutputStreamSourceStageBenchmark") + implicit val materializer = ActorMaterializer() + + private val bytes: Array[Byte] = Array.emptyByteArray + + private var os: OutputStream = _ + private var done: Future[Done] = _ + + @Benchmark + @OperationsPerInvocation(WritesPerBench) + def consumeWrites(): Unit = { + val (os, done) = StreamConverters.asOutputStream() + .toMat(Sink.ignore)(Keep.both) + .run() + new Thread(new Runnable { + def run(): Unit = { + var counter = 0 + while (counter > WritesPerBench) { + os.write(bytes) + counter += 1 + } + os.close() + } + }).start() + Await.result(done, 30.seconds) + } + + @TearDown + def shutdown(): Unit = { + Await.result(system.terminate(), 5.seconds) + } + +} diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala index 484d92db14..002a841695 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala @@ -68,22 +68,17 @@ class OutputStreamSourceSpec extends StreamSpec(UnboundedMailboxConfig) { probe.expectComplete() } - "block flush call until send all buffer to downstream" in assertAllStagesStopped { - val (outputStream, probe) = StreamConverters.asOutputStream().toMat(TestSink.probe[ByteString])(Keep.both).run - val s = probe.expectSubscription() - - outputStream.write(bytesArray) - val f = Future(outputStream.flush()) - - expectTimeout(f, timeout) - probe.expectNoMsg(Zero) - - s.request(1) - expectSuccess(f, ()) - probe.expectNext(byteString) - - outputStream.close() - probe.expectComplete() + // https://github.com/akka/akka/issues/25983 + "not truncate the stream on close" in assertAllStagesStopped { + for (_ ← 1 to 10) { + val (outputStream, result) = + StreamConverters.asOutputStream() + .toMat(Sink.fold[ByteString, ByteString](ByteString.empty)(_ ++ _))(Keep.both) + .run + outputStream.write(bytesArray) + outputStream.close() + result.futureValue should be(ByteString(bytesArray)) + } } "not block flushes when buffer is empty" in assertAllStagesStopped { @@ -134,34 +129,17 @@ class OutputStreamSourceSpec extends StreamSpec(UnboundedMailboxConfig) { the[Exception] thrownBy outputStream.write(bytesArray) shouldBe a[IOException] } - "use dedicated default-blocking-io-dispatcher by default" in assertAllStagesStopped { - val sys = ActorSystem("dispatcher-testing", UnboundedMailboxConfig) - val materializer = ActorMaterializer()(sys) - - try { - StreamConverters.asOutputStream().runWith(TestSink.probe[ByteString])(materializer) - materializer.asInstanceOf[PhasedFusingActorMaterializer].supervisor.tell(StreamSupervisor.GetChildren, testActor) - val ref = expectMsgType[Children].children.find(_.path.toString contains "outputStreamSource").get - assertDispatcher(ref, "akka.stream.default-blocking-io-dispatcher") - } finally shutdown(sys) - - } - "throw IOException when writing to the stream after the subscriber has cancelled the reactive stream" in assertAllStagesStopped { - val sourceProbe = TestProbe() - val (outputStream, probe) = TestSourceStage(new OutputStreamSourceStage(timeout), sourceProbe) + val (outputStream, sink) = StreamConverters.asOutputStream() .toMat(TestSink.probe[ByteString])(Keep.both).run - val s = probe.expectSubscription() + val s = sink.expectSubscription() outputStream.write(bytesArray) s.request(1) - sourceProbe.expectMsg(GraphStageMessages.Pull) - - probe.expectNext(byteString) + sink.expectNext(byteString) s.cancel() - sourceProbe.expectMsg(GraphStageMessages.DownstreamFinish) awaitAssert { the[Exception] thrownBy outputStream.write(bytesArray) shouldBe a[IOException] diff --git a/akka-stream/src/main/mima-filters/2.5.20.backwards.excludes b/akka-stream/src/main/mima-filters/2.5.20.backwards.excludes new file mode 100644 index 0000000000..251f4161a3 --- /dev/null +++ b/akka-stream/src/main/mima-filters/2.5.20.backwards.excludes @@ -0,0 +1,7 @@ +# OutputStreamSourceStage #25983 +ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.stream.impl.io.OutputStreamAdapter.this") +ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$Flush$") +ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.io.OutputStreamAdapter.*") +ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$DownstreamStatus") +ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$Ok$") +ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$Canceled$") diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala b/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala index 9e626b99dc..5062c8bcfb 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala @@ -5,30 +5,23 @@ package akka.stream.impl.io import java.io.{ IOException, OutputStream } -import java.util.concurrent.atomic.AtomicReference -import java.util.concurrent.{ BlockingQueue, LinkedBlockingQueue } +import java.util.concurrent.{ Semaphore, TimeUnit } import akka.stream.Attributes.InputBuffer import akka.stream.impl.Stages.DefaultAttributes import akka.stream.impl.io.OutputStreamSourceStage._ import akka.stream.stage._ -import akka.stream.{ ActorMaterializerHelper, Attributes, Outlet, SourceShape } +import akka.stream.{ Attributes, Outlet, SourceShape } import akka.util.ByteString +import scala.concurrent.Await import scala.concurrent.duration.FiniteDuration -import scala.concurrent.{ Await, ExecutionContext, Future, Promise } import scala.util.control.NonFatal -import scala.util.{ Failure, Success, Try } private[stream] object OutputStreamSourceStage { sealed trait AdapterToStageMessage - case object Flush extends AdapterToStageMessage + case class Send(data: ByteString) extends AdapterToStageMessage case object Close extends AdapterToStageMessage - - sealed trait DownstreamStatus - case object Ok extends DownstreamStatus - case object Canceled extends DownstreamStatus - } final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration) extends GraphStageWithMaterializedValue[SourceShape[ByteString], OutputStream] { @@ -41,160 +34,57 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration require(maxBuffer > 0, "Buffer size must be greater than 0") - val dataQueue = new LinkedBlockingQueue[ByteString](maxBuffer) - val downstreamStatus = new AtomicReference[DownstreamStatus](Ok) + // Semaphore counting the number of elements we are ready to accept. + // Initially we are ready to accept 'maxBuffer' elements, which will be buffered + // by 'emit' if there is no demand yet. + // Semaphore permits are taken out of the pool when inserting data into the + // OutputStream, and new permits are released when downstream signals demand. + val semaphore = new Semaphore(maxBuffer, /* fair =*/ true) final class OutputStreamSourceLogic extends GraphStageLogic(shape) { - var flush: Option[Promise[Unit]] = None - var close: Option[Promise[Unit]] = None - - private var dispatcher: ExecutionContext = null // set in preStart - private val blockingThreadRef: AtomicReference[Thread] = new AtomicReference() // for postStop interrupt - - private val downstreamCallback: AsyncCallback[Try[ByteString]] = - getAsyncCallback { - case Success(elem) ⇒ onPush(elem) - case Failure(ex) ⇒ failStage(ex) - } - - private val upstreamCallback: AsyncCallback[(AdapterToStageMessage, Promise[Unit])] = + val upstreamCallback: AsyncCallback[AdapterToStageMessage] = getAsyncCallback(onAsyncMessage) - def wakeUp(msg: AdapterToStageMessage): Future[Unit] = { - val p = Promise[Unit]() - upstreamCallback.invoke((msg, p)) - p.future - } - - private def onAsyncMessage(event: (AdapterToStageMessage, Promise[Unit])): Unit = - event._1 match { - case Flush ⇒ - flush = Some(event._2) - sendResponseIfNeed() + private def onAsyncMessage(event: AdapterToStageMessage): Unit = { + event match { + case Send(data) ⇒ + emit(out, data, () ⇒ semaphore.release()) case Close ⇒ - close = Some(event._2) - sendResponseIfNeed() + completeStage() } - - private def unblockUpstream(): Boolean = - flush match { - case Some(p) ⇒ - p.complete(Success(())) - flush = None - true - case None ⇒ close match { - case Some(p) ⇒ - downstreamStatus.set(Canceled) - p.complete(Success(())) - close = None - completeStage() - true - case None ⇒ false - } - } - - private def sendResponseIfNeed(): Unit = - if (downstreamStatus.get() == Canceled || dataQueue.isEmpty) unblockUpstream() - - private def onPush(data: ByteString): Unit = - if (downstreamStatus.get() == Ok) { - push(out, data) - sendResponseIfNeed() - } - - override def preStart(): Unit = { - // this stage is running on the blocking IO dispatcher by default, but we also want to schedule futures - // that are blocking, so we need to look it up - val actorMat = ActorMaterializerHelper.downcast(materializer) - dispatcher = actorMat.system.dispatchers.lookup(actorMat.settings.blockingIoDispatcher) } setHandler(out, new OutHandler { override def onPull(): Unit = { - implicit val ec = dispatcher - Future { - val currentThread = Thread.currentThread() - // keep track of the thread for postStop interrupt - blockingThreadRef.compareAndSet(null, currentThread) - try { - dataQueue.take() - } catch { - case _: InterruptedException ⇒ - Thread.interrupted() - ByteString.empty - } finally { - blockingThreadRef.compareAndSet(currentThread, null); - } - }.onComplete(downstreamCallback.invoke) } }) - - override def postStop(): Unit = { - //assuming there can be no further in messages - downstreamStatus.set(Canceled) - dataQueue.clear() - // if blocked reading, make sure the take() completes - dataQueue.put(ByteString.empty) - // interrupt any pending blocking take - val blockingThread = blockingThreadRef.get() - if (blockingThread != null) - blockingThread.interrupt() - super.postStop() - } } val logic = new OutputStreamSourceLogic - (logic, new OutputStreamAdapter(dataQueue, downstreamStatus, logic.wakeUp, writeTimeout)) + (logic, new OutputStreamAdapter(semaphore, logic.upstreamCallback, writeTimeout)) } } private[akka] class OutputStreamAdapter( - dataQueue: BlockingQueue[ByteString], - downstreamStatus: AtomicReference[DownstreamStatus], - sendToStage: (AdapterToStageMessage) ⇒ Future[Unit], - writeTimeout: FiniteDuration) + unfulfilledDemand: Semaphore, + sendToStage: AsyncCallback[AdapterToStageMessage], + writeTimeout: FiniteDuration) extends OutputStream { - var isActive = true - var isPublisherAlive = true - def publisherClosedException = new IOException("Reactive stream is terminated, no writes are possible") - @scala.throws(classOf[IOException]) - private[this] def send(sendAction: () ⇒ Unit): Unit = { - if (isActive) { - if (isPublisherAlive) sendAction() - else throw publisherClosedException - } else throw new IOException("OutputStream is closed") + private[this] def sendData(data: ByteString): Unit = { + if (!unfulfilledDemand.tryAcquire(writeTimeout.toMillis, TimeUnit.MILLISECONDS)) { + throw new IOException("Timed out trying to write data to stream") + } + + try { + Await.result(sendToStage.invokeWithFeedback(Send(data)), writeTimeout) + } catch { + case NonFatal(e) ⇒ throw new IOException(e) + } } - @scala.throws(classOf[IOException]) - private[this] def sendData(data: ByteString): Unit = - send(() ⇒ { - try { - dataQueue.put(data) - } catch { case NonFatal(ex) ⇒ throw new IOException(ex) } - if (downstreamStatus.get() == Canceled) { - isPublisherAlive = false - throw publisherClosedException - } - }) - - @scala.throws(classOf[IOException]) - private[this] def sendMessage(message: AdapterToStageMessage, handleCancelled: Boolean = true) = - send(() ⇒ - try { - Await.ready(sendToStage(message), writeTimeout) - if (downstreamStatus.get() == Canceled && handleCancelled) { - //Publisher considered to be terminated at earliest convenience to minimize messages sending back and forth - isPublisherAlive = false - throw publisherClosedException - } - } catch { - case e: IOException ⇒ throw e - case NonFatal(e) ⇒ throw new IOException(e) - }) - @scala.throws(classOf[IOException]) override def write(b: Int): Unit = { sendData(ByteString(b)) @@ -206,11 +96,18 @@ private[akka] class OutputStreamAdapter( } @scala.throws(classOf[IOException]) - override def flush(): Unit = sendMessage(Flush) + override def flush(): Unit = + // Flushing does nothing: at best we could guarantee that our own buffer + // is empty, but that doesn't mean the element has been accepted downstream, + // so there is little value in that. + () @scala.throws(classOf[IOException]) override def close(): Unit = { - sendMessage(Close, handleCancelled = false) - isActive = false + try { + Await.result(sendToStage.invokeWithFeedback(Close), writeTimeout) + } catch { + case NonFatal(e) ⇒ throw new IOException(e) + } } }