diff --git a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala index 38689642c0..f801242014 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Stages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Stages.scala @@ -153,10 +153,6 @@ private[stream] object Stages { override def create(attr: Attributes): Stage[T, T] = fusing.Log(name, extract, loggingAdapter, supervision(attr)) } - final case class Recover[In, Out >: In](pf: PartialFunction[Throwable, Out], attributes: Attributes = recover) extends SymbolicStage[In, Out] { - override def create(attr: Attributes): Stage[In, Out] = fusing.Recover(pf) - } - final case class Grouped[T](n: Int, attributes: Attributes = grouped) extends SymbolicStage[T, immutable.Seq[T]] { require(n > 0, "n must be greater than 0") override def create(attr: Attributes): Stage[T, immutable.Seq[T]] = fusing.Grouped(n) diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 858d2c4985..e16fae74a2 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -194,29 +194,48 @@ private[stream] final case class Collect[In, Out](pf: PartialFunction[In, Out]) /** * INTERNAL API */ -private[akka] final case class Recover[T](pf: PartialFunction[Throwable, T]) extends PushPullStage[T, T] { - import Collect.NotApplied - var recovered: Option[T] = None +private[akka] final case class Recover[T](pf: PartialFunction[Throwable, T]) extends GraphStage[FlowShape[T, T]] { + val in = Inlet[T]("Recover.in") + val out = Outlet[T]("Recover.out") + override val shape: FlowShape[T, T] = FlowShape(in, out) - override def onPush(elem: T, ctx: Context[T]): SyncDirective = { - ctx.push(elem) - } + override protected val initialAttributes: Attributes = DefaultAttributes.recover - override def onPull(ctx: Context[T]): SyncDirective = - recovered match { - case Some(value) ⇒ ctx.pushAndFinish(value) - case None ⇒ ctx.pull() + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { + import Collect.NotApplied + + var recovered: Option[T] = None + + override def onPush(): Unit = { + push(out, grab(in)) } - override def onUpstreamFailure(t: Throwable, ctx: Context[T]): TerminationDirective = { - pf.applyOrElse(t, NotApplied) match { - case NotApplied ⇒ ctx.fail(t) - case result: T @unchecked ⇒ - recovered = Some(result) - ctx.absorbTermination() + override def onPull(): Unit = { + recovered match { + case Some(elem) ⇒ { + push(out, elem) + completeStage() + } + case None ⇒ pull(in) + } } - } + override def onUpstreamFailure(ex: Throwable): Unit = { + pf.applyOrElse(ex, NotApplied) match { + case NotApplied ⇒ failStage(ex) + case result: T @unchecked ⇒ { + if (isAvailable(out)) { + push(out, result) + completeStage() + } else { + recovered = Some(result) + } + } + } + } + + setHandlers(in, out, this) + } } /** diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala index 4538b813a2..5c6adafeb5 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Flow.scala @@ -425,7 +425,7 @@ trait FlowOps[+Out, +Mat] { * '''Cancels when''' downstream cancels * */ - def recover[T >: Out](pf: PartialFunction[Throwable, T]): Repr[T] = andThen(Recover(pf)) + def recover[T >: Out](pf: PartialFunction[Throwable, T]): Repr[T] = via(Recover(pf)) /** * RecoverWith allows to switch to alternative Source on flow failure. It will stay in effect after