diff --git a/akka-http-core/src/main/java/akka/http/javadsl/model/headers/TlsSessionInfo.java b/akka-http-core/src/main/java/akka/http/javadsl/model/headers/TlsSessionInfo.java new file mode 100644 index 0000000000..d398d0d5c5 --- /dev/null +++ b/akka-http-core/src/main/java/akka/http/javadsl/model/headers/TlsSessionInfo.java @@ -0,0 +1,25 @@ +/** + * Copyright (C) 2009-2016 Typesafe Inc. + */ + +package akka.http.javadsl.model.headers; + +import javax.net.ssl.SSLSession; + +/** + * Model for the synthetic `Tls-Session-Info` header which carries the SSLSession of the connection + * the message carrying this header was received with. + * + * This header will only be added if it enabled in the configuration by setting + * akka.http.[client|server].parsing.tls-session-info-header = on. + */ +public abstract class TlsSessionInfo extends CustomHeader { + /** + * @return the SSLSession this message was received over. + */ + public abstract SSLSession getSession(); + + public static TlsSessionInfo create(SSLSession session) { + return akka.http.scaladsl.model.headers.Tls$minusSession$minusInfo$.MODULE$.apply(session); + } +} diff --git a/akka-http-core/src/main/resources/reference.conf b/akka-http-core/src/main/resources/reference.conf index 35ae9ead17..55da20ba6e 100644 --- a/akka-http-core/src/main/resources/reference.conf +++ b/akka-http-core/src/main/resources/reference.conf @@ -337,5 +337,10 @@ akka.http { If-Unmodified-Since = 0 User-Agent = 32 } + + # Enables/disables inclusion of an Tls-Session-Info header in parsed + # messages over Tls transports (i.e., HttpRequest on server side and + # HttpResponse on client side). + tls-session-info-header = off } } diff --git a/akka-http-core/src/main/scala/akka/http/ParserSettings.scala b/akka-http-core/src/main/scala/akka/http/ParserSettings.scala index fee8caf553..7af10fdb4b 100644 --- a/akka-http-core/src/main/scala/akka/http/ParserSettings.scala +++ b/akka-http-core/src/main/scala/akka/http/ParserSettings.scala @@ -27,6 +27,7 @@ final case class ParserSettings( illegalHeaderWarnings: Boolean, errorLoggingVerbosity: ParserSettings.ErrorLoggingVerbosity, headerValueCacheLimits: Map[String, Int], + includeTlsSessionInfoHeader: Boolean, customMethods: String ⇒ Option[HttpMethod], customStatusCodes: Int ⇒ Option[StatusCode]) extends HttpHeaderParser.Settings { @@ -75,6 +76,7 @@ object ParserSettings extends SettingsCompanion[ParserSettings]("akka.http.parsi c getBoolean "illegal-header-warnings", ErrorLoggingVerbosity(c getString "error-logging-verbosity"), cacheConfig.entrySet.asScala.map(kvp ⇒ kvp.getKey -> cacheConfig.getInt(kvp.getKey))(collection.breakOut), + c getBoolean "tls-session-info-header", _ ⇒ None, _ ⇒ None) } diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala index 0f76d57e3d..2bbd8df266 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/client/OutgoingConnectionBlueprint.scala @@ -100,8 +100,8 @@ private[http] object OutgoingConnectionBlueprint { val wrapTls = b.add(Flow[ByteString].map(SendBytes)) terminationMerge.out ~> requestRendering ~> logger ~> wrapTls - val unwrapTls = b.add(Flow[SslTlsInbound].collect { case SessionBytes(_, bytes) ⇒ bytes }) - unwrapTls ~> responseParsingMerge.in0 + val collectSessionBytes = b.add(Flow[SslTlsInbound].collect { case s: SessionBytes ⇒ s }) + collectSessionBytes ~> responseParsingMerge.in0 methodBypassFanout.out(0) ~> terminationMerge.in0 @@ -113,7 +113,7 @@ private[http] object OutgoingConnectionBlueprint { BidiShape( methodBypassFanout.in, wrapTls.out, - unwrapTls.in, + collectSessionBytes.in, terminationFanout.out(1)) }) @@ -154,8 +154,8 @@ private[http] object OutgoingConnectionBlueprint { * 2. Read from the dataInput until exactly one response has been fully received * 3. Go back to 1. */ - class ResponseParsingMerge(rootParser: HttpResponseParser) extends GraphStage[FanInShape2[ByteString, HttpMethod, List[ResponseOutput]]] { - private val dataInput = Inlet[ByteString]("data") + class ResponseParsingMerge(rootParser: HttpResponseParser) extends GraphStage[FanInShape2[SessionBytes, HttpMethod, List[ResponseOutput]]] { + private val dataInput = Inlet[SessionBytes]("data") private val methodBypassInput = Inlet[HttpMethod]("method") private val out = Outlet[List[ResponseOutput]]("out") @@ -174,7 +174,7 @@ private[http] object OutgoingConnectionBlueprint { override def onPush(): Unit = { val method = grab(methodBypassInput) parser.setRequestMethodForNextResponse(method) - val output = parser.onPush(ByteString.empty) + val output = parser.parseBytes(ByteString.empty) drainParser(output) } override def onUpstreamFinish(): Unit = @@ -185,7 +185,7 @@ private[http] object OutgoingConnectionBlueprint { setHandler(dataInput, new InHandler { override def onPush(): Unit = { val bytes = grab(dataInput) - val output = parser.onPush(bytes) + val output = parser.parseSessionBytes(bytes) drainParser(output) } override def onUpstreamFinish(): Unit = diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala index af1a3830bd..227aeb5051 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/parsing/HttpMessageParser.scala @@ -4,6 +4,10 @@ package akka.http.impl.engine.parsing +import javax.net.ssl.SSLSession + +import akka.stream.io.{ SessionBytes, SslTlsInbound } + import scala.annotation.tailrec import scala.collection.mutable.ListBuffer import akka.parboiled2.CharUtils @@ -30,11 +34,17 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser private[this] var completionHandling: CompletionHandling = CompletionOk private[this] var terminated = false + private[this] var lastSession: SSLSession = null // used to prevent having to recreate header on each message + private[this] var tlsSessionInfoHeader: `Tls-Session-Info` = null + def initialHeaderBuffer: ListBuffer[HttpHeader] = + if (settings.includeTlsSessionInfoHeader && tlsSessionInfoHeader != null) ListBuffer(tlsSessionInfoHeader) + else ListBuffer() + def isTerminated = terminated - val stage: PushPullStage[ByteString, Output] = - new PushPullStage[ByteString, Output] { - def onPush(elem: ByteString, ctx: Context[Output]) = handleParserOutput(self.onPush(elem), ctx) + val stage: PushPullStage[SessionBytes, Output] = + new PushPullStage[SessionBytes, Output] { + def onPush(input: SessionBytes, ctx: Context[Output]) = handleParserOutput(self.parseSessionBytes(input), ctx) def onPull(ctx: Context[Output]) = handleParserOutput(self.onPull(), ctx) private def handleParserOutput(output: Output, ctx: Context[Output]): SyncDirective = output match { @@ -46,7 +56,14 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser if (self.onUpstreamFinish()) ctx.finish() else ctx.absorbTermination() } - final def onPush(input: ByteString): Output = { + final def parseSessionBytes(input: SessionBytes): Output = { + if (input.session ne lastSession) { + lastSession = input.session + tlsSessionInfoHeader = `Tls-Session-Info`(input.session) + } + parseBytes(input.bytes) + } + final def parseBytes(input: ByteString): Output = { @tailrec def run(next: ByteString ⇒ StateResult): StateResult = (try next(input) catch { @@ -102,7 +119,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser def badProtocol: Nothing - @tailrec final def parseHeaderLines(input: ByteString, lineStart: Int, headers: ListBuffer[HttpHeader] = ListBuffer[HttpHeader](), + @tailrec final def parseHeaderLines(input: ByteString, lineStart: Int, headers: ListBuffer[HttpHeader] = initialHeaderBuffer, headerCount: Int = 0, ch: Option[Connection] = None, clh: Option[`Content-Length`] = None, cth: Option[`Content-Type`] = None, teh: Option[`Transfer-Encoding`] = None, e100c: Boolean = false, diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala index 2c54a09db4..49af8de1fc 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/server/HttpServerBluePrint.scala @@ -57,16 +57,19 @@ private[http] object HttpServerBluePrint { requestPreparation(settings) atop controller(settings, log) atop parsingRendering(settings, log) atop - new ProtocolSwitchStage(settings, log) atop - unwrapTls + websocketSupport(settings, log) atop + tlsSupport theStack.withAttributes(HttpAttributes.remoteAddress(remoteAddress)) } - val unwrapTls: BidiFlow[ByteString, SslTlsOutbound, SslTlsInbound, ByteString, Unit] = - BidiFlow.fromFlows(Flow[ByteString].map(SendBytes), Flow[SslTlsInbound].collect { case x: SessionBytes ⇒ x.bytes }) + val tlsSupport: BidiFlow[ByteString, SslTlsOutbound, SslTlsInbound, SessionBytes, Unit] = + BidiFlow.fromFlows(Flow[ByteString].map(SendBytes), Flow[SslTlsInbound].collect { case x: SessionBytes ⇒ x }) - def parsingRendering(settings: ServerSettings, log: LoggingAdapter): BidiFlow[ResponseRenderingContext, ResponseRenderingOutput, ByteString, RequestOutput, Unit] = + def websocketSupport(settings: ServerSettings, log: LoggingAdapter): BidiFlow[ResponseRenderingOutput, ByteString, SessionBytes, SessionBytes, Unit] = + BidiFlow.fromGraph(new ProtocolSwitchStage(settings, log)) + + def parsingRendering(settings: ServerSettings, log: LoggingAdapter): BidiFlow[ResponseRenderingContext, ResponseRenderingOutput, SessionBytes, RequestOutput, Unit] = BidiFlow.fromFlows(rendering(settings, log), parsing(settings, log)) def controller(settings: ServerSettings, log: LoggingAdapter): BidiFlow[HttpResponse, ResponseRenderingContext, RequestOutput, RequestOutput, Unit] = @@ -114,7 +117,7 @@ private[http] object HttpServerBluePrint { } }) - def parsing(settings: ServerSettings, log: LoggingAdapter): Flow[ByteString, RequestOutput, Unit] = { + def parsing(settings: ServerSettings, log: LoggingAdapter): Flow[SessionBytes, RequestOutput, Unit] = { import settings._ // the initial header parser we initially use for every connection, @@ -137,7 +140,7 @@ private[http] object HttpServerBluePrint { case x ⇒ x } - Flow[ByteString].transform(() ⇒ + Flow[SessionBytes].transform(() ⇒ // each connection uses a single (private) request parser instance for all its requests // which builds a cache of all header instances seen on that connection rootParser.createShallowCopy().stage).named("rootParser") @@ -344,12 +347,12 @@ private[http] object HttpServerBluePrint { One2OneBidiFlow[HttpRequest, HttpResponse](pipeliningLimit).reversed private class ProtocolSwitchStage(settings: ServerSettings, log: LoggingAdapter) - extends GraphStage[BidiShape[ResponseRenderingOutput, ByteString, ByteString, ByteString]] { + extends GraphStage[BidiShape[ResponseRenderingOutput, ByteString, SessionBytes, SessionBytes]] { - private val fromNet = Inlet[ByteString]("fromNet") + private val fromNet = Inlet[SessionBytes]("fromNet") private val toNet = Outlet[ByteString]("toNet") - private val toHttp = Outlet[ByteString]("toHttp") + private val toHttp = Outlet[SessionBytes]("toHttp") private val fromHttp = Inlet[ResponseRenderingOutput]("fromHttp") override def initialAttributes = Attributes.name("ProtocolSwitchStage") @@ -432,14 +435,14 @@ private[http] object HttpServerBluePrint { }) setHandler(fromNet, new InHandler { - override def onPush(): Unit = sourceOut.push(grab(fromNet)) + override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes) }) sourceOut.setHandler(new OutHandler { override def onPull(): Unit = { if (!hasBeenPulled(fromNet)) pull(fromNet) cancelTimeout(timeoutKey) sourceOut.setHandler(new OutHandler { - override def onPull(): Unit = pull(fromNet) + override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet) }) } }) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebsocketClientBlueprint.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebsocketClientBlueprint.scala index d14f627c0b..b4523a4d70 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebsocketClientBlueprint.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/ws/WebsocketClientBlueprint.scala @@ -79,7 +79,7 @@ object WebsocketClientBlueprint { parser.setRequestMethodForNextResponse(HttpMethods.GET) def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = { - parser.onPush(elem) match { + parser.parseBytes(elem) match { case NeedMoreData ⇒ ctx.pull() case ResponseStart(status, protocol, headers, entity, close) ⇒ val response = HttpResponse(status, headers, protocol = protocol) diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala index 1e50a57d9e..7df3a213c6 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala @@ -8,8 +8,10 @@ import java.lang.Iterable import java.net.InetSocketAddress import java.security.MessageDigest import java.util +import javax.net.ssl.SSLSession import akka.event.Logging +import akka.stream.io.ScalaSessionAPI import scala.reflect.ClassTag import scala.util.{ Failure, Success, Try } @@ -177,7 +179,6 @@ final case class `If-Range`(entityTagOrDateTime: Either[EntityTag, DateTime]) ex protected def companion = `If-Range` } -// FIXME: resurrect SSL-Session-Info header once akka.io.SslTlsSupport supports it final case class RawHeader(name: String, value: String) extends jm.headers.RawHeader { val lowercaseName = name.toRootLowerCase def render[R <: Rendering](r: R): r.type = r ~~ name ~~ ':' ~~ ' ' ~~ value @@ -760,6 +761,29 @@ final case class `Set-Cookie`(cookie: HttpCookie) extends jm.headers.SetCookie w protected def companion = `Set-Cookie` } +/** + * Model for the synthetic `Tls-Session-Info` header which carries the SSLSession of the connection + * the message carrying this header was received with. + * + * This header will only be added if it enabled in the configuration by setting + * + * ``` + * akka.http.[client|server].parsing.tls-session-info-header = on + * ``` + */ +final case class `Tls-Session-Info`(session: SSLSession) extends jm.headers.TlsSessionInfo with ScalaSessionAPI { + override def suppressRendering: Boolean = true + override def toString = s"SSL-Session-Info($session)" + def name(): String = "SSL-Session-Info" + def value(): String = "" + + /** Java API */ + def getSession(): SSLSession = session + + def lowercaseName: String = name.toRootLowerCase + def render[R <: Rendering](r: R): r.type = r ~~ name ~~ ':' ~~ ' ' ~~ value +} + // http://tools.ietf.org/html/rfc7230#section-3.3.1 object `Transfer-Encoding` extends ModeledCompanion[`Transfer-Encoding`] { def apply(first: TransferEncoding, more: TransferEncoding*): `Transfer-Encoding` = apply(immutable.Seq(first +: more: _*)) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala index bbe90e9eac..0948a75c05 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/client/TlsEndpointVerificationSpec.scala @@ -13,17 +13,15 @@ import akka.stream.testkit.AkkaSpec import akka.http.impl.util._ import akka.http.scaladsl.{ HttpsContext, Http } import akka.http.scaladsl.model.{ StatusCodes, HttpResponse, HttpRequest } -import akka.http.scaladsl.model.headers.Host +import akka.http.scaladsl.model.headers.{ Host, `Tls-Session-Info` } import org.scalatest.time.{ Span, Seconds } import scala.concurrent.Future -import akka.testkit.EventFilter -import javax.net.ssl.SSLException class TlsEndpointVerificationSpec extends AkkaSpec(""" akka.loglevel = INFO akka.io.tcp.trace-logging = off + akka.http.parsing.tls-session-info-header = on """) with ScalaFutures { - implicit val materializer = ActorMaterializer() /* @@ -47,6 +45,8 @@ class TlsEndpointVerificationSpec extends AkkaSpec(""" whenReady(pipe(HttpRequest(uri = "https://akka.example.org:8080/")), timeout) { response ⇒ response.status shouldEqual StatusCodes.OK + val tlsInfo = response.header[`Tls-Session-Info`].get + tlsInfo.peerPrincipal.get.getName shouldEqual "CN=akka.example.org,O=Internet Widgits Pty Ltd,ST=Some-State,C=AU" } } "not accept certificates for foreign hosts" in { @@ -96,7 +96,12 @@ class TlsEndpointVerificationSpec extends AkkaSpec(""" Source.single(req).via(pipelineFlow(clientContext, hostname)).runWith(Sink.head) def pipelineFlow(clientContext: HttpsContext, hostname: String): Flow[HttpRequest, HttpResponse, Unit] = { - val handler: HttpRequest ⇒ HttpResponse = _ ⇒ HttpResponse() + val handler: HttpRequest ⇒ HttpResponse = { req ⇒ + // verify Tls-Session-Info header information + val name = req.header[`Tls-Session-Info`].flatMap(_.localPrincipal).map(_.getName) + if (name.exists(_ == "CN=akka.example.org,O=Internet Widgits Pty Ltd,ST=Some-State,C=AU")) HttpResponse() + else HttpResponse(StatusCodes.BadRequest, entity = "Tls-Session-Info header verification failed") + } val serverSideTls = Http().sslTlsStage(Some(ExampleHttpContexts.exampleServerContext), Server) val clientSideTls = Http().sslTlsStage(Some(clientContext), Client, Some(hostname -> 8080)) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala index 631fc039af..e90494bc3a 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/RequestParserSpec.scala @@ -4,13 +4,23 @@ package akka.http.impl.engine.parsing -import akka.stream.impl.fusing.GraphInterpreter -import com.typesafe.config.{ Config, ConfigFactory } import scala.concurrent.Future import scala.concurrent.duration._ + +import com.typesafe.config.{ Config, ConfigFactory } + +import akka.util.ByteString + +import akka.actor.ActorSystem + +import akka.stream.ActorMaterializer +import akka.stream.scaladsl._ + +import akka.stream.io.{ SslTlsPlacebo, SessionBytes } + import org.scalatest.matchers.Matcher import org.scalatest.{ BeforeAndAfterAll, FreeSpec, Matchers } -import akka.actor.ActorSystem + import akka.http.ParserSettings import akka.http.impl.engine.parsing.ParserOutput._ import akka.http.impl.util._ @@ -24,9 +34,6 @@ import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.util.FastFuture import akka.http.scaladsl.util.FastFuture._ -import akka.stream.{ OverflowStrategy, ActorMaterializer } -import akka.stream.scaladsl._ -import akka.util.ByteString class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { val testConf: Config = ConfigFactory.parseString(""" @@ -233,10 +240,10 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { } "don't overflow the stack for large buffers of chunks" in new Test { - override val awaitAtMost = 3000.millis + override val awaitAtMost = 10000.millis val x = NotEnoughDataException - val numChunks = 15000 // failed starting from 4000 with sbt started with `-Xss2m` + val numChunks = 12000 // failed starting from 4000 with sbt started with `-Xss2m` val oneChunk = "1\r\nz\n" val manyChunks = (oneChunk * numChunks) + "0\r\n" @@ -473,7 +480,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { def multiParse(parser: HttpRequestParser)(input: Seq[String]): Seq[Either[RequestOutput, StrictEqualHttpRequest]] = Source(input.toList) - .map(ByteString.apply) + .map(bytes ⇒ SessionBytes(SslTlsPlacebo.dummySession, ByteString(bytes))) .transform(() ⇒ parser.stage).named("parser") .splitWhen(x ⇒ x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError]) .prefixAndTail(1) diff --git a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ResponseParserSpec.scala b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ResponseParserSpec.scala index 29a4aeab2c..00f46b4b75 100644 --- a/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ResponseParserSpec.scala +++ b/akka-http-core/src/test/scala/akka/http/impl/engine/parsing/ResponseParserSpec.scala @@ -6,6 +6,7 @@ package akka.http.impl.engine.parsing import akka.http.ParserSettings import akka.http.scaladsl.util.FastFuture +import akka.stream.io.{ SslTlsPlacebo, SessionBytes } import com.typesafe.config.{ ConfigFactory, Config } import scala.concurrent.{ Future, Await } import scala.concurrent.duration._ @@ -290,7 +291,7 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll { def rawParse(requestMethod: HttpMethod, input: String*): Source[Either[ResponseOutput, HttpResponse], Unit] = Source(input.toList) - .map(ByteString.apply) + .map(bytes ⇒ SessionBytes(SslTlsPlacebo.dummySession, ByteString(bytes))) .transform(() ⇒ newParserStage(requestMethod)).named("parser") .splitWhen(x ⇒ x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError]) .prefixAndTail(1) diff --git a/akka-stream/src/main/scala/akka/stream/io/SslTls.scala b/akka-stream/src/main/scala/akka/stream/io/SslTls.scala index 35dbfd483f..04ec83a536 100644 --- a/akka-stream/src/main/scala/akka/stream/io/SslTls.scala +++ b/akka-stream/src/main/scala/akka/stream/io/SslTls.scala @@ -4,6 +4,7 @@ package akka.stream.io import java.lang.{ Integer ⇒ jInteger } +import java.security.Principal import akka.japi import akka.stream._ @@ -152,18 +153,63 @@ object SslTls { * unwrapping [[SendBytes]]. */ object SslTlsPlacebo { + // this constructs a session for (invalid) protocol SSL_NULL_WITH_NULL_NULL + private[akka] val dummySession = SSLContext.getDefault.createSSLEngine.getSession + val forScala: scaladsl.BidiFlow[SslTlsOutbound, ByteString, ByteString, SessionBytes, Unit] = scaladsl.BidiFlow.fromGraph(scaladsl.GraphDSL.create() { implicit b ⇒ - // this constructs a session for (invalid) protocol SSL_NULL_WITH_NULL_NULL - val session = SSLContext.getDefault.createSSLEngine.getSession val top = b.add(scaladsl.Flow[SslTlsOutbound].collect { case SendBytes(bytes) ⇒ bytes }) - val bottom = b.add(scaladsl.Flow[ByteString].map(SessionBytes(session, _))) + val bottom = b.add(scaladsl.Flow[ByteString].map(SessionBytes(dummySession, _))) BidiShape.fromFlows(top, bottom) }) val forJava: javadsl.BidiFlow[SslTlsOutbound, ByteString, ByteString, SessionBytes, Unit] = new javadsl.BidiFlow(forScala) } +import java.security.Principal +import java.security.cert.Certificate +import javax.net.ssl.{ SSLPeerUnverifiedException, SSLSession } + +/** Allows access to an SSLSession with Scala types */ +trait ScalaSessionAPI { + def session: SSLSession + + /** + * Scala API: Extract the certificates that were actually used by this + * engine during this session’s negotiation. The list is empty if no + * certificates were used. + */ + def localCertificates: List[Certificate] = Option(session.getLocalCertificates).map(_.toList).getOrElse(Nil) + /** + * Scala API: Extract the Principal that was actually used by this engine + * during this session’s negotiation. + */ + def localPrincipal: Option[Principal] = Option(session.getLocalPrincipal) + /** + * Scala API: Extract the certificates that were used by the peer engine + * during this session’s negotiation. The list is empty if no certificates + * were used. + */ + def peerCertificates: List[Certificate] = + try Option(session.getPeerCertificates).map(_.toList).getOrElse(Nil) + catch { case e: SSLPeerUnverifiedException ⇒ Nil } + /** + * Scala API: Extract the Principal that the peer engine presented during + * this session’s negotiation. + */ + def peerPrincipal: Option[Principal] = + try Option(session.getPeerPrincipal) + catch { case e: SSLPeerUnverifiedException ⇒ None } +} + +object ScalaSessionAPI { + /** Constructs a ScalaSessionAPI instance from an SSLSession */ + def apply(_session: SSLSession): ScalaSessionAPI = + new ScalaSessionAPI { + def session: SSLSession = _session + } +} + /** * Many protocols are asymmetric and distinguish between the client and the * server, where the latter listens passively for messages and the former @@ -313,34 +359,7 @@ case object SessionTruncated extends SessionTruncated * The Java API for getting session information is given by the SSLSession object, * the Scala API adapters are offered below. */ -case class SessionBytes(session: SSLSession, bytes: ByteString) extends SslTlsInbound { - /** - * Scala API: Extract the certificates that were actually used by this - * engine during this session’s negotiation. The list is empty if no - * certificates were used. - */ - def localCertificates: List[Certificate] = Option(session.getLocalCertificates).map(_.toList).getOrElse(Nil) - /** - * Scala API: Extract the Principal that was actually used by this engine - * during this session’s negotiation. - */ - def localPrincipal = Option(session.getLocalPrincipal) - /** - * Scala API: Extract the certificates that were used by the peer engine - * during this session’s negotiation. The list is empty if no certificates - * were used. - */ - def peerCertificates = - try Option(session.getPeerCertificates).map(_.toList).getOrElse(Nil) - catch { case e: SSLPeerUnverifiedException ⇒ Nil } - /** - * Scala API: Extract the Principal that the peer engine presented during - * this session’s negotiation. - */ - def peerPrincipal = - try Option(session.getPeerPrincipal) - catch { case e: SSLPeerUnverifiedException ⇒ None } -} +case class SessionBytes(session: SSLSession, bytes: ByteString) extends SslTlsInbound with ScalaSessionAPI /** * This is the supertype of all messages that the SslTls stage accepts on its