pekko/project/build/MultiJvmTests.scala

229 lines
8.7 KiB
Scala

import sbt._
import sbt.Process
import java.io.File
import java.lang.{ProcessBuilder => JProcessBuilder}
import java.io.{BufferedReader, Closeable, InputStream, InputStreamReader, IOException, OutputStream}
import java.io.{PipedInputStream, PipedOutputStream}
import scala.concurrent.SyncVar
trait MultiJvmTests extends DefaultProject {
def multiJvmTestName = "MultiJvm"
def multiJvmOptions: Seq[String] = Seq.empty
def multiJvmExtraOptions(className: String): Seq[String] = Seq.empty
val MultiJvmTestName = multiJvmTestName
val ScalaTestRunner = "org.scalatest.tools.Runner"
val ScalaTestOptions = "-o"
val javaPath = Path.fromFile(System.getProperty("java.home")) / "bin" / "java"
private val HeaderStart = System.getProperty("sbt.start.delimiter", "==")
private val HeaderEnd = System.getProperty("sbt.end.delimiter", "==")
// exclude multi jvm tests from normal tests
override def testOptions = super.testOptions ++ Seq(TestFilter(test => !test.name.contains(MultiJvmTestName)))
lazy val multiJvmTest = multiJvmTestAction
lazy val multiJvmRun = multiJvmRunAction
lazy val multiJvmTestAll = multiJvmTestAllAction
def multiJvmTestAction = multiJvmMethod(getMultiJvmTests, testScalaOptions)
def multiJvmRunAction = multiJvmMethod(getMultiJvmApps, runScalaOptions)
def multiJvmTestAllAction = multiJvmTask(Nil, getMultiJvmTests, testScalaOptions)
def multiJvmMethod(getMultiTestsMap: => Map[String, Seq[String]], scalaOptions: String => Seq[String]) = {
task { args =>
multiJvmTask(args.toList, getMultiTestsMap, scalaOptions)
} completeWith(getMultiTestsMap.keys.toList)
}
def multiJvmTask(tests: List[String], getMultiTestsMap: => Map[String, Seq[String]], scalaOptions: String => Seq[String]) = {
task {
val multiTestsMap = getMultiTestsMap
def process(runTests: List[String]): Option[String] = {
if (runTests.isEmpty) {
None
} else {
val testName = runTests(0)
val failed = multiTestsMap.get(testName) match {
case Some(testClasses) => runMulti(testName, testClasses, scalaOptions)
case None => Some("No multi jvm test called " + testName)
}
failed orElse process(runTests.tail)
}
}
val runTests = if (tests.size > 0) tests else multiTestsMap.keys.toList.asInstanceOf[List[String]]
process(runTests)
} dependsOn (testCompile)
}
def getMultiJvmTests(): Map[String, Seq[String]] = {
val allTests = testCompileConditional.analysis.allTests.toList.map(_.className)
filterMultiJvmTests(allTests)
}
def getMultiJvmApps(): Map[String, Seq[String]] = {
val allApps = (mainCompileConditional.analysis.allApplications.toSeq ++
testCompileConditional.analysis.allApplications.toSeq)
filterMultiJvmTests(allApps)
}
def filterMultiJvmTests(allTests: Seq[String]): Map[String, Seq[String]] = {
val multiJvmTests = allTests filter (_.contains(MultiJvmTestName))
val names = multiJvmTests map { fullName =>
val lastDot = fullName.lastIndexOf(".")
val className = if (lastDot >= 0) fullName.substring(lastDot + 1) else fullName
val i = className.indexOf(MultiJvmTestName)
if (i >= 0) className.substring(0, i) else className
}
val testPairs = names map { name => (name, multiJvmTests.toList.filter(_.contains(name)).sort(_ < _)) }
Map(testPairs: _*)
}
def testIdentifier(className: String) = {
val i = className.indexOf(MultiJvmTestName)
val l = MultiJvmTestName.length
className.substring(i + l)
}
def testSimpleName(className: String) = {
className.split("\\.").last
}
def testScalaOptions(testClass: String) = {
val scalaTestJars = testClasspath.get.filter(_.name.contains("scalatest"))
val cp = Path.makeString(scalaTestJars)
val paths = "\"" + testClasspath.get.map(_.absolutePath).mkString(" ", " ", " ") + "\""
Seq("-cp", cp, ScalaTestRunner, ScalaTestOptions, "-s", testClass, "-p", paths)
}
def runScalaOptions(appClass: String) = {
val cp = Path.makeString(testClasspath.get)
Seq("-cp", cp, appClass)
}
def runMulti(testName: String, testClasses: Seq[String], scalaOptions: String => Seq[String]) = {
log.control(ControlEvent.Start, "%s multi-jvm / %s %s" format (HeaderStart, testName, HeaderEnd))
val processes = testClasses.toList.zipWithIndex map {
case (testClass, index) => {
val jvmName = "JVM-" + testIdentifier(testClass)
val jvmLogger = new JvmLogger(jvmName)
val className = testSimpleName(testClass)
val optionsFiles = (testSourcePath ** (className + ".opts")).get
val optionsFromFile: Seq[String] = {
if (!optionsFiles.isEmpty) {
val file = optionsFiles.toList.head.asFile
log.info("Reading JVM options from %s" + file)
FileUtilities.readString(file, log) match {
case Right(opts: String) => opts.trim.split(" ").toSeq
case _ => Seq.empty
}
} else Seq.empty
}
val extraOptions = multiJvmExtraOptions(className)
val jvmOptions = multiJvmOptions ++ optionsFromFile ++ extraOptions
log.info("Starting %s for %s" format (jvmName, testClass))
log.info(" with JVM options: %s" format jvmOptions.mkString(" "))
(testClass, startJvm(jvmOptions, scalaOptions(testClass), jvmLogger, index == 0))
}
}
val exitCodes = processes map {
case (testClass, process) => (testClass, process.exitValue)
}
val failures = exitCodes flatMap {
case (testClass, exit) if exit > 0 => Some("%s failed with exit code %s" format (testClass, exit))
case _ => None
}
failures foreach (log.error(_))
log.control(ControlEvent.Finish, "%s multi-jvm / %s %s" format (HeaderStart, testName, HeaderEnd))
if (!failures.isEmpty) Some("Some processes failed") else None
}
def startJvm(jvmOptions: Seq[String], scalaOptions: Seq[String], logger: Logger, connectInput: Boolean) = {
val si = buildScalaInstance
val scalaJars = Seq(si.libraryJar, si.compilerJar)
forkScala(jvmOptions, scalaJars, scalaOptions, logger, connectInput)
}
def forkScala(jvmOptions: Seq[String], scalaJars: Iterable[File], arguments: Seq[String], logger: Logger, connectInput: Boolean) = {
val scalaClasspath = scalaJars.map(_.getAbsolutePath).mkString(File.pathSeparator)
val bootClasspath = "-Xbootclasspath/a:" + scalaClasspath
val mainScalaClass = "scala.tools.nsc.MainGenericRunner"
val options = jvmOptions ++ Seq(bootClasspath, mainScalaClass) ++ arguments
forkJava(options, logger, connectInput)
}
def forkJava(options: Seq[String], logger: Logger, connectInput: Boolean) = {
val java = javaPath.toString
val command = (java :: options.toList).toArray
val builder = new JProcessBuilder(command: _*)
Process(builder).run(JvmIO(logger, connectInput))
}
}
final class JvmLogger(name: String) extends BasicLogger {
def jvm(message: String) = "[%s] %s" format (name, message)
def log(level: Level.Value, message: => String) = System.out.synchronized {
System.out.println(jvm(message))
}
def trace(t: => Throwable) = System.out.synchronized {
val traceLevel = getTrace
if (traceLevel >= 0) System.out.print(StackTrace.trimmed(t, traceLevel))
}
def success(message: => String) = log(Level.Info, message)
def control(event: ControlEvent.Value, message: => String) = log(Level.Info, message)
def logAll(events: Seq[LogEvent]) = System.out.synchronized { events.foreach(log) }
}
object JvmIO {
def apply(log: Logger, connectInput: Boolean) =
new ProcessIO(input(connectInput), processStream(log, Level.Info), processStream(log, Level.Error))
final val BufferSize = 8192
def processStream(log: Logger, level: Level.Value): InputStream => Unit =
processStream(line => log.log(level, line))
def processStream(processLine: String => Unit): InputStream => Unit = in => {
val reader = new BufferedReader(new InputStreamReader(in))
def process {
val line = reader.readLine()
if (line != null) {
processLine(line)
process
}
}
process
}
def input(connectInput: Boolean): OutputStream => Unit =
if (connectInput) connectSystemIn else ignoreOutputStream
def connectSystemIn(out: OutputStream) = transfer(System.in, out)
def ignoreOutputStream = (out: OutputStream) => ()
def transfer(in: InputStream, out: OutputStream): Unit = {
try {
val buffer = new Array[Byte](BufferSize)
def read {
val byteCount = in.read(buffer)
if (Thread.interrupted) throw new InterruptedException
if (byteCount > 0) {
out.write(buffer, 0, byteCount)
out.flush()
read
}
}
read
} catch {
case _: InterruptedException => ()
}
}
}