diff --git a/project/build/MultiJvmTests.scala b/project/build/MultiJvmTests.scala new file mode 100644 index 0000000000..0ec0359c96 --- /dev/null +++ b/project/build/MultiJvmTests.scala @@ -0,0 +1,203 @@ +import sbt._ +import sbt.Process +import java.io.File +import java.lang.{ProcessBuilder => JProcessBuilder} +import java.io.{BufferedReader, Closeable, InputStream, InputStreamReader, IOException, OutputStream} +import java.io.{PipedInputStream, PipedOutputStream} +import scala.concurrent.SyncVar + +trait MultiJvmTests extends BasicScalaProject { + def multiJvmTestName = "MultiJvm" + def multiJvmOptions: Seq[String] = Nil + + val MultiJvmTestName = multiJvmTestName + + val ScalaTestRunner = "org.scalatest.tools.Runner" + val ScalaTestOptions = "-o" + + val javaPath = Path.fromFile(System.getProperty("java.home")) / "bin" / "java" + + private val HeaderStart = System.getProperty("sbt.start.delimiter", "==") + private val HeaderEnd = System.getProperty("sbt.end.delimiter", "==") + + // exclude multi jvm tests from normal tests + override def testOptions = super.testOptions ++ Seq(TestFilter(test => !test.name.contains(MultiJvmTestName))) + + lazy val multiJvmTest = multiJvmTestAction + lazy val multiJvmRun = multiJvmRunAction + + def multiJvmTestAction = multiJvmAction(getMultiJvmTests, testScalaOptions) + def multiJvmRunAction = multiJvmAction(getMultiJvmApps, runScalaOptions) + + def multiJvmAction(getMultiTestsMap: => Map[String, Seq[String]], scalaOptions: String => Seq[String]) = { + task { args => + task { + val multiTestsMap = getMultiTestsMap + def process(tests: List[String]): Option[String] = { + if (tests.isEmpty) { + None + } else { + val testName = tests(0) + val failed = multiTestsMap.get(testName) match { + case Some(testClasses) => runMulti(testName, testClasses, scalaOptions) + case None => Some("No multi jvm test called " + testName) + } + failed orElse process(tests.tail) + } + } + val tests = if (args.size > 0) args.toList else multiTestsMap.keys.toList.asInstanceOf[List[String]] + process(tests) + } dependsOn (testCompile) + } completeWith(getMultiTestsMap.keys.toList) + } + + def getMultiJvmTests(): Map[String, Seq[String]] = { + val allTests = testCompileConditional.analysis.allTests.toList.map(_.className) + filterMultiJvmTests(allTests) + } + + def getMultiJvmApps(): Map[String, Seq[String]] = { + val allApps = (mainCompileConditional.analysis.allApplications.toSeq ++ + testCompileConditional.analysis.allApplications.toSeq) + filterMultiJvmTests(allApps) + } + + def filterMultiJvmTests(allTests: Seq[String]): Map[String, Seq[String]] = { + val multiJvmTests = allTests filter (_.contains(MultiJvmTestName)) + val names = multiJvmTests map { fullName => + val lastDot = fullName.lastIndexOf(".") + val className = if (lastDot >= 0) fullName.substring(lastDot + 1) else fullName + val i = className.indexOf(MultiJvmTestName) + if (i >= 0) className.substring(0, i) else className + } + val testPairs = names map { name => (name, multiJvmTests.toList.filter(_.contains(name)).sort(_ < _)) } + Map(testPairs: _*) + } + + def testIdentifier(className: String) = { + val i = className.indexOf(MultiJvmTestName) + val l = MultiJvmTestName.length + className.substring(i + l) + } + + def testScalaOptions(testClass: String) = { + val scalaTestJars = testClasspath.get.filter(_.name.contains("scalatest")) + val cp = Path.makeString(scalaTestJars) + val paths = "\"" + testClasspath.get.map(_.absolutePath).mkString(" ", " ", " ") + "\"" + Seq("-cp", cp, ScalaTestRunner, ScalaTestOptions, "-s", testClass, "-p", paths) + } + + def runScalaOptions(appClass: String) = { + val cp = Path.makeString(testClasspath.get) + Seq("-cp", cp, appClass) + } + + def runMulti(testName: String, testClasses: Seq[String], scalaOptions: String => Seq[String]) = { + log.control(ControlEvent.Start, "%s multi-jvm / %s %s" format (HeaderStart, testName, HeaderEnd)) + val processes = testClasses.toList.zipWithIndex map { + case (testClass, index) => { + val jvmName = "JVM-" + testIdentifier(testClass) + val jvmLogger = new JvmLogger(jvmName) + log.info("Starting %s for %s" format (jvmName, testClass)) + (testClass, startJvm(multiJvmOptions, scalaOptions(testClass), jvmLogger, index == 0)) + } + } + val exitCodes = processes map { + case (testClass, process) => (testClass, process.exitValue) + } + val failures = exitCodes flatMap { + case (testClass, exit) if exit > 0 => Some("%s failed with exit code %s" format (testClass, exit)) + case _ => None + } + failures foreach (log.error(_)) + log.control(ControlEvent.Finish, "%s multi-jvm / %s %s" format (HeaderStart, testName, HeaderEnd)) + if (!failures.isEmpty) Some("Some processes failed") else None + } + + def startJvm(jvmOptions: Seq[String], scalaOptions: Seq[String], logger: Logger, connectInput: Boolean) = { + val si = buildScalaInstance + val scalaJars = Seq(si.libraryJar, si.compilerJar) + forkScala(jvmOptions, scalaJars, scalaOptions, logger, connectInput) + } + + def forkScala(jvmOptions: Seq[String], scalaJars: Iterable[File], arguments: Seq[String], logger: Logger, connectInput: Boolean) = { + val scalaClasspath = scalaJars.map(_.getAbsolutePath).mkString(File.pathSeparator) + val bootClasspath = "-Xbootclasspath/a:" + scalaClasspath + val mainScalaClass = "scala.tools.nsc.MainGenericRunner" + val options = jvmOptions ++ Seq(bootClasspath, mainScalaClass) ++ arguments + forkJava(options, logger, connectInput) + } + + def forkJava(options: Seq[String], logger: Logger, connectInput: Boolean) = { + val java = javaPath.toString + val command = (java :: options.toList).toArray + val builder = new JProcessBuilder(command: _*) + Process(builder).run(JvmIO(logger, connectInput)) + } +} + +final class JvmLogger(name: String) extends BasicLogger { + def jvm(message: String) = "[%s] %s" format (name, message) + + def log(level: Level.Value, message: => String) = System.out.synchronized { + System.out.println(jvm(message)) + } + + def trace(t: => Throwable) = System.out.synchronized { + val traceLevel = getTrace + if (traceLevel >= 0) System.out.print(StackTrace.trimmed(t, traceLevel)) + } + + def success(message: => String) = log(Level.Info, message) + def control(event: ControlEvent.Value, message: => String) = log(Level.Info, message) + + def logAll(events: Seq[LogEvent]) = System.out.synchronized { events.foreach(log) } +} + +object JvmIO { + def apply(log: Logger, connectInput: Boolean) = + new ProcessIO(input(connectInput), processStream(log, Level.Info), processStream(log, Level.Error)) + + final val BufferSize = 8192 + + def processStream(log: Logger, level: Level.Value): InputStream => Unit = + processStream(line => log.log(level, line)) + + def processStream(processLine: String => Unit): InputStream => Unit = in => { + val reader = new BufferedReader(new InputStreamReader(in)) + def process { + val line = reader.readLine() + if (line != null) { + processLine(line) + process + } + } + process + } + + def input(connectInput: Boolean): OutputStream => Unit = + if (connectInput) connectSystemIn else ignoreOutputStream + + def connectSystemIn(out: OutputStream) = transfer(System.in, out) + + def ignoreOutputStream = (out: OutputStream) => () + + def transfer(in: InputStream, out: OutputStream): Unit = { + try { + val buffer = new Array[Byte](BufferSize) + def read { + val byteCount = in.read(buffer) + if (Thread.interrupted) throw new InterruptedException + if (byteCount > 0) { + out.write(buffer, 0, byteCount) + out.flush() + read + } + } + read + } catch { + case _: InterruptedException => () + } + } +} +