diff --git a/akka-docs-dev/rst/scala/code/docs/stream/FlowErrorDocSpec.scala b/akka-docs-dev/rst/scala/code/docs/stream/FlowErrorDocSpec.scala index 9502926710..0b3b77327a 100644 --- a/akka-docs-dev/rst/scala/code/docs/stream/FlowErrorDocSpec.scala +++ b/akka-docs-dev/rst/scala/code/docs/stream/FlowErrorDocSpec.scala @@ -82,10 +82,10 @@ class FlowErrorDocSpec extends AkkaSpec { val result = source.grouped(1000).runWith(Sink.head) // the negative element cause the scan stage to be restarted, // i.e. start from 0 again - // result here will be a Future completed with Success(Vector(0, 1, 0, 5, 12)) + // result here will be a Future completed with Success(Vector(0, 1, 4, 0, 5, 12)) //#restart-section - Await.result(result, remaining) should be(Vector(0, 1, 0, 5, 12)) + Await.result(result, remaining) should be(Vector(0, 1, 4, 0, 5, 12)) } } diff --git a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala index 14061ba7b6..2ce04baa81 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/impl/fusing/InterpreterSupervisionSpec.scala @@ -343,9 +343,11 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { "resume when Scan throws" in new TestSetup(Seq( Scan(1, (acc: Int, x: Int) ⇒ if (x == 10) throw TE else acc + x, resumingDecider))) { downstream.requestOne() + lastEvents() should be(Set(OnNext(1))) + downstream.requestOne() lastEvents() should be(Set(RequestOne)) upstream.onNext(2) - lastEvents() should be(Set(OnNext(1))) + lastEvents() should be(Set(OnNext(3))) downstream.requestOne() lastEvents() should be(Set(RequestOne)) @@ -353,15 +355,17 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { lastEvents() should be(Set(RequestOne)) upstream.onNext(4) - lastEvents() should be(Set(OnNext(3))) // 1 + 2 + lastEvents() should be(Set(OnNext(7))) // 1 + 2 + 4 } "restart when Scan throws" in new TestSetup(Seq( Scan(1, (acc: Int, x: Int) ⇒ if (x == 10) throw TE else acc + x, restartingDecider))) { downstream.requestOne() + lastEvents() should be(Set(OnNext(1))) + downstream.requestOne() lastEvents() should be(Set(RequestOne)) upstream.onNext(2) - lastEvents() should be(Set(OnNext(1))) + lastEvents() should be(Set(OnNext(3))) downstream.requestOne() lastEvents() should be(Set(RequestOne)) @@ -371,10 +375,12 @@ class InterpreterSupervisionSpec extends InterpreterSpecKit { upstream.onNext(4) lastEvents() should be(Set(OnNext(1))) // starts over again + downstream.requestOne() + lastEvents() should be(Set(OnNext(5))) downstream.requestOne() lastEvents() should be(Set(RequestOne)) upstream.onNext(20) - lastEvents() should be(Set(OnNext(5))) // 1+4 + lastEvents() should be(Set(OnNext(25))) // 1 + 4 + 20 } "restart when Conflate `seed` throws" in new TestSetup(Seq(Conflate( diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala new file mode 100644 index 0000000000..dbf077b96d --- /dev/null +++ b/akka-stream-tests/src/test/scala/akka/stream/io/TlsSpec.scala @@ -0,0 +1,411 @@ +package akka.stream.io + +import java.security.{ KeyStore, SecureRandom } +import javax.net.ssl.{ TrustManagerFactory, KeyManagerFactory, SSLContext } +import akka.stream.{ Graph, BidiShape, ActorFlowMaterializer } +import akka.stream.scaladsl._ +import akka.stream.io._ +import akka.stream.testkit.{ TestUtils, AkkaSpec } +import akka.util.ByteString +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.collection.immutable +import scala.util.Random +import akka.stream.stage.AsyncStage +import akka.stream.stage.AsyncContext +import java.util.concurrent.TimeoutException +import akka.actor.ActorSystem +import javax.net.ssl.SSLSession +import akka.pattern.{ after ⇒ later } +import scala.concurrent.Future +import java.net.InetSocketAddress +import akka.testkit.EventFilter +import akka.stream.stage.PushStage +import akka.stream.stage.Context + +object TlsSpec { + + val rnd = new Random + + def initSslContext(): SSLContext = { + + val password = "changeme" + + val keyStore = KeyStore.getInstance(KeyStore.getDefaultType) + keyStore.load(getClass.getResourceAsStream("/keystore"), password.toCharArray) + + val trustStore = KeyStore.getInstance(KeyStore.getDefaultType) + trustStore.load(getClass.getResourceAsStream("/truststore"), password.toCharArray) + + val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm) + keyManagerFactory.init(keyStore, password.toCharArray) + + val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) + trustManagerFactory.init(trustStore) + + val context = SSLContext.getInstance("TLS") + context.init(keyManagerFactory.getKeyManagers, trustManagerFactory.getTrustManagers, new SecureRandom) + context + } + + /** + * This is a stage that fires a TimeoutException failure 2 seconds after it was started, + * independent of the traffic going through. The purpose is to include the last seen + * element in the exception message to help in figuring out what went wrong. + */ + class Timeout(duration: FiniteDuration)(implicit system: ActorSystem) extends AsyncStage[ByteString, ByteString, Unit] { + private var last: ByteString = _ + + override def initAsyncInput(ctx: AsyncContext[ByteString, Unit]) = { + val cb = ctx.getAsyncCallback() + system.scheduler.scheduleOnce(duration)(cb.invoke(()))(system.dispatcher) + } + + override def onAsyncInput(u: Unit, ctx: AsyncContext[ByteString, Unit]) = + ctx.fail(new TimeoutException(s"timeout expired, last element was $last")) + + override def onPush(elem: ByteString, ctx: AsyncContext[ByteString, Unit]) = { + last = elem + if (ctx.isHoldingDownstream) ctx.pushAndPull(elem) + else ctx.holdUpstream() + } + + override def onPull(ctx: AsyncContext[ByteString, Unit]) = + if (ctx.isFinishing) ctx.pushAndFinish(last) + else if (ctx.isHoldingUpstream) ctx.pushAndPull(last) + else ctx.holdDownstream() + + override def onUpstreamFinish(ctx: AsyncContext[ByteString, Unit]) = + if (ctx.isHoldingUpstream) ctx.absorbTermination() + else ctx.finish() + + override def onDownstreamFinish(ctx: AsyncContext[ByteString, Unit]) = { + system.log.debug("cancelled") + ctx.finish() + } + } + + // FIXME #17226 replace by .dropWhile when implemented + class DropWhile[T](p: T ⇒ Boolean) extends PushStage[T, T] { + private var open = false + override def onPush(elem: T, ctx: Context[T]) = + if (open) ctx.push(elem) + else if (p(elem)) ctx.pull() + else { + open = true + ctx.push(elem) + } + } + +} + +class TlsSpec extends AkkaSpec("akka.loglevel=INFO\nakka.actor.debug.receive=off") { + import TlsSpec._ + + import system.dispatcher + implicit val materializer = ActorFlowMaterializer() + + import FlowGraph.Implicits._ + + "StreamTLS" must { + + val sslContext = initSslContext() + + val debug = Flow[SslTlsInbound].map { x ⇒ + x match { + case SessionTruncated ⇒ system.log.debug(s" ----------- truncated ") + case SessionBytes(_, b) ⇒ system.log.debug(s" ----------- (${b.size}) ${b.take(32).utf8String}") + } + x + } + + val cipherSuites = NegotiateNewSession.withCipherSuites("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_128_CBC_SHA") + def clientTls(closing: Closing) = SslTls(sslContext, cipherSuites, Client, closing) + def serverTls(closing: Closing) = SslTls(sslContext, cipherSuites, Server, closing) + + trait Named { + def name: String = + getClass.getName + .reverse + .dropWhile(c ⇒ "$0123456789".indexOf(c) != -1) + .takeWhile(_ != '$') + .reverse + } + + trait CommunicationSetup extends Named { + def decorateFlow(leftClosing: Closing, rightClosing: Closing, + rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]): Flow[SslTlsOutbound, SslTlsInbound, Unit] + def cleanup(): Unit = () + } + + object ClientInitiates extends CommunicationSetup { + def decorateFlow(leftClosing: Closing, rightClosing: Closing, + rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = + clientTls(leftClosing) atop serverTls(rightClosing).reversed join rhs + } + + object ServerInitiates extends CommunicationSetup { + def decorateFlow(leftClosing: Closing, rightClosing: Closing, + rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = + serverTls(leftClosing) atop clientTls(rightClosing).reversed join rhs + } + + def server(flow: Flow[ByteString, ByteString, Any]) = { + val server = StreamTcp() + .bind(new InetSocketAddress("localhost", 0)) + .to(Sink.foreach(c ⇒ c.flow.join(flow).run())) + .run() + Await.result(server, 2.seconds) + } + + object ClientInitiatesViaTcp extends CommunicationSetup { + var binding: StreamTcp.ServerBinding = null + def decorateFlow(leftClosing: Closing, rightClosing: Closing, + rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = { + binding = server(serverTls(rightClosing).reversed join rhs) + clientTls(leftClosing) join StreamTcp().outgoingConnection(binding.localAddress) + } + override def cleanup(): Unit = binding.unbind() + } + + object ServerInitiatesViaTcp extends CommunicationSetup { + var binding: StreamTcp.ServerBinding = null + def decorateFlow(leftClosing: Closing, rightClosing: Closing, + rhs: Flow[SslTlsInbound, SslTlsOutbound, Any]) = { + binding = server(clientTls(rightClosing).reversed join rhs) + serverTls(leftClosing) join StreamTcp().outgoingConnection(binding.localAddress) + } + override def cleanup(): Unit = binding.unbind() + } + + val communicationPatterns = + Seq( + ClientInitiates, + ServerInitiates, + ClientInitiatesViaTcp, + ServerInitiatesViaTcp) + + trait PayloadScenario extends Named { + def flow: Flow[SslTlsInbound, SslTlsOutbound, Any] = + Flow[SslTlsInbound] + .map { + var session: SSLSession = null + def setSession(s: SSLSession) = { + session = s + system.log.debug(s"new session: $session (${session.getId mkString ","})") + } + + { + case SessionTruncated ⇒ SendBytes(ByteString("TRUNCATED")) + case SessionBytes(s, b) if session == null ⇒ + setSession(s) + SendBytes(b) + case SessionBytes(s, b) if s != session ⇒ + setSession(s) + SendBytes(ByteString("NEWSESSION") ++ b) + case SessionBytes(s, b) ⇒ SendBytes(b) + } + } + def leftClosing: Closing = IgnoreComplete + def rightClosing: Closing = IgnoreComplete + + def inputs: immutable.Seq[SslTlsOutbound] + def output: ByteString + + protected def send(str: String) = SendBytes(ByteString(str)) + protected def send(ch: Char) = SendBytes(ByteString(ch.toByte)) + } + + object SingleBytes extends PayloadScenario { + val str = "0123456789" + def inputs = str.map(ch ⇒ SendBytes(ByteString(ch.toByte))) + def output = ByteString(str) + } + + object MediumMessages extends PayloadScenario { + val strs = "0123456789" map (d ⇒ d.toString * (rnd.nextInt(9000) + 1000)) + def inputs = strs map (s ⇒ SendBytes(ByteString(s))) + def output = ByteString((strs :\ "")(_ ++ _)) + } + + object LargeMessages extends PayloadScenario { + // TLS max packet size is 16384 bytes + val strs = "0123456789" map (d ⇒ d.toString * (rnd.nextInt(9000) + 17000)) + def inputs = strs map (s ⇒ SendBytes(ByteString(s))) + def output = ByteString((strs :\ "")(_ ++ _)) + } + + object EmptyBytesFirst extends PayloadScenario { + def inputs = List(ByteString.empty, ByteString("hello")).map(SendBytes) + def output = ByteString("hello") + } + + object EmptyBytesInTheMiddle extends PayloadScenario { + def inputs = List(ByteString("hello"), ByteString.empty, ByteString(" world")).map(SendBytes) + def output = ByteString("hello world") + } + + object EmptyBytesLast extends PayloadScenario { + def inputs = List(ByteString("hello"), ByteString.empty).map(SendBytes) + def output = ByteString("hello") + } + + // this demonstrates that cancellation is ignored so that the five results make it back + object CancellingRHS extends PayloadScenario { + override def flow = + Flow[SslTlsInbound] + .mapConcat { + case SessionTruncated ⇒ SessionTruncated :: Nil + case SessionBytes(s, bytes) ⇒ bytes.map(b ⇒ SessionBytes(s, ByteString(b))) + } + .take(5) + .mapAsync(5, x ⇒ later(500.millis, system.scheduler)(Future.successful(x))) + .via(super.flow) + override def rightClosing = IgnoreCancel + + val str = "abcdef" * 100 + def inputs = str.map(send) + def output = ByteString(str.take(5)) + } + + object CancellingRHSIgnoresBoth extends PayloadScenario { + override def flow = + Flow[SslTlsInbound] + .mapConcat { + case SessionTruncated ⇒ SessionTruncated :: Nil + case SessionBytes(s, bytes) ⇒ bytes.map(b ⇒ SessionBytes(s, ByteString(b))) + } + .take(5) + .mapAsync(5, x ⇒ later(500.millis, system.scheduler)(Future.successful(x))) + .via(super.flow) + override def rightClosing = IgnoreBoth + + val str = "abcdef" * 100 + def inputs = str.map(send) + def output = ByteString(str.take(5)) + } + + object LHSIgnoresBoth extends PayloadScenario { + override def leftClosing = IgnoreBoth + val str = "0123456789" + def inputs = str.map(ch ⇒ SendBytes(ByteString(ch.toByte))) + def output = ByteString(str) + } + + object BothSidesIgnoreBoth extends PayloadScenario { + override def leftClosing = IgnoreBoth + override def rightClosing = IgnoreBoth + val str = "0123456789" + def inputs = str.map(ch ⇒ SendBytes(ByteString(ch.toByte))) + def output = ByteString(str) + } + + object SessionRenegotiationBySender extends PayloadScenario { + def inputs = List(send("hello"), NegotiateNewSession, send("world")) + def output = ByteString("helloNEWSESSIONworld") + } + + // difference is that the RHS engine will now receive the handshake while trying to send + object SessionRenegotiationByReceiver extends PayloadScenario { + val str = "abcdef" * 100 + def inputs = str.map(send) ++ Seq(NegotiateNewSession) ++ "hello world".map(send) + def output = ByteString(str + "NEWSESSIONhello world") + } + + val logCipherSuite = Flow[SslTlsInbound] + .map { + var session: SSLSession = null + def setSession(s: SSLSession) = { + session = s + system.log.debug(s"new session: $session (${session.getId mkString ","})") + } + + { + case SessionTruncated ⇒ SendBytes(ByteString("TRUNCATED")) + case SessionBytes(s, b) if s != session ⇒ + setSession(s) + SendBytes(ByteString(s.getCipherSuite) ++ b) + case SessionBytes(s, b) ⇒ SendBytes(b) + } + } + + object SessionRenegotiationFirstOne extends PayloadScenario { + override def flow = logCipherSuite + def inputs = NegotiateNewSession.withCipherSuites("TLS_RSA_WITH_AES_128_CBC_SHA") :: send("hello") :: Nil + def output = ByteString("TLS_RSA_WITH_AES_128_CBC_SHAhello") + } + + object SessionRenegotiationFirstTwo extends PayloadScenario { + override def flow = logCipherSuite + def inputs = NegotiateNewSession.withCipherSuites("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA") :: send("hello") :: Nil + def output = ByteString("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHAhello") + } + + val scenarios = + Seq( + SingleBytes, + MediumMessages, + LargeMessages, + EmptyBytesFirst, + EmptyBytesInTheMiddle, + EmptyBytesLast, + CancellingRHS, + SessionRenegotiationBySender, + SessionRenegotiationByReceiver, + SessionRenegotiationFirstOne, + SessionRenegotiationFirstTwo) + + for { + commPattern ← communicationPatterns + scenario ← scenarios + } { + s"work in mode ${commPattern.name} while sending ${scenario.name}" in { + val onRHS = debug.via(scenario.flow) + val f = + Source(scenario.inputs) + .via(commPattern.decorateFlow(scenario.leftClosing, scenario.rightClosing, onRHS)) + .transform(() ⇒ new PushStage[SslTlsInbound, SslTlsInbound] { + override def onPush(elem: SslTlsInbound, ctx: Context[SslTlsInbound]) = + ctx.push(elem) + override def onDownstreamFinish(ctx: Context[SslTlsInbound]) = { + system.log.debug("me cancelled") + ctx.finish() + } + }) + .via(debug) + .collect { case SessionBytes(_, b) ⇒ b } + .scan(ByteString.empty)(_ ++ _) + .transform(() ⇒ new Timeout(6.seconds)) + .transform(() ⇒ new DropWhile(_.size < scenario.output.size)) + .runWith(Sink.head) + + Await.result(f, 8.seconds).utf8String should be(scenario.output.utf8String) + + commPattern.cleanup() + + // flush log so as to not mix up logs of different test cases + if (log.isDebugEnabled) + EventFilter.debug("stopgap", occurrences = 1) intercept { + log.debug("stopgap") + } + } + } + + } + + "A SslTlsPlacebo" must { + + "pass through data" in { + val f = Source(1 to 3) + .map(b ⇒ SendBytes(ByteString(b.toByte))) + .via(SslTlsPlacebo.forScala join Flow.apply) + .grouped(10) + .runWith(Sink.head) + val result = Await.result(f, 3.seconds) + result.map(_.bytes) should be((1 to 3).map(b ⇒ ByteString(b.toByte))) + result.map(_.session).foreach(s ⇒ s.getCipherSuite should be("SSL_NULL_WITH_NULL_NULL")) + } + + } + +} diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanSpec.scala index f44f4a352e..68e344a390 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowScanSpec.scala @@ -6,12 +6,12 @@ package akka.stream.scaladsl import scala.concurrent.Await import scala.concurrent.duration._ import scala.concurrent.forkjoin.ThreadLocalRandom.{ current ⇒ random } - import scala.collection.immutable - import akka.stream.ActorFlowMaterializer import akka.stream.ActorFlowMaterializerSettings import akka.stream.testkit.AkkaSpec +import akka.stream.ActorOperationAttributes +import akka.stream.Supervision class FlowScanSpec extends AkkaSpec { @@ -39,5 +39,20 @@ class FlowScanSpec extends AkkaSpec { val v = Vector.empty[Int] scan(Source(v)) should be(v.scan(0)(_ + _)) } + + "emit values promptly" in { + val f = Source.single(1).concat(Source.lazyEmpty).scan(0)(_ + _).grouped(2).runWith(Sink.head) + Await.result(f, 1.second) should be(Seq(0, 1)) + } + + "fail properly" in { + import ActorOperationAttributes._ + val scan = Flow[Int].scan(0) { (old, current) ⇒ + require(current > 0) + old + current + }.withAttributes(supervisionStrategy(Supervision.restartingDecider)) + val f = Source(List(1, 3, -1, 5, 7)).via(scan).grouped(1000).runWith(Sink.head) + Await.result(f, 1.second) should be(Seq(0, 1, 4, 0, 5, 12)) + } } } diff --git a/akka-stream/src/main/scala/akka/stream/Shape.scala b/akka-stream/src/main/scala/akka/stream/Shape.scala index 2337087f40..c794a2b118 100644 --- a/akka-stream/src/main/scala/akka/stream/Shape.scala +++ b/akka-stream/src/main/scala/akka/stream/Shape.scala @@ -213,6 +213,7 @@ final case class BidiShape[-In1, +Out1, -In2, +Out2](in1: Inlet[In1], require(outlets.size == 2, s"proposed outlets [${outlets.mkString(", ")}] do not fit BidiShape") BidiShape(inlets(0), outlets(0), inlets(1), outlets(1)) } + def reversed: Shape = copyFromPorts(inlets.reverse, outlets.reverse) //#implementation-details-elided } //#bidi-shape diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala index a6a4f91c59..8af7083ca4 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorFlowMaterializerImpl.scala @@ -11,10 +11,15 @@ import akka.pattern.ask import akka.stream.actor.ActorSubscriber import akka.stream.impl.GenJunctions.ZipWithModule import akka.stream.impl.Junctions._ +import akka.stream.impl.MultiStreamInputProcessor.SubstreamSubscriber import akka.stream.impl.StreamLayout.Module import akka.stream.impl.fusing.ActorInterpreter +import akka.stream.impl.io.SslTlsCipherActor import akka.stream.scaladsl._ import akka.stream._ +import akka.stream.io._ +import akka.stream.io.SslTls.TlsModule +import akka.util.ByteString import org.reactivestreams._ import scala.concurrent.{ Await, ExecutionContextExecutor } @@ -83,6 +88,22 @@ private[akka] case class ActorFlowMaterializerImpl(override val settings: ActorF assignPort(stage.outPort, processor) mat + case tls: TlsModule ⇒ + val es = effectiveSettings(effectiveAttributes) + val props = SslTlsCipherActor.props(es, tls.sslContext, tls.firstSession, tracing = true, tls.role, tls.closing) + val impl = actorOf(props, stageName(effectiveAttributes), es.dispatcher) + def factory(id: Int) = new ActorPublisher[Any](impl) { + override val wakeUpMsg = FanOut.SubstreamSubscribePending(id) + } + val publishers = Vector.tabulate(2)(factory) + impl ! FanOut.ExposedPublishers(publishers) + + assignPort(tls.plainOut, publishers(SslTlsCipherActor.UserOut)) + assignPort(tls.cipherOut, publishers(SslTlsCipherActor.TransportOut)) + + assignPort(tls.plainIn, FanIn.SubInput[Any](impl, SslTlsCipherActor.UserIn)) + assignPort(tls.cipherIn, FanIn.SubInput[Any](impl, SslTlsCipherActor.TransportIn)) + case junction: JunctionModule ⇒ materializeJunction(junction, effectiveAttributes, effectiveSettings(effectiveAttributes)) } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/FanIn.scala b/akka-stream/src/main/scala/akka/stream/impl/FanIn.scala index 6980fbb279..0ae7090d3b 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FanIn.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FanIn.scala @@ -155,6 +155,7 @@ private[akka] object FanIn { if (input.inputsDepleted) { if (marked(id)) markedDepleted += 1 depleted(id) = true + onDepleted(id) } elem } diff --git a/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala b/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala index 963ae08464..759a15c267 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FanOut.scala @@ -186,6 +186,16 @@ private[akka] object FanOut { def onCancel(output: Int): Unit = () + def demandAvailableFor(id: Int) = new TransferState { + override def isCompleted: Boolean = cancelled(id) || completed(id) + override def isReady: Boolean = pending(id) + } + + def demandOrCancelAvailableFor(id: Int) = new TransferState { + override def isCompleted: Boolean = false + override def isReady: Boolean = pending(id) || cancelled(id) + } + /** * Will only transfer an element when all marked outputs * have demand, and will complete as soon as any of the marked diff --git a/akka-stream/src/main/scala/akka/stream/impl/FlexiMergeImpl.scala b/akka-stream/src/main/scala/akka/stream/impl/FlexiMergeImpl.scala index 2a43b078fa..49c8469ffa 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/FlexiMergeImpl.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/FlexiMergeImpl.scala @@ -120,22 +120,26 @@ private[akka] class FlexiMergeImpl[T, S <: Shape]( nextPhase(TransferPhase(precondition) { () ⇒ behavior.condition match { case read: ReadAny[t] ⇒ + suppressCompletion() val id = inputBunch.idToDequeue() val elem = inputBunch.dequeueAndYield(id) val inputHandle = inputMapping(id) callOnInput(inputHandle, elem) triggerCompletionAfterRead(inputHandle) case r: ReadPreferred[t] ⇒ + suppressCompletion() val elem = inputBunch.dequeuePrefering(indexOf(r.preferred)) val id = inputBunch.lastDequeuedId val inputHandle = inputMapping(id) callOnInput(inputHandle, elem) triggerCompletionAfterRead(inputHandle) case Read(input) ⇒ + suppressCompletion() val elem = inputBunch.dequeue(indexOf(input)) callOnInput(input, elem) triggerCompletionAfterRead(input) case read: ReadAll[t] ⇒ + suppressCompletion() val inputs = read.inputs val values = inputs.collect { case input if include(input) ⇒ input → inputBunch.dequeue(indexOf(input)) @@ -160,15 +164,22 @@ private[akka] class FlexiMergeImpl[T, S <: Shape]( } } - private def triggerCompletionAfterRead(inputHandle: InPort): Unit = + private var completionEnabled = true + + private def suppressCompletion(): Unit = completionEnabled = false + + private def triggerCompletionAfterRead(inputHandle: InPort): Unit = { + completionEnabled = true if (inputBunch.isDepleted(indexOf(inputHandle))) triggerCompletion(inputHandle) + } private def triggerCompletion(in: InPort): Unit = - changeBehavior( - try completion.onUpstreamFinish(ctx, in) - catch { - case NonFatal(e) ⇒ fail(e); mergeLogic.SameState - }) + if (completionEnabled) + changeBehavior( + try completion.onUpstreamFinish(ctx, in) + catch { + case NonFatal(e) ⇒ fail(e); mergeLogic.SameState + }) } diff --git a/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala b/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala index 23835646d7..cf278bfc50 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/Transfer.scala @@ -171,4 +171,3 @@ private[akka] trait Pump { protected def pumpFinished(): Unit } - diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala index a42673f648..a6852715a8 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/ActorInterpreter.scala @@ -298,7 +298,7 @@ private[akka] object ActorInterpreter { def props(settings: ActorFlowMaterializerSettings, ops: Seq[Stage[_, _]], materializer: ActorFlowMaterializer): Props = Props(new ActorInterpreter(settings, ops, materializer)) - case class AsyncInput(op: AsyncStage[Any, Any, Any], ctx: AsyncContext[Any, Any], event: Any) + case class AsyncInput(op: AsyncStage[Any, Any, Any], ctx: AsyncContext[Any, Any], event: Any) extends DeadLetterSuppression } /** diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index 1bdbcd25de..cd59892e5b 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -108,18 +108,27 @@ private[akka] final case class Drop[T](count: Long) extends PushStage[T, T] { */ private[akka] final case class Scan[In, Out](zero: Out, f: (Out, In) ⇒ Out, decider: Supervision.Decider) extends PushPullStage[In, Out] { private var aggregator = zero + private var pushedZero = false override def onPush(elem: In, ctx: Context[Out]): SyncDirective = { - val old = aggregator - aggregator = f(old, elem) - ctx.push(old) + if (pushedZero) { + aggregator = f(aggregator, elem) + ctx.push(aggregator) + } else { + aggregator = f(zero, elem) + ctx.push(zero) + } } override def onPull(ctx: Context[Out]): SyncDirective = - if (ctx.isFinishing) ctx.pushAndFinish(aggregator) - else ctx.pull() + if (!pushedZero) { + pushedZero = true + if (ctx.isFinishing) ctx.pushAndFinish(aggregator) else ctx.push(aggregator) + } else ctx.pull() - override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective = ctx.absorbTermination() + override def onUpstreamFinish(ctx: Context[Out]): TerminationDirective = + if (pushedZero) ctx.finish() + else ctx.absorbTermination() override def decide(t: Throwable): Supervision.Directive = decider(t) diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/SslTls.scala b/akka-stream/src/main/scala/akka/stream/impl/io/SslTls.scala new file mode 100644 index 0000000000..c01ab68d1e --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/impl/io/SslTls.scala @@ -0,0 +1,449 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.impl.io + +import java.nio.ByteBuffer +import java.security.Principal +import java.security.cert.Certificate +import javax.net.ssl.SSLEngineResult.HandshakeStatus +import javax.net.ssl.SSLEngineResult.HandshakeStatus._ +import javax.net.ssl.SSLEngineResult.Status._ +import javax.net.ssl._ +import akka.actor.{ Props, Actor, ActorLogging, ActorRef } +import akka.stream.ActorFlowMaterializerSettings +import akka.stream.impl.FanIn.InputBunch +import akka.stream.impl.FanOut.OutputBunch +import akka.stream.impl._ +import akka.util.ByteString +import akka.util.ByteStringBuilder +import org.reactivestreams.Publisher +import org.reactivestreams.Subscriber +import scala.annotation.tailrec +import scala.collection.immutable +import akka.stream.io._ +import akka.event.LoggingReceive + +/** + * INTERNAL API. + */ +private[akka] object SslTlsCipherActor { + + def props(settings: ActorFlowMaterializerSettings, + sslContext: SSLContext, + firstSession: NegotiateNewSession, + tracing: Boolean, + role: Role, + closing: Closing): Props = + Props(new SslTlsCipherActor(settings, sslContext, firstSession, tracing, role, closing)) + + final val TransportIn = 0 + final val TransportOut = 0 + + final val UserOut = 1 + final val UserIn = 1 +} + +/** + * INTERNAL API. + */ +private[akka] class SslTlsCipherActor(settings: ActorFlowMaterializerSettings, sslContext: SSLContext, + firstSession: NegotiateNewSession, tracing: Boolean, + role: Role, closing: Closing) + extends Actor with ActorLogging with Pump { + + import SslTlsCipherActor._ + + protected val outputBunch = new OutputBunch(outputCount = 2, self, this) + outputBunch.markAllOutputs() + + protected val inputBunch = new InputBunch(inputCount = 2, settings.maxInputBufferSize, this) { + override def onError(input: Int, e: Throwable): Unit = fail(e) + } + + /** + * The SSLEngine needs bite-sized chunks of data but we get arbitrary ByteString + * from both the UserIn and the TransportIn ports. This is used to chop up such + * a ByteString by filling the respective ByteBuffer and taking care to dequeue + * a new element when data are demanded and none are left lying on the chopping + * block. + */ + class ChoppingBlock(idx: Int, name: String) extends TransferState { + override def isReady: Boolean = buffer.nonEmpty + override def isCompleted: Boolean = false + + private var buffer = ByteString.empty + + /** + * Whether there are no bytes lying on this chopping block. + */ + def isEmpty: Boolean = buffer.isEmpty + + /** + * Pour as many bytes as are available either on the chopping block or in + * the inputBunch’s next ByteString into the supplied ByteBuffer, which is + * expected to be in “read left-overs” mode, i.e. everything between its + * position and limit is retained. In order to allocate a fresh ByteBuffer + * with these characteristics, use `prepare()`. + */ + def chopInto(b: ByteBuffer): Unit = { + b.compact() + if (buffer.isEmpty) { + buffer = inputBunch.dequeue(idx) match { + // this class handles both UserIn and TransportIn + case bs: ByteString ⇒ bs + case SendBytes(bs) ⇒ bs + case n: NegotiateNewSession ⇒ + setNewSessionParameters(n) + ByteString.empty + } + if (tracing) log.debug(s"chopping from new chunk of ${buffer.size} into $name (${b.position})") + } else { + if (tracing) log.debug(s"chopping from old chunk of ${buffer.size} into $name (${b.position})") + } + val copied = buffer.copyToBuffer(b) + buffer = buffer.drop(copied) + b.flip() + } + + /** + * When potentially complete packet data are left after unwrap() we must + * put them back onto the chopping block because otherwise the pump will + * not know that we are runnable. + */ + def putBack(b: ByteBuffer): Unit = + if (b.hasRemaining()) { + if (tracing) log.debug(s"putting back ${b.remaining} bytes into $name") + val bs = ByteString(b) + if (bs.nonEmpty) buffer = bs ++ buffer + prepare(b) + } + + /** + * Prepare a fresh ByteBuffer for receiving a chop of data. + */ + def prepare(b: ByteBuffer): Unit = { + b.clear() + b.limit(0) + } + } + + // These are Nettys default values + // 16665 + 1024 (room for compressed data) + 1024 (for OpenJDK compatibility) + val transportOutBuffer = ByteBuffer.allocate(16665 + 2048) + /* + * deviating here: chopping multiple input packets into this buffer can lead to + * an OVERFLOW signal that also is an UNDERFLOW; avoid unnecessary copying by + * increasing this buffer size to host up to two packets + */ + val userOutBuffer = ByteBuffer.allocate(16665 * 2 + 2048) + val transportInBuffer = ByteBuffer.allocate(16665 + 2048) + val userInBuffer = ByteBuffer.allocate(16665 + 2048) + + val userInChoppingBlock = new ChoppingBlock(UserIn, "UserIn") + userInChoppingBlock.prepare(userInBuffer) + val transportInChoppingBlock = new ChoppingBlock(TransportIn, "TransportIn") + transportInChoppingBlock.prepare(transportInBuffer) + + val engine: SSLEngine = { + val e = sslContext.createSSLEngine() + e.setUseClientMode(role == Client) + e + } + var currentSession = engine.getSession + var currentSessionParameters = firstSession + applySessionParameters() + + def applySessionParameters(): Unit = { + val csp = currentSessionParameters + import csp._ + enabledCipherSuites foreach (cs ⇒ engine.setEnabledCipherSuites(cs.toArray)) + enabledProtocols foreach (p ⇒ engine.setEnabledProtocols(p.toArray)) + clientAuth match { + case Some(ClientAuth.None) ⇒ engine.setNeedClientAuth(false) + case Some(ClientAuth.Want) ⇒ engine.setWantClientAuth(true) + case Some(ClientAuth.Need) ⇒ engine.setNeedClientAuth(true) + case None ⇒ // do nothing + } + sslParameters foreach (p ⇒ engine.setSSLParameters(p)) + engine.beginHandshake() + lastHandshakeStatus = engine.getHandshakeStatus + } + + def setNewSessionParameters(n: NegotiateNewSession): Unit = { + if (tracing) log.debug(s"applying $n") + currentSession.invalidate() + currentSessionParameters = n + applySessionParameters() + corkUser = true + } + + /* + * So here’s the big picture summary: the SSLEngine is the boss, and it can + * be in several states. Depending on this state, we may want to react to + * different input and output conditions. + * + * - normal bidirectional operation (does both outbound and inbound) + * - outbound close initiated, inbound still open + * - inbound close initiated, outbound still open + * - fully closed + * + * Upon reaching the last state we obviously just shut down. In addition to + * these user-data states, the engine may at any point in time also be + * handshaking. This is mostly transparent, but it has an influence on the + * outbound direction: + * + * - if the local user triggered a re-negotiation, cork all user data until + * that is finished + * - if the outbound direction has been closed, trigger outbound readiness + * based upon HandshakeStatus.NEED_WRAP + * + * These conditions lead to the introduction of a synthetic TransferState + * representing the Engine. + */ + + var lastHandshakeStatus: HandshakeStatus = _ + + val engineNeedsWrap = new TransferState { + def isReady = lastHandshakeStatus == NEED_WRAP + def isCompleted = false + } + + val engineInboundOpen = new TransferState { + def isReady = !engine.isInboundDone() + def isCompleted = false + } + + var corkUser = true + + val userHasData = new TransferState { + private val user = inputBunch.inputsOrCompleteAvailableFor(UserIn) || userInChoppingBlock + def isReady = !corkUser && user.isReady && lastHandshakeStatus != NEED_UNWRAP + def isCompleted = false + } + + val transportHasData = inputBunch.inputsOrCompleteAvailableFor(TransportIn) || transportInChoppingBlock + val userOutCancelled = new TransferState { + def isReady = outputBunch.isCancelled(UserOut) + def isCompleted = inputBunch.isDepleted(TransportIn) + } + + // bidirectional case + val outbound = (userHasData || engineNeedsWrap) && outputBunch.demandAvailableFor(TransportOut) + val inbound = (transportHasData || userOutCancelled) && outputBunch.demandOrCancelAvailableFor(UserOut) + + // half-closed + val outboundHalfClosed = engineNeedsWrap && outputBunch.demandAvailableFor(TransportOut) + val inboundHalfClosed = transportHasData && engineInboundOpen + + def completeOrFlush(): Unit = + if (engine.isOutboundDone()) nextPhase(completedPhase) + else nextPhase(flushingOutbound) + + private def doInbound(isOutboundClosed: Boolean, inboundState: TransferState): Boolean = + if (inputBunch.isDepleted(TransportIn) && transportInChoppingBlock.isEmpty) { + if (tracing) log.debug("closing inbound") + try engine.closeInbound() + catch { case ex: SSLException ⇒ outputBunch.enqueue(UserOut, SessionTruncated) } + completeOrFlush() + false + } else if (inboundState != inboundHalfClosed && outputBunch.isCancelled(UserOut)) { + if (!isOutboundClosed && closing.ignoreCancel) { + if (tracing) log.debug("ignoring UserIn cancellation") + nextPhase(inboundClosed) + } else { + if (tracing) log.debug("closing inbound due to UserOut cancellation") + engine.closeOutbound() // this is the correct way of shutting down the engine + lastHandshakeStatus = engine.getHandshakeStatus + nextPhase(flushingOutbound) + } + true + } else if (inboundState.isReady) { + transportInChoppingBlock.chopInto(transportInBuffer) + try { + doUnwrap() + true + } catch { + case ex: SSLException ⇒ + if (tracing) log.debug(s"SSLException during doUnwrap: $ex") + completeOrFlush() + false + } + } else true + + private def doOutbound(isInboundClosed: Boolean): Unit = + if (inputBunch.isDepleted(UserIn) && userInChoppingBlock.isEmpty) { + if (!isInboundClosed && closing.ignoreComplete) { + if (tracing) log.debug("ignoring closeOutbound") + } else { + if (tracing) log.debug("closing outbound directly") + engine.closeOutbound() + lastHandshakeStatus = engine.getHandshakeStatus + } + nextPhase(outboundClosed) + } else if (outputBunch.isCancelled(TransportOut)) { + nextPhase(completedPhase) + } else if (outbound.isReady) { + if (userHasData.isReady) userInChoppingBlock.chopInto(userInBuffer) + try doWrap() + catch { + case ex: SSLException ⇒ + if (tracing) log.debug(s"SSLException during doWrap: $ex") + completeOrFlush() + } + } + + val bidirectional = TransferPhase(outbound || inbound) { () ⇒ + if (tracing) log.debug("bidirectional") + val continue = doInbound(isOutboundClosed = false, inbound) + if (continue) { + if (tracing) log.debug("bidirectional continue") + doOutbound(isInboundClosed = false) + } + } + + val flushingOutbound = TransferPhase(outboundHalfClosed) { () ⇒ + if (tracing) log.debug("flushingOutbound") + try doWrap() + catch { case ex: SSLException ⇒ nextPhase(completedPhase) } + } + + val awaitingClose = TransferPhase(inputBunch.inputsAvailableFor(TransportIn)) { () ⇒ + if (tracing) log.debug("awaitingClose") + transportInChoppingBlock.chopInto(transportInBuffer) + try doUnwrap(ignoreOutput = true) + catch { case ex: SSLException ⇒ nextPhase(completedPhase) } + } + + val outboundClosed = TransferPhase(outboundHalfClosed || inbound) { () ⇒ + if (tracing) log.debug("outboundClosed") + val continue = doInbound(isOutboundClosed = true, inbound) + if (continue && outboundHalfClosed.isReady) { + if (tracing) log.debug("outboundClosed continue") + try doWrap() + catch { case ex: SSLException ⇒ nextPhase(completedPhase) } + } + } + + val inboundClosed = TransferPhase(outbound || inboundHalfClosed) { () ⇒ + if (tracing) log.debug("inboundClosed") + val continue = doInbound(isOutboundClosed = false, inboundHalfClosed) + if (continue) { + if (tracing) log.debug("inboundClosed continue") + doOutbound(isInboundClosed = true) + } + } + + def flushToTransport(): Unit = { + if (tracing) log.debug("flushToTransport") + transportOutBuffer.flip() + if (transportOutBuffer.hasRemaining) { + val bs = ByteString(transportOutBuffer) + outputBunch.enqueue(TransportOut, bs) + if (tracing) log.debug(s"sending ${bs.size} bytes") + } + transportOutBuffer.clear() + } + + def flushToUser(): Unit = { + if (tracing) log.debug("flushToUser") + userOutBuffer.flip() + if (userOutBuffer.hasRemaining) { + val bs = ByteString(userOutBuffer) + outputBunch.enqueue(UserOut, SessionBytes(currentSession, bs)) + } + userOutBuffer.clear() + } + + private def doWrap(): Unit = { + val result = engine.wrap(userInBuffer, transportOutBuffer) + lastHandshakeStatus = result.getHandshakeStatus + if (tracing) log.debug(s"wrap: status=${result.getStatus} handshake=$lastHandshakeStatus remaining=${userInBuffer.remaining} out=${transportOutBuffer.position}") + if (lastHandshakeStatus == FINISHED) handshakeFinished() + runDelegatedTasks() + result.getStatus match { + case OK ⇒ + flushToTransport() + userInChoppingBlock.putBack(userInBuffer) + case CLOSED ⇒ + flushToTransport() + if (engine.isInboundDone()) nextPhase(completedPhase) + else nextPhase(awaitingClose) + case s ⇒ fail(new IllegalStateException(s"unexpected status $s in doWrap()")) + } + } + + @tailrec + private def doUnwrap(ignoreOutput: Boolean = false): Unit = { + val result = engine.unwrap(transportInBuffer, userOutBuffer) + if (ignoreOutput) userOutBuffer.clear() + lastHandshakeStatus = result.getHandshakeStatus + if (tracing) log.debug(s"unwrap: status=${result.getStatus} handshake=$lastHandshakeStatus remaining=${transportInBuffer.remaining} out=${userOutBuffer.position}") + runDelegatedTasks() + result.getStatus match { + case OK ⇒ + result.getHandshakeStatus match { + case NEED_WRAP ⇒ flushToUser() + case FINISHED ⇒ + flushToUser() + handshakeFinished() + transportInChoppingBlock.putBack(transportInBuffer) + case _ ⇒ + if (transportInBuffer.hasRemaining()) doUnwrap() + else flushToUser() + } + case CLOSED ⇒ + flushToUser() + if (engine.isOutboundDone()) nextPhase(completedPhase) + else nextPhase(flushingOutbound) + case BUFFER_UNDERFLOW ⇒ + flushToUser() + case BUFFER_OVERFLOW ⇒ + flushToUser() + transportInChoppingBlock.putBack(transportInBuffer) + case s ⇒ fail(new IllegalStateException(s"unexpected status $s in doUnwrap()")) + } + } + + @tailrec + private def runDelegatedTasks(): Unit = { + val task = engine.getDelegatedTask + if (task != null) { + if (tracing) log.debug("running task") + task.run() + runDelegatedTasks() + } else { + val st = lastHandshakeStatus + lastHandshakeStatus = engine.getHandshakeStatus + if (tracing && st != lastHandshakeStatus) log.debug(s"handshake status after tasks: $lastHandshakeStatus") + } + } + + private def handshakeFinished(): Unit = { + if (tracing) log.debug("handshake finished") + currentSession = engine.getSession + corkUser = false + } + + override def receive = inputBunch.subreceive.orElse[Any, Unit](outputBunch.subreceive) + + nextPhase(bidirectional) + + protected def fail(e: Throwable): Unit = { + // FIXME: escalate to supervisor + if (tracing) log.debug("fail {} due to: {}", self, e.getMessage) + inputBunch.cancel() + outputBunch.error(TransportOut, e) + outputBunch.error(UserOut, e) + context.stop(self) + } + + override protected def pumpFailed(e: Throwable): Unit = fail(e) + + override protected def pumpFinished(): Unit = { + inputBunch.cancel() + outputBunch.complete() + if (tracing) log.debug(s"STOP Outbound Closed: ${engine.isOutboundDone} Inbound closed: ${engine.isInboundDone}") + context.stop(self) + } +} diff --git a/akka-stream/src/main/scala/akka/stream/io/SslTls.scala b/akka-stream/src/main/scala/akka/stream/io/SslTls.scala new file mode 100644 index 0000000000..52187e742c --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/io/SslTls.scala @@ -0,0 +1,411 @@ +/** + * Copyright (C) 2015 Typesafe Inc. + */ +package akka.stream.io + +import akka.stream._ +import akka.stream.impl.StreamLayout.Module +import akka.util.ByteString +import javax.net.ssl._ +import scala.annotation.varargs +import scala.collection.immutable +import java.security.cert.Certificate + +/** + * Stream cipher support based upon JSSE. + * + * The underlying SSLEngine has four ports: plaintext input/output and + * ciphertext input/output. These are modeled as a [[akka.stream.BidiShape]] + * element for use in stream topologies, where the plaintext ports are on the + * left hand side of the shape and the ciphertext ports on the right hand side. + * + * Configuring JSSE is a rather complex topic, please refer to the JDK platform + * documentation or the excellent user guide that is part of the Play Framework + * documentation. The philosophy of this integration into Akka Streams is to + * expose all knobs and dials to client code and therefore not limit the + * configuration possibilities. In particular the client code will have to + * provide the SSLContext from which the SSLEngine is then created. Handshake + * parameters are set using [[NegotiateNewSession]] messages, the settings for + * the initial handshake need to be provided up front using the same class; + * please refer to the method documentation below. + * + * '''IMPORTANT NOTE''' + * + * The TLS specification does not permit half-closing of the user data session + * that it transports—to be precise a half-close will always promptly lead to a + * full close. This means that canceling the plaintext output or completing the + * plaintext input of the SslTls stage will lead to full termination of the + * secure connection without regard to whether bytes are remaining to be sent or + * received, respectively. Especially for a client the common idiom of attaching + * a finite Source to the plaintext input and transforming the plaintext response + * bytes coming out will not work out of the box due to early termination of the + * connection. For this reason there is a parameter that determines whether the + * SslTls stage shall ignore completion and/or cancellation events, and the + * default is to ignore completion (in view of the client–server scenario). In + * order to terminate the connection the client will then need to cancel the + * plaintext output as soon as all expected bytes have been received. When + * ignoring both types of events the stage will shut down once both events have + * been received. See also [[Closing]]. + */ +object SslTls { + + /** + * Scala API: create a StreamTls [[akka.stream.scaladsl.BidiFlow]]. The + * SSLContext will be used to create an SSLEngine to which then the + * `firstSession` parameters are applied before initiating the first + * handshake. The `role` parameter determines the SSLEngine’s role; this is + * often the same as the underlying transport’s server or client role, but + * that is not a requirement and depends entirely on the application + * protocol. + * + * For a description of the `closing` parameter please refer to [[Closing]]. + */ + def apply(sslContext: SSLContext, firstSession: NegotiateNewSession, + role: Role, closing: Closing = IgnoreComplete): scaladsl.BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound, Unit] = + new scaladsl.BidiFlow(TlsModule(OperationAttributes.none, sslContext, firstSession, role, closing)) + + /** + * Java API: create a StreamTls [[akka.stream.javadsl.BidiFlow]] in client mode. The + * SSLContext will be used to create an SSLEngine to which then the + * `firstSession` parameters are applied before initiating the first + * handshake. The `role` parameter determines the SSLEngine’s role; this is + * often the same as the underlying transport’s server or client role, but + * that is not a requirement and depends entirely on the application + * protocol. + * + * This method uses the default closing behavior or [[IgnoreComplete]]. + */ + def create(sslContext: SSLContext, firstSession: NegotiateNewSession, role: Role): javadsl.BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound, Unit] = + new javadsl.BidiFlow(apply(sslContext, firstSession, role)) + + /** + * Java API: create a StreamTls [[akka.stream.javadsl.BidiFlow]] in client mode. The + * SSLContext will be used to create an SSLEngine to which then the + * `firstSession` parameters are applied before initiating the first + * handshake. The `role` parameter determines the SSLEngine’s role; this is + * often the same as the underlying transport’s server or client role, but + * that is not a requirement and depends entirely on the application + * protocol. + * + * For a description of the `closing` parameter please refer to [[Closing]]. + */ + def create(sslContext: SSLContext, firstSession: NegotiateNewSession, role: Role, closing: Closing): javadsl.BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound, Unit] = + new javadsl.BidiFlow(apply(sslContext, firstSession, role, closing)) + + /** + * INTERNAL API. + */ + private[akka] case class TlsModule(plainIn: Inlet[SslTlsOutbound], plainOut: Outlet[SslTlsInbound], + cipherIn: Inlet[ByteString], cipherOut: Outlet[ByteString], + shape: Shape, attributes: OperationAttributes, + sslContext: SSLContext, firstSession: NegotiateNewSession, + role: Role, closing: Closing) extends Module { + override def subModules: Set[Module] = Set.empty + + override def withAttributes(att: OperationAttributes): Module = copy(attributes = att) + override def carbonCopy: Module = { + val mod = TlsModule(attributes, sslContext, firstSession, role, closing) + if (plainIn == shape.inlets(0)) mod + else mod.replaceShape(mod.shape.asInstanceOf[BidiShape[_, _, _, _]].reversed) + } + + override def replaceShape(s: Shape) = + if (s == shape) this + else if (shape.hasSamePortsAs(s)) copy(shape = s) + else throw new IllegalArgumentException("trying to replace shape with different ports") + } + + /** + * INTERNAL API. + */ + private[akka] object TlsModule { + def apply(attributes: OperationAttributes, sslContext: SSLContext, firstSession: NegotiateNewSession, role: Role, closing: Closing): TlsModule = { + val name = attributes.nameOrDefault(s"StreamTls($role)") + val cipherIn = new Inlet[ByteString](s"$name.cipherIn") + val cipherOut = new Outlet[ByteString](s"$name.cipherOut") + val plainIn = new Inlet[SslTlsOutbound](s"$name.transportIn") + val plainOut = new Outlet[SslTlsInbound](s"$name.transportOut") + val shape = new BidiShape(plainIn, cipherOut, cipherIn, plainOut) + TlsModule(plainIn, plainOut, cipherIn, cipherOut, shape, attributes, sslContext, firstSession, role, closing) + } + } +} + +/** + * This object holds simple wrapping [[BidiFlow]] implementations that can + * be used instead of [[SslTls]] when no encryption is desired. The flows will + * just adapt the message protocol by wrapping into [[SessionBytes]] and + * unwrapping [[SendBytes]]. + */ +object SslTlsPlacebo { + val forScala = scaladsl.BidiFlow() { implicit b ⇒ + // this constructs a session for (invalid) protocol SSL_NULL_WITH_NULL_NULL + val session = SSLContext.getDefault.createSSLEngine.getSession + val top = b.add(scaladsl.Flow[SslTlsOutbound].collect { case SendBytes(b) ⇒ b }) + val bottom = b.add(scaladsl.Flow[ByteString].map(SessionBytes(session, _))) + BidiShape(top, bottom) + } + val forJava = new javadsl.BidiFlow(forScala) +} + +/** + * Many protocols are asymmetric and distinguish between the client and the + * server, where the latter listens passively for messages and the former + * actively initiates the exchange. + */ +object Role { + /** + * Java API: obtain the [[Client]] singleton value. + */ + def client: Role = Client + /** + * Java API: obtain the [[Server]] singleton value. + */ + def server: Role = Server +} +sealed abstract class Role + +/** + * The client is usually the side that consumes the service provided by its + * interlocutor. The precise interpretation of this role is protocol specific. + */ +sealed abstract class Client extends Role +case object Client extends Client + +/** + * The server is usually the side the provides the service to its interlocutor. + * The precise interpretation of this role is protocol specific. + */ +sealed abstract class Server extends Role +case object Server extends Server + +/** + * All streams in Akka are unidirectional: while in a complex flow graph data + * may flow in multiple directions these individual flows are independent from + * each other. The difference between two half-duplex connections in opposite + * directions and a full-duplex connection is that the underlying transport + * is shared in the latter and tearing it down will end the data transfer in + * both directions. + * + * When integrating a full-duplex transport medium that does not support + * half-closing (which means ending one direction of data transfer without + * ending the other) into a stream topology, there can be unexpected effects. + * Feeding a finite Source into this medium will close the connection after + * all elements have been sent, which means that possible replies may not + * be received in full. To support this type of usage, the sending and + * receiving of data on the same side (e.g. on the [[Client]]) need to be + * coordinated such that it is known when all replies have been received. + * Only then should the transport be shut down. + * + * To support these scenarios it is recommended that the full-duplex + * transport integration is configurable in terms of termination handling, + * which means that the user can optionally suppress the normal (closing) + * reaction to completion or cancellation events, as is expressed by the + * possible values of this type: + * + * - [[EagerClose]] means to not ignore signals + * - [[IgnoreCancel]] means to not react to cancellation of the receiving + * side unless the sending side has already completed + * - [[IgnoreComplete]] means to not reacto the completion of the sending + * side unless the receiving side has already cancelled + * - [[IgnoreBoth]] means to ignore the first termination signal—be that + * cancellation or completion—and only act upon the second one + */ +sealed abstract class Closing { + def ignoreCancel: Boolean + def ignoreComplete: Boolean +} +object Closing { + /** + * Java API: obtain the [[EagerClose]] singleton value. + */ + def eagerClose: Closing = EagerClose + /** + * Java API: obtain the [[IgnoreCancel]] singleton value. + */ + def ignoreCancel: Closing = IgnoreCancel + /** + * Java API: obtain the [[IgnoreComplete]] singleton value. + */ + def ignoreComplete: Closing = IgnoreComplete + /** + * Java API: obtain the [[IgnoreBoth]] singleton value. + */ + def ignoreBoth: Closing = IgnoreBoth +} + +/** + * see [[Closing]] + */ +sealed abstract class EagerClose extends Closing { + override def ignoreCancel = false + override def ignoreComplete = false +} +case object EagerClose extends EagerClose + +/** + * see [[Closing]] + */ +sealed abstract class IgnoreCancel extends Closing { + override def ignoreCancel = true + override def ignoreComplete = false +} +case object IgnoreCancel extends IgnoreCancel + +/** + * see [[Closing]] + */ +sealed abstract class IgnoreComplete extends Closing { + override def ignoreCancel = false + override def ignoreComplete = true +} +case object IgnoreComplete extends IgnoreComplete + +/** + * see [[Closing]] + */ +sealed abstract class IgnoreBoth extends Closing { + override def ignoreCancel = true + override def ignoreComplete = true +} +case object IgnoreBoth extends IgnoreBoth + +/** + * This is the supertype of all messages that the SslTls stage emits on the + * plaintext side. + */ +sealed trait SslTlsInbound + +/** + * If the underlying transport is closed before the final TLS closure command + * is received from the peer then the SSLEngine will throw an SSLException that + * warns about possible truncation attacks. This exception is caught and + * translated into this message when encountered. Most of the time this occurs + * not because of a malicious attacker but due to a connection abort or a + * misbehaving communication peer. + */ +sealed abstract class SessionTruncated extends SslTlsInbound +case object SessionTruncated extends SessionTruncated + +/** + * Plaintext bytes emitted by the SSLEngine are received over one specific + * encryption session and this class bundles the bytes with the SSLSession + * object. When the session changes due to renegotiation (which can be + * initiated by either party) the new session value will not compare equal to + * the previous one. + * + * The Java API for getting session information is given by the SSLSession object, + * the Scala API adapters are offered below. + */ +case class SessionBytes(session: SSLSession, bytes: ByteString) extends SslTlsInbound { + /** + * Scala API: Extract the certificates that were actually used by this + * engine during this session’s negotiation. The list is empty if no + * certificates were used. + */ + def localCertificates: List[Certificate] = Option(session.getLocalCertificates).map(_.toList).getOrElse(Nil) + /** + * Scala API: Extract the Principal that was actually used by this engine + * during this session’s negotiation. + */ + def localPrincipal = Option(session.getLocalPrincipal) + /** + * Scala API: Extract the certificates that were used by the peer engine + * during this session’s negotiation. The list is empty if no certificates + * were used. + */ + def peerCertificates = + try Option(session.getPeerCertificates).map(_.toList).getOrElse(Nil) + catch { case e: SSLPeerUnverifiedException ⇒ Nil } + /** + * Scala API: Extract the Principal that the peer engine presented during + * this session’s negotiation. + */ + def peerPrincipal = + try Option(session.getPeerPrincipal) + catch { case e: SSLPeerUnverifiedException ⇒ None } +} + +/** + * This is the supertype of all messages that the SslTls stage accepts on its + * plaintext side. + */ +sealed trait SslTlsOutbound + +/** + * Initiate a new session negotiation. Any [[SendBytes]] commands following + * this one will be held back (i.e. back-pressured) until the new handshake is + * completed, meaning that the bytes following this message will be encrypted + * according to the requirements outlined here. + * + * Each of the values in this message is optional and will have the following + * effect if provided: + * + * - `enabledCipherSuites` will be passed to `SSLEngine::setEnabledCipherSuites()` + * - `enabledProtocols` will be passed to `SSLEngine::setEnabledProtocols()` + * - `clientAuth` will be passed to `SSLEngine::setWantClientAuth()` or `SSLEngine.setNeedClientAuth()`, respectively + * - `sslParameters` will be passed to `SSLEngine::setSSLParameters()` + */ +case class NegotiateNewSession( + enabledCipherSuites: Option[immutable.Seq[String]], + enabledProtocols: Option[immutable.Seq[String]], + clientAuth: Option[ClientAuth], + sslParameters: Option[SSLParameters]) extends SslTlsOutbound { + + /** + * Java API: Make a copy of this message with the given `enabledCipherSuites`. + */ + @varargs + def withCipherSuites(s: String*) = copy(enabledCipherSuites = Some(s.toList)) + + /** + * Java API: Make a copy of this message with the given `enabledProtocols`. + */ + @varargs + def withProtocols(p: String*) = copy(enabledProtocols = Some(p.toList)) + + /** + * Java API: Make a copy of this message with the given [[ClientAuth]] setting. + */ + def withClientAuth(ca: ClientAuth) = copy(clientAuth = Some(ca)) + + /** + * Java API: Make a copy of this message with the given [[SSLParameters]]. + */ + def withParameters(p: SSLParameters) = copy(sslParameters = Some(p)) +} + +object NegotiateNewSession extends NegotiateNewSession(None, None, None, None) { + /** + * Java API: obtain the default value (which will leave the SSLEngine’s + * settings unchanged). + */ + def withDefaults = this +} + +/** + * Send the given [[akka.util.ByteString]] across the encrypted session to the + * peer. + */ +case class SendBytes(bytes: ByteString) extends SslTlsOutbound + +/** + * An SSLEngine can either demand, allow or ignore its peer’s authentication + * (via certificates), where `Need` will fail the handshake if the peer does + * not provide valid credentials, `Want` allows the peer to send credentials + * and verifies them if provided, and `None` disables peer certificate + * verification. + * + * See the documentation for `SSLEngine::setWantClientAuth` for more + * information. + */ +sealed abstract class ClientAuth +object ClientAuth { + case object None extends ClientAuth + case object Want extends ClientAuth + case object Need extends ClientAuth + + def none: ClientAuth = None + def want: ClientAuth = Want + def need: ClientAuth = Need +} diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala index c02effab40..66d4e3f8da 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/BidiFlow.scala @@ -115,11 +115,7 @@ final class BidiFlow[-I1, +O1, -I2, +O2, +Mat](private[stream] override val modu /** * Turn this BidiFlow around by 180 degrees, logically flipping it upside down in a protocol stack. */ - def reversed: BidiFlow[I2, O2, I1, O1, Mat] = { - val ins = shape.inlets - val outs = shape.outlets - new BidiFlow(module.replaceShape(shape.copyFromPorts(ins.reverse, outs.reverse))) - } + def reversed: BidiFlow[I2, O2, I1, O1, Mat] = new BidiFlow(module.replaceShape(shape.reversed)) override def withAttributes(attr: OperationAttributes): BidiFlow[I1, O1, I2, O2, Mat] = new BidiFlow(module.withAttributes(attr).wrap()) diff --git a/akka-stream/src/main/scala/akka/stream/ssl/SslTlsCipher.scala b/akka-stream/src/main/scala/akka/stream/ssl/SslTlsCipher.scala deleted file mode 100644 index 6857f719b8..0000000000 --- a/akka-stream/src/main/scala/akka/stream/ssl/SslTlsCipher.scala +++ /dev/null @@ -1,376 +0,0 @@ -/** - * Copyright (C) 2009-2014 Typesafe Inc. - */ - -package akka.stream.ssl - -import java.nio.ByteBuffer -import java.security.Principal -import java.security.cert.Certificate -import javax.net.ssl.SSLEngineResult.HandshakeStatus._ -import javax.net.ssl.SSLEngineResult.Status._ -import javax.net.ssl.SSLEngine -import javax.net.ssl.SSLEngineResult -import javax.net.ssl.SSLPeerUnverifiedException -import javax.net.ssl.SSLSession - -import akka.actor.Actor -import akka.actor.ActorLogging -import akka.actor.ActorRef -import akka.stream.ActorFlowMaterializerSettings -import akka.stream.impl._ -import akka.util.ByteString -import akka.util.ByteStringBuilder -import org.reactivestreams.Publisher -import org.reactivestreams.Subscriber - -import scala.annotation.tailrec - -object SslTlsCipher { - - /** - * An established SSL session. - */ - final case class InboundSession( - sessionInfo: SessionInfo, - data: Publisher[ByteString]) - - /** - * A request to establish an SSL session. - * FIXME Not used right now since there is only one session established - */ - final case class OutboundSession( - negotiation: SessionNegotiation, - data: Subscriber[ByteString]) - - /** - * Information about the established SSL session. - */ - final case class SessionInfo( - cipherSuite: String, - localCertificates: List[Certificate], - localPrincipal: Option[Principal], - peerCertificates: List[Certificate], - peerPrincipal: Option[Principal]) - - object SessionInfo { - - def apply(engine: SSLEngine): SessionInfo = - apply(engine.getSession) - - def apply(session: SSLSession): SessionInfo = { - val localCertificates = Option(session.getLocalCertificates).map { _.toList } getOrElse Nil - val localPrincipal = Option(session.getLocalPrincipal) - val peerCertificates = - try session.getPeerCertificates.toList - catch { case e: SSLPeerUnverifiedException ⇒ Nil } - val peerPrincipal = - try Option(session.getPeerPrincipal) - catch { case e: SSLPeerUnverifiedException ⇒ None } - SessionInfo(session.getCipherSuite, localCertificates, localPrincipal, peerCertificates, peerPrincipal) - } - } - - /** - * Information needed to establish an SSL session. - */ - final case class SessionNegotiation(engine: SSLEngine) -} - -final case class SslTlsCipher( - sessionInbound: Publisher[SslTlsCipher.InboundSession], - // FIXME We only have one session, and the SessionNegotiation is passed in via the constructor. - // This should really be a Subscriber[SslTlsCipher.OutboundSession] - plainTextOutbound: Subscriber[ByteString], - cipherTextInbound: Subscriber[ByteString], - cipherTextOutbound: Publisher[ByteString]) - -object SslTlsCipherActor { - val EmptyByteArray = Array.empty[Byte] - val EmptyByteBuffer = ByteBuffer.wrap(EmptyByteArray) -} - -class SslTlsCipherActor(val requester: ActorRef, val sessionNegotioation: SslTlsCipher.SessionNegotiation, tracing: Boolean) - extends Actor - with ActorLogging - with Pump - with MultiStreamOutputProcessorLike - with MultiStreamInputProcessorLike { - - override val subscriptionTimeoutSettings = ActorFlowMaterializerSettings(context.system).subscriptionTimeoutSettings - - def this(requester: ActorRef, sessionNegotioation: SslTlsCipher.SessionNegotiation) = - this(requester, sessionNegotioation, false) - - import MultiStreamInputProcessor.SubstreamSubscriber - import SslTlsCipherActor._ - - private var _nextId = 0L - protected def nextId(): Long = { _nextId += 1; _nextId } - override protected val inputBufferSize = 1 - - // The cipherTextInput (Subscriber[ByteString]) - val inboundCipherTextInput = createSubstreamInput() - - // The cipherTextOutput (Publisher[ByteString]) - val outboundCipherTextOutput = createSubstreamOutput() - - // The Publisher[SslTlsCipher.InboundSession] - // FIXME For now there is only one session ever exposed - val inboundSessionOutput = createSubstreamOutput() - - // The read side for the user (Publisher[ByteString]) - // FIXME For now there is only one session ever exposed - val inboundPlaintextOutput = createSubstreamOutput() - - // The write side for the user (Subscriber[ByteString]) - // FIXME For now there is only one session ever exposed - val outboundPlaintextInput = createSubstreamInput() - - // Plaintext bytes to be encrypted - var plaintextOutboundBytes = EmptyByteBuffer - val plaintextOutboundBytesPending = new TransferState { - override def isReady = plaintextOutboundBytes.hasRemaining - override def isCompleted = false - } - - // Encrypted bytes to be sent - val cipherTextOutboundBytes = new ByteStringBuilder - - // Encrypted bytes to be decrypted - var cipherTextInboundBytes = EmptyByteBuffer - val cipherTextInboundBytesPending = new TransferState { - override def isReady = cipherTextInboundBytes.hasRemaining - override def isCompleted = false - } - - // Plaintext bytes to be received - val plaintextInboundBytes = new ByteStringBuilder - - // FIXME: Change this into a pool of ByteBuffer later - // These are Nettys default values - // 16665 + 1024 (room for compressed data) + 1024 (for OpenJDK compatibility) - val temporaryBuffer = ByteBuffer.allocate(16665 + 2048) - - val engine: SSLEngine = sessionNegotioation.engine - - def doWrap(tempBuf: ByteBuffer): SSLEngineResult = { - tempBuf.clear() - if (tracing) log.debug("before wrap {}", plaintextOutboundBytes.remaining) - val result = engine.wrap(plaintextOutboundBytes, tempBuf) - if (tracing) log.debug("after wrap {}", plaintextOutboundBytes.remaining) - tempBuf.flip() - if (tempBuf.hasRemaining) { - val bs = ByteString(tempBuf) - if (tracing) log.debug("wrap Enqueue cipher bytes {}", bs) - cipherTextOutboundBytes ++= bs - } - result - } - - def doUnwrap(tempBuf: ByteBuffer): SSLEngineResult = { - tempBuf.clear() - if (tracing) log.debug("before unwrap {}", cipherTextInboundBytes.remaining) - val result = engine.unwrap(cipherTextInboundBytes, tempBuf) - if (tracing) log.debug("after unwrap {}", cipherTextInboundBytes.remaining) - tempBuf.flip() - if (tempBuf.hasRemaining) { - val bs = ByteString(tempBuf) - if (tracing) log.debug("unwrap Enqueue cipher bytes {}", bs) - plaintextInboundBytes ++= bs - } - result - } - - def enqueueCipherInputBytes(data: ByteString): Unit = { - cipherTextInboundBytes = - if (cipherTextInboundBytes.hasRemaining) { - val buffer = ByteBuffer.allocate(cipherTextInboundBytes.remaining + data.size) - buffer.put(cipherTextInboundBytes) - data.copyToBuffer(buffer) - buffer.flip() - buffer - } else data.toByteBuffer - } - - def writeCipherTextOutboundBytes() = { - if (cipherTextOutboundBytes.length > 0) { - val bs = cipherTextOutboundBytes.result() - cipherTextOutboundBytes.clear() - outboundCipherTextOutput.enqueueOutputElement(bs) - } - } - - def writePlaintextInboundBytes() = { - if (plaintextInboundBytes.length > 0) { - val bs = plaintextInboundBytes.result() - plaintextInboundBytes.clear() - inboundPlaintextOutput.enqueueOutputElement(bs) - } - } - - @tailrec - private def runDelegatedTasks(): Unit = { - val task = engine.getDelegatedTask - if (task != null) { - if (tracing) log.debug("Running delegated task {}", task) - task.run() - runDelegatedTasks() - } - } - - def publishSSLSessionEstablished(): Unit = { - import SslTlsCipher._ - val info = SessionInfo(engine) - val is = InboundSession(info, inboundPlaintextOutput.asInstanceOf[Publisher[ByteString]]) - if (tracing) log.debug("#### Handshake done!") - inboundSessionOutput.enqueueOutputElement(is) - } - - val unwrapPhase: TransferPhase = TransferPhase(inboundCipherTextInput.NeedsInput || cipherTextInboundBytesPending) { () ⇒ - if (tracing) log.debug("### UNWRAP") - if (inboundCipherTextInput.NeedsInput.isReady) - enqueueCipherInputBytes(inboundCipherTextInput.dequeueInputElement().asInstanceOf[ByteString]) - val result = doUnwrap(temporaryBuffer) - val rs = result.getStatus - if (tracing) log.debug("## UNWRAP {}", rs) - val hs = result.getHandshakeStatus - val next = rs match { - case OK ⇒ - handshakePhase(hs) - case CLOSED ⇒ if (!engine.isInboundDone) encryptionPhase else completedPhase - case BUFFER_OVERFLOW ⇒ throw new IllegalStateException // the SslBufferPool should make sure that buffers are never too small - case BUFFER_UNDERFLOW ⇒ throw new IllegalStateException // should never appear as a result of a wrap - } - nextPhase(next) - } - - val wrapPhase: TransferPhase = TransferPhase(outboundCipherTextOutput.NeedsDemand) { () ⇒ - if (tracing) log.debug("### WRAP") - val result = doWrap(temporaryBuffer) - val rs = result.getStatus - if (tracing) log.debug("## WRAP {}", rs) - val hs = result.getHandshakeStatus - val next = rs match { - case OK ⇒ - writeCipherTextOutboundBytes() - handshakePhase(hs) - case CLOSED ⇒ if (!engine.isInboundDone) decryptionPhase else completedPhase - case BUFFER_OVERFLOW ⇒ throw new IllegalStateException // the SslBufferPool should make sure that buffers are never too small - case BUFFER_UNDERFLOW ⇒ throw new IllegalStateException // should never appear as a result of a wrap - } - nextPhase(next) - } - - def handshakePhase(hs: SSLEngineResult.HandshakeStatus): TransferPhase = { - if (tracing) log.debug("### HS {}", hs) - hs match { - case status @ (NOT_HANDSHAKING | FINISHED) ⇒ - if (status == FINISHED) publishSSLSessionEstablished() - engineRunningPhase - case NEED_WRAP ⇒ wrapPhase - case NEED_UNWRAP ⇒ unwrapPhase - case NEED_TASK ⇒ - runDelegatedTasks() - engine.getHandshakeStatus match { - case NEED_WRAP ⇒ wrapPhase - case NEED_UNWRAP ⇒ unwrapPhase - case x ⇒ throw new IllegalStateException(s"Bad Handshake status $x") - } - } - } - - val waitForHandshakeStartPhase: TransferPhase = TransferPhase((outboundPlaintextInput.NeedsInput || inboundCipherTextInput.NeedsInput) && inboundSessionOutput.NeedsDemand) { () ⇒ - if (tracing) log.debug("#### Starting Handshake") - engine.beginHandshake() - nextPhase(handshakePhase(engine.getHandshakeStatus)) - } - - val encryptionInputAvailable = outboundPlaintextInput.NeedsInput || plaintextOutboundBytesPending - val decryptionInputAvailable = inboundCipherTextInput.NeedsInput || cipherTextInboundBytesPending - - val canEncrypt = encryptionInputAvailable && outboundCipherTextOutput.NeedsDemand - val canDecrypt = decryptionInputAvailable && inboundPlaintextOutput.NeedsDemand - - val engineRunningPhase: TransferPhase = TransferPhase(canEncrypt || canDecrypt) { () ⇒ - if (tracing) log.debug("#### Engine running") - if (canEncrypt.isExecutable) { - nextPhase(encryptionPhase) - } else { - nextPhase(decryptionPhase) - } - } - - val encryptionPhase: TransferPhase = TransferPhase(canEncrypt) { () ⇒ - if (tracing) log.debug("### Encrypting") - if (!plaintextOutboundBytesPending.isReady && outboundPlaintextInput.inputsAvailable) { - val elem = outboundPlaintextInput.dequeueInputElement().asInstanceOf[ByteString] - plaintextOutboundBytes = elem.asByteBuffer - } - val result = doWrap(temporaryBuffer) - val rs = result.getStatus - if (tracing) log.debug("## Encrypting {}", rs) - val hs = result.getHandshakeStatus - val next = rs match { - case OK ⇒ - if (hs == NOT_HANDSHAKING) { - writeCipherTextOutboundBytes() - engineRunningPhase - } else handshakePhase(hs) - case CLOSED ⇒ if (!engine.isInboundDone) decryptionPhase else completedPhase - case BUFFER_OVERFLOW ⇒ throw new IllegalStateException // the SslBufferPool should make sure that buffers are never too small - case BUFFER_UNDERFLOW ⇒ throw new IllegalStateException // should never appear as a result of a wrap - } - nextPhase(next) - } - - val decryptionPhase: TransferPhase = TransferPhase(canDecrypt) { () ⇒ - if (tracing) log.debug("### Decrypting") - if (inboundCipherTextInput.NeedsInput.isReady) { - val elem = inboundCipherTextInput.dequeueInputElement().asInstanceOf[ByteString] - enqueueCipherInputBytes(elem) - } - val result = doUnwrap(temporaryBuffer) - val rs = result.getStatus - if (tracing) log.debug("## Decrypting {}", rs) - val hs = result.getHandshakeStatus - val next = rs match { - case OK ⇒ - if (hs == NOT_HANDSHAKING) { - writePlaintextInboundBytes() - engineRunningPhase - } else handshakePhase(hs) - case CLOSED ⇒ if (!engine.isOutboundDone) encryptionPhase else completedPhase - case BUFFER_OVERFLOW ⇒ throw new IllegalStateException // the SslBufferPool should make sure that buffers are never too small - case BUFFER_UNDERFLOW ⇒ throw new IllegalStateException // should never appear as a result of a wrap - } - nextPhase(next) - } - - nextPhase(waitForHandshakeStartPhase) - - override def preStart() { - val plainTextInput = inboundSessionOutput.asInstanceOf[Publisher[SslTlsCipher.InboundSession]] - val plainTextOutput = new SubstreamSubscriber[ByteString](self, outboundPlaintextInput.key) - val cipherTextInput = new SubstreamSubscriber[ByteString](self, inboundCipherTextInput.key) - val cipherTextOutput = outboundCipherTextOutput.asInstanceOf[Publisher[ByteString]] - requester ! SslTlsCipher(plainTextInput, plainTextOutput, cipherTextInput, cipherTextOutput) - } - - override def receive = inputSubstreamManagement orElse outputSubstreamManagement - - protected def fail(e: Throwable): Unit = { - // FIXME: escalate to supervisor - if (tracing) log.debug("fail {} due to: {}", self, e.getMessage) - failInputs(e) - failOutputs(e) - context.stop(self) - } - - override protected def pumpFailed(e: Throwable): Unit = fail(e) - - override protected def pumpFinished(): Unit = { - finishInputs() - finishOutputs() - } -}