Fix race condition in StreamConverters.asOutputStream() (#26136)

Fixes #25983
This commit is contained in:
Arnout Engelen 2019-02-07 17:00:38 +01:00 committed by Johannes Rudolph
parent d6ae3f1da9
commit b0b0865e4c
4 changed files with 121 additions and 179 deletions

View file

@ -0,0 +1,60 @@
/*
* Copyright (C) 2019 Lightbend Inc. <https://www.lightbend.com>
*/
package akka.stream.impl
import java.io.OutputStream
import java.util.concurrent.TimeUnit
import akka.Done
import akka.actor.ActorSystem
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.{ Keep, Sink, StreamConverters }
import org.openjdk.jmh.annotations.TearDown
import scala.concurrent.{ Await, Future }
import scala.concurrent.duration._
import org.openjdk.jmh.annotations._
object OutputStreamSourceStageBenchmark {
final val WritesPerBench = 10000
}
@State(Scope.Benchmark)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@BenchmarkMode(Array(Mode.Throughput))
class OutputStreamSourceStageBenchmark {
import OutputStreamSourceStageBenchmark.WritesPerBench
implicit val system = ActorSystem("OutputStreamSourceStageBenchmark")
implicit val materializer = ActorMaterializer()
private val bytes: Array[Byte] = Array.emptyByteArray
private var os: OutputStream = _
private var done: Future[Done] = _
@Benchmark
@OperationsPerInvocation(WritesPerBench)
def consumeWrites(): Unit = {
val (os, done) = StreamConverters.asOutputStream()
.toMat(Sink.ignore)(Keep.both)
.run()
new Thread(new Runnable {
def run(): Unit = {
var counter = 0
while (counter > WritesPerBench) {
os.write(bytes)
counter += 1
}
os.close()
}
}).start()
Await.result(done, 30.seconds)
}
@TearDown
def shutdown(): Unit = {
Await.result(system.terminate(), 5.seconds)
}
}

View file

@ -68,22 +68,17 @@ class OutputStreamSourceSpec extends StreamSpec(UnboundedMailboxConfig) {
probe.expectComplete() probe.expectComplete()
} }
"block flush call until send all buffer to downstream" in assertAllStagesStopped { // https://github.com/akka/akka/issues/25983
val (outputStream, probe) = StreamConverters.asOutputStream().toMat(TestSink.probe[ByteString])(Keep.both).run "not truncate the stream on close" in assertAllStagesStopped {
val s = probe.expectSubscription() for (_ 1 to 10) {
val (outputStream, result) =
StreamConverters.asOutputStream()
.toMat(Sink.fold[ByteString, ByteString](ByteString.empty)(_ ++ _))(Keep.both)
.run
outputStream.write(bytesArray) outputStream.write(bytesArray)
val f = Future(outputStream.flush())
expectTimeout(f, timeout)
probe.expectNoMsg(Zero)
s.request(1)
expectSuccess(f, ())
probe.expectNext(byteString)
outputStream.close() outputStream.close()
probe.expectComplete() result.futureValue should be(ByteString(bytesArray))
}
} }
"not block flushes when buffer is empty" in assertAllStagesStopped { "not block flushes when buffer is empty" in assertAllStagesStopped {
@ -134,34 +129,17 @@ class OutputStreamSourceSpec extends StreamSpec(UnboundedMailboxConfig) {
the[Exception] thrownBy outputStream.write(bytesArray) shouldBe a[IOException] the[Exception] thrownBy outputStream.write(bytesArray) shouldBe a[IOException]
} }
"use dedicated default-blocking-io-dispatcher by default" in assertAllStagesStopped {
val sys = ActorSystem("dispatcher-testing", UnboundedMailboxConfig)
val materializer = ActorMaterializer()(sys)
try {
StreamConverters.asOutputStream().runWith(TestSink.probe[ByteString])(materializer)
materializer.asInstanceOf[PhasedFusingActorMaterializer].supervisor.tell(StreamSupervisor.GetChildren, testActor)
val ref = expectMsgType[Children].children.find(_.path.toString contains "outputStreamSource").get
assertDispatcher(ref, "akka.stream.default-blocking-io-dispatcher")
} finally shutdown(sys)
}
"throw IOException when writing to the stream after the subscriber has cancelled the reactive stream" in assertAllStagesStopped { "throw IOException when writing to the stream after the subscriber has cancelled the reactive stream" in assertAllStagesStopped {
val sourceProbe = TestProbe() val (outputStream, sink) = StreamConverters.asOutputStream()
val (outputStream, probe) = TestSourceStage(new OutputStreamSourceStage(timeout), sourceProbe)
.toMat(TestSink.probe[ByteString])(Keep.both).run .toMat(TestSink.probe[ByteString])(Keep.both).run
val s = probe.expectSubscription() val s = sink.expectSubscription()
outputStream.write(bytesArray) outputStream.write(bytesArray)
s.request(1) s.request(1)
sourceProbe.expectMsg(GraphStageMessages.Pull)
probe.expectNext(byteString)
sink.expectNext(byteString)
s.cancel() s.cancel()
sourceProbe.expectMsg(GraphStageMessages.DownstreamFinish)
awaitAssert { awaitAssert {
the[Exception] thrownBy outputStream.write(bytesArray) shouldBe a[IOException] the[Exception] thrownBy outputStream.write(bytesArray) shouldBe a[IOException]

View file

@ -0,0 +1,7 @@
# OutputStreamSourceStage #25983
ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.stream.impl.io.OutputStreamAdapter.this")
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$Flush$")
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.io.OutputStreamAdapter.*")
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$DownstreamStatus")
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$Ok$")
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$Canceled$")

View file

@ -5,30 +5,23 @@
package akka.stream.impl.io package akka.stream.impl.io
import java.io.{ IOException, OutputStream } import java.io.{ IOException, OutputStream }
import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.{ Semaphore, TimeUnit }
import java.util.concurrent.{ BlockingQueue, LinkedBlockingQueue }
import akka.stream.Attributes.InputBuffer import akka.stream.Attributes.InputBuffer
import akka.stream.impl.Stages.DefaultAttributes import akka.stream.impl.Stages.DefaultAttributes
import akka.stream.impl.io.OutputStreamSourceStage._ import akka.stream.impl.io.OutputStreamSourceStage._
import akka.stream.stage._ import akka.stream.stage._
import akka.stream.{ ActorMaterializerHelper, Attributes, Outlet, SourceShape } import akka.stream.{ Attributes, Outlet, SourceShape }
import akka.util.ByteString import akka.util.ByteString
import scala.concurrent.Await
import scala.concurrent.duration.FiniteDuration import scala.concurrent.duration.FiniteDuration
import scala.concurrent.{ Await, ExecutionContext, Future, Promise }
import scala.util.control.NonFatal import scala.util.control.NonFatal
import scala.util.{ Failure, Success, Try }
private[stream] object OutputStreamSourceStage { private[stream] object OutputStreamSourceStage {
sealed trait AdapterToStageMessage sealed trait AdapterToStageMessage
case object Flush extends AdapterToStageMessage case class Send(data: ByteString) extends AdapterToStageMessage
case object Close extends AdapterToStageMessage case object Close extends AdapterToStageMessage
sealed trait DownstreamStatus
case object Ok extends DownstreamStatus
case object Canceled extends DownstreamStatus
} }
final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration) extends GraphStageWithMaterializedValue[SourceShape[ByteString], OutputStream] { final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration) extends GraphStageWithMaterializedValue[SourceShape[ByteString], OutputStream] {
@ -41,159 +34,56 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration
require(maxBuffer > 0, "Buffer size must be greater than 0") require(maxBuffer > 0, "Buffer size must be greater than 0")
val dataQueue = new LinkedBlockingQueue[ByteString](maxBuffer) // Semaphore counting the number of elements we are ready to accept.
val downstreamStatus = new AtomicReference[DownstreamStatus](Ok) // Initially we are ready to accept 'maxBuffer' elements, which will be buffered
// by 'emit' if there is no demand yet.
// Semaphore permits are taken out of the pool when inserting data into the
// OutputStream, and new permits are released when downstream signals demand.
val semaphore = new Semaphore(maxBuffer, /* fair =*/ true)
final class OutputStreamSourceLogic extends GraphStageLogic(shape) { final class OutputStreamSourceLogic extends GraphStageLogic(shape) {
var flush: Option[Promise[Unit]] = None val upstreamCallback: AsyncCallback[AdapterToStageMessage] =
var close: Option[Promise[Unit]] = None
private var dispatcher: ExecutionContext = null // set in preStart
private val blockingThreadRef: AtomicReference[Thread] = new AtomicReference() // for postStop interrupt
private val downstreamCallback: AsyncCallback[Try[ByteString]] =
getAsyncCallback {
case Success(elem) onPush(elem)
case Failure(ex) failStage(ex)
}
private val upstreamCallback: AsyncCallback[(AdapterToStageMessage, Promise[Unit])] =
getAsyncCallback(onAsyncMessage) getAsyncCallback(onAsyncMessage)
def wakeUp(msg: AdapterToStageMessage): Future[Unit] = { private def onAsyncMessage(event: AdapterToStageMessage): Unit = {
val p = Promise[Unit]() event match {
upstreamCallback.invoke((msg, p)) case Send(data)
p.future emit(out, data, () semaphore.release())
}
private def onAsyncMessage(event: (AdapterToStageMessage, Promise[Unit])): Unit =
event._1 match {
case Flush
flush = Some(event._2)
sendResponseIfNeed()
case Close case Close
close = Some(event._2)
sendResponseIfNeed()
}
private def unblockUpstream(): Boolean =
flush match {
case Some(p)
p.complete(Success(()))
flush = None
true
case None close match {
case Some(p)
downstreamStatus.set(Canceled)
p.complete(Success(()))
close = None
completeStage() completeStage()
true
case None false
} }
} }
private def sendResponseIfNeed(): Unit =
if (downstreamStatus.get() == Canceled || dataQueue.isEmpty) unblockUpstream()
private def onPush(data: ByteString): Unit =
if (downstreamStatus.get() == Ok) {
push(out, data)
sendResponseIfNeed()
}
override def preStart(): Unit = {
// this stage is running on the blocking IO dispatcher by default, but we also want to schedule futures
// that are blocking, so we need to look it up
val actorMat = ActorMaterializerHelper.downcast(materializer)
dispatcher = actorMat.system.dispatchers.lookup(actorMat.settings.blockingIoDispatcher)
}
setHandler(out, new OutHandler { setHandler(out, new OutHandler {
override def onPull(): Unit = { override def onPull(): Unit = {
implicit val ec = dispatcher
Future {
val currentThread = Thread.currentThread()
// keep track of the thread for postStop interrupt
blockingThreadRef.compareAndSet(null, currentThread)
try {
dataQueue.take()
} catch {
case _: InterruptedException
Thread.interrupted()
ByteString.empty
} finally {
blockingThreadRef.compareAndSet(currentThread, null);
}
}.onComplete(downstreamCallback.invoke)
} }
}) })
override def postStop(): Unit = {
//assuming there can be no further in messages
downstreamStatus.set(Canceled)
dataQueue.clear()
// if blocked reading, make sure the take() completes
dataQueue.put(ByteString.empty)
// interrupt any pending blocking take
val blockingThread = blockingThreadRef.get()
if (blockingThread != null)
blockingThread.interrupt()
super.postStop()
}
} }
val logic = new OutputStreamSourceLogic val logic = new OutputStreamSourceLogic
(logic, new OutputStreamAdapter(dataQueue, downstreamStatus, logic.wakeUp, writeTimeout)) (logic, new OutputStreamAdapter(semaphore, logic.upstreamCallback, writeTimeout))
} }
} }
private[akka] class OutputStreamAdapter( private[akka] class OutputStreamAdapter(
dataQueue: BlockingQueue[ByteString], unfulfilledDemand: Semaphore,
downstreamStatus: AtomicReference[DownstreamStatus], sendToStage: AsyncCallback[AdapterToStageMessage],
sendToStage: (AdapterToStageMessage) Future[Unit],
writeTimeout: FiniteDuration) writeTimeout: FiniteDuration)
extends OutputStream { extends OutputStream {
var isActive = true
var isPublisherAlive = true
def publisherClosedException = new IOException("Reactive stream is terminated, no writes are possible")
@scala.throws(classOf[IOException]) @scala.throws(classOf[IOException])
private[this] def send(sendAction: () Unit): Unit = { private[this] def sendData(data: ByteString): Unit = {
if (isActive) { if (!unfulfilledDemand.tryAcquire(writeTimeout.toMillis, TimeUnit.MILLISECONDS)) {
if (isPublisherAlive) sendAction() throw new IOException("Timed out trying to write data to stream")
else throw publisherClosedException
} else throw new IOException("OutputStream is closed")
} }
@scala.throws(classOf[IOException])
private[this] def sendData(data: ByteString): Unit =
send(() {
try { try {
dataQueue.put(data) Await.result(sendToStage.invokeWithFeedback(Send(data)), writeTimeout)
} catch { case NonFatal(ex) throw new IOException(ex) }
if (downstreamStatus.get() == Canceled) {
isPublisherAlive = false
throw publisherClosedException
}
})
@scala.throws(classOf[IOException])
private[this] def sendMessage(message: AdapterToStageMessage, handleCancelled: Boolean = true) =
send(()
try {
Await.ready(sendToStage(message), writeTimeout)
if (downstreamStatus.get() == Canceled && handleCancelled) {
//Publisher considered to be terminated at earliest convenience to minimize messages sending back and forth
isPublisherAlive = false
throw publisherClosedException
}
} catch { } catch {
case e: IOException throw e
case NonFatal(e) throw new IOException(e) case NonFatal(e) throw new IOException(e)
}) }
}
@scala.throws(classOf[IOException]) @scala.throws(classOf[IOException])
override def write(b: Int): Unit = { override def write(b: Int): Unit = {
@ -206,11 +96,18 @@ private[akka] class OutputStreamAdapter(
} }
@scala.throws(classOf[IOException]) @scala.throws(classOf[IOException])
override def flush(): Unit = sendMessage(Flush) override def flush(): Unit =
// Flushing does nothing: at best we could guarantee that our own buffer
// is empty, but that doesn't mean the element has been accepted downstream,
// so there is little value in that.
()
@scala.throws(classOf[IOException]) @scala.throws(classOf[IOException])
override def close(): Unit = { override def close(): Unit = {
sendMessage(Close, handleCancelled = false) try {
isActive = false Await.result(sendToStage.invokeWithFeedback(Close), writeTimeout)
} catch {
case NonFatal(e) throw new IOException(e)
}
} }
} }