diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala index 179d1c83ea..91a02beb81 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala @@ -78,7 +78,7 @@ private[http] object OutgoingConnectionBlueprint { val terminationFanout = b.add(Broadcast[HttpResponse](2)) val terminationMerge = b.add(TerminationMerge) - val logger = b.add(Flow[ByteString].transform(() ⇒ errorHandling((t: Throwable) ⇒ log.error(t, "Outgoing request stream error"))).named("errorLogger")) + val logger = b.add(ErrorHandling[ByteString]((t: Throwable) ⇒ log.error(t, "Outgoing request stream error")).named("errorLogger")) val wrapTls = b.add(Flow[ByteString].map(SendBytes)) terminationMerge.out ~> requestRendering ~> logger ~> wrapTls diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index 868b6be588..caff5e35ca 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -248,7 +248,7 @@ private[http] object HttpServerBluePrint { Flow[ResponseRenderingContext] .via(responseRendererFactory.renderer.named("renderer")) - .via(Flow[ResponseRenderingOutput].transform(() ⇒ errorHandling(errorHandler)).named("errorLogger")) + .via(ErrorHandling[ResponseRenderingOutput](errorHandler).named("errorLogger")) } class RequestTimeoutSupport(initialTimeout: FiniteDuration) diff --git a/akka-http-core/src/main/scala/akka/http/impl/util/package.scala b/akka-http-core/src/main/scala/akka/http/impl/util/package.scala index 6edebededb..76f108d695 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/util/package.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/util/package.scala @@ -5,6 +5,7 @@ package akka.http.impl import akka.NotUsed +import akka.stream.{ Attributes, Outlet, Inlet, FlowShape } import language.implicitConversions import java.nio.charset.Charset @@ -88,15 +89,6 @@ package object util { } } - private[http] def errorHandling[T](handler: Throwable ⇒ Unit): PushStage[T, T] = - new PushStage[T, T] { - override def onPush(element: T, ctx: Context[T]): SyncDirective = ctx.push(element) - override def onUpstreamFailure(cause: Throwable, ctx: Context[T]): TerminationDirective = { - handler(cause) - super.onUpstreamFailure(cause, ctx) - } - } - private[http] def humanReadableByteCount(bytes: Long, si: Boolean): String = { val unit = if (si) 1000 else 1024 if (bytes >= unit) { @@ -110,9 +102,26 @@ package object util { package util { import akka.http.scaladsl.model.{ ContentType, HttpEntity } + import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage import akka.stream.{ Attributes, Outlet, Inlet, FlowShape } import scala.concurrent.duration.FiniteDuration + private[http] final case class ErrorHandling[T](handler: Throwable ⇒ Unit) extends SimpleLinearGraphStage[T] { + override def createLogic(attr: Attributes) = + new GraphStageLogic(shape) with InHandler with OutHandler { + override def onPush(): Unit = push(out, grab(in)) + + override def onUpstreamFailure(ex: Throwable): Unit = { + handler(ex) + super.onUpstreamFailure(ex) + } + + override def onPull(): Unit = pull(in) + + setHandlers(in, out, this) + } + } + private[http] class ToStrict(timeout: FiniteDuration, contentType: ContentType) extends GraphStage[FlowShape[ByteString, HttpEntity.Strict]] {