blob: c2ebd388a2365924b92d752c3a4576fface8b271 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream}
import java.net.{HttpURLConnection, URI, URL}
import java.nio.charset.StandardCharsets
import java.security.SecureRandom
import java.security.cert.X509Certificate
import java.util.{Arrays, Properties}
import java.util.concurrent.{TimeoutException, TimeUnit}
import java.util.jar.{JarEntry, JarOutputStream}
import javax.net.ssl._
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.sys.process.{Process, ProcessLogger}
import scala.util.Try
import com.google.common.io.{ByteStreams, Files}
import org.apache.log4j.PropertyConfigurator
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.util.Utils
/**
* Utilities for tests. Included in main codebase since it's used by multiple
* projects.
*
* TODO: See if we can move this to the test codebase by specifying
* test dependencies between projects.
*/
private[spark] object TestUtils {
/**
* Create a jar that defines classes with the given names.
*
* Note: if this is used during class loader tests, class names should be unique
* in order to avoid interference between tests.
*/
def createJarWithClasses(
classNames: Seq[String],
toStringValue: String = "",
classNamesWithBase: Seq[(String, String)] = Seq.empty,
classpathUrls: Seq[URL] = Seq.empty): URL = {
val tempDir = Utils.createTempDir()
val files1 = for (name <- classNames) yield {
createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls)
}
val files2 = for ((childName, baseName) <- classNamesWithBase) yield {
createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls)
}
val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis()))
createJar(files1 ++ files2, jarFile)
}
/**
* Create a jar file containing multiple files. The `files` map contains a mapping of
* file names in the jar file to their contents.
*/
def createJarWithFiles(files: Map[String, String], dir: File = null): URL = {
val tempDir = Option(dir).getOrElse(Utils.createTempDir())
val jarFile = File.createTempFile("testJar", ".jar", tempDir)
val jarStream = new JarOutputStream(new FileOutputStream(jarFile))
files.foreach { case (k, v) =>
val entry = new JarEntry(k)
jarStream.putNextEntry(entry)
ByteStreams.copy(new ByteArrayInputStream(v.getBytes(StandardCharsets.UTF_8)), jarStream)
}
jarStream.close()
jarFile.toURI.toURL
}
/**
* Create a jar file that contains this set of files. All files will be located in the specified
* directory or at the root of the jar.
*/
def createJar(files: Seq[File], jarFile: File, directoryPrefix: Option[String] = None): URL = {
val jarFileStream = new FileOutputStream(jarFile)
val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest())
for (file <- files) {
// The `name` for the argument in `JarEntry` should use / for its separator. This is
// ZIP specification.
val prefix = directoryPrefix.map(d => s"$d/").getOrElse("")
val jarEntry = new JarEntry(prefix + file.getName)
jarStream.putNextEntry(jarEntry)
val in = new FileInputStream(file)
ByteStreams.copy(in, jarStream)
in.close()
}
jarStream.close()
jarFileStream.close()
jarFile.toURI.toURL
}
// Adapted from the JavaCompiler.java doc examples
private val SOURCE = JavaFileObject.Kind.SOURCE
private def createURI(name: String) = {
URI.create(s"string:///${name.replace(".", "/")}${SOURCE.extension}")
}
private[spark] class JavaSourceFromString(val name: String, val code: String)
extends SimpleJavaFileObject(createURI(name), SOURCE) {
override def getCharContent(ignoreEncodingErrors: Boolean): String = code
}
/** Creates a compiled class with the source file. Class file will be placed in destDir. */
def createCompiledClass(
className: String,
destDir: File,
sourceFile: JavaSourceFromString,
classpathUrls: Seq[URL]): File = {
val compiler = ToolProvider.getSystemJavaCompiler
// Calling this outputs a class file in pwd. It's easier to just rename the files than
// build a custom FileManager that controls the output location.
val options = if (classpathUrls.nonEmpty) {
Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator))
} else {
Seq.empty
}
compiler.getTask(null, null, null, options.asJava, null, Arrays.asList(sourceFile)).call()
val fileName = className + ".class"
val result = new File(fileName)
assert(result.exists(), "Compiled file not found: " + result.getAbsolutePath())
val out = new File(destDir, fileName)
// renameTo cannot handle in and out files in different filesystems
// use google's Files.move instead
Files.move(result, out)
assert(out.exists(), "Destination file not moved: " + out.getAbsolutePath())
out
}
/** Creates a compiled class with the given name. Class file will be placed in destDir. */
def createCompiledClass(
className: String,
destDir: File,
toStringValue: String = "",
baseClass: String = null,
classpathUrls: Seq[URL] = Seq.empty): File = {
val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("")
val sourceFile = new JavaSourceFromString(className,
"public class " + className + extendsText + " implements java.io.Serializable {" +
" @Override public String toString() { return \"" + toStringValue + "\"; }}")
createCompiledClass(className, destDir, sourceFile, classpathUrls)
}
/**
* Run some code involving jobs submitted to the given context and assert that the jobs spilled.
*/
def assertSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = {
val listener = new SpillListener
withListener(sc, listener) { _ =>
body
}
assert(listener.numSpilledStages > 0, s"expected $identifier to spill, but did not")
}
/**
* Run some code involving jobs submitted to the given context and assert that the jobs
* did not spill.
*/
def assertNotSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = {
val listener = new SpillListener
withListener(sc, listener) { _ =>
body
}
assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did")
}
/**
* Test if a command is available.
*/
def testCommandAvailable(command: String): Boolean = {
val attempt = Try(Process(command).run(ProcessLogger(_ => ())).exitValue())
attempt.isSuccess && attempt.get == 0
}
/**
* Returns the response code from an HTTP(S) URL.
*/
def httpResponseCode(
url: URL,
method: String = "GET",
headers: Seq[(String, String)] = Nil): Int = {
val connection = url.openConnection().asInstanceOf[HttpURLConnection]
connection.setRequestMethod(method)
headers.foreach { case (k, v) => connection.setRequestProperty(k, v) }
// Disable cert and host name validation for HTTPS tests.
if (connection.isInstanceOf[HttpsURLConnection]) {
val sslCtx = SSLContext.getInstance("SSL")
val trustManager = new X509TrustManager {
override def getAcceptedIssuers(): Array[X509Certificate] = null
override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {}
override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {}
}
val verifier = new HostnameVerifier() {
override def verify(hostname: String, session: SSLSession): Boolean = true
}
sslCtx.init(null, Array(trustManager), new SecureRandom())
connection.asInstanceOf[HttpsURLConnection].setSSLSocketFactory(sslCtx.getSocketFactory())
connection.asInstanceOf[HttpsURLConnection].setHostnameVerifier(verifier)
}
try {
connection.connect()
connection.getResponseCode()
} finally {
connection.disconnect()
}
}
/**
* Runs some code with the given listener installed in the SparkContext. After the code runs,
* this method will wait until all events posted to the listener bus are processed, and then
* remove the listener from the bus.
*/
def withListener[L <: SparkListener](sc: SparkContext, listener: L) (body: L => Unit): Unit = {
sc.addSparkListener(listener)
try {
body(listener)
} finally {
sc.listenerBus.waitUntilEmpty(TimeUnit.SECONDS.toMillis(10))
sc.listenerBus.removeListener(listener)
}
}
/**
* Wait until at least `numExecutors` executors are up, or throw `TimeoutException` if the waiting
* time elapsed before `numExecutors` executors up. Exposed for testing.
*
* @param numExecutors the number of executors to wait at least
* @param timeout time to wait in milliseconds
*/
private[spark] def waitUntilExecutorsUp(
sc: SparkContext,
numExecutors: Int,
timeout: Long): Unit = {
val finishTime = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(timeout)
while (System.nanoTime() < finishTime) {
if (sc.statusTracker.getExecutorInfos.length > numExecutors) {
return
}
// Sleep rather than using wait/notify, because this is used only for testing and wait/notify
// add overhead in the general case.
Thread.sleep(10)
}
throw new TimeoutException(
s"Can't find $numExecutors executors before $timeout milliseconds elapsed")
}
/**
* config a log4j properties used for testsuite
*/
def configTestLog4j(level: String): Unit = {
val pro = new Properties()
pro.put("log4j.rootLogger", s"$level, console")
pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender")
pro.put("log4j.appender.console.target", "System.err")
pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout")
pro.put("log4j.appender.console.layout.ConversionPattern",
"%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n")
PropertyConfigurator.configure(pro)
}
/**
* Lists files recursively.
*/
def recursiveList(f: File): Array[File] = {
require(f.isDirectory)
val current = f.listFiles
current ++ current.filter(_.isDirectory).flatMap(recursiveList)
}
}
/**
* A `SparkListener` that detects whether spills have occurred in Spark jobs.
*/
private class SpillListener extends SparkListener {
private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]]
private val spilledStageIds = new mutable.HashSet[Int]
def numSpilledStages: Int = synchronized {
spilledStageIds.size
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
stageIdToTaskMetrics.getOrElseUpdate(
taskEnd.stageId, new ArrayBuffer[TaskMetrics]) += taskEnd.taskMetrics
}
override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = synchronized {
val stageId = stageComplete.stageInfo.stageId
val metrics = stageIdToTaskMetrics.remove(stageId).toSeq.flatten
val spilled = metrics.map(_.memoryBytesSpilled).sum > 0
if (spilled) {
spilledStageIds += stageId
}
}
}