Merge pull request #19488 from akka/wip-19398-forward-port-RK

multiple HTTP logic fixes
This commit is contained in:
drewhk 2016-01-19 12:29:17 +01:00
commit 4de3f63b93
9 changed files with 397 additions and 121 deletions

View file

@ -167,7 +167,6 @@ private[http] object OutgoingConnectionBlueprint {
// each connection uses a single (private) response parser instance for all its responses
// which builds a cache of all header instances seen on that connection
val parser = rootParser.createShallowCopy()
var methodBypassCompleted = false
var waitingForMethod = true
setHandler(methodBypassInput, new InHandler {
@ -179,7 +178,6 @@ private[http] object OutgoingConnectionBlueprint {
}
override def onUpstreamFinish(): Unit =
if (waitingForMethod) completeStage()
else methodBypassCompleted = true
})
setHandler(dataInput, new InHandler {
@ -201,17 +199,16 @@ private[http] object OutgoingConnectionBlueprint {
setHandler(out, eagerTerminateOutput)
val getNextMethod = ()
if (methodBypassCompleted) completeStage()
else {
pull(methodBypassInput)
waitingForMethod = true
}
val getNextMethod = () {
waitingForMethod = true
if (isClosed(methodBypassInput)) completeStage()
else pull(methodBypassInput)
}
val getNextData = () {
waitingForMethod = false
if (!isClosed(dataInput)) pull(dataInput)
else completeStage()
if (isClosed(dataInput)) completeStage()
else pull(dataInput)
}
@tailrec def drainParser(current: ResponseOutput, b: ListBuffer[ResponseOutput] = ListBuffer.empty): Unit = {
@ -219,13 +216,10 @@ private[http] object OutgoingConnectionBlueprint {
if (output.nonEmpty) emit(out, output, andThen)
else andThen()
current match {
case NeedNextRequestMethod
e(b.result(), getNextMethod)
case StreamEnd
e(b.result(), () completeStage())
case NeedMoreData
e(b.result(), getNextData)
case x drainParser(parser.onPull(), b += x)
case NeedNextRequestMethod e(b.result(), getNextMethod)
case StreamEnd e(b.result(), () completeStage())
case NeedMoreData e(b.result(), getNextData)
case x drainParser(parser.onPull(), b += x)
}
}

View file

@ -62,13 +62,13 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) {
private[this] var closeMode: CloseMode = DontClose // signals what to do after the current response
private[this] def close: Boolean = closeMode != DontClose
private[this] def closeIf(cond: Boolean): Unit =
if (cond) closeMode = CloseConnection
var closeMode: CloseMode = DontClose // signals what to do after the current response
def close: Boolean = closeMode != DontClose
def closeIf(cond: Boolean): Unit = if (cond) closeMode = CloseConnection
var transferring = false
setHandler(in, new InHandler {
def onPush(): Unit =
override def onPush(): Unit =
render(grab(in)) match {
case Strict(outElement)
push(out, outElement)
@ -76,23 +76,36 @@ private[http] class HttpResponseRendererFactory(serverHeader: Option[headers.Ser
case Streamed(outStream) transfer(outStream)
}
override def onUpstreamFinish(): Unit = closeMode = CloseConnection
override def onUpstreamFinish(): Unit =
if (transferring) closeMode = CloseConnection
else completeStage()
})
val waitForDemandHandler = new OutHandler {
def onPull(): Unit = if (close) completeStage() else pull(in)
def onPull(): Unit = pull(in)
}
setHandler(out, waitForDemandHandler)
def transfer(outStream: Source[ResponseRenderingOutput, Any]): Unit = {
transferring = true
val sinkIn = new SubSinkInlet[ResponseRenderingOutput]("RenderingSink")
sinkIn.setHandler(new InHandler {
def onPush(): Unit = push(out, sinkIn.grab())
override def onUpstreamFinish(): Unit = if (close) completeStage() else setHandler(out, waitForDemandHandler)
override def onPush(): Unit = push(out, sinkIn.grab())
override def onUpstreamFinish(): Unit =
if (close) completeStage()
else {
transferring = false
setHandler(out, waitForDemandHandler)
if (isAvailable(out)) pull(in)
}
})
setHandler(out, new OutHandler {
def onPull(): Unit = sinkIn.pull()
override def onPull(): Unit = sinkIn.pull()
override def onDownstreamFinish(): Unit = {
completeStage()
sinkIn.cancel()
}
})
sinkIn.pull()
Source.fromGraph(outStream).runWith(sinkIn.sink)(interpreter.subFusingMaterializer)
outStream.runWith(sinkIn.sink)(interpreter.subFusingMaterializer)
}
def render(ctx: ResponseRenderingContext): StrictOrStreamed = {

View file

@ -371,6 +371,11 @@ private[http] object HttpServerBluePrint {
def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) {
import akka.http.impl.engine.rendering.ResponseRenderingOutput._
/*
* These handlers are in charge until a switch command comes in, then they
* are replaced.
*/
setHandler(fromHttp, new InHandler {
override def onPush(): Unit =
grab(fromHttp) match {
@ -381,21 +386,22 @@ private[http] object HttpServerBluePrint {
cancel(fromHttp)
switchToWebsocket(handlerFlow)
}
override def onUpstreamFinish(): Unit = complete(toNet)
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
})
setHandler(toNet, new OutHandler {
override def onPull(): Unit = pull(fromHttp)
override def onDownstreamFinish(): Unit = completeStage()
})
setHandler(fromNet, new InHandler {
def onPush(): Unit = push(toHttp, grab(fromNet))
// propagate error but don't close stage yet to prevent fromHttp/fromWs being cancelled
// too eagerly
override def onPush(): Unit = push(toHttp, grab(fromNet))
override def onUpstreamFinish(): Unit = complete(toHttp)
override def onUpstreamFailure(ex: Throwable): Unit = fail(toHttp, ex)
})
setHandler(toHttp, new OutHandler {
override def onPull(): Unit = pull(fromNet)
override def onDownstreamFinish(): Unit = ()
override def onDownstreamFinish(): Unit = cancel(fromNet)
})
private var activeTimers = 0
@ -427,36 +433,60 @@ private[http] object HttpServerBluePrint {
case Right(messageHandler)
Websocket.stack(serverSide = true, maskingRandomFactory = settings.websocketRandomFactory, log = log).join(messageHandler)
}
val sinkIn = new SubSinkInlet[ByteString]("FrameSink")
val sourceOut = new SubSourceOutlet[ByteString]("FrameSource")
val timeoutKey = SubscriptionTimeout(() {
sourceOut.timeout(timeout)
if (sourceOut.isClosed) completeStage()
})
addTimeout(timeoutKey)
sinkIn.setHandler(new InHandler {
override def onPush(): Unit = push(toNet, sinkIn.grab())
})
setHandler(toNet, new OutHandler {
override def onPull(): Unit = sinkIn.pull()
override def onUpstreamFinish(): Unit = complete(toNet)
override def onUpstreamFailure(ex: Throwable): Unit = fail(toNet, ex)
})
setHandler(fromNet, new InHandler {
override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes)
})
sourceOut.setHandler(new OutHandler {
override def onPull(): Unit = {
if (!hasBeenPulled(fromNet)) pull(fromNet)
cancelTimeout(timeoutKey)
sourceOut.setHandler(new OutHandler {
override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet)
})
}
})
if (isClosed(fromNet)) {
setHandler(toNet, new OutHandler {
override def onPull(): Unit = sinkIn.pull()
override def onDownstreamFinish(): Unit = {
completeStage()
sinkIn.cancel()
}
})
Websocket.framing.join(frameHandler).runWith(Source.empty, sinkIn.sink)(subFusingMaterializer)
} else {
val sourceOut = new SubSourceOutlet[ByteString]("FrameSource")
Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer)
val timeoutKey = SubscriptionTimeout(() {
sourceOut.timeout(timeout)
if (sourceOut.isClosed) completeStage()
})
addTimeout(timeoutKey)
setHandler(toNet, new OutHandler {
override def onPull(): Unit = sinkIn.pull()
override def onDownstreamFinish(): Unit = {
completeStage()
sinkIn.cancel()
sourceOut.complete()
}
})
setHandler(fromNet, new InHandler {
override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes)
override def onUpstreamFinish(): Unit = sourceOut.complete()
override def onUpstreamFailure(ex: Throwable): Unit = sourceOut.fail(ex)
})
sourceOut.setHandler(new OutHandler {
override def onPull(): Unit = {
if (!hasBeenPulled(fromNet)) pull(fromNet)
cancelTimeout(timeoutKey)
sourceOut.setHandler(new OutHandler {
override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet)
override def onDownstreamFinish(): Unit = cancel(fromNet)
})
}
override def onDownstreamFinish(): Unit = cancel(fromNet)
})
Websocket.framing.join(frameHandler).runWith(sourceOut.source, sinkIn.sink)(subFusingMaterializer)
}
}
}
}

View file

@ -195,14 +195,19 @@ private[http] object Websocket {
def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) {
passAlong(bypass, out, doFinish = true, doFail = true)
passAlong(user, out, doFinish = false, doFail = false)
class PassAlong[T <: AnyRef](from: Inlet[T]) extends InHandler with (() Unit) {
override def apply(): Unit = tryPull(from)
override def onPush(): Unit = emit(out, grab(from), this)
override def onUpstreamFinish(): Unit =
if (isClosed(bypass) && isClosed(user)) completeStage()
}
setHandler(bypass, new PassAlong(bypass))
setHandler(user, new PassAlong(user))
passAlong(tick, out, doFinish = false, doFail = false)
setHandler(out, eagerTerminateOutput)
override def preStart(): Unit = {
super.preStart()
pull(bypass)
pull(user)
pull(tick)

View file

@ -11,16 +11,19 @@ import akka.stream.scaladsl._
import akka.stream.testkit.AkkaSpec
import akka.http.scaladsl.{ Http, TestUtils }
import akka.http.scaladsl.model._
import akka.stream.testkit.Utils
import org.scalatest.concurrent.ScalaFutures
class HighLevelOutgoingConnectionSpec extends AkkaSpec {
class HighLevelOutgoingConnectionSpec extends AkkaSpec with ScalaFutures {
implicit val materializer = ActorMaterializer()
implicit val patience = PatienceConfig(1.second)
"The connection-level client implementation" should {
"be able to handle 100 pipelined requests across one connection" in {
"be able to handle 100 pipelined requests across one connection" in Utils.assertAllStagesStopped {
val (_, serverHostName, serverPort) = TestUtils.temporaryServerHostnameAndPort()
Http().bindAndHandleSync(r HttpResponse(entity = r.uri.toString.reverse.takeWhile(Character.isDigit).reverse),
val binding = Http().bindAndHandleSync(r HttpResponse(entity = r.uri.toString.reverse.takeWhile(Character.isDigit).reverse),
serverHostName, serverPort)
val N = 100
@ -32,13 +35,14 @@ class HighLevelOutgoingConnectionSpec extends AkkaSpec {
.map { r val s = r.data.utf8String; log.debug(s); s.toInt }
.runFold(0)(_ + _)
Await.result(result, 10.seconds) shouldEqual N * (N + 1) / 2
result.futureValue(PatienceConfig(10.seconds)) shouldEqual N * (N + 1) / 2
binding.futureValue.unbind()
}
"be able to handle 100 pipelined requests across 4 connections (client-flow is reusable)" in {
"be able to handle 100 pipelined requests across 4 connections (client-flow is reusable)" in Utils.assertAllStagesStopped {
val (_, serverHostName, serverPort) = TestUtils.temporaryServerHostnameAndPort()
Http().bindAndHandleSync(r HttpResponse(entity = r.uri.toString.reverse.takeWhile(Character.isDigit).reverse),
val binding = Http().bindAndHandleSync(r HttpResponse(entity = r.uri.toString.reverse.takeWhile(Character.isDigit).reverse),
serverHostName, serverPort)
val connFlow = Http().outgoingConnection(serverHostName, serverPort)
@ -64,12 +68,14 @@ class HighLevelOutgoingConnectionSpec extends AkkaSpec {
.map { r val s = r.data.utf8String; log.debug(s); s.toInt }
.runFold(0)(_ + _)
Await.result(result, 10.seconds) shouldEqual C * N * (N + 1) / 2
result.futureValue(PatienceConfig(10.seconds)) shouldEqual C * N * (N + 1) / 2
binding.futureValue.unbind()
}
"catch response stream truncation" in {
"catch response stream truncation" in Utils.assertAllStagesStopped {
val (_, serverHostName, serverPort) = TestUtils.temporaryServerHostnameAndPort()
Http().bindAndHandleSync({
val binding = Http().bindAndHandleSync({
case HttpRequest(_, Uri.Path("/b"), _, _, _) HttpResponse(headers = List(headers.Connection("close")))
case _ HttpResponse()
}, serverHostName, serverPort)
@ -81,6 +87,7 @@ class HighLevelOutgoingConnectionSpec extends AkkaSpec {
.runWith(Sink.head)
a[One2OneBidiFlow.OutputTruncationException.type] should be thrownBy Await.result(x, 1.second)
binding.futureValue.unbind()
}
}
}

View file

@ -1,50 +0,0 @@
/**
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import akka.stream.testkit.AkkaSpec
import scala.concurrent.Await
import com.typesafe.config.ConfigFactory
import com.typesafe.config.Config
import akka.actor.ActorSystem
import akka.http.scaladsl.model.HttpRequest
import akka.http.scaladsl.model.ws._
import akka.http.scaladsl._
import akka.stream.scaladsl._
import akka.stream._
import scala.concurrent.duration._
import org.scalatest.concurrent.ScalaFutures
import org.scalactic.ConversionCheckedTripleEquals
import akka.stream.testkit.Utils
class BypassRouterSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode = off") with ScalaFutures with ConversionCheckedTripleEquals {
implicit val patience = PatienceConfig(3.seconds)
import system.dispatcher
implicit val materializer = ActorMaterializer()
"BypassRouter" must {
"work without double pull-ing some ports" in Utils.assertAllStagesStopped {
val bindingFuture = Http().bindAndHandleSync({
case HttpRequest(_, _, headers, _, _)
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket u }.get
upgrade.handleMessages(Flow.apply, None)
}, interface = "localhost", port = 8080)
val binding = Await.result(bindingFuture, 3.seconds)
val N = 100
val (response, count) = Http().singleWebsocketRequest(
WebsocketRequest("ws://127.0.0.1:8080"),
Flow.fromSinkAndSourceMat(
Sink.fold(0)((n, _: Message) n + 1),
Source.repeat(TextMessage("hello")).take(N))(Keep.left))
count.futureValue should ===(N)
binding.unbind()
}
}
}

View file

@ -0,0 +1,174 @@
/**
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.impl.engine.ws
import scala.concurrent.Await
import scala.concurrent.duration.DurationInt
import org.scalactic.ConversionCheckedTripleEquals
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.time.Span.convertDurationToSpan
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.HttpRequest
import akka.http.scaladsl.model.Uri.apply
import akka.http.scaladsl.model.ws._
import akka.stream._
import akka.stream.scaladsl._
import akka.stream.testkit._
import akka.stream.scaladsl.GraphDSL.Implicits._
import org.scalatest.concurrent.Eventually
import akka.stream.io.SslTlsPlacebo
import java.net.InetSocketAddress
import akka.stream.impl.fusing.GraphStages
import akka.util.ByteString
import akka.http.scaladsl.model.StatusCodes
import akka.stream.testkit.scaladsl.TestSink
import scala.concurrent.Future
class WebsocketIntegrationSpec extends AkkaSpec("akka.stream.materializer.debug.fuzzing-mode=off")
with ScalaFutures with ConversionCheckedTripleEquals with Eventually {
implicit val patience = PatienceConfig(3.seconds)
import system.dispatcher
implicit val materializer = ActorMaterializer()
"A Websocket server" must {
"not reset the connection when no data are flowing" in Utils.assertAllStagesStopped {
val source = TestPublisher.probe[Message]()
val bindingFuture = Http().bindAndHandleSync({
case HttpRequest(_, _, headers, _, _)
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket u }.get
upgrade.handleMessages(Flow.fromSinkAndSource(Sink.ignore, Source.fromPublisher(source)), None)
}, interface = "localhost", port = 0)
val binding = Await.result(bindingFuture, 3.seconds)
val myPort = binding.localAddress.getPort
val (response, sink) = Http().singleWebsocketRequest(
WebsocketRequest("ws://127.0.0.1:" + myPort),
Flow.fromSinkAndSourceMat(TestSink.probe[Message], Source.empty)(Keep.left))
response.futureValue.response.status.isSuccess should ===(true)
sink
.request(10)
.expectNoMsg(500.millis)
source
.sendNext(TextMessage("hello"))
.sendComplete()
sink
.expectNext(TextMessage("hello"))
.expectComplete()
binding.unbind()
}
"not reset the connection when no data are flowing and the connection is closed from the client" in Utils.assertAllStagesStopped {
val source = TestPublisher.probe[Message]()
val bindingFuture = Http().bindAndHandleSync({
case HttpRequest(_, _, headers, _, _)
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket u }.get
upgrade.handleMessages(Flow.fromSinkAndSource(Sink.ignore, Source.fromPublisher(source)), None)
}, interface = "localhost", port = 0)
val binding = Await.result(bindingFuture, 3.seconds)
val myPort = binding.localAddress.getPort
val ((response, breaker), sink) =
Source.empty
.viaMat {
Http().websocketClientLayer(WebsocketRequest("ws://localhost:" + myPort))
.atop(SslTlsPlacebo.forScala)
.joinMat(Flow.fromGraph(GraphStages.breaker[ByteString]).via(
Tcp().outgoingConnection(new InetSocketAddress("localhost", myPort), halfClose = true)))(Keep.both)
}(Keep.right)
.toMat(TestSink.probe[Message])(Keep.both)
.run()
response.futureValue.response.status.isSuccess should ===(true)
sink
.request(10)
.expectNoMsg(1500.millis)
breaker.value.get.get.complete()
source
.sendNext(TextMessage("hello"))
.sendComplete()
sink
.expectNext(TextMessage("hello"))
.expectComplete()
binding.unbind()
}
"echo 100 elements and then shut down without error" in Utils.assertAllStagesStopped {
val bindingFuture = Http().bindAndHandleSync({
case HttpRequest(_, _, headers, _, _)
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket u }.get
upgrade.handleMessages(Flow.apply, None)
}, interface = "localhost", port = 0)
val binding = Await.result(bindingFuture, 3.seconds)
val myPort = binding.localAddress.getPort
val N = 100
val (response, count) = Http().singleWebsocketRequest(
WebsocketRequest("ws://127.0.0.1:" + myPort),
Flow.fromSinkAndSourceMat(
Sink.fold(0)((n, _: Message) n + 1),
Source.repeat(TextMessage("hello")).take(N))(Keep.left))
count.futureValue should ===(N)
binding.unbind()
}
"send back 100 elements and then terminate without error even when not ordinarily closed" in Utils.assertAllStagesStopped {
val N = 100
val handler = Flow.fromGraph(GraphDSL.create() { implicit b
val merge = b.add(Merge[Int](2))
// convert to int so we can connect to merge
val mapMsgToInt = b.add(Flow[Message].map(_ -1))
val mapIntToMsg = b.add(Flow[Int].map(x TextMessage.Strict(s"Sending: $x")))
// source we want to use to send message to the connected websocket sink
val rangeSource = b.add(Source(1 to N))
mapMsgToInt ~> merge // this part of the merge will never provide msgs
rangeSource ~> merge ~> mapIntToMsg
FlowShape(mapMsgToInt.in, mapIntToMsg.out)
})
val bindingFuture = Http().bindAndHandleSync({
case HttpRequest(_, _, headers, _, _)
val upgrade = headers.collectFirst { case u: UpgradeToWebsocket u }.get
upgrade.handleMessages(handler, None)
}, interface = "localhost", port = 0)
val binding = Await.result(bindingFuture, 3.seconds)
val myPort = binding.localAddress.getPort
@volatile var messages = 0
val (breaker, completion) =
Source.maybe
.viaMat {
Http().websocketClientLayer(WebsocketRequest("ws://localhost:" + myPort))
.atop(SslTlsPlacebo.forScala)
// the resource leak of #19398 existed only for severed websocket connections
.atopMat(GraphStages.bidiBreaker[ByteString, ByteString])(Keep.right)
.join(Tcp().outgoingConnection(new InetSocketAddress("localhost", myPort), halfClose = true))
}(Keep.right)
.toMat(Sink.foreach(_ messages += 1))(Keep.both)
.run()
eventually(messages should ===(N))
// breaker should have been fulfilled long ago
breaker.value.get.get.completeAndCancel()
completion.futureValue
binding.unbind()
}
}
}

View file

@ -108,6 +108,107 @@ object GraphStages {
private val _detacher = new Detacher[Any]
def detacher[T]: GraphStage[FlowShape[T, T]] = _detacher.asInstanceOf[GraphStage[FlowShape[T, T]]]
final class Breaker(callback: Breaker.Operation Unit) {
import Breaker._
def complete(): Unit = callback(Complete)
def cancel(): Unit = callback(Cancel)
def fail(ex: Throwable): Unit = callback(Fail(ex))
def completeAndCancel(): Unit = callback(CompleteAndCancel)
def failAndCancel(ex: Throwable): Unit = callback(FailAndCancel(ex))
}
object Breaker extends GraphStageWithMaterializedValue[FlowShape[Any, Any], Future[Breaker]] {
sealed trait Operation
case object Complete extends Operation
case object Cancel extends Operation
case class Fail(ex: Throwable) extends Operation
case object CompleteAndCancel extends Operation
case class FailAndCancel(ex: Throwable) extends Operation
override val initialAttributes = Attributes.name("breaker")
override val shape = FlowShape(Inlet[Any]("breaker.in"), Outlet[Any]("breaker.out"))
override def createLogicAndMaterializedValue(attr: Attributes) = {
val promise = Promise[Breaker]
val logic = new GraphStageLogic(shape) {
passAlong(shape.in, shape.out)
setHandler(shape.out, eagerTerminateOutput)
override def preStart(): Unit = {
pull(shape.in)
promise.success(new Breaker(getAsyncCallback[Operation] {
case Complete complete(shape.out)
case Cancel cancel(shape.in)
case Fail(ex) fail(shape.out, ex)
case CompleteAndCancel completeStage()
case FailAndCancel(ex) failStage(ex)
}.invoke))
}
}
(logic, promise.future)
}
}
def breaker[T]: Graph[FlowShape[T, T], Future[Breaker]] = Breaker.asInstanceOf[Graph[FlowShape[T, T], Future[Breaker]]]
object BidiBreaker extends GraphStageWithMaterializedValue[BidiShape[Any, Any, Any, Any], Future[Breaker]] {
import Breaker._
override val initialAttributes = Attributes.name("breaker")
override val shape = BidiShape(
Inlet[Any]("breaker.in1"), Outlet[Any]("breaker.out1"),
Inlet[Any]("breaker.in2"), Outlet[Any]("breaker.out2"))
override def createLogicAndMaterializedValue(attr: Attributes) = {
val promise = Promise[Breaker]
val logic = new GraphStageLogic(shape) {
setHandler(shape.in1, new InHandler {
override def onPush(): Unit = push(shape.out1, grab(shape.in1))
override def onUpstreamFinish(): Unit = complete(shape.out1)
override def onUpstreamFailure(ex: Throwable): Unit = fail(shape.out1, ex)
})
setHandler(shape.in2, new InHandler {
override def onPush(): Unit = push(shape.out2, grab(shape.in2))
override def onUpstreamFinish(): Unit = complete(shape.out2)
override def onUpstreamFailure(ex: Throwable): Unit = fail(shape.out2, ex)
})
setHandler(shape.out1, new OutHandler {
override def onPull(): Unit = pull(shape.in1)
override def onDownstreamFinish(): Unit = cancel(shape.in1)
})
setHandler(shape.out2, new OutHandler {
override def onPull(): Unit = pull(shape.in2)
override def onDownstreamFinish(): Unit = cancel(shape.in2)
})
override def preStart(): Unit = {
promise.success(new Breaker(getAsyncCallback[Operation] {
case Complete
complete(shape.out1)
complete(shape.out2)
case Cancel
cancel(shape.in1)
cancel(shape.in2)
case Fail(ex)
fail(shape.out1, ex)
fail(shape.out2, ex)
case CompleteAndCancel completeStage()
case FailAndCancel(ex) failStage(ex)
}.invoke))
}
}
(logic, promise.future)
}
}
def bidiBreaker[T1, T2]: Graph[BidiShape[T1, T1, T2, T2], Future[Breaker]] = BidiBreaker.asInstanceOf[Graph[BidiShape[T1, T1, T2, T2], Future[Breaker]]]
private object TickSource {
class TickSourceCancellable(cancelled: AtomicBoolean) extends Cancellable {
private val cancelPromise = Promise[Unit]()

View file

@ -261,11 +261,13 @@ private[stream] object TcpConnectionStage {
// (or half-close is turned off)
if (isClosed(bytesOut) || !role.halfClose) connection ! Close
// We still read, so we only close the write side
else connection ! ConfirmedClose
else if (connection != null) connection ! ConfirmedClose
else completeStage()
}
override def onUpstreamFailure(ex: Throwable): Unit = {
connection ! Abort
if (connection != null) connection ! Abort
else failStage(ex)
}
})