Merge pull request #1679 from akka/wip-3389-fix-ssl-engine-shutdown-ban

=act #3389 Improve SSL closing sequence
This commit is contained in:
Björn Antonsson 2013-08-27 05:55:10 -07:00
commit ba92d38fbd
2 changed files with 103 additions and 11 deletions

View file

@ -54,6 +54,18 @@ object SslTlsSupport {
*
* Each instance of this stage has a scratch [[ByteBuffer]] of approx. 18kiB
* allocated which is used by the SSLEngine.
*
* One thing to keep in mind is that there's no support for half-closed connections
* in SSL (but SSL on the other side requires half-closed connections from its transport
* layer).
*
* This means:
* 1. keepOpenOnPeerClosed is not supported on top of SSL (once you receive PeerClosed
* the connection is closed, further CloseCommands are ignored)
* 2. keepOpenOnPeerClosed should always be enabled on the transport layer beneath SSL so
* that one can wait for the other side's SSL level close_notify message without barfing
* RST to the peer because this socket is already gone
*
*/
class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command, Command, Event, Event] {
@ -64,6 +76,7 @@ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command
val log = ctx.getLogger
// TODO: should this be a ThreadLocal?
val tempBuf = ByteBuffer.allocate(SslTlsSupport.MaxPacketSize)
var originalCloseCommand: Tcp.CloseCommand = _
override val commandPipeline = (cmd: Command) cmd match {
case x: Tcp.Write
@ -74,9 +87,12 @@ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command
}
case x @ (Tcp.Close | Tcp.ConfirmedClose)
originalCloseCommand = x.asInstanceOf[Tcp.CloseCommand]
log.debug("Closing SSLEngine due to reception of [{}]", x)
engine.closeOutbound()
closeEngine() :+ Right(x)
// don't send close command to network here, it's the job of the SSL engine
// to shutdown the connection when getting CLOSED in encrypt
closeEngine()
case cmd ctx.singleCommand(cmd)
}
@ -92,11 +108,20 @@ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command
decrypt(buf)
case x: Tcp.ConnectionClosed
if (!engine.isOutboundDone) {
// After we have closed the connection we ignore FIN from the other side.
// That's to avoid a strange condition where we know that no truncation attack
// can happen any more (because we actively closed the connection) but the peer
// isn't behaving properly and didn't send close_notify. Why is this condition strange?
// Because if we had closed the connection directly after we sent close_notify (which
// is allowed by the spec) we wouldn't even have noticed.
if (!engine.isOutboundDone)
try engine.closeInbound()
catch { case e: SSLException } // ignore warning about possible truncation attacks
}
ctx.singleEvent(x)
if (x.isAborted || (originalCloseCommand eq null)) ctx.singleEvent(x)
else if (!engine.isInboundDone) ctx.singleEvent(originalCloseCommand.event)
// else the close message was sent by decrypt case CLOSED
else ctx.singleEvent(x)
case ev ctx.singleEvent(ev)
}
@ -139,8 +164,8 @@ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command
case CLOSED
if (postContentLeft) {
log.warning("SSLEngine closed prematurely while sending")
nextCmds :+ Right(Tcp.Close)
} else nextCmds
nextCmds :+ Right(Tcp.Abort)
} else nextCmds :+ Right(Tcp.ConfirmedClose)
case BUFFER_OVERFLOW
throw new IllegalStateException("BUFFER_OVERFLOW: the SslBufferPool should make sure that buffers are never too small")
case BUFFER_UNDERFLOW
@ -180,9 +205,11 @@ class SslTlsSupport(engine: SSLEngine) extends PipelineStage[HasLogging, Command
}
case CLOSED
if (!engine.isOutboundDone) {
log.warning("SSLEngine closed prematurely while receiving")
nextOutput :+ Right(Tcp.Close)
} else nextOutput
closeEngine(nextOutput :+ Left(Tcp.PeerClosed))
} else { // now both sides are closed on the SSL level
// close the underlying connection, we don't need it any more
nextOutput :+ Left(originalCloseCommand.event) :+ Right(Tcp.Close)
}
case BUFFER_UNDERFLOW
inboundReceptacle = buffer // save buffer so we can append the next one to it
nextOutput

View file

