+str #15833 TLS with session renegotiation
This commit is contained in:
parent
33919f683c
commit
616838a738
16 changed files with 1367 additions and 403 deletions
|
|
@ -82,10 +82,10 @@ class FlowErrorDocSpec extends AkkaSpec {
|
|||
val result = source.grouped(1000).runWith(Sink.head)
|
||||
// the negative element cause the scan stage to be restarted,
|
||||
// i.e. start from 0 again
|
||||
// result here will be a Future completed with Success(Vector(0, 1, 0, 5, 12))
|
||||
// result here will be a Future completed with Success(Vector(0, 1, 4, 0, 5, 12))
|
||||
//#restart-section
|
||||
|
||||
Await.result(result, remaining) should be(Vector(0, 1, 0, 5, 12))
|
||||
Await.result(result, remaining) should be(Vector(0, 1, 4, 0, 5, 12))
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -343,9 +343,11 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit {
|
|||
"resume when Scan throws" in new TestSetup(Seq(
|
||||
Scan(1, (acc: Int, x: Int) ⇒ if (x == 10) throw TE else acc + x, resumingDecider))) {
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(OnNext(1)))
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(RequestOne))
|
||||
upstream.onNext(2)
|
||||
lastEvents() should be(Set(OnNext(1)))
|
||||
lastEvents() should be(Set(OnNext(3)))
|
||||
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(RequestOne))
|
||||
|
|
@ -353,15 +355,17 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit {
|
|||
lastEvents() should be(Set(RequestOne))
|
||||
|
||||
upstream.onNext(4)
|
||||
lastEvents() should be(Set(OnNext(3))) // 1 + 2
|
||||
lastEvents() should be(Set(OnNext(7))) // 1 + 2 + 4
|
||||
}
|
||||
|
||||
"restart when Scan throws" in new TestSetup(Seq(
|
||||
Scan(1, (acc: Int, x: Int) ⇒ if (x == 10) throw TE else acc + x, restartingDecider))) {
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(OnNext(1)))
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(RequestOne))
|
||||
upstream.onNext(2)
|
||||
lastEvents() should be(Set(OnNext(1)))
|
||||
lastEvents() should be(Set(OnNext(3)))
|
||||
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(RequestOne))
|
||||
|
|
@ -371,10 +375,12 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit {
|
|||
upstream.onNext(4)
|
||||
lastEvents() should be(Set(OnNext(1))) // starts over again
|
||||
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(OnNext(5)))
|
||||
downstream.requestOne()
|
||||
lastEvents() should be(Set(RequestOne))
|
||||
upstream.onNext(20)
|
||||
lastEvents() should be(Set(OnNext(5))) // 1+4
|
||||
lastEvents() should be(Set(OnNext(25))) // 1 + 4 + 20
|
||||
}
|
||||
|
||||
"restart when Conflate `seed` throws" in new TestSetup(Seq(Conflate(
|
||||
|
|
|
|||
411
akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala
Normal file
411
akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala
Normal file
|
|
@ -0,0 +1,411 @@
|
|||
package akka.stream.io
|
||||
|
||||
import java.security.{ KeyStore, SecureRandom }
|
||||
import javax.net.ssl.{ TrustManagerFactory, KeyManagerFactory, SSLContext }
|
||||
import akka.stream.{ Graph, BidiShape, ActorFlowMaterializer }
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream.io._
|
||||
import akka.stream.testkit.{ TestUtils, AkkaSpec }
|
||||
import akka.util.ByteString
|
||||
import scala.concurrent.Await
|
||||
import scala.concurrent.duration._
|
||||
import scala.collection.immutable
|
||||
import scala.util.Random
|
||||
import akka.stream.stage.AsyncStage
|
||||
import akka.stream.stage.AsyncContext
|
||||
import java.util.concurrent.TimeoutException
|
||||
import akka.actor.ActorSystem
|
||||
import javax.net.ssl.SSLSession
|
||||
import akka.pattern.{ after ⇒ later }
|
||||
import scala.concurrent.Future
|
||||
import java.net.InetSocketAddress
|
||||
import akka.testkit.EventFilter
|
||||
import akka.stream.stage.PushStage
|
||||
import akka.stream.stage.Context
|
||||
|
||||
object TlsSpec {
|
||||
|
||||
val rnd = new Random
|
||||
|
||||
def initSslContext(): SSLContext = {
|
||||
|
||||
val password = "changeme"
|
||||
|
||||
val keyStore = KeyStore.getInstance(KeyStore.getDefaultType)
|
||||
keyStore.load(getClass.getResourceAsStream("/keystore"), password.toCharArray)
|
||||
|
||||
val trustStore = KeyStore.getInstance(KeyStore.getDefaultType)
|
||||
trustStore.load(getClass.getResourceAsStream("/truststore"), password.toCharArray)
|
||||
|
||||
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm)
|
||||
keyManagerFactory.init(keyStore, password.toCharArray)
|
||||
|
||||
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
|
||||
trustManagerFactory.init(trustStore)
|
||||
|
||||
val context = SSLContext.getInstance("TLS")
|
||||
context.init(keyManagerFactory.getKeyManagers, trustManagerFactory.getTrustManagers, new SecureRandom)
|
||||
context
|
||||
}
|
||||
|
||||
/**
|
||||
* This is a stage that fires a TimeoutException failure 2 seconds after it was started,
|
||||
* independent of the traffic going through. The purpose is to include the last seen
|
||||
* element in the exception message to help in figuring out what went wrong.
|
||||
*/
|
||||
class Timeout(duration: FiniteDuration)(implicit system: ActorSystem) extends AsyncStage[ByteString, ByteString, Unit] {
|
||||
private var last: ByteString = _
|
||||
|
||||
override def initAsyncInput(ctx: AsyncContext[ByteString, Unit]) = {
|
||||
val cb = ctx.getAsyncCallback()
|
||||
system.scheduler.scheduleOnce(duration)(cb.invoke(()))(system.dispatcher)
|
||||
}
|
||||
|
||||
override def onAsyncInput(u: Unit, ctx: AsyncContext[ByteString, Unit]) =
|
||||
ctx.fail(new TimeoutException(s"timeout expired, last element was $last"))
|
||||
|
||||
override def onPush(elem: ByteString, ctx: AsyncContext[ByteString, Unit]) = {
|
||||
last = elem
|
||||
if (ctx.isHoldingDownstream) ctx.pushAndPull(elem)
|
||||
else ctx.holdUpstream()
|
||||
}
|
||||
|
||||
override def onPull(ctx: AsyncContext[ByteString, Unit]) =
|
||||
if (ctx.isFinishing) ctx.pushAndFinish(last)
|
||||
else if (ctx.isHoldingUpstream) ctx.pushAndPull(last)
|
||||
else ctx.holdDownstream()
|
||||
|
||||
override def onUpstreamFinish(ctx: AsyncContext[ByteString, Unit]) =
|
||||
if (ctx.isHoldingUpstream) ctx.absorbTermination()
|
||||
else ctx.finish()
|
||||
|
||||
override def onDownstreamFinish(ctx: AsyncContext[ByteString, Unit]) = {
|
||||
system.log.debug("cancelled")
|
||||
ctx.finish()
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME #17226 replace by .dropWhile when implemented
|
||||
class DropWhile[T](p: T ⇒ Boolean) extends PushStage[T, T] {
|
||||
private var open = false
|
||||
override def onPush(elem: T, ctx: Context[T]) =
|
||||
if (open) ctx.push(elem)
|
||||
else if (p(elem)) ctx.pull()
|
||||
else {
|
||||
open = true
|
||||
ctx.push(elem)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off") {
|
||||
import TlsSpec._
|
||||
|
||||
import system.dispatcher
|
||||
implicit val materializer = ActorFlowMaterializer()
|
||||
|
||||
import FlowGraph.Implicits._
|
||||
|
||||
"StreamTLS" must {
|
||||
|
||||
val sslContext = initSslContext()
|
||||
|
||||
val debug = Flow[SslTlsInbound].map { x ⇒
|
||||
x match {
|
||||
case SessionTruncated ⇒ system.log.debug(s" ----------- truncated ")
|
||||
case SessionBytes(_, b) ⇒ system.log.debug(s" ----------- (${b.size}) ${b.take(32).utf8String}")
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
val cipherSuites = NegotiateNewSession.withCipherSuites("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_128_CBC_SHA")
|
||||
def clientTls(closing: Closing) = SslTls(sslContext, cipherSuites, Client, closing)
|
||||
def serverTls(closing: Closing) = SslTls(sslContext, cipherSuites, Server, closing)
|
||||
|
||||
trait Named {
|
||||
def name: String =
|
||||
getClass.getName
|
||||
.reverse
|
||||
.dropWhile(c ⇒ "$0123456789".indexOf(c) != -1)
|
||||
.takeWhile(_ != '$')
|
||||
.reverse
|
||||
}
|
||||
|
||||
trait CommunicationSetup extends Named {
|
||||
def decorateFlow(leftClosing: Closing, rightClosing: Closing,
|
||||
rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]): Flow[SslTlsOutbound, SslTlsInbound, Unit]
|
||||
def cleanup(): Unit = ()
|
||||
}
|
||||
|
||||
object ClientInitiates extends CommunicationSetup {
|
||||
def decorateFlow(leftClosing: Closing, rightClosing: Closing,
|
||||
rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) =
|
||||
clientTls(leftClosing) atop serverTls(rightClosing).reversed join rhs
|
||||
}
|
||||
|
||||
object ServerInitiates extends CommunicationSetup {
|
||||
def decorateFlow(leftClosing: Closing, rightClosing: Closing,
|
||||
rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) =
|
||||
serverTls(leftClosing) atop clientTls(rightClosing).reversed join rhs
|
||||
}
|
||||
|
||||
def server(flow: Flow[ByteString, ByteString, Any]) = {
|
||||
val server = StreamTcp()
|
||||
.bind(new InetSocketAddress("localhost", 0))
|
||||
.to(Sink.foreach(c ⇒ c.flow.join(flow).run()))
|
||||
.run()
|
||||
Await.result(server, 2.seconds)
|
||||
}
|
||||
|
||||
object ClientInitiatesViaTcp extends CommunicationSetup {
|
||||
var binding: StreamTcp.ServerBinding = null
|
||||
def decorateFlow(leftClosing: Closing, rightClosing: Closing,
|
||||
rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = {
|
||||
binding = server(serverTls(rightClosing).reversed join rhs)
|
||||
clientTls(leftClosing) join StreamTcp().outgoingConnection(binding.localAddress)
|
||||
}
|
||||
override def cleanup(): Unit = binding.unbind()
|
||||
}
|
||||
|
||||
object ServerInitiatesViaTcp extends CommunicationSetup {
|
||||
var binding: StreamTcp.ServerBinding = null
|
||||
def decorateFlow(leftClosing: Closing, rightClosing: Closing,
|
||||
rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = {
|
||||
binding = server(clientTls(rightClosing).reversed join rhs)
|
||||
serverTls(leftClosing) join StreamTcp().outgoingConnection(binding.localAddress)
|
||||
}
|
||||
override def cleanup(): Unit = binding.unbind()
|
||||
}
|
||||
|
||||
val communicationPatterns =
|
||||
Seq(
|
||||
ClientInitiates,
|
||||
ServerInitiates,
|
||||
ClientInitiatesViaTcp,
|
||||
ServerInitiatesViaTcp)
|
||||
|
||||
trait PayloadScenario extends Named {
|
||||
def flow: Flow[SslTlsInbound, SslTlsOutbound, Any] =
|
||||
Flow[SslTlsInbound]
|
||||
.map {
|
||||
var session: SSLSession = null
|
||||
def setSession(s: SSLSession) = {
|
||||
session = s
|
||||
system.log.debug(s"new session: $session (${session.getId mkString ","})")
|
||||
}
|
||||
|
||||
{
|
||||
case SessionTruncated ⇒ SendBytes(ByteString("TRUNCATED"))
|
||||
case SessionBytes(s, b) if session == null ⇒
|
||||
setSession(s)
|
||||
SendBytes(b)
|
||||
case SessionBytes(s, b) if s != session ⇒
|
||||
setSession(s)
|
||||
SendBytes(ByteString("NEWSESSION") ++ b)
|
||||
case SessionBytes(s, b) ⇒ SendBytes(b)
|
||||
}
|
||||
}
|
||||
def leftClosing: Closing = IgnoreComplete
|
||||
def rightClosing: Closing = IgnoreComplete
|
||||
|
||||
def inputs: immutable.Seq[SslTlsOutbound]
|
||||
def output: ByteString
|
||||
|
||||
protected def send(str: String) = SendBytes(ByteString(str))
|
||||
protected def send(ch: Char) = SendBytes(ByteString(ch.toByte))
|
||||
}
|
||||
|
||||
object SingleBytes extends PayloadScenario {
|
||||
val str = "0123456789"
|
||||
def inputs = str.map(ch ⇒ SendBytes(ByteString(ch.toByte)))
|
||||
def output = ByteString(str)
|
||||
}
|
||||
|
||||
object MediumMessages extends PayloadScenario {
|
||||
val strs = "0123456789" map (d ⇒ d.toString * (rnd.nextInt(9000) + 1000))
|
||||
def inputs = strs map (s ⇒ SendBytes(ByteString(s)))
|
||||
def output = ByteString((strs :\ "")(_ ++ _))
|
||||
}
|
||||
|
||||
object LargeMessages extends PayloadScenario {
|
||||
// TLS max packet size is 16384 bytes
|
||||
val strs = "0123456789" map (d ⇒ d.toString * (rnd.nextInt(9000) + 17000))
|
||||
def inputs = strs map (s ⇒ SendBytes(ByteString(s)))
|
||||
def output = ByteString((strs :\ "")(_ ++ _))
|
||||
}
|
||||
|
||||
object EmptyBytesFirst extends PayloadScenario {
|
||||
def inputs = List(ByteString.empty, ByteString("hello")).map(SendBytes)
|
||||
def output = ByteString("hello")
|
||||
}
|
||||
|
||||
object EmptyBytesInTheMiddle extends PayloadScenario {
|
||||
def inputs = List(ByteString("hello"), ByteString.empty, ByteString(" world")).map(SendBytes)
|
||||
def output = ByteString("hello world")
|
||||
}
|
||||
|
||||
object EmptyBytesLast extends PayloadScenario {
|
||||
def inputs = List(ByteString("hello"), ByteString.empty).map(SendBytes)
|
||||
def output = ByteString("hello")
|
||||
}
|
||||
|
||||
// this demonstrates that cancellation is ignored so that the five results make it back
|
||||
object CancellingRHS extends PayloadScenario {
|
||||
override def flow =
|
||||
Flow[SslTlsInbound]
|
||||
.mapConcat {
|
||||
case SessionTruncated ⇒ SessionTruncated :: Nil
|
||||
case SessionBytes(s, bytes) ⇒ bytes.map(b ⇒ SessionBytes(s, ByteString(b)))
|
||||
}
|
||||
.take(5)
|
||||
.mapAsync(5, x ⇒ later(500.millis, system.scheduler)(Future.successful(x)))
|
||||
.via(super.flow)
|
||||
override def rightClosing = IgnoreCancel
|
||||
|
||||
val str = "abcdef" * 100
|
||||
def inputs = str.map(send)
|
||||
def output = ByteString(str.take(5))
|
||||
}
|
||||
|
||||
object CancellingRHSIgnoresBoth extends PayloadScenario {
|
||||
override def flow =
|
||||
Flow[SslTlsInbound]
|
||||
.mapConcat {
|
||||
case SessionTruncated ⇒ SessionTruncated :: Nil
|
||||
case SessionBytes(s, bytes) ⇒ bytes.map(b ⇒ SessionBytes(s, ByteString(b)))
|
||||
}
|
||||
.take(5)
|
||||
.mapAsync(5, x ⇒ later(500.millis, system.scheduler)(Future.successful(x)))
|
||||
.via(super.flow)
|
||||
override def rightClosing = IgnoreBoth
|
||||
|
||||
val str = "abcdef" * 100
|
||||
def inputs = str.map(send)
|
||||
def output = ByteString(str.take(5))
|
||||
}
|
||||
|
||||
object LHSIgnoresBoth extends PayloadScenario {
|
||||
override def leftClosing = IgnoreBoth
|
||||
val str = "0123456789"
|
||||
def inputs = str.map(ch ⇒ SendBytes(ByteString(ch.toByte)))
|
||||
def output = ByteString(str)
|
||||
}
|
||||
|
||||
object BothSidesIgnoreBoth extends PayloadScenario {
|
||||
override def leftClosing = IgnoreBoth
|
||||
override def rightClosing = IgnoreBoth
|
||||
val str = "0123456789"
|
||||
def inputs = str.map(ch ⇒ SendBytes(ByteString(ch.toByte)))
|
||||
def output = ByteString(str)
|
||||
}
|
||||
|
||||
object SessionRenegotiationBySender extends PayloadScenario {
|
||||
def inputs = List(send("hello"), NegotiateNewSession, send("world"))
|
||||
def output = ByteString("helloNEWSESSIONworld")
|
||||
}
|
||||
|
||||
// difference is that the RHS engine will now receive the handshake while trying to send
|
||||
object SessionRenegotiationByReceiver extends PayloadScenario {
|
||||
val str = "abcdef" * 100
|
||||
def inputs = str.map(send) ++ Seq(NegotiateNewSession) ++ "hello world".map(send)
|
||||
def output = ByteString(str + "NEWSESSIONhello world")
|
||||
}
|
||||
|
||||
val logCipherSuite = Flow[SslTlsInbound]
|
||||
.map {
|
||||
var session: SSLSession = null
|
||||
def setSession(s: SSLSession) = {
|
||||
session = s
|
||||
system.log.debug(s"new session: $session (${session.getId mkString ","})")
|
||||
}
|
||||
|
||||
{
|
||||
case SessionTruncated ⇒ SendBytes(ByteString("TRUNCATED"))
|
||||
case SessionBytes(s, b) if s != session ⇒
|
||||
setSession(s)
|
||||
SendBytes(ByteString(s.getCipherSuite) ++ b)
|
||||
case SessionBytes(s, b) ⇒ SendBytes(b)
|
||||
}
|
||||
}
|
||||
|
||||
object SessionRenegotiationFirstOne extends PayloadScenario {
|
||||
override def flow = logCipherSuite
|
||||
def inputs = NegotiateNewSession.withCipherSuites("TLS_RSA_WITH_AES_128_CBC_SHA") :: send("hello") :: Nil
|
||||
def output = ByteString("TLS_RSA_WITH_AES_128_CBC_SHAhello")
|
||||
}
|
||||
|
||||
object SessionRenegotiationFirstTwo extends PayloadScenario {
|
||||
override def flow = logCipherSuite
|
||||
def inputs = NegotiateNewSession.withCipherSuites("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA") :: send("hello") :: Nil
|
||||
def output = ByteString("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHAhello")
|
||||
}
|
||||
|
||||
val scenarios =
|
||||
Seq(
|
||||
SingleBytes,
|
||||
MediumMessages,
|
||||
LargeMessages,
|
||||
EmptyBytesFirst,
|
||||
EmptyBytesInTheMiddle,
|
||||
EmptyBytesLast,
|
||||
CancellingRHS,
|
||||
SessionRenegotiationBySender,
|
||||
SessionRenegotiationByReceiver,
|
||||
SessionRenegotiationFirstOne,
|
||||
SessionRenegotiationFirstTwo)
|
||||
|
||||
for {
|
||||
commPattern ← communicationPatterns
|
||||
scenario ← scenarios
|
||||
} {
|
||||
s"work in mode ${commPattern.name} while sending ${scenario.name}" in {
|
||||
val onRHS = debug.via(scenario.flow)
|
||||
val f =
|
||||
Source(scenario.inputs)
|
||||
.via(commPattern.decorateFlow(scenario.leftClosing, scenario.rightClosing, onRHS))
|
||||
.transform(() ⇒ new PushStage[SslTlsInbound, SslTlsInbound] {
|
||||
override def onPush(elem: SslTlsInbound, ctx: Context[SslTlsInbound]) =
|
||||
ctx.push(elem)
|
||||
override def onDownstreamFinish(ctx: Context[SslTlsInbound]) = {
|
||||
system.log.debug("me cancelled")
|
||||
ctx.finish()
|
||||
}
|
||||
})
|
||||
.via(debug)
|
||||
.collect { case SessionBytes(_, b) ⇒ b }
|
||||
.scan(ByteString.empty)(_ ++ _)
|
||||
.transform(() ⇒ new Timeout(6.seconds))
|
||||
.transform(() ⇒ new DropWhile(_.size < scenario.output.size))
|
||||
.runWith(Sink.head)
|
||||
|
||||
Await.result(f, 8.seconds).utf8String should be(scenario.output.utf8String)
|
||||
|
||||
commPattern.cleanup()
|
||||
|
||||
// flush log so as to not mix up logs of different test cases
|
||||
if (log.isDebugEnabled)
|
||||
EventFilter.debug("stopgap", occurrences = 1) intercept {
|
||||
log.debug("stopgap")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
"A SslTlsPlacebo" must {
|
||||
|
||||
"pass through data" in {
|
||||
val f = Source(1 to 3)
|
||||
.map(b ⇒ SendBytes(ByteString(b.toByte)))
|
||||
.via(SslTlsPlacebo.forScala join Flow.apply)
|
||||
.grouped(10)
|
||||
.runWith(Sink.head)
|
||||
val result = Await.result(f, 3.seconds)
|
||||
result.map(_.bytes) should be((1 to 3).map(b ⇒ ByteString(b.toByte)))
|
||||
result.map(_.session).foreach(s ⇒ s.getCipherSuite should be("SSL_NULL_WITH_NULL_NULL"))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -6,12 +6,12 @@ package akka.stream.scaladsl
|
|||
import scala.concurrent.Await
|
||||
import scala.concurrent.duration._
|
||||
import scala.concurrent.forkjoin.ThreadLocalRandom.{ current ⇒ random }
|
||||
|
||||
import scala.collection.immutable
|
||||
|
||||
import akka.stream.ActorFlowMaterializer
|
||||
import akka.stream.ActorFlowMaterializerSettings
|
||||
import akka.stream.testkit.AkkaSpec
|
||||
import akka.stream.ActorOperationAttributes
|
||||
import akka.stream.Supervision
|
||||
|
||||
class FlowScanSpec extends AkkaSpec {
|
||||
|
||||
|
|
@ -39,5 +39,20 @@ class FlowScanSpec extends AkkaSpec {
|
|||
val v = Vector.empty[Int]
|
||||
scan(Source(v)) should be(v.scan(0)(_ + _))
|
||||
}
|
||||
|
||||
"emit values promptly" in {
|
||||
val f = Source.single(1).concat(Source.lazyEmpty).scan(0)(_ + _).grouped(2).runWith(Sink.head)
|
||||
Await.result(f, 1.second) should be(Seq(0, 1))
|
||||
}
|
||||
|
||||
"fail properly" in {
|
||||
import ActorOperationAttributes._
|
||||
val scan = Flow[Int].scan(0) { (old, current) ⇒
|
||||
require(current > 0)
|
||||
old + current
|
||||
}.withAttributes(supervisionStrategy(Supervision.restartingDecider))
|
||||
val f = Source(List(1, 3, -1, 5, 7)).via(scan).grouped(1000).runWith(Sink.head)
|
||||
Await.result(f, 1.second) should be(Seq(0, 1, 4, 0, 5, 12))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -213,6 +213,7 @@ final case class BidiShape[-In1, +Out1, -In2, +Out2](in1: Inlet[In1],
|
|||
require(outlets.size == 2, s"proposed outlets [${outlets.mkString(", ")}] do not fit BidiShape")
|
||||
BidiShape(inlets(0), outlets(0), inlets(1), outlets(1))
|
||||
}
|
||||
def reversed: Shape = copyFromPorts(inlets.reverse, outlets.reverse)
|
||||
//#implementation-details-elided
|
||||
}
|
||||
//#bidi-shape
|
||||
|
|
|
|||
|
|
@ -11,10 +11,15 @@ import akka.pattern.ask
|
|||
import akka.stream.actor.ActorSubscriber
|
||||
import akka.stream.impl.GenJunctions.ZipWithModule
|
||||
import akka.stream.impl.Junctions._
|
||||
import akka.stream.impl.MultiStreamInputProcessor.SubstreamSubscriber
|
||||
import akka.stream.impl.StreamLayout.Module
|
||||
import akka.stream.impl.fusing.ActorInterpreter
|
||||
import akka.stream.impl.io.SslTlsCipherActor
|
||||
import akka.stream.scaladsl._
|
||||
import akka.stream._
|
||||
import akka.stream.io._
|
||||
import akka.stream.io.SslTls.TlsModule
|
||||
import akka.util.ByteString
|
||||
import org.reactivestreams._
|
||||
|
||||
import scala.concurrent.{ Await, ExecutionContextExecutor }
|
||||
|
|
@ -83,6 +88,22 @@ private[akka] case class ActorFlowMaterializerImpl(override val settings: ActorF
|
|||
assignPort(stage.outPort, processor)
|
||||
mat
|
||||
|
||||
case tls: TlsModule ⇒
|
||||
val es = effectiveSettings(effectiveAttributes)
|
||||
val props = SslTlsCipherActor.props(es, tls.sslContext, tls.firstSession, tracing = true, tls.role, tls.closing)
|
||||
val impl = actorOf(props, stageName(effectiveAttributes), es.dispatcher)
|
||||
def factory(id: Int) = new ActorPublisher[Any](impl) {
|
||||
override val wakeUpMsg = FanOut.SubstreamSubscribePending(id)
|
||||
}
|
||||
val publishers = Vector.tabulate(2)(factory)
|
||||
impl ! FanOut.ExposedPublishers(publishers)
|
||||
|
||||
assignPort(tls.plainOut, publishers(SslTlsCipherActor.UserOut))
|
||||
assignPort(tls.cipherOut, publishers(SslTlsCipherActor.TransportOut))
|
||||
|
||||
assignPort(tls.plainIn, FanIn.SubInput[Any](impl, SslTlsCipherActor.UserIn))
|
||||
assignPort(tls.cipherIn, FanIn.SubInput[Any](impl, SslTlsCipherActor.TransportIn))
|
||||
|
||||
case junction: JunctionModule ⇒ materializeJunction(junction, effectiveAttributes, effectiveSettings(effectiveAttributes))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -155,6 +155,7 @@ private[akka] object FanIn {
|
|||
if (input.inputsDepleted) {
|
||||
if (marked(id)) markedDepleted += 1
|
||||
depleted(id) = true
|
||||
onDepleted(id)
|
||||
}
|
||||
elem
|
||||
}
|
||||
|
|
|
|||
|
|
@ -186,6 +186,16 @@ private[akka] object FanOut {
|
|||
|
||||
def onCancel(output: Int): Unit = ()
|
||||
|
||||
def demandAvailableFor(id: Int) = new TransferState {
|
||||
override def isCompleted: Boolean = cancelled(id) || completed(id)
|
||||
override def isReady: Boolean = pending(id)
|
||||
}
|
||||
|
||||
def demandOrCancelAvailableFor(id: Int) = new TransferState {
|
||||
override def isCompleted: Boolean = false
|
||||
override def isReady: Boolean = pending(id) || cancelled(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* Will only transfer an element when all marked outputs
|
||||
* have demand, and will complete as soon as any of the marked
|
||||
|
|
|
|||
|
|
@ -120,22 +120,26 @@ private[akka] class FlexiMergeImpl[T, S <: Shape](
|
|||
nextPhase(TransferPhase(precondition) { () ⇒
|
||||
behavior.condition match {
|
||||
case read: ReadAny[t] ⇒
|
||||
suppressCompletion()
|
||||
val id = inputBunch.idToDequeue()
|
||||
val elem = inputBunch.dequeueAndYield(id)
|
||||
val inputHandle = inputMapping(id)
|
||||
callOnInput(inputHandle, elem)
|
||||
triggerCompletionAfterRead(inputHandle)
|
||||
case r: ReadPreferred[t] ⇒
|
||||
suppressCompletion()
|
||||
val elem = inputBunch.dequeuePrefering(indexOf(r.preferred))
|
||||
val id = inputBunch.lastDequeuedId
|
||||
val inputHandle = inputMapping(id)
|
||||
callOnInput(inputHandle, elem)
|
||||
triggerCompletionAfterRead(inputHandle)
|
||||
case Read(input) ⇒
|
||||
suppressCompletion()
|
||||
val elem = inputBunch.dequeue(indexOf(input))
|
||||
callOnInput(input, elem)
|
||||
triggerCompletionAfterRead(input)
|
||||
case read: ReadAll[t] ⇒
|
||||
suppressCompletion()
|
||||
val inputs = read.inputs
|
||||
val values = inputs.collect {
|
||||
case input if include(input) ⇒ input → inputBunch.dequeue(indexOf(input))
|
||||
|
|
@ -160,11 +164,18 @@ private[akka] class FlexiMergeImpl[T, S <: Shape](
|
|||
}
|
||||
}
|
||||
|
||||
private def triggerCompletionAfterRead(inputHandle: InPort): Unit =
|
||||
private var completionEnabled = true
|
||||
|
||||
private def suppressCompletion(): Unit = completionEnabled = false
|
||||
|
||||
private def triggerCompletionAfterRead(inputHandle: InPort): Unit = {
|
||||
completionEnabled = true
|
||||
if (inputBunch.isDepleted(indexOf(inputHandle)))
|
||||
triggerCompletion(inputHandle)
|
||||
}
|
||||
|
||||
private def triggerCompletion(in: InPort): Unit =
|
||||
if (completionEnabled)
|
||||
changeBehavior(
|
||||
try completion.onUpstreamFinish(ctx, in)
|
||||
catch {
|
||||
|
|
|
|||
|
|
@ -171,4 +171,3 @@ private[akka] trait Pump {
|
|||
protected def pumpFinished(): Unit
|
||||
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -298,7 +298,7 @@ private[akka] object ActorInterpreter {
|
|||
def props(settings: ActorFlowMaterializerSettings, ops: Seq[Stage[_, _]], materializer: ActorFlowMaterializer): Props =
|
||||
Props(new ActorInterpreter(settings, ops, materializer))
|
||||
|
||||
case class AsyncInput(op: AsyncStage[Any, Any, Any], ctx: AsyncContext[Any, Any], event: Any)
|
||||
case class AsyncInput(op: AsyncStage[Any, Any, Any], ctx: AsyncContext[Any, Any], event: Any) extends DeadLetterSuppression
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -108,18 +108,27 @@ private[akka] final case class Drop[T](count: Long) extends PushStage[T, T] {
|
|||
*/
|
||||
private[akka] final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out, decider: Supervision.Decider) extends PushPullStage[In, Out] {
|
||||
private var aggregator = zero
|
||||
private var pushedZero = false
|
||||
|
||||
override def onPush(elem: In, ctx: Context[Out]): SyncDirective = {
|
||||
val old = aggregator
|
||||
aggregator = f(old, elem)
|
||||
ctx.push(old)
|
||||
if (pushedZero) {
|
||||
aggregator = f(aggregator, elem)
|
||||
ctx.push(aggregator)
|
||||
} else {
|
||||
aggregator = f(zero, elem)
|
||||
ctx.push(zero)
|
||||
}
|
||||
}
|
||||
|
||||
override def onPull(ctx: Context[Out]): SyncDirective =
|
||||
if (ctx.isFinishing) ctx.pushAndFinish(aggregator)
|
||||
else ctx.pull()
|
||||
if (!pushedZero) {
|
||||
pushedZero = true
|
||||
if (ctx.isFinishing) ctx.pushAndFinish(aggregator) else ctx.push(aggregator)
|
||||
} else ctx.pull()
|
||||
|
||||
override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective = ctx.absorbTermination()
|
||||
override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective =
|
||||
if (pushedZero) ctx.finish()
|
||||
else ctx.absorbTermination()
|
||||
|
||||
override def decide(t: Throwable): Supervision.Directive = decider(t)
|
||||
|
||||
|
|
|
|||
449
akka-stream/src/main/scala/akka/stream/impl/io/SslTls.scala
Normal file
449
akka-stream/src/main/scala/akka/stream/impl/io/SslTls.scala
Normal file
|
|
@ -0,0 +1,449 @@
|
|||
/**
|
||||
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
package akka.stream.impl.io
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
import java.security.Principal
|
||||
import java.security.cert.Certificate
|
||||
import javax.net.ssl.SSLEngineResult.HandshakeStatus
|
||||
import javax.net.ssl.SSLEngineResult.HandshakeStatus._
|
||||
import javax.net.ssl.SSLEngineResult.Status._
|
||||
import javax.net.ssl._
|
||||
import akka.actor.{ Props, Actor, ActorLogging, ActorRef }
|
||||
import akka.stream.ActorFlowMaterializerSettings
|
||||
import akka.stream.impl.FanIn.InputBunch
|
||||
import akka.stream.impl.FanOut.OutputBunch
|
||||
import akka.stream.impl._
|
||||
import akka.util.ByteString
|
||||
import akka.util.ByteStringBuilder
|
||||
import org.reactivestreams.Publisher
|
||||
import org.reactivestreams.Subscriber
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.immutable
|
||||
import akka.stream.io._
|
||||
import akka.event.LoggingReceive
|
||||
|
||||
/**
|
||||
* INTERNAL API.
|
||||
*/
|
||||
private[akka] object SslTlsCipherActor {
|
||||
|
||||
def props(settings: ActorFlowMaterializerSettings,
|
||||
sslContext: SSLContext,
|
||||
firstSession: NegotiateNewSession,
|
||||
tracing: Boolean,
|
||||
role: Role,
|
||||
closing: Closing): Props =
|
||||
Props(new SslTlsCipherActor(settings, sslContext, firstSession, tracing, role, closing))
|
||||
|
||||
final val TransportIn = 0
|
||||
final val TransportOut = 0
|
||||
|
||||
final val UserOut = 1
|
||||
final val UserIn = 1
|
||||
}
|
||||
|
||||
/**
|
||||
* INTERNAL API.
|
||||
*/
|
||||
private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, sslContext: SSLContext,
|
||||
firstSession: NegotiateNewSession, tracing: Boolean,
|
||||
role: Role, closing: Closing)
|
||||
extends Actor with ActorLogging with Pump {
|
||||
|
||||
import SslTlsCipherActor._
|
||||
|
||||
protected val outputBunch = new OutputBunch(outputCount = 2, self, this)
|
||||
outputBunch.markAllOutputs()
|
||||
|
||||
protected val inputBunch = new InputBunch(inputCount = 2, settings.maxInputBufferSize, this) {
|
||||
override def onError(input: Int, e: Throwable): Unit = fail(e)
|
||||
}
|
||||
|
||||
/**
|
||||
* The SSLEngine needs bite-sized chunks of data but we get arbitrary ByteString
|
||||
* from both the UserIn and the TransportIn ports. This is used to chop up such
|
||||
* a ByteString by filling the respective ByteBuffer and taking care to dequeue
|
||||
* a new element when data are demanded and none are left lying on the chopping
|
||||
* block.
|
||||
*/
|
||||
class ChoppingBlock(idx: Int, name: String) extends TransferState {
|
||||
override def isReady: Boolean = buffer.nonEmpty
|
||||
override def isCompleted: Boolean = false
|
||||
|
||||
private var buffer = ByteString.empty
|
||||
|
||||
/**
|
||||
* Whether there are no bytes lying on this chopping block.
|
||||
*/
|
||||
def isEmpty: Boolean = buffer.isEmpty
|
||||
|
||||
/**
|
||||
* Pour as many bytes as are available either on the chopping block or in
|
||||
* the inputBunch’s next ByteString into the supplied ByteBuffer, which is
|
||||
* expected to be in “read left-overs” mode, i.e. everything between its
|
||||
* position and limit is retained. In order to allocate a fresh ByteBuffer
|
||||
* with these characteristics, use `prepare()`.
|
||||
*/
|
||||
def chopInto(b: ByteBuffer): Unit = {
|
||||
b.compact()
|
||||
if (buffer.isEmpty) {
|
||||
buffer = inputBunch.dequeue(idx) match {
|
||||
// this class handles both UserIn and TransportIn
|
||||
case bs: ByteString ⇒ bs
|
||||
case SendBytes(bs) ⇒ bs
|
||||
case n: NegotiateNewSession ⇒
|
||||
setNewSessionParameters(n)
|
||||
ByteString.empty
|
||||
}
|
||||
if (tracing) log.debug(s"chopping from new chunk of ${buffer.size} into $name (${b.position})")
|
||||
} else {
|
||||
if (tracing) log.debug(s"chopping from old chunk of ${buffer.size} into $name (${b.position})")
|
||||
}
|
||||
val copied = buffer.copyToBuffer(b)
|
||||
buffer = buffer.drop(copied)
|
||||
b.flip()
|
||||
}
|
||||
|
||||
/**
|
||||
* When potentially complete packet data are left after unwrap() we must
|
||||
* put them back onto the chopping block because otherwise the pump will
|
||||
* not know that we are runnable.
|
||||
*/
|
||||
def putBack(b: ByteBuffer): Unit =
|
||||
if (b.hasRemaining()) {
|
||||
if (tracing) log.debug(s"putting back ${b.remaining} bytes into $name")
|
||||
val bs = ByteString(b)
|
||||
if (bs.nonEmpty) buffer = bs ++ buffer
|
||||
prepare(b)
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare a fresh ByteBuffer for receiving a chop of data.
|
||||
*/
|
||||
def prepare(b: ByteBuffer): Unit = {
|
||||
b.clear()
|
||||
b.limit(0)
|
||||
}
|
||||
}
|
||||
|
||||
// These are Nettys default values
|
||||
// 16665 + 1024 (room for compressed data) + 1024 (for OpenJDK compatibility)
|
||||
val transportOutBuffer = ByteBuffer.allocate(16665 + 2048)
|
||||
/*
|
||||
* deviating here: chopping multiple input packets into this buffer can lead to
|
||||
* an OVERFLOW signal that also is an UNDERFLOW; avoid unnecessary copying by
|
||||
* increasing this buffer size to host up to two packets
|
||||
*/
|
||||
val userOutBuffer = ByteBuffer.allocate(16665 * 2 + 2048)
|
||||
val transportInBuffer = ByteBuffer.allocate(16665 + 2048)
|
||||
val userInBuffer = ByteBuffer.allocate(16665 + 2048)
|
||||
|
||||
val userInChoppingBlock = new ChoppingBlock(UserIn, "UserIn")
|
||||
userInChoppingBlock.prepare(userInBuffer)
|
||||
val transportInChoppingBlock = new ChoppingBlock(TransportIn, "TransportIn")
|
||||
transportInChoppingBlock.prepare(transportInBuffer)
|
||||
|
||||
val engine: SSLEngine = {
|
||||
val e = sslContext.createSSLEngine()
|
||||
e.setUseClientMode(role == Client)
|
||||
e
|
||||
}
|
||||
var currentSession = engine.getSession
|
||||
var currentSessionParameters = firstSession
|
||||
applySessionParameters()
|
||||
|
||||
def applySessionParameters(): Unit = {
|
||||
val csp = currentSessionParameters
|
||||
import csp._
|
||||
enabledCipherSuites foreach (cs ⇒ engine.setEnabledCipherSuites(cs.toArray))
|
||||
enabledProtocols foreach (p ⇒ engine.setEnabledProtocols(p.toArray))
|
||||
clientAuth match {
|
||||
case Some(ClientAuth.None) ⇒ engine.setNeedClientAuth(false)
|
||||
case Some(ClientAuth.Want) ⇒ engine.setWantClientAuth(true)
|
||||
case Some(ClientAuth.Need) ⇒ engine.setNeedClientAuth(true)
|
||||
case None ⇒ // do nothing
|
||||
}
|
||||
sslParameters foreach (p ⇒ engine.setSSLParameters(p))
|
||||
engine.beginHandshake()
|
||||
lastHandshakeStatus = engine.getHandshakeStatus
|
||||
}
|
||||
|
||||
def setNewSessionParameters(n: NegotiateNewSession): Unit = {
|
||||
if (tracing) log.debug(s"applying $n")
|
||||
currentSession.invalidate()
|
||||
currentSessionParameters = n
|
||||
applySessionParameters()
|
||||
corkUser = true
|
||||
}
|
||||
|
||||
/*
|
||||
* So here’s the big picture summary: the SSLEngine is the boss, and it can
|
||||
* be in several states. Depending on this state, we may want to react to
|
||||
* different input and output conditions.
|
||||
*
|
||||
* - normal bidirectional operation (does both outbound and inbound)
|
||||
* - outbound close initiated, inbound still open
|
||||
* - inbound close initiated, outbound still open
|
||||
* - fully closed
|
||||
*
|
||||
* Upon reaching the last state we obviously just shut down. In addition to
|
||||
* these user-data states, the engine may at any point in time also be
|
||||
* handshaking. This is mostly transparent, but it has an influence on the
|
||||
* outbound direction:
|
||||
*
|
||||
* - if the local user triggered a re-negotiation, cork all user data until
|
||||
* that is finished
|
||||
* - if the outbound direction has been closed, trigger outbound readiness
|
||||
* based upon HandshakeStatus.NEED_WRAP
|
||||
*
|
||||
* These conditions lead to the introduction of a synthetic TransferState
|
||||
* representing the Engine.
|
||||
*/
|
||||
|
||||
var lastHandshakeStatus: HandshakeStatus = _
|
||||
|
||||
val engineNeedsWrap = new TransferState {
|
||||
def isReady = lastHandshakeStatus == NEED_WRAP
|
||||
def isCompleted = false
|
||||
}
|
||||
|
||||
val engineInboundOpen = new TransferState {
|
||||
def isReady = !engine.isInboundDone()
|
||||
def isCompleted = false
|
||||
}
|
||||
|
||||
var corkUser = true
|
||||
|
||||
val userHasData = new TransferState {
|
||||
private val user = inputBunch.inputsOrCompleteAvailableFor(UserIn) || userInChoppingBlock
|
||||
def isReady = !corkUser && user.isReady && lastHandshakeStatus != NEED_UNWRAP
|
||||
def isCompleted = false
|
||||
}
|
||||
|
||||
val transportHasData = inputBunch.inputsOrCompleteAvailableFor(TransportIn) || transportInChoppingBlock
|
||||
val userOutCancelled = new TransferState {
|
||||
def isReady = outputBunch.isCancelled(UserOut)
|
||||
def isCompleted = inputBunch.isDepleted(TransportIn)
|
||||
}
|
||||
|
||||
// bidirectional case
|
||||
val outbound = (userHasData || engineNeedsWrap) && outputBunch.demandAvailableFor(TransportOut)
|
||||
val inbound = (transportHasData || userOutCancelled) && outputBunch.demandOrCancelAvailableFor(UserOut)
|
||||
|
||||
// half-closed
|
||||
val outboundHalfClosed = engineNeedsWrap && outputBunch.demandAvailableFor(TransportOut)
|
||||
val inboundHalfClosed = transportHasData && engineInboundOpen
|
||||
|
||||
def completeOrFlush(): Unit =
|
||||
if (engine.isOutboundDone()) nextPhase(completedPhase)
|
||||
else nextPhase(flushingOutbound)
|
||||
|
||||
private def doInbound(isOutboundClosed: Boolean, inboundState: TransferState): Boolean =
|
||||
if (inputBunch.isDepleted(TransportIn) && transportInChoppingBlock.isEmpty) {
|
||||
if (tracing) log.debug("closing inbound")
|
||||
try engine.closeInbound()
|
||||
catch { case ex: SSLException ⇒ outputBunch.enqueue(UserOut, SessionTruncated) }
|
||||
completeOrFlush()
|
||||
false
|
||||
} else if (inboundState != inboundHalfClosed && outputBunch.isCancelled(UserOut)) {
|
||||
if (!isOutboundClosed && closing.ignoreCancel) {
|
||||
if (tracing) log.debug("ignoring UserIn cancellation")
|
||||
nextPhase(inboundClosed)
|
||||
} else {
|
||||
if (tracing) log.debug("closing inbound due to UserOut cancellation")
|
||||
engine.closeOutbound() // this is the correct way of shutting down the engine
|
||||
lastHandshakeStatus = engine.getHandshakeStatus
|
||||
nextPhase(flushingOutbound)
|
||||
}
|
||||
true
|
||||
} else if (inboundState.isReady) {
|
||||
transportInChoppingBlock.chopInto(transportInBuffer)
|
||||
try {
|
||||
doUnwrap()
|
||||
true
|
||||
} catch {
|
||||
case ex: SSLException ⇒
|
||||
if (tracing) log.debug(s"SSLException during doUnwrap: $ex")
|
||||
completeOrFlush()
|
||||
false
|
||||
}
|
||||
} else true
|
||||
|
||||
private def doOutbound(isInboundClosed: Boolean): Unit =
|
||||
if (inputBunch.isDepleted(UserIn) && userInChoppingBlock.isEmpty) {
|
||||
if (!isInboundClosed && closing.ignoreComplete) {
|
||||
if (tracing) log.debug("ignoring closeOutbound")
|
||||
} else {
|
||||
if (tracing) log.debug("closing outbound directly")
|
||||
engine.closeOutbound()
|
||||
lastHandshakeStatus = engine.getHandshakeStatus
|
||||
}
|
||||
nextPhase(outboundClosed)
|
||||
} else if (outputBunch.isCancelled(TransportOut)) {
|
||||
nextPhase(completedPhase)
|
||||
} else if (outbound.isReady) {
|
||||
if (userHasData.isReady) userInChoppingBlock.chopInto(userInBuffer)
|
||||
try doWrap()
|
||||
catch {
|
||||
case ex: SSLException ⇒
|
||||
if (tracing) log.debug(s"SSLException during doWrap: $ex")
|
||||
completeOrFlush()
|
||||
}
|
||||
}
|
||||
|
||||
val bidirectional = TransferPhase(outbound || inbound) { () ⇒
|
||||
if (tracing) log.debug("bidirectional")
|
||||
val continue = doInbound(isOutboundClosed = false, inbound)
|
||||
if (continue) {
|
||||
if (tracing) log.debug("bidirectional continue")
|
||||
doOutbound(isInboundClosed = false)
|
||||
}
|
||||
}
|
||||
|
||||
val flushingOutbound = TransferPhase(outboundHalfClosed) { () ⇒
|
||||
if (tracing) log.debug("flushingOutbound")
|
||||
try doWrap()
|
||||
catch { case ex: SSLException ⇒ nextPhase(completedPhase) }
|
||||
}
|
||||
|
||||
val awaitingClose = TransferPhase(inputBunch.inputsAvailableFor(TransportIn)) { () ⇒
|
||||
if (tracing) log.debug("awaitingClose")
|
||||
transportInChoppingBlock.chopInto(transportInBuffer)
|
||||
try doUnwrap(ignoreOutput = true)
|
||||
catch { case ex: SSLException ⇒ nextPhase(completedPhase) }
|
||||
}
|
||||
|
||||
val outboundClosed = TransferPhase(outboundHalfClosed || inbound) { () ⇒
|
||||
if (tracing) log.debug("outboundClosed")
|
||||
val continue = doInbound(isOutboundClosed = true, inbound)
|
||||
if (continue && outboundHalfClosed.isReady) {
|
||||
if (tracing) log.debug("outboundClosed continue")
|
||||
try doWrap()
|
||||
catch { case ex: SSLException ⇒ nextPhase(completedPhase) }
|
||||
}
|
||||
}
|
||||
|
||||
val inboundClosed = TransferPhase(outbound || inboundHalfClosed) { () ⇒
|
||||
if (tracing) log.debug("inboundClosed")
|
||||
val continue = doInbound(isOutboundClosed = false, inboundHalfClosed)
|
||||
if (continue) {
|
||||
if (tracing) log.debug("inboundClosed continue")
|
||||
doOutbound(isInboundClosed = true)
|
||||
}
|
||||
}
|
||||
|
||||
def flushToTransport(): Unit = {
|
||||
if (tracing) log.debug("flushToTransport")
|
||||
transportOutBuffer.flip()
|
||||
if (transportOutBuffer.hasRemaining) {
|
||||
val bs = ByteString(transportOutBuffer)
|
||||
outputBunch.enqueue(TransportOut, bs)
|
||||
if (tracing) log.debug(s"sending ${bs.size} bytes")
|
||||
}
|
||||
transportOutBuffer.clear()
|
||||
}
|
||||
|
||||
def flushToUser(): Unit = {
|
||||
if (tracing) log.debug("flushToUser")
|
||||
userOutBuffer.flip()
|
||||
if (userOutBuffer.hasRemaining) {
|
||||
val bs = ByteString(userOutBuffer)
|
||||
outputBunch.enqueue(UserOut, SessionBytes(currentSession, bs))
|
||||
}
|
||||
userOutBuffer.clear()
|
||||
}
|
||||
|
||||
private def doWrap(): Unit = {
|
||||
val result = engine.wrap(userInBuffer, transportOutBuffer)
|
||||
lastHandshakeStatus = result.getHandshakeStatus
|
||||
if (tracing) log.debug(s"wrap: status=${result.getStatus} handshake=$lastHandshakeStatus remaining=${userInBuffer.remaining} out=${transportOutBuffer.position}")
|
||||
if (lastHandshakeStatus == FINISHED) handshakeFinished()
|
||||
runDelegatedTasks()
|
||||
result.getStatus match {
|
||||
case OK ⇒
|
||||
flushToTransport()
|
||||
userInChoppingBlock.putBack(userInBuffer)
|
||||
case CLOSED ⇒
|
||||
flushToTransport()
|
||||
if (engine.isInboundDone()) nextPhase(completedPhase)
|
||||
else nextPhase(awaitingClose)
|
||||
case s ⇒ fail(new IllegalStateException(s"unexpected status $s in doWrap()"))
|
||||
}
|
||||
}
|
||||
|
||||
@tailrec
|
||||
private def doUnwrap(ignoreOutput: Boolean = false): Unit = {
|
||||
val result = engine.unwrap(transportInBuffer, userOutBuffer)
|
||||
if (ignoreOutput) userOutBuffer.clear()
|
||||
lastHandshakeStatus = result.getHandshakeStatus
|
||||
if (tracing) log.debug(s"unwrap: status=${result.getStatus} handshake=$lastHandshakeStatus remaining=${transportInBuffer.remaining} out=${userOutBuffer.position}")
|
||||
runDelegatedTasks()
|
||||
result.getStatus match {
|
||||
case OK ⇒
|
||||
result.getHandshakeStatus match {
|
||||
case NEED_WRAP ⇒ flushToUser()
|
||||
case FINISHED ⇒
|
||||
flushToUser()
|
||||
handshakeFinished()
|
||||
transportInChoppingBlock.putBack(transportInBuffer)
|
||||
case _ ⇒
|
||||
if (transportInBuffer.hasRemaining()) doUnwrap()
|
||||
else flushToUser()
|
||||
}
|
||||
case CLOSED ⇒
|
||||
flushToUser()
|
||||
if (engine.isOutboundDone()) nextPhase(completedPhase)
|
||||
else nextPhase(flushingOutbound)
|
||||
case BUFFER_UNDERFLOW ⇒
|
||||
flushToUser()
|
||||
case BUFFER_OVERFLOW ⇒
|
||||
flushToUser()
|
||||
transportInChoppingBlock.putBack(transportInBuffer)
|
||||
case s ⇒ fail(new IllegalStateException(s"unexpected status $s in doUnwrap()"))
|
||||
}
|
||||
}
|
||||
|
||||
@tailrec
|
||||
private def runDelegatedTasks(): Unit = {
|
||||
val task = engine.getDelegatedTask
|
||||
if (task != null) {
|
||||
if (tracing) log.debug("running task")
|
||||
task.run()
|
||||
runDelegatedTasks()
|
||||
} else {
|
||||
val st = lastHandshakeStatus
|
||||
lastHandshakeStatus = engine.getHandshakeStatus
|
||||
if (tracing && st != lastHandshakeStatus) log.debug(s"handshake status after tasks: $lastHandshakeStatus")
|
||||
}
|
||||
}
|
||||
|
||||
private def handshakeFinished(): Unit = {
|
||||
if (tracing) log.debug("handshake finished")
|
||||
currentSession = engine.getSession
|
||||
corkUser = false
|
||||
}
|
||||
|
||||
override def receive = inputBunch.subreceive.orElse[Any, Unit](outputBunch.subreceive)
|
||||
|
||||
nextPhase(bidirectional)
|
||||
|
||||
protected def fail(e: Throwable): Unit = {
|
||||
// FIXME: escalate to supervisor
|
||||
if (tracing) log.debug("fail {} due to: {}", self, e.getMessage)
|
||||
inputBunch.cancel()
|
||||
outputBunch.error(TransportOut, e)
|
||||
outputBunch.error(UserOut, e)
|
||||
context.stop(self)
|
||||
}
|
||||
|
||||
override protected def pumpFailed(e: Throwable): Unit = fail(e)
|
||||
|
||||
override protected def pumpFinished(): Unit = {
|
||||
inputBunch.cancel()
|
||||
outputBunch.complete()
|
||||
if (tracing) log.debug(s"STOP Outbound Closed: ${engine.isOutboundDone} Inbound closed: ${engine.isInboundDone}")
|
||||
context.stop(self)
|
||||
}
|
||||
}
|
||||
411
akka-stream/src/main/scala/akka/stream/io/SslTls.scala
Normal file
411
akka-stream/src/main/scala/akka/stream/io/SslTls.scala
Normal file
|
|
@ -0,0 +1,411 @@
|
|||
/**
|
||||
* Copyright (C) 2015 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
package akka.stream.io
|
||||
|
||||
import akka.stream._
|
||||
import akka.stream.impl.StreamLayout.Module
|
||||
import akka.util.ByteString
|
||||
import javax.net.ssl._
|
||||
import scala.annotation.varargs
|
||||
import scala.collection.immutable
|
||||
import java.security.cert.Certificate
|
||||
|
||||
/**
|
||||
* Stream cipher support based upon JSSE.
|
||||
*
|
||||
* The underlying SSLEngine has four ports: plaintext input/output and
|
||||
* ciphertext input/output. These are modeled as a [[akka.stream.BidiShape]]
|
||||
* element for use in stream topologies, where the plaintext ports are on the
|
||||
* left hand side of the shape and the ciphertext ports on the right hand side.
|
||||
*
|
||||
* Configuring JSSE is a rather complex topic, please refer to the JDK platform
|
||||
* documentation or the excellent user guide that is part of the Play Framework
|
||||
* documentation. The philosophy of this integration into Akka Streams is to
|
||||
* expose all knobs and dials to client code and therefore not limit the
|
||||
* configuration possibilities. In particular the client code will have to
|
||||
* provide the SSLContext from which the SSLEngine is then created. Handshake
|
||||
* parameters are set using [[NegotiateNewSession]] messages, the settings for
|
||||
* the initial handshake need to be provided up front using the same class;
|
||||
* please refer to the method documentation below.
|
||||
*
|
||||
* '''IMPORTANT NOTE'''
|
||||
*
|
||||
* The TLS specification does not permit half-closing of the user data session
|
||||
* that it transports—to be precise a half-close will always promptly lead to a
|
||||
* full close. This means that canceling the plaintext output or completing the
|
||||
* plaintext input of the SslTls stage will lead to full termination of the
|
||||
* secure connection without regard to whether bytes are remaining to be sent or
|
||||
* received, respectively. Especially for a client the common idiom of attaching
|
||||
* a finite Source to the plaintext input and transforming the plaintext response
|
||||
* bytes coming out will not work out of the box due to early termination of the
|
||||
* connection. For this reason there is a parameter that determines whether the
|
||||
* SslTls stage shall ignore completion and/or cancellation events, and the
|
||||
* default is to ignore completion (in view of the client–server scenario). In
|
||||
* order to terminate the connection the client will then need to cancel the
|
||||
* plaintext output as soon as all expected bytes have been received. When
|
||||
* ignoring both types of events the stage will shut down once both events have
|
||||
* been received. See also [[Closing]].
|
||||
*/
|
||||
object SslTls {
|
||||
|
||||
/**
|
||||
* Scala API: create a StreamTls [[akka.stream.scaladsl.BidiFlow]]. The
|
||||
* SSLContext will be used to create an SSLEngine to which then the
|
||||
* `firstSession` parameters are applied before initiating the first
|
||||
* handshake. The `role` parameter determines the SSLEngine’s role; this is
|
||||
* often the same as the underlying transport’s server or client role, but
|
||||
* that is not a requirement and depends entirely on the application
|
||||
* protocol.
|
||||
*
|
||||
* For a description of the `closing` parameter please refer to [[Closing]].
|
||||
*/
|
||||
def apply(sslContext: SSLContext, firstSession: NegotiateNewSession,
|
||||
role: Role, closing: Closing = IgnoreComplete): scaladsl.BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound, Unit] =
|
||||
new scaladsl.BidiFlow(TlsModule(OperationAttributes.none, sslContext, firstSession, role, closing))
|
||||
|
||||
/**
|
||||
* Java API: create a StreamTls [[akka.stream.javadsl.BidiFlow]] in client mode. The
|
||||
* SSLContext will be used to create an SSLEngine to which then the
|
||||
* `firstSession` parameters are applied before initiating the first
|
||||
* handshake. The `role` parameter determines the SSLEngine’s role; this is
|
||||
* often the same as the underlying transport’s server or client role, but
|
||||
* that is not a requirement and depends entirely on the application
|
||||
* protocol.
|
||||
*
|
||||
* This method uses the default closing behavior or [[IgnoreComplete]].
|
||||
*/
|
||||
def create(sslContext: SSLContext, firstSession: NegotiateNewSession, role: Role): javadsl.BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound, Unit] =
|
||||
new javadsl.BidiFlow(apply(sslContext, firstSession, role))
|
||||
|
||||
/**
|
||||
* Java API: create a StreamTls [[akka.stream.javadsl.BidiFlow]] in client mode. The
|
||||
* SSLContext will be used to create an SSLEngine to which then the
|
||||
* `firstSession` parameters are applied before initiating the first
|
||||
* handshake. The `role` parameter determines the SSLEngine’s role; this is
|
||||
* often the same as the underlying transport’s server or client role, but
|
||||
* that is not a requirement and depends entirely on the application
|
||||
* protocol.
|
||||
*
|
||||
* For a description of the `closing` parameter please refer to [[Closing]].
|
||||
*/
|
||||
def create(sslContext: SSLContext, firstSession: NegotiateNewSession, role: Role, closing: Closing): javadsl.BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound, Unit] =
|
||||
new javadsl.BidiFlow(apply(sslContext, firstSession, role, closing))
|
||||
|
||||
/**
|
||||
* INTERNAL API.
|
||||
*/
|
||||
private[akka] case class TlsModule(plainIn: Inlet[SslTlsOutbound], plainOut: Outlet[SslTlsInbound],
|
||||
cipherIn: Inlet[ByteString], cipherOut: Outlet[ByteString],
|
||||
shape: Shape, attributes: OperationAttributes,
|
||||
sslContext: SSLContext, firstSession: NegotiateNewSession,
|
||||
role: Role, closing: Closing) extends Module {
|
||||
override def subModules: Set[Module] = Set.empty
|
||||
|
||||
override def withAttributes(att: OperationAttributes): Module = copy(attributes = att)
|
||||
override def carbonCopy: Module = {
|
||||
val mod = TlsModule(attributes, sslContext, firstSession, role, closing)
|
||||
if (plainIn == shape.inlets(0)) mod
|
||||
else mod.replaceShape(mod.shape.asInstanceOf[BidiShape[_, _, _, _]].reversed)
|
||||
}
|
||||
|
||||
override def replaceShape(s: Shape) =
|
||||
if (s == shape) this
|
||||
else if (shape.hasSamePortsAs(s)) copy(shape = s)
|
||||
else throw new IllegalArgumentException("trying to replace shape with different ports")
|
||||
}
|
||||
|
||||
/**
|
||||
* INTERNAL API.
|
||||
*/
|
||||
private[akka] object TlsModule {
|
||||
def apply(attributes: OperationAttributes, sslContext: SSLContext, firstSession: NegotiateNewSession, role: Role, closing: Closing): TlsModule = {
|
||||
val name = attributes.nameOrDefault(s"StreamTls($role)")
|
||||
val cipherIn = new Inlet[ByteString](s"$name.cipherIn")
|
||||
val cipherOut = new Outlet[ByteString](s"$name.cipherOut")
|
||||
val plainIn = new Inlet[SslTlsOutbound](s"$name.transportIn")
|
||||
val plainOut = new Outlet[SslTlsInbound](s"$name.transportOut")
|
||||
val shape = new BidiShape(plainIn, cipherOut, cipherIn, plainOut)
|
||||
TlsModule(plainIn, plainOut, cipherIn, cipherOut, shape, attributes, sslContext, firstSession, role, closing)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This object holds simple wrapping [[BidiFlow]] implementations that can
|
||||
* be used instead of [[SslTls]] when no encryption is desired. The flows will
|
||||
* just adapt the message protocol by wrapping into [[SessionBytes]] and
|
||||
* unwrapping [[SendBytes]].
|
||||
*/
|
||||
object SslTlsPlacebo {
|
||||
val forScala = scaladsl.BidiFlow() { implicit b ⇒
|
||||
// this constructs a session for (invalid) protocol SSL_NULL_WITH_NULL_NULL
|
||||
val session = SSLContext.getDefault.createSSLEngine.getSession
|
||||
val top = b.add(scaladsl.Flow[SslTlsOutbound].collect { case SendBytes(b) ⇒ b })
|
||||
val bottom = b.add(scaladsl.Flow[ByteString].map(SessionBytes(session, _)))
|
||||
BidiShape(top, bottom)
|
||||
}
|
||||
val forJava = new javadsl.BidiFlow(forScala)
|
||||
}
|
||||
|
||||
/**
|
||||
* Many protocols are asymmetric and distinguish between the client and the
|
||||
* server, where the latter listens passively for messages and the former
|
||||
* actively initiates the exchange.
|
||||
*/
|
||||
object Role {
|
||||
/**
|
||||
* Java API: obtain the [[Client]] singleton value.
|
||||
*/
|
||||
def client: Role = Client
|
||||
/**
|
||||
* Java API: obtain the [[Server]] singleton value.
|
||||
*/
|
||||
def server: Role = Server
|
||||
}
|
||||
sealed abstract class Role
|
||||
|
||||
/**
|
||||
* The client is usually the side that consumes the service provided by its
|
||||
* interlocutor. The precise interpretation of this role is protocol specific.
|
||||
*/
|
||||
sealed abstract class Client extends Role
|
||||
case object Client extends Client
|
||||
|
||||
/**
|
||||
* The server is usually the side the provides the service to its interlocutor.
|
||||
* The precise interpretation of this role is protocol specific.
|
||||
*/
|
||||
sealed abstract class Server extends Role
|
||||
case object Server extends Server
|
||||
|
||||
/**
|
||||
* All streams in Akka are unidirectional: while in a complex flow graph data
|
||||
* may flow in multiple directions these individual flows are independent from
|
||||
* each other. The difference between two half-duplex connections in opposite
|
||||
* directions and a full-duplex connection is that the underlying transport
|
||||
* is shared in the latter and tearing it down will end the data transfer in
|
||||
* both directions.
|
||||
*
|
||||
* When integrating a full-duplex transport medium that does not support
|
||||
* half-closing (which means ending one direction of data transfer without
|
||||
* ending the other) into a stream topology, there can be unexpected effects.
|
||||
* Feeding a finite Source into this medium will close the connection after
|
||||
* all elements have been sent, which means that possible replies may not
|
||||
* be received in full. To support this type of usage, the sending and
|
||||
* receiving of data on the same side (e.g. on the [[Client]]) need to be
|
||||
* coordinated such that it is known when all replies have been received.
|
||||
* Only then should the transport be shut down.
|
||||
*
|
||||
* To support these scenarios it is recommended that the full-duplex
|
||||
* transport integration is configurable in terms of termination handling,
|
||||
* which means that the user can optionally suppress the normal (closing)
|
||||
* reaction to completion or cancellation events, as is expressed by the
|
||||
* possible values of this type:
|
||||
*
|
||||
* - [[EagerClose]] means to not ignore signals
|
||||
* - [[IgnoreCancel]] means to not react to cancellation of the receiving
|
||||
* side unless the sending side has already completed
|
||||
* - [[IgnoreComplete]] means to not reacto the completion of the sending
|
||||
* side unless the receiving side has already cancelled
|
||||
* - [[IgnoreBoth]] means to ignore the first termination signal—be that
|
||||
* cancellation or completion—and only act upon the second one
|
||||
*/
|
||||
sealed abstract class Closing {
|
||||
def ignoreCancel: Boolean
|
||||
def ignoreComplete: Boolean
|
||||
}
|
||||
object Closing {
|
||||
/**
|
||||
* Java API: obtain the [[EagerClose]] singleton value.
|
||||
*/
|
||||
def eagerClose: Closing = EagerClose
|
||||
/**
|
||||
* Java API: obtain the [[IgnoreCancel]] singleton value.
|
||||
*/
|
||||
def ignoreCancel: Closing = IgnoreCancel
|
||||
/**
|
||||
* Java API: obtain the [[IgnoreComplete]] singleton value.
|
||||
*/
|
||||
def ignoreComplete: Closing = IgnoreComplete
|
||||
/**
|
||||
* Java API: obtain the [[IgnoreBoth]] singleton value.
|
||||
*/
|
||||
def ignoreBoth: Closing = IgnoreBoth
|
||||
}
|
||||
|
||||
/**
|
||||
* see [[Closing]]
|
||||
*/
|
||||
sealed abstract class EagerClose extends Closing {
|
||||
override def ignoreCancel = false
|
||||
override def ignoreComplete = false
|
||||
}
|
||||
case object EagerClose extends EagerClose
|
||||
|
||||
/**
|
||||
* see [[Closing]]
|
||||
*/
|
||||
sealed abstract class IgnoreCancel extends Closing {
|
||||
override def ignoreCancel = true
|
||||
override def ignoreComplete = false
|
||||
}
|
||||
case object IgnoreCancel extends IgnoreCancel
|
||||
|
||||
/**
|
||||
* see [[Closing]]
|
||||
*/
|
||||
sealed abstract class IgnoreComplete extends Closing {
|
||||
override def ignoreCancel = false
|
||||
override def ignoreComplete = true
|
||||
}
|
||||
case object IgnoreComplete extends IgnoreComplete
|
||||
|
||||
/**
|
||||
* see [[Closing]]
|
||||
*/
|
||||
sealed abstract class IgnoreBoth extends Closing {
|
||||
override def ignoreCancel = true
|
||||
override def ignoreComplete = true
|
||||
}
|
||||
case object IgnoreBoth extends IgnoreBoth
|
||||
|
||||
/**
|
||||
* This is the supertype of all messages that the SslTls stage emits on the
|
||||
* plaintext side.
|
||||
*/
|
||||
sealed trait SslTlsInbound
|
||||
|
||||
/**
|
||||
* If the underlying transport is closed before the final TLS closure command
|
||||
* is received from the peer then the SSLEngine will throw an SSLException that
|
||||
* warns about possible truncation attacks. This exception is caught and
|
||||
* translated into this message when encountered. Most of the time this occurs
|
||||
* not because of a malicious attacker but due to a connection abort or a
|
||||
* misbehaving communication peer.
|
||||
*/
|
||||
sealed abstract class SessionTruncated extends SslTlsInbound
|
||||
case object SessionTruncated extends SessionTruncated
|
||||
|
||||
/**
|
||||
* Plaintext bytes emitted by the SSLEngine are received over one specific
|
||||
* encryption session and this class bundles the bytes with the SSLSession
|
||||
* object. When the session changes due to renegotiation (which can be
|
||||
* initiated by either party) the new session value will not compare equal to
|
||||
* the previous one.
|
||||
*
|
||||
* The Java API for getting session information is given by the SSLSession object,
|
||||
* the Scala API adapters are offered below.
|
||||
*/
|
||||
case class SessionBytes(session: SSLSession, bytes: ByteString) extends SslTlsInbound {
|
||||
/**
|
||||
* Scala API: Extract the certificates that were actually used by this
|
||||
* engine during this session’s negotiation. The list is empty if no
|
||||
* certificates were used.
|
||||
*/
|
||||
def localCertificates: List[Certificate] = Option(session.getLocalCertificates).map(_.toList).getOrElse(Nil)
|
||||
/**
|
||||
* Scala API: Extract the Principal that was actually used by this engine
|
||||
* during this session’s negotiation.
|
||||
*/
|
||||
def localPrincipal = Option(session.getLocalPrincipal)
|
||||
/**
|
||||
* Scala API: Extract the certificates that were used by the peer engine
|
||||
* during this session’s negotiation. The list is empty if no certificates
|
||||
* were used.
|
||||
*/
|
||||
def peerCertificates =
|
||||
try Option(session.getPeerCertificates).map(_.toList).getOrElse(Nil)
|
||||
catch { case e: SSLPeerUnverifiedException ⇒ Nil }
|
||||
/**
|
||||
* Scala API: Extract the Principal that the peer engine presented during
|
||||
* this session’s negotiation.
|
||||
*/
|
||||
def peerPrincipal =
|
||||
try Option(session.getPeerPrincipal)
|
||||
catch { case e: SSLPeerUnverifiedException ⇒ None }
|
||||
}
|
||||
|
||||
/**
|
||||
* This is the supertype of all messages that the SslTls stage accepts on its
|
||||
* plaintext side.
|
||||
*/
|
||||
sealed trait SslTlsOutbound
|
||||
|
||||
/**
|
||||
* Initiate a new session negotiation. Any [[SendBytes]] commands following
|
||||
* this one will be held back (i.e. back-pressured) until the new handshake is
|
||||
* completed, meaning that the bytes following this message will be encrypted
|
||||
* according to the requirements outlined here.
|
||||
*
|
||||
* Each of the values in this message is optional and will have the following
|
||||
* effect if provided:
|
||||
*
|
||||
* - `enabledCipherSuites` will be passed to `SSLEngine::setEnabledCipherSuites()`
|
||||
* - `enabledProtocols` will be passed to `SSLEngine::setEnabledProtocols()`
|
||||
* - `clientAuth` will be passed to `SSLEngine::setWantClientAuth()` or `SSLEngine.setNeedClientAuth()`, respectively
|
||||
* - `sslParameters` will be passed to `SSLEngine::setSSLParameters()`
|
||||
*/
|
||||
case class NegotiateNewSession(
|
||||
enabledCipherSuites: Option[immutable.Seq[String]],
|
||||
enabledProtocols: Option[immutable.Seq[String]],
|
||||
clientAuth: Option[ClientAuth],
|
||||
sslParameters: Option[SSLParameters]) extends SslTlsOutbound {
|
||||
|
||||
/**
|
||||
* Java API: Make a copy of this message with the given `enabledCipherSuites`.
|
||||
*/
|
||||
@varargs
|
||||
def withCipherSuites(s: String*) = copy(enabledCipherSuites = Some(s.toList))
|
||||
|
||||
/**
|
||||
* Java API: Make a copy of this message with the given `enabledProtocols`.
|
||||
*/
|
||||
@varargs
|
||||
def withProtocols(p: String*) = copy(enabledProtocols = Some(p.toList))
|
||||
|
||||
/**
|
||||
* Java API: Make a copy of this message with the given [[ClientAuth]] setting.
|
||||
*/
|
||||
def withClientAuth(ca: ClientAuth) = copy(clientAuth = Some(ca))
|
||||
|
||||
/**
|
||||
* Java API: Make a copy of this message with the given [[SSLParameters]].
|
||||
*/
|
||||
def withParameters(p: SSLParameters) = copy(sslParameters = Some(p))
|
||||
}
|
||||
|
||||
object NegotiateNewSession extends NegotiateNewSession(None, None, None, None) {
|
||||
/**
|
||||
* Java API: obtain the default value (which will leave the SSLEngine’s
|
||||
* settings unchanged).
|
||||
*/
|
||||
def withDefaults = this
|
||||
}
|
||||
|
||||
/**
|
||||
* Send the given [[akka.util.ByteString]] across the encrypted session to the
|
||||
* peer.
|
||||
*/
|
||||
case class SendBytes(bytes: ByteString) extends SslTlsOutbound
|
||||
|
||||
/**
|
||||
* An SSLEngine can either demand, allow or ignore its peer’s authentication
|
||||
* (via certificates), where `Need` will fail the handshake if the peer does
|
||||
* not provide valid credentials, `Want` allows the peer to send credentials
|
||||
* and verifies them if provided, and `None` disables peer certificate
|
||||
* verification.
|
||||
*
|
||||
* See the documentation for `SSLEngine::setWantClientAuth` for more
|
||||
* information.
|
||||
*/
|
||||
sealed abstract class ClientAuth
|
||||
object ClientAuth {
|
||||
case object None extends ClientAuth
|
||||
case object Want extends ClientAuth
|
||||
case object Need extends ClientAuth
|
||||
|
||||
def none: ClientAuth = None
|
||||
def want: ClientAuth = Want
|
||||
def need: ClientAuth = Need
|
||||
}
|
||||
|
|
@ -115,11 +115,7 @@ final class BidiFlow[-I1, +O1, -I2, +O2, +Mat](private[stream] override val modu
|
|||
/**
|
||||
* Turn this BidiFlow around by 180 degrees, logically flipping it upside down in a protocol stack.
|
||||
*/
|
||||
def reversed: BidiFlow[I2, O2, I1, O1, Mat] = {
|
||||
val ins = shape.inlets
|
||||
val outs = shape.outlets
|
||||
new BidiFlow(module.replaceShape(shape.copyFromPorts(ins.reverse, outs.reverse)))
|
||||
}
|
||||
def reversed: BidiFlow[I2, O2, I1, O1, Mat] = new BidiFlow(module.replaceShape(shape.reversed))
|
||||
|
||||
override def withAttributes(attr: OperationAttributes): BidiFlow[I1, O1, I2, O2, Mat] =
|
||||
new BidiFlow(module.withAttributes(attr).wrap())
|
||||
|
|
|
|||
|
|
@ -1,376 +0,0 @@
|
|||
/**
|
||||
* Copyright (C) 2009-2014 Typesafe Inc. <http://www.typesafe.com>
|
||||
*/
|
||||
|
||||
package akka.stream.ssl
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
import java.security.Principal
|
||||
import java.security.cert.Certificate
|
||||
import javax.net.ssl.SSLEngineResult.HandshakeStatus._
|
||||
import javax.net.ssl.SSLEngineResult.Status._
|
||||
import javax.net.ssl.SSLEngine
|
||||
import javax.net.ssl.SSLEngineResult
|
||||
import javax.net.ssl.SSLPeerUnverifiedException
|
||||
import javax.net.ssl.SSLSession
|
||||
|
||||
import akka.actor.Actor
|
||||
import akka.actor.ActorLogging
|
||||
import akka.actor.ActorRef
|
||||
import akka.stream.ActorFlowMaterializerSettings
|
||||
import akka.stream.impl._
|
||||
import akka.util.ByteString
|
||||
import akka.util.ByteStringBuilder
|
||||
import org.reactivestreams.Publisher
|
||||
import org.reactivestreams.Subscriber
|
||||
|
||||
import scala.annotation.tailrec
|
||||
|
||||
object SslTlsCipher {
|
||||
|
||||
/**
|
||||
* An established SSL session.
|
||||
*/
|
||||
final case class InboundSession(
|
||||
sessionInfo: SessionInfo,
|
||||
data: Publisher[ByteString])
|
||||
|
||||
/**
|
||||
* A request to establish an SSL session.
|
||||
* FIXME Not used right now since there is only one session established
|
||||
*/
|
||||
final case class OutboundSession(
|
||||
negotiation: SessionNegotiation,
|
||||
data: Subscriber[ByteString])
|
||||
|
||||
/**
|
||||
* Information about the established SSL session.
|
||||
*/
|
||||
final case class SessionInfo(
|
||||
cipherSuite: String,
|
||||
localCertificates: List[Certificate],
|
||||
localPrincipal: Option[Principal],
|
||||
peerCertificates: List[Certificate],
|
||||
peerPrincipal: Option[Principal])
|
||||
|
||||
object SessionInfo {
|
||||
|
||||
def apply(engine: SSLEngine): SessionInfo =
|
||||
apply(engine.getSession)
|
||||
|
||||
def apply(session: SSLSession): SessionInfo = {
|
||||
val localCertificates = Option(session.getLocalCertificates).map { _.toList } getOrElse Nil
|
||||
val localPrincipal = Option(session.getLocalPrincipal)
|
||||
val peerCertificates =
|
||||
try session.getPeerCertificates.toList
|
||||
catch { case e: SSLPeerUnverifiedException ⇒ Nil }
|
||||
val peerPrincipal =
|
||||
try Option(session.getPeerPrincipal)
|
||||
catch { case e: SSLPeerUnverifiedException ⇒ None }
|
||||
SessionInfo(session.getCipherSuite, localCertificates, localPrincipal, peerCertificates, peerPrincipal)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Information needed to establish an SSL session.
|
||||
*/
|
||||
final case class SessionNegotiation(engine: SSLEngine)
|
||||
}
|
||||
|
||||
final case class SslTlsCipher(
|
||||
sessionInbound: Publisher[SslTlsCipher.InboundSession],
|
||||
// FIXME We only have one session, and the SessionNegotiation is passed in via the constructor.
|
||||
// This should really be a Subscriber[SslTlsCipher.OutboundSession]
|
||||
plainTextOutbound: Subscriber[ByteString],
|
||||
cipherTextInbound: Subscriber[ByteString],
|
||||
cipherTextOutbound: Publisher[ByteString])
|
||||
|
||||
object SslTlsCipherActor {
|
||||
val EmptyByteArray = Array.empty[Byte]
|
||||
val EmptyByteBuffer = ByteBuffer.wrap(EmptyByteArray)
|
||||
}
|
||||
|
||||
class SslTlsCipherActor(val requester: ActorRef, val sessionNegotioation: SslTlsCipher.SessionNegotiation, tracing: Boolean)
|
||||
extends Actor
|
||||
with ActorLogging
|
||||
with Pump
|
||||
with MultiStreamOutputProcessorLike
|
||||
with MultiStreamInputProcessorLike {
|
||||
|
||||
override val subscriptionTimeoutSettings = ActorFlowMaterializerSettings(context.system).subscriptionTimeoutSettings
|
||||
|
||||
def this(requester: ActorRef, sessionNegotioation: SslTlsCipher.SessionNegotiation) =
|
||||
this(requester, sessionNegotioation, false)
|
||||
|
||||
import MultiStreamInputProcessor.SubstreamSubscriber
|
||||
import SslTlsCipherActor._
|
||||
|
||||
private var _nextId = 0L
|
||||
protected def nextId(): Long = { _nextId += 1; _nextId }
|
||||
override protected val inputBufferSize = 1
|
||||
|
||||
// The cipherTextInput (Subscriber[ByteString])
|
||||
val inboundCipherTextInput = createSubstreamInput()
|
||||
|
||||
// The cipherTextOutput (Publisher[ByteString])
|
||||
val outboundCipherTextOutput = createSubstreamOutput()
|
||||
|
||||
// The Publisher[SslTlsCipher.InboundSession]
|
||||
// FIXME For now there is only one session ever exposed
|
||||
val inboundSessionOutput = createSubstreamOutput()
|
||||
|
||||
// The read side for the user (Publisher[ByteString])
|
||||
// FIXME For now there is only one session ever exposed
|
||||
val inboundPlaintextOutput = createSubstreamOutput()
|
||||
|
||||
// The write side for the user (Subscriber[ByteString])
|
||||
// FIXME For now there is only one session ever exposed
|
||||
val outboundPlaintextInput = createSubstreamInput()
|
||||
|
||||
// Plaintext bytes to be encrypted
|
||||
var plaintextOutboundBytes = EmptyByteBuffer
|
||||
val plaintextOutboundBytesPending = new TransferState {
|
||||
override def isReady = plaintextOutboundBytes.hasRemaining
|
||||
override def isCompleted = false
|
||||
}
|
||||
|
||||
// Encrypted bytes to be sent
|
||||
val cipherTextOutboundBytes = new ByteStringBuilder
|
||||
|
||||
// Encrypted bytes to be decrypted
|
||||
var cipherTextInboundBytes = EmptyByteBuffer
|
||||
val cipherTextInboundBytesPending = new TransferState {
|
||||
override def isReady = cipherTextInboundBytes.hasRemaining
|
||||
override def isCompleted = false
|
||||
}
|
||||
|
||||
// Plaintext bytes to be received
|
||||
val plaintextInboundBytes = new ByteStringBuilder
|
||||
|
||||
// FIXME: Change this into a pool of ByteBuffer later
|
||||
// These are Nettys default values
|
||||
// 16665 + 1024 (room for compressed data) + 1024 (for OpenJDK compatibility)
|
||||
val temporaryBuffer = ByteBuffer.allocate(16665 + 2048)
|
||||
|
||||
val engine: SSLEngine = sessionNegotioation.engine
|
||||
|
||||
def doWrap(tempBuf: ByteBuffer): SSLEngineResult = {
|
||||
tempBuf.clear()
|
||||
if (tracing) log.debug("before wrap {}", plaintextOutboundBytes.remaining)
|
||||
val result = engine.wrap(plaintextOutboundBytes, tempBuf)
|
||||
if (tracing) log.debug("after wrap {}", plaintextOutboundBytes.remaining)
|
||||
tempBuf.flip()
|
||||
if (tempBuf.hasRemaining) {
|
||||
val bs = ByteString(tempBuf)
|
||||
if (tracing) log.debug("wrap Enqueue cipher bytes {}", bs)
|
||||
cipherTextOutboundBytes ++= bs
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
def doUnwrap(tempBuf: ByteBuffer): SSLEngineResult = {
|
||||
tempBuf.clear()
|
||||
if (tracing) log.debug("before unwrap {}", cipherTextInboundBytes.remaining)
|
||||
val result = engine.unwrap(cipherTextInboundBytes, tempBuf)
|
||||
if (tracing) log.debug("after unwrap {}", cipherTextInboundBytes.remaining)
|
||||
tempBuf.flip()
|
||||
if (tempBuf.hasRemaining) {
|
||||
val bs = ByteString(tempBuf)
|
||||
if (tracing) log.debug("unwrap Enqueue cipher bytes {}", bs)
|
||||
plaintextInboundBytes ++= bs
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
def enqueueCipherInputBytes(data: ByteString): Unit = {
|
||||
cipherTextInboundBytes =
|
||||
if (cipherTextInboundBytes.hasRemaining) {
|
||||
val buffer = ByteBuffer.allocate(cipherTextInboundBytes.remaining + data.size)
|
||||
buffer.put(cipherTextInboundBytes)
|
||||
data.copyToBuffer(buffer)
|
||||
buffer.flip()
|
||||
buffer
|
||||
} else data.toByteBuffer
|
||||
}
|
||||
|
||||
def writeCipherTextOutboundBytes() = {
|
||||
if (cipherTextOutboundBytes.length > 0) {
|
||||
val bs = cipherTextOutboundBytes.result()
|
||||
cipherTextOutboundBytes.clear()
|
||||
outboundCipherTextOutput.enqueueOutputElement(bs)
|
||||
}
|
||||
}
|
||||
|
||||
def writePlaintextInboundBytes() = {
|
||||
if (plaintextInboundBytes.length > 0) {
|
||||
val bs = plaintextInboundBytes.result()
|
||||
plaintextInboundBytes.clear()
|
||||
inboundPlaintextOutput.enqueueOutputElement(bs)
|
||||
}
|
||||
}
|
||||
|
||||
@tailrec
|
||||
private def runDelegatedTasks(): Unit = {
|
||||
val task = engine.getDelegatedTask
|
||||
if (task != null) {
|
||||
if (tracing) log.debug("Running delegated task {}", task)
|
||||
task.run()
|
||||
runDelegatedTasks()
|
||||
}
|
||||
}
|
||||
|
||||
def publishSSLSessionEstablished(): Unit = {
|
||||
import SslTlsCipher._
|
||||
val info = SessionInfo(engine)
|
||||
val is = InboundSession(info, inboundPlaintextOutput.asInstanceOf[Publisher[ByteString]])
|
||||
if (tracing) log.debug("#### Handshake done!")
|
||||
inboundSessionOutput.enqueueOutputElement(is)
|
||||
}
|
||||
|
||||
val unwrapPhase: TransferPhase = TransferPhase(inboundCipherTextInput.NeedsInput || cipherTextInboundBytesPending) { () ⇒
|
||||
if (tracing) log.debug("### UNWRAP")
|
||||
if (inboundCipherTextInput.NeedsInput.isReady)
|
||||
enqueueCipherInputBytes(inboundCipherTextInput.dequeueInputElement().asInstanceOf[ByteString])
|
||||
val result = doUnwrap(temporaryBuffer)
|
||||
val rs = result.getStatus
|
||||
if (tracing) log.debug("## UNWRAP {}", rs)
|
||||
val hs = result.getHandshakeStatus
|
||||
val next = rs match {
|
||||
case OK ⇒
|
||||
handshakePhase(hs)
|
||||
case CLOSED ⇒ if (!engine.isInboundDone) encryptionPhase else completedPhase
|
||||
case BUFFER_OVERFLOW ⇒ throw new IllegalStateException // the SslBufferPool should make sure that buffers are never too small
|
||||
case BUFFER_UNDERFLOW ⇒ throw new IllegalStateException // should never appear as a result of a wrap
|
||||
}
|
||||
nextPhase(next)
|
||||
}
|
||||
|
||||
val wrapPhase: TransferPhase = TransferPhase(outboundCipherTextOutput.NeedsDemand) { () ⇒
|
||||
if (tracing) log.debug("### WRAP")
|
||||
val result = doWrap(temporaryBuffer)
|
||||
val rs = result.getStatus
|
||||
if (tracing) log.debug("## WRAP {}", rs)
|
||||
val hs = result.getHandshakeStatus
|
||||
val next = rs match {
|
||||
case OK ⇒
|
||||
writeCipherTextOutboundBytes()
|
||||
handshakePhase(hs)
|
||||
case CLOSED ⇒ if (!engine.isInboundDone) decryptionPhase else completedPhase
|
||||
case BUFFER_OVERFLOW ⇒ throw new IllegalStateException // the SslBufferPool should make sure that buffers are never too small
|
||||
case BUFFER_UNDERFLOW ⇒ throw new IllegalStateException // should never appear as a result of a wrap
|
||||
}
|
||||
nextPhase(next)
|
||||
}
|
||||
|
||||
def handshakePhase(hs: SSLEngineResult.HandshakeStatus): TransferPhase = {
|
||||
if (tracing) log.debug("### HS {}", hs)
|
||||
hs match {
|
||||
case status @ (NOT_HANDSHAKING | FINISHED) ⇒
|
||||
if (status == FINISHED) publishSSLSessionEstablished()
|
||||
engineRunningPhase
|
||||
case NEED_WRAP ⇒ wrapPhase
|
||||
case NEED_UNWRAP ⇒ unwrapPhase
|
||||
case NEED_TASK ⇒
|
||||
runDelegatedTasks()
|
||||
engine.getHandshakeStatus match {
|
||||
case NEED_WRAP ⇒ wrapPhase
|
||||
case NEED_UNWRAP ⇒ unwrapPhase
|
||||
case x ⇒ throw new IllegalStateException(s"Bad Handshake status $x")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val waitForHandshakeStartPhase: TransferPhase = TransferPhase((outboundPlaintextInput.NeedsInput || inboundCipherTextInput.NeedsInput) && inboundSessionOutput.NeedsDemand) { () ⇒
|
||||
if (tracing) log.debug("#### Starting Handshake")
|
||||
engine.beginHandshake()
|
||||
nextPhase(handshakePhase(engine.getHandshakeStatus))
|
||||
}
|
||||
|
||||
val encryptionInputAvailable = outboundPlaintextInput.NeedsInput || plaintextOutboundBytesPending
|
||||
val decryptionInputAvailable = inboundCipherTextInput.NeedsInput || cipherTextInboundBytesPending
|
||||
|
||||
val canEncrypt = encryptionInputAvailable && outboundCipherTextOutput.NeedsDemand
|
||||
val canDecrypt = decryptionInputAvailable && inboundPlaintextOutput.NeedsDemand
|
||||
|
||||
val engineRunningPhase: TransferPhase = TransferPhase(canEncrypt || canDecrypt) { () ⇒
|
||||
if (tracing) log.debug("#### Engine running")
|
||||
if (canEncrypt.isExecutable) {
|
||||
nextPhase(encryptionPhase)
|
||||
} else {
|
||||
nextPhase(decryptionPhase)
|
||||
}
|
||||
}
|
||||
|
||||
val encryptionPhase: TransferPhase = TransferPhase(canEncrypt) { () ⇒
|
||||
if (tracing) log.debug("### Encrypting")
|
||||
if (!plaintextOutboundBytesPending.isReady && outboundPlaintextInput.inputsAvailable) {
|
||||
val elem = outboundPlaintextInput.dequeueInputElement().asInstanceOf[ByteString]
|
||||
plaintextOutboundBytes = elem.asByteBuffer
|
||||
}
|
||||
val result = doWrap(temporaryBuffer)
|
||||
val rs = result.getStatus
|
||||
if (tracing) log.debug("## Encrypting {}", rs)
|
||||
val hs = result.getHandshakeStatus
|
||||
val next = rs match {
|
||||
case OK ⇒
|
||||
if (hs == NOT_HANDSHAKING) {
|
||||
writeCipherTextOutboundBytes()
|
||||
engineRunningPhase
|
||||
} else handshakePhase(hs)
|
||||
case CLOSED ⇒ if (!engine.isInboundDone) decryptionPhase else completedPhase
|
||||
case BUFFER_OVERFLOW ⇒ throw new IllegalStateException // the SslBufferPool should make sure that buffers are never too small
|
||||
case BUFFER_UNDERFLOW ⇒ throw new IllegalStateException // should never appear as a result of a wrap
|
||||
}
|
||||
nextPhase(next)
|
||||
}
|
||||
|
||||
val decryptionPhase: TransferPhase = TransferPhase(canDecrypt) { () ⇒
|
||||
if (tracing) log.debug("### Decrypting")
|
||||
if (inboundCipherTextInput.NeedsInput.isReady) {
|
||||
val elem = inboundCipherTextInput.dequeueInputElement().asInstanceOf[ByteString]
|
||||
enqueueCipherInputBytes(elem)
|
||||
}
|
||||
val result = doUnwrap(temporaryBuffer)
|
||||
val rs = result.getStatus
|
||||
if (tracing) log.debug("## Decrypting {}", rs)
|
||||
val hs = result.getHandshakeStatus
|
||||
val next = rs match {
|
||||
case OK ⇒
|
||||
if (hs == NOT_HANDSHAKING) {
|
||||
writePlaintextInboundBytes()
|
||||
engineRunningPhase
|
||||
} else handshakePhase(hs)
|
||||
case CLOSED ⇒ if (!engine.isOutboundDone) encryptionPhase else completedPhase
|
||||
case BUFFER_OVERFLOW ⇒ throw new IllegalStateException // the SslBufferPool should make sure that buffers are never too small
|
||||
case BUFFER_UNDERFLOW ⇒ throw new IllegalStateException // should never appear as a result of a wrap
|
||||
}
|
||||
nextPhase(next)
|
||||
}
|
||||
|
||||
nextPhase(waitForHandshakeStartPhase)
|
||||
|
||||
override def preStart() {
|
||||
val plainTextInput = inboundSessionOutput.asInstanceOf[Publisher[SslTlsCipher.InboundSession]]
|
||||
val plainTextOutput = new SubstreamSubscriber[ByteString](self, outboundPlaintextInput.key)
|
||||
val cipherTextInput = new SubstreamSubscriber[ByteString](self, inboundCipherTextInput.key)
|
||||
val cipherTextOutput = outboundCipherTextOutput.asInstanceOf[Publisher[ByteString]]
|
||||
requester ! SslTlsCipher(plainTextInput, plainTextOutput, cipherTextInput, cipherTextOutput)
|
||||
}
|
||||
|
||||
override def receive = inputSubstreamManagement orElse outputSubstreamManagement
|
||||
|
||||
protected def fail(e: Throwable): Unit = {
|
||||
// FIXME: escalate to supervisor
|
||||
if (tracing) log.debug("fail {} due to: {}", self, e.getMessage)
|
||||
failInputs(e)
|
||||
failOutputs(e)
|
||||
context.stop(self)
|
||||
}
|
||||
|
||||
override protected def pumpFailed(e: Throwable): Unit = fail(e)
|
||||
|
||||
override protected def pumpFinished(): Unit = {
|
||||
finishInputs()
|
||||
finishOutputs()
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue