Merge pull request #1679 from akka/wip-3389-fix-ssl-engine-shutdown-ban

=act #3389 Improve SSL closing sequence
This commit is contained in:
Björn Antonsson 2013-08-27 05:55:10 -07:00
commit ba92d38fbd
2 changed files with 103 additions and 11 deletions

View file

@ -54,6 +54,18 @@ object SslTlsSupport {
* *
* Each instance of this stage has a scratch [[ByteBuffer]] of approx. 18kiB * Each instance of this stage has a scratch [[ByteBuffer]] of approx. 18kiB
* allocated which is used by the SSLEngine. * allocated which is used by the SSLEngine.
*
* One thing to keep in mind is that there's no support for half-closed connections
* in SSL (but SSL on the other side requires half-closed connections from its transport
* layer).
*
* This means:
* 1. keepOpenOnPeerClosed is not supported on top of SSL (once you receive PeerClosed
* the connection is closed, further CloseCommands are ignored)
* 2. keepOpenOnPeerClosed should always be enabled on the transport layer beneath SSL so
* that one can wait for the other side's SSL level close_notify message without barfing
* RST to the peer because this socket is already gone
*
*/ */
class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command, Command, Event, Event] { class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command, Command, Event, Event] {
@ -64,6 +76,7 @@ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command
val log = ctx.getLogger val log = ctx.getLogger
// TODO: should this be a ThreadLocal? // TODO: should this be a ThreadLocal?
val tempBuf = ByteBuffer.allocate(SslTlsSupport.MaxPacketSize) val tempBuf = ByteBuffer.allocate(SslTlsSupport.MaxPacketSize)
var originalCloseCommand: Tcp.CloseCommand = _
override val commandPipeline = (cmd: Command) cmd match { override val commandPipeline = (cmd: Command) cmd match {
case x: Tcp.Write case x: Tcp.Write
@ -74,9 +87,12 @@ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command
} }
case x @ (Tcp.Close | Tcp.ConfirmedClose) case x @ (Tcp.Close | Tcp.ConfirmedClose)
originalCloseCommand = x.asInstanceOf[Tcp.CloseCommand]
log.debug("Closing SSLEngine due to reception of [{}]", x) log.debug("Closing SSLEngine due to reception of [{}]", x)
engine.closeOutbound() engine.closeOutbound()
closeEngine() :+ Right(x) // don't send close command to network here, it's the job of the SSL engine
// to shutdown the connection when getting CLOSED in encrypt
closeEngine()
case cmd ctx.singleCommand(cmd) case cmd ctx.singleCommand(cmd)
} }
@ -92,11 +108,20 @@ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command
decrypt(buf) decrypt(buf)
case x: Tcp.ConnectionClosed case x: Tcp.ConnectionClosed
if (!engine.isOutboundDone) { // After we have closed the connection we ignore FIN from the other side.
// That's to avoid a strange condition where we know that no truncation attack
// can happen any more (because we actively closed the connection) but the peer
// isn't behaving properly and didn't send close_notify. Why is this condition strange?
// Because if we had closed the connection directly after we sent close_notify (which
// is allowed by the spec) we wouldn't even have noticed.
if (!engine.isOutboundDone)
try engine.closeInbound() try engine.closeInbound()
catch { case e: SSLException } // ignore warning about possible truncation attacks catch { case e: SSLException } // ignore warning about possible truncation attacks
}
ctx.singleEvent(x) if (x.isAborted || (originalCloseCommand eq null)) ctx.singleEvent(x)
else if (!engine.isInboundDone) ctx.singleEvent(originalCloseCommand.event)
// else the close message was sent by decrypt case CLOSED
else ctx.singleEvent(x)
case ev ctx.singleEvent(ev) case ev ctx.singleEvent(ev)
} }
@ -139,8 +164,8 @@ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command
case CLOSED case CLOSED
if (postContentLeft) { if (postContentLeft) {
log.warning("SSLEngine closed prematurely while sending") log.warning("SSLEngine closed prematurely while sending")
nextCmds :+ Right(Tcp.Close) nextCmds :+ Right(Tcp.Abort)
} else nextCmds } else nextCmds :+ Right(Tcp.ConfirmedClose)
case BUFFER_OVERFLOW case BUFFER_OVERFLOW
throw new IllegalStateException("BUFFER_OVERFLOW: the SslBufferPool should make sure that buffers are never too small") throw new IllegalStateException("BUFFER_OVERFLOW: the SslBufferPool should make sure that buffers are never too small")
case BUFFER_UNDERFLOW case BUFFER_UNDERFLOW
@ -180,9 +205,11 @@ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command
} }
case CLOSED case CLOSED
if (!engine.isOutboundDone) { if (!engine.isOutboundDone) {
log.warning("SSLEngine closed prematurely while receiving") closeEngine(nextOutput :+ Left(Tcp.PeerClosed))
nextOutput :+ Right(Tcp.Close) } else { // now both sides are closed on the SSL level
} else nextOutput // close the underlying connection, we don't need it any more
nextOutput :+ Left(originalCloseCommand.event) :+ Right(Tcp.Close)
}
case BUFFER_UNDERFLOW case BUFFER_UNDERFLOW
inboundReceptacle = buffer // save buffer so we can append the next one to it inboundReceptacle = buffer // save buffer so we can append the next one to it
nextOutput nextOutput

