From cf6cb5f2d919d2ba751b58083a46284888b67efd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Andr=C3=A9n?= Date: Wed, 15 Feb 2017 14:15:51 +0100 Subject: [PATCH] Eliminate race in connection closed tcp test case #21903 --- .../test/scala/akka/stream/io/TcpSpec.scala | 93 +++++++++++++++---- .../scala/akka/stream/impl/io/TcpStages.scala | 23 +++-- 2 files changed, 91 insertions(+), 25 deletions(-) diff --git a/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala index 3b71b08840..8db445bdbd 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/io/TcpSpec.scala @@ -3,25 +3,26 @@ */ package akka.stream.io -import akka.NotUsed +import akka.{ Done, NotUsed } import akka.actor.{ ActorSystem, Address, Kill } import akka.io.Tcp._ -import akka.stream.scaladsl.Tcp.IncomingConnection +import akka.stream.scaladsl.Tcp.{ IncomingConnection, ServerBinding } import akka.stream.scaladsl.{ Flow, _ } import akka.stream.testkit.TestUtils.temporaryServerAddress import scala.util.control.NonFatal import akka.stream.testkit.Utils._ import akka.stream.testkit._ -import akka.stream.{ ActorMaterializer, BindFailedException, StreamTcpException } +import akka.stream._ import akka.util.{ ByteString, Helpers } import scala.collection.immutable -import scala.concurrent.{ Await, Promise } +import scala.concurrent.{ Await, Future, Promise } import scala.concurrent.duration._ -import java.net.{ BindException, InetSocketAddress } +import java.net._ -import akka.testkit.{ EventFilter, TestLatch } +import akka.testkit.{ EventFilter, TestKit, TestLatch } +import com.typesafe.config.ConfigFactory import org.scalatest.concurrent.PatienceConfiguration.Timeout class TcpSpec extends StreamSpec("akka.stream.materializer.subscription-timeout.timeout = 2s") with TcpHelper { @@ -509,22 +510,76 @@ class TcpSpec extends StreamSpec("akka.stream.materializer.subscription-timeout. } "not shut down connections after the connection stream cancelled" in assertAllStagesStopped { - val address = temporaryServerAddress() - val (futureBinding, _) = Tcp().bind(address.getHostName, address.getPort).take(1).toMat(Sink.foreach { tcp ⇒ - Thread.sleep(1000) // we're testing here to see if it survives such race - tcp.flow.join(Flow[ByteString]).run() - })(Keep.both) - .run() - // make sure server is running first - futureBinding.futureValue + // configure a few timeouts we do not want to hit + val config = ConfigFactory.parseString(""" + akka.actor.serializer-messages = off + akka.io.tcp.register-timeout = 42s + """) + val serverSystem = ActorSystem("server", config) + val clientSystem = ActorSystem("client", config) + val serverMaterializer = ActorMaterializer(ActorMaterializerSettings(serverSystem) + .withSubscriptionTimeoutSettings(StreamSubscriptionTimeoutSettings( + StreamSubscriptionTimeoutTerminationMode.cancel, 42.seconds)))(serverSystem) + val clientMaterializer = ActorMaterializer(ActorMaterializerSettings(clientSystem) + .withSubscriptionTimeoutSettings(StreamSubscriptionTimeoutSettings( + StreamSubscriptionTimeoutTerminationMode.cancel, 42.seconds)))(clientSystem) - // then connect, should trigger a block and then - val total = Source(immutable.Iterable.fill(1000)(ByteString(0))) - .via(Tcp().outgoingConnection(address)) - .runFold(0)(_ + _.size) + try { - total.futureValue should ===(1000) + val address = temporaryServerAddress() + val completeRequest = TestLatch()(serverSystem) + val serverGotRequest = Promise[Done]() + + def portClosed(): Boolean = + try { + val socket = new Socket() + socket.connect(address, 250) + socket.close() + serverSystem.log.info("port open") + false + } catch { + case _: SocketTimeoutException ⇒ true + case _: SocketException ⇒ true + } + + import serverSystem.dispatcher + val futureBinding: Future[ServerBinding] = + Tcp(serverSystem).bind(address.getHostName, address.getPort) + // accept one connection, then cancel + .take(1) + // keep the accepted request hanging + .map { connection ⇒ + serverGotRequest.success(Done) + Future { + Await.ready(completeRequest, remainingOrDefault) // wait for the port close below + // when the server has closed the port and stopped accepting incoming + // connections, complete the one accepted connection + connection.flow.join(Flow[ByteString]).run() + } + } + .to(Sink.ignore) + .run()(serverMaterializer) + + // make sure server is running first + futureBinding.futureValue + + // then connect once, which should lead to the server cancelling + val total = Source(immutable.Iterable.fill(100)(ByteString(0))) + .via(Tcp(clientSystem).outgoingConnection(address)) + .runFold(0)(_ + _.size)(clientMaterializer) + + serverGotRequest.future.futureValue + // this can take a bit of time worst case but is often swift + awaitCond(portClosed()) + completeRequest.open() + + total.futureValue should ===(100) // connection + + } finally { + TestKit.shutdownActorSystem(serverSystem) + TestKit.shutdownActorSystem(clientSystem) + } } "shut down properly even if some accepted connection Flows have not been subscribed to" in assertAllStagesStopped { diff --git a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala index 09913694c9..f0f19edc8f 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/io/TcpStages.scala @@ -54,6 +54,7 @@ private[stream] class ConnectionSourceStage( val connectionFlowsAwaitingInitialization = new AtomicLong() var listener: ActorRef = _ var unbindPromise = Promise[Unit]() + var unbindStarted = false override def preStart(): Unit = { getStageActor(receive) @@ -84,11 +85,15 @@ private[stream] class ConnectionSourceStage( push(out, connectionFor(c, sender)) case Unbind ⇒ if (!isClosed(out) && (listener ne null)) tryUnbind() - case Unbound ⇒ // If we're unbound then just shut down - if (connectionFlowsAwaitingInitialization.get() == 0) completeStage() - else scheduleOnce(BindShutdownTimer, bindShutdownTimeout) + case Unbound ⇒ + unbindCompleted() case Terminated(ref) if ref == listener ⇒ - failStage(new IllegalStateException("IO Listener actor terminated unexpectedly")) + if (unbindStarted) { + unbindCompleted() + } else { + failStage(new IllegalStateException("IO Listener actor terminated unexpectedly for remote endpoint [" + + endpoint.getHostString + ":" + endpoint.getPort + "]")) + } } } @@ -125,13 +130,19 @@ private[stream] class ConnectionSourceStage( } private def tryUnbind(): Unit = { - if (listener ne null) { - stageActor.unwatch(listener) + if ((listener ne null) && !unbindStarted) { + unbindStarted = true setKeepGoing(true) listener ! Unbind } } + private def unbindCompleted(): Unit = { + stageActor.unwatch(listener) + if (connectionFlowsAwaitingInitialization.get() == 0) completeStage() + else scheduleOnce(BindShutdownTimer, bindShutdownTimeout) + } + override def onTimer(timerKey: Any): Unit = timerKey match { case BindShutdownTimer ⇒ completeStage() // TODO need to manually shut down instead right?