+str #15833 TLS with session renegotiation

This commit is contained in:
Roland Kuhn 2015-04-20 16:33:57 +02:00
parent 33919f683c
commit 616838a738
16 changed files with 1367 additions and 403 deletions

View file

@ -82,10 +82,10 @@ class FlowErrorDocSpec extends AkkaSpec {
val result = source.grouped(1000).runWith(Sink.head)
// the negative element cause the scan stage to be restarted,
// i.e. start from 0 again
// result here will be a Future completed with Success(Vector(0, 1, 0, 5, 12))
// result here will be a Future completed with Success(Vector(0, 1, 4, 0, 5, 12))
//#restart-section
Await.result(result, remaining) should be(Vector(0, 1, 0, 5, 12))
Await.result(result, remaining) should be(Vector(0, 1, 4, 0, 5, 12))
}
}

View file

@ -343,9 +343,11 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit {
"resume when Scan throws" in new TestSetup(Seq(
Scan(1, (acc: Int, x: Int) if (x == 10) throw TE else acc + x, resumingDecider))) {
downstream.requestOne()
lastEvents() should be(Set(OnNext(1)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(2)
lastEvents() should be(Set(OnNext(1)))
lastEvents() should be(Set(OnNext(3)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
@ -353,15 +355,17 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit {
lastEvents() should be(Set(RequestOne))
upstream.onNext(4)
lastEvents() should be(Set(OnNext(3))) // 1 + 2
lastEvents() should be(Set(OnNext(7))) // 1 + 2 + 4
}
"restart when Scan throws" in new TestSetup(Seq(
Scan(1, (acc: Int, x: Int) if (x == 10) throw TE else acc + x, restartingDecider))) {
downstream.requestOne()
lastEvents() should be(Set(OnNext(1)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(2)
lastEvents() should be(Set(OnNext(1)))
lastEvents() should be(Set(OnNext(3)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
@ -371,10 +375,12 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit {
upstream.onNext(4)
lastEvents() should be(Set(OnNext(1))) // starts over again
downstream.requestOne()
lastEvents() should be(Set(OnNext(5)))
downstream.requestOne()
lastEvents() should be(Set(RequestOne))
upstream.onNext(20)
lastEvents() should be(Set(OnNext(5))) // 1+4
lastEvents() should be(Set(OnNext(25))) // 1 + 4 + 20
}
"restart when Conflate `seed` throws" in new TestSetup(Seq(Conflate(

View file

@ -0,0 +1,411 @@
package akka.stream.io
import java.security.{ KeyStore, SecureRandom }
import javax.net.ssl.{ TrustManagerFactory, KeyManagerFactory, SSLContext }
import akka.stream.{ Graph, BidiShape, ActorFlowMaterializer }
import akka.stream.scaladsl._
import akka.stream.io._
import akka.stream.testkit.{ TestUtils, AkkaSpec }
import akka.util.ByteString
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.collection.immutable
import scala.util.Random
import akka.stream.stage.AsyncStage
import akka.stream.stage.AsyncContext
import java.util.concurrent.TimeoutException
import akka.actor.ActorSystem
import javax.net.ssl.SSLSession
import akka.pattern.{ after later }
import scala.concurrent.Future
import java.net.InetSocketAddress
import akka.testkit.EventFilter
import akka.stream.stage.PushStage
import akka.stream.stage.Context
object TlsSpec {
val rnd = new Random
def initSslContext(): SSLContext = {
val password = "changeme"
val keyStore = KeyStore.getInstance(KeyStore.getDefaultType)
keyStore.load(getClass.getResourceAsStream("/keystore"), password.toCharArray)
val trustStore = KeyStore.getInstance(KeyStore.getDefaultType)
trustStore.load(getClass.getResourceAsStream("/truststore"), password.toCharArray)
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm)
keyManagerFactory.init(keyStore, password.toCharArray)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
trustManagerFactory.init(trustStore)
val context = SSLContext.getInstance("TLS")
context.init(keyManagerFactory.getKeyManagers, trustManagerFactory.getTrustManagers, new SecureRandom)
context
}
/**
* This is a stage that fires a TimeoutException failure 2 seconds after it was started,
* independent of the traffic going through. The purpose is to include the last seen
* element in the exception message to help in figuring out what went wrong.
*/
class Timeout(duration: FiniteDuration)(implicit system: ActorSystem) extends AsyncStage[ByteString, ByteString, Unit] {
private var last: ByteString = _
override def initAsyncInput(ctx: AsyncContext[ByteString, Unit]) = {
val cb = ctx.getAsyncCallback()
system.scheduler.scheduleOnce(duration)(cb.invoke(()))(system.dispatcher)
}
override def onAsyncInput(u: Unit, ctx: AsyncContext[ByteString, Unit]) =
ctx.fail(new TimeoutException(s"timeout expired, last element was $last"))
override def onPush(elem: ByteString, ctx: AsyncContext[ByteString, Unit]) = {
last = elem
if (ctx.isHoldingDownstream) ctx.pushAndPull(elem)
else ctx.holdUpstream()
}
override def onPull(ctx: AsyncContext[ByteString, Unit]) =
if (ctx.isFinishing) ctx.pushAndFinish(last)
else if (ctx.isHoldingUpstream) ctx.pushAndPull(last)
else ctx.holdDownstream()
override def onUpstreamFinish(ctx: AsyncContext[ByteString, Unit]) =
if (ctx.isHoldingUpstream) ctx.absorbTermination()
else ctx.finish()
override def onDownstreamFinish(ctx: AsyncContext[ByteString, Unit]) = {
system.log.debug("cancelled")
ctx.finish()
}
}
// FIXME #17226 replace by .dropWhile when implemented
class DropWhile[T](p: T Boolean) extends PushStage[T, T] {
private var open = false
override def onPush(elem: T, ctx: Context[T]) =
if (open) ctx.push(elem)
else if (p(elem)) ctx.pull()
else {
open = true
ctx.push(elem)
}
}
}
class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off") {
import TlsSpec._
import system.dispatcher
implicit val materializer = ActorFlowMaterializer()
import FlowGraph.Implicits._
"StreamTLS" must {
val sslContext = initSslContext()
val debug = Flow[SslTlsInbound].map { x
x match {
case SessionTruncated system.log.debug(s" ----------- truncated ")
case SessionBytes(_, b) system.log.debug(s" ----------- (${b.size}) ${b.take(32).utf8String}")
}
x
}
val cipherSuites = NegotiateNewSession.withCipherSuites("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_128_CBC_SHA")
def clientTls(closing: Closing) = SslTls(sslContext, cipherSuites, Client, closing)
def serverTls(closing: Closing) = SslTls(sslContext, cipherSuites, Server, closing)
trait Named {
def name: String =
getClass.getName
.reverse
.dropWhile(c "$0123456789".indexOf(c) != -1)
.takeWhile(_ != '$')
.reverse
}
trait CommunicationSetup extends Named {
def decorateFlow(leftClosing: Closing, rightClosing: Closing,
rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]): Flow[SslTlsOutbound, SslTlsInbound, Unit]
def cleanup(): Unit = ()
}
object ClientInitiates extends CommunicationSetup {
def decorateFlow(leftClosing: Closing, rightClosing: Closing,
rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) =
clientTls(leftClosing) atop serverTls(rightClosing).reversed join rhs
}
object ServerInitiates extends CommunicationSetup {
def decorateFlow(leftClosing: Closing, rightClosing: Closing,
rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) =
serverTls(leftClosing) atop clientTls(rightClosing).reversed join rhs
}
def server(flow: Flow[ByteString, ByteString, Any]) = {
val server = StreamTcp()
.bind(new InetSocketAddress("localhost", 0))
.to(Sink.foreach(c c.flow.join(flow).run()))
.run()
Await.result(server, 2.seconds)
}
object ClientInitiatesViaTcp extends CommunicationSetup {
var binding: StreamTcp.ServerBinding = null
def decorateFlow(leftClosing: Closing, rightClosing: Closing,
rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = {
binding = server(serverTls(rightClosing).reversed join rhs)
clientTls(leftClosing) join StreamTcp().outgoingConnection(binding.localAddress)
}
override def cleanup(): Unit = binding.unbind()
}
object ServerInitiatesViaTcp extends CommunicationSetup {
var binding: StreamTcp.ServerBinding = null
def decorateFlow(leftClosing: Closing, rightClosing: Closing,
rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = {
binding = server(clientTls(rightClosing).reversed join rhs)
serverTls(leftClosing) join StreamTcp().outgoingConnection(binding.localAddress)
}
override def cleanup(): Unit = binding.unbind()
}
val communicationPatterns =
Seq(
ClientInitiates,
ServerInitiates,
ClientInitiatesViaTcp,
ServerInitiatesViaTcp)
trait PayloadScenario extends Named {
def flow: Flow[SslTlsInbound, SslTlsOutbound, Any] =
Flow[SslTlsInbound]
.map {
var session: SSLSession = null
def setSession(s: SSLSession) = {
session = s
system.log.debug(s"new session: $session (${session.getId mkString ","})")
}
{
case SessionTruncated SendBytes(ByteString("TRUNCATED"))
case SessionBytes(s, b) if session == null
setSession(s)
SendBytes(b)
case SessionBytes(s, b) if s != session
setSession(s)
SendBytes(ByteString("NEWSESSION") ++ b)
case SessionBytes(s, b) SendBytes(b)
}
}
def leftClosing: Closing = IgnoreComplete
def rightClosing: Closing = IgnoreComplete
def inputs: immutable.Seq[SslTlsOutbound]
def output: ByteString
protected def send(str: String) = SendBytes(ByteString(str))
protected def send(ch: Char) = SendBytes(ByteString(ch.toByte))
}
object SingleBytes extends PayloadScenario {
val str = "0123456789"
def inputs = str.map(ch SendBytes(ByteString(ch.toByte)))
def output = ByteString(str)
}
object MediumMessages extends PayloadScenario {
val strs = "0123456789" map (d d.toString * (rnd.nextInt(9000) + 1000))
def inputs = strs map (s SendBytes(ByteString(s)))
def output = ByteString((strs :\ "")(_ ++ _))
}
object LargeMessages extends PayloadScenario {
// TLS max packet size is 16384 bytes
val strs = "0123456789" map (d d.toString * (rnd.nextInt(9000) + 17000))
def inputs = strs map (s SendBytes(ByteString(s)))
def output = ByteString((strs :\ "")(_ ++ _))
}
object EmptyBytesFirst extends PayloadScenario {
def inputs = List(ByteString.empty, ByteString("hello")).map(SendBytes)
def output = ByteString("hello")
}
object EmptyBytesInTheMiddle extends PayloadScenario {
def inputs = List(ByteString("hello"), ByteString.empty, ByteString(" world")).map(SendBytes)
def output = ByteString("hello world")
}
object EmptyBytesLast extends PayloadScenario {
def inputs = List(ByteString("hello"), ByteString.empty).map(SendBytes)
def output = ByteString("hello")
}
// this demonstrates that cancellation is ignored so that the five results make it back
object CancellingRHS extends PayloadScenario {
override def flow =
Flow[SslTlsInbound]
.mapConcat {
case SessionTruncated SessionTruncated :: Nil
case SessionBytes(s, bytes) bytes.map(b SessionBytes(s, ByteString(b)))
}
.take(5)
.mapAsync(5, x later(500.millis, system.scheduler)(Future.successful(x)))
.via(super.flow)
override def rightClosing = IgnoreCancel
val str = "abcdef" * 100
def inputs = str.map(send)
def output = ByteString(str.take(5))
}
object CancellingRHSIgnoresBoth extends PayloadScenario {
override def flow =
Flow[SslTlsInbound]
.mapConcat {
case SessionTruncated SessionTruncated :: Nil
case SessionBytes(s, bytes) bytes.map(b SessionBytes(s, ByteString(b)))
}
.take(5)
.mapAsync(5, x later(500.millis, system.scheduler)(Future.successful(x)))
.via(super.flow)
override def rightClosing = IgnoreBoth
val str = "abcdef" * 100
def inputs = str.map(send)
def output = ByteString(str.take(5))
}
object LHSIgnoresBoth extends PayloadScenario {
override def leftClosing = IgnoreBoth
val str = "0123456789"
def inputs = str.map(ch SendBytes(ByteString(ch.toByte)))
def output = ByteString(str)
}
object BothSidesIgnoreBoth extends PayloadScenario {
override def leftClosing = IgnoreBoth
override def rightClosing = IgnoreBoth
val str = "0123456789"
def inputs = str.map(ch SendBytes(ByteString(ch.toByte)))
def output = ByteString(str)
}
object SessionRenegotiationBySender extends PayloadScenario {
def inputs = List(send("hello"), NegotiateNewSession, send("world"))
def output = ByteString("helloNEWSESSIONworld")
}
// difference is that the RHS engine will now receive the handshake while trying to send
object SessionRenegotiationByReceiver extends PayloadScenario {
val str = "abcdef" * 100
def inputs = str.map(send) ++ Seq(NegotiateNewSession) ++ "hello world".map(send)
def output = ByteString(str + "NEWSESSIONhello world")
}
val logCipherSuite = Flow[SslTlsInbound]
.map {
var session: SSLSession = null
def setSession(s: SSLSession) = {
session = s
system.log.debug(s"new session: $session (${session.getId mkString ","})")
}
{
case SessionTruncated SendBytes(ByteString("TRUNCATED"))
case SessionBytes(s, b) if s != session
setSession(s)
SendBytes(ByteString(s.getCipherSuite) ++ b)
case SessionBytes(s, b) SendBytes(b)
}
}
object SessionRenegotiationFirstOne extends PayloadScenario {
override def flow = logCipherSuite
def inputs = NegotiateNewSession.withCipherSuites("TLS_RSA_WITH_AES_128_CBC_SHA") :: send("hello") :: Nil
def output = ByteString("TLS_RSA_WITH_AES_128_CBC_SHAhello")
}
object SessionRenegotiationFirstTwo extends PayloadScenario {
override def flow = logCipherSuite
def inputs = NegotiateNewSession.withCipherSuites("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA") :: send("hello") :: Nil
def output = ByteString("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHAhello")
}
val scenarios =
Seq(
SingleBytes,
MediumMessages,
LargeMessages,
EmptyBytesFirst,
EmptyBytesInTheMiddle,
EmptyBytesLast,
CancellingRHS,
SessionRenegotiationBySender,
SessionRenegotiationByReceiver,
SessionRenegotiationFirstOne,
SessionRenegotiationFirstTwo)
for {
commPattern communicationPatterns
scenario scenarios
} {
s"work in mode ${commPattern.name} while sending ${scenario.name}" in {
val onRHS = debug.via(scenario.flow)
val f =
Source(scenario.inputs)
.via(commPattern.decorateFlow(scenario.leftClosing, scenario.rightClosing, onRHS))
.transform(() new PushStage[SslTlsInbound, SslTlsInbound] {
override def onPush(elem: SslTlsInbound, ctx: Context[SslTlsInbound]) =
ctx.push(elem)
override def onDownstreamFinish(ctx: Context[SslTlsInbound]) = {
system.log.debug("me cancelled")
ctx.finish()
}
})
.via(debug)
.collect { case SessionBytes(_, b) b }
.scan(ByteString.empty)(_ ++ _)
.transform(() new Timeout(6.seconds))
.transform(() new DropWhile(_.size < scenario.output.size))
.runWith(Sink.head)
Await.result(f, 8.seconds).utf8String should be(scenario.output.utf8String)
commPattern.cleanup()
// flush log so as to not mix up logs of different test cases
if (log.isDebugEnabled)
EventFilter.debug("stopgap", occurrences = 1) intercept {
log.debug("stopgap")
}
}
}
}
"A SslTlsPlacebo" must {
"pass through data" in {
val f = Source(1 to 3)
.map(b SendBytes(ByteString(b.toByte)))
.via(SslTlsPlacebo.forScala join Flow.apply)
.grouped(10)
.runWith(Sink.head)
val result = Await.result(f, 3.seconds)
result.map(_.bytes) should be((1 to 3).map(b ByteString(b.toByte)))
result.map(_.session).foreach(s s.getCipherSuite should be("SSL_NULL_WITH_NULL_NULL"))
}
}
}

View file

@ -6,12 +6,12 @@ package akka.stream.scaladsl
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.concurrent.forkjoin.ThreadLocalRandom.{ current random }
import scala.collection.immutable
import akka.stream.ActorFlowMaterializer
import akka.stream.ActorFlowMaterializerSettings
import akka.stream.testkit.AkkaSpec
import akka.stream.ActorOperationAttributes
import akka.stream.Supervision
class FlowScanSpec extends AkkaSpec {
@ -39,5 +39,20 @@ class FlowScanSpec extends AkkaSpec {
val v = Vector.empty[Int]
scan(Source(v)) should be(v.scan(0)(_ + _))
}
"emit values promptly" in {
val f = Source.single(1).concat(Source.lazyEmpty).scan(0)(_ + _).grouped(2).runWith(Sink.head)
Await.result(f, 1.second) should be(Seq(0, 1))
}
"fail properly" in {
import ActorOperationAttributes._
val scan = Flow[Int].scan(0) { (old, current)
require(current > 0)
old + current
}.withAttributes(supervisionStrategy(Supervision.restartingDecider))
val f = Source(List(1, 3, -1, 5, 7)).via(scan).grouped(1000).runWith(Sink.head)
Await.result(f, 1.second) should be(Seq(0, 1, 4, 0, 5, 12))
}
}
}

View file

@ -213,6 +213,7 @@ final case class BidiShape[-In1, +Out1, -In2, +Out2](in1: Inlet[In1],
require(outlets.size == 2, s"proposed outlets [${outlets.mkString(", ")}] do not fit BidiShape")
BidiShape(inlets(0), outlets(0), inlets(1), outlets(1))
}
def reversed: Shape = copyFromPorts(inlets.reverse, outlets.reverse)
//#implementation-details-elided
}
//#bidi-shape

View file

@ -11,10 +11,15 @@ import akka.pattern.ask
import akka.stream.actor.ActorSubscriber
import akka.stream.impl.GenJunctions.ZipWithModule
import akka.stream.impl.Junctions._
import akka.stream.impl.MultiStreamInputProcessor.SubstreamSubscriber
import akka.stream.impl.StreamLayout.Module
import akka.stream.impl.fusing.ActorInterpreter
import akka.stream.impl.io.SslTlsCipherActor
import akka.stream.scaladsl._
import akka.stream._
import akka.stream.io._
import akka.stream.io.SslTls.TlsModule
import akka.util.ByteString
import org.reactivestreams._
import scala.concurrent.{ Await, ExecutionContextExecutor }
@ -83,6 +88,22 @@ private[akka] case class ActorFlowMaterializerImpl(override val settings: ActorF
assignPort(stage.outPort, processor)
mat
case tls: TlsModule
val es = effectiveSettings(effectiveAttributes)
val props = SslTlsCipherActor.props(es, tls.sslContext, tls.firstSession, tracing = true, tls.role, tls.closing)
val impl = actorOf(props, stageName(effectiveAttributes), es.dispatcher)
def factory(id: Int) = new ActorPublisher[Any](impl) {
override val wakeUpMsg = FanOut.SubstreamSubscribePending(id)
}
val publishers = Vector.tabulate(2)(factory)
impl ! FanOut.ExposedPublishers(publishers)
assignPort(tls.plainOut, publishers(SslTlsCipherActor.UserOut))
assignPort(tls.cipherOut, publishers(SslTlsCipherActor.TransportOut))
assignPort(tls.plainIn, FanIn.SubInput[Any](impl, SslTlsCipherActor.UserIn))
assignPort(tls.cipherIn, FanIn.SubInput[Any](impl, SslTlsCipherActor.TransportIn))
case junction: JunctionModule materializeJunction(junction, effectiveAttributes, effectiveSettings(effectiveAttributes))
}
}

View file

@ -155,6 +155,7 @@ private[akka] object FanIn {
if (input.inputsDepleted) {
if (marked(id)) markedDepleted += 1
depleted(id) = true
onDepleted(id)
}
elem
}

View file

@ -186,6 +186,16 @@ private[akka] object FanOut {
def onCancel(output: Int): Unit = ()
def demandAvailableFor(id: Int) = new TransferState {
override def isCompleted: Boolean = cancelled(id) || completed(id)
override def isReady: Boolean = pending(id)
}
def demandOrCancelAvailableFor(id: Int) = new TransferState {
override def isCompleted: Boolean = false
override def isReady: Boolean = pending(id) || cancelled(id)
}
/**
* Will only transfer an element when all marked outputs
* have demand, and will complete as soon as any of the marked

View file

@ -120,22 +120,26 @@ private[akka] class FlexiMergeImpl[T, S <: Shape](
nextPhase(TransferPhase(precondition) { ()
behavior.condition match {
case read: ReadAny[t]
suppressCompletion()
val id = inputBunch.idToDequeue()
val elem = inputBunch.dequeueAndYield(id)
val inputHandle = inputMapping(id)
callOnInput(inputHandle, elem)
triggerCompletionAfterRead(inputHandle)
case r: ReadPreferred[t]
suppressCompletion()
val elem = inputBunch.dequeuePrefering(indexOf(r.preferred))
val id = inputBunch.lastDequeuedId
val inputHandle = inputMapping(id)
callOnInput(inputHandle, elem)
triggerCompletionAfterRead(inputHandle)
case Read(input)
suppressCompletion()
val elem = inputBunch.dequeue(indexOf(input))
callOnInput(input, elem)
triggerCompletionAfterRead(input)
case read: ReadAll[t]
suppressCompletion()
val inputs = read.inputs
val values = inputs.collect {
case input if include(input) input inputBunch.dequeue(indexOf(input))
@ -160,11 +164,18 @@ private[akka] class FlexiMergeImpl[T, S <: Shape](
}
}
private def triggerCompletionAfterRead(inputHandle: InPort): Unit =
private var completionEnabled = true
private def suppressCompletion(): Unit = completionEnabled = false
private def triggerCompletionAfterRead(inputHandle: InPort): Unit = {
completionEnabled = true
if (inputBunch.isDepleted(indexOf(inputHandle)))
triggerCompletion(inputHandle)
}
private def triggerCompletion(in: InPort): Unit =
if (completionEnabled)
changeBehavior(
try completion.onUpstreamFinish(ctx, in)
catch {

View file

@ -171,4 +171,3 @@ private[akka] trait Pump {
protected def pumpFinished(): Unit
}

View file

@ -298,7 +298,7 @@ private[akka] object ActorInterpreter {
def props(settings: ActorFlowMaterializerSettings, ops: Seq[Stage[_, _]], materializer: ActorFlowMaterializer): Props =
Props(new ActorInterpreter(settings, ops, materializer))
case class AsyncInput(op: AsyncStage[Any, Any, Any], ctx: AsyncContext[Any, Any], event: Any)
case class AsyncInput(op: AsyncStage[Any, Any, Any], ctx: AsyncContext[Any, Any], event: Any) extends DeadLetterSuppression
}
/**

View file

@ -108,18 +108,27 @@ private[akka] final case class Drop[T](count: Long) extends PushStage[T, T] {
*/
private[akka] final case class Scan[In, Out](zero: Out, f: (Out, In) Out, decider: Supervision.Decider) extends PushPullStage[In, Out] {
private var aggregator = zero
private var pushedZero = false
override def onPush(elem: In, ctx: Context[Out]): SyncDirective = {
val old = aggregator
aggregator = f(old, elem)
ctx.push(old)
if (pushedZero) {
aggregator = f(aggregator, elem)
ctx.push(aggregator)
} else {
aggregator = f(zero, elem)
ctx.push(zero)
}
}
override def onPull(ctx: Context[Out]): SyncDirective =
if (ctx.isFinishing) ctx.pushAndFinish(aggregator)
else ctx.pull()
if (!pushedZero) {
pushedZero = true
if (ctx.isFinishing) ctx.pushAndFinish(aggregator) else ctx.push(aggregator)
} else ctx.pull()
override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective = ctx.absorbTermination()
override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective =
if (pushedZero) ctx.finish()
else ctx.absorbTermination()
override def decide(t: Throwable): Supervision.Directive = decider(t)

View file

@ -0,0 +1,449 @@
/**
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.stream.impl.io
import java.nio.ByteBuffer
import java.security.Principal
import java.security.cert.Certificate
import javax.net.ssl.SSLEngineResult.HandshakeStatus
import javax.net.ssl.SSLEngineResult.HandshakeStatus._
import javax.net.ssl.SSLEngineResult.Status._
import javax.net.ssl._
import akka.actor.{ Props, Actor, ActorLogging, ActorRef }
import akka.stream.ActorFlowMaterializerSettings
import akka.stream.impl.FanIn.InputBunch
import akka.stream.impl.FanOut.OutputBunch
import akka.stream.impl._
import akka.util.ByteString
import akka.util.ByteStringBuilder
import org.reactivestreams.Publisher
import org.reactivestreams.Subscriber
import scala.annotation.tailrec
import scala.collection.immutable
import akka.stream.io._
import akka.event.LoggingReceive
/**
* INTERNAL API.
*/
private[akka] object SslTlsCipherActor {
def props(settings: ActorFlowMaterializerSettings,
sslContext: SSLContext,
firstSession: NegotiateNewSession,
tracing: Boolean,
role: Role,
closing: Closing): Props =
Props(new SslTlsCipherActor(settings, sslContext, firstSession, tracing, role, closing))
final val TransportIn = 0
final val TransportOut = 0
final val UserOut = 1
final val UserIn = 1
}
/**
* INTERNAL API.
*/
private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, sslContext: SSLContext,
firstSession: NegotiateNewSession, tracing: Boolean,
role: Role, closing: Closing)
extends Actor with ActorLogging with Pump {
import SslTlsCipherActor._
protected val outputBunch = new OutputBunch(outputCount = 2, self, this)
outputBunch.markAllOutputs()
protected val inputBunch = new InputBunch(inputCount = 2, settings.maxInputBufferSize, this) {
override def onError(input: Int, e: Throwable): Unit = fail(e)
}
/**
* The SSLEngine needs bite-sized chunks of data but we get arbitrary ByteString
* from both the UserIn and the TransportIn ports. This is used to chop up such
* a ByteString by filling the respective ByteBuffer and taking care to dequeue
* a new element when data are demanded and none are left lying on the chopping
* block.
*/
class ChoppingBlock(idx: Int, name: String) extends TransferState {
override def isReady: Boolean = buffer.nonEmpty
override def isCompleted: Boolean = false
private var buffer = ByteString.empty
/**
* Whether there are no bytes lying on this chopping block.
*/
def isEmpty: Boolean = buffer.isEmpty
/**
* Pour as many bytes as are available either on the chopping block or in
* the inputBunchs next ByteString into the supplied ByteBuffer, which is
* expected to be in read left-overs mode, i.e. everything between its
* position and limit is retained. In order to allocate a fresh ByteBuffer
* with these characteristics, use `prepare()`.
*/
def chopInto(b: ByteBuffer): Unit = {
b.compact()
if (buffer.isEmpty) {
buffer = inputBunch.dequeue(idx) match {
// this class handles both UserIn and TransportIn
case bs: ByteString bs
case SendBytes(bs) bs
case n: NegotiateNewSession
setNewSessionParameters(n)
ByteString.empty
}
if (tracing) log.debug(s"chopping from new chunk of ${buffer.size} into $name (${b.position})")
} else {
if (tracing) log.debug(s"chopping from old chunk of ${buffer.size} into $name (${b.position})")
}
val copied = buffer.copyToBuffer(b)
buffer = buffer.drop(copied)
b.flip()
}
/**
* When potentially complete packet data are left after unwrap() we must
* put them back onto the chopping block because otherwise the pump will
* not know that we are runnable.
*/
def putBack(b: ByteBuffer): Unit =
if (b.hasRemaining()) {
if (tracing) log.debug(s"putting back ${b.remaining} bytes into $name")
val bs = ByteString(b)
if (bs.nonEmpty) buffer = bs ++ buffer
prepare(b)
}
/**
* Prepare a fresh ByteBuffer for receiving a chop of data.
*/
def prepare(b: ByteBuffer): Unit = {
b.clear()
b.limit(0)
}
}
// These are Nettys default values
// 16665 + 1024 (room for compressed data) + 1024 (for OpenJDK compatibility)
val transportOutBuffer = ByteBuffer.allocate(16665 + 2048)
/*
* deviating here: chopping multiple input packets into this buffer can lead to
* an OVERFLOW signal that also is an UNDERFLOW; avoid unnecessary copying by
* increasing this buffer size to host up to two packets
*/
val userOutBuffer = ByteBuffer.allocate(16665 * 2 + 2048)
val transportInBuffer = ByteBuffer.allocate(16665 + 2048)
val userInBuffer = ByteBuffer.allocate(16665 + 2048)
val userInChoppingBlock = new ChoppingBlock(UserIn, "UserIn")
userInChoppingBlock.prepare(userInBuffer)
val transportInChoppingBlock = new ChoppingBlock(TransportIn, "TransportIn")
transportInChoppingBlock.prepare(transportInBuffer)
val engine: SSLEngine = {
val e = sslContext.createSSLEngine()
e.setUseClientMode(role == Client)
e
}
var currentSession = engine.getSession
var currentSessionParameters = firstSession
applySessionParameters()
def applySessionParameters(): Unit = {
val csp = currentSessionParameters
import csp._
enabledCipherSuites foreach (cs engine.setEnabledCipherSuites(cs.toArray))
enabledProtocols foreach (p engine.setEnabledProtocols(p.toArray))
clientAuth match {
case Some(ClientAuth.None) engine.setNeedClientAuth(false)
case Some(ClientAuth.Want) engine.setWantClientAuth(true)
case Some(ClientAuth.Need) engine.setNeedClientAuth(true)
case None // do nothing
}
sslParameters foreach (p engine.setSSLParameters(p))
engine.beginHandshake()
lastHandshakeStatus = engine.getHandshakeStatus
}
def setNewSessionParameters(n: NegotiateNewSession): Unit = {
if (tracing) log.debug(s"applying $n")
currentSession.invalidate()
currentSessionParameters = n
applySessionParameters()
corkUser = true
}
/*
* So heres the big picture summary: the SSLEngine is the boss, and it can
* be in several states. Depending on this state, we may want to react to
* different input and output conditions.
*
* - normal bidirectional operation (does both outbound and inbound)
* - outbound close initiated, inbound still open
* - inbound close initiated, outbound still open
* - fully closed
*
* Upon reaching the last state we obviously just shut down. In addition to
* these user-data states, the engine may at any point in time also be
* handshaking. This is mostly transparent, but it has an influence on the
* outbound direction:
*
* - if the local user triggered a re-negotiation, cork all user data until
* that is finished
* - if the outbound direction has been closed, trigger outbound readiness
* based upon HandshakeStatus.NEED_WRAP
*
* These conditions lead to the introduction of a synthetic TransferState
* representing the Engine.
*/
var lastHandshakeStatus: HandshakeStatus = _
val engineNeedsWrap = new TransferState {
def isReady = lastHandshakeStatus == NEED_WRAP
def isCompleted = false
}
val engineInboundOpen = new TransferState {
def isReady = !engine.isInboundDone()
def isCompleted = false
}
var corkUser = true
val userHasData = new TransferState {
private val user = inputBunch.inputsOrCompleteAvailableFor(UserIn) || userInChoppingBlock
def isReady = !corkUser && user.isReady && lastHandshakeStatus != NEED_UNWRAP
def isCompleted = false
}
val transportHasData = inputBunch.inputsOrCompleteAvailableFor(TransportIn) || transportInChoppingBlock
val userOutCancelled = new TransferState {
def isReady = outputBunch.isCancelled(UserOut)
def isCompleted = inputBunch.isDepleted(TransportIn)
}
// bidirectional case
val outbound = (userHasData || engineNeedsWrap) && outputBunch.demandAvailableFor(TransportOut)
val inbound = (transportHasData || userOutCancelled) && outputBunch.demandOrCancelAvailableFor(UserOut)
// half-closed
val outboundHalfClosed = engineNeedsWrap && outputBunch.demandAvailableFor(TransportOut)
val inboundHalfClosed = transportHasData && engineInboundOpen
def completeOrFlush(): Unit =
if (engine.isOutboundDone()) nextPhase(completedPhase)
else nextPhase(flushingOutbound)
private def doInbound(isOutboundClosed: Boolean, inboundState: TransferState): Boolean =
if (inputBunch.isDepleted(TransportIn) && transportInChoppingBlock.isEmpty) {
if (tracing) log.debug("closing inbound")
try engine.closeInbound()
catch { case ex: SSLException outputBunch.enqueue(UserOut, SessionTruncated) }
completeOrFlush()
false
} else if (inboundState != inboundHalfClosed && outputBunch.isCancelled(UserOut)) {
if (!isOutboundClosed && closing.ignoreCancel) {
if (tracing) log.debug("ignoring UserIn cancellation")
nextPhase(inboundClosed)
} else {
if (tracing) log.debug("closing inbound due to UserOut cancellation")
engine.closeOutbound() // this is the correct way of shutting down the engine
lastHandshakeStatus = engine.getHandshakeStatus
nextPhase(flushingOutbound)
}
true
} else if (inboundState.isReady) {
transportInChoppingBlock.chopInto(transportInBuffer)
try {
doUnwrap()
true
} catch {
case ex: SSLException
if (tracing) log.debug(s"SSLException during doUnwrap: $ex")
completeOrFlush()
false
}
} else true
private def doOutbound(isInboundClosed: Boolean): Unit =
if (inputBunch.isDepleted(UserIn) && userInChoppingBlock.isEmpty) {
if (!isInboundClosed && closing.ignoreComplete) {
if (tracing) log.debug("ignoring closeOutbound")
} else {
if (tracing) log.debug("closing outbound directly")
engine.closeOutbound()
lastHandshakeStatus = engine.getHandshakeStatus
}
nextPhase(outboundClosed)
} else if (outputBunch.isCancelled(TransportOut)) {
nextPhase(completedPhase)
} else if (outbound.isReady) {
if (userHasData.isReady) userInChoppingBlock.chopInto(userInBuffer)
try doWrap()
catch {
case ex: SSLException
if (tracing) log.debug(s"SSLException during doWrap: $ex")
completeOrFlush()
}
}
val bidirectional = TransferPhase(outbound || inbound) { ()
if (tracing) log.debug("bidirectional")
val continue = doInbound(isOutboundClosed = false, inbound)
if (continue) {
if (tracing) log.debug("bidirectional continue")
doOutbound(isInboundClosed = false)
}
}
val flushingOutbound = TransferPhase(outboundHalfClosed) { ()
if (tracing) log.debug("flushingOutbound")
try doWrap()
catch { case ex: SSLException nextPhase(completedPhase) }
}
val awaitingClose = TransferPhase(inputBunch.inputsAvailableFor(TransportIn)) { ()
if (tracing) log.debug("awaitingClose")
transportInChoppingBlock.chopInto(transportInBuffer)
try doUnwrap(ignoreOutput = true)
catch { case ex: SSLException nextPhase(completedPhase) }
}
val outboundClosed = TransferPhase(outboundHalfClosed || inbound) { ()
if (tracing) log.debug("outboundClosed")
val continue = doInbound(isOutboundClosed = true, inbound)
if (continue && outboundHalfClosed.isReady) {
if (tracing) log.debug("outboundClosed continue")
try doWrap()
catch { case ex: SSLException nextPhase(completedPhase) }
}
}
val inboundClosed = TransferPhase(outbound || inboundHalfClosed) { ()
if (tracing) log.debug("inboundClosed")
val continue = doInbound(isOutboundClosed = false, inboundHalfClosed)
if (continue) {
if (tracing) log.debug("inboundClosed continue")
doOutbound(isInboundClosed = true)
}
}
def flushToTransport(): Unit = {
if (tracing) log.debug("flushToTransport")
transportOutBuffer.flip()
if (transportOutBuffer.hasRemaining) {
val bs = ByteString(transportOutBuffer)
outputBunch.enqueue(TransportOut, bs)
if (tracing) log.debug(s"sending ${bs.size} bytes")
}
transportOutBuffer.clear()
}
def flushToUser(): Unit = {
if (tracing) log.debug("flushToUser")
userOutBuffer.flip()
if (userOutBuffer.hasRemaining) {
val bs = ByteString(userOutBuffer)
outputBunch.enqueue(UserOut, SessionBytes(currentSession, bs))
}
userOutBuffer.clear()
}
private def doWrap(): Unit = {
val result = engine.wrap(userInBuffer, transportOutBuffer)
lastHandshakeStatus = result.getHandshakeStatus
if (tracing) log.debug(s"wrap: status=${result.getStatus} handshake=$lastHandshakeStatus remaining=${userInBuffer.remaining} out=${transportOutBuffer.position}")
if (lastHandshakeStatus == FINISHED) handshakeFinished()
runDelegatedTasks()
result.getStatus match {
case OK
flushToTransport()
userInChoppingBlock.putBack(userInBuffer)
case CLOSED
flushToTransport()
if (engine.isInboundDone()) nextPhase(completedPhase)
else nextPhase(awaitingClose)
case s fail(new IllegalStateException(s"unexpected status $s in doWrap()"))
}
}
@tailrec
private def doUnwrap(ignoreOutput: Boolean = false): Unit = {
val result = engine.unwrap(transportInBuffer, userOutBuffer)
if (ignoreOutput) userOutBuffer.clear()
lastHandshakeStatus = result.getHandshakeStatus
if (tracing) log.debug(s"unwrap: status=${result.getStatus} handshake=$lastHandshakeStatus remaining=${transportInBuffer.remaining} out=${userOutBuffer.position}")
runDelegatedTasks()
result.getStatus match {
case OK
result.getHandshakeStatus match {
case NEED_WRAP flushToUser()
case FINISHED
flushToUser()
handshakeFinished()
transportInChoppingBlock.putBack(transportInBuffer)
case _
if (transportInBuffer.hasRemaining()) doUnwrap()
else flushToUser()
}
case CLOSED
flushToUser()
if (engine.isOutboundDone()) nextPhase(completedPhase)
else nextPhase(flushingOutbound)
case BUFFER_UNDERFLOW
flushToUser()
case BUFFER_OVERFLOW
flushToUser()
transportInChoppingBlock.putBack(transportInBuffer)
case s fail(new IllegalStateException(s"unexpected status $s in doUnwrap()"))
}
}
@tailrec
private def runDelegatedTasks(): Unit = {
val task = engine.getDelegatedTask
if (task != null) {
if (tracing) log.debug("running task")
task.run()
runDelegatedTasks()
} else {
val st = lastHandshakeStatus
lastHandshakeStatus = engine.getHandshakeStatus
if (tracing && st != lastHandshakeStatus) log.debug(s"handshake status after tasks: $lastHandshakeStatus")
}
}
private def handshakeFinished(): Unit = {
if (tracing) log.debug("handshake finished")
currentSession = engine.getSession
corkUser = false
}
override def receive = inputBunch.subreceive.orElse[Any, Unit](outputBunch.subreceive)
nextPhase(bidirectional)
protected def fail(e: Throwable): Unit = {
// FIXME: escalate to supervisor
if (tracing) log.debug("fail {} due to: {}", self, e.getMessage)
inputBunch.cancel()
outputBunch.error(TransportOut, e)
outputBunch.error(UserOut, e)
context.stop(self)
}
override protected def pumpFailed(e: Throwable): Unit = fail(e)
override protected def pumpFinished(): Unit = {
inputBunch.cancel()
outputBunch.complete()
if (tracing) log.debug(s"STOP Outbound Closed: ${engine.isOutboundDone} Inbound closed: ${engine.isInboundDone}")
context.stop(self)
}
}

View file

@ -0,0 +1,411 @@
/**
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.stream.io
import akka.stream._
import akka.stream.impl.StreamLayout.Module
import akka.util.ByteString
import javax.net.ssl._
import scala.annotation.varargs
import scala.collection.immutable
import java.security.cert.Certificate
/**
* Stream cipher support based upon JSSE.
*
* The underlying SSLEngine has four ports: plaintext input/output and
* ciphertext input/output. These are modeled as a [[akka.stream.BidiShape]]
* element for use in stream topologies, where the plaintext ports are on the
* left hand side of the shape and the ciphertext ports on the right hand side.
*
* Configuring JSSE is a rather complex topic, please refer to the JDK platform
* documentation or the excellent user guide that is part of the Play Framework
* documentation. The philosophy of this integration into Akka Streams is to
* expose all knobs and dials to client code and therefore not limit the
* configuration possibilities. In particular the client code will have to
* provide the SSLContext from which the SSLEngine is then created. Handshake
* parameters are set using [[NegotiateNewSession]] messages, the settings for
* the initial handshake need to be provided up front using the same class;
* please refer to the method documentation below.
*
* '''IMPORTANT NOTE'''
*
* The TLS specification does not permit half-closing of the user data session
* that it transportsto be precise a half-close will always promptly lead to a
* full close. This means that canceling the plaintext output or completing the
* plaintext input of the SslTls stage will lead to full termination of the
* secure connection without regard to whether bytes are remaining to be sent or
* received, respectively. Especially for a client the common idiom of attaching
* a finite Source to the plaintext input and transforming the plaintext response
* bytes coming out will not work out of the box due to early termination of the
* connection. For this reason there is a parameter that determines whether the
* SslTls stage shall ignore completion and/or cancellation events, and the
* default is to ignore completion (in view of the clientserver scenario). In
* order to terminate the connection the client will then need to cancel the
* plaintext output as soon as all expected bytes have been received. When
* ignoring both types of events the stage will shut down once both events have
* been received. See also [[Closing]].
*/
object SslTls {
/**
* Scala API: create a StreamTls [[akka.stream.scaladsl.BidiFlow]]. The
* SSLContext will be used to create an SSLEngine to which then the
* `firstSession` parameters are applied before initiating the first
* handshake. The `role` parameter determines the SSLEngines role; this is
* often the same as the underlying transports server or client role, but
* that is not a requirement and depends entirely on the application
* protocol.
*
* For a description of the `closing` parameter please refer to [[Closing]].
*/
def apply(sslContext: SSLContext, firstSession: NegotiateNewSession,
role: Role, closing: Closing = IgnoreComplete): scaladsl.BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound, Unit] =
new scaladsl.BidiFlow(TlsModule(OperationAttributes.none, sslContext, firstSession, role, closing))
/**
* Java API: create a StreamTls [[akka.stream.javadsl.BidiFlow]] in client mode. The
* SSLContext will be used to create an SSLEngine to which then the
* `firstSession` parameters are applied before initiating the first
* handshake. The `role` parameter determines the SSLEngines role; this is
* often the same as the underlying transports server or client role, but
* that is not a requirement and depends entirely on the application
* protocol.
*
* This method uses the default closing behavior or [[IgnoreComplete]].
*/
def create(sslContext: SSLContext, firstSession: NegotiateNewSession, role: Role): javadsl.BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound, Unit] =
new javadsl.BidiFlow(apply(sslContext, firstSession, role))
/**
* Java API: create a StreamTls [[akka.stream.javadsl.BidiFlow]] in client mode. The
* SSLContext will be used to create an SSLEngine to which then the
* `firstSession` parameters are applied before initiating the first
* handshake. The `role` parameter determines the SSLEngines role; this is
* often the same as the underlying transports server or client role, but
* that is not a requirement and depends entirely on the application
* protocol.
*
* For a description of the `closing` parameter please refer to [[Closing]].
*/
def create(sslContext: SSLContext, firstSession: NegotiateNewSession, role: Role, closing: Closing): javadsl.BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound, Unit] =
new javadsl.BidiFlow(apply(sslContext, firstSession, role, closing))
/**
* INTERNAL API.
*/
private[akka] case class TlsModule(plainIn: Inlet[SslTlsOutbound], plainOut: Outlet[SslTlsInbound],
cipherIn: Inlet[ByteString], cipherOut: Outlet[ByteString],
shape: Shape, attributes: OperationAttributes,
sslContext: SSLContext, firstSession: NegotiateNewSession,
role: Role, closing: Closing) extends Module {
override def subModules: Set[Module] = Set.empty
override def withAttributes(att: OperationAttributes): Module = copy(attributes = att)
override def carbonCopy: Module = {
val mod = TlsModule(attributes, sslContext, firstSession, role, closing)
if (plainIn == shape.inlets(0)) mod
else mod.replaceShape(mod.shape.asInstanceOf[BidiShape[_, _, _, _]].reversed)
}
override def replaceShape(s: Shape) =
if (s == shape) this
else if (shape.hasSamePortsAs(s)) copy(shape = s)
else throw new IllegalArgumentException("trying to replace shape with different ports")
}
/**
* INTERNAL API.
*/
private[akka] object TlsModule {
def apply(attributes: OperationAttributes, sslContext: SSLContext, firstSession: NegotiateNewSession, role: Role, closing: Closing): TlsModule = {
val name = attributes.nameOrDefault(s"StreamTls($role)")
val cipherIn = new Inlet[ByteString](s"$name.cipherIn")
val cipherOut = new Outlet[ByteString](s"$name.cipherOut")
val plainIn = new Inlet[SslTlsOutbound](s"$name.transportIn")
val plainOut = new Outlet[SslTlsInbound](s"$name.transportOut")
val shape = new BidiShape(plainIn, cipherOut, cipherIn, plainOut)
TlsModule(plainIn, plainOut, cipherIn, cipherOut, shape, attributes, sslContext, firstSession, role, closing)
}
}
}
/**
* This object holds simple wrapping [[BidiFlow]] implementations that can
* be used instead of [[SslTls]] when no encryption is desired. The flows will
* just adapt the message protocol by wrapping into [[SessionBytes]] and
* unwrapping [[SendBytes]].
*/
object SslTlsPlacebo {
val forScala = scaladsl.BidiFlow() { implicit b
// this constructs a session for (invalid) protocol SSL_NULL_WITH_NULL_NULL
val session = SSLContext.getDefault.createSSLEngine.getSession
val top = b.add(scaladsl.Flow[SslTlsOutbound].collect { case SendBytes(b) b })
val bottom = b.add(scaladsl.Flow[ByteString].map(SessionBytes(session, _)))
BidiShape(top, bottom)
}
val forJava = new javadsl.BidiFlow(forScala)
}
/**
* Many protocols are asymmetric and distinguish between the client and the
* server, where the latter listens passively for messages and the former
* actively initiates the exchange.
*/
object Role {
/**
* Java API: obtain the [[Client]] singleton value.
*/
def client: Role = Client
/**
* Java API: obtain the [[Server]] singleton value.
*/
def server: Role = Server
}
sealed abstract class Role
/**
* The client is usually the side that consumes the service provided by its
* interlocutor. The precise interpretation of this role is protocol specific.
*/
sealed abstract class Client extends Role
case object Client extends Client
/**
* The server is usually the side the provides the service to its interlocutor.
* The precise interpretation of this role is protocol specific.
*/
sealed abstract class Server extends Role
case object Server extends Server
/**
* All streams in Akka are unidirectional: while in a complex flow graph data
* may flow in multiple directions these individual flows are independent from
* each other. The difference between two half-duplex connections in opposite
* directions and a full-duplex connection is that the underlying transport
* is shared in the latter and tearing it down will end the data transfer in
* both directions.
*
* When integrating a full-duplex transport medium that does not support
* half-closing (which means ending one direction of data transfer without
* ending the other) into a stream topology, there can be unexpected effects.
* Feeding a finite Source into this medium will close the connection after
* all elements have been sent, which means that possible replies may not
* be received in full. To support this type of usage, the sending and
* receiving of data on the same side (e.g. on the [[Client]]) need to be
* coordinated such that it is known when all replies have been received.
* Only then should the transport be shut down.
*
* To support these scenarios it is recommended that the full-duplex
* transport integration is configurable in terms of termination handling,
* which means that the user can optionally suppress the normal (closing)
* reaction to completion or cancellation events, as is expressed by the
* possible values of this type:
*
* - [[EagerClose]] means to not ignore signals
* - [[IgnoreCancel]] means to not react to cancellation of the receiving
* side unless the sending side has already completed
* - [[IgnoreComplete]] means to not reacto the completion of the sending
* side unless the receiving side has already cancelled
* - [[IgnoreBoth]] means to ignore the first termination signalbe that
* cancellation or completionand only act upon the second one
*/
sealed abstract class Closing {
def ignoreCancel: Boolean
def ignoreComplete: Boolean
}
object Closing {
/**
* Java API: obtain the [[EagerClose]] singleton value.
*/
def eagerClose: Closing = EagerClose
/**
* Java API: obtain the [[IgnoreCancel]] singleton value.
*/
def ignoreCancel: Closing = IgnoreCancel
/**
* Java API: obtain the [[IgnoreComplete]] singleton value.
*/
def ignoreComplete: Closing = IgnoreComplete
/**
* Java API: obtain the [[IgnoreBoth]] singleton value.
*/
def ignoreBoth: Closing = IgnoreBoth
}
/**
* see [[Closing]]
*/
sealed abstract class EagerClose extends Closing {
override def ignoreCancel = false
override def ignoreComplete = false
}
case object EagerClose extends EagerClose
/**
* see [[Closing]]
*/
sealed abstract class IgnoreCancel extends Closing {
override def ignoreCancel = true
override def ignoreComplete = false
}
case object IgnoreCancel extends IgnoreCancel
/**
* see [[Closing]]
*/
sealed abstract class IgnoreComplete extends Closing {
override def ignoreCancel = false
override def ignoreComplete = true
}
case object IgnoreComplete extends IgnoreComplete
/**
* see [[Closing]]
*/
sealed abstract class IgnoreBoth extends Closing {
override def ignoreCancel = true
override def ignoreComplete = true
}
case object IgnoreBoth extends IgnoreBoth
/**
* This is the supertype of all messages that the SslTls stage emits on the
* plaintext side.
*/
sealed trait SslTlsInbound
/**
* If the underlying transport is closed before the final TLS closure command
* is received from the peer then the SSLEngine will throw an SSLException that
* warns about possible truncation attacks. This exception is caught and
* translated into this message when encountered. Most of the time this occurs
* not because of a malicious attacker but due to a connection abort or a
* misbehaving communication peer.
*/
sealed abstract class SessionTruncated extends SslTlsInbound
case object SessionTruncated extends SessionTruncated
/**
* Plaintext bytes emitted by the SSLEngine are received over one specific
* encryption session and this class bundles the bytes with the SSLSession
* object. When the session changes due to renegotiation (which can be
* initiated by either party) the new session value will not compare equal to
* the previous one.
*
* The Java API for getting session information is given by the SSLSession object,
* the Scala API adapters are offered below.
*/
case class SessionBytes(session: SSLSession, bytes: ByteString) extends SslTlsInbound {
/**
* Scala API: Extract the certificates that were actually used by this
* engine during this sessions negotiation. The list is empty if no
* certificates were used.
*/
def localCertificates: List[Certificate] = Option(session.getLocalCertificates).map(_.toList).getOrElse(Nil)
/**
* Scala API: Extract the Principal that was actually used by this engine
* during this sessions negotiation.
*/
def localPrincipal = Option(session.getLocalPrincipal)
/**
* Scala API: Extract the certificates that were used by the peer engine
* during this sessions negotiation. The list is empty if no certificates
* were used.
*/
def peerCertificates =
try Option(session.getPeerCertificates).map(_.toList).getOrElse(Nil)
catch { case e: SSLPeerUnverifiedException Nil }
/**
* Scala API: Extract the Principal that the peer engine presented during
* this sessions negotiation.
*/
def peerPrincipal =
try Option(session.getPeerPrincipal)
catch { case e: SSLPeerUnverifiedException None }
}
/**
* This is the supertype of all messages that the SslTls stage accepts on its
* plaintext side.
*/
sealed trait SslTlsOutbound
/**
* Initiate a new session negotiation. Any [[SendBytes]] commands following
* this one will be held back (i.e. back-pressured) until the new handshake is
* completed, meaning that the bytes following this message will be encrypted
* according to the requirements outlined here.
*
* Each of the values in this message is optional and will have the following
* effect if provided:
*
* - `enabledCipherSuites` will be passed to `SSLEngine::setEnabledCipherSuites()`
* - `enabledProtocols` will be passed to `SSLEngine::setEnabledProtocols()`
* - `clientAuth` will be passed to `SSLEngine::setWantClientAuth()` or `SSLEngine.setNeedClientAuth()`, respectively
* - `sslParameters` will be passed to `SSLEngine::setSSLParameters()`
*/
case class NegotiateNewSession(
enabledCipherSuites: Option[immutable.Seq[String]],
enabledProtocols: Option[immutable.Seq[String]],
clientAuth: Option[ClientAuth],
sslParameters: Option[SSLParameters]) extends SslTlsOutbound {
/**
* Java API: Make a copy of this message with the given `enabledCipherSuites`.
*/
@varargs
def withCipherSuites(s: String*) = copy(enabledCipherSuites = Some(s.toList))
/**
* Java API: Make a copy of this message with the given `enabledProtocols`.
*/
@varargs
def withProtocols(p: String*) = copy(enabledProtocols = Some(p.toList))
/**
* Java API: Make a copy of this message with the given [[ClientAuth]] setting.
*/
def withClientAuth(ca: ClientAuth) = copy(clientAuth = Some(ca))
/**
* Java API: Make a copy of this message with the given [[SSLParameters]].
*/
def withParameters(p: SSLParameters) = copy(sslParameters = Some(p))
}
object NegotiateNewSession extends NegotiateNewSession(None, None, None, None) {
/**
* Java API: obtain the default value (which will leave the SSLEngines
* settings unchanged).
*/
def withDefaults = this
}
/**
* Send the given [[akka.util.ByteString]] across the encrypted session to the
* peer.
*/
case class SendBytes(bytes: ByteString) extends SslTlsOutbound
/**
* An SSLEngine can either demand, allow or ignore its peers authentication
* (via certificates), where `Need` will fail the handshake if the peer does
* not provide valid credentials, `Want` allows the peer to send credentials
* and verifies them if provided, and `None` disables peer certificate
* verification.
*
* See the documentation for `SSLEngine::setWantClientAuth` for more
* information.
*/
sealed abstract class ClientAuth
object ClientAuth {
case object None extends ClientAuth
case object Want extends ClientAuth
case object Need extends ClientAuth
def none: ClientAuth = None
def want: ClientAuth = Want
def need: ClientAuth = Need
}

View file

@ -115,11 +115,7 @@ final class BidiFlow[-I1, +O1, -I2, +O2, +Mat](private[stream] override val modu
/**
* Turn this BidiFlow around by 180 degrees, logically flipping it upside down in a protocol stack.
*/
def reversed: BidiFlow[I2, O2, I1, O1, Mat] = {
val ins = shape.inlets
val outs = shape.outlets
new BidiFlow(module.replaceShape(shape.copyFromPorts(ins.reverse, outs.reverse)))
}
def reversed: BidiFlow[I2, O2, I1, O1, Mat] = new BidiFlow(module.replaceShape(shape.reversed))
override def withAttributes(attr: OperationAttributes): BidiFlow[I1, O1, I2, O2, Mat] =
new BidiFlow(module.withAttributes(attr).wrap())

View file

@ -1,376 +0,0 @@
/**
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
*/
package akka.stream.ssl
import java.nio.ByteBuffer
import java.security.Principal
import java.security.cert.Certificate
import javax.net.ssl.SSLEngineResult.HandshakeStatus._
import javax.net.ssl.SSLEngineResult.Status._
import javax.net.ssl.SSLEngine
import javax.net.ssl.SSLEngineResult
import javax.net.ssl.SSLPeerUnverifiedException
import javax.net.ssl.SSLSession
import akka.actor.Actor
import akka.actor.ActorLogging
import akka.actor.ActorRef
import akka.stream.ActorFlowMaterializerSettings
import akka.stream.impl._
import akka.util.ByteString
import akka.util.ByteStringBuilder
import org.reactivestreams.Publisher
import org.reactivestreams.Subscriber
import scala.annotation.tailrec
object SslTlsCipher {
/**
* An established SSL session.
*/
final case class InboundSession(
sessionInfo: SessionInfo,
data: Publisher[ByteString])
/**
* A request to establish an SSL session.
* FIXME Not used right now since there is only one session established
*/
final case class OutboundSession(
negotiation: SessionNegotiation,
data: Subscriber[ByteString])
/**
* Information about the established SSL session.
*/
final case class SessionInfo(
cipherSuite: String,
localCertificates: List[Certificate],
localPrincipal: Option[Principal],
peerCertificates: List[Certificate],
peerPrincipal: Option[Principal])
object SessionInfo {
def apply(engine: SSLEngine): SessionInfo =
apply(engine.getSession)
def apply(session: SSLSession): SessionInfo = {
val localCertificates = Option(session.getLocalCertificates).map { _.toList } getOrElse Nil
val localPrincipal = Option(session.getLocalPrincipal)
val peerCertificates =
try session.getPeerCertificates.toList
catch { case e: SSLPeerUnverifiedException Nil }
val peerPrincipal =
try Option(session.getPeerPrincipal)
catch { case e: SSLPeerUnverifiedException None }
SessionInfo(session.getCipherSuite, localCertificates, localPrincipal, peerCertificates, peerPrincipal)
}
}
/**
* Information needed to establish an SSL session.
*/
final case class SessionNegotiation(engine: SSLEngine)
}
final case class SslTlsCipher(
sessionInbound: Publisher[SslTlsCipher.InboundSession],
// FIXME We only have one session, and the SessionNegotiation is passed in via the constructor.
// This should really be a Subscriber[SslTlsCipher.OutboundSession]
plainTextOutbound: Subscriber[ByteString],
cipherTextInbound: Subscriber[ByteString],
cipherTextOutbound: Publisher[ByteString])
object SslTlsCipherActor {
val EmptyByteArray = Array.empty[Byte]
val EmptyByteBuffer = ByteBuffer.wrap(EmptyByteArray)
}
class SslTlsCipherActor(val requester: ActorRef, val sessionNegotioation: SslTlsCipher.SessionNegotiation, tracing: Boolean)
extends Actor
with ActorLogging
with Pump
with MultiStreamOutputProcessorLike
with MultiStreamInputProcessorLike {
override val subscriptionTimeoutSettings = ActorFlowMaterializerSettings(context.system).subscriptionTimeoutSettings
def this(requester: ActorRef, sessionNegotioation: SslTlsCipher.SessionNegotiation) =
this(requester, sessionNegotioation, false)
import MultiStreamInputProcessor.SubstreamSubscriber
import SslTlsCipherActor._
private var _nextId = 0L
protected def nextId(): Long = { _nextId += 1; _nextId }
override protected val inputBufferSize = 1
// The cipherTextInput (Subscriber[ByteString])
val inboundCipherTextInput = createSubstreamInput()
// The cipherTextOutput (Publisher[ByteString])
val outboundCipherTextOutput = createSubstreamOutput()
// The Publisher[SslTlsCipher.InboundSession]
// FIXME For now there is only one session ever exposed
val inboundSessionOutput = createSubstreamOutput()
// The read side for the user (Publisher[ByteString])
// FIXME For now there is only one session ever exposed
val inboundPlaintextOutput = createSubstreamOutput()
// The write side for the user (Subscriber[ByteString])
// FIXME For now there is only one session ever exposed
val outboundPlaintextInput = createSubstreamInput()
// Plaintext bytes to be encrypted
var plaintextOutboundBytes = EmptyByteBuffer
val plaintextOutboundBytesPending = new TransferState {
override def isReady = plaintextOutboundBytes.hasRemaining
override def isCompleted = false
}
// Encrypted bytes to be sent
val cipherTextOutboundBytes = new ByteStringBuilder
// Encrypted bytes to be decrypted
var cipherTextInboundBytes = EmptyByteBuffer
val cipherTextInboundBytesPending = new TransferState {
override def isReady = cipherTextInboundBytes.hasRemaining
override def isCompleted = false
}
// Plaintext bytes to be received
val plaintextInboundBytes = new ByteStringBuilder
// FIXME: Change this into a pool of ByteBuffer later
// These are Nettys default values
// 16665 + 1024 (room for compressed data) + 1024 (for OpenJDK compatibility)
val temporaryBuffer = ByteBuffer.allocate(16665 + 2048)
val engine: SSLEngine = sessionNegotioation.engine
def doWrap(tempBuf: ByteBuffer): SSLEngineResult = {
tempBuf.clear()
if (tracing) log.debug("before wrap {}", plaintextOutboundBytes.remaining)
val result = engine.wrap(plaintextOutboundBytes, tempBuf)
if (tracing) log.debug("after wrap {}", plaintextOutboundBytes.remaining)
tempBuf.flip()
if (tempBuf.hasRemaining) {
val bs = ByteString(tempBuf)
if (tracing) log.debug("wrap Enqueue cipher bytes {}", bs)
cipherTextOutboundBytes ++= bs
}
result
}
def doUnwrap(tempBuf: ByteBuffer): SSLEngineResult = {
tempBuf.clear()
if (tracing) log.debug("before unwrap {}", cipherTextInboundBytes.remaining)
val result = engine.unwrap(cipherTextInboundBytes, tempBuf)
if (tracing) log.debug("after unwrap {}", cipherTextInboundBytes.remaining)
tempBuf.flip()
if (tempBuf.hasRemaining) {
val bs = ByteString(tempBuf)
if (tracing) log.debug("unwrap Enqueue cipher bytes {}", bs)
plaintextInboundBytes ++= bs
}
result
}
def enqueueCipherInputBytes(data: ByteString): Unit = {
cipherTextInboundBytes =
if (cipherTextInboundBytes.hasRemaining) {
val buffer = ByteBuffer.allocate(cipherTextInboundBytes.remaining + data.size)
buffer.put(cipherTextInboundBytes)
data.copyToBuffer(buffer)
buffer.flip()
buffer
} else data.toByteBuffer
}
def writeCipherTextOutboundBytes() = {
if (cipherTextOutboundBytes.length > 0) {
val bs = cipherTextOutboundBytes.result()
cipherTextOutboundBytes.clear()
outboundCipherTextOutput.enqueueOutputElement(bs)
}
}
def writePlaintextInboundBytes() = {
if (plaintextInboundBytes.length > 0) {
val bs = plaintextInboundBytes.result()
plaintextInboundBytes.clear()
inboundPlaintextOutput.enqueueOutputElement(bs)
}
}
@tailrec
private def runDelegatedTasks(): Unit = {
val task = engine.getDelegatedTask
if (task != null) {
if (tracing) log.debug("Running delegated task {}", task)
task.run()
runDelegatedTasks()
}
}
def publishSSLSessionEstablished(): Unit = {
import SslTlsCipher._
val info = SessionInfo(engine)
val is = InboundSession(info, inboundPlaintextOutput.asInstanceOf[Publisher[ByteString]])
if (tracing) log.debug("#### Handshake done!")
inboundSessionOutput.enqueueOutputElement(is)
}
val unwrapPhase: TransferPhase = TransferPhase(inboundCipherTextInput.NeedsInput || cipherTextInboundBytesPending) { ()
if (tracing) log.debug("### UNWRAP")
if (inboundCipherTextInput.NeedsInput.isReady)
enqueueCipherInputBytes(inboundCipherTextInput.dequeueInputElement().asInstanceOf[ByteString])
val result = doUnwrap(temporaryBuffer)
val rs = result.getStatus
if (tracing) log.debug("## UNWRAP {}", rs)
val hs = result.getHandshakeStatus
val next = rs match {
case OK
handshakePhase(hs)
case CLOSED if (!engine.isInboundDone) encryptionPhase else completedPhase
case BUFFER_OVERFLOW throw new IllegalStateException // the SslBufferPool should make sure that buffers are never too small
case BUFFER_UNDERFLOW throw new IllegalStateException // should never appear as a result of a wrap
}
nextPhase(next)
}
val wrapPhase: TransferPhase = TransferPhase(outboundCipherTextOutput.NeedsDemand) { ()
if (tracing) log.debug("### WRAP")
val result = doWrap(temporaryBuffer)
val rs = result.getStatus
if (tracing) log.debug("## WRAP {}", rs)
val hs = result.getHandshakeStatus
val next = rs match {
case OK
writeCipherTextOutboundBytes()
handshakePhase(hs)
case CLOSED if (!engine.isInboundDone) decryptionPhase else completedPhase
case BUFFER_OVERFLOW throw new IllegalStateException // the SslBufferPool should make sure that buffers are never too small
case BUFFER_UNDERFLOW throw new IllegalStateException // should never appear as a result of a wrap
}
nextPhase(next)
}
def handshakePhase(hs: SSLEngineResult.HandshakeStatus): TransferPhase = {
if (tracing) log.debug("### HS {}", hs)
hs match {
case status @ (NOT_HANDSHAKING | FINISHED)
if (status == FINISHED) publishSSLSessionEstablished()
engineRunningPhase
case NEED_WRAP wrapPhase
case NEED_UNWRAP unwrapPhase
case NEED_TASK
runDelegatedTasks()
engine.getHandshakeStatus match {
case NEED_WRAP wrapPhase
case NEED_UNWRAP unwrapPhase
case x throw new IllegalStateException(s"Bad Handshake status $x")
}
}
}
val waitForHandshakeStartPhase: TransferPhase = TransferPhase((outboundPlaintextInput.NeedsInput || inboundCipherTextInput.NeedsInput) && inboundSessionOutput.NeedsDemand) { ()
if (tracing) log.debug("#### Starting Handshake")
engine.beginHandshake()
nextPhase(handshakePhase(engine.getHandshakeStatus))
}
val encryptionInputAvailable = outboundPlaintextInput.NeedsInput || plaintextOutboundBytesPending
val decryptionInputAvailable = inboundCipherTextInput.NeedsInput || cipherTextInboundBytesPending
val canEncrypt = encryptionInputAvailable && outboundCipherTextOutput.NeedsDemand
val canDecrypt = decryptionInputAvailable && inboundPlaintextOutput.NeedsDemand
val engineRunningPhase: TransferPhase = TransferPhase(canEncrypt || canDecrypt) { ()
if (tracing) log.debug("#### Engine running")
if (canEncrypt.isExecutable) {
nextPhase(encryptionPhase)
} else {
nextPhase(decryptionPhase)
}
}
val encryptionPhase: TransferPhase = TransferPhase(canEncrypt) { ()
if (tracing) log.debug("### Encrypting")
if (!plaintextOutboundBytesPending.isReady && outboundPlaintextInput.inputsAvailable) {
val elem = outboundPlaintextInput.dequeueInputElement().asInstanceOf[ByteString]
plaintextOutboundBytes = elem.asByteBuffer
}
val result = doWrap(temporaryBuffer)
val rs = result.getStatus
if (tracing) log.debug("## Encrypting {}", rs)
val hs = result.getHandshakeStatus
val next = rs match {
case OK
if (hs == NOT_HANDSHAKING) {
writeCipherTextOutboundBytes()
engineRunningPhase
} else handshakePhase(hs)
case CLOSED if (!engine.isInboundDone) decryptionPhase else completedPhase
case BUFFER_OVERFLOW throw new IllegalStateException // the SslBufferPool should make sure that buffers are never too small
case BUFFER_UNDERFLOW throw new IllegalStateException // should never appear as a result of a wrap
}
nextPhase(next)
}
val decryptionPhase: TransferPhase = TransferPhase(canDecrypt) { ()
if (tracing) log.debug("### Decrypting")
if (inboundCipherTextInput.NeedsInput.isReady) {
val elem = inboundCipherTextInput.dequeueInputElement().asInstanceOf[ByteString]
enqueueCipherInputBytes(elem)
}
val result = doUnwrap(temporaryBuffer)
val rs = result.getStatus
if (tracing) log.debug("## Decrypting {}", rs)
val hs = result.getHandshakeStatus
val next = rs match {
case OK
if (hs == NOT_HANDSHAKING) {
writePlaintextInboundBytes()
engineRunningPhase
} else handshakePhase(hs)
case CLOSED if (!engine.isOutboundDone) encryptionPhase else completedPhase
case BUFFER_OVERFLOW throw new IllegalStateException // the SslBufferPool should make sure that buffers are never too small
case BUFFER_UNDERFLOW throw new IllegalStateException // should never appear as a result of a wrap
}
nextPhase(next)
}
nextPhase(waitForHandshakeStartPhase)
override def preStart() {
val plainTextInput = inboundSessionOutput.asInstanceOf[Publisher[SslTlsCipher.InboundSession]]
val plainTextOutput = new SubstreamSubscriber[ByteString](self, outboundPlaintextInput.key)
val cipherTextInput = new SubstreamSubscriber[ByteString](self, inboundCipherTextInput.key)
val cipherTextOutput = outboundCipherTextOutput.asInstanceOf[Publisher[ByteString]]
requester ! SslTlsCipher(plainTextInput, plainTextOutput, cipherTextInput, cipherTextOutput)
}
override def receive = inputSubstreamManagement orElse outputSubstreamManagement
protected def fail(e: Throwable): Unit = {
// FIXME: escalate to supervisor
if (tracing) log.debug("fail {} due to: {}", self, e.getMessage)
failInputs(e)
failOutputs(e)
context.stop(self)
}
override protected def pumpFailed(e: Throwable): Unit = fail(e)
override protected def pumpFinished(): Unit = {
finishInputs()
finishOutputs()
}
}