View file

@ -39,7 +39,7 @@ import akka.io.TcpReadWriteAdapter
import akka.remote.security.provider.AkkaProvider import akka.remote.security.provider.AkkaProvider
import akka.testkit.{ AkkaSpec, TestProbe } import akka.testkit.{ AkkaSpec, TestProbe }
import akka.util.{ ByteString, Timeout } import akka.util.{ ByteString, Timeout }
import javax.net.ssl.{ KeyManagerFactory, SSLContext, SSLServerSocket, SSLSocket, TrustManagerFactory } import javax.net.ssl._
import akka.actor.Deploy import akka.actor.Deploy
// TODO move this into akka-actor once AkkaProvider for SecureRandom does not have external dependencies // TODO move this into akka-actor once AkkaProvider for SecureRandom does not have external dependencies
@ -52,22 +52,32 @@ class SslTlsSupportSpec extends AkkaSpec {
"The SslTlsSupport" should { "The SslTlsSupport" should {
"work between a Java client and a Java server" in { "work between a Java client and a Java server" in {
invalidateSessions()
val server = new JavaSslServer val server = new JavaSslServer
val client = new JavaSslClient(server.address) val client = new JavaSslClient(server.address)
client.run() client.run()
val baselineSessionCounts = sessionCounts()
client.close() client.close()
// make sure not to lose sessions by invalid session closure
sessionCounts() === baselineSessionCounts
server.close() server.close()
sessionCounts() === baselineSessionCounts // see above
} }
"work between a akka client and a Java server" in { "work between a akka client and a Java server" in {
invalidateSessions()
val server = new JavaSslServer val server = new JavaSslServer
val client = new AkkaSslClient(server.address) val client = new AkkaSslClient(server.address)
client.run() client.run()
val baselineSessionCounts = sessionCounts()
client.close() client.close()
sessionCounts() === baselineSessionCounts // see above
server.close() server.close()
sessionCounts() === baselineSessionCounts // see above
} }
"work between a Java client and a akka server" in { "work between a Java client and a akka server" in {
invalidateSessions()
val serverAddress = TestUtils.temporaryServerAddress() val serverAddress = TestUtils.temporaryServerAddress()
val probe = TestProbe() val probe = TestProbe()
val bindHandler = probe.watch(system.actorOf(Props(new AkkaSslServer(serverAddress)).withDeploy(Deploy.local), "server1")) val bindHandler = probe.watch(system.actorOf(Props(new AkkaSslServer(serverAddress)).withDeploy(Deploy.local), "server1"))
@ -75,11 +85,14 @@ class SslTlsSupportSpec extends AkkaSpec {
val client = new JavaSslClient(serverAddress) val client = new JavaSslClient(serverAddress)
client.run() client.run()
val baselineSessionCounts = sessionCounts()
client.close() client.close()
sessionCounts() === baselineSessionCounts // see above
probe.expectTerminated(bindHandler) probe.expectTerminated(bindHandler)
} }
"work between a akka client and a akka server" in { "work between a akka client and a akka server" in {
invalidateSessions()
val serverAddress = TestUtils.temporaryServerAddress() val serverAddress = TestUtils.temporaryServerAddress()
val probe = TestProbe() val probe = TestProbe()
val bindHandler = probe.watch(system.actorOf(Props(new AkkaSslServer(serverAddress)).withDeploy(Deploy.local), "server2")) val bindHandler = probe.watch(system.actorOf(Props(new AkkaSslServer(serverAddress)).withDeploy(Deploy.local), "server2"))
@ -87,9 +100,36 @@ class SslTlsSupportSpec extends AkkaSpec {
val client = new AkkaSslClient(serverAddress) val client = new AkkaSslClient(serverAddress)
client.run() client.run()
val baselineSessionCounts = sessionCounts()
client.close() client.close()
sessionCounts() === baselineSessionCounts // see above
probe.expectTerminated(bindHandler) probe.expectTerminated(bindHandler)
} }
"work between an akka client and a Java server with confirmedClose" in {
invalidateSessions()
val server = new JavaSslServer
val client = new AkkaSslClient(server.address)
client.run()
val baselineSessionCounts = sessionCounts()
client.closeConfirmed()
sessionCounts() === baselineSessionCounts // see above
server.close()
sessionCounts() === baselineSessionCounts // see above
}
"akka client runs the full shutdown sequence if peer closes" in {
invalidateSessions()
val server = new JavaSslServer
val client = new AkkaSslClient(server.address)
client.run()
val baselineSessionCounts = serverSessions().length
server.close()
client.peerClosed()
// we only check the akka side server sessions here
// the java client seems to lose the session for some reason
serverSessions().length === baselineSessionCounts
}
} }
val counter = new AtomicInteger val counter = new AtomicInteger
@ -112,7 +152,7 @@ class SslTlsSupportSpec extends AkkaSpec {
val handler = system.actorOf(TcpPipelineHandler.props(init, connection, probe.ref).withDeploy(Deploy.local), val handler = system.actorOf(TcpPipelineHandler.props(init, connection, probe.ref).withDeploy(Deploy.local),
"client" + counter.incrementAndGet()) "client" + counter.incrementAndGet())
probe.send(connection, Tcp.Register(handler)) probe.send(connection, Tcp.Register(handler, keepOpenOnPeerClosed = true))
def run() { def run() {
probe.send(handler, Command("3+4\n")) probe.send(handler, Command("3+4\n"))
@ -127,12 +167,22 @@ class SslTlsSupportSpec extends AkkaSpec {
probe.expectMsg(Event("0\n")) probe.expectMsg(Event("0\n"))
} }
def peerClosed(): Unit = {
probe.expectMsg(Tcp.PeerClosed)
TestUtils.verifyActorTermination(handler)
}
def close() { def close() {
probe.send(handler, Management(Tcp.Close)) probe.send(handler, Management(Tcp.Close))
probe.expectMsgType[Tcp.ConnectionClosed] probe.expectMsgType[Tcp.ConnectionClosed]
TestUtils.verifyActorTermination(handler) TestUtils.verifyActorTermination(handler)
} }
def closeConfirmed(): Unit = {
probe.send(handler, Management(Tcp.ConfirmedClose))
probe.expectMsg(Tcp.ConfirmedClosed)
TestUtils.verifyActorTermination(handler)
}
} }
//#server //#server
@ -270,6 +320,21 @@ class SslTlsSupportSpec extends AkkaSpec {
engine engine
} }
import collection.JavaConverters._
def clientSessions() = sessions(_.getServerSessionContext)
def serverSessions() = sessions(_.getClientSessionContext)
def sessionCounts() = (clientSessions().length, serverSessions().length)
def sessions(f: SSLContext SSLSessionContext): Seq[SSLSession] = {
val ctx = f(sslContext)
val ids = ctx.getIds().asScala.toIndexedSeq
ids.map(ctx.getSession)
}
def invalidateSessions() = {
clientSessions().foreach(_.invalidate())
serverSessions().foreach(_.invalidate())
}
} }
object SslTlsSupportSpec { object SslTlsSupportSpec {