diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala index 3e7173b091..40eb176cc9 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala @@ -9,19 +9,27 @@ import akka.stream.scaladsl.Tcp.OutgoingConnection import scala.collection.immutable import scala.concurrent.{ Future, Await } import akka.io.Tcp._ +import akka.stream.{ BindFailedException, ActorFlowMaterializer, ActorFlowMaterializerSettings, StreamTcpException } +import akka.stream.scaladsl.Tcp.IncomingConnection +import akka.stream.scaladsl.{ Flow, _ } +import akka.stream.testkit.TestUtils.temporaryServerAddress +import akka.stream.testkit.Utils._ +import akka.stream.testkit._ +import akka.util.{ Helpers, ByteString } +import scala.collection.immutable import akka.stream.{ ActorFlowMaterializer, StreamTcpException, BindFailedException } import scala.concurrent.Await import scala.concurrent.duration._ -import akka.util.{ Helpers, ByteString } +import akka.util.ByteString import akka.stream.scaladsl.Flow import akka.stream.testkit._ import akka.stream.testkit.Utils._ import akka.stream.scaladsl._ import akka.stream.testkit.TestUtils.temporaryServerAddress -class TcpSpec extends AkkaSpec("akka.io.tcp.windows-connection-abort-workaround-enabled=auto") with TcpHelper { +class TcpSpec extends AkkaSpec("akka.io.tcp.windows-connection-abort-workaround-enabled=auto\nakka.stream.subscription-timeout.timeout = 3s") with TcpHelper { import akka.stream.io.TcpHelper._ var demand = 0L @@ -494,6 +502,38 @@ class TcpSpec extends AkkaSpec("akka.io.tcp.windows-connection-abort-workaround- Await.result(binding4.unbind(), 1.second) } + "not shut down connections after the connection stream cancelled" in assertAllStagesStopped { + val address = temporaryServerAddress() + Tcp().bind(address.getHostName, address.getPort).take(1).runForeach(_.flow.join(Flow[ByteString]).run()) + + val total = Source(immutable.Iterable.fill(1000)(ByteString(0))) + .via(Tcp().outgoingConnection(address)) + .runFold(0)(_ + _.size) + + Await.result(total, 3.seconds) should ===(1000) + } + + "shut down properly even if some accepted connection Flows have not been subscribed to" in assertAllStagesStopped { + val address = temporaryServerAddress() + val takeTwoAndDropSecond = Flow[IncomingConnection].grouped(2).take(1).map(_.head) + Tcp().bind(address.getHostName, address.getPort) + .via(takeTwoAndDropSecond) + .runForeach(_.flow.join(Flow[ByteString]).run()) + + val folder = Source(immutable.Iterable.fill(1000)(ByteString(0))) + .via(Tcp().outgoingConnection(address)) + .toMat(Sink.fold(0)(_ + _.size))(Keep.right) + + val total = folder.run() + val rejected = folder.run() + + Await.result(total, 3.seconds) should ===(1000) + + a[StreamTcpException] should be thrownBy { + Await.result(rejected, 5.seconds) should ===(1000) + } + } + } def validateServerClientCommunication(testData: ByteString, diff --git a/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala b/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala index 170893d4bc..8102bd1d08 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/ActorProcessor.scala @@ -172,6 +172,8 @@ private[akka] class SimpleOutputs(val actor: ActorRef, val pump: Pump) extends D override def subreceive = _subreceive private val _subreceive = new SubReceive(waitingExposedPublisher) + def isSubscribed = subscriber ne null + def enqueueOutputElement(elem: Any): Unit = { ReactiveStreamsCompliance.requireNonNullElement(elem) downstreamDemand -= 1 diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TcpConnectionStream.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TcpConnectionStream.scala index bf18bbe02b..f1512ae04f 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/TcpConnectionStream.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/TcpConnectionStream.scala @@ -5,16 +5,16 @@ package akka.stream.impl.io import java.net.InetSocketAddress import akka.io.{ IO, Tcp } +import akka.stream.impl.io.StreamTcpManager.ExposedProcessor import scala.concurrent.Promise -import scala.util.control.NoStackTrace import akka.actor._ import akka.util.ByteString import akka.io.Tcp._ -import akka.stream.ActorFlowMaterializerSettings -import akka.stream.StreamTcpException -import org.reactivestreams.Processor +import akka.stream.{ StreamSubscriptionTimeoutSettings, ActorFlowMaterializerSettings, StreamTcpException } +import org.reactivestreams.{ Publisher, Processor } import akka.stream.impl._ -import akka.actor.ActorLogging + +import scala.util.control.NoStackTrace /** * INTERNAL API @@ -33,6 +33,7 @@ private[akka] object TcpStreamActor { def inboundProps(connection: ActorRef, halfClose: Boolean, settings: ActorFlowMaterializerSettings): Props = Props(new InboundTcpStreamActor(connection, halfClose, settings)).withDispatcher(settings.dispatcher).withDeploy(Deploy.local) + case object SubscriptionTimeout extends NoSerializationVerificationNeeded } /** @@ -47,7 +48,7 @@ private[akka] abstract class TcpStreamActor(val settings: ActorFlowMaterializerS override def inputOnError(e: Throwable): Unit = fail(e) } - val primaryOutputs: Outputs = new SimpleOutputs(self, readPump) + val primaryOutputs: SimpleOutputs = new SimpleOutputs(self, readPump) def fullClose: Boolean = !halfClose @@ -216,7 +217,14 @@ private[akka] abstract class TcpStreamActor(val settings: ActorFlowMaterializerS final override def receive = new ExposedPublisherReceive(activeReceive, unhandled) { override def receiveExposedPublisher(ep: ExposedPublisher): Unit = { + import context.dispatcher primaryOutputs.subreceive(ep) + subscriptionTimer = Some( + context.system.scheduler.scheduleOnce( + settings.subscriptionTimeoutSettings.timeout, + self, + SubscriptionTimeout)) + context become activeReceive } } @@ -226,7 +234,8 @@ private[akka] abstract class TcpStreamActor(val settings: ActorFlowMaterializerS primaryOutputs.subreceive orElse tcpInputs.subreceive orElse tcpOutputs.subreceive orElse - commonCloseHandling + commonCloseHandling orElse + handleSubscriptionTimeout def commonCloseHandling: Receive = { case Terminated(_) ⇒ fail(new StreamTcpException("The connection actor has terminated. Stopping now.")) @@ -240,9 +249,20 @@ private[akka] abstract class TcpStreamActor(val settings: ActorFlowMaterializerS case Aborted ⇒ fail(new StreamTcpException("The connection has been aborted")) } + def handleSubscriptionTimeout: Receive = { + case SubscriptionTimeout ⇒ + val millis = settings.subscriptionTimeoutSettings.timeout.toMillis + if (!primaryOutputs.isSubscribed) { + fail(new SubscriptionTimeoutException(s"Publisher was not attached to upstream within deadline (${millis}) ms") with NoStackTrace) + context.stop(self) + } + } + readPump.nextPhase(readPump.running) writePump.nextPhase(writePump.running) + var subscriptionTimer: Option[Cancellable] = None + def fail(e: Throwable): Unit = { if (settings.debugLogging) log.debug("fail due to: {}", e.getMessage) @@ -263,6 +283,7 @@ private[akka] abstract class TcpStreamActor(val settings: ActorFlowMaterializerS tcpOutputs.complete() primaryInputs.cancel() primaryOutputs.complete() + subscriptionTimer.foreach(_.cancel()) super.postStop() // Remember, we have a Stash } } diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TcpListenStreamActor.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TcpListenStreamActor.scala index 63df82ae4d..2bd462abaf 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/TcpListenStreamActor.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/TcpListenStreamActor.scala @@ -65,7 +65,8 @@ private[akka] class TcpListenStreamActor(localAddressPromise: Promise[InetSocket finished = true incomingConnections.cancel() primaryOutputs.complete() - context.stop(self) + // Stop only after all already accepted connections have been shut down + if (context.children.isEmpty) context.stop(self) } } @@ -123,6 +124,7 @@ private[akka] class TcpListenStreamActor(localAddressPromise: Promise[InetSocket if (!closed && listener != null) listener ! Unbind closed = true pendingConnection = null + pump() } override def dequeueInputElement(): Any = { val elem = pendingConnection @@ -139,11 +141,15 @@ private[akka] class TcpListenStreamActor(localAddressPromise: Promise[InetSocket } } - def activeReceive: Actor.Receive = primaryOutputs.subreceive orElse incomingConnections.subreceive + def activeReceive: Actor.Receive = primaryOutputs.subreceive orElse incomingConnections.subreceive orElse { + case Terminated(_) ⇒ + // If the Source is cancelled, and this was our last child, stop ourselves + if (incomingConnections.isClosed && context.children.isEmpty) context.stop(self) + } def runningPhase = TransferPhase(primaryOutputs.NeedsDemand && incomingConnections.NeedsInput) { () ⇒ val (connected: Connected, connection: ActorRef) = incomingConnections.dequeueInputElement() - val tcpStreamActor = context.actorOf(TcpStreamActor.inboundProps(connection, halfClose, settings)) + val tcpStreamActor = context.watch(context.actorOf(TcpStreamActor.inboundProps(connection, halfClose, settings))) val processor = ActorProcessor[ByteString, ByteString](tcpStreamActor) val conn = StreamTcp.IncomingConnection( connected.localAddress,