diff --git a/akka-stream-tests/src/test/resources/badtruststore b/akka-stream-tests/src/test/resources/badtruststore new file mode 100644 index 0000000000..52d9e5987c Binary files /dev/null and b/akka-stream-tests/src/test/resources/badtruststore differ diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala index ce5025021c..3d6bcf510d 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala @@ -19,24 +19,20 @@ import akka.stream.testkit._ import akka.stream.testkit.Utils._ import akka.testkit.EventFilter import akka.util.ByteString -import javax.net.ssl.KeyManagerFactory -import javax.net.ssl.SSLContext -import javax.net.ssl.SSLSession -import javax.net.ssl.TrustManagerFactory +import javax.net.ssl._ object TlsSpec { val rnd = new Random - def initSslContext(): SSLContext = { - + def initWithTrust(trustPath: String) = { val password = "changeme" val keyStore = KeyStore.getInstance(KeyStore.getDefaultType) keyStore.load(getClass.getResourceAsStream("/keystore"), password.toCharArray) val trustStore = KeyStore.getInstance(KeyStore.getDefaultType) - trustStore.load(getClass.getResourceAsStream("/truststore"), password.toCharArray) + trustStore.load(getClass.getResourceAsStream(trustPath), password.toCharArray) val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm) keyManagerFactory.init(keyStore, password.toCharArray) @@ -49,6 +45,8 @@ object TlsSpec { context } + def initSslContext(): SSLContext = initWithTrust("/truststore") + /** * This is a stage that fires a TimeoutException failure 2 seconds after it was started, * independent of the traffic going through. The purpose is to include the last seen @@ -110,6 +108,7 @@ class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off val cipherSuites = NegotiateNewSession.withCipherSuites("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_128_CBC_SHA") def clientTls(closing: Closing) = SslTls(sslContext, cipherSuites, Client, closing) + def badClientTls(closing: Closing) = SslTls(initWithTrust("/badtruststore"), cipherSuites, Client, closing) def serverTls(closing: Closing) = SslTls(sslContext, cipherSuites, Server, closing) trait Named { @@ -380,6 +379,30 @@ class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off } } + "emit an error if the TLS handshake fails certificate checks" in assertAllStagesStopped { + val getError = Flow[SslTlsInbound] + .map[Either[SslTlsInbound, SSLException]](i => Left(i)) + .recover { case e: SSLException => Right(e) } + .collect { case Right(e) => e }.toMat(Sink.head)(Keep.right) + + val simple = Flow.wrap(getError, Source.lazyEmpty[SslTlsOutbound])(Keep.left) + + // The creation of actual TCP connections is necessary. It is the easiest way to decouple the client and server + // under error conditions, and has the bonus of matching most actual SSL deployments. + val (server, serverErr) = Tcp() + .bind("localhost", 0) + .map(c ⇒ { + c.flow.joinMat(serverTls(IgnoreBoth).reversed.joinMat(simple)(Keep.right))(Keep.right).run() + }) + .toMat(Sink.head)(Keep.both).run() + + val clientErr = simple.join(badClientTls(IgnoreBoth)) + .join(Tcp().outgoingConnection(Await.result(server, 1.second).localAddress)).run() + + Await.result(serverErr.flatMap(identity), 1.second).getMessage should include ("certificate_unknown") + Await.result(clientErr, 1.second).getMessage should equal ("General SSLEngine problem") + } + "reliably cancel subscriptions when TransportIn fails early" in assertAllStagesStopped { val ex = new Exception("hello") val (sub, out1, out2) = diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/SslTlsCipherActor.scala b/akka-stream/src/main/scala/akka/stream/impl/io/SslTlsCipherActor.scala index 8780f53b8d..7ff5762e40 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/SslTlsCipherActor.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/SslTlsCipherActor.scala @@ -299,6 +299,8 @@ private[akka] class SslTlsCipherActor(settings: ActorMaterializerSettings, sslCo } catch { case ex: SSLException ⇒ if (tracing) log.debug(s"SSLException during doUnwrap: $ex") + fail(ex, closeTransport = false) + engine.closeInbound() completeOrFlush() false } @@ -323,6 +325,7 @@ private[akka] class SslTlsCipherActor(settings: ActorMaterializerSettings, sslCo catch { case ex: SSLException ⇒ if (tracing) log.debug(s"SSLException during doWrap: $ex") + fail(ex, closeTransport = false) completeOrFlush() } } @@ -422,10 +425,13 @@ private[akka] class SslTlsCipherActor(settings: ActorMaterializerSettings, sslCo initialPhase(2, bidirectional) - protected def fail(e: Throwable): Unit = { + protected def fail(e: Throwable, closeTransport: Boolean=true): Unit = { if (tracing) log.debug("fail {} due to: {}", self, e.getMessage) inputBunch.cancel() - outputBunch.error(TransportOut, e) + if(closeTransport) { + log.debug("closing output") + outputBunch.error(TransportOut, e) + } outputBunch.error(UserOut, e) pump() }