Simpler tls over tcp #24153

This commit is contained in:
Johan Andrén 2018-01-16 18:05:08 +01:00 committed by GitHub
parent 14c427bd4f
commit 32987c8704
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 382 additions and 24 deletions

View file

@ -123,6 +123,21 @@ see
@java[[Javadoc](http://doc.akka.io/japi/akka/current/akka/stream/javadsl/Framing.html#simpleFramingProtocol-int-)]
for more information.
### TLS
Similar factories as shown above for raw TCP but where the data is encrypted using TLS are available from `Tcp` through `outgoingTlsConnection`, `bindTls` and `bindAndHandleTls`, see the @scala[@scaladoc[`Tcp Scaladoc`](akka.stream.scaladsl.Tcp)]@java[@javadoc[`Tcp Javadoc`](akka.stream.javadsl.Tcp)] for details.
Using TLS requires a keystore and a truststore and then a somewhat involved dance of configuring the SSLContext and the details for how the session should be negotiated:
Scala
: @@snip [TcpSpec.scala]($akka$akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala) { #setting-up-ssl-context }
Java
: @@snip [TcpSpec.scala]($akka$akka-stream-tests/src/test/java/akka/stream/javadsl/TcpTest.java) { #setting-up-ssl-context }
The `SslContext` and `NegotiateFirstSession` instances can then be used with the binding or outgoing connection factory methods.
## Streaming File IO
Akka Streams provide simple Sources and Sinks that can work with `ByteString` instances to perform IO operations

View file

@ -3,34 +3,44 @@
*/
package akka.stream.javadsl;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import akka.Done;
import akka.actor.ActorSystem;
import akka.japi.function.Function2;
import akka.japi.function.Procedure;
import akka.stream.BindFailedException;
import akka.stream.StreamTcpException;
import akka.stream.StreamTest;
import akka.stream.javadsl.Tcp.IncomingConnection;
import akka.stream.javadsl.Tcp.ServerBinding;
import akka.testkit.AkkaJUnitActorSystemResource;
import akka.testkit.AkkaSpec;
import akka.testkit.SocketUtil;
import akka.testkit.javadsl.EventFilter;
import akka.testkit.javadsl.TestKit;
import akka.util.ByteString;
import org.junit.ClassRule;
import org.junit.Test;
import java.net.BindException;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.net.BindException;
import akka.Done;
import akka.NotUsed;
import akka.testkit.javadsl.EventFilter;
import akka.testkit.javadsl.TestKit;
import org.junit.ClassRule;
import org.junit.Test;
import scala.concurrent.Await;
import scala.concurrent.Future;
import scala.concurrent.duration.FiniteDuration;
import scala.runtime.BoxedUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import akka.stream.*;
import akka.stream.javadsl.Tcp.*;
import akka.japi.function.*;
import akka.testkit.AkkaSpec;
import akka.testkit.SocketUtil;
import akka.util.ByteString;
import akka.testkit.AkkaJUnitActorSystemResource;
// #setting-up-ssl-context
// imports
import akka.stream.TLSClientAuth;
import akka.stream.TLSProtocol;
import com.typesafe.sslconfig.akka.AkkaSSLConfig;
import java.security.KeyStore;
import javax.net.ssl.*;
import java.security.SecureRandom;
// #setting-up-ssl-context
public class TcpTest extends StreamTest {
public TcpTest() {
@ -126,4 +136,51 @@ public class TcpTest extends StreamTest {
}
}
// compile only sample
public void constructSslContext() throws Exception {
ActorSystem system = null;
// #setting-up-ssl-context
// -- setup logic ---
AkkaSSLConfig sslConfig = AkkaSSLConfig.get(system);
// Don't hardcode your password in actual code
char[] password = "abcdef".toCharArray();
// trust store and keys in one keystore
KeyStore keyStore = KeyStore.getInstance("PKCS12");
keyStore.load(getClass().getResourceAsStream("/tcp-spec-keystore.p12"), password);
TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
tmf.init(keyStore);
KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance("SunX509");
keyManagerFactory.init(keyStore, password);
// initial ssl context
SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(keyManagerFactory.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());
// protocols
SSLParameters defaultParams = sslContext.getDefaultSSLParameters();
String[] defaultProtocols = defaultParams.getProtocols();
String[] protocols = sslConfig.configureProtocols(defaultProtocols, sslConfig.config());
defaultParams.setProtocols(protocols);
// ciphers
String[] defaultCiphers = defaultParams.getCipherSuites();
String[] cipherSuites = sslConfig.configureCipherSuites(defaultCiphers, sslConfig.config());
defaultParams.setCipherSuites(cipherSuites);
TLSProtocol.NegotiateNewSession negotiateNewSession = TLSProtocol.negotiateNewSession()
.withCipherSuites(cipherSuites)
.withProtocols(protocols)
.withParameters(defaultParams)
.withClientAuth(TLSClientAuth.none());
// #setting-up-ssl-context
}
}

View file

@ -4,7 +4,9 @@
package akka.stream.io
import java.net._
import java.security.SecureRandom
import java.util.concurrent.atomic.AtomicInteger
import javax.net.ssl.{ KeyManagerFactory, SSLContext, TrustManagerFactory }
import akka.actor.{ ActorIdentity, ActorSystem, ExtendedActorSystem, Identify, Kill }
import akka.io.Tcp._
@ -18,6 +20,7 @@ import akka.testkit.SocketUtil.temporaryServerAddress
import akka.util.ByteString
import akka.{ Done, NotUsed }
import com.typesafe.config.ConfigFactory
import org.scalatest.concurrent.PatienceConfiguration
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import scala.collection.immutable
@ -25,7 +28,10 @@ import scala.concurrent.duration._
import scala.concurrent.{ Await, Future, Promise }
import scala.util.control.NonFatal
class TcpSpec extends StreamSpec("akka.stream.materializer.subscription-timeout.timeout = 2s") with TcpHelper {
class TcpSpec extends StreamSpec("""
akka.loglevel = info
akka.stream.materializer.subscription-timeout.timeout = 2s
""") with TcpHelper {
"Outgoing TCP stream" must {
@ -692,6 +698,94 @@ class TcpSpec extends StreamSpec("akka.stream.materializer.subscription-timeout.
}
}
"TLS client and server convenience methods" should {
"allow for 'simple' TLS" in {
// cert is valid until 2025, so if this tests starts failing after that you need to create a new one
val (sslContext, firstSession) = initSslMess()
val address = temporaryServerAddress()
Tcp().bindAndHandleTls(
// just echo charactes until we reach '\n', then complete stream
// also - byte is our framing
Flow[ByteString].mapConcat(_.utf8String.toList)
.takeWhile(_ != '\n')
.map(c ByteString(c)),
address.getHostName,
address.getPort,
sslContext,
firstSession
).futureValue
system.log.info(s"Server bound to ${address.getHostString}:${address.getPort}")
val connectionFlow = Tcp().outgoingTlsConnection(address.getHostName, address.getPort, sslContext, firstSession)
val chars = "hello\n".toList.map(_.toString)
val (connectionF, result) =
Source(chars).map(c ByteString(c))
.concat(Source.maybe) // do not complete it from our side
.viaMat(connectionFlow)(Keep.right)
.map(_.utf8String)
.toMat(Sink.fold("")(_ + _))(Keep.both)
.run()
connectionF.futureValue
system.log.info(s"Client connected to ${address.getHostString}:${address.getPort}")
result.futureValue(PatienceConfiguration.Timeout(10.seconds)) should ===("hello")
}
def initSslMess() = {
// #setting-up-ssl-context
import akka.stream.TLSClientAuth
import akka.stream.TLSProtocol
import com.typesafe.sslconfig.akka.AkkaSSLConfig
import java.security.KeyStore
import javax.net.ssl._
val sslConfig = AkkaSSLConfig()
// Don't hardcode your password in actual code
val password = "abcdef".toCharArray
// trust store and keys in one keystore
val keyStore = KeyStore.getInstance("PKCS12")
keyStore.load(classOf[TcpSpec].getResourceAsStream("/tcp-spec-keystore.p12"), password)
val tmf = TrustManagerFactory.getInstance("SunX509")
tmf.init(keyStore)
val keyManagerFactory = KeyManagerFactory.getInstance("SunX509")
keyManagerFactory.init(keyStore, password)
// initial ssl context
val sslContext = SSLContext.getInstance("TLS")
sslContext.init(keyManagerFactory.getKeyManagers, tmf.getTrustManagers, new SecureRandom)
// protocols
val defaultParams = sslContext.getDefaultSSLParameters
val defaultProtocols = defaultParams.getProtocols
val protocols = sslConfig.configureProtocols(defaultProtocols, sslConfig.config)
defaultParams.setProtocols(protocols)
// ciphers
val defaultCiphers = defaultParams.getCipherSuites
val cipherSuites = sslConfig.configureCipherSuites(defaultCiphers, sslConfig.config)
defaultParams.setCipherSuites(cipherSuites)
val negotiateNewSession = TLSProtocol.NegotiateNewSession
.withCipherSuites(cipherSuites: _*)
.withProtocols(protocols: _*)
.withParameters(defaultParams)
.withClientAuth(TLSClientAuth.None)
// #setting-up-ssl-context
(sslContext, negotiateNewSession)
}
}
def validateServerClientCommunication(
testData: ByteString,
serverConnection: ServerConnection,

View file

@ -24,8 +24,10 @@ import akka.io.Inet.SocketOption
import scala.compat.java8.OptionConverters._
import scala.compat.java8.FutureConverters._
import java.util.concurrent.CompletionStage
import javax.net.ssl.SSLContext
import akka.annotation.InternalApi
import akka.annotation.{ ApiMayChange, InternalApi }
import akka.stream.TLSProtocol.NegotiateNewSession
object Tcp extends ExtensionId[Tcp] with ExtensionIdProvider {
@ -199,4 +201,82 @@ class Tcp(system: ExtendedActorSystem) extends akka.actor.Extension {
Flow.fromGraph(delegate.outgoingConnection(new InetSocketAddress(host, port))
.mapMaterializedValue(_.map(new OutgoingConnection(_))(ec).toJava))
/**
* Creates an [[Tcp.OutgoingConnection]] with TLS.
* The returned flow represents a TCP client connection to the given endpoint where all bytes in and
* out go through TLS.
*
* @see [[Tcp.outgoingConnection()]]
*/
def outgoingTlsConnection(host: String, port: Int, sslContext: SSLContext, negotiateNewSession: NegotiateNewSession): Flow[ByteString, ByteString, CompletionStage[OutgoingConnection]] =
Flow.fromGraph(delegate.outgoingTlsConnection(host, port, sslContext, negotiateNewSession)
.mapMaterializedValue(_.map(new OutgoingConnection(_))(ec).toJava))
/**
* Creates an [[Tcp.OutgoingConnection]] with TLS.
* The returned flow represents a TCP client connection to the given endpoint where all bytes in and
* out go through TLS.
*
* @see [[Tcp.outgoingConnection()]]
*
* Marked API-may-change to leave room for an improvement around the very long parameter list.
*/
@ApiMayChange
def outgoingTlsConnection(
remoteAddress: InetSocketAddress,
sslContext: SSLContext,
negotiateNewSession: NegotiateNewSession,
localAddress: Optional[InetSocketAddress],
options: JIterable[SocketOption],
connectTimeout: Duration,
idleTimeout: Duration
): Flow[ByteString, ByteString, CompletionStage[OutgoingConnection]] =
Flow.fromGraph(delegate.outgoingTlsConnection(
remoteAddress,
sslContext,
negotiateNewSession,
localAddress.asScala,
immutableSeq(options),
connectTimeout,
idleTimeout)
.mapMaterializedValue(_.map(new OutgoingConnection(_))(ec).toJava))
/**
* Creates a [[Tcp.ServerBinding]] instance which represents a prospective TCP server binding on the given `endpoint`
* where all incoming and outgoing bytes are passed through TLS.
*
* @see [[Tcp.bind()]]
* Marked API-may-change to leave room for an improvement around the very long parameter list.
*/
@ApiMayChange
def bindTls(
interface: String,
port: Int,
sslContext: SSLContext,
negotiateNewSession: NegotiateNewSession,
backlog: Int,
options: JIterable[SocketOption],
halfClose: Boolean,
idleTimeout: Duration
): Source[IncomingConnection, CompletionStage[ServerBinding]] =
Source.fromGraph(delegate.bindTls(interface, port, sslContext, negotiateNewSession, backlog, immutableSeq(options), idleTimeout)
.map(new IncomingConnection(_))
.mapMaterializedValue(_.map(new ServerBinding(_))(ec).toJava))
/**
* Creates a [[Tcp.ServerBinding]] instance which represents a prospective TCP server binding on the given `endpoint`
* where all incoming and outgoing bytes are passed through TLS.
*
* @see [[Tcp.bind()]]
*/
def bindTls(
interface: String,
port: Int,
sslContext: SSLContext,
negotiateNewSession: NegotiateNewSession
): Source[IncomingConnection, CompletionStage[ServerBinding]] =
Source.fromGraph(delegate.bindTls(interface, port, sslContext, negotiateNewSession)
.map(new IncomingConnection(_))
.mapMaterializedValue(_.map(new ServerBinding(_))(ec).toJava))
}

View file

@ -5,16 +5,18 @@ package akka.stream.scaladsl
import java.net.InetSocketAddress
import java.util.concurrent.TimeoutException
import javax.net.ssl.SSLContext
import akka.{ Done, NotUsed }
import akka.actor._
import akka.annotation.InternalApi
import akka.annotation.{ ApiMayChange, InternalApi }
import akka.io.Inet.SocketOption
import akka.io.{ IO, Tcp IoTcp }
import akka.stream.TLSProtocol.NegotiateNewSession
import akka.stream._
import akka.stream.impl.fusing.GraphStages.detacher
import akka.stream.impl.io.{ ConnectionSourceStage, OutgoingConnectionStage, TcpIdleTimeout }
import akka.util.ByteString
import akka.{ Done, NotUsed }
import scala.collection.immutable
import scala.concurrent.Future
@ -70,6 +72,15 @@ object Tcp extends ExtensionId[Tcp] with ExtensionIdProvider {
def lookup() = Tcp
def createExtension(system: ExtendedActorSystem): Tcp = new Tcp(system)
// just wraps/unwraps the TLS byte events to provide ByteString, ByteString flows
private val tlsWrapping: BidiFlow[ByteString, TLSProtocol.SendBytes, TLSProtocol.SslTlsInbound, ByteString, NotUsed] = BidiFlow.fromFlows(
Flow[ByteString].map(TLSProtocol.SendBytes),
Flow[TLSProtocol.SslTlsInbound].collect {
case sb: TLSProtocol.SessionBytes sb.bytes
// ignore other kinds of inbounds (currently only Truncated)
}
)
}
final class Tcp(system: ExtendedActorSystem) extends akka.actor.Extension {
@ -208,6 +219,107 @@ final class Tcp(system: ExtendedActorSystem) extends akka.actor.Extension {
*/
def outgoingConnection(host: String, port: Int): Flow[ByteString, ByteString, Future[OutgoingConnection]] =
outgoingConnection(InetSocketAddress.createUnresolved(host, port))
/**
* Creates an [[Tcp.OutgoingConnection]] with TLS.
* The returned flow represents a TCP client connection to the given endpoint where all bytes in and
* out go through TLS.
*
* For more advanced use cases you can manually combine [[Tcp.outgoingConnection()]] and [[TLS]]
*
* @param negotiateNewSession Details about what to require when negotiating the connection with the server
* @param sslContext Context containing details such as the trust and keystore
*
* @see [[Tcp.outgoingConnection()]]
*/
def outgoingTlsConnection(
host: String,
port: Int,
sslContext: SSLContext,
negotiateNewSession: NegotiateNewSession): Flow[ByteString, ByteString, Future[OutgoingConnection]] =
outgoingTlsConnection(InetSocketAddress.createUnresolved(host, port), sslContext, negotiateNewSession)
/**
* Creates an [[Tcp.OutgoingConnection]] with TLS.
* The returned flow represents a TCP client connection to the given endpoint where all bytes in and
* out go through TLS.
*
* @see [[Tcp.outgoingConnection()]]
* @param negotiateNewSession Details about what to require when negotiating the connection with the server
* @param sslContext Context containing details such as the trust and keystore
*
* Marked API-may-change to leave room for an improvement around the very long parameter list.
*/
@ApiMayChange
def outgoingTlsConnection(
remoteAddress: InetSocketAddress,
sslContext: SSLContext,
negotiateNewSession: NegotiateNewSession,
localAddress: Option[InetSocketAddress] = None,
options: immutable.Traversable[SocketOption] = Nil,
connectTimeout: Duration = Duration.Inf,
idleTimeout: Duration = Duration.Inf): Flow[ByteString, ByteString, Future[OutgoingConnection]] = {
val connection = outgoingConnection(remoteAddress, localAddress, options, true, connectTimeout, idleTimeout)
val tls = TLS(sslContext, negotiateNewSession, TLSRole.client)
connection.join(tlsWrapping.atop(tls).reversed)
}
/**
* Creates a [[Tcp.ServerBinding]] instance which represents a prospective TCP server binding on the given `endpoint`
* where all incoming and outgoing bytes are passed through TLS.
*
* @param negotiateNewSession Details about what to require when negotiating the connection with the server
* @param sslContext Context containing details such as the trust and keystore
* @see [[Tcp.bind]]
*
* Marked API-may-change to leave room for an improvement around the very long parameter list.
*/
@ApiMayChange
def bindTls(
interface: String,
port: Int,
sslContext: SSLContext,
negotiateNewSession: NegotiateNewSession,
backlog: Int = 100,
options: immutable.Traversable[SocketOption] = Nil,
idleTimeout: Duration = Duration.Inf): Source[IncomingConnection, Future[ServerBinding]] = {
val tls = tlsWrapping.atop(TLS(sslContext, negotiateNewSession, TLSRole.server)).reversed
bind(interface, port, backlog, options, true, idleTimeout).map { incomingConnection
incomingConnection.copy(
flow = incomingConnection.flow.join(tls)
)
}
}
/**
* Creates a [[Tcp.ServerBinding]] instance which represents a prospective TCP server binding on the given `endpoint`
* handling the incoming connections through TLS and then run using the provided Flow.
*
* @param negotiateNewSession Details about what to require when negotiating the connection with the server
* @param sslContext Context containing details such as the trust and keystore
* @see [[Tcp.bindAndHandle()]]
*
* Marked API-may-change to leave room for an improvement around the very long parameter list.
*/
@ApiMayChange
def bindAndHandleTls(
handler: Flow[ByteString, ByteString, _],
interface: String,
port: Int,
sslContext: SSLContext,
negotiateNewSession: NegotiateNewSession,
backlog: Int = 100,
options: immutable.Traversable[SocketOption] = Nil,
idleTimeout: Duration = Duration.Inf)(implicit m: Materializer): Future[ServerBinding] = {
bindTls(interface, port, sslContext, negotiateNewSession, backlog, options, idleTimeout)
.to(Sink.foreach { conn: IncomingConnection
conn.handleWith(handler)
}).run()
}
}
final class TcpIdleTimeoutException(msg: String, timeout: Duration)