blob: 5e3078d7292ba888c8d0aa8fbfd2916b0e0ecf7b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream}
import java.net.{HttpURLConnection, InetSocketAddress, URL}
import java.nio.charset.StandardCharsets
import java.nio.file.{Files => JavaFiles, Paths}
import java.nio.file.attribute.PosixFilePermission.{OWNER_EXECUTE, OWNER_READ, OWNER_WRITE}
import java.security.SecureRandom
import java.security.cert.X509Certificate
import java.util.{EnumSet, Locale}
import java.util.concurrent.{TimeoutException, TimeUnit}
import java.util.jar.{JarEntry, JarOutputStream, Manifest}
import java.util.regex.Pattern
import javax.net.ssl._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import scala.reflect.{classTag, ClassTag}
import scala.sys.process.Process
import scala.util.Try
import com.google.common.io.{ByteStreams, Files}
import org.apache.commons.lang3.StringUtils
import org.apache.logging.log4j.LogManager
import org.apache.logging.log4j.core.LoggerContext
import org.apache.logging.log4j.core.appender.ConsoleAppender
import org.apache.logging.log4j.core.config.builder.api.ConfigurationBuilderFactory
import org.eclipse.jetty.server.Handler
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.handler.DefaultHandler
import org.eclipse.jetty.server.handler.HandlerList
import org.eclipse.jetty.server.handler.ResourceHandler
import org.json4s.JsonAST.JValue
import org.json4s.jackson.JsonMethods.{compact, render}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.util.{SparkTestUtils, Utils}
/**
* Utilities for tests. Included in main codebase since it's used by multiple
* projects.
*
* TODO: See if we can move this to the test codebase by specifying
* test dependencies between projects.
*/
private[spark] object TestUtils extends SparkTestUtils {
/**
* Create a jar that defines classes with the given names.
*
* Note: if this is used during class loader tests, class names should be unique
* in order to avoid interference between tests.
*/
def createJarWithClasses(
classNames: Seq[String],
toStringValue: String = "",
classNamesWithBase: Seq[(String, String)] = Seq.empty,
classpathUrls: Seq[URL] = Seq.empty): URL = {
val tempDir = Utils.createTempDir()
val files1 = for (name <- classNames) yield {
createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls)
}
val files2 = for ((childName, baseName) <- classNamesWithBase) yield {
createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls)
}
val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis()))
createJar(files1 ++ files2, jarFile)
}
/**
* Create a jar file containing multiple files. The `files` map contains a mapping of
* file names in the jar file to their contents.
*/
def createJarWithFiles(files: Map[String, String], dir: File = null): URL = {
val tempDir = Option(dir).getOrElse(Utils.createTempDir())
val jarFile = File.createTempFile("testJar", ".jar", tempDir)
val jarStream = new JarOutputStream(new FileOutputStream(jarFile))
files.foreach { case (k, v) =>
val entry = new JarEntry(k)
jarStream.putNextEntry(entry)
ByteStreams.copy(new ByteArrayInputStream(v.getBytes(StandardCharsets.UTF_8)), jarStream)
}
jarStream.close()
jarFile.toURI.toURL
}
/**
* Create a jar file that contains this set of files. All files will be located in the specified
* directory or at the root of the jar.
*/
def createJar(
files: Seq[File],
jarFile: File,
directoryPrefix: Option[String] = None,
mainClass: Option[String] = None): URL = {
val manifest = mainClass match {
case Some(mc) =>
val m = new Manifest()
m.getMainAttributes.putValue("Manifest-Version", "1.0")
m.getMainAttributes.putValue("Main-Class", mc)
m
case None =>
new Manifest()
}
val jarFileStream = new FileOutputStream(jarFile)
val jarStream = new JarOutputStream(jarFileStream, manifest)
for (file <- files) {
// The `name` for the argument in `JarEntry` should use / for its separator. This is
// ZIP specification.
val prefix = directoryPrefix.map(d => s"$d/").getOrElse("")
val jarEntry = new JarEntry(prefix + file.getName)
jarStream.putNextEntry(jarEntry)
val in = new FileInputStream(file)
ByteStreams.copy(in, jarStream)
in.close()
}
jarStream.close()
jarFileStream.close()
jarFile.toURI.toURL
}
/**
* Run some code involving jobs submitted to the given context and assert that the jobs spilled.
*/
def assertSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = {
val listener = new SpillListener
withListener(sc, listener) { _ =>
body
}
assert(listener.numSpilledStages > 0, s"expected $identifier to spill, but did not")
}
/**
* Run some code involving jobs submitted to the given context and assert that the jobs
* did not spill.
*/
def assertNotSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = {
val listener = new SpillListener
withListener(sc, listener) { _ =>
body
}
assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did")
}
/**
* Asserts that exception message contains the message. Please note this checks all
* exceptions in the tree. If a type parameter `E` is supplied, this will additionally confirm
* that the exception is a subtype of the exception provided in the type parameter.
*/
def assertExceptionMsg[E <: Throwable : ClassTag](
exception: Throwable,
msg: String,
ignoreCase: Boolean = false): Unit = {
val (typeMsg, typeCheck) = if (classTag[E] == classTag[Nothing]) {
("", (_: Throwable) => true)
} else {
val clazz = classTag[E].runtimeClass
(s"of type ${clazz.getName} ", (e: Throwable) => clazz.isAssignableFrom(e.getClass))
}
def contain(e: Throwable, msg: String): Boolean = {
if (ignoreCase) {
e.getMessage.toLowerCase(Locale.ROOT).contains(msg.toLowerCase(Locale.ROOT))
} else {
e.getMessage.contains(msg)
} && typeCheck(e)
}
var e = exception
var contains = contain(e, msg)
while (e.getCause != null && !contains) {
e = e.getCause
contains = contain(e, msg)
}
assert(contains,
s"Exception tree doesn't contain the expected exception ${typeMsg}with message: $msg\n" +
Utils.exceptionString(e))
}
/**
* Test if a command is available.
*/
def testCommandAvailable(command: String): Boolean = Utils.checkCommandAvailable(command)
// SPARK-40053: This string needs to be updated when the
// minimum python supported version changes.
val minimumPythonSupportedVersion: String = "3.7.0"
def isPythonVersionAvailable: Boolean = {
val version = minimumPythonSupportedVersion.split('.').map(_.toInt)
assert(version.length == 3)
isPythonVersionAtLeast(version(0), version(1), version(2))
}
private def isPythonVersionAtLeast(major: Int, minor: Int, reversion: Int): Boolean = {
val cmdSeq = if (Utils.isWindows) Seq("cmd.exe", "/C") else Seq("sh", "-c")
val pythonSnippet = s"import sys; sys.exit(sys.version_info < ($major, $minor, $reversion))"
Try(Process(cmdSeq :+ s"python3 -c '$pythonSnippet'").! == 0).getOrElse(false)
}
/**
* Get the absolute path from the executable. This implementation was borrowed from
* `spark/dev/sparktestsupport/shellutils.py`.
*/
def getAbsolutePathFromExecutable(executable: String): Option[String] = {
val command = if (Utils.isWindows) s"$executable.exe" else executable
if (command.split(File.separator, 2).length == 1 &&
JavaFiles.isRegularFile(Paths.get(command)) &&
JavaFiles.isExecutable(Paths.get(command))) {
Some(Paths.get(command).toAbsolutePath.toString)
} else {
sys.env("PATH").split(Pattern.quote(File.pathSeparator))
.map(path => Paths.get(s"${StringUtils.strip(path, "\"")}${File.separator}$command"))
.find(p => JavaFiles.isRegularFile(p) && JavaFiles.isExecutable(p))
.map(_.toString)
}
}
/**
* Returns the response code from an HTTP(S) URL.
*/
def httpResponseCode(
url: URL,
method: String = "GET",
headers: Seq[(String, String)] = Nil): Int = {
withHttpConnection(url, method, headers = headers) { connection =>
connection.getResponseCode()
}
}
/**
* Returns the Location header from an HTTP(S) URL.
*/
def redirectUrl(
url: URL,
method: String = "GET",
headers: Seq[(String, String)] = Nil): String = {
withHttpConnection(url, method, headers = headers) { connection =>
connection.getHeaderField("Location");
}
}
/**
* Returns the response message from an HTTP(S) URL.
*/
def httpResponseMessage(
url: URL,
method: String = "GET",
headers: Seq[(String, String)] = Nil): String = {
withHttpConnection(url, method, headers = headers) { connection =>
Source.fromInputStream(connection.getInputStream, "utf-8").getLines().mkString("\n")
}
}
def withHttpConnection[T](
url: URL,
method: String = "GET",
headers: Seq[(String, String)] = Nil)
(fn: HttpURLConnection => T): T = {
val connection = url.openConnection().asInstanceOf[HttpURLConnection]
connection.setRequestMethod(method)
headers.foreach { case (k, v) => connection.setRequestProperty(k, v) }
connection match {
// Disable cert and host name validation for HTTPS tests.
case httpConnection: HttpsURLConnection =>
val sslCtx = SSLContext.getInstance("SSL")
val trustManager = new X509TrustManager {
override def getAcceptedIssuers: Array[X509Certificate] = null
override def checkClientTrusted(x509Certificates: Array[X509Certificate],
s: String): Unit = {}
override def checkServerTrusted(x509Certificates: Array[X509Certificate],
s: String): Unit = {}
}
val verifier = new HostnameVerifier() {
override def verify(hostname: String, session: SSLSession): Boolean = true
}
sslCtx.init(null, Array(trustManager), new SecureRandom())
httpConnection.setSSLSocketFactory(sslCtx.getSocketFactory)
httpConnection.setHostnameVerifier(verifier)
case _ => // do nothing
}
try {
connection.connect()
fn(connection)
} finally {
connection.disconnect()
}
}
/**
* Runs some code with the given listener installed in the SparkContext. After the code runs,
* this method will wait until all events posted to the listener bus are processed, and then
* remove the listener from the bus.
*/
def withListener[L <: SparkListener](sc: SparkContext, listener: L) (body: L => Unit): Unit = {
sc.addSparkListener(listener)
try {
body(listener)
} finally {
sc.listenerBus.waitUntilEmpty()
sc.listenerBus.removeListener(listener)
}
}
def withHttpServer(resBaseDir: String = ".")(body: URL => Unit): Unit = {
// 0 as port means choosing randomly from the available ports
val server = new Server(new InetSocketAddress(Utils.localCanonicalHostName(), 0))
val resHandler = new ResourceHandler()
resHandler.setResourceBase(resBaseDir)
val handlers = new HandlerList()
handlers.setHandlers(Array[Handler](resHandler, new DefaultHandler()))
server.setHandler(handlers)
server.start()
try {
body(server.getURI.toURL)
} finally {
server.stop()
}
}
/**
* Wait until at least `numExecutors` executors are up, or throw `TimeoutException` if the waiting
* time elapsed before `numExecutors` executors up. Exposed for testing.
*
* @param numExecutors the number of executors to wait at least
* @param timeout time to wait in milliseconds
*/
private[spark] def waitUntilExecutorsUp(
sc: SparkContext,
numExecutors: Int,
timeout: Long): Unit = {
val finishTime = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(timeout)
while (System.nanoTime() < finishTime) {
if (sc.statusTracker.getExecutorInfos.length > numExecutors) {
return
}
// Sleep rather than using wait/notify, because this is used only for testing and wait/notify
// add overhead in the general case.
Thread.sleep(10)
}
throw new TimeoutException(
s"Can't find $numExecutors executors before $timeout milliseconds elapsed")
}
/**
* config a log4j2 properties used for testsuite
*/
def configTestLog4j2(level: String): Unit = {
val builder = ConfigurationBuilderFactory.newConfigurationBuilder()
val appenderBuilder = builder.newAppender("console", "CONSOLE")
.addAttribute("target", ConsoleAppender.Target.SYSTEM_ERR)
appenderBuilder.add(builder.newLayout("PatternLayout")
.addAttribute("pattern", "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n%ex"))
builder.add(appenderBuilder)
builder.add(builder.newRootLogger(level).add(builder.newAppenderRef("console")))
val configuration = builder.build()
LogManager.getContext(false).asInstanceOf[LoggerContext].reconfigure(configuration)
}
/**
* Lists files recursively.
*/
def recursiveList(f: File): Array[File] = {
require(f.isDirectory)
val current = f.listFiles
current ++ current.filter(_.isDirectory).flatMap(recursiveList)
}
/**
* Returns the list of files at 'path' recursively. This skips files that are ignored normally
* by MapReduce.
*/
def listDirectory(path: File): Array[String] = {
val result = ArrayBuffer.empty[String]
if (path.isDirectory) {
path.listFiles.foreach(f => result.appendAll(listDirectory(f)))
} else {
val c = path.getName.charAt(0)
if (c != '.' && c != '_') result.append(path.getAbsolutePath)
}
result.toArray
}
/** Creates a temp JSON file that contains the input JSON record. */
def createTempJsonFile(dir: File, prefix: String, jsonValue: JValue): String = {
val file = File.createTempFile(prefix, ".json", dir)
JavaFiles.write(file.toPath, compact(render(jsonValue)).getBytes())
file.getPath
}
/** Creates a temp bash script that prints the given output. */
def createTempScriptWithExpectedOutput(dir: File, prefix: String, output: String): String = {
val file = File.createTempFile(prefix, ".sh", dir)
val script = s"cat <<EOF\n$output\nEOF\n"
Files.write(script, file, StandardCharsets.UTF_8)
JavaFiles.setPosixFilePermissions(file.toPath,
EnumSet.of(OWNER_READ, OWNER_EXECUTE, OWNER_WRITE))
file.getPath
}
}
/**
* A `SparkListener` that detects whether spills have occurred in Spark jobs.
*/
private class SpillListener extends SparkListener {
private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]]
private val spilledStageIds = new mutable.HashSet[Int]
def numSpilledStages: Int = synchronized {
spilledStageIds.size
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
stageIdToTaskMetrics.getOrElseUpdate(
taskEnd.stageId, new ArrayBuffer[TaskMetrics]) += taskEnd.taskMetrics
}
override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = synchronized {
val stageId = stageComplete.stageInfo.stageId
val metrics = stageIdToTaskMetrics.remove(stageId).toSeq.flatten
val spilled = metrics.map(_.memoryBytesSpilled).sum > 0
if (spilled) {
spilledStageIds += stageId
}
}
}