@ -39,7 +39,7 @@ import akka.io.TcpReadWriteAdapter
import akka.remote.security.provider.AkkaProvider
import akka.testkit.{ AkkaSpec, TestProbe }
import akka.util.{ ByteString, Timeout }
import javax.net.ssl.{ KeyManagerFactory, SSLContext, SSLServerSocket, SSLSocket, TrustManagerFactory }
import javax.net.ssl._
import akka.actor.Deploy
// TODO move this into akka-actor once AkkaProvider for SecureRandom does not have external dependencies
@ -52,22 +52,32 @@ class SslTlsSupportSpec extends AkkaSpec {
"The SslTlsSupport" should {
"work between a Java client and a Java server" in {
invalidateSessions()
val server = new JavaSslServer
val client = new JavaSslClient(server.address)
client.run()
val baselineSessionCounts = sessionCounts()
client.close()
// make sure not to lose sessions by invalid session closure
sessionCounts() === baselineSessionCounts
server.close()
sessionCounts() === baselineSessionCounts // see above
}
"work between a akka client and a Java server" in {
invalidateSessions()
val server = new JavaSslServer
val client = new AkkaSslClient(server.address)
client.run()
val baselineSessionCounts = sessionCounts()
client.close()
sessionCounts() === baselineSessionCounts // see above
server.close()
sessionCounts() === baselineSessionCounts // see above
}
"work between a Java client and a akka server" in {
invalidateSessions()
val serverAddress = TestUtils.temporaryServerAddress()
val probe = TestProbe()
val bindHandler = probe.watch(system.actorOf(Props(new AkkaSslServer(serverAddress)).withDeploy(Deploy.local), "server1"))
@ -75,11 +85,14 @@ class SslTlsSupportSpec extends AkkaSpec {
val client = new JavaSslClient(serverAddress)
client.run()
val baselineSessionCounts = sessionCounts()
client.close()
sessionCounts() === baselineSessionCounts // see above
probe.expectTerminated(bindHandler)
}
"work between a akka client and a akka server" in {
invalidateSessions()
val serverAddress = TestUtils.temporaryServerAddress()
val probe = TestProbe()
val bindHandler = probe.watch(system.actorOf(Props(new AkkaSslServer(serverAddress)).withDeploy(Deploy.local), "server2"))
@ -87,9 +100,36 @@ class SslTlsSupportSpec extends AkkaSpec {
val client = new AkkaSslClient(serverAddress)
client.run()
val baselineSessionCounts = sessionCounts()
client.close()
sessionCounts() === baselineSessionCounts // see above
probe.expectTerminated(bindHandler)
}
"work between an akka client and a Java server with confirmedClose" in {
invalidateSessions()
val server = new JavaSslServer
val client = new AkkaSslClient(server.address)
client.run()
val baselineSessionCounts = sessionCounts()
client.closeConfirmed()
sessionCounts() === baselineSessionCounts // see above
server.close()
sessionCounts() === baselineSessionCounts // see above
}
"akka client runs the full shutdown sequence if peer closes" in {
invalidateSessions()
val server = new JavaSslServer
val client = new AkkaSslClient(server.address)
client.run()
val baselineSessionCounts = serverSessions().length
server.close()
client.peerClosed()
// we only check the akka side server sessions here
// the java client seems to lose the session for some reason
serverSessions().length === baselineSessionCounts
}
}
val counter = new AtomicInteger
@ -112,7 +152,7 @@ class SslTlsSupportSpec extends AkkaSpec {
val handler = system.actorOf(TcpPipelineHandler.props(init, connection, probe.ref).withDeploy(Deploy.local),
"client" + counter.incrementAndGet())
probe.send(connection, Tcp.Register(handler))
probe.send(connection, Tcp.Register(handler, keepOpenOnPeerClosed = true))
def run() {
probe.send(handler, Command("3+4\n"))
@ -127,12 +167,22 @@ class SslTlsSupportSpec extends AkkaSpec {
probe.expectMsg(Event("0\n"))
}
def peerClosed(): Unit = {
probe.expectMsg(Tcp.PeerClosed)
TestUtils.verifyActorTermination(handler)
}
def close() {
probe.send(handler, Management(Tcp.Close))
probe.expectMsgType[Tcp.ConnectionClosed]
TestUtils.verifyActorTermination(handler)
}
def closeConfirmed(): Unit = {
probe.send(handler, Management(Tcp.ConfirmedClose))
probe.expectMsg(Tcp.ConfirmedClosed)
TestUtils.verifyActorTermination(handler)
}
}
//#server
@ -270,6 +320,21 @@ class SslTlsSupportSpec extends AkkaSpec {
engine
}
import collection.JavaConverters._
def clientSessions() = sessions(_.getServerSessionContext)
def serverSessions() = sessions(_.getClientSessionContext)
def sessionCounts() = (clientSessions().length, serverSessions().length)
def sessions(f: SSLContext SSLSessionContext): Seq[SSLSession] = {
val ctx = f(sslContext)
val ids = ctx.getIds().asScala.toIndexedSeq
ids.map(ctx.getSession)
}
def invalidateSessions() = {
clientSessions().foreach(_.invalidate())
serverSessions().foreach(_.invalidate())
}
}
object SslTlsSupportSpec {