Merge pull request #18655 from AnIrishDuck/wip-18058-tls-errors

#18058 Bubble up tls errors
This commit is contained in:
Konrad Malawski 2015-10-19 16:31:47 +02:00
commit 8270fedb6c
3 changed files with 38 additions and 9 deletions

Binary file not shown.

View file

@ -19,24 +19,20 @@ import akka.stream.testkit._
import akka.stream.testkit.Utils._
import akka.testkit.EventFilter
import akka.util.ByteString
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLSession
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl._
object TlsSpec {
val rnd = new Random
def initSslContext(): SSLContext = {
def initWithTrust(trustPath: String) = {
val password = "changeme"
val keyStore = KeyStore.getInstance(KeyStore.getDefaultType)
keyStore.load(getClass.getResourceAsStream("/keystore"), password.toCharArray)
val trustStore = KeyStore.getInstance(KeyStore.getDefaultType)
trustStore.load(getClass.getResourceAsStream("/truststore"), password.toCharArray)
trustStore.load(getClass.getResourceAsStream(trustPath), password.toCharArray)
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm)
keyManagerFactory.init(keyStore, password.toCharArray)
@ -49,6 +45,8 @@ object TlsSpec {
context
}
def initSslContext(): SSLContext = initWithTrust("/truststore")
/**
* This is a stage that fires a TimeoutException failure 2 seconds after it was started,
* independent of the traffic going through. The purpose is to include the last seen
@ -110,6 +108,7 @@ class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off
val cipherSuites = NegotiateNewSession.withCipherSuites("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_128_CBC_SHA")
def clientTls(closing: Closing) = SslTls(sslContext, cipherSuites, Client, closing)
def badClientTls(closing: Closing) = SslTls(initWithTrust("/badtruststore"), cipherSuites, Client, closing)
def serverTls(closing: Closing) = SslTls(sslContext, cipherSuites, Server, closing)
trait Named {
@ -380,6 +379,30 @@ class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off
}
}
"emit an error if the TLS handshake fails certificate checks" in assertAllStagesStopped {
val getError = Flow[SslTlsInbound]
.map[Either[SslTlsInbound, SSLException]](i => Left(i))
.recover { case e: SSLException => Right(e) }
.collect { case Right(e) => e }.toMat(Sink.head)(Keep.right)
val simple = Flow.wrap(getError, Source.lazyEmpty[SslTlsOutbound])(Keep.left)
// The creation of actual TCP connections is necessary. It is the easiest way to decouple the client and server
// under error conditions, and has the bonus of matching most actual SSL deployments.
val (server, serverErr) = Tcp()
.bind("localhost", 0)
.map(c {
c.flow.joinMat(serverTls(IgnoreBoth).reversed.joinMat(simple)(Keep.right))(Keep.right).run()
})
.toMat(Sink.head)(Keep.both).run()
val clientErr = simple.join(badClientTls(IgnoreBoth))
.join(Tcp().outgoingConnection(Await.result(server, 1.second).localAddress)).run()
Await.result(serverErr.flatMap(identity), 1.second).getMessage should include ("certificate_unknown")
Await.result(clientErr, 1.second).getMessage should equal ("General SSLEngine problem")
}
"reliably cancel subscriptions when TransportIn fails early" in assertAllStagesStopped {
val ex = new Exception("hello")
val (sub, out1, out2) =

View file

@ -299,6 +299,8 @@ private[akka] class SslTlsCipherActor(settings: ActorMaterializerSettings, sslCo
} catch {
case ex: SSLException
if (tracing) log.debug(s"SSLException during doUnwrap: $ex")
fail(ex, closeTransport = false)
engine.closeInbound()
completeOrFlush()
false
}
@ -323,6 +325,7 @@ private[akka] class SslTlsCipherActor(settings: ActorMaterializerSettings, sslCo
catch {
case ex: SSLException
if (tracing) log.debug(s"SSLException during doWrap: $ex")
fail(ex, closeTransport = false)
completeOrFlush()
}
}
@ -422,10 +425,13 @@ private[akka] class SslTlsCipherActor(settings: ActorMaterializerSettings, sslCo
initialPhase(2, bidirectional)
protected def fail(e: Throwable): Unit = {
protected def fail(e: Throwable, closeTransport: Boolean=true): Unit = {
if (tracing) log.debug("fail {} due to: {}", self, e.getMessage)
inputBunch.cancel()
outputBunch.error(TransportOut, e)
if(closeTransport) {
log.debug("closing output")
outputBunch.error(TransportOut, e)
}
outputBunch.error(UserOut, e)
pump()
}