diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/UnfoldResourceAsyncSourceSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/UnfoldResourceAsyncSourceSpec.scala index b1d8ed5a6b..41889e3fc6 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/UnfoldResourceAsyncSourceSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/UnfoldResourceAsyncSourceSpec.scala @@ -3,7 +3,7 @@ */ package akka.stream.scaladsl -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.{ AtomicBoolean, AtomicInteger } import akka.Done import akka.actor.ActorSystem @@ -245,6 +245,64 @@ class UnfoldResourceAsyncSourceSpec extends StreamSpec(UnboundedMailboxConfig) { startCount.get should ===(2) } + "fail stream when restarting and close throws" in assertAllStagesStopped { + val out = TestSubscriber.probe[Int]() + Source.unfoldResourceAsync[Int, Iterator[Int]]( + () ⇒ Future.successful(List(1, 2, 3).iterator), + reader ⇒ throw TE("read-error"), + _ ⇒ throw new TE("close-error") + ).withAttributes(ActorAttributes.supervisionStrategy(Supervision.restartingDecider)) + .runWith(Sink.fromSubscriber(out)) + + out.request(1) + out.expectError().getMessage should ===("close-error") + } + + "fail stream when restarting and close returns failed future" in assertAllStagesStopped { + val out = TestSubscriber.probe[Int]() + Source.unfoldResourceAsync[Int, Iterator[Int]]( + () ⇒ Future.successful(List(1, 2, 3).iterator), + reader ⇒ throw TE("read-error"), + _ ⇒ Future.failed(new TE("close-error")) + ).withAttributes(ActorAttributes.supervisionStrategy(Supervision.restartingDecider)) + .runWith(Sink.fromSubscriber(out)) + + out.request(1) + out.expectError().getMessage should ===("close-error") + } + + "fail stream when restarting and start throws" in assertAllStagesStopped { + val startCounter = new AtomicInteger(0) + val out = TestSubscriber.probe[Int]() + Source.unfoldResourceAsync[Int, Iterator[Int]]( + () ⇒ + if (startCounter.incrementAndGet() < 2) Future.successful(List(1, 2, 3).iterator) + else throw TE("start-error"), + reader ⇒ throw TE("read-error"), + _ ⇒ Future.successful(Done) + ).withAttributes(ActorAttributes.supervisionStrategy(Supervision.restartingDecider)) + .runWith(Sink.fromSubscriber(out)) + + out.request(1) + out.expectError().getMessage should ===("start-error") + } + + "fail stream when restarting and start returns failed future" in assertAllStagesStopped { + val startCounter = new AtomicInteger(0) + val out = TestSubscriber.probe[Int]() + Source.unfoldResourceAsync[Int, Iterator[Int]]( + () ⇒ + if (startCounter.incrementAndGet() < 2) Future.successful(List(1, 2, 3).iterator) + else Future.failed(TE("start-error")), + reader ⇒ throw TE("read-error"), + _ ⇒ Future.successful(Done) + ).withAttributes(ActorAttributes.supervisionStrategy(Supervision.restartingDecider)) + .runWith(Sink.fromSubscriber(out)) + + out.request(1) + out.expectError().getMessage should ===("start-error") + } + "use dedicated blocking-io-dispatcher by default" in assertAllStagesStopped { val sys = ActorSystem("dispatcher-testing", UnboundedMailboxConfig) val materializer = ActorMaterializer()(sys) diff --git a/akka-stream/src/main/scala/akka/stream/impl/UnfoldResourceSourceAsync.scala b/akka-stream/src/main/scala/akka/stream/impl/UnfoldResourceSourceAsync.scala index 4c9aee91a8..dfc8f0e03a 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/UnfoldResourceSourceAsync.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/UnfoldResourceSourceAsync.scala @@ -5,14 +5,13 @@ package akka.stream.impl import akka.Done import akka.annotation.InternalApi -import akka.dispatch.ExecutionContexts import akka.stream.ActorAttributes.SupervisionStrategy import akka.stream._ import akka.stream.impl.Stages.DefaultAttributes import akka.stream.stage._ -import scala.concurrent.{ Future, Promise } -import scala.util.Try +import scala.concurrent.Future +import scala.util.{ Failure, Success, Try } import scala.util.control.NonFatal /** @@ -28,82 +27,85 @@ import scala.util.control.NonFatal def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) with OutHandler { lazy val decider = inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider - var resource = Promise[S]() - var open = false - implicit val context = ExecutionContexts.sameThreadExecutionContext + private implicit def ec = ActorMaterializerHelper.downcast(materializer).system.dispatcher + private var state: Option[S] = None - setHandler(out, this) + private val createdCallback = getAsyncCallback[Try[S]] { + case Success(resource) ⇒ + state = Some(resource) + if (isAvailable(out)) onPull() + case Failure(t) ⇒ failStage(t) + }.invoke _ - override def preStart(): Unit = createStream(false) - - private def createStream(withPull: Boolean): Unit = { - val createdCallback = getAsyncCallback[Try[S]] { - case scala.util.Success(res) ⇒ - open = true - resource.success(res) - if (withPull) onPull() - case scala.util.Failure(t) ⇒ failStage(t) - } - try { - create().onComplete(createdCallback.invoke) - } catch { - case NonFatal(ex) ⇒ failStage(ex) - } - } - - private def onResourceReady(f: (S) ⇒ Unit): Unit = resource.future.foreach(f) - - val errorHandler: PartialFunction[Throwable, Unit] = { + private val errorHandler: PartialFunction[Throwable, Unit] = { case NonFatal(ex) ⇒ decider(ex) match { case Supervision.Stop ⇒ - onResourceReady(close(_)) failStage(ex) - case Supervision.Restart ⇒ restartState() + case Supervision.Restart ⇒ restartResource() case Supervision.Resume ⇒ onPull() } } - val readCallback = getAsyncCallback[Try[Option[T]]] { - case scala.util.Success(data) ⇒ data match { + private val readCallback = getAsyncCallback[Try[Option[T]]] { + case Success(data) ⇒ data match { case Some(d) ⇒ push(out, d) - case None ⇒ closeStage() + case None ⇒ + // end of resource reached, lets close it + state match { + case Some(resource) ⇒ + close(resource).onComplete(getAsyncCallback[Try[Done]] { + case Success(Done) ⇒ completeStage() + case Failure(ex) ⇒ failStage(ex) + }.invoke) + state = None + + case None ⇒ + // cannot happen, but for good measure + throw new IllegalStateException("Reached end of data but there is no open resource") + } } - case scala.util.Failure(t) ⇒ errorHandler(t) + case Failure(t) ⇒ errorHandler(t) }.invoke _ - final override def onPull(): Unit = - onResourceReady { resource ⇒ - try { readData(resource).onComplete(readCallback) } catch errorHandler + override def preStart(): Unit = createResource() + + override def onPull(): Unit = + state match { + case Some(resource) ⇒ + try { + readData(resource).onComplete(readCallback) + } catch errorHandler + case None ⇒ + // we got a pull but there is no open resource, we are either + // currently creating/restarting then the read will be triggered when creating the + // resource completes, or shutting down and then the pull does not matter anyway } - override def onDownstreamFinish(): Unit = closeStage() - - private def closeAndThen(f: () ⇒ Unit): Unit = { - setKeepGoing(true) - val closedCallback = getAsyncCallback[Try[Done]] { - case scala.util.Success(_) ⇒ - open = false - f() - case scala.util.Failure(t) ⇒ - open = false - failStage(t) - } - - onResourceReady(res ⇒ - try { close(res).onComplete(closedCallback.invoke) } catch { - case NonFatal(ex) ⇒ failStage(ex) - }) - } - private def restartState(): Unit = closeAndThen(() ⇒ { - resource = Promise[S]() - createStream(true) - }) - private def closeStage(): Unit = closeAndThen(completeStage) - override def postStop(): Unit = { - if (open) closeStage() + state.foreach(r ⇒ close(r)) } + private def restartResource(): Unit = { + state match { + case Some(resource) ⇒ + // wait for the resource to close before restarting + close(resource).onComplete(getAsyncCallback[Try[Done]] { + case Success(Done) ⇒ + createResource() + case Failure(ex) ⇒ failStage(ex) + }.invoke) + state = None + case None ⇒ + createResource() + } + } + + private def createResource(): Unit = { + create().onComplete(createdCallback) + } + + setHandler(out, this) + } override def toString = "UnfoldResourceSourceAsync"