Added MultiJvmTests trait
This commit is contained in:
parent
2655d44ee9
commit
70bbeba2a0
1 changed files with 203 additions and 0 deletions
203
project/build/MultiJvmTests.scala
Normal file
203
project/build/MultiJvmTests.scala
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
import sbt._
|
||||
import sbt.Process
|
||||
import java.io.File
|
||||
import java.lang.{ProcessBuilder => JProcessBuilder}
|
||||
import java.io.{BufferedReader, Closeable, InputStream, InputStreamReader, IOException, OutputStream}
|
||||
import java.io.{PipedInputStream, PipedOutputStream}
|
||||
import scala.concurrent.SyncVar
|
||||
|
||||
trait MultiJvmTests extends BasicScalaProject {
|
||||
def multiJvmTestName = "MultiJvm"
|
||||
def multiJvmOptions: Seq[String] = Nil
|
||||
|
||||
val MultiJvmTestName = multiJvmTestName
|
||||
|
||||
val ScalaTestRunner = "org.scalatest.tools.Runner"
|
||||
val ScalaTestOptions = "-o"
|
||||
|
||||
val javaPath = Path.fromFile(System.getProperty("java.home")) / "bin" / "java"
|
||||
|
||||
private val HeaderStart = System.getProperty("sbt.start.delimiter", "==")
|
||||
private val HeaderEnd = System.getProperty("sbt.end.delimiter", "==")
|
||||
|
||||
// exclude multi jvm tests from normal tests
|
||||
override def testOptions = super.testOptions ++ Seq(TestFilter(test => !test.name.contains(MultiJvmTestName)))
|
||||
|
||||
lazy val multiJvmTest = multiJvmTestAction
|
||||
lazy val multiJvmRun = multiJvmRunAction
|
||||
|
||||
def multiJvmTestAction = multiJvmAction(getMultiJvmTests, testScalaOptions)
|
||||
def multiJvmRunAction = multiJvmAction(getMultiJvmApps, runScalaOptions)
|
||||
|
||||
def multiJvmAction(getMultiTestsMap: => Map[String, Seq[String]], scalaOptions: String => Seq[String]) = {
|
||||
task { args =>
|
||||
task {
|
||||
val multiTestsMap = getMultiTestsMap
|
||||
def process(tests: List[String]): Option[String] = {
|
||||
if (tests.isEmpty) {
|
||||
None
|
||||
} else {
|
||||
val testName = tests(0)
|
||||
val failed = multiTestsMap.get(testName) match {
|
||||
case Some(testClasses) => runMulti(testName, testClasses, scalaOptions)
|
||||
case None => Some("No multi jvm test called " + testName)
|
||||
}
|
||||
failed orElse process(tests.tail)
|
||||
}
|
||||
}
|
||||
val tests = if (args.size > 0) args.toList else multiTestsMap.keys.toList.asInstanceOf[List[String]]
|
||||
process(tests)
|
||||
} dependsOn (testCompile)
|
||||
} completeWith(getMultiTestsMap.keys.toList)
|
||||
}
|
||||
|
||||
def getMultiJvmTests(): Map[String, Seq[String]] = {
|
||||
val allTests = testCompileConditional.analysis.allTests.toList.map(_.className)
|
||||
filterMultiJvmTests(allTests)
|
||||
}
|
||||
|
||||
def getMultiJvmApps(): Map[String, Seq[String]] = {
|
||||
val allApps = (mainCompileConditional.analysis.allApplications.toSeq ++
|
||||
testCompileConditional.analysis.allApplications.toSeq)
|
||||
filterMultiJvmTests(allApps)
|
||||
}
|
||||
|
||||
def filterMultiJvmTests(allTests: Seq[String]): Map[String, Seq[String]] = {
|
||||
val multiJvmTests = allTests filter (_.contains(MultiJvmTestName))
|
||||
val names = multiJvmTests map { fullName =>
|
||||
val lastDot = fullName.lastIndexOf(".")
|
||||
val className = if (lastDot >= 0) fullName.substring(lastDot + 1) else fullName
|
||||
val i = className.indexOf(MultiJvmTestName)
|
||||
if (i >= 0) className.substring(0, i) else className
|
||||
}
|
||||
val testPairs = names map { name => (name, multiJvmTests.toList.filter(_.contains(name)).sort(_ < _)) }
|
||||
Map(testPairs: _*)
|
||||
}
|
||||
|
||||
def testIdentifier(className: String) = {
|
||||
val i = className.indexOf(MultiJvmTestName)
|
||||
val l = MultiJvmTestName.length
|
||||
className.substring(i + l)
|
||||
}
|
||||
|
||||
def testScalaOptions(testClass: String) = {
|
||||
val scalaTestJars = testClasspath.get.filter(_.name.contains("scalatest"))
|
||||
val cp = Path.makeString(scalaTestJars)
|
||||
val paths = "\"" + testClasspath.get.map(_.absolutePath).mkString(" ", " ", " ") + "\""
|
||||
Seq("-cp", cp, ScalaTestRunner, ScalaTestOptions, "-s", testClass, "-p", paths)
|
||||
}
|
||||
|
||||
def runScalaOptions(appClass: String) = {
|
||||
val cp = Path.makeString(testClasspath.get)
|
||||
Seq("-cp", cp, appClass)
|
||||
}
|
||||
|
||||
def runMulti(testName: String, testClasses: Seq[String], scalaOptions: String => Seq[String]) = {
|
||||
log.control(ControlEvent.Start, "%s multi-jvm / %s %s" format (HeaderStart, testName, HeaderEnd))
|
||||
val processes = testClasses.toList.zipWithIndex map {
|
||||
case (testClass, index) => {
|
||||
val jvmName = "JVM-" + testIdentifier(testClass)
|
||||
val jvmLogger = new JvmLogger(jvmName)
|
||||
log.info("Starting %s for %s" format (jvmName, testClass))
|
||||
(testClass, startJvm(multiJvmOptions, scalaOptions(testClass), jvmLogger, index == 0))
|
||||
}
|
||||
}
|
||||
val exitCodes = processes map {
|
||||
case (testClass, process) => (testClass, process.exitValue)
|
||||
}
|
||||
val failures = exitCodes flatMap {
|
||||
case (testClass, exit) if exit > 0 => Some("%s failed with exit code %s" format (testClass, exit))
|
||||
case _ => None
|
||||
}
|
||||
failures foreach (log.error(_))
|
||||
log.control(ControlEvent.Finish, "%s multi-jvm / %s %s" format (HeaderStart, testName, HeaderEnd))
|
||||
if (!failures.isEmpty) Some("Some processes failed") else None
|
||||
}
|
||||
|
||||
def startJvm(jvmOptions: Seq[String], scalaOptions: Seq[String], logger: Logger, connectInput: Boolean) = {
|
||||
val si = buildScalaInstance
|
||||
val scalaJars = Seq(si.libraryJar, si.compilerJar)
|
||||
forkScala(jvmOptions, scalaJars, scalaOptions, logger, connectInput)
|
||||
}
|
||||
|
||||
def forkScala(jvmOptions: Seq[String], scalaJars: Iterable[File], arguments: Seq[String], logger: Logger, connectInput: Boolean) = {
|
||||
val scalaClasspath = scalaJars.map(_.getAbsolutePath).mkString(File.pathSeparator)
|
||||
val bootClasspath = "-Xbootclasspath/a:" + scalaClasspath
|
||||
val mainScalaClass = "scala.tools.nsc.MainGenericRunner"
|
||||
val options = jvmOptions ++ Seq(bootClasspath, mainScalaClass) ++ arguments
|
||||
forkJava(options, logger, connectInput)
|
||||
}
|
||||
|
||||
def forkJava(options: Seq[String], logger: Logger, connectInput: Boolean) = {
|
||||
val java = javaPath.toString
|
||||
val command = (java :: options.toList).toArray
|
||||
val builder = new JProcessBuilder(command: _*)
|
||||
Process(builder).run(JvmIO(logger, connectInput))
|
||||
}
|
||||
}
|
||||
|
||||
final class JvmLogger(name: String) extends BasicLogger {
|
||||
def jvm(message: String) = "[%s] %s" format (name, message)
|
||||
|
||||
def log(level: Level.Value, message: => String) = System.out.synchronized {
|
||||
System.out.println(jvm(message))
|
||||
}
|
||||
|
||||
def trace(t: => Throwable) = System.out.synchronized {
|
||||
val traceLevel = getTrace
|
||||
if (traceLevel >= 0) System.out.print(StackTrace.trimmed(t, traceLevel))
|
||||
}
|
||||
|
||||
def success(message: => String) = log(Level.Info, message)
|
||||
def control(event: ControlEvent.Value, message: => String) = log(Level.Info, message)
|
||||
|
||||
def logAll(events: Seq[LogEvent]) = System.out.synchronized { events.foreach(log) }
|
||||
}
|
||||
|
||||
object JvmIO {
|
||||
def apply(log: Logger, connectInput: Boolean) =
|
||||
new ProcessIO(input(connectInput), processStream(log, Level.Info), processStream(log, Level.Error))
|
||||
|
||||
final val BufferSize = 8192
|
||||
|
||||
def processStream(log: Logger, level: Level.Value): InputStream => Unit =
|
||||
processStream(line => log.log(level, line))
|
||||
|
||||
def processStream(processLine: String => Unit): InputStream => Unit = in => {
|
||||
val reader = new BufferedReader(new InputStreamReader(in))
|
||||
def process {
|
||||
val line = reader.readLine()
|
||||
if (line != null) {
|
||||
processLine(line)
|
||||
process
|
||||
}
|
||||
}
|
||||
process
|
||||
}
|
||||
|
||||
def input(connectInput: Boolean): OutputStream => Unit =
|
||||
if (connectInput) connectSystemIn else ignoreOutputStream
|
||||
|
||||
def connectSystemIn(out: OutputStream) = transfer(System.in, out)
|
||||
|
||||
def ignoreOutputStream = (out: OutputStream) => ()
|
||||
|
||||
def transfer(in: InputStream, out: OutputStream): Unit = {
|
||||
try {
|
||||
val buffer = new Array[Byte](BufferSize)
|
||||
def read {
|
||||
val byteCount = in.read(buffer)
|
||||
if (Thread.interrupted) throw new InterruptedException
|
||||
if (byteCount > 0) {
|
||||
out.write(buffer, 0, byteCount)
|
||||
out.flush()
|
||||
read
|
||||
}
|
||||
}
|
||||
read
|
||||
} catch {
|
||||
case _: InterruptedException => ()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue