diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala index dbf077b96d..2fddc74a26 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala @@ -1,27 +1,31 @@ package akka.stream.io -import java.security.{ KeyStore, SecureRandom } -import javax.net.ssl.{ TrustManagerFactory, KeyManagerFactory, SSLContext } -import akka.stream.{ Graph, BidiShape, ActorFlowMaterializer } -import akka.stream.scaladsl._ -import akka.stream.io._ -import akka.stream.testkit.{ TestUtils, AkkaSpec } -import akka.util.ByteString -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.collection.immutable -import scala.util.Random -import akka.stream.stage.AsyncStage -import akka.stream.stage.AsyncContext -import java.util.concurrent.TimeoutException -import akka.actor.ActorSystem -import javax.net.ssl.SSLSession -import akka.pattern.{ after ⇒ later } -import scala.concurrent.Future import java.net.InetSocketAddress +import java.security.KeyStore +import java.security.SecureRandom +import java.util.concurrent.TimeoutException + +import scala.collection.immutable +import scala.concurrent.Await +import scala.concurrent.Future +import scala.concurrent.duration._ +import scala.util.Random + +import akka.actor.ActorSystem +import akka.pattern.{ after ⇒ later } +import akka.stream.ActorFlowMaterializer +import akka.stream.scaladsl._ +import akka.stream.scaladsl.FlowGraph.Implicits._ +import akka.stream.stage._ +import akka.stream.testkit.AkkaSpec +import akka.stream.testkit.StreamTestKit.PublisherProbe +import akka.stream.testkit.StreamTestKit.assertAllStagesStopped import akka.testkit.EventFilter -import akka.stream.stage.PushStage -import akka.stream.stage.Context +import akka.util.ByteString +import javax.net.ssl.KeyManagerFactory +import javax.net.ssl.SSLContext +import javax.net.ssl.SSLSession +import javax.net.ssl.TrustManagerFactory object TlsSpec { @@ -107,7 +111,7 @@ class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off import FlowGraph.Implicits._ - "StreamTLS" must { + "SslTls" must { val sslContext = initSslContext() @@ -359,7 +363,7 @@ class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off commPattern ← communicationPatterns scenario ← scenarios } { - s"work in mode ${commPattern.name} while sending ${scenario.name}" in { + s"work in mode ${commPattern.name} while sending ${scenario.name}" in assertAllStagesStopped { val onRHS = debug.via(scenario.flow) val f = Source(scenario.inputs) @@ -391,6 +395,40 @@ class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off } } + "reliably cancel subscriptions when TransportIn fails early" in assertAllStagesStopped { + val ex = new Exception("hello") + val (sub, out1, out2) = + FlowGraph.closed(Source.subscriber[SslTlsOutbound], Sink.head[ByteString], Sink.head[SslTlsInbound])((_, _, _)) { implicit b ⇒ + (s, o1, o2) ⇒ + val tls = b.add(clientTls(EagerClose)) + s ~> tls.in1; tls.out1 ~> o1 + o2 <~ tls.out2; tls.in2 <~ Source.failed(ex) + }.run() + the[Exception] thrownBy Await.result(out1, 1.second) should be(ex) + the[Exception] thrownBy Await.result(out2, 1.second) should be(ex) + Thread.sleep(500) + val pub = PublisherProbe() + pub.subscribe(sub) + pub.expectSubscription().expectCancellation() + } + + "reliably cancel subscriptions when UserIn fails early" in assertAllStagesStopped { + val ex = new Exception("hello") + val (sub, out1, out2) = + FlowGraph.closed(Source.subscriber[ByteString], Sink.head[ByteString], Sink.head[SslTlsInbound])((_, _, _)) { implicit b ⇒ + (s, o1, o2) ⇒ + val tls = b.add(clientTls(EagerClose)) + Source.failed[SslTlsOutbound](ex) ~> tls.in1; tls.out1 ~> o1 + o2 <~ tls.out2; tls.in2 <~ s + }.run() + the[Exception] thrownBy Await.result(out1, 1.second) should be(ex) + the[Exception] thrownBy Await.result(out2, 1.second) should be(ex) + Thread.sleep(500) + val pub = PublisherProbe() + pub.subscribe(sub) + pub.expectSubscription().expectCancellation() + } + } "A SslTlsPlacebo" must { diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala index 34ed5dee75..c83527908a 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala @@ -92,7 +92,7 @@ private[akka] case class ActorFlowMaterializerImpl( case tls: TlsModule ⇒ val es = effectiveSettings(effectiveAttributes) - val props = SslTlsCipherActor.props(es, tls.sslContext, tls.firstSession, tracing = true, tls.role, tls.closing) + val props = SslTlsCipherActor.props(es, tls.sslContext, tls.firstSession, tracing = false, tls.role, tls.closing) val impl = actorOf(props, stageName(effectiveAttributes), es.dispatcher) def factory(id: Int) = new ActorPublisher[Any](impl) { override val wakeUpMsg = FanOut.SubstreamSubscribePending(id) diff --git a/akka-stream/src/main/scala/akka/stream/impl/FanIn.scala b/akka-stream/src/main/scala/akka/stream/impl/FanIn.scala index 1e3cae43d7..19a13fac44 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FanIn.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FanIn.scala @@ -186,7 +186,7 @@ private[akka] object FanIn { } def inputsAvailableFor(id: Int) = new TransferState { - override def isCompleted: Boolean = depleted(id) + override def isCompleted: Boolean = depleted(id) || cancelled(id) override def isReady: Boolean = pending(id) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala b/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala index 81de929af2..baf2dc32ef 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala @@ -69,6 +69,8 @@ private[akka] object FanOut { def isCancelled(output: Int): Boolean = cancelled(output) + def isErrored(output: Int): Boolean = errored(output) + def complete(): Unit = if (!bunchCancelled) { bunchCancelled = true @@ -187,7 +189,7 @@ private[akka] object FanOut { def onCancel(output: Int): Unit = () def demandAvailableFor(id: Int) = new TransferState { - override def isCompleted: Boolean = cancelled(id) || completed(id) + override def isCompleted: Boolean = cancelled(id) || completed(id) || errored(id) override def isReady: Boolean = pending(id) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/SslTls.scala b/akka-stream/src/main/scala/akka/stream/impl/io/SslTls.scala index c01ab68d1e..2d7043cc6b 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/SslTls.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/SslTls.scala @@ -69,8 +69,8 @@ private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, s * block. */ class ChoppingBlock(idx: Int, name: String) extends TransferState { - override def isReady: Boolean = buffer.nonEmpty - override def isCompleted: Boolean = false + override def isReady: Boolean = buffer.nonEmpty || inputBunch.isPending(idx) || inputBunch.isDepleted(idx) + override def isCompleted: Boolean = inputBunch.isCancelled(idx) private var buffer = ByteString.empty @@ -203,38 +203,76 @@ private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, s */ var lastHandshakeStatus: HandshakeStatus = _ + var corkUser = true val engineNeedsWrap = new TransferState { def isReady = lastHandshakeStatus == NEED_WRAP - def isCompleted = false + def isCompleted = engine.isOutboundDone } val engineInboundOpen = new TransferState { - def isReady = !engine.isInboundDone() - def isCompleted = false + def isReady = true + def isCompleted = engine.isInboundDone } - var corkUser = true - val userHasData = new TransferState { - private val user = inputBunch.inputsOrCompleteAvailableFor(UserIn) || userInChoppingBlock - def isReady = !corkUser && user.isReady && lastHandshakeStatus != NEED_UNWRAP - def isCompleted = false + def isReady = !corkUser && userInChoppingBlock.isReady && lastHandshakeStatus != NEED_UNWRAP + def isCompleted = inputBunch.isCancelled(UserIn) || inputBunch.isDepleted(UserIn) } - val transportHasData = inputBunch.inputsOrCompleteAvailableFor(TransportIn) || transportInChoppingBlock val userOutCancelled = new TransferState { def isReady = outputBunch.isCancelled(UserOut) - def isCompleted = inputBunch.isDepleted(TransportIn) + def isCompleted = engine.isInboundDone || outputBunch.isErrored(UserOut) } // bidirectional case val outbound = (userHasData || engineNeedsWrap) && outputBunch.demandAvailableFor(TransportOut) - val inbound = (transportHasData || userOutCancelled) && outputBunch.demandOrCancelAvailableFor(UserOut) + val inbound = (transportInChoppingBlock && outputBunch.demandAvailableFor(UserOut)) || userOutCancelled // half-closed val outboundHalfClosed = engineNeedsWrap && outputBunch.demandAvailableFor(TransportOut) - val inboundHalfClosed = transportHasData && engineInboundOpen + val inboundHalfClosed = transportInChoppingBlock && engineInboundOpen + + val bidirectional = TransferPhase(outbound || inbound) { () ⇒ + if (tracing) log.debug("bidirectional") + val continue = doInbound(isOutboundClosed = false, inbound) + if (continue) { + if (tracing) log.debug("bidirectional continue") + doOutbound(isInboundClosed = false) + } + } + + val flushingOutbound = TransferPhase(outboundHalfClosed) { () ⇒ + if (tracing) log.debug("flushingOutbound") + try doWrap() + catch { case ex: SSLException ⇒ nextPhase(completedPhase) } + } + + val awaitingClose = TransferPhase(inputBunch.inputsAvailableFor(TransportIn) && engineInboundOpen) { () ⇒ + if (tracing) log.debug("awaitingClose") + transportInChoppingBlock.chopInto(transportInBuffer) + try doUnwrap(ignoreOutput = true) + catch { case ex: SSLException ⇒ nextPhase(completedPhase) } + } + + val outboundClosed = TransferPhase(outboundHalfClosed || inbound) { () ⇒ + if (tracing) log.debug("outboundClosed") + val continue = doInbound(isOutboundClosed = true, inbound) + if (continue && outboundHalfClosed.isReady) { + if (tracing) log.debug("outboundClosed continue") + try doWrap() + catch { case ex: SSLException ⇒ nextPhase(completedPhase) } + } + } + + val inboundClosed = TransferPhase(outbound || inboundHalfClosed) { () ⇒ + if (tracing) log.debug("inboundClosed") + val continue = doInbound(isOutboundClosed = false, inboundHalfClosed) + if (continue) { + if (tracing) log.debug("inboundClosed continue") + doOutbound(isInboundClosed = true) + } + } def completeOrFlush(): Unit = if (engine.isOutboundDone()) nextPhase(completedPhase) @@ -282,6 +320,7 @@ private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, s } nextPhase(outboundClosed) } else if (outputBunch.isCancelled(TransportOut)) { + if (tracing) log.debug("shutting down because TransportOut is cancelled") nextPhase(completedPhase) } else if (outbound.isReady) { if (userHasData.isReady) userInChoppingBlock.chopInto(userInBuffer) @@ -293,47 +332,6 @@ private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, s } } - val bidirectional = TransferPhase(outbound || inbound) { () ⇒ - if (tracing) log.debug("bidirectional") - val continue = doInbound(isOutboundClosed = false, inbound) - if (continue) { - if (tracing) log.debug("bidirectional continue") - doOutbound(isInboundClosed = false) - } - } - - val flushingOutbound = TransferPhase(outboundHalfClosed) { () ⇒ - if (tracing) log.debug("flushingOutbound") - try doWrap() - catch { case ex: SSLException ⇒ nextPhase(completedPhase) } - } - - val awaitingClose = TransferPhase(inputBunch.inputsAvailableFor(TransportIn)) { () ⇒ - if (tracing) log.debug("awaitingClose") - transportInChoppingBlock.chopInto(transportInBuffer) - try doUnwrap(ignoreOutput = true) - catch { case ex: SSLException ⇒ nextPhase(completedPhase) } - } - - val outboundClosed = TransferPhase(outboundHalfClosed || inbound) { () ⇒ - if (tracing) log.debug("outboundClosed") - val continue = doInbound(isOutboundClosed = true, inbound) - if (continue && outboundHalfClosed.isReady) { - if (tracing) log.debug("outboundClosed continue") - try doWrap() - catch { case ex: SSLException ⇒ nextPhase(completedPhase) } - } - } - - val inboundClosed = TransferPhase(outbound || inboundHalfClosed) { () ⇒ - if (tracing) log.debug("inboundClosed") - val continue = doInbound(isOutboundClosed = false, inboundHalfClosed) - if (continue) { - if (tracing) log.debug("inboundClosed continue") - doOutbound(isInboundClosed = true) - } - } - def flushToTransport(): Unit = { if (tracing) log.debug("flushToTransport") transportOutBuffer.flip() @@ -427,15 +425,19 @@ private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, s override def receive = inputBunch.subreceive.orElse[Any, Unit](outputBunch.subreceive) - nextPhase(bidirectional) + initialPhase(2, bidirectional) protected def fail(e: Throwable): Unit = { - // FIXME: escalate to supervisor if (tracing) log.debug("fail {} due to: {}", self, e.getMessage) inputBunch.cancel() outputBunch.error(TransportOut, e) outputBunch.error(UserOut, e) - context.stop(self) + pump() + } + + override def postStop(): Unit = { + if (tracing) log.debug("postStop") + super.postStop() } override protected def pumpFailed(e: Throwable): Unit = fail(e)