Fix race condition in StreamConverters.asOutputStream() (#26136)
Fixes #25983
This commit is contained in:
parent
d6ae3f1da9
commit
b0b0865e4c
4 changed files with 121 additions and 179 deletions
|
|
@ -0,0 +1,60 @@
|
||||||
|
/*
|
||||||
|
* Copyright (C) 2019 Lightbend Inc. <https://www.lightbend.com>
|
||||||
|
*/
|
||||||
|
|
||||||
|
package akka.stream.impl
|
||||||
|
|
||||||
|
import java.io.OutputStream
|
||||||
|
import java.util.concurrent.TimeUnit
|
||||||
|
|
||||||
|
import akka.Done
|
||||||
|
import akka.actor.ActorSystem
|
||||||
|
import akka.stream.ActorMaterializer
|
||||||
|
import akka.stream.scaladsl.{ Keep, Sink, StreamConverters }
|
||||||
|
import org.openjdk.jmh.annotations.TearDown
|
||||||
|
|
||||||
|
import scala.concurrent.{ Await, Future }
|
||||||
|
import scala.concurrent.duration._
|
||||||
|
import org.openjdk.jmh.annotations._
|
||||||
|
|
||||||
|
object OutputStreamSourceStageBenchmark {
|
||||||
|
final val WritesPerBench = 10000
|
||||||
|
}
|
||||||
|
@State(Scope.Benchmark)
|
||||||
|
@OutputTimeUnit(TimeUnit.MILLISECONDS)
|
||||||
|
@BenchmarkMode(Array(Mode.Throughput))
|
||||||
|
class OutputStreamSourceStageBenchmark {
|
||||||
|
import OutputStreamSourceStageBenchmark.WritesPerBench
|
||||||
|
implicit val system = ActorSystem("OutputStreamSourceStageBenchmark")
|
||||||
|
implicit val materializer = ActorMaterializer()
|
||||||
|
|
||||||
|
private val bytes: Array[Byte] = Array.emptyByteArray
|
||||||
|
|
||||||
|
private var os: OutputStream = _
|
||||||
|
private var done: Future[Done] = _
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
@OperationsPerInvocation(WritesPerBench)
|
||||||
|
def consumeWrites(): Unit = {
|
||||||
|
val (os, done) = StreamConverters.asOutputStream()
|
||||||
|
.toMat(Sink.ignore)(Keep.both)
|
||||||
|
.run()
|
||||||
|
new Thread(new Runnable {
|
||||||
|
def run(): Unit = {
|
||||||
|
var counter = 0
|
||||||
|
while (counter > WritesPerBench) {
|
||||||
|
os.write(bytes)
|
||||||
|
counter += 1
|
||||||
|
}
|
||||||
|
os.close()
|
||||||
|
}
|
||||||
|
}).start()
|
||||||
|
Await.result(done, 30.seconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
@TearDown
|
||||||
|
def shutdown(): Unit = {
|
||||||
|
Await.result(system.terminate(), 5.seconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -68,22 +68,17 @@ class OutputStreamSourceSpec extends StreamSpec(UnboundedMailboxConfig) {
|
||||||
probe.expectComplete()
|
probe.expectComplete()
|
||||||
}
|
}
|
||||||
|
|
||||||
"block flush call until send all buffer to downstream" in assertAllStagesStopped {
|
// https://github.com/akka/akka/issues/25983
|
||||||
val (outputStream, probe) = StreamConverters.asOutputStream().toMat(TestSink.probe[ByteString])(Keep.both).run
|
"not truncate the stream on close" in assertAllStagesStopped {
|
||||||
val s = probe.expectSubscription()
|
for (_ ← 1 to 10) {
|
||||||
|
val (outputStream, result) =
|
||||||
|
StreamConverters.asOutputStream()
|
||||||
|
.toMat(Sink.fold[ByteString, ByteString](ByteString.empty)(_ ++ _))(Keep.both)
|
||||||
|
.run
|
||||||
outputStream.write(bytesArray)
|
outputStream.write(bytesArray)
|
||||||
val f = Future(outputStream.flush())
|
|
||||||
|
|
||||||
expectTimeout(f, timeout)
|
|
||||||
probe.expectNoMsg(Zero)
|
|
||||||
|
|
||||||
s.request(1)
|
|
||||||
expectSuccess(f, ())
|
|
||||||
probe.expectNext(byteString)
|
|
||||||
|
|
||||||
outputStream.close()
|
outputStream.close()
|
||||||
probe.expectComplete()
|
result.futureValue should be(ByteString(bytesArray))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
"not block flushes when buffer is empty" in assertAllStagesStopped {
|
"not block flushes when buffer is empty" in assertAllStagesStopped {
|
||||||
|
|
@ -134,34 +129,17 @@ class OutputStreamSourceSpec extends StreamSpec(UnboundedMailboxConfig) {
|
||||||
the[Exception] thrownBy outputStream.write(bytesArray) shouldBe a[IOException]
|
the[Exception] thrownBy outputStream.write(bytesArray) shouldBe a[IOException]
|
||||||
}
|
}
|
||||||
|
|
||||||
"use dedicated default-blocking-io-dispatcher by default" in assertAllStagesStopped {
|
|
||||||
val sys = ActorSystem("dispatcher-testing", UnboundedMailboxConfig)
|
|
||||||
val materializer = ActorMaterializer()(sys)
|
|
||||||
|
|
||||||
try {
|
|
||||||
StreamConverters.asOutputStream().runWith(TestSink.probe[ByteString])(materializer)
|
|
||||||
materializer.asInstanceOf[PhasedFusingActorMaterializer].supervisor.tell(StreamSupervisor.GetChildren, testActor)
|
|
||||||
val ref = expectMsgType[Children].children.find(_.path.toString contains "outputStreamSource").get
|
|
||||||
assertDispatcher(ref, "akka.stream.default-blocking-io-dispatcher")
|
|
||||||
} finally shutdown(sys)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
"throw IOException when writing to the stream after the subscriber has cancelled the reactive stream" in assertAllStagesStopped {
|
"throw IOException when writing to the stream after the subscriber has cancelled the reactive stream" in assertAllStagesStopped {
|
||||||
val sourceProbe = TestProbe()
|
val (outputStream, sink) = StreamConverters.asOutputStream()
|
||||||
val (outputStream, probe) = TestSourceStage(new OutputStreamSourceStage(timeout), sourceProbe)
|
|
||||||
.toMat(TestSink.probe[ByteString])(Keep.both).run
|
.toMat(TestSink.probe[ByteString])(Keep.both).run
|
||||||
|
|
||||||
val s = probe.expectSubscription()
|
val s = sink.expectSubscription()
|
||||||
|
|
||||||
outputStream.write(bytesArray)
|
outputStream.write(bytesArray)
|
||||||
s.request(1)
|
s.request(1)
|
||||||
sourceProbe.expectMsg(GraphStageMessages.Pull)
|
|
||||||
|
|
||||||
probe.expectNext(byteString)
|
|
||||||
|
|
||||||
|
sink.expectNext(byteString)
|
||||||
s.cancel()
|
s.cancel()
|
||||||
sourceProbe.expectMsg(GraphStageMessages.DownstreamFinish)
|
|
||||||
|
|
||||||
awaitAssert {
|
awaitAssert {
|
||||||
the[Exception] thrownBy outputStream.write(bytesArray) shouldBe a[IOException]
|
the[Exception] thrownBy outputStream.write(bytesArray) shouldBe a[IOException]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
# OutputStreamSourceStage #25983
|
||||||
|
ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.stream.impl.io.OutputStreamAdapter.this")
|
||||||
|
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$Flush$")
|
||||||
|
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.io.OutputStreamAdapter.*")
|
||||||
|
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$DownstreamStatus")
|
||||||
|
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$Ok$")
|
||||||
|
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$Canceled$")
|
||||||
|
|
@ -5,30 +5,23 @@
|
||||||
package akka.stream.impl.io
|
package akka.stream.impl.io
|
||||||
|
|
||||||
import java.io.{ IOException, OutputStream }
|
import java.io.{ IOException, OutputStream }
|
||||||
import java.util.concurrent.atomic.AtomicReference
|
import java.util.concurrent.{ Semaphore, TimeUnit }
|
||||||
import java.util.concurrent.{ BlockingQueue, LinkedBlockingQueue }
|
|
||||||
|
|
||||||
import akka.stream.Attributes.InputBuffer
|
import akka.stream.Attributes.InputBuffer
|
||||||
import akka.stream.impl.Stages.DefaultAttributes
|
import akka.stream.impl.Stages.DefaultAttributes
|
||||||
import akka.stream.impl.io.OutputStreamSourceStage._
|
import akka.stream.impl.io.OutputStreamSourceStage._
|
||||||
import akka.stream.stage._
|
import akka.stream.stage._
|
||||||
import akka.stream.{ ActorMaterializerHelper, Attributes, Outlet, SourceShape }
|
import akka.stream.{ Attributes, Outlet, SourceShape }
|
||||||
import akka.util.ByteString
|
import akka.util.ByteString
|
||||||
|
|
||||||
|
import scala.concurrent.Await
|
||||||
import scala.concurrent.duration.FiniteDuration
|
import scala.concurrent.duration.FiniteDuration
|
||||||
import scala.concurrent.{ Await, ExecutionContext, Future, Promise }
|
|
||||||
import scala.util.control.NonFatal
|
import scala.util.control.NonFatal
|
||||||
import scala.util.{ Failure, Success, Try }
|
|
||||||
|
|
||||||
private[stream] object OutputStreamSourceStage {
|
private[stream] object OutputStreamSourceStage {
|
||||||
sealed trait AdapterToStageMessage
|
sealed trait AdapterToStageMessage
|
||||||
case object Flush extends AdapterToStageMessage
|
case class Send(data: ByteString) extends AdapterToStageMessage
|
||||||
case object Close extends AdapterToStageMessage
|
case object Close extends AdapterToStageMessage
|
||||||
|
|
||||||
sealed trait DownstreamStatus
|
|
||||||
case object Ok extends DownstreamStatus
|
|
||||||
case object Canceled extends DownstreamStatus
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration) extends GraphStageWithMaterializedValue[SourceShape[ByteString], OutputStream] {
|
final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration) extends GraphStageWithMaterializedValue[SourceShape[ByteString], OutputStream] {
|
||||||
|
|
@ -41,159 +34,56 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration
|
||||||
|
|
||||||
require(maxBuffer > 0, "Buffer size must be greater than 0")
|
require(maxBuffer > 0, "Buffer size must be greater than 0")
|
||||||
|
|
||||||
val dataQueue = new LinkedBlockingQueue[ByteString](maxBuffer)
|
// Semaphore counting the number of elements we are ready to accept.
|
||||||
val downstreamStatus = new AtomicReference[DownstreamStatus](Ok)
|
// Initially we are ready to accept 'maxBuffer' elements, which will be buffered
|
||||||
|
// by 'emit' if there is no demand yet.
|
||||||
|
// Semaphore permits are taken out of the pool when inserting data into the
|
||||||
|
// OutputStream, and new permits are released when downstream signals demand.
|
||||||
|
val semaphore = new Semaphore(maxBuffer, /* fair =*/ true)
|
||||||
|
|
||||||
final class OutputStreamSourceLogic extends GraphStageLogic(shape) {
|
final class OutputStreamSourceLogic extends GraphStageLogic(shape) {
|
||||||
|
|
||||||
var flush: Option[Promise[Unit]] = None
|
val upstreamCallback: AsyncCallback[AdapterToStageMessage] =
|
||||||
var close: Option[Promise[Unit]] = None
|
|
||||||
|
|
||||||
private var dispatcher: ExecutionContext = null // set in preStart
|
|
||||||
private val blockingThreadRef: AtomicReference[Thread] = new AtomicReference() // for postStop interrupt
|
|
||||||
|
|
||||||
private val downstreamCallback: AsyncCallback[Try[ByteString]] =
|
|
||||||
getAsyncCallback {
|
|
||||||
case Success(elem) ⇒ onPush(elem)
|
|
||||||
case Failure(ex) ⇒ failStage(ex)
|
|
||||||
}
|
|
||||||
|
|
||||||
private val upstreamCallback: AsyncCallback[(AdapterToStageMessage, Promise[Unit])] =
|
|
||||||
getAsyncCallback(onAsyncMessage)
|
getAsyncCallback(onAsyncMessage)
|
||||||
|
|
||||||
def wakeUp(msg: AdapterToStageMessage): Future[Unit] = {
|
private def onAsyncMessage(event: AdapterToStageMessage): Unit = {
|
||||||
val p = Promise[Unit]()
|
event match {
|
||||||
upstreamCallback.invoke((msg, p))
|
case Send(data) ⇒
|
||||||
p.future
|
emit(out, data, () ⇒ semaphore.release())
|
||||||
}
|
|
||||||
|
|
||||||
private def onAsyncMessage(event: (AdapterToStageMessage, Promise[Unit])): Unit =
|
|
||||||
event._1 match {
|
|
||||||
case Flush ⇒
|
|
||||||
flush = Some(event._2)
|
|
||||||
sendResponseIfNeed()
|
|
||||||
case Close ⇒
|
case Close ⇒
|
||||||
close = Some(event._2)
|
|
||||||
sendResponseIfNeed()
|
|
||||||
}
|
|
||||||
|
|
||||||
private def unblockUpstream(): Boolean =
|
|
||||||
flush match {
|
|
||||||
case Some(p) ⇒
|
|
||||||
p.complete(Success(()))
|
|
||||||
flush = None
|
|
||||||
true
|
|
||||||
case None ⇒ close match {
|
|
||||||
case Some(p) ⇒
|
|
||||||
downstreamStatus.set(Canceled)
|
|
||||||
p.complete(Success(()))
|
|
||||||
close = None
|
|
||||||
completeStage()
|
completeStage()
|
||||||
true
|
|
||||||
case None ⇒ false
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def sendResponseIfNeed(): Unit =
|
|
||||||
if (downstreamStatus.get() == Canceled || dataQueue.isEmpty) unblockUpstream()
|
|
||||||
|
|
||||||
private def onPush(data: ByteString): Unit =
|
|
||||||
if (downstreamStatus.get() == Ok) {
|
|
||||||
push(out, data)
|
|
||||||
sendResponseIfNeed()
|
|
||||||
}
|
|
||||||
|
|
||||||
override def preStart(): Unit = {
|
|
||||||
// this stage is running on the blocking IO dispatcher by default, but we also want to schedule futures
|
|
||||||
// that are blocking, so we need to look it up
|
|
||||||
val actorMat = ActorMaterializerHelper.downcast(materializer)
|
|
||||||
dispatcher = actorMat.system.dispatchers.lookup(actorMat.settings.blockingIoDispatcher)
|
|
||||||
}
|
|
||||||
|
|
||||||
setHandler(out, new OutHandler {
|
setHandler(out, new OutHandler {
|
||||||
override def onPull(): Unit = {
|
override def onPull(): Unit = {
|
||||||
implicit val ec = dispatcher
|
|
||||||
Future {
|
|
||||||
val currentThread = Thread.currentThread()
|
|
||||||
// keep track of the thread for postStop interrupt
|
|
||||||
blockingThreadRef.compareAndSet(null, currentThread)
|
|
||||||
try {
|
|
||||||
dataQueue.take()
|
|
||||||
} catch {
|
|
||||||
case _: InterruptedException ⇒
|
|
||||||
Thread.interrupted()
|
|
||||||
ByteString.empty
|
|
||||||
} finally {
|
|
||||||
blockingThreadRef.compareAndSet(currentThread, null);
|
|
||||||
}
|
|
||||||
}.onComplete(downstreamCallback.invoke)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
override def postStop(): Unit = {
|
|
||||||
//assuming there can be no further in messages
|
|
||||||
downstreamStatus.set(Canceled)
|
|
||||||
dataQueue.clear()
|
|
||||||
// if blocked reading, make sure the take() completes
|
|
||||||
dataQueue.put(ByteString.empty)
|
|
||||||
// interrupt any pending blocking take
|
|
||||||
val blockingThread = blockingThreadRef.get()
|
|
||||||
if (blockingThread != null)
|
|
||||||
blockingThread.interrupt()
|
|
||||||
super.postStop()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val logic = new OutputStreamSourceLogic
|
val logic = new OutputStreamSourceLogic
|
||||||
(logic, new OutputStreamAdapter(dataQueue, downstreamStatus, logic.wakeUp, writeTimeout))
|
(logic, new OutputStreamAdapter(semaphore, logic.upstreamCallback, writeTimeout))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private[akka] class OutputStreamAdapter(
|
private[akka] class OutputStreamAdapter(
|
||||||
dataQueue: BlockingQueue[ByteString],
|
unfulfilledDemand: Semaphore,
|
||||||
downstreamStatus: AtomicReference[DownstreamStatus],
|
sendToStage: AsyncCallback[AdapterToStageMessage],
|
||||||
sendToStage: (AdapterToStageMessage) ⇒ Future[Unit],
|
|
||||||
writeTimeout: FiniteDuration)
|
writeTimeout: FiniteDuration)
|
||||||
extends OutputStream {
|
extends OutputStream {
|
||||||
|
|
||||||
var isActive = true
|
|
||||||
var isPublisherAlive = true
|
|
||||||
def publisherClosedException = new IOException("Reactive stream is terminated, no writes are possible")
|
|
||||||
|
|
||||||
@scala.throws(classOf[IOException])
|
@scala.throws(classOf[IOException])
|
||||||
private[this] def send(sendAction: () ⇒ Unit): Unit = {
|
private[this] def sendData(data: ByteString): Unit = {
|
||||||
if (isActive) {
|
if (!unfulfilledDemand.tryAcquire(writeTimeout.toMillis, TimeUnit.MILLISECONDS)) {
|
||||||
if (isPublisherAlive) sendAction()
|
throw new IOException("Timed out trying to write data to stream")
|
||||||
else throw publisherClosedException
|
|
||||||
} else throw new IOException("OutputStream is closed")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@scala.throws(classOf[IOException])
|
|
||||||
private[this] def sendData(data: ByteString): Unit =
|
|
||||||
send(() ⇒ {
|
|
||||||
try {
|
try {
|
||||||
dataQueue.put(data)
|
Await.result(sendToStage.invokeWithFeedback(Send(data)), writeTimeout)
|
||||||
} catch { case NonFatal(ex) ⇒ throw new IOException(ex) }
|
|
||||||
if (downstreamStatus.get() == Canceled) {
|
|
||||||
isPublisherAlive = false
|
|
||||||
throw publisherClosedException
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
@scala.throws(classOf[IOException])
|
|
||||||
private[this] def sendMessage(message: AdapterToStageMessage, handleCancelled: Boolean = true) =
|
|
||||||
send(() ⇒
|
|
||||||
try {
|
|
||||||
Await.ready(sendToStage(message), writeTimeout)
|
|
||||||
if (downstreamStatus.get() == Canceled && handleCancelled) {
|
|
||||||
//Publisher considered to be terminated at earliest convenience to minimize messages sending back and forth
|
|
||||||
isPublisherAlive = false
|
|
||||||
throw publisherClosedException
|
|
||||||
}
|
|
||||||
} catch {
|
} catch {
|
||||||
case e: IOException ⇒ throw e
|
|
||||||
case NonFatal(e) ⇒ throw new IOException(e)
|
case NonFatal(e) ⇒ throw new IOException(e)
|
||||||
})
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@scala.throws(classOf[IOException])
|
@scala.throws(classOf[IOException])
|
||||||
override def write(b: Int): Unit = {
|
override def write(b: Int): Unit = {
|
||||||
|
|
@ -206,11 +96,18 @@ private[akka] class OutputStreamAdapter(
|
||||||
}
|
}
|
||||||
|
|
||||||
@scala.throws(classOf[IOException])
|
@scala.throws(classOf[IOException])
|
||||||
override def flush(): Unit = sendMessage(Flush)
|
override def flush(): Unit =
|
||||||
|
// Flushing does nothing: at best we could guarantee that our own buffer
|
||||||
|
// is empty, but that doesn't mean the element has been accepted downstream,
|
||||||
|
// so there is little value in that.
|
||||||
|
()
|
||||||
|
|
||||||
@scala.throws(classOf[IOException])
|
@scala.throws(classOf[IOException])
|
||||||
override def close(): Unit = {
|
override def close(): Unit = {
|
||||||
sendMessage(Close, handleCancelled = false)
|
try {
|
||||||
isActive = false
|
Await.result(sendToStage.invokeWithFeedback(Close), writeTimeout)
|
||||||
|
} catch {
|
||||||
|
case NonFatal(e) ⇒ throw new IOException(e)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue