Added MultiJvmTests trait

This commit is contained in:
Jonas Bonér 2011-05-16 12:33:53 +02:00
parent 2655d44ee9
commit 70bbeba2a0

View file

@ -0,0 +1,203 @@
import sbt._
import sbt.Process
import java.io.File
import java.lang.{ProcessBuilder => JProcessBuilder}
import java.io.{BufferedReader, Closeable, InputStream, InputStreamReader, IOException, OutputStream}
import java.io.{PipedInputStream, PipedOutputStream}
import scala.concurrent.SyncVar
trait MultiJvmTests extends BasicScalaProject {
def multiJvmTestName = "MultiJvm"
def multiJvmOptions: Seq[String] = Nil
val MultiJvmTestName = multiJvmTestName
val ScalaTestRunner = "org.scalatest.tools.Runner"
val ScalaTestOptions = "-o"
val javaPath = Path.fromFile(System.getProperty("java.home")) / "bin" / "java"
private val HeaderStart = System.getProperty("sbt.start.delimiter", "==")
private val HeaderEnd = System.getProperty("sbt.end.delimiter", "==")
// exclude multi jvm tests from normal tests
override def testOptions = super.testOptions ++ Seq(TestFilter(test => !test.name.contains(MultiJvmTestName)))
lazy val multiJvmTest = multiJvmTestAction
lazy val multiJvmRun = multiJvmRunAction
def multiJvmTestAction = multiJvmAction(getMultiJvmTests, testScalaOptions)
def multiJvmRunAction = multiJvmAction(getMultiJvmApps, runScalaOptions)
def multiJvmAction(getMultiTestsMap: => Map[String, Seq[String]], scalaOptions: String => Seq[String]) = {
task { args =>
task {
val multiTestsMap = getMultiTestsMap
def process(tests: List[String]): Option[String] = {
if (tests.isEmpty) {
None
} else {
val testName = tests(0)
val failed = multiTestsMap.get(testName) match {
case Some(testClasses) => runMulti(testName, testClasses, scalaOptions)
case None => Some("No multi jvm test called " + testName)
}
failed orElse process(tests.tail)
}
}
val tests = if (args.size > 0) args.toList else multiTestsMap.keys.toList.asInstanceOf[List[String]]
process(tests)
} dependsOn (testCompile)
} completeWith(getMultiTestsMap.keys.toList)
}
def getMultiJvmTests(): Map[String, Seq[String]] = {
val allTests = testCompileConditional.analysis.allTests.toList.map(_.className)
filterMultiJvmTests(allTests)
}
def getMultiJvmApps(): Map[String, Seq[String]] = {
val allApps = (mainCompileConditional.analysis.allApplications.toSeq ++
testCompileConditional.analysis.allApplications.toSeq)
filterMultiJvmTests(allApps)
}
def filterMultiJvmTests(allTests: Seq[String]): Map[String, Seq[String]] = {
val multiJvmTests = allTests filter (_.contains(MultiJvmTestName))
val names = multiJvmTests map { fullName =>
val lastDot = fullName.lastIndexOf(".")
val className = if (lastDot >= 0) fullName.substring(lastDot + 1) else fullName
val i = className.indexOf(MultiJvmTestName)
if (i >= 0) className.substring(0, i) else className
}
val testPairs = names map { name => (name, multiJvmTests.toList.filter(_.contains(name)).sort(_ < _)) }
Map(testPairs: _*)
}
def testIdentifier(className: String) = {
val i = className.indexOf(MultiJvmTestName)
val l = MultiJvmTestName.length
className.substring(i + l)
}
def testScalaOptions(testClass: String) = {
val scalaTestJars = testClasspath.get.filter(_.name.contains("scalatest"))
val cp = Path.makeString(scalaTestJars)
val paths = "\"" + testClasspath.get.map(_.absolutePath).mkString(" ", " ", " ") + "\""
Seq("-cp", cp, ScalaTestRunner, ScalaTestOptions, "-s", testClass, "-p", paths)
}
def runScalaOptions(appClass: String) = {
val cp = Path.makeString(testClasspath.get)
Seq("-cp", cp, appClass)
}
def runMulti(testName: String, testClasses: Seq[String], scalaOptions: String => Seq[String]) = {
log.control(ControlEvent.Start, "%s multi-jvm / %s %s" format (HeaderStart, testName, HeaderEnd))
val processes = testClasses.toList.zipWithIndex map {
case (testClass, index) => {
val jvmName = "JVM-" + testIdentifier(testClass)
val jvmLogger = new JvmLogger(jvmName)
log.info("Starting %s for %s" format (jvmName, testClass))
(testClass, startJvm(multiJvmOptions, scalaOptions(testClass), jvmLogger, index == 0))
}
}
val exitCodes = processes map {
case (testClass, process) => (testClass, process.exitValue)
}
val failures = exitCodes flatMap {
case (testClass, exit) if exit > 0 => Some("%s failed with exit code %s" format (testClass, exit))
case _ => None
}
failures foreach (log.error(_))
log.control(ControlEvent.Finish, "%s multi-jvm / %s %s" format (HeaderStart, testName, HeaderEnd))
if (!failures.isEmpty) Some("Some processes failed") else None
}
def startJvm(jvmOptions: Seq[String], scalaOptions: Seq[String], logger: Logger, connectInput: Boolean) = {
val si = buildScalaInstance
val scalaJars = Seq(si.libraryJar, si.compilerJar)
forkScala(jvmOptions, scalaJars, scalaOptions, logger, connectInput)
}
def forkScala(jvmOptions: Seq[String], scalaJars: Iterable[File], arguments: Seq[String], logger: Logger, connectInput: Boolean) = {
val scalaClasspath = scalaJars.map(_.getAbsolutePath).mkString(File.pathSeparator)
val bootClasspath = "-Xbootclasspath/a:" + scalaClasspath
val mainScalaClass = "scala.tools.nsc.MainGenericRunner"
val options = jvmOptions ++ Seq(bootClasspath, mainScalaClass) ++ arguments
forkJava(options, logger, connectInput)
}
def forkJava(options: Seq[String], logger: Logger, connectInput: Boolean) = {
val java = javaPath.toString
val command = (java :: options.toList).toArray
val builder = new JProcessBuilder(command: _*)
Process(builder).run(JvmIO(logger, connectInput))
}
}
final class JvmLogger(name: String) extends BasicLogger {
def jvm(message: String) = "[%s] %s" format (name, message)
def log(level: Level.Value, message: => String) = System.out.synchronized {
System.out.println(jvm(message))
}
def trace(t: => Throwable) = System.out.synchronized {
val traceLevel = getTrace
if (traceLevel >= 0) System.out.print(StackTrace.trimmed(t, traceLevel))
}
def success(message: => String) = log(Level.Info, message)
def control(event: ControlEvent.Value, message: => String) = log(Level.Info, message)
def logAll(events: Seq[LogEvent]) = System.out.synchronized { events.foreach(log) }
}
object JvmIO {
def apply(log: Logger, connectInput: Boolean) =
new ProcessIO(input(connectInput), processStream(log, Level.Info), processStream(log, Level.Error))
final val BufferSize = 8192
def processStream(log: Logger, level: Level.Value): InputStream => Unit =
processStream(line => log.log(level, line))
def processStream(processLine: String => Unit): InputStream => Unit = in => {
val reader = new BufferedReader(new InputStreamReader(in))
def process {
val line = reader.readLine()
if (line != null) {
processLine(line)
process
}
}
process
}
def input(connectInput: Boolean): OutputStream => Unit =
if (connectInput) connectSystemIn else ignoreOutputStream
def connectSystemIn(out: OutputStream) = transfer(System.in, out)
def ignoreOutputStream = (out: OutputStream) => ()
def transfer(in: InputStream, out: OutputStream): Unit = {
try {
val buffer = new Array[Byte](BufferSize)
def read {
val byteCount = in.read(buffer)
if (Thread.interrupted) throw new InterruptedException
if (byteCount > 0) {
out.write(buffer, 0, byteCount)
out.flush()
read
}
}
read
} catch {
case _: InterruptedException => ()
}
}
}