diff --git a/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala b/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala index f27570c71a..629c5c3ba5 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/util/StreamUtils.scala @@ -63,17 +63,25 @@ private[http] object StreamUtils { def captureTermination[T, Mat](source: Source[T, Mat]): (Source[T, Mat], Future[Unit]) = { val promise = Promise[Unit]() - val transformer = new PushStage[T, T] { - def onPush(element: T, ctx: Context[T]) = ctx.push(element) - override def onUpstreamFailure(cause: Throwable, ctx: Context[T]) = { - promise.failure(cause) - ctx.fail(cause) - } - override def postStop(): Unit = { - promise.trySuccess(()) + val transformer = new SimpleLinearGraphStage[T] { + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { + override def onPush(): Unit = push(out, grab(in)) + + override def onPull(): Unit = pull(in) + + override def onUpstreamFailure(ex: Throwable): Unit = { + promise.failure(ex) + failStage(ex) + } + + override def postStop(): Unit = { + promise.trySuccess(()) + } + + setHandlers(in, out, this) } } - source.transform(() ⇒ transformer) -> promise.future + source.via(transformer) -> promise.future } def sliceBytesTransformer(start: Long, length: Long): Flow[ByteString, ByteString, NotUsed] = {