diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index 2216cb5b43..c0337777bb 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -371,6 +371,11 @@ private[http] object HttpServerBluePrint { def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { import akka.http.impl.engine.rendering.ResponseRenderingOutput._ + /* + * These handlers are in charge until a switch command comes in, then they + * are replaced. + */ + setHandler(fromHttp, new InHandler { override def onPush(): Unit = grab(fromHttp) match { @@ -381,21 +386,22 @@ private[http] object HttpServerBluePrint { cancel(fromHttp) switchToWebsocket(handlerFlow) } + override def onUpstreamFinish(): Unit = complete(toNet) + override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex) }) setHandler(toNet, new OutHandler { override def onPull(): Unit = pull(fromHttp) + override def onDownstreamFinish(): Unit = completeStage() }) setHandler(fromNet, new InHandler { - def onPush(): Unit = push(toHttp, grab(fromNet)) - - // propagate error but don't close stage yet to prevent fromHttp/fromWs being cancelled - // too eagerly + override def onPush(): Unit = push(toHttp, grab(fromNet)) + override def onUpstreamFinish(): Unit = complete(toHttp) override def onUpstreamFailure(ex: Throwable): Unit = fail(toHttp, ex) }) setHandler(toHttp, new OutHandler { override def onPull(): Unit = pull(fromNet) - override def onDownstreamFinish(): Unit = () + override def onDownstreamFinish(): Unit = cancel(fromNet) }) private var activeTimers = 0 @@ -438,13 +444,22 @@ private[http] object HttpServerBluePrint { sinkIn.setHandler(new InHandler { override def onPush(): Unit = push(toNet, sinkIn.grab()) + override def onUpstreamFinish(): Unit = complete(toNet) + override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex) }) setHandler(toNet, new OutHandler { override def onPull(): Unit = sinkIn.pull() + override def onDownstreamFinish(): Unit = { + completeStage() + sinkIn.cancel() + sourceOut.complete() + } }) setHandler(fromNet, new InHandler { override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes) + override def onUpstreamFinish(): Unit = sourceOut.complete() + override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex) }) sourceOut.setHandler(new OutHandler { override def onPull(): Unit = { @@ -454,6 +469,7 @@ private[http] object HttpServerBluePrint { override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet) }) } + override def onDownstreamFinish(): Unit = cancel(fromNet) }) Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/BypassRouterSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/BypassRouterSpec.scala deleted file mode 100644 index 89be18a5da..0000000000 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/BypassRouterSpec.scala +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright (C) 2015 Typesafe Inc. - */ -package akka.http.impl.engine.ws - -import akka.stream.testkit.AkkaSpec -import scala.concurrent.Await -import com.typesafe.config.ConfigFactory -import com.typesafe.config.Config -import akka.actor.ActorSystem -import akka.http.scaladsl.model.HttpRequest -import akka.http.scaladsl.model.ws._ -import akka.http.scaladsl._ -import akka.stream.scaladsl._ -import akka.stream._ -import scala.concurrent.duration._ -import org.scalatest.concurrent.ScalaFutures -import org.scalactic.ConversionCheckedTripleEquals -import akka.stream.testkit.Utils - -class BypassRouterSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode = off") with ScalaFutures with ConversionCheckedTripleEquals { - - implicit val patience = PatienceConfig(3.seconds) - import system.dispatcher - implicit val materializer = ActorMaterializer() - - "BypassRouter" must { - - "work without double pull-ing some ports" in Utils.assertAllStagesStopped { - val bindingFuture = Http().bindAndHandleSync({ - case HttpRequest(_, _, headers, _, _) ⇒ - val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get - upgrade.handleMessages(Flow.apply, None) - }, interface = "localhost", port = 8080) - val binding = Await.result(bindingFuture, 3.seconds) - - val N = 100 - val (response, count) = Http().singleWebsocketRequest( - WebsocketRequest("ws://127.0.0.1:8080"), - Flow.fromSinkAndSourceMat( - Sink.fold(0)((n, _: Message) ⇒ n + 1), - Source.repeat(TextMessage("hello")).take(N))(Keep.left)) - - count.futureValue should ===(N) - binding.unbind() - } - - } - -} diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala new file mode 100644 index 0000000000..f389f88c72 --- /dev/null +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/ws/WebsocketIntegrationSpec.scala @@ -0,0 +1,101 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.http.impl.engine.ws + +import scala.concurrent.Await +import scala.concurrent.duration.DurationInt +import org.scalactic.ConversionCheckedTripleEquals +import org.scalatest.concurrent.ScalaFutures +import org.scalatest.time.Span.convertDurationToSpan +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.HttpRequest +import akka.http.scaladsl.model.Uri.apply +import akka.http.scaladsl.model.ws._ +import akka.stream._ +import akka.stream.scaladsl._ +import akka.stream.testkit._ +import akka.stream.scaladsl.GraphDSL.Implicits._ +import org.scalatest.concurrent.Eventually +import akka.stream.io.SslTlsPlacebo +import java.net.InetSocketAddress +import akka.stream.impl.fusing.GraphStages +import akka.util.ByteString + +class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode=off") + with ScalaFutures with ConversionCheckedTripleEquals with Eventually { + + implicit val patience = PatienceConfig(3.seconds) + import system.dispatcher + implicit val materializer = ActorMaterializer() + + "A Websocket server" must { + + "echo 100 elements and then shut down without error" in Utils.assertAllStagesStopped { + val bindingFuture = Http().bindAndHandleSync({ + case HttpRequest(_, _, headers, _, _) ⇒ + val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get + upgrade.handleMessages(Flow.apply, None) + }, interface = "localhost", port = 8080) + val binding = Await.result(bindingFuture, 3.seconds) + + val N = 100 + val (response, count) = Http().singleWebsocketRequest( + WebsocketRequest("ws://127.0.0.1:8080"), + Flow.fromSinkAndSourceMat( + Sink.fold(0)((n, _: Message) ⇒ n + 1), + Source.repeat(TextMessage("hello")).take(N))(Keep.left)) + + count.futureValue should ===(N) + binding.unbind() + } + + "send back 100 elements and then terminate without error even when not ordinarily closed" in Utils.assertAllStagesStopped { + val N = 100 + + val handler = Flow.fromGraph(GraphDSL.create() { implicit b ⇒ + val merge = b.add(Merge[Int](2)) + + // convert to int so we can connect to merge + val mapMsgToInt = b.add(Flow[Message].map(_ ⇒ -1)) + val mapIntToMsg = b.add(Flow[Int].map(x ⇒ TextMessage.Strict(s"Sending: $x"))) + + // source we want to use to send message to the connected websocket sink + val rangeSource = b.add(Source(1 to N)) + + mapMsgToInt ~> merge // this part of the merge will never provide msgs + rangeSource ~> merge ~> mapIntToMsg + + FlowShape(mapMsgToInt.in, mapIntToMsg.out) + }) + + val bindingFuture = Http().bindAndHandleSync({ + case HttpRequest(_, _, headers, _, _) ⇒ + val upgrade = headers.collectFirst { case u: UpgradeToWebsocket ⇒ u }.get + upgrade.handleMessages(handler, None) + }, interface = "localhost", port = 8080) + val binding = Await.result(bindingFuture, 3.seconds) + + @volatile var messages = 0 + val (breaker, completion) = + Source.maybe + .viaMat { + Http().websocketClientLayer(WebsocketRequest("ws://localhost:8080")) + .atop(SslTlsPlacebo.forScala) + // the resource leak of #19398 existed only for severed websocket connections + .atopMat(GraphStages.bidiBreaker[ByteString, ByteString])(Keep.right) + .join(Tcp().outgoingConnection(new InetSocketAddress("localhost", 8080), halfClose = true)) + }(Keep.right) + .toMat(Sink.foreach(_ ⇒ messages += 1))(Keep.both) + .run() + eventually(messages should ===(N)) + // breaker should have been fulfilled long ago + breaker.value.get.get.complete() + completion.futureValue + + binding.unbind() + } + + } + +} diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala index ea70aaf273..359a46c7ef 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/GraphStages.scala @@ -105,6 +105,89 @@ object GraphStages { private val _detacher = new Detacher[Any] def detacher[T]: GraphStage[FlowShape[T, T]] = _detacher.asInstanceOf[GraphStage[FlowShape[T, T]]] + final class Breaker(callback: Breaker.Operation ⇒ Unit) { + import Breaker._ + def complete(): Unit = callback(Complete) + def fail(ex: Throwable): Unit = callback(Fail(ex)) + } + + object Breaker extends GraphStageWithMaterializedValue[FlowShape[Any, Any], Future[Breaker]] { + sealed trait Operation + case object Complete extends Operation + case class Fail(ex: Throwable) extends Operation + + override val initialAttributes = Attributes.name("breaker") + override val shape = FlowShape(Inlet[Any]("breaker.in"), Outlet[Any]("breaker.out")) + + override def createLogicAndMaterializedValue(attr: Attributes) = { + val promise = Promise[Breaker] + + val logic = new GraphStageLogic(shape) { + + passAlong(shape.in, shape.out) + setHandler(shape.out, eagerTerminateOutput) + + override def preStart(): Unit = { + pull(shape.in) + promise.success(new Breaker(getAsyncCallback[Operation] { + case Complete ⇒ completeStage() + case Fail(ex) ⇒ failStage(ex) + }.invoke)) + } + } + + (logic, promise.future) + } + } + + def breaker[T]: Graph[FlowShape[T, T], Future[Breaker]] = Breaker.asInstanceOf[Graph[FlowShape[T, T], Future[Breaker]]] + + object BidiBreaker extends GraphStageWithMaterializedValue[BidiShape[Any, Any, Any, Any], Future[Breaker]] { + import Breaker._ + + override val initialAttributes = Attributes.name("breaker") + override val shape = BidiShape( + Inlet[Any]("breaker.in1"), Outlet[Any]("breaker.out1"), + Inlet[Any]("breaker.in2"), Outlet[Any]("breaker.out2")) + + override def createLogicAndMaterializedValue(attr: Attributes) = { + val promise = Promise[Breaker] + + val logic = new GraphStageLogic(shape) { + + setHandler(shape.in1, new InHandler { + override def onPush(): Unit = push(shape.out1, grab(shape.in1)) + override def onUpstreamFinish(): Unit = complete(shape.out1) + override def onUpstreamFailure(ex: Throwable): Unit = fail(shape.out1, ex) + }) + setHandler(shape.in2, new InHandler { + override def onPush(): Unit = push(shape.out2, grab(shape.in2)) + override def onUpstreamFinish(): Unit = complete(shape.out2) + override def onUpstreamFailure(ex: Throwable): Unit = fail(shape.out2, ex) + }) + setHandler(shape.out1, new OutHandler { + override def onPull(): Unit = pull(shape.in1) + override def onDownstreamFinish(): Unit = cancel(shape.in1) + }) + setHandler(shape.out2, new OutHandler { + override def onPull(): Unit = pull(shape.in2) + override def onDownstreamFinish(): Unit = cancel(shape.in2) + }) + + override def preStart(): Unit = { + promise.success(new Breaker(getAsyncCallback[Operation] { + case Complete ⇒ completeStage() + case Fail(ex) ⇒ failStage(ex) + }.invoke)) + } + } + + (logic, promise.future) + } + } + + def bidiBreaker[T1, T2]: Graph[BidiShape[T1, T1, T2, T2], Future[Breaker]] = BidiBreaker.asInstanceOf[Graph[BidiShape[T1, T1, T2, T2], Future[Breaker]]] + private object TickSource { class TickSourceCancellable(cancelled: AtomicBoolean) extends Cancellable { private val cancelPromise = Promise[Unit]() diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala index 18b324244a..97b562beae 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala @@ -261,11 +261,13 @@ private[stream] object TcpConnectionStage { // (or half-close is turned off) if (isClosed(bytesOut) || !role.halfClose) connection ! Close // We still read, so we only close the write side - else connection ! ConfirmedClose + else if (connection != null) connection ! ConfirmedClose + else completeStage() } override def onUpstreamFailure(ex: Throwable): Unit = { - connection ! Abort + if (connection != null) connection ! Abort + else failStage(ex) } })