blob: 22daf11df7fca848090b73924995cb12d2687d1b [file] [log] [blame]
import java.nio.file.attribute.BasicFileAttributes
import java.nio.file.{Path, Paths, _}
import sbt.{IO, _}
import scala.collection.JavaConversions._
import scala.collection.concurrent.TrieMap
import scala.xml.{Elem, XML}
object SparkInstaller {
private val attempts = TrieMap[(String, String), Unit]()
private def isGitVersion(providedVersion: String): Boolean = {
val GitTagPattern = """v[0-9].+""".r
val GitBranchPattern = """(master|branch).*\-SNAPSHOT""".r
providedVersion match {
case GitTagPattern() => true
case GitBranchPattern(_) => true
case _ => false
}
}
private object StdProcessLogger extends sbt.ProcessLogger {
override def info(s: => String): Unit =
if (s.startsWith("\r")) System.out.println("RRR " + s)
else if (s.endsWith("\r")) System.out.println(s + "RRR")
else System.out.println(s)
override def error(s: => String): Unit =
if (s.startsWith("\r")) System.err.println("RRR " + s)
else if (s.endsWith("\r")) System.err.println(s + "RRR")
else System.err.println(s)
override def buffer[T](f: => T): T = f
}
private def isInstalled(version: String, scalaVersion: String): Boolean = {
val mavenRepoPath = Paths.get(new URL(BuildUtil.mavenLocalResolver.root).toURI)
val sparkCoreJar = mavenRepoPath.resolve("org").resolve("apache").resolve("spark")
.resolve(s"spark-core_$scalaVersion").resolve(version)
.resolve(s"spark-core_$scalaVersion-$version.jar")
Files.isRegularFile(sparkCoreJar)
}
private def cloneSpark(version: String, destination: Path, force: Boolean): Unit = {
val isSnapshot = version.endsWith("-SNAPSHOT")
val baseVersion = if (isSnapshot) version.dropRight("-SNAPSHOT".length) else version
if (!Files.exists(destination.resolve("pom.xml")) || force || isSnapshot) {
IO.delete(destination.toFile)
val dir = destination.toAbsolutePath.toString
val repo = "https://github.com/apache/spark.git"
val result = s"git clone --branch $baseVersion --depth 1 $repo $dir" ! StdProcessLogger
if (result != 0)
throw new RuntimeException("Failed to fetch Spark source code")
}
}
private def replaceVersionInPoms(rootDir: Path, newVersion: String): Unit = {
def replaceVersionInPomXml(xml: Elem, newVersion: String): Elem = {
import scala.xml.Node
def replaceVersion(nodes: Seq[Node]): Seq[Node] =
for (node <- nodes) yield node match {
case <version>{_}</version> =>
<version>{newVersion}</version>
case parent @ <parent>{nodes@_*}</parent> if (parent \ "groupId").text == "org.apache.spark" =>
<parent>{replaceVersion(nodes)}</parent>
case other => other
}
xml.copy(child = replaceVersion(xml.child))
}
def replaceVersionInPom(file: Path, newVersion: String): Unit = {
val pom = XML.loadFile(file.toFile)
val updatedPom = replaceVersionInPomXml(pom, newVersion)
XML.save(file.toAbsolutePath.toString, updatedPom, "UTF-8")
}
Files.walkFileTree(rootDir, Set.empty[FileVisitOption], 100, new SimpleFileVisitor[Path]() {
override def visitFile(file: Path, attrs: BasicFileAttributes) = {
if (file.getFileName.toString == "pom.xml")
replaceVersionInPom(file, newVersion)
FileVisitResult.CONTINUE
}
override def preVisitDirectory(dir: Path, attrs: BasicFileAttributes) = {
if (Files.isHidden(dir) || dir.getFileName.toString.startsWith("."))
FileVisitResult.SKIP_SUBTREE
else FileVisitResult.CONTINUE
}
})
}
private def installSpark(rootDir: Path, scalaVersion: String): Unit = {
val `scala_2.11` = scalaVersion != "2.10"
if (`scala_2.11`) {
val cmd = List(
Paths.get("dev").resolve("change-scala-version.sh").toString,
"2.11"
)
val result = sbt.Process(cmd, rootDir.toFile) ! StdProcessLogger
if (result != 0)
throw new RuntimeException("Failed to change Scala version")
}
val cmd = List(
Paths.get("build").resolve("mvn").toString,
"--force",
"--batch-mode",
"-DskipTests"
) ::: (if (`scala_2.11`) List("-Dscala-2.11") else Nil) :::
"install" :: Nil
val result = sbt.Process(cmd, rootDir.toFile, "AMPLAB_JENKINS" -> "1") ! StdProcessLogger
if (result != 0)
throw new RuntimeException("Failed to install Spark")
}
private lazy val isWindows = sys.props("os.name").toLowerCase().contains("windows")
private def getAndInstallSpark(version: String, scalaVersion: String, force: Boolean, rootDir: Path): Unit = {
if ((force || !isInstalled(version, scalaVersion)) && !attempts.contains(version -> scalaVersion)) {
attempts.putIfAbsent(version -> scalaVersion, {})
if (isWindows) {
println(s"""
Cannot automatically fetch and install Spark on Windows. You need to do this on your own.
""")
throw new RuntimeException("Cannot install Spark")
}
val sparkDir = rootDir.resolve("target").resolve("spark-builds").resolve(version)
if (!Files.isDirectory(sparkDir)) {
Files.createDirectories(sparkDir)
}
try {
println(s"Cloning Spark $version to $sparkDir...")
cloneSpark(version, sparkDir, force)
println("Replacing version in POM files...")
replaceVersionInPoms(sparkDir, version)
println("Building and installing Spark to local Maven repository...")
installSpark(sparkDir, scalaVersion)
println(s"Spark $version has been installed.")
} catch {
case ex: Throwable =>
ex.printStackTrace()
throw ex;
}
} else {
println(s"Spark $version for Scala $scalaVersion detected.")
}
}
private lazy val force = sys.props.get("spark.forceInstall").exists(_.toLowerCase == "true")
def apply(scalaVersion: String, dir: Path): Unit = {
val version = Versions.Spark
if (isGitVersion(version) && !Versions.doNotInstallSpark) {
getAndInstallSpark(version, scalaVersion, force, dir)
}
}
}