From d88de3f79da8d674b4aae63e046bd2053c0363f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bjo=CC=88rn=20Antonsson?= Date: Mon, 26 Aug 2013 13:26:15 +0200 Subject: [PATCH] =act #3389 Improve SSL closing sequence * Port of this fix in Spray by @jrudolph from this pull request https://github.com/spray/spray/pull/400 --- .../main/scala/akka/io/SslTlsSupport.scala | 45 +++++++++--- .../scala/akka/io/ssl/SslTlsSupportSpec.scala | 69 ++++++++++++++++++- 2 files changed, 103 insertions(+), 11 deletions(-) diff --git a/akka-actor/src/main/scala/akka/io/SslTlsSupport.scala b/akka-actor/src/main/scala/akka/io/SslTlsSupport.scala index 43e78bfe5b..ba70b18e22 100644 --- a/akka-actor/src/main/scala/akka/io/SslTlsSupport.scala +++ b/akka-actor/src/main/scala/akka/io/SslTlsSupport.scala @@ -54,6 +54,18 @@ object SslTlsSupport { * * Each instance of this stage has a scratch [[ByteBuffer]] of approx. 18kiB * allocated which is used by the SSLEngine. + * + * One thing to keep in mind is that there's no support for half-closed connections + * in SSL (but SSL on the other side requires half-closed connections from its transport + * layer). + * + * This means: + * 1. keepOpenOnPeerClosed is not supported on top of SSL (once you receive PeerClosed + * the connection is closed, further CloseCommands are ignored) + * 2. keepOpenOnPeerClosed should always be enabled on the transport layer beneath SSL so + * that one can wait for the other side's SSL level close_notify message without barfing + * RST to the peer because this socket is already gone + * */ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command, Command, Event, Event] { @@ -64,6 +76,7 @@ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command val log = ctx.getLogger // TODO: should this be a ThreadLocal? val tempBuf = ByteBuffer.allocate(SslTlsSupport.MaxPacketSize) + var originalCloseCommand: Tcp.CloseCommand = _ override val commandPipeline = (cmd: Command) ⇒ cmd match { case x: Tcp.Write ⇒ @@ -74,9 +87,12 @@ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command } case x @ (Tcp.Close | Tcp.ConfirmedClose) ⇒ + originalCloseCommand = x.asInstanceOf[Tcp.CloseCommand] log.debug("Closing SSLEngine due to reception of [{}]", x) engine.closeOutbound() - closeEngine() :+ Right(x) + // don't send close command to network here, it's the job of the SSL engine + // to shutdown the connection when getting CLOSED in encrypt + closeEngine() case cmd ⇒ ctx.singleCommand(cmd) } @@ -92,11 +108,20 @@ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command decrypt(buf) case x: Tcp.ConnectionClosed ⇒ - if (!engine.isOutboundDone) { + // After we have closed the connection we ignore FIN from the other side. + // That's to avoid a strange condition where we know that no truncation attack + // can happen any more (because we actively closed the connection) but the peer + // isn't behaving properly and didn't send close_notify. Why is this condition strange? + // Because if we had closed the connection directly after we sent close_notify (which + // is allowed by the spec) we wouldn't even have noticed. + if (!engine.isOutboundDone) try engine.closeInbound() catch { case e: SSLException ⇒ } // ignore warning about possible truncation attacks - } - ctx.singleEvent(x) + + if (x.isAborted || (originalCloseCommand eq null)) ctx.singleEvent(x) + else if (!engine.isInboundDone) ctx.singleEvent(originalCloseCommand.event) + // else the close message was sent by decrypt case CLOSED + else ctx.singleEvent(x) case ev ⇒ ctx.singleEvent(ev) } @@ -139,8 +164,8 @@ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command case CLOSED ⇒ if (postContentLeft) { log.warning("SSLEngine closed prematurely while sending") - nextCmds :+ Right(Tcp.Close) - } else nextCmds + nextCmds :+ Right(Tcp.Abort) + } else nextCmds :+ Right(Tcp.ConfirmedClose) case BUFFER_OVERFLOW ⇒ throw new IllegalStateException("BUFFER_OVERFLOW: the SslBufferPool should make sure that buffers are never too small") case BUFFER_UNDERFLOW ⇒ @@ -180,9 +205,11 @@ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command } case CLOSED ⇒ if (!engine.isOutboundDone) { - log.warning("SSLEngine closed prematurely while receiving") - nextOutput :+ Right(Tcp.Close) - } else nextOutput + closeEngine(nextOutput :+ Left(Tcp.PeerClosed)) + } else { // now both sides are closed on the SSL level + // close the underlying connection, we don't need it any more + nextOutput :+ Left(originalCloseCommand.event) :+ Right(Tcp.Close) + } case BUFFER_UNDERFLOW ⇒ inboundReceptacle = buffer // save buffer so we can append the next one to it nextOutput diff --git a/akka-remote/src/test/scala/akka/io/ssl/SslTlsSupportSpec.scala b/akka-remote/src/test/scala/akka/io/ssl/SslTlsSupportSpec.scala index d7791cce4a..ca81d7ce0b 100644 --- a/akka-remote/src/test/scala/akka/io/ssl/SslTlsSupportSpec.scala +++ b/akka-remote/src/test/scala/akka/io/ssl/SslTlsSupportSpec.scala @@ -39,7 +39,7 @@ import akka.io.TcpReadWriteAdapter import akka.remote.security.provider.AkkaProvider import akka.testkit.{ AkkaSpec, TestProbe } import akka.util.{ ByteString, Timeout } -import javax.net.ssl.{ KeyManagerFactory, SSLContext, SSLServerSocket, SSLSocket, TrustManagerFactory } +import javax.net.ssl._ import akka.actor.Deploy // TODO move this into akka-actor once AkkaProvider for SecureRandom does not have external dependencies @@ -52,22 +52,32 @@ class SslTlsSupportSpec extends AkkaSpec { "The SslTlsSupport" should { "work between a Java client and a Java server" in { + invalidateSessions() val server = new JavaSslServer val client = new JavaSslClient(server.address) client.run() + val baselineSessionCounts = sessionCounts() client.close() + // make sure not to lose sessions by invalid session closure + sessionCounts() === baselineSessionCounts server.close() + sessionCounts() === baselineSessionCounts // see above } "work between a akka client and a Java server" in { + invalidateSessions() val server = new JavaSslServer val client = new AkkaSslClient(server.address) client.run() + val baselineSessionCounts = sessionCounts() client.close() + sessionCounts() === baselineSessionCounts // see above server.close() + sessionCounts() === baselineSessionCounts // see above } "work between a Java client and a akka server" in { + invalidateSessions() val serverAddress = TestUtils.temporaryServerAddress() val probe = TestProbe() val bindHandler = probe.watch(system.actorOf(Props(new AkkaSslServer(serverAddress)).withDeploy(Deploy.local), "server1")) @@ -75,11 +85,14 @@ class SslTlsSupportSpec extends AkkaSpec { val client = new JavaSslClient(serverAddress) client.run() + val baselineSessionCounts = sessionCounts() client.close() + sessionCounts() === baselineSessionCounts // see above probe.expectTerminated(bindHandler) } "work between a akka client and a akka server" in { + invalidateSessions() val serverAddress = TestUtils.temporaryServerAddress() val probe = TestProbe() val bindHandler = probe.watch(system.actorOf(Props(new AkkaSslServer(serverAddress)).withDeploy(Deploy.local), "server2")) @@ -87,9 +100,36 @@ class SslTlsSupportSpec extends AkkaSpec { val client = new AkkaSslClient(serverAddress) client.run() + val baselineSessionCounts = sessionCounts() client.close() + sessionCounts() === baselineSessionCounts // see above probe.expectTerminated(bindHandler) } + + "work between an akka client and a Java server with confirmedClose" in { + invalidateSessions() + val server = new JavaSslServer + val client = new AkkaSslClient(server.address) + client.run() + val baselineSessionCounts = sessionCounts() + client.closeConfirmed() + sessionCounts() === baselineSessionCounts // see above + server.close() + sessionCounts() === baselineSessionCounts // see above + } + + "akka client runs the full shutdown sequence if peer closes" in { + invalidateSessions() + val server = new JavaSslServer + val client = new AkkaSslClient(server.address) + client.run() + val baselineSessionCounts = serverSessions().length + server.close() + client.peerClosed() + // we only check the akka side server sessions here + // the java client seems to lose the session for some reason + serverSessions().length === baselineSessionCounts + } } val counter = new AtomicInteger @@ -112,7 +152,7 @@ class SslTlsSupportSpec extends AkkaSpec { val handler = system.actorOf(TcpPipelineHandler.props(init, connection, probe.ref).withDeploy(Deploy.local), "client" + counter.incrementAndGet()) - probe.send(connection, Tcp.Register(handler)) + probe.send(connection, Tcp.Register(handler, keepOpenOnPeerClosed = true)) def run() { probe.send(handler, Command("3+4\n")) @@ -127,12 +167,22 @@ class SslTlsSupportSpec extends AkkaSpec { probe.expectMsg(Event("0\n")) } + def peerClosed(): Unit = { + probe.expectMsg(Tcp.PeerClosed) + TestUtils.verifyActorTermination(handler) + } + def close() { probe.send(handler, Management(Tcp.Close)) probe.expectMsgType[Tcp.ConnectionClosed] TestUtils.verifyActorTermination(handler) } + def closeConfirmed(): Unit = { + probe.send(handler, Management(Tcp.ConfirmedClose)) + probe.expectMsg(Tcp.ConfirmedClosed) + TestUtils.verifyActorTermination(handler) + } } //#server @@ -270,6 +320,21 @@ class SslTlsSupportSpec extends AkkaSpec { engine } + import collection.JavaConverters._ + def clientSessions() = sessions(_.getServerSessionContext) + def serverSessions() = sessions(_.getClientSessionContext) + def sessionCounts() = (clientSessions().length, serverSessions().length) + + def sessions(f: SSLContext ⇒ SSLSessionContext): Seq[SSLSession] = { + val ctx = f(sslContext) + val ids = ctx.getIds().asScala.toIndexedSeq + ids.map(ctx.getSession) + } + + def invalidateSessions() = { + clientSessions().foreach(_.invalidate()) + serverSessions().foreach(_.invalidate()) + } } object SslTlsSupportSpec {