Merge pull request #19210 from spray/w/18349-add-TlsSessionInfo-header

+htc #18349 emit Tls-Session-Info header when configured on both client and server
This commit is contained in:
Konrad Malawski 2016-01-07 13:52:34 +01:00
commit 5ac1b7ee47
12 changed files with 180 additions and 72 deletions

View file

@ -0,0 +1,25 @@
/**
* Copyright (C) 2009-2016 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.http.javadsl.model.headers;
import javax.net.ssl.SSLSession;
/**
* Model for the synthetic `Tls-Session-Info` header which carries the SSLSession of the connection
* the message carrying this header was received with.
*
* This header will only be added if it enabled in the configuration by setting
* <code>akka.http.[client|server].parsing.tls-session-info-header = on</code>.
*/
public abstract class TlsSessionInfo extends CustomHeader {
/**
* @return the SSLSession this message was received over.
*/
public abstract SSLSession getSession();
public static TlsSessionInfo create(SSLSession session) {
return akka.http.scaladsl.model.headers.Tls$minusSession$minusInfo$.MODULE$.apply(session);
}
}

View file

@ -337,5 +337,10 @@ akka.http {
If-Unmodified-Since = 0
User-Agent = 32
}
# Enables/disables inclusion of an Tls-Session-Info header in parsed
# messages over Tls transports (i.e., HttpRequest on server side and
# HttpResponse on client side).
tls-session-info-header = off
}
}

View file

@ -27,6 +27,7 @@ final case class ParserSettings(
illegalHeaderWarnings: Boolean,
errorLoggingVerbosity: ParserSettings.ErrorLoggingVerbosity,
headerValueCacheLimits: Map[String, Int],
includeTlsSessionInfoHeader: Boolean,
customMethods: String Option[HttpMethod],
customStatusCodes: Int Option[StatusCode]) extends HttpHeaderParser.Settings {
@ -75,6 +76,7 @@ object ParserSettings extends SettingsCompanion[ParserSettings]("akka.http.parsi
c getBoolean "illegal-header-warnings",
ErrorLoggingVerbosity(c getString "error-logging-verbosity"),
cacheConfig.entrySet.asScala.map(kvp kvp.getKey -> cacheConfig.getInt(kvp.getKey))(collection.breakOut),
c getBoolean "tls-session-info-header",
_ None,
_ None)
}

View file

