Fix race condition in StreamConverters.asOutputStream() (#26136)
Fixes #25983
This commit is contained in:
parent
d6ae3f1da9
commit
b0b0865e4c
4 changed files with 121 additions and 179 deletions
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* Copyright (C) 2019 Lightbend Inc. <https://www.lightbend.com>
|
||||
*/
|
||||
|
||||
package akka.stream.impl
|
||||
|
||||
import java.io.OutputStream
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
import akka.Done
|
||||
import akka.actor.ActorSystem
|
||||
import akka.stream.ActorMaterializer
|
||||
import akka.stream.scaladsl.{ Keep, Sink, StreamConverters }
|
||||
import org.openjdk.jmh.annotations.TearDown
|
||||
|
||||
import scala.concurrent.{ Await, Future }
|
||||
import scala.concurrent.duration._
|
||||
import org.openjdk.jmh.annotations._
|
||||
|
||||
object OutputStreamSourceStageBenchmark {
|
||||
final val WritesPerBench = 10000
|
||||
}
|
||||
@State(Scope.Benchmark)
|
||||
@OutputTimeUnit(TimeUnit.MILLISECONDS)
|
||||
@BenchmarkMode(Array(Mode.Throughput))
|
||||
class OutputStreamSourceStageBenchmark {
|
||||
import OutputStreamSourceStageBenchmark.WritesPerBench
|
||||
implicit val system = ActorSystem("OutputStreamSourceStageBenchmark")
|
||||
implicit val materializer = ActorMaterializer()
|
||||
|
||||
private val bytes: Array[Byte] = Array.emptyByteArray
|
||||
|
||||
private var os: OutputStream = _
|
||||
private var done: Future[Done] = _
|
||||
|
||||
@Benchmark
|
||||
@OperationsPerInvocation(WritesPerBench)
|
||||
def consumeWrites(): Unit = {
|
||||
val (os, done) = StreamConverters.asOutputStream()
|
||||
.toMat(Sink.ignore)(Keep.both)
|
||||
.run()
|
||||
new Thread(new Runnable {
|
||||
def run(): Unit = {
|
||||
var counter = 0
|
||||
while (counter > WritesPerBench) {
|
||||
os.write(bytes)
|
||||
counter += 1
|
||||
}
|
||||
os.close()
|
||||
}
|
||||
}).start()
|
||||
Await.result(done, 30.seconds)
|
||||
}
|
||||
|
||||
@TearDown
|
||||
def shutdown(): Unit = {
|
||||
Await.result(system.terminate(), 5.seconds)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -68,22 +68,17 @@ class OutputStreamSourceSpec extends StreamSpec(UnboundedMailboxConfig) {
|
|||
probe.expectComplete()
|
||||
}
|
||||
|
||||
"block flush call until send all buffer to downstream" in assertAllStagesStopped {
|
||||
val (outputStream, probe) = StreamConverters.asOutputStream().toMat(TestSink.probe[ByteString])(Keep.both).run
|
||||
val s = probe.expectSubscription()
|
||||
|
||||
outputStream.write(bytesArray)
|
||||
val f = Future(outputStream.flush())
|
||||
|
||||
expectTimeout(f, timeout)
|
||||
probe.expectNoMsg(Zero)
|
||||
|
||||
s.request(1)
|
||||
expectSuccess(f, ())
|
||||
probe.expectNext(byteString)
|
||||
|
||||
outputStream.close()
|
||||
probe.expectComplete()
|
||||
// https://github.com/akka/akka/issues/25983
|
||||
"not truncate the stream on close" in assertAllStagesStopped {
|
||||
for (_ ← 1 to 10) {
|
||||
val (outputStream, result) =
|
||||
StreamConverters.asOutputStream()
|
||||
.toMat(Sink.fold[ByteString, ByteString](ByteString.empty)(_ ++ _))(Keep.both)
|
||||
.run
|
||||
outputStream.write(bytesArray)
|
||||
outputStream.close()
|
||||
result.futureValue should be(ByteString(bytesArray))
|
||||
}
|
||||
}
|
||||
|
||||
"not block flushes when buffer is empty" in assertAllStagesStopped {
|
||||
|
|
@ -134,34 +129,17 @@ class OutputStreamSourceSpec extends StreamSpec(UnboundedMailboxConfig) {
|
|||
the[Exception] thrownBy outputStream.write(bytesArray) shouldBe a[IOException]
|
||||
}
|
||||
|
||||
"use dedicated default-blocking-io-dispatcher by default" in assertAllStagesStopped {
|
||||
val sys = ActorSystem("dispatcher-testing", UnboundedMailboxConfig)
|
||||
val materializer = ActorMaterializer()(sys)
|
||||
|
||||
try {
|
||||
StreamConverters.asOutputStream().runWith(TestSink.probe[ByteString])(materializer)
|
||||
materializer.asInstanceOf[PhasedFusingActorMaterializer].supervisor.tell(StreamSupervisor.GetChildren, testActor)
|
||||
val ref = expectMsgType[Children].children.find(_.path.toString contains "outputStreamSource").get
|
||||
assertDispatcher(ref, "akka.stream.default-blocking-io-dispatcher")
|
||||
} finally shutdown(sys)
|
||||
|
||||
}
|
||||
|
||||
"throw IOException when writing to the stream after the subscriber has cancelled the reactive stream" in assertAllStagesStopped {
|
||||
val sourceProbe = TestProbe()
|
||||
val (outputStream, probe) = TestSourceStage(new OutputStreamSourceStage(timeout), sourceProbe)
|
||||
val (outputStream, sink) = StreamConverters.asOutputStream()
|
||||
.toMat(TestSink.probe[ByteString])(Keep.both).run
|
||||
|
||||
val s = probe.expectSubscription()
|
||||
val s = sink.expectSubscription()
|
||||
|
||||
outputStream.write(bytesArray)
|
||||
s.request(1)
|
||||
sourceProbe.expectMsg(GraphStageMessages.Pull)
|
||||
|
||||
probe.expectNext(byteString)
|
||||
|
||||
sink.expectNext(byteString)
|
||||
s.cancel()
|
||||
sourceProbe.expectMsg(GraphStageMessages.DownstreamFinish)
|
||||
|
||||
awaitAssert {
|
||||
the[Exception] thrownBy outputStream.write(bytesArray) shouldBe a[IOException]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
# OutputStreamSourceStage #25983
|
||||
ProblemFilters.exclude[IncompatibleMethTypeProblem]("akka.stream.impl.io.OutputStreamAdapter.this")
|
||||
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$Flush$")
|
||||
ProblemFilters.exclude[DirectMissingMethodProblem]("akka.stream.impl.io.OutputStreamAdapter.*")
|
||||
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$DownstreamStatus")
|
||||
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$Ok$")
|
||||
ProblemFilters.exclude[MissingClassProblem]("akka.stream.impl.io.OutputStreamSourceStage$Canceled$")
|
||||
|
|
@ -5,30 +5,23 @@
|
|||
package akka.stream.impl.io
|
||||
|
||||
import java.io.{ IOException, OutputStream }
|
||||
import java.util.concurrent.atomic.AtomicReference
|
||||
import java.util.concurrent.{ BlockingQueue, LinkedBlockingQueue }
|
||||
import java.util.concurrent.{ Semaphore, TimeUnit }
|
||||
|
||||
import akka.stream.Attributes.InputBuffer
|
||||
import akka.stream.impl.Stages.DefaultAttributes
|
||||
import akka.stream.impl.io.OutputStreamSourceStage._
|
||||
import akka.stream.stage._
|
||||
import akka.stream.{ ActorMaterializerHelper, Attributes, Outlet, SourceShape }
|
||||
import akka.stream.{ Attributes, Outlet, SourceShape }
|
||||
import akka.util.ByteString
|
||||
|
||||
import scala.concurrent.Await
|
||||
import scala.concurrent.duration.FiniteDuration
|
||||
import scala.concurrent.{ Await, ExecutionContext, Future, Promise }
|
||||
import scala.util.control.NonFatal
|
||||
import scala.util.{ Failure, Success, Try }
|
||||
|
||||
private[stream] object OutputStreamSourceStage {
|
||||
sealed trait AdapterToStageMessage
|
||||
case object Flush extends AdapterToStageMessage
|
||||
case class Send(data: ByteString) extends AdapterToStageMessage
|
||||
case object Close extends AdapterToStageMessage
|
||||
|
||||
sealed trait DownstreamStatus
|
||||
case object Ok extends DownstreamStatus
|
||||
case object Canceled extends DownstreamStatus
|
||||
|
||||
}
|
||||
|
||||
final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration) extends GraphStageWithMaterializedValue[SourceShape[ByteString], OutputStream] {
|
||||
|
|
@ -41,160 +34,57 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration
|
|||
|
||||
require(maxBuffer > 0, "Buffer size must be greater than 0")
|
||||
|
||||
val dataQueue = new LinkedBlockingQueue[ByteString](maxBuffer)
|
||||
val downstreamStatus = new AtomicReference[DownstreamStatus](Ok)
|
||||
// Semaphore counting the number of elements we are ready to accept.
|
||||
// Initially we are ready to accept 'maxBuffer' elements, which will be buffered
|
||||
// by 'emit' if there is no demand yet.
|
||||
// Semaphore permits are taken out of the pool when inserting data into the
|
||||
// OutputStream, and new permits are released when downstream signals demand.
|
||||
val semaphore = new Semaphore(maxBuffer, /* fair =*/ true)
|
||||
|
||||
final class OutputStreamSourceLogic extends GraphStageLogic(shape) {
|
||||
|
||||
var flush: Option[Promise[Unit]] = None
|
||||
var close: Option[Promise[Unit]] = None
|
||||
|
||||
private var dispatcher: ExecutionContext = null // set in preStart
|
||||
private val blockingThreadRef: AtomicReference[Thread] = new AtomicReference() // for postStop interrupt
|
||||
|
||||
private val downstreamCallback: AsyncCallback[Try[ByteString]] =
|
||||
getAsyncCallback {
|
||||
case Success(elem) ⇒ onPush(elem)
|
||||
case Failure(ex) ⇒ failStage(ex)
|
||||
}
|
||||
|
||||
private val upstreamCallback: AsyncCallback[(AdapterToStageMessage, Promise[Unit])] =
|
||||
val upstreamCallback: AsyncCallback[AdapterToStageMessage] =
|
||||
getAsyncCallback(onAsyncMessage)
|
||||
|
||||
def wakeUp(msg: AdapterToStageMessage): Future[Unit] = {
|
||||
val p = Promise[Unit]()
|
||||
upstreamCallback.invoke((msg, p))
|
||||
p.future
|
||||
}
|
||||
|
||||
private def onAsyncMessage(event: (AdapterToStageMessage, Promise[Unit])): Unit =
|
||||
event._1 match {
|
||||
case Flush ⇒
|
||||
flush = Some(event._2)
|
||||
sendResponseIfNeed()
|
||||
private def onAsyncMessage(event: AdapterToStageMessage): Unit = {
|
||||
event match {
|
||||
case Send(data) ⇒
|
||||
emit(out, data, () ⇒ semaphore.release())
|
||||
case Close ⇒
|
||||
close = Some(event._2)
|
||||
sendResponseIfNeed()
|
||||
completeStage()
|
||||
}
|
||||
|
||||
private def unblockUpstream(): Boolean =
|
||||
flush match {
|
||||
case Some(p) ⇒
|
||||
p.complete(Success(()))
|
||||
flush = None
|
||||
true
|
||||
case None ⇒ close match {
|
||||
case Some(p) ⇒
|
||||
downstreamStatus.set(Canceled)
|
||||
p.complete(Success(()))
|
||||
close = None
|
||||
completeStage()
|
||||
true
|
||||
case None ⇒ false
|
||||
}
|
||||
}
|
||||
|
||||
private def sendResponseIfNeed(): Unit =
|
||||
if (downstreamStatus.get() == Canceled || dataQueue.isEmpty) unblockUpstream()
|
||||
|
||||
private def onPush(data: ByteString): Unit =
|
||||
if (downstreamStatus.get() == Ok) {
|
||||
push(out, data)
|
||||
sendResponseIfNeed()
|
||||
}
|
||||
|
||||
override def preStart(): Unit = {
|
||||
// this stage is running on the blocking IO dispatcher by default, but we also want to schedule futures
|
||||
// that are blocking, so we need to look it up
|
||||
val actorMat = ActorMaterializerHelper.downcast(materializer)
|
||||
dispatcher = actorMat.system.dispatchers.lookup(actorMat.settings.blockingIoDispatcher)
|
||||
}
|
||||
|
||||
setHandler(out, new OutHandler {
|
||||
override def onPull(): Unit = {
|
||||
implicit val ec = dispatcher
|
||||
Future {
|
||||
val currentThread = Thread.currentThread()
|
||||
// keep track of the thread for postStop interrupt
|
||||
blockingThreadRef.compareAndSet(null, currentThread)
|
||||
try {
|
||||
dataQueue.take()
|
||||
} catch {
|
||||
case _: InterruptedException ⇒
|
||||
Thread.interrupted()
|
||||
ByteString.empty
|
||||
} finally {
|
||||
blockingThreadRef.compareAndSet(currentThread, null);
|
||||
}
|
||||
}.onComplete(downstreamCallback.invoke)
|
||||
}
|
||||
})
|
||||
|
||||
override def postStop(): Unit = {
|
||||
//assuming there can be no further in messages
|
||||
downstreamStatus.set(Canceled)
|
||||
dataQueue.clear()
|
||||
// if blocked reading, make sure the take() completes
|
||||
dataQueue.put(ByteString.empty)
|
||||
// interrupt any pending blocking take
|
||||
val blockingThread = blockingThreadRef.get()
|
||||
if (blockingThread != null)
|
||||
blockingThread.interrupt()
|
||||
super.postStop()
|
||||
}
|
||||
}
|
||||
|
||||
val logic = new OutputStreamSourceLogic
|
||||
(logic, new OutputStreamAdapter(dataQueue, downstreamStatus, logic.wakeUp, writeTimeout))
|
||||
(logic, new OutputStreamAdapter(semaphore, logic.upstreamCallback, writeTimeout))
|
||||
}
|
||||
}
|
||||
|
||||
private[akka] class OutputStreamAdapter(
|
||||
dataQueue: BlockingQueue[ByteString],
|
||||
downstreamStatus: AtomicReference[DownstreamStatus],
|
||||
sendToStage: (AdapterToStageMessage) ⇒ Future[Unit],
|
||||
writeTimeout: FiniteDuration)
|
||||
unfulfilledDemand: Semaphore,
|
||||
sendToStage: AsyncCallback[AdapterToStageMessage],
|
||||
writeTimeout: FiniteDuration)
|
||||
extends OutputStream {
|
||||
|
||||
var isActive = true
|
||||
var isPublisherAlive = true
|
||||
def publisherClosedException = new IOException("Reactive stream is terminated, no writes are possible")
|
||||
|
||||
@scala.throws(classOf[IOException])
|
||||
private[this] def send(sendAction: () ⇒ Unit): Unit = {
|
||||
if (isActive) {
|
||||
if (isPublisherAlive) sendAction()
|
||||
else throw publisherClosedException
|
||||
} else throw new IOException("OutputStream is closed")
|
||||
private[this] def sendData(data: ByteString): Unit = {
|
||||
if (!unfulfilledDemand.tryAcquire(writeTimeout.toMillis, TimeUnit.MILLISECONDS)) {
|
||||
throw new IOException("Timed out trying to write data to stream")
|
||||
}
|
||||
|
||||
try {
|
||||
Await.result(sendToStage.invokeWithFeedback(Send(data)), writeTimeout)
|
||||
} catch {
|
||||
case NonFatal(e) ⇒ throw new IOException(e)
|
||||
}
|
||||
}
|
||||
|
||||
@scala.throws(classOf[IOException])
|
||||
private[this] def sendData(data: ByteString): Unit =
|
||||
send(() ⇒ {
|
||||
try {
|
||||
dataQueue.put(data)
|
||||
} catch { case NonFatal(ex) ⇒ throw new IOException(ex) }
|
||||
if (downstreamStatus.get() == Canceled) {
|
||||
isPublisherAlive = false
|
||||
throw publisherClosedException
|
||||
}
|
||||
})
|
||||
|
||||
@scala.throws(classOf[IOException])
|
||||
private[this] def sendMessage(message: AdapterToStageMessage, handleCancelled: Boolean = true) =
|
||||
send(() ⇒
|
||||
try {
|
||||
Await.ready(sendToStage(message), writeTimeout)
|
||||
if (downstreamStatus.get() == Canceled && handleCancelled) {
|
||||
//Publisher considered to be terminated at earliest convenience to minimize messages sending back and forth
|
||||
isPublisherAlive = false
|
||||
throw publisherClosedException
|
||||
}
|
||||
} catch {
|
||||
case e: IOException ⇒ throw e
|
||||
case NonFatal(e) ⇒ throw new IOException(e)
|
||||
})
|
||||
|
||||
@scala.throws(classOf[IOException])
|
||||
override def write(b: Int): Unit = {
|
||||
sendData(ByteString(b))
|
||||
|
|
@ -206,11 +96,18 @@ private[akka] class OutputStreamAdapter(
|
|||
}
|
||||
|
||||
@scala.throws(classOf[IOException])
|
||||
override def flush(): Unit = sendMessage(Flush)
|
||||
override def flush(): Unit =
|
||||
// Flushing does nothing: at best we could guarantee that our own buffer
|
||||
// is empty, but that doesn't mean the element has been accepted downstream,
|
||||
// so there is little value in that.
|
||||
()
|
||||
|
||||
@scala.throws(classOf[IOException])
|
||||
override def close(): Unit = {
|
||||
sendMessage(Close, handleCancelled = false)
|
||||
isActive = false
|
||||
try {
|
||||
Await.result(sendToStage.invokeWithFeedback(Close), writeTimeout)
|
||||
} catch {
|
||||
case NonFatal(e) ⇒ throw new IOException(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue