Merge pull request #17289 from akka/wip-SslTls-termination-∂π
refactor SslTlsActor and stop it reliably
This commit is contained in:
commit
7b4a640147
5 changed files with 125 additions and 83 deletions
|
|
@ -1,27 +1,31 @@
|
|||
package akka.stream.io
|
||||
|
||||
import java.security.{ KeyStore, SecureRandom }
|
||||
import javax.net.ssl.{ TrustManagerFactory, KeyManagerFactory, SSLContext }
|
||||
import akka.stream.{ Graph, BidiShape, ActorFlowMaterializer }
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream.io._
|
||||
import akka.stream.testkit.{ TestUtils, AkkaSpec }
|
||||
import akka.util.ByteString
|
||||
import scala.concurrent.Await
|
||||
import scala.concurrent.duration._
|
||||
import scala.collection.immutable
|
||||
import scala.util.Random
|
||||
import akka.stream.stage.AsyncStage
|
||||
import akka.stream.stage.AsyncContext
|
||||
import java.util.concurrent.TimeoutException
|
||||
import akka.actor.ActorSystem
|
||||
import javax.net.ssl.SSLSession
|
||||
import akka.pattern.{ after ⇒ later }
|
||||
import scala.concurrent.Future
|
||||
import java.net.InetSocketAddress
|
||||
import java.security.KeyStore
|
||||
import java.security.SecureRandom
|
||||
import java.util.concurrent.TimeoutException
|
||||
|
||||
import scala.collection.immutable
|
||||
import scala.concurrent.Await
|
||||
import scala.concurrent.Future
|
||||
import scala.concurrent.duration._
|
||||
import scala.util.Random
|
||||
|
||||
import akka.actor.ActorSystem
|
||||
import akka.pattern.{ after ⇒ later }
|
||||
import akka.stream.ActorFlowMaterializer
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream.scaladsl.FlowGraph.Implicits._
|
||||
import akka.stream.stage._
|
||||
import akka.stream.testkit.AkkaSpec
|
||||
import akka.stream.testkit.StreamTestKit.PublisherProbe
|
||||
import akka.stream.testkit.StreamTestKit.assertAllStagesStopped
|
||||
import akka.testkit.EventFilter
|
||||
import akka.stream.stage.PushStage
|
||||
import akka.stream.stage.Context
|
||||
import akka.util.ByteString
|
||||
import javax.net.ssl.KeyManagerFactory
|
||||
import javax.net.ssl.SSLContext
|
||||
import javax.net.ssl.SSLSession
|
||||
import javax.net.ssl.TrustManagerFactory
|
||||
|
||||
object TlsSpec {
|
||||
|
||||
|
|
@ -107,7 +111,7 @@ class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off
|
|||
|
||||
import FlowGraph.Implicits._
|
||||
|
||||
"StreamTLS" must {
|
||||
"SslTls" must {
|
||||
|
||||
val sslContext = initSslContext()
|
||||
|
||||
|
|
@ -359,7 +363,7 @@ class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off
|
|||
commPattern ← communicationPatterns
|
||||
scenario ← scenarios
|
||||
} {
|
||||
s"work in mode ${commPattern.name} while sending ${scenario.name}" in {
|
||||
s"work in mode ${commPattern.name} while sending ${scenario.name}" in assertAllStagesStopped {
|
||||
val onRHS = debug.via(scenario.flow)
|
||||
val f =
|
||||
Source(scenario.inputs)
|
||||
|
|
@ -391,6 +395,40 @@ class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off
|
|||
}
|
||||
}
|
||||
|
||||
"reliably cancel subscriptions when TransportIn fails early" in assertAllStagesStopped {
|
||||
val ex = new Exception("hello")
|
||||
val (sub, out1, out2) =
|
||||
FlowGraph.closed(Source.subscriber[SslTlsOutbound], Sink.head[ByteString], Sink.head[SslTlsInbound])((_, _, _)) { implicit b ⇒
|
||||
(s, o1, o2) ⇒
|
||||
val tls = b.add(clientTls(EagerClose))
|
||||
s ~> tls.in1; tls.out1 ~> o1
|
||||
o2 <~ tls.out2; tls.in2 <~ Source.failed(ex)
|
||||
}.run()
|
||||
the[Exception] thrownBy Await.result(out1, 1.second) should be(ex)
|
||||
the[Exception] thrownBy Await.result(out2, 1.second) should be(ex)
|
||||
Thread.sleep(500)
|
||||
val pub = PublisherProbe()
|
||||
pub.subscribe(sub)
|
||||
pub.expectSubscription().expectCancellation()
|
||||
}
|
||||
|
||||
"reliably cancel subscriptions when UserIn fails early" in assertAllStagesStopped {
|
||||
val ex = new Exception("hello")
|
||||
val (sub, out1, out2) =
|
||||
FlowGraph.closed(Source.subscriber[ByteString], Sink.head[ByteString], Sink.head[SslTlsInbound])((_, _, _)) { implicit b ⇒
|
||||
(s, o1, o2) ⇒
|
||||
val tls = b.add(clientTls(EagerClose))
|
||||
Source.failed[SslTlsOutbound](ex) ~> tls.in1; tls.out1 ~> o1
|
||||
o2 <~ tls.out2; tls.in2 <~ s
|
||||
}.run()
|
||||
the[Exception] thrownBy Await.result(out1, 1.second) should be(ex)
|
||||
the[Exception] thrownBy Await.result(out2, 1.second) should be(ex)
|
||||
Thread.sleep(500)
|
||||
val pub = PublisherProbe()
|
||||
pub.subscribe(sub)
|
||||
pub.expectSubscription().expectCancellation()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
"A SslTlsPlacebo" must {
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ private[akka] case class ActorFlowMaterializerImpl(
|
|||
|
||||
case tls: TlsModule ⇒
|
||||
val es = effectiveSettings(effectiveAttributes)
|
||||
val props = SslTlsCipherActor.props(es, tls.sslContext, tls.firstSession, tracing = true, tls.role, tls.closing)
|
||||
val props = SslTlsCipherActor.props(es, tls.sslContext, tls.firstSession, tracing = false, tls.role, tls.closing)
|
||||
val impl = actorOf(props, stageName(effectiveAttributes), es.dispatcher)
|
||||
def factory(id: Int) = new ActorPublisher[Any](impl) {
|
||||
override val wakeUpMsg = FanOut.SubstreamSubscribePending(id)
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ private[akka] object FanIn {
|
|||
}
|
||||
|
||||
def inputsAvailableFor(id: Int) = new TransferState {
|
||||
override def isCompleted: Boolean = depleted(id)
|
||||
override def isCompleted: Boolean = depleted(id) || cancelled(id)
|
||||
override def isReady: Boolean = pending(id)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -69,6 +69,8 @@ private[akka] object FanOut {
|
|||
|
||||
def isCancelled(output: Int): Boolean = cancelled(output)
|
||||
|
||||
def isErrored(output: Int): Boolean = errored(output)
|
||||
|
||||
def complete(): Unit =
|
||||
if (!bunchCancelled) {
|
||||
bunchCancelled = true
|
||||
|
|
@ -187,7 +189,7 @@ private[akka] object FanOut {
|
|||
def onCancel(output: Int): Unit = ()
|
||||
|
||||
def demandAvailableFor(id: Int) = new TransferState {
|
||||
override def isCompleted: Boolean = cancelled(id) || completed(id)
|
||||
override def isCompleted: Boolean = cancelled(id) || completed(id) || errored(id)
|
||||
override def isReady: Boolean = pending(id)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -69,8 +69,8 @@ private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, s
|
|||
* block.
|
||||
*/
|
||||
class ChoppingBlock(idx: Int, name: String) extends TransferState {
|
||||
override def isReady: Boolean = buffer.nonEmpty
|
||||
override def isCompleted: Boolean = false
|
||||
override def isReady: Boolean = buffer.nonEmpty || inputBunch.isPending(idx) || inputBunch.isDepleted(idx)
|
||||
override def isCompleted: Boolean = inputBunch.isCancelled(idx)
|
||||
|
||||
private var buffer = ByteString.empty
|
||||
|
||||
|
|
@ -203,38 +203,76 @@ private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, s
|
|||
*/
|
||||
|
||||
var lastHandshakeStatus: HandshakeStatus = _
|
||||
var corkUser = true
|
||||
|
||||
val engineNeedsWrap = new TransferState {
|
||||
def isReady = lastHandshakeStatus == NEED_WRAP
|
||||
def isCompleted = false
|
||||
def isCompleted = engine.isOutboundDone
|
||||
}
|
||||
|
||||
val engineInboundOpen = new TransferState {
|
||||
def isReady = !engine.isInboundDone()
|
||||
def isCompleted = false
|
||||
def isReady = true
|
||||
def isCompleted = engine.isInboundDone
|
||||
}
|
||||
|
||||
var corkUser = true
|
||||
|
||||
val userHasData = new TransferState {
|
||||
private val user = inputBunch.inputsOrCompleteAvailableFor(UserIn) || userInChoppingBlock
|
||||
def isReady = !corkUser && user.isReady && lastHandshakeStatus != NEED_UNWRAP
|
||||
def isCompleted = false
|
||||
def isReady = !corkUser && userInChoppingBlock.isReady && lastHandshakeStatus != NEED_UNWRAP
|
||||
def isCompleted = inputBunch.isCancelled(UserIn) || inputBunch.isDepleted(UserIn)
|
||||
}
|
||||
|
||||
val transportHasData = inputBunch.inputsOrCompleteAvailableFor(TransportIn) || transportInChoppingBlock
|
||||
val userOutCancelled = new TransferState {
|
||||
def isReady = outputBunch.isCancelled(UserOut)
|
||||
def isCompleted = inputBunch.isDepleted(TransportIn)
|
||||
def isCompleted = engine.isInboundDone || outputBunch.isErrored(UserOut)
|
||||
}
|
||||
|
||||
// bidirectional case
|
||||
val outbound = (userHasData || engineNeedsWrap) && outputBunch.demandAvailableFor(TransportOut)
|
||||
val inbound = (transportHasData || userOutCancelled) && outputBunch.demandOrCancelAvailableFor(UserOut)
|
||||
val inbound = (transportInChoppingBlock && outputBunch.demandAvailableFor(UserOut)) || userOutCancelled
|
||||
|
||||
// half-closed
|
||||
val outboundHalfClosed = engineNeedsWrap && outputBunch.demandAvailableFor(TransportOut)
|
||||
val inboundHalfClosed = transportHasData && engineInboundOpen
|
||||
val inboundHalfClosed = transportInChoppingBlock && engineInboundOpen
|
||||
|
||||
val bidirectional = TransferPhase(outbound || inbound) { () ⇒
|
||||
if (tracing) log.debug("bidirectional")
|
||||
val continue = doInbound(isOutboundClosed = false, inbound)
|
||||
if (continue) {
|
||||
if (tracing) log.debug("bidirectional continue")
|
||||
doOutbound(isInboundClosed = false)
|
||||
}
|
||||
}
|
||||
|
||||
val flushingOutbound = TransferPhase(outboundHalfClosed) { () ⇒
|
||||
if (tracing) log.debug("flushingOutbound")
|
||||
try doWrap()
|
||||
catch { case ex: SSLException ⇒ nextPhase(completedPhase) }
|
||||
}
|
||||
|
||||
val awaitingClose = TransferPhase(inputBunch.inputsAvailableFor(TransportIn) && engineInboundOpen) { () ⇒
|
||||
if (tracing) log.debug("awaitingClose")
|
||||
transportInChoppingBlock.chopInto(transportInBuffer)
|
||||
try doUnwrap(ignoreOutput = true)
|
||||
catch { case ex: SSLException ⇒ nextPhase(completedPhase) }
|
||||
}
|
||||
|
||||
val outboundClosed = TransferPhase(outboundHalfClosed || inbound) { () ⇒
|
||||
if (tracing) log.debug("outboundClosed")
|
||||
val continue = doInbound(isOutboundClosed = true, inbound)
|
||||
if (continue && outboundHalfClosed.isReady) {
|
||||
if (tracing) log.debug("outboundClosed continue")
|
||||
try doWrap()
|
||||
catch { case ex: SSLException ⇒ nextPhase(completedPhase) }
|
||||
}
|
||||
}
|
||||
|
||||
val inboundClosed = TransferPhase(outbound || inboundHalfClosed) { () ⇒
|
||||
if (tracing) log.debug("inboundClosed")
|
||||
val continue = doInbound(isOutboundClosed = false, inboundHalfClosed)
|
||||
if (continue) {
|
||||
if (tracing) log.debug("inboundClosed continue")
|
||||
doOutbound(isInboundClosed = true)
|
||||
}
|
||||
}
|
||||
|
||||
def completeOrFlush(): Unit =
|
||||
if (engine.isOutboundDone()) nextPhase(completedPhase)
|
||||
|
|
@ -282,6 +320,7 @@ private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, s
|
|||
}
|
||||
nextPhase(outboundClosed)
|
||||
} else if (outputBunch.isCancelled(TransportOut)) {
|
||||
if (tracing) log.debug("shutting down because TransportOut is cancelled")
|
||||
nextPhase(completedPhase)
|
||||
} else if (outbound.isReady) {
|
||||
if (userHasData.isReady) userInChoppingBlock.chopInto(userInBuffer)
|
||||
|
|
@ -293,47 +332,6 @@ private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, s
|
|||
}
|
||||
}
|
||||
|
||||
val bidirectional = TransferPhase(outbound || inbound) { () ⇒
|
||||
if (tracing) log.debug("bidirectional")
|
||||
val continue = doInbound(isOutboundClosed = false, inbound)
|
||||
if (continue) {
|
||||
if (tracing) log.debug("bidirectional continue")
|
||||
doOutbound(isInboundClosed = false)
|
||||
}
|
||||
}
|
||||
|
||||
val flushingOutbound = TransferPhase(outboundHalfClosed) { () ⇒
|
||||
if (tracing) log.debug("flushingOutbound")
|
||||
try doWrap()
|
||||
catch { case ex: SSLException ⇒ nextPhase(completedPhase) }
|
||||
}
|
||||
|
||||
val awaitingClose = TransferPhase(inputBunch.inputsAvailableFor(TransportIn)) { () ⇒
|
||||
if (tracing) log.debug("awaitingClose")
|
||||
transportInChoppingBlock.chopInto(transportInBuffer)
|
||||
try doUnwrap(ignoreOutput = true)
|
||||
catch { case ex: SSLException ⇒ nextPhase(completedPhase) }
|
||||
}
|
||||
|
||||
val outboundClosed = TransferPhase(outboundHalfClosed || inbound) { () ⇒
|
||||
if (tracing) log.debug("outboundClosed")
|
||||
val continue = doInbound(isOutboundClosed = true, inbound)
|
||||
if (continue && outboundHalfClosed.isReady) {
|
||||
if (tracing) log.debug("outboundClosed continue")
|
||||
try doWrap()
|
||||
catch { case ex: SSLException ⇒ nextPhase(completedPhase) }
|
||||
}
|
||||
}
|
||||
|
||||
val inboundClosed = TransferPhase(outbound || inboundHalfClosed) { () ⇒
|
||||
if (tracing) log.debug("inboundClosed")
|
||||
val continue = doInbound(isOutboundClosed = false, inboundHalfClosed)
|
||||
if (continue) {
|
||||
if (tracing) log.debug("inboundClosed continue")
|
||||
doOutbound(isInboundClosed = true)
|
||||
}
|
||||
}
|
||||
|
||||
def flushToTransport(): Unit = {
|
||||
if (tracing) log.debug("flushToTransport")
|
||||
transportOutBuffer.flip()
|
||||
|
|
@ -427,15 +425,19 @@ private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, s
|
|||
|
||||
override def receive = inputBunch.subreceive.orElse[Any, Unit](outputBunch.subreceive)
|
||||
|
||||
nextPhase(bidirectional)
|
||||
initialPhase(2, bidirectional)
|
||||
|
||||
protected def fail(e: Throwable): Unit = {
|
||||
// FIXME: escalate to supervisor
|
||||
if (tracing) log.debug("fail {} due to: {}", self, e.getMessage)
|
||||
inputBunch.cancel()
|
||||
outputBunch.error(TransportOut, e)
|
||||
outputBunch.error(UserOut, e)
|
||||
context.stop(self)
|
||||
pump()
|
||||
}
|
||||
|
||||
override def postStop(): Unit = {
|
||||
if (tracing) log.debug("postStop")
|
||||
super.postStop()
|
||||
}
|
||||
|
||||
override protected def pumpFailed(e: Throwable): Unit = fail(e)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue