diff --git a/akka-stream-testkit/src/test/scala/akka/stream/testkit/StreamTestDefaultMailbox.scala b/akka-stream-testkit/src/test/scala/akka/stream/testkit/StreamTestDefaultMailbox.scala index 5506678d48..d05c709471 100644 --- a/akka-stream-testkit/src/test/scala/akka/stream/testkit/StreamTestDefaultMailbox.scala +++ b/akka-stream-testkit/src/test/scala/akka/stream/testkit/StreamTestDefaultMailbox.scala @@ -3,14 +3,12 @@ package akka.stream.testkit import akka.dispatch.ProducesMessageQueue import akka.dispatch.UnboundedMailbox import akka.dispatch.MessageQueue -import akka.stream.impl.io.StreamTcpManager import com.typesafe.config.Config import akka.actor.ActorSystem import akka.dispatch.MailboxType import akka.actor.ActorRef import akka.actor.ActorRefWithCell import akka.actor.Actor -import akka.stream.impl.io.TcpListenStreamActor /** * INTERNAL API @@ -28,8 +26,7 @@ private[akka] final case class StreamTestDefaultMailbox() extends MailboxType wi val actorClass = r.underlying.props.actorClass assert(actorClass != classOf[Actor], s"Don't use anonymous actor classes, actor class for $r was [${actorClass.getName}]") // StreamTcpManager is allowed to use another dispatcher - val specialCases: Set[Class[_]] = Set(classOf[StreamTcpManager], classOf[TcpListenStreamActor]) - assert(!actorClass.getName.startsWith("akka.stream.") || specialCases(actorClass), + assert(!actorClass.getName.startsWith("akka.stream."), s"$r with actor class [${actorClass.getName}] must not run on default dispatcher in tests. " + "Did you forget to define `props.withDispatcher` when creating the actor? " + "Or did you forget to configure the `akka.stream.materializer` setting accordingly or force the " + diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/StreamTcpManager.scala b/akka-stream/src/main/scala/akka/stream/impl/io/StreamTcpManager.scala deleted file mode 100644 index 18ab1a3b61..0000000000 --- a/akka-stream/src/main/scala/akka/stream/impl/io/StreamTcpManager.scala +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright (C) 2014 Typesafe Inc. - */ -package akka.stream.impl.io - -import java.net.InetSocketAddress -import java.net.URLEncoder -import scala.collection.immutable -import scala.concurrent.Future -import scala.concurrent.Promise -import scala.concurrent.duration.Duration -import scala.concurrent.duration.FiniteDuration -import akka.actor.{ NoSerializationVerificationNeeded, Actor, DeadLetterSuppression } -import akka.io.Inet.SocketOption -import akka.io.Tcp -import akka.stream.ActorMaterializerSettings -import akka.stream.impl.ActorProcessor -import akka.stream.impl.ActorPublisher -import akka.stream.scaladsl.{ Tcp ⇒ StreamTcp } -import akka.util.ByteString -import org.reactivestreams.Processor -import org.reactivestreams.Subscriber - -/** - * INTERNAL API - */ -private[akka] object StreamTcpManager { - /** - * INTERNAL API - */ - private[akka] final case class Connect( - processorPromise: Promise[Processor[ByteString, ByteString]], - localAddressPromise: Promise[InetSocketAddress], - remoteAddress: InetSocketAddress, - localAddress: Option[InetSocketAddress], - halfClose: Boolean, - options: immutable.Traversable[SocketOption], - connectTimeout: Duration, - idleTimeout: Duration) - extends DeadLetterSuppression with NoSerializationVerificationNeeded - - /** - * INTERNAL API - */ - private[akka] final case class Bind( - localAddressPromise: Promise[InetSocketAddress], - unbindPromise: Promise[() ⇒ Future[Unit]], - flowSubscriber: Subscriber[StreamTcp.IncomingConnection], - endpoint: InetSocketAddress, - backlog: Int, - halfClose: Boolean, - options: immutable.Traversable[SocketOption], - idleTimeout: Duration) - extends DeadLetterSuppression with NoSerializationVerificationNeeded - - /** - * INTERNAL API - */ - private[akka] final case class ExposedProcessor(processor: Processor[ByteString, ByteString]) - extends DeadLetterSuppression with NoSerializationVerificationNeeded - -} - -/** - * INTERNAL API - */ -private[akka] class StreamTcpManager extends Actor { - import StreamTcpManager._ - - var nameCounter = 0 - def encName(prefix: String, endpoint: InetSocketAddress) = { - nameCounter += 1 - s"$prefix-$nameCounter-${URLEncoder.encode(endpoint.toString, "utf-8")}" - } - - def receive: Receive = { - case Connect(processorPromise, localAddressPromise, remoteAddress, localAddress, halfClose, options, connectTimeout, idleTimeout) ⇒ - val connTimeout = connectTimeout match { - case x: FiniteDuration ⇒ Some(x) - case _ ⇒ None - } - val processorActor = context.actorOf(TcpStreamActor.outboundProps(processorPromise, localAddressPromise, halfClose, idleTimeout, - Tcp.Connect(remoteAddress, localAddress, options, connTimeout, pullMode = true), - materializerSettings = ActorMaterializerSettings(context.system)), name = encName("client", remoteAddress)) - processorActor ! ExposedProcessor(ActorProcessor[ByteString, ByteString](processorActor)) - - case Bind(localAddressPromise, unbindPromise, flowSubscriber, endpoint, backlog, halfClose, options, idleTimeout) ⇒ - val props = TcpListenStreamActor.props(localAddressPromise, unbindPromise, flowSubscriber, halfClose, idleTimeout, - Tcp.Bind(context.system.deadLetters, endpoint, backlog, options, pullMode = true), - ActorMaterializerSettings(context.system)) - .withDispatcher(context.props.dispatcher) - val publisherActor = context.actorOf(props, name = encName("server", endpoint)) - // this sends the ExposedPublisher message to the publisher actor automatically - ActorPublisher[Any](publisherActor) - } -} diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TcpConnectionStream.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TcpConnectionStream.scala deleted file mode 100644 index cf0b0b887c..0000000000 --- a/akka-stream/src/main/scala/akka/stream/impl/io/TcpConnectionStream.scala +++ /dev/null @@ -1,345 +0,0 @@ -/** - * Copyright (C) 2009-2014 Typesafe Inc. - */ -package akka.stream.impl.io - -import java.net.InetSocketAddress -import akka.io.{ IO, Tcp } -import scala.concurrent.Promise -import akka.actor._ -import akka.util.ByteString -import akka.io.Tcp._ -import akka.stream.{ AbruptTerminationException, ActorMaterializerSettings, StreamTcpException } -import org.reactivestreams.Processor -import akka.stream.impl._ - -import scala.concurrent.duration.Duration -import scala.util.control.NoStackTrace - -/** - * INTERNAL API - */ -private[akka] object TcpStreamActor { - case object WriteAck extends Tcp.Event - - def outboundProps(processorPromise: Promise[Processor[ByteString, ByteString]], - localAddressPromise: Promise[InetSocketAddress], - halfClose: Boolean, - idleTimeout: Duration, - connectCmd: Connect, - materializerSettings: ActorMaterializerSettings): Props = - Props(new OutboundTcpStreamActor(processorPromise, localAddressPromise, halfClose, idleTimeout, connectCmd, - materializerSettings)).withDispatcher(materializerSettings.dispatcher).withDeploy(Deploy.local) - - def inboundProps(connection: ActorRef, halfClose: Boolean, settings: ActorMaterializerSettings): Props = - Props(new InboundTcpStreamActor(connection, halfClose, settings)).withDispatcher(settings.dispatcher).withDeploy(Deploy.local) - - case object SubscriptionTimeout extends NoSerializationVerificationNeeded -} - -/** - * INTERNAL API - */ -private[akka] abstract class TcpStreamActor(val settings: ActorMaterializerSettings, halfClose: Boolean) extends Actor - with ActorLogging { - - import TcpStreamActor._ - - val primaryInputs: Inputs = new BatchingInputBuffer(settings.initialInputBufferSize, writePump) { - override def inputOnError(e: Throwable): Unit = fail(e) - } - - val primaryOutputs: SimpleOutputs = new SimpleOutputs(self, readPump) - - def fullClose: Boolean = !halfClose - - object tcpInputs extends DefaultInputTransferStates { - private var closed: Boolean = false - private var pendingElement: ByteString = null - private var connection: ActorRef = _ - - val subreceive = new SubReceive(Actor.emptyBehavior) - - def setConnection(c: ActorRef): Unit = { - connection = c - // Prefetch - c ! ResumeReading - subreceive.become(handleRead) - readPump.pump() - } - - def handleRead: Receive = { - case Received(data) ⇒ - if (closed) connection ! ResumeReading - else { - pendingElement = data - readPump.pump() - } - case ConfirmedClosed ⇒ - cancel() - readPump.pump() - case PeerClosed ⇒ - cancel() - readPump.pump() - } - - override def inputsAvailable: Boolean = pendingElement ne null - override def inputsDepleted: Boolean = closed && !inputsAvailable - override def isClosed: Boolean = closed - - override def cancel(): Unit = { - if (!closed) { - closed = true - pendingElement = null - if (!tcpOutputs.isFlushed && (connection ne null)) connection ! ResumeReading - } - } - - override def dequeueInputElement(): Any = { - val elem = pendingElement - pendingElement = null - connection ! ResumeReading - elem - } - - } - - object tcpOutputs extends DefaultOutputTransferStates { - private var closed: Boolean = false - private var lastWriteAcked = true - private var connection: ActorRef = _ - - def isClosed: Boolean = closed - // Full-close mode needs to wait for the last write Ack before sending Close to avoid doing a connection reset - def isFlushed: Boolean = closed && (halfClose || lastWriteAcked) - - private def initialized: Boolean = connection ne null - - def setConnection(c: ActorRef): Unit = { - connection = c - writePump.pump() - subreceive.become(handleWrite) - } - - val subreceive = new SubReceive(Actor.emptyBehavior) - - def handleWrite: Receive = { - case WriteAck ⇒ - lastWriteAcked = true - if (fullClose && closed) { - // Finish the closing after the last write has been flushed in full close mode. - connection ! Close - tryShutdown() - } - writePump.pump() - - } - - override def error(e: Throwable): Unit = { - if (!closed && initialized) connection ! Abort - closed = true - } - - override def complete(): Unit = { - if (!closed && initialized) { - closed = true - - if (halfClose) { - if (tcpInputs.isClosed) { - // Reading has stopped, either because of cancel, or PeerClosed, just Close now - connection ! Close - tryShutdown() - } else { - // We still read, so we only close the write side - connection ! ConfirmedClose - } - } else { - if (lastWriteAcked) { - // No pending writes, close now - connection ! Close - tryShutdown() - } - // Else wait for final Ack (see handleWrite) - } - } - } - - override def cancel(): Unit = complete() - - override def enqueueOutputElement(elem: Any): Unit = { - ReactiveStreamsCompliance.requireNonNullElement(elem) - connection ! Write(elem.asInstanceOf[ByteString], WriteAck) - lastWriteAcked = false - } - - override def demandAvailable: Boolean = lastWriteAcked - } - - object writePump extends Pump { - - def running = TransferPhase(primaryInputs.NeedsInput && tcpOutputs.NeedsDemand) { () ⇒ - var batch = ByteString.empty - while (primaryInputs.inputsAvailable) batch ++= primaryInputs.dequeueInputElement().asInstanceOf[ByteString] - tcpOutputs.enqueueOutputElement(batch) - } - - override protected def pumpFinished(): Unit = { - if (fullClose) { - // In full close mode we shut down the read size immediately once the write side is finished - tcpInputs.cancel() - primaryOutputs.complete() - readPump.pump() - } - tcpOutputs.complete() - primaryInputs.cancel() - tryShutdown() - } - override protected def pumpFailed(e: Throwable): Unit = fail(e) - } - - object readPump extends Pump { - - def running = TransferPhase(tcpInputs.NeedsInput && primaryOutputs.NeedsDemand) { () ⇒ - primaryOutputs.enqueueOutputElement(tcpInputs.dequeueInputElement()) - } - - override protected def pumpFinished(): Unit = { - tcpInputs.cancel() - primaryOutputs.complete() - tryShutdown() - } - override protected def pumpFailed(e: Throwable): Unit = fail(e) - } - - final override def receive = new ExposedPublisherReceive(activeReceive, unhandled) { - override def receiveExposedPublisher(ep: ExposedPublisher): Unit = { - import context.dispatcher - primaryOutputs.subreceive(ep) - subscriptionTimer = Some( - context.system.scheduler.scheduleOnce( - settings.subscriptionTimeoutSettings.timeout, - self, - SubscriptionTimeout)) - - context become activeReceive - } - } - - def activeReceive = - primaryInputs.subreceive orElse - primaryOutputs.subreceive orElse - tcpInputs.subreceive orElse - tcpOutputs.subreceive orElse - commonCloseHandling orElse - handleSubscriptionTimeout - - def commonCloseHandling: Receive = { - case Terminated(_) ⇒ fail(new StreamTcpException("The connection actor has terminated. Stopping now.")) - case Closed ⇒ - tcpInputs.cancel() - tcpOutputs.complete() - writePump.pump() - readPump.pump() - case ErrorClosed(cause) ⇒ fail(new StreamTcpException(s"The connection closed with error $cause")) - case CommandFailed(cmd) ⇒ fail(new StreamTcpException(s"Tcp command [$cmd] failed")) - case Aborted ⇒ fail(new StreamTcpException("The connection has been aborted")) - } - - def handleSubscriptionTimeout: Receive = { - case SubscriptionTimeout ⇒ - val millis = settings.subscriptionTimeoutSettings.timeout.toMillis - if (!primaryOutputs.isSubscribed) { - fail(new SubscriptionTimeoutException(s"Publisher was not attached to upstream within deadline ($millis) ms") with NoStackTrace) - context.stop(self) - } - } - - readPump.nextPhase(readPump.running) - writePump.nextPhase(writePump.running) - - var subscriptionTimer: Option[Cancellable] = None - - def fail(e: Throwable): Unit = { - if (settings.debugLogging) - log.debug("fail due to: {}", e.getMessage) - tcpInputs.cancel() - tcpOutputs.error(e) - primaryInputs.cancel() - primaryOutputs.error(e) - tryShutdown() - } - - def tryShutdown(): Unit = - if (primaryInputs.isClosed && tcpInputs.isClosed && tcpOutputs.isFlushed) - context.stop(self) - - override def postStop(): Unit = { - // Close if it has not yet been done - val abruptTermination = AbruptTerminationException(self) - tcpInputs.cancel() - tcpOutputs.error(abruptTermination) - primaryInputs.cancel() - primaryOutputs.error(abruptTermination) - subscriptionTimer.foreach(_.cancel()) - super.postStop() // Remember, we have a Stash - } -} - -/** - * INTERNAL API - */ -private[akka] class InboundTcpStreamActor( - val connection: ActorRef, _halfClose: Boolean, _settings: ActorMaterializerSettings) - extends TcpStreamActor(_settings, _halfClose) { - context.watch(connection) - - connection ! Register(self, keepOpenOnPeerClosed = true, useResumeWriting = false) - tcpInputs.setConnection(connection) - tcpOutputs.setConnection(connection) -} - -/** - * INTERNAL API - */ -private[akka] class OutboundTcpStreamActor(processorPromise: Promise[Processor[ByteString, ByteString]], - localAddressPromise: Promise[InetSocketAddress], - _halfClose: Boolean, - idleTimeout: Duration, - val connectCmd: Connect, _settings: ActorMaterializerSettings) - extends TcpStreamActor(_settings, _halfClose) { - import context.system - - val initSteps = new SubReceive(waitingExposedProcessor) - - override def activeReceive = initSteps orElse super.activeReceive - - def waitingExposedProcessor: Receive = { - case StreamTcpManager.ExposedProcessor(processor) ⇒ - IO(Tcp) ! connectCmd - initSteps.become(waitConnection(processor)) - } - - def waitConnection(exposedProcessor: Processor[ByteString, ByteString]): Receive = { - case Connected(remoteAddress, localAddress) ⇒ - val connection = sender() - context.watch(connection) - connection ! Register(self, keepOpenOnPeerClosed = true, useResumeWriting = false) - tcpOutputs.setConnection(connection) - tcpInputs.setConnection(connection) - localAddressPromise.success(localAddress) - processorPromise.success(exposedProcessor) - initSteps.become(Actor.emptyBehavior) - - case f: CommandFailed ⇒ - val ex = new StreamTcpException("Connection failed.") - localAddressPromise.failure(ex) - processorPromise.failure(ex) - fail(ex) - } - - override def fail(e: Throwable): Unit = { - processorPromise.tryFailure(e) - localAddressPromise.tryFailure(e) - super.fail(e) - } -} diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TcpListenStreamActor.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TcpListenStreamActor.scala deleted file mode 100644 index b304c8cc85..0000000000 --- a/akka-stream/src/main/scala/akka/stream/impl/io/TcpListenStreamActor.scala +++ /dev/null @@ -1,183 +0,0 @@ -/** - * Copyright (C) 2009-2014 Typesafe Inc. - */ -package akka.stream.impl.io - -import java.net.InetSocketAddress - -import akka.actor._ -import akka.io.Tcp._ -import akka.io.{ IO, Tcp } -import akka.stream.impl._ -import akka.stream.scaladsl.{ Tcp ⇒ StreamTcp, BidiFlow, Flow } -import akka.stream.{ ActorMaterializerSettings, BindFailedException, ConnectionException } -import akka.util.ByteString -import org.reactivestreams.Subscriber - -import scala.concurrent.duration.Duration -import scala.concurrent.{ Future, Promise } - -/** - * INTERNAL API - */ -private[akka] object TcpListenStreamActor { - def props(localAddressPromise: Promise[InetSocketAddress], - unbindPromise: Promise[() ⇒ Future[Unit]], - flowSubscriber: Subscriber[StreamTcp.IncomingConnection], - halfClose: Boolean, - idleTimeout: Duration, - bindCmd: Tcp.Bind, materializerSettings: ActorMaterializerSettings): Props = { - Props(new TcpListenStreamActor(localAddressPromise, unbindPromise, flowSubscriber, halfClose, bindCmd, idleTimeout, materializerSettings)) - .withDeploy(Deploy.local) - } -} - -/** - * INTERNAL API - */ -private[akka] class TcpListenStreamActor(localAddressPromise: Promise[InetSocketAddress], - unbindPromise: Promise[() ⇒ Future[Unit]], - flowSubscriber: Subscriber[StreamTcp.IncomingConnection], - halfClose: Boolean, - bindCmd: Tcp.Bind, - idleTimeout: Duration, - settings: ActorMaterializerSettings) extends Actor - with Pump with ActorLogging { - import ReactiveStreamsCompliance._ - import context.system - - object primaryOutputs extends SimpleOutputs(self, pump = this) { - - override def waitingExposedPublisher: Actor.Receive = { - case ExposedPublisher(publisher) ⇒ - exposedPublisher = publisher - IO(Tcp) ! bindCmd.copy(handler = self) - subreceive.become(downstreamRunning) - case other ⇒ - throw new IllegalStateException(s"The first message must be ExposedPublisher but was [$other]") - } - - def getExposedPublisher = exposedPublisher - } - - private val unboundPromise = Promise[Unit]() - private var finished = false - - override protected def pumpFinished(): Unit = { - if (!finished) { - finished = true - incomingConnections.cancel() - primaryOutputs.complete() - // Stop only after all already accepted connections have been shut down - if (context.children.isEmpty) context.stop(self) - } - } - - override protected def pumpFailed(e: Throwable): Unit = fail(e) - - val incomingConnections: Inputs = new DefaultInputTransferStates { - var listener: ActorRef = _ - private var closed: Boolean = false - private var pendingConnection: (Connected, ActorRef) = null - - def waitBound: Receive = { - case Bound(localAddress) ⇒ - listener = sender() - nextPhase(runningPhase) - listener ! ResumeAccepting(1) - val target = self - localAddressPromise.success(localAddress) - unbindPromise.success(() ⇒ { target ! Unbind; unboundPromise.future }) - primaryOutputs.getExposedPublisher.subscribe(flowSubscriber.asInstanceOf[Subscriber[Any]]) - subreceive.become(running) - case f: CommandFailed ⇒ - val ex = BindFailedException - localAddressPromise.failure(ex) - unbindPromise.success(() ⇒ Future.successful(())) - try { - tryOnSubscribe(flowSubscriber, CancelledSubscription) - tryOnError(flowSubscriber, ex) - } finally fail(ex) - } - - def running: Receive = { - case c: Connected ⇒ - pendingConnection = (c, sender()) - pump() - case f: CommandFailed ⇒ - val ex = new ConnectionException(s"Command [${f.cmd}] failed") - if (f.cmd.isInstanceOf[Unbind.type]) unboundPromise.tryFailure(BindFailedException) - fail(ex) - case Unbind ⇒ - if (!closed && listener != null) listener ! Unbind - listener = null - pump() - case Unbound ⇒ // If we're unbound then just shut down - cancel() - unboundPromise.trySuccess(()) - pump() - } - - override val subreceive = new SubReceive(waitBound) - - override def inputsAvailable: Boolean = pendingConnection ne null - override def inputsDepleted: Boolean = closed && !inputsAvailable - override def isClosed: Boolean = closed - override def cancel(): Unit = { - if (!closed && listener != null) listener ! Unbind - closed = true - pendingConnection = null - pump() - } - override def dequeueInputElement(): Any = { - val elem = pendingConnection - pendingConnection = null - listener ! ResumeAccepting(1) - elem - } - } - - final override def receive = new ExposedPublisherReceive(activeReceive, unhandled) { - override def receiveExposedPublisher(ep: ExposedPublisher): Unit = { - primaryOutputs.subreceive(ep) - context become activeReceive - } - } - - def activeReceive: Actor.Receive = primaryOutputs.subreceive orElse incomingConnections.subreceive orElse { - case Terminated(_) ⇒ - // If the Source is cancelled, and this was our last child, stop ourselves - if (incomingConnections.isClosed && context.children.isEmpty) context.stop(self) - } - - def runningPhase = TransferPhase(primaryOutputs.NeedsDemand && incomingConnections.NeedsInput) { () ⇒ - val (connected: Connected, connection: ActorRef) = incomingConnections.dequeueInputElement() - val tcpStreamActor = context.watch(context.actorOf(TcpStreamActor.inboundProps(connection, halfClose, settings))) - val processor = ActorProcessor[ByteString, ByteString](tcpStreamActor) - - import scala.concurrent.duration.FiniteDuration - val handler = (idleTimeout match { - case d: FiniteDuration ⇒ Flow[ByteString].join(BidiFlow.bidirectionalIdleTimeout[ByteString, ByteString](d)) - case _ ⇒ Flow[ByteString] - }).via(Flow.fromProcessor(() ⇒ processor)) - - val conn = StreamTcp.IncomingConnection( - connected.localAddress, - connected.remoteAddress, - handler) - primaryOutputs.enqueueOutputElement(conn) - } - - override def postStop(): Unit = { - unboundPromise.trySuccess(()) - primaryOutputs.complete() - super.postStop() - } - - def fail(e: Throwable): Unit = { - if (settings.debugLogging) - log.debug("fail due to: {}", e.getMessage) - incomingConnections.cancel() - primaryOutputs.error(e) - } -} diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala new file mode 100644 index 0000000000..30db86af67 --- /dev/null +++ b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala @@ -0,0 +1,320 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ +package akka.stream.impl.io + +import java.net.InetSocketAddress +import java.util.concurrent.atomic.AtomicBoolean + +import akka.actor.{ ActorRef, Terminated } +import akka.dispatch.ExecutionContexts +import akka.io.Inet.SocketOption +import akka.io.Tcp +import akka.io.Tcp._ +import akka.stream._ +import akka.stream.impl.ReactiveStreamsCompliance +import akka.stream.impl.fusing.GraphStages.Detacher +import akka.stream.scaladsl.Tcp.{ OutgoingConnection, ServerBinding } +import akka.stream.scaladsl.{ BidiFlow, Flow, Tcp ⇒ StreamTcp } +import akka.stream.stage.GraphStageLogic.StageActorRef +import akka.stream.stage._ +import akka.util.ByteString + +import scala.collection.immutable +import scala.concurrent.duration.{ Duration, FiniteDuration } +import scala.concurrent.{ Future, Promise } + +/** + * INTERNAL API + */ +private[stream] class ConnectionSourceStage(val tcpManager: ActorRef, + val endpoint: InetSocketAddress, + val backlog: Int, + val options: immutable.Traversable[SocketOption], + val halfClose: Boolean, + val idleTimeout: Duration) + extends GraphStageWithMaterializedValue[SourceShape[StreamTcp.IncomingConnection], Future[StreamTcp.ServerBinding]] { + + val out: Outlet[StreamTcp.IncomingConnection] = Outlet("IncomingConnections.out") + val shape: SourceShape[StreamTcp.IncomingConnection] = SourceShape(out) + + private val BindTimer = "BindTimer" + + // TODO: Timeout on bind + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[ServerBinding]) = { + val bindingPromise = Promise[ServerBinding] + + val logic = new TimerGraphStageLogic(shape) { + implicit var self: StageActorRef = _ + var listener: ActorRef = _ + var unbindPromise = Promise[Unit]() + + override def preStart(): Unit = { + self = getStageActorRef(receive) + tcpManager ! Tcp.Bind(self, endpoint, backlog, options, pullMode = true) + } + + private def receive(evt: (ActorRef, Any)): Unit = { + val sender = evt._1 + val msg = evt._2 + msg match { + case Bound(localAddress) ⇒ + listener = sender + self.watch(listener) + if (isAvailable(out)) listener ! ResumeAccepting(1) + val target = self + bindingPromise.success(ServerBinding(localAddress)(() ⇒ { target ! Unbind; unbindPromise.future })) + case f: CommandFailed ⇒ + val ex = BindFailedException + bindingPromise.failure(ex) + unbindPromise.success(() ⇒ Future.successful(())) + failStage(ex) + case c: Connected ⇒ + push(out, connectionFor(c, sender)) + case Unbind ⇒ + if (!isClosed(out) && (listener ne null)) listener ! Unbind + case Unbound ⇒ // If we're unbound then just shut down + completeStage() + case Terminated(ref) if ref == listener ⇒ + failStage(new IllegalStateException("IO Listener actor terminated unexpectedly")) + } + } + + setHandler(out, new OutHandler { + override def onPull(): Unit = { + // Ignore if still binding + if (listener ne null) listener ! ResumeAccepting(1) + } + + override def onDownstreamFinish(): Unit = tryUnbind() + }) + + private def connectionFor(connected: Connected, connection: ActorRef): StreamTcp.IncomingConnection = { + val tcpFlow = + Flow.fromGraph(new IncomingConnectionStage(connection, connected.remoteAddress, halfClose)) + .via(new Detacher[ByteString]) // must read ahead for proper completions + + // FIXME: Previous code was wrong, must add new tests + val handler = idleTimeout match { + case d: FiniteDuration ⇒ tcpFlow.join(BidiFlow.bidirectionalIdleTimeout[ByteString, ByteString](d)) + case _ ⇒ tcpFlow + } + + StreamTcp.IncomingConnection( + connected.localAddress, + connected.remoteAddress, + handler) + } + + private def tryUnbind(): Unit = { + if (listener ne null) listener ! Unbind + else completeStage() + } + + override def postStop(): Unit = { + unbindPromise.trySuccess(()) + bindingPromise.tryFailure(new NoSuchElementException("Binding was unbound before it was completely finished")) + } + } + + (logic, bindingPromise.future) + } + +} + +/** + * INTERNAL API + */ +private[stream] object TcpConnectionStage { + case object WriteAck extends Tcp.Event + + trait TcpRole { + def halfClose: Boolean + } + case class Outbound( + manager: ActorRef, + connectCmd: Connect, + localAddressPromise: Promise[InetSocketAddress], + halfClose: Boolean) extends TcpRole + case class Inbound(connection: ActorRef, halfClose: Boolean) extends TcpRole + + /* + * This is a *non-deatched* design, i.e. this does not prefetch itself any of the inputs. It relies on downstream + * stages to provide the necessary prefetch on `bytesOut` and the framework to do the proper prefetch in the buffer + * backing `bytesIn`. If prefetch on `bytesOut` is required (i.e. user stages cannot be trusted) then it is better + * to attach an extra, fused buffer to the end of this flow. Keeping this stage non-detached makes it much simpler and + * easier to maintain and understand. + */ + class TcpStreamLogic(val shape: FlowShape[ByteString, ByteString], val role: TcpRole) extends GraphStageLogic(shape) { + implicit private var self: StageActorRef = _ + + private def bytesIn = shape.inlet + private def bytesOut = shape.outlet + private var connection: ActorRef = _ + + // No reading until role have been decided + setHandler(bytesOut, new OutHandler { + override def onPull(): Unit = () + }) + + override def preStart(): Unit = role match { + case Inbound(conn, _) ⇒ + setHandler(bytesOut, readHandler) + self = getStageActorRef(connected) + connection = conn + self.watch(connection) + connection ! Register(self, keepOpenOnPeerClosed = true, useResumeWriting = false) + pull(bytesIn) + case ob @ Outbound(manager, cmd, _, _) ⇒ + self = getStageActorRef(connecting(ob)) + self.watch(manager) + manager ! cmd + } + + private def connecting(ob: Outbound)(evt: (ActorRef, Any)): Unit = { + val sender = evt._1 + val msg = evt._2 + msg match { + case Terminated(_) ⇒ failStage(new StreamTcpException("The IO manager actor (TCP) has terminated. Stopping now.")) + case CommandFailed(cmd) ⇒ failStage(new StreamTcpException(s"Tcp command [$cmd] failed")) + case c: Connected ⇒ + role.asInstanceOf[Outbound].localAddressPromise.success(c.localAddress) + connection = sender + setHandler(bytesOut, readHandler) + self.unwatch(ob.manager) + self = getStageActorRef(connected) + self.watch(connection) + connection ! Register(self, keepOpenOnPeerClosed = true, useResumeWriting = false) + if (isAvailable(bytesOut)) connection ! ResumeReading + pull(bytesIn) + } + } + + private def connected(evt: (ActorRef, Any)): Unit = { + val sender = evt._1 + val msg = evt._2 + msg match { + case Terminated(_) ⇒ + failStage(new StreamTcpException("The connection actor has terminated. Stopping now.")) + case CommandFailed(cmd) ⇒ failStage(new StreamTcpException(s"Tcp command [$cmd] failed")) + + case ErrorClosed(cause) ⇒ failStage(new StreamTcpException(s"The connection closed with error $cause")) + case Aborted ⇒ failStage(new StreamTcpException("The connection has been aborted")) + case Closed ⇒ completeStage() + case ConfirmedClosed ⇒ completeStage() + case PeerClosed ⇒ complete(bytesOut) + + case Received(data) ⇒ + // Keep on reading even when closed. There is no "close-read-side" in TCP + if (isClosed(bytesOut)) connection ! ResumeReading + else push(bytesOut, data) + + case WriteAck ⇒ if (!isClosed(bytesIn)) pull(bytesIn) + } + } + + val readHandler = new OutHandler { + override def onPull(): Unit = { + connection ! ResumeReading + } + + override def onDownstreamFinish(): Unit = { + if (!isClosed(bytesIn)) connection ! ResumeReading + else { + connection ! Abort + completeStage() + } + } + } + + setHandler(bytesIn, new InHandler { + override def onPush(): Unit = { + val elem = grab(bytesIn) + ReactiveStreamsCompliance.requireNonNullElement(elem) + connection ! Write(elem.asInstanceOf[ByteString], WriteAck) + } + + override def onUpstreamFinish(): Unit = { + // Reading has stopped before, either because of cancel, or PeerClosed, so just Close now + // (or half-close is turned off) + if (isClosed(bytesOut) || !role.halfClose) connection ! Close + // We still read, so we only close the write side + else connection ! ConfirmedClose + } + + override def onUpstreamFailure(ex: Throwable): Unit = { + connection ! Abort + } + }) + + override def keepGoingAfterAllPortsClosed: Boolean = true + + override def postStop(): Unit = role match { + case Outbound(_, _, localAddressPromise, _) ⇒ + // Fail if has not been completed with an address eariler + localAddressPromise.tryFailure(new StreamTcpException("Connection failed.")) + case _ ⇒ + } + } +} + +/** + * INTERNAL API + */ +private[stream] class IncomingConnectionStage(connection: ActorRef, remoteAddress: InetSocketAddress, halfClose: Boolean) + extends GraphStage[FlowShape[ByteString, ByteString]] { + import TcpConnectionStage._ + + private val hasBeenCreated = new AtomicBoolean(false) + + val bytesIn: Inlet[ByteString] = Inlet("IncomingTCP.in") + val bytesOut: Outlet[ByteString] = Outlet("IncomingTCP.out") + val shape: FlowShape[ByteString, ByteString] = FlowShape(bytesIn, bytesOut) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = { + if (hasBeenCreated.get) throw new IllegalStateException("Cannot materialize an incoming connection Flow twice.") + hasBeenCreated.set(true) + + new TcpStreamLogic(shape, Inbound(connection, halfClose)) + } + + override def toString = s"TCP-from($remoteAddress)" +} + +/** + * INTERNAL API + */ +private[stream] class OutgoingConnectionStage(manager: ActorRef, + remoteAddress: InetSocketAddress, + localAddress: Option[InetSocketAddress] = None, + options: immutable.Traversable[SocketOption] = Nil, + halfClose: Boolean = true, + connectTimeout: Duration = Duration.Inf) + + extends GraphStageWithMaterializedValue[FlowShape[ByteString, ByteString], Future[StreamTcp.OutgoingConnection]] { + import TcpConnectionStage._ + + val bytesIn: Inlet[ByteString] = Inlet("IncomingTCP.in") + val bytesOut: Outlet[ByteString] = Outlet("IncomingTCP.out") + val shape: FlowShape[ByteString, ByteString] = FlowShape(bytesIn, bytesOut) + + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[StreamTcp.OutgoingConnection]) = { + + // FIXME: A method like this would make soo much sense on Duration (i.e. toOption) + val connTimeout = connectTimeout match { + case x: FiniteDuration ⇒ Some(x) + case _ ⇒ None + } + + val localAddressPromise = Promise[InetSocketAddress] + val logic = new TcpStreamLogic(shape, Outbound( + manager, + Connect(remoteAddress, localAddress, options, connTimeout, pullMode = true), + localAddressPromise, + halfClose)) + + (logic, localAddressPromise.future.map(OutgoingConnection(remoteAddress, _))(ExecutionContexts.sameThreadExecutionContext)) + } + + override def toString = s"TCP-to($remoteAddress)" +} \ No newline at end of file diff --git a/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala b/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala index dd218b11af..37a1a2b681 100644 --- a/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala +++ b/akka-stream/src/main/scala/akka/stream/scaladsl/Tcp.scala @@ -7,18 +7,15 @@ import java.net.InetSocketAddress import akka.actor._ import akka.io.Inet.SocketOption -import akka.io.{ Tcp ⇒ IoTcp } +import akka.io.{ IO, Tcp ⇒ IoTcp } import akka.stream._ -import akka.stream.impl.ReactiveStreamsCompliance._ -import akka.stream.impl.StreamLayout.Module -import akka.stream.impl._ -import akka.stream.impl.io.{ DelayedInitProcessor, StreamTcpManager } +import akka.stream.impl.fusing.GraphStages.Detacher +import akka.stream.impl.io.{ ConnectionSourceStage, OutgoingConnectionStage } import akka.util.ByteString -import org.reactivestreams.{ Processor, Publisher, Subscriber } import scala.collection.immutable +import scala.concurrent.Future import scala.concurrent.duration.{ Duration, FiniteDuration } -import scala.concurrent.{ Future, Promise } object Tcp extends ExtensionId[Tcp] with ExtensionIdProvider { @@ -65,53 +62,6 @@ object Tcp extends ExtensionId[Tcp] with ExtensionIdProvider { class Tcp(system: ExtendedActorSystem) extends akka.actor.Extension { import Tcp._ - private val manager: ActorRef = system.systemActorOf(Props[StreamTcpManager] - .withDispatcher(IoTcp(system).Settings.ManagementDispatcher).withDeploy(Deploy.local), name = "IO-TCP-STREAM") - - private class BindSource( - val endpoint: InetSocketAddress, - val backlog: Int, - val options: immutable.Traversable[SocketOption], - val halfClose: Boolean, - val idleTimeout: Duration = Duration.Inf, - val attributes: Attributes, - _shape: SourceShape[IncomingConnection]) extends SourceModule[IncomingConnection, Future[ServerBinding]](_shape) { - - override def create(context: MaterializationContext): (Publisher[IncomingConnection], Future[ServerBinding]) = { - val localAddressPromise = Promise[InetSocketAddress]() - val unbindPromise = Promise[() ⇒ Future[Unit]]() - val publisher = new Publisher[IncomingConnection] { - - override def subscribe(s: Subscriber[_ >: IncomingConnection]): Unit = { - requireNonNullSubscriber(s) - manager ! StreamTcpManager.Bind( - localAddressPromise, - unbindPromise, - s.asInstanceOf[Subscriber[IncomingConnection]], - endpoint, - backlog, - halfClose, - options, - idleTimeout) - } - - } - - import system.dispatcher - val bindingFuture = unbindPromise.future.zip(localAddressPromise.future).map { - case (unbindAction, localAddress) ⇒ - ServerBinding(localAddress)(unbindAction) - } - - (publisher, bindingFuture) - } - - override protected def newInstance(s: SourceShape[IncomingConnection]): SourceModule[IncomingConnection, Future[ServerBinding]] = - new BindSource(endpoint, backlog, options, halfClose, idleTimeout, attributes, shape) - override def withAttributes(attr: Attributes): Module = - new BindSource(endpoint, backlog, options, halfClose, idleTimeout, attr, shape) - } - /** * Creates a [[Tcp.ServerBinding]] instance which represents a prospective TCP server binding on the given `endpoint`. * @@ -138,10 +88,14 @@ class Tcp(system: ExtendedActorSystem) extends akka.actor.Extension { backlog: Int = 100, options: immutable.Traversable[SocketOption] = Nil, halfClose: Boolean = false, - idleTimeout: Duration = Duration.Inf): Source[IncomingConnection, Future[ServerBinding]] = { - new Source(new BindSource(new InetSocketAddress(interface, port), backlog, options, halfClose, idleTimeout, - Attributes.none, SourceShape(Outlet("BindSource.out")))) - } + idleTimeout: Duration = Duration.Inf): Source[IncomingConnection, Future[ServerBinding]] = + Source.fromGraph(new ConnectionSourceStage( + IO(IoTcp)(system), + new InetSocketAddress(interface, port), + backlog, + options, + halfClose, + idleTimeout)) /** * Creates a [[Tcp.ServerBinding]] instance which represents a prospective TCP server binding on the given `endpoint` @@ -202,20 +156,18 @@ class Tcp(system: ExtendedActorSystem) extends akka.actor.Extension { connectTimeout: Duration = Duration.Inf, idleTimeout: Duration = Duration.Inf): Flow[ByteString, ByteString, Future[OutgoingConnection]] = { - val timeoutHandling = idleTimeout match { - case d: FiniteDuration ⇒ Flow[ByteString].join(BidiFlow.bidirectionalIdleTimeout[ByteString, ByteString](d)) - case _ ⇒ Flow[ByteString] - } + val tcpFlow = Flow.fromGraph(new OutgoingConnectionStage( + IO(IoTcp)(system), + remoteAddress, + localAddress, + options, + halfClose, + connectTimeout)).via(new Detacher[ByteString]) // must read ahead for proper completions - Flow[ByteString].deprecatedAndThenMat(() ⇒ { - val processorPromise = Promise[Processor[ByteString, ByteString]]() - val localAddressPromise = Promise[InetSocketAddress]() - manager ! StreamTcpManager.Connect(processorPromise, localAddressPromise, remoteAddress, localAddress, halfClose, options, - connectTimeout, idleTimeout) - import system.dispatcher - val outgoingConnection = localAddressPromise.future.map(OutgoingConnection(remoteAddress, _)) - (new DelayedInitProcessor[ByteString, ByteString](processorPromise.future), outgoingConnection) - }).via(timeoutHandling) + idleTimeout match { + case d: FiniteDuration ⇒ tcpFlow.join(BidiFlow.bidirectionalIdleTimeout[ByteString, ByteString](d)) + case _ ⇒ tcpFlow + } } diff --git a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala index 2a0ae4c0c0..dd90160969 100644 --- a/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala +++ b/akka-stream/src/main/scala/akka/stream/stage/GraphStage.scala @@ -3,7 +3,8 @@ */ package akka.stream.stage -import akka.actor.{ Cancellable, DeadLetterSuppression } +import akka.actor._ +import akka.dispatch.sysmsg.{ Unwatch, Watch, DeathWatchNotification, SystemMessage } import akka.stream._ import akka.stream.impl.ReactiveStreamsCompliance import akka.stream.impl.StreamLayout.Module @@ -137,6 +138,32 @@ object GraphStageLogic { private object DoNothing extends (() ⇒ Unit) { def apply(): Unit = () } + + /** + * Minimal actor to work with other actors and watch them in a synchronous ways + */ + class StageActorRef(asyncCallback: AsyncCallback[(ActorRef, Any)], override val path: ActorPath) extends akka.actor.MinimalActorRef { + override def provider: ActorRefProvider = + throw new UnsupportedOperationException("GraphStage MinimalActor does not provide") + + override def !(message: Any)(implicit sender: ActorRef = Actor.noSender): Unit = { + asyncCallback.invoke((sender, message)) + } + + // FIXME: This ActorRef is not watchable! + override def sendSystemMessage(message: SystemMessage): Unit = message match { + case DeathWatchNotification(actorRef, _, _) ⇒ + this.!(Terminated(actorRef)(existenceConfirmed = true, addressTerminated = false)) + case _ ⇒ //ignore all other messages + } + + def watch(actorRef: ActorRef): Unit = + actorRef.asInstanceOf[InternalActorRef].sendSystemMessage(Watch(actorRef.asInstanceOf[InternalActorRef], this)) + + def unwatch(actorRef: ActorRef): Unit = + actorRef.asInstanceOf[InternalActorRef].sendSystemMessage(Unwatch(actorRef.asInstanceOf[InternalActorRef], this)) + + } } /** @@ -703,13 +730,25 @@ abstract class GraphStageLogic private[stream] (val inCount: Int, val outCount: * * This object can be cached and reused within the same [[GraphStageLogic]]. */ - final def getAsyncCallback[T](handler: T ⇒ Unit): AsyncCallback[T] = { + final protected def getAsyncCallback[T](handler: T ⇒ Unit): AsyncCallback[T] = { new AsyncCallback[T] { override def invoke(event: T): Unit = interpreter.onAsyncInput(GraphStageLogic.this, event, handler.asInstanceOf[Any ⇒ Unit]) } } + /** + * Created MinimalActorRef to get messages and watch other actors in synchronous way. + * @param receive - callback that will be + * @return - minimal actor with watch method + */ + // FIXME: I don't like the Pair allocation :( + final protected def getStageActorRef(receive: ((ActorRef, Any)) ⇒ Unit): StageActorRef = + // FIXME: Avoid returning multiple actorrefs. Probably push down this feature to a subclass? + // FIXME: toString is a completely wrong path + new StageActorRef(getAsyncCallback(receive), + ActorMaterializer.downcast(interpreter.materializer).supervisor.path / toString) + // Internal hooks to avoid reliance on user calling super in preStart protected[stream] def beforePreStart(): Unit = ()