Merge pull request #17289 from akka/wip-SslTls-termination-∂π

refactor SslTlsActor and stop it reliably
This commit is contained in:
Roland Kuhn 2015-04-24 11:22:28 +02:00
commit 7b4a640147
5 changed files with 125 additions and 83 deletions

View file

@ -1,27 +1,31 @@
package akka.stream.io
import java.security.{ KeyStore, SecureRandom }
import javax.net.ssl.{ TrustManagerFactory, KeyManagerFactory, SSLContext }
import akka.stream.{ Graph, BidiShape, ActorFlowMaterializer }
import akka.stream.scaladsl._
import akka.stream.io._
import akka.stream.testkit.{ TestUtils, AkkaSpec }
import akka.util.ByteString
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.collection.immutable
import scala.util.Random
import akka.stream.stage.AsyncStage
import akka.stream.stage.AsyncContext
import java.util.concurrent.TimeoutException
import akka.actor.ActorSystem
import javax.net.ssl.SSLSession
import akka.pattern.{ after later }
import scala.concurrent.Future
import java.net.InetSocketAddress
import java.security.KeyStore
import java.security.SecureRandom
import java.util.concurrent.TimeoutException
import scala.collection.immutable
import scala.concurrent.Await
import scala.concurrent.Future
import scala.concurrent.duration._
import scala.util.Random
import akka.actor.ActorSystem
import akka.pattern.{ after later }
import akka.stream.ActorFlowMaterializer
import akka.stream.scaladsl._
import akka.stream.scaladsl.FlowGraph.Implicits._
import akka.stream.stage._
import akka.stream.testkit.AkkaSpec
import akka.stream.testkit.StreamTestKit.PublisherProbe
import akka.stream.testkit.StreamTestKit.assertAllStagesStopped
import akka.testkit.EventFilter
import akka.stream.stage.PushStage
import akka.stream.stage.Context
import akka.util.ByteString
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLSession
import javax.net.ssl.TrustManagerFactory
object TlsSpec {
@ -107,7 +111,7 @@ class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off
import FlowGraph.Implicits._
"StreamTLS" must {
"SslTls" must {
val sslContext = initSslContext()
@ -359,7 +363,7 @@ class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off
commPattern communicationPatterns
scenario scenarios
} {
s"work in mode ${commPattern.name} while sending ${scenario.name}" in {
s"work in mode ${commPattern.name} while sending ${scenario.name}" in assertAllStagesStopped {
val onRHS = debug.via(scenario.flow)
val f =
Source(scenario.inputs)
@ -391,6 +395,40 @@ class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off
}
}
"reliably cancel subscriptions when TransportIn fails early" in assertAllStagesStopped {
val ex = new Exception("hello")
val (sub, out1, out2) =
FlowGraph.closed(Source.subscriber[SslTlsOutbound], Sink.head[ByteString], Sink.head[SslTlsInbound])((_, _, _)) { implicit b
(s, o1, o2)
val tls = b.add(clientTls(EagerClose))
s ~> tls.in1; tls.out1 ~> o1
o2 <~ tls.out2; tls.in2 <~ Source.failed(ex)
}.run()
the[Exception] thrownBy Await.result(out1, 1.second) should be(ex)
the[Exception] thrownBy Await.result(out2, 1.second) should be(ex)
Thread.sleep(500)
val pub = PublisherProbe()
pub.subscribe(sub)
pub.expectSubscription().expectCancellation()
}
"reliably cancel subscriptions when UserIn fails early" in assertAllStagesStopped {
val ex = new Exception("hello")
val (sub, out1, out2) =
FlowGraph.closed(Source.subscriber[ByteString], Sink.head[ByteString], Sink.head[SslTlsInbound])((_, _, _)) { implicit b
(s, o1, o2)
val tls = b.add(clientTls(EagerClose))
Source.failed[SslTlsOutbound](ex) ~> tls.in1; tls.out1 ~> o1
o2 <~ tls.out2; tls.in2 <~ s
}.run()
the[Exception] thrownBy Await.result(out1, 1.second) should be(ex)
the[Exception] thrownBy Await.result(out2, 1.second) should be(ex)
Thread.sleep(500)
val pub = PublisherProbe()
pub.subscribe(sub)
pub.expectSubscription().expectCancellation()
}
}
"A SslTlsPlacebo" must {

View file

@ -92,7 +92,7 @@ private[akka] case class ActorFlowMaterializerImpl(
case tls: TlsModule
val es = effectiveSettings(effectiveAttributes)
val props = SslTlsCipherActor.props(es, tls.sslContext, tls.firstSession, tracing = true, tls.role, tls.closing)
val props = SslTlsCipherActor.props(es, tls.sslContext, tls.firstSession, tracing = false, tls.role, tls.closing)
val impl = actorOf(props, stageName(effectiveAttributes), es.dispatcher)
def factory(id: Int) = new ActorPublisher[Any](impl) {
override val wakeUpMsg = FanOut.SubstreamSubscribePending(id)

View file

@ -186,7 +186,7 @@ private[akka] object FanIn {
}
def inputsAvailableFor(id: Int) = new TransferState {
override def isCompleted: Boolean = depleted(id)
override def isCompleted: Boolean = depleted(id) || cancelled(id)
override def isReady: Boolean = pending(id)
}

View file

@ -69,6 +69,8 @@ private[akka] object FanOut {
def isCancelled(output: Int): Boolean = cancelled(output)
def isErrored(output: Int): Boolean = errored(output)
def complete(): Unit =
if (!bunchCancelled) {
bunchCancelled = true
@ -187,7 +189,7 @@ private[akka] object FanOut {
def onCancel(output: Int): Unit = ()
def demandAvailableFor(id: Int) = new TransferState {
override def isCompleted: Boolean = cancelled(id) || completed(id)
override def isCompleted: Boolean = cancelled(id) || completed(id) || errored(id)
override def isReady: Boolean = pending(id)
}

View file

@ -69,8 +69,8 @@ private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, s
* block.
*/
class ChoppingBlock(idx: Int, name: String) extends TransferState {
override def isReady: Boolean = buffer.nonEmpty
override def isCompleted: Boolean = false
override def isReady: Boolean = buffer.nonEmpty || inputBunch.isPending(idx) || inputBunch.isDepleted(idx)
override def isCompleted: Boolean = inputBunch.isCancelled(idx)
private var buffer = ByteString.empty
@ -203,38 +203,76 @@ private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, s
*/
var lastHandshakeStatus: HandshakeStatus = _
var corkUser = true
val engineNeedsWrap = new TransferState {
def isReady = lastHandshakeStatus == NEED_WRAP
def isCompleted = false
def isCompleted = engine.isOutboundDone
}
val engineInboundOpen = new TransferState {
def isReady = !engine.isInboundDone()
def isCompleted = false
def isReady = true
def isCompleted = engine.isInboundDone
}
var corkUser = true
val userHasData = new TransferState {
private val user = inputBunch.inputsOrCompleteAvailableFor(UserIn) || userInChoppingBlock
def isReady = !corkUser && user.isReady && lastHandshakeStatus != NEED_UNWRAP
def isCompleted = false
def isReady = !corkUser && userInChoppingBlock.isReady && lastHandshakeStatus != NEED_UNWRAP
def isCompleted = inputBunch.isCancelled(UserIn) || inputBunch.isDepleted(UserIn)
}
val transportHasData = inputBunch.inputsOrCompleteAvailableFor(TransportIn) || transportInChoppingBlock
val userOutCancelled = new TransferState {
def isReady = outputBunch.isCancelled(UserOut)
def isCompleted = inputBunch.isDepleted(TransportIn)
def isCompleted = engine.isInboundDone || outputBunch.isErrored(UserOut)
}
// bidirectional case
val outbound = (userHasData || engineNeedsWrap) && outputBunch.demandAvailableFor(TransportOut)
val inbound = (transportHasData || userOutCancelled) && outputBunch.demandOrCancelAvailableFor(UserOut)
val inbound = (transportInChoppingBlock && outputBunch.demandAvailableFor(UserOut)) || userOutCancelled
// half-closed
val outboundHalfClosed = engineNeedsWrap && outputBunch.demandAvailableFor(TransportOut)
val inboundHalfClosed = transportHasData && engineInboundOpen
val inboundHalfClosed = transportInChoppingBlock && engineInboundOpen
val bidirectional = TransferPhase(outbound || inbound) { ()
if (tracing) log.debug("bidirectional")
val continue = doInbound(isOutboundClosed = false, inbound)
if (continue) {
if (tracing) log.debug("bidirectional continue")
doOutbound(isInboundClosed = false)
}
}
val flushingOutbound = TransferPhase(outboundHalfClosed) { ()
if (tracing) log.debug("flushingOutbound")
try doWrap()
catch { case ex: SSLException nextPhase(completedPhase) }
}
val awaitingClose = TransferPhase(inputBunch.inputsAvailableFor(TransportIn) && engineInboundOpen) { ()
if (tracing) log.debug("awaitingClose")
transportInChoppingBlock.chopInto(transportInBuffer)
try doUnwrap(ignoreOutput = true)
catch { case ex: SSLException nextPhase(completedPhase) }
}
val outboundClosed = TransferPhase(outboundHalfClosed || inbound) { ()
if (tracing) log.debug("outboundClosed")
val continue = doInbound(isOutboundClosed = true, inbound)
if (continue && outboundHalfClosed.isReady) {
if (tracing) log.debug("outboundClosed continue")
try doWrap()
catch { case ex: SSLException nextPhase(completedPhase) }
}
}
val inboundClosed = TransferPhase(outbound || inboundHalfClosed) { ()
if (tracing) log.debug("inboundClosed")
val continue = doInbound(isOutboundClosed = false, inboundHalfClosed)
if (continue) {
if (tracing) log.debug("inboundClosed continue")
doOutbound(isInboundClosed = true)
}
}
def completeOrFlush(): Unit =
if (engine.isOutboundDone()) nextPhase(completedPhase)
@ -282,6 +320,7 @@ private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, s
}
nextPhase(outboundClosed)
} else if (outputBunch.isCancelled(TransportOut)) {
if (tracing) log.debug("shutting down because TransportOut is cancelled")
nextPhase(completedPhase)
} else if (outbound.isReady) {
if (userHasData.isReady) userInChoppingBlock.chopInto(userInBuffer)
@ -293,47 +332,6 @@ private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, s
}
}
val bidirectional = TransferPhase(outbound || inbound) { ()
if (tracing) log.debug("bidirectional")
val continue = doInbound(isOutboundClosed = false, inbound)
if (continue) {
if (tracing) log.debug("bidirectional continue")
doOutbound(isInboundClosed = false)
}
}
val flushingOutbound = TransferPhase(outboundHalfClosed) { ()
if (tracing) log.debug("flushingOutbound")
try doWrap()
catch { case ex: SSLException nextPhase(completedPhase) }
}
val awaitingClose = TransferPhase(inputBunch.inputsAvailableFor(TransportIn)) { ()
if (tracing) log.debug("awaitingClose")
transportInChoppingBlock.chopInto(transportInBuffer)
try doUnwrap(ignoreOutput = true)
catch { case ex: SSLException nextPhase(completedPhase) }
}
val outboundClosed = TransferPhase(outboundHalfClosed || inbound) { ()
if (tracing) log.debug("outboundClosed")
val continue = doInbound(isOutboundClosed = true, inbound)
if (continue && outboundHalfClosed.isReady) {
if (tracing) log.debug("outboundClosed continue")
try doWrap()
catch { case ex: SSLException nextPhase(completedPhase) }
}
}
val inboundClosed = TransferPhase(outbound || inboundHalfClosed) { ()
if (tracing) log.debug("inboundClosed")
val continue = doInbound(isOutboundClosed = false, inboundHalfClosed)
if (continue) {
if (tracing) log.debug("inboundClosed continue")
doOutbound(isInboundClosed = true)
}
}
def flushToTransport(): Unit = {
if (tracing) log.debug("flushToTransport")
transportOutBuffer.flip()
@ -427,15 +425,19 @@ private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, s
override def receive = inputBunch.subreceive.orElse[Any, Unit](outputBunch.subreceive)
nextPhase(bidirectional)
initialPhase(2, bidirectional)
protected def fail(e: Throwable): Unit = {
// FIXME: escalate to supervisor
if (tracing) log.debug("fail {} due to: {}", self, e.getMessage)
inputBunch.cancel()
outputBunch.error(TransportOut, e)
outputBunch.error(UserOut, e)
context.stop(self)
pump()
}
override def postStop(): Unit = {
if (tracing) log.debug("postStop")
super.postStop()
}
override protected def pumpFailed(e: Throwable): Unit = fail(e)