@ -100,8 +100,8 @@ private[http] object OutgoingConnectionBlueprint {
val wrapTls = b.add(Flow[ByteString].map(SendBytes))
terminationMerge.out ~> requestRendering ~> logger ~> wrapTls
val unwrapTls = b.add(Flow[SslTlsInbound].collect { case SessionBytes(_, bytes) bytes })
unwrapTls ~> responseParsingMerge.in0
val collectSessionBytes = b.add(Flow[SslTlsInbound].collect { case s: SessionBytes s })
collectSessionBytes ~> responseParsingMerge.in0
methodBypassFanout.out(0) ~> terminationMerge.in0
@ -113,7 +113,7 @@ private[http] object OutgoingConnectionBlueprint {
BidiShape(
methodBypassFanout.in,
wrapTls.out,
unwrapTls.in,
collectSessionBytes.in,
terminationFanout.out(1))
})
@ -154,8 +154,8 @@ private[http] object OutgoingConnectionBlueprint {
* 2. Read from the dataInput until exactly one response has been fully received
* 3. Go back to 1.
*/
class ResponseParsingMerge(rootParser: HttpResponseParser) extends GraphStage[FanInShape2[ByteString, HttpMethod, List[ResponseOutput]]] {
private val dataInput = Inlet[ByteString]("data")
class ResponseParsingMerge(rootParser: HttpResponseParser) extends GraphStage[FanInShape2[SessionBytes, HttpMethod, List[ResponseOutput]]] {
private val dataInput = Inlet[SessionBytes]("data")
private val methodBypassInput = Inlet[HttpMethod]("method")
private val out = Outlet[List[ResponseOutput]]("out")
@ -174,7 +174,7 @@ private[http] object OutgoingConnectionBlueprint {
override def onPush(): Unit = {
val method = grab(methodBypassInput)
parser.setRequestMethodForNextResponse(method)
val output = parser.onPush(ByteString.empty)
val output = parser.parseBytes(ByteString.empty)
drainParser(output)
}
override def onUpstreamFinish(): Unit =
@ -185,7 +185,7 @@ private[http] object OutgoingConnectionBlueprint {
setHandler(dataInput, new InHandler {
override def onPush(): Unit = {
val bytes = grab(dataInput)
val output = parser.onPush(bytes)
val output = parser.parseSessionBytes(bytes)
drainParser(output)
}
override def onUpstreamFinish(): Unit =

View file

@ -4,6 +4,10 @@
package akka.http.impl.engine.parsing
import javax.net.ssl.SSLSession
import akka.stream.io.{ SessionBytes, SslTlsInbound }
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import akka.parboiled2.CharUtils
@ -30,11 +34,17 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser
private[this] var completionHandling: CompletionHandling = CompletionOk
private[this] var terminated = false
private[this] var lastSession: SSLSession = null // used to prevent having to recreate header on each message
private[this] var tlsSessionInfoHeader: `Tls-Session-Info` = null
def initialHeaderBuffer: ListBuffer[HttpHeader] =
if (settings.includeTlsSessionInfoHeader && tlsSessionInfoHeader != null) ListBuffer(tlsSessionInfoHeader)
else ListBuffer()
def isTerminated = terminated
val stage: PushPullStage[ByteString, Output] =
new PushPullStage[ByteString, Output] {
def onPush(elem: ByteString, ctx: Context[Output]) = handleParserOutput(self.onPush(elem), ctx)
val stage: PushPullStage[SessionBytes, Output] =
new PushPullStage[SessionBytes, Output] {
def onPush(input: SessionBytes, ctx: Context[Output]) = handleParserOutput(self.parseSessionBytes(input), ctx)
def onPull(ctx: Context[Output]) = handleParserOutput(self.onPull(), ctx)
private def handleParserOutput(output: Output, ctx: Context[Output]): SyncDirective =
output match {
@ -46,7 +56,14 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser
if (self.onUpstreamFinish()) ctx.finish() else ctx.absorbTermination()
}
final def onPush(input: ByteString): Output = {
final def parseSessionBytes(input: SessionBytes): Output = {
if (input.session ne lastSession) {
lastSession = input.session
tlsSessionInfoHeader = `Tls-Session-Info`(input.session)
}
parseBytes(input.bytes)
}
final def parseBytes(input: ByteString): Output = {
@tailrec def run(next: ByteString StateResult): StateResult =
(try next(input)
catch {
@ -102,7 +119,7 @@ private[http] abstract class HttpMessageParser[Output >: MessageOutput <: Parser
def badProtocol: Nothing
@tailrec final def parseHeaderLines(input: ByteString, lineStart: Int, headers: ListBuffer[HttpHeader] = ListBuffer[HttpHeader](),
@tailrec final def parseHeaderLines(input: ByteString, lineStart: Int, headers: ListBuffer[HttpHeader] = initialHeaderBuffer,
headerCount: Int = 0, ch: Option[Connection] = None,
clh: Option[`Content-Length`] = None, cth: Option[`Content-Type`] = None,
teh: Option[`Transfer-Encoding`] = None, e100c: Boolean = false,

View file

@ -57,16 +57,19 @@ private[http] object HttpServerBluePrint {
requestPreparation(settings) atop
controller(settings, log) atop
parsingRendering(settings, log) atop
new ProtocolSwitchStage(settings, log) atop
unwrapTls
websocketSupport(settings, log) atop
tlsSupport
theStack.withAttributes(HttpAttributes.remoteAddress(remoteAddress))
}
val unwrapTls: BidiFlow[ByteString, SslTlsOutbound, SslTlsInbound, ByteString, Unit] =
BidiFlow.fromFlows(Flow[ByteString].map(SendBytes), Flow[SslTlsInbound].collect { case x: SessionBytes x.bytes })
val tlsSupport: BidiFlow[ByteString, SslTlsOutbound, SslTlsInbound, SessionBytes, Unit] =
BidiFlow.fromFlows(Flow[ByteString].map(SendBytes), Flow[SslTlsInbound].collect { case x: SessionBytes x })
def parsingRendering(settings: ServerSettings, log: LoggingAdapter): BidiFlow[ResponseRenderingContext, ResponseRenderingOutput, ByteString, RequestOutput, Unit] =
def websocketSupport(settings: ServerSettings, log: LoggingAdapter): BidiFlow[ResponseRenderingOutput, ByteString, SessionBytes, SessionBytes, Unit] =
BidiFlow.fromGraph(new ProtocolSwitchStage(settings, log))
def parsingRendering(settings: ServerSettings, log: LoggingAdapter): BidiFlow[ResponseRenderingContext, ResponseRenderingOutput, SessionBytes, RequestOutput, Unit] =
BidiFlow.fromFlows(rendering(settings, log), parsing(settings, log))
def controller(settings: ServerSettings, log: LoggingAdapter): BidiFlow[HttpResponse, ResponseRenderingContext, RequestOutput, RequestOutput, Unit] =
@ -114,7 +117,7 @@ private[http] object HttpServerBluePrint {
}
})
def parsing(settings: ServerSettings, log: LoggingAdapter): Flow[ByteString, RequestOutput, Unit] = {
def parsing(settings: ServerSettings, log: LoggingAdapter): Flow[SessionBytes, RequestOutput, Unit] = {
import settings._
// the initial header parser we initially use for every connection,
@ -137,7 +140,7 @@ private[http] object HttpServerBluePrint {
case x x
}
Flow[ByteString].transform(()
Flow[SessionBytes].transform(()
// each connection uses a single (private) request parser instance for all its requests
// which builds a cache of all header instances seen on that connection
rootParser.createShallowCopy().stage).named("rootParser")
@ -344,12 +347,12 @@ private[http] object HttpServerBluePrint {
One2OneBidiFlow[HttpRequest, HttpResponse](pipeliningLimit).reversed
private class ProtocolSwitchStage(settings: ServerSettings, log: LoggingAdapter)
extends GraphStage[BidiShape[ResponseRenderingOutput, ByteString, ByteString, ByteString]] {
extends GraphStage[BidiShape[ResponseRenderingOutput, ByteString, SessionBytes, SessionBytes]] {
private val fromNet = Inlet[ByteString]("fromNet")
private val fromNet = Inlet[SessionBytes]("fromNet")
private val toNet = Outlet[ByteString]("toNet")
private val toHttp = Outlet[ByteString]("toHttp")
private val toHttp = Outlet[SessionBytes]("toHttp")
private val fromHttp = Inlet[ResponseRenderingOutput]("fromHttp")
override def initialAttributes = Attributes.name("ProtocolSwitchStage")
@ -432,14 +435,14 @@ private[http] object HttpServerBluePrint {
})
setHandler(fromNet, new InHandler {
override def onPush(): Unit = sourceOut.push(grab(fromNet))
override def onPush(): Unit = sourceOut.push(grab(fromNet).bytes)
})
sourceOut.setHandler(new OutHandler {
override def onPull(): Unit = {
if (!hasBeenPulled(fromNet)) pull(fromNet)
cancelTimeout(timeoutKey)
sourceOut.setHandler(new OutHandler {
override def onPull(): Unit = pull(fromNet)
override def onPull(): Unit = if (!hasBeenPulled(fromNet)) pull(fromNet)
})
}
})

View file

@ -79,7 +79,7 @@ object WebsocketClientBlueprint {
parser.setRequestMethodForNextResponse(HttpMethods.GET)
def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = {
parser.onPush(elem) match {
parser.parseBytes(elem) match {
case NeedMoreData ctx.pull()
case ResponseStart(status, protocol, headers, entity, close)
val response = HttpResponse(status, headers, protocol = protocol)

View file

@ -8,8 +8,10 @@ import java.lang.Iterable
import java.net.InetSocketAddress
import java.security.MessageDigest
import java.util
import javax.net.ssl.SSLSession
import akka.event.Logging
import akka.stream.io.ScalaSessionAPI
import scala.reflect.ClassTag
import scala.util.{ Failure, Success, Try }
@ -177,7 +179,6 @@ final case class `If-Range`(entityTagOrDateTime: Either[EntityTag, DateTime]) ex
protected def companion = `If-Range`
}
// FIXME: resurrect SSL-Session-Info header once akka.io.SslTlsSupport supports it
final case class RawHeader(name: String, value: String) extends jm.headers.RawHeader {
val lowercaseName = name.toRootLowerCase
def render[R <: Rendering](r: R): r.type = r ~~ name ~~ ':' ~~ ' ' ~~ value
@ -760,6 +761,29 @@ final case class `Set-Cookie`(cookie: HttpCookie) extends jm.headers.SetCookie w
protected def companion = `Set-Cookie`
}
/**
* Model for the synthetic `Tls-Session-Info` header which carries the SSLSession of the connection
* the message carrying this header was received with.
*
* This header will only be added if it enabled in the configuration by setting
*
* ```
* akka.http.[client|server].parsing.tls-session-info-header = on
* ```
*/
final case class `Tls-Session-Info`(session: SSLSession) extends jm.headers.TlsSessionInfo with ScalaSessionAPI {
override def suppressRendering: Boolean = true
override def toString = s"SSL-Session-Info($session)"
def name(): String = "SSL-Session-Info"
def value(): String = ""
/** Java API */
def getSession(): SSLSession = session
def lowercaseName: String = name.toRootLowerCase
def render[R <: Rendering](r: R): r.type = r ~~ name ~~ ':' ~~ ' ' ~~ value
}
// http://tools.ietf.org/html/rfc7230#section-3.3.1
object `Transfer-Encoding` extends ModeledCompanion[`Transfer-Encoding`] {
def apply(first: TransferEncoding, more: TransferEncoding*): `Transfer-Encoding` = apply(immutable.Seq(first +: more: _*))

View file

@ -13,17 +13,15 @@ import akka.stream.testkit.AkkaSpec
import akka.http.impl.util._
import akka.http.scaladsl.{ HttpsContext, Http }
import akka.http.scaladsl.model.{ StatusCodes, HttpResponse, HttpRequest }
import akka.http.scaladsl.model.headers.Host
import akka.http.scaladsl.model.headers.{ Host, `Tls-Session-Info` }
import org.scalatest.time.{ Span, Seconds }
import scala.concurrent.Future
import akka.testkit.EventFilter
import javax.net.ssl.SSLException
class TlsEndpointVerificationSpec extends AkkaSpec("""
akka.loglevel = INFO
akka.io.tcp.trace-logging = off
akka.http.parsing.tls-session-info-header = on
""") with ScalaFutures {
implicit val materializer = ActorMaterializer()
/*
@ -47,6 +45,8 @@ class TlsEndpointVerificationSpec extends AkkaSpec("""
whenReady(pipe(HttpRequest(uri = "https://akka.example.org:8080/")), timeout) { response
response.status shouldEqual StatusCodes.OK
val tlsInfo = response.header[`Tls-Session-Info`].get
tlsInfo.peerPrincipal.get.getName shouldEqual "CN=akka.example.org,O=Internet Widgits Pty Ltd,ST=Some-State,C=AU"
}
}
"not accept certificates for foreign hosts" in {
@ -96,7 +96,12 @@ class TlsEndpointVerificationSpec extends AkkaSpec("""
Source.single(req).via(pipelineFlow(clientContext, hostname)).runWith(Sink.head)
def pipelineFlow(clientContext: HttpsContext, hostname: String): Flow[HttpRequest, HttpResponse, Unit] = {
val handler: HttpRequest HttpResponse = _ HttpResponse()
val handler: HttpRequest HttpResponse = { req
// verify Tls-Session-Info header information
val name = req.header[`Tls-Session-Info`].flatMap(_.localPrincipal).map(_.getName)
if (name.exists(_ == "CN=akka.example.org,O=Internet Widgits Pty Ltd,ST=Some-State,C=AU")) HttpResponse()
else HttpResponse(StatusCodes.BadRequest, entity = "Tls-Session-Info header verification failed")
}
val serverSideTls = Http().sslTlsStage(Some(ExampleHttpContexts.exampleServerContext), Server)
val clientSideTls = Http().sslTlsStage(Some(clientContext), Client, Some(hostname -> 8080))

View file

@ -4,13 +4,23 @@
package akka.http.impl.engine.parsing
import akka.stream.impl.fusing.GraphInterpreter
import com.typesafe.config.{ Config, ConfigFactory }
import scala.concurrent.Future
import scala.concurrent.duration._
import com.typesafe.config.{ Config, ConfigFactory }
import akka.util.ByteString
import akka.actor.ActorSystem
import akka.stream.ActorMaterializer
import akka.stream.scaladsl._
import akka.stream.io.{ SslTlsPlacebo, SessionBytes }
import org.scalatest.matchers.Matcher
import org.scalatest.{ BeforeAndAfterAll, FreeSpec, Matchers }
import akka.actor.ActorSystem
import akka.http.ParserSettings
import akka.http.impl.engine.parsing.ParserOutput._
import akka.http.impl.util._
@ -24,9 +34,6 @@ import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.util.FastFuture
import akka.http.scaladsl.util.FastFuture._
import akka.stream.{ OverflowStrategy, ActorMaterializer }
import akka.stream.scaladsl._
import akka.util.ByteString
class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
val testConf: Config = ConfigFactory.parseString("""
@ -233,10 +240,10 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
}
"don't overflow the stack for large buffers of chunks" in new Test {
override val awaitAtMost = 3000.millis
override val awaitAtMost = 10000.millis
val x = NotEnoughDataException
val numChunks = 15000 // failed starting from 4000 with sbt started with `-Xss2m`
val numChunks = 12000 // failed starting from 4000 with sbt started with `-Xss2m`
val oneChunk = "1\r\nz\n"
val manyChunks = (oneChunk * numChunks) + "0\r\n"
@ -473,7 +480,7 @@ class RequestParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
def multiParse(parser: HttpRequestParser)(input: Seq[String]): Seq[Either[RequestOutput, StrictEqualHttpRequest]] =
Source(input.toList)
.map(ByteString.apply)
.map(bytes SessionBytes(SslTlsPlacebo.dummySession, ByteString(bytes)))
.transform(() parser.stage).named("parser")
.splitWhen(x x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError])
.prefixAndTail(1)

View file

@ -6,6 +6,7 @@ package akka.http.impl.engine.parsing
import akka.http.ParserSettings
import akka.http.scaladsl.util.FastFuture
import akka.stream.io.{ SslTlsPlacebo, SessionBytes }
import com.typesafe.config.{ ConfigFactory, Config }
import scala.concurrent.{ Future, Await }
import scala.concurrent.duration._
@ -290,7 +291,7 @@ class ResponseParserSpec extends FreeSpec with Matchers with BeforeAndAfterAll {
def rawParse(requestMethod: HttpMethod, input: String*): Source[Either[ResponseOutput, HttpResponse], Unit] =
Source(input.toList)
.map(ByteString.apply)
.map(bytes SessionBytes(SslTlsPlacebo.dummySession, ByteString(bytes)))
.transform(() newParserStage(requestMethod)).named("parser")
.splitWhen(x x.isInstanceOf[MessageStart] || x.isInstanceOf[EntityStreamError])
.prefixAndTail(1)