diff --git a/akka-stream-tests/src/test/java/akka/stream/io/OutputStreamSinkTest.java b/akka-stream-tests/src/test/java/akka/stream/io/OutputStreamSinkTest.java index 45ff7ef1fe..af9be0391a 100644 --- a/akka-stream-tests/src/test/java/akka/stream/io/OutputStreamSinkTest.java +++ b/akka-stream-tests/src/test/java/akka/stream/io/OutputStreamSinkTest.java @@ -29,7 +29,7 @@ public class OutputStreamSinkTest extends StreamTest { } @ClassRule - public static AkkaJUnitActorSystemResource actorSystemResource = new AkkaJUnitActorSystemResource("OutputStreamSink", + public static AkkaJUnitActorSystemResource actorSystemResource = new AkkaJUnitActorSystemResource("OutputStreamSinkTest", Utils.UnboundedMailboxConfig()); @Test public void mustSignalFailureViaIoResult() throws Exception { @@ -44,7 +44,7 @@ public class OutputStreamSinkTest extends StreamTest { } }; final CompletionStage resultFuture = Source.single(ByteString.fromString("123456")).runWith(StreamConverters.fromOutputStream(() -> os), materializer); - final IOResult result = resultFuture.toCompletableFuture().get(300, TimeUnit.MILLISECONDS); + final IOResult result = resultFuture.toCompletableFuture().get(3000, TimeUnit.MILLISECONDS); assertFalse(result.wasSuccessful()); assertTrue(result.getError().getMessage().equals("Can't accept more data.")); diff --git a/akka-stream-tests/src/test/java/akka/stream/io/OutputStreamSourceTest.java b/akka-stream-tests/src/test/java/akka/stream/io/OutputStreamSourceTest.java index 85a84d5e37..f2d0847077 100644 --- a/akka-stream-tests/src/test/java/akka/stream/io/OutputStreamSourceTest.java +++ b/akka-stream-tests/src/test/java/akka/stream/io/OutputStreamSourceTest.java @@ -29,11 +29,11 @@ public class OutputStreamSourceTest extends StreamTest { } @ClassRule - public static AkkaJUnitActorSystemResource actorSystemResource = new AkkaJUnitActorSystemResource("OutputStreamSource", + public static AkkaJUnitActorSystemResource actorSystemResource = new AkkaJUnitActorSystemResource("OutputStreamSourceTest2", Utils.UnboundedMailboxConfig()); @Test public void mustSendEventsViaOutputStream() throws Exception { - final FiniteDuration timeout = FiniteDuration.create(300, TimeUnit.MILLISECONDS); + final FiniteDuration timeout = FiniteDuration.create(3000, TimeUnit.MILLISECONDS); final JavaTestKit probe = new JavaTestKit(system); final Source source = StreamConverters.asOutputStream(timeout); @@ -45,6 +45,8 @@ public class OutputStreamSourceTest extends StreamTest { })).run(materializer); s.write("a".getBytes()); + + assertEquals(ByteString.fromString("a"), probe.receiveOne(timeout)); s.close(); diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala index a6c11c7f03..56c5c2dd9d 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/OutputStreamSourceSpec.scala @@ -31,7 +31,7 @@ class OutputStreamSourceSpec extends AkkaSpec(UnboundedMailboxConfig) { val settings = ActorMaterializerSettings(system).withDispatcher("akka.actor.default-dispatcher") implicit val materializer = ActorMaterializer(settings) - val timeout = 300.milliseconds + val timeout = 3.seconds val bytesArray = Array.fill[Byte](3)(Random.nextInt(1024).asInstanceOf[Byte]) val byteString = ByteString(bytesArray) @@ -41,6 +41,16 @@ class OutputStreamSourceSpec extends AkkaSpec(UnboundedMailboxConfig) { def expectSuccess[T](f: Future[T], value: T) = Await.result(f, remainingOrDefault) should be(value) + def assertNoBlockedThreads(): Unit = { + def threadsBlocked = + ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).toSeq + .filter(t ⇒ t.getThreadName.startsWith("OutputStreamSourceSpec") && + t.getLockName != null && + t.getLockName.startsWith("java.util.concurrent.locks.AbstractQueuedSynchronizer")) + + awaitAssert(threadsBlocked should ===(Seq()), 3.seconds) + } + "OutputStreamSource" must { "read bytes from OutputStream" in assertAllStagesStopped { val (outputStream, probe) = StreamConverters.asOutputStream().toMat(TestSink.probe[ByteString])(Keep.both).run @@ -156,11 +166,11 @@ class OutputStreamSourceSpec extends AkkaSpec(UnboundedMailboxConfig) { .withAttributes(inputBuffer(0, 0)) .runWith(Sink.head) /* - With Sink.head we test the code path in which the source - itself throws an exception when being materialized. If - Sink.ignore is used, the same exception is thrown by - Materializer. - */ + With Sink.head we test the code path in which the source + itself throws an exception when being materialized. If + Sink.ignore is used, the same exception is thrown by + Materializer. + */ } } @@ -175,13 +185,22 @@ class OutputStreamSourceSpec extends AkkaSpec(UnboundedMailboxConfig) { sub.request(1) sub.cancel() - def threadsBlocked = - ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).toSeq - .filter(t ⇒ t.getThreadName.startsWith("OutputStreamSourceSpec") && - t.getLockName != null && - t.getLockName.startsWith("java.util.concurrent.locks.AbstractQueuedSynchronizer")) + assertNoBlockedThreads() + } - awaitAssert(threadsBlocked should ===(Seq()), 3.seconds) + "not leave blocked threads when materializer shutdown" in { + val materializer2 = ActorMaterializer(settings) + val (outputStream, probe) = StreamConverters.asOutputStream(timeout) + .toMat(TestSink.probe[ByteString])(Keep.both).run()(materializer2) + + val sub = probe.expectSubscription() + + // triggers a blocking read on the queue + // and then shutdown the materializer before we got anything + sub.request(1) + materializer2.shutdown() + + assertNoBlockedThreads() } } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala b/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala index 7888f51227..15b0a03d1b 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/OutputStreamSourceStage.scala @@ -6,18 +6,21 @@ package akka.stream.impl.io import java.io.{ IOException, OutputStream } import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.{ BlockingQueue, LinkedBlockingQueue } - import akka.stream.{ Outlet, SourceShape, Attributes } import akka.stream.Attributes.InputBuffer import akka.stream.impl.Stages.DefaultAttributes import akka.stream.impl.io.OutputStreamSourceStage._ import akka.stream.stage._ import akka.util.ByteString - import scala.concurrent.duration.FiniteDuration import scala.concurrent.{ Await, Future, Promise } import scala.util.control.NonFatal import scala.util.{ Failure, Success, Try } +import akka.stream.ActorAttributes +import akka.stream.impl.Stages.DefaultAttributes.IODispatcher +import akka.stream.ActorAttributes.Dispatcher +import scala.concurrent.ExecutionContext +import akka.stream.ActorMaterializer private[stream] object OutputStreamSourceStage { sealed trait AdapterToStageMessage @@ -40,6 +43,9 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, OutputStream) = { val maxBuffer = inheritedAttributes.getAttribute(classOf[InputBuffer], InputBuffer(16, 16)).max + + val dispatcherId = inheritedAttributes.get[Dispatcher](IODispatcher).dispatcher + require(maxBuffer > 0, "Buffer size must be greater than 0") val dataQueue = new LinkedBlockingQueue[ByteString](maxBuffer) @@ -49,6 +55,9 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration var flush: Option[Promise[Unit]] = None var close: Option[Promise[Unit]] = None + private var dispatcher: ExecutionContext = null // set in preStart + private var blockingThread: Thread = null // for postStop interrupt + private val downstreamCallback: AsyncCallback[Try[ByteString]] = getAsyncCallback { case Success(elem) ⇒ onPush(elem) @@ -102,6 +111,11 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration sendResponseIfNeed() } + override def preStart(): Unit = { + dispatcher = ActorMaterializer.downcast(materializer).system.dispatchers.lookup(dispatcherId) + super.preStart() + } + setHandler(out, new OutHandler { override def onDownstreamFinish(): Unit = { //assuming there can be no further in messages @@ -112,10 +126,29 @@ final private[stream] class OutputStreamSourceStage(writeTimeout: FiniteDuration completeStage() } override def onPull(): Unit = { - implicit val ex = interpreter.materializer.executionContext - Future(dataQueue.take()).onComplete(downstreamCallback.invoke) + implicit val ec = dispatcher + Future { + // keep track of the thread for postStop interrupt + blockingThread = Thread.currentThread() + try { + dataQueue.take() + } catch { + case _: InterruptedException ⇒ + Thread.interrupted() + ByteString() + } finally { + blockingThread = null + } + }.onComplete(downstreamCallback.invoke) } }) + + override def postStop(): Unit = { + // interrupt any pending blocking take + if (blockingThread != null) + blockingThread.interrupt() + super.postStop() + } } (logic, new OutputStreamAdapter(dataQueue, downstreamStatus, logic.wakeUp, writeTimeout)) }