blob: b677a5134ec9325257bf2a34019911c4320b9d75 [file] [log] [blame]
#!/usr/bin/env python3
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from argparse import ArgumentParser
import os
import re
import shutil
import subprocess
import sys
import tempfile
from threading import Thread, Lock
import time
import uuid
if sys.version < '3':
import Queue
else:
import queue as Queue
from multiprocessing import Manager
# Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../dev/"))
from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings)
from sparktestsupport.shellutils import which, subprocess_check_output # noqa
from sparktestsupport.modules import all_modules, pyspark_sql # noqa
python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root')
def print_red(text):
print('\033[31m' + text + '\033[0m')
SKIPPED_TESTS = Manager().dict()
LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log")
FAILURE_REPORTING_LOCK = Lock()
LOGGER = logging.getLogger()
# Find out where the assembly jars are located.
# TODO: revisit for Scala 2.13
for scala in ["2.12"]:
build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala)
if os.path.isdir(build_dir):
SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*")
break
else:
raise Exception("Cannot find assembly build directory, please build Spark first.")
def run_individual_python_test(target_dir, test_name, pyspark_python):
env = dict(os.environ)
env.update({
'SPARK_DIST_CLASSPATH': SPARK_DIST_CLASSPATH,
'SPARK_TESTING': '1',
'SPARK_PREPEND_CLASSES': '1',
'PYSPARK_PYTHON': which(pyspark_python),
'PYSPARK_DRIVER_PYTHON': which(pyspark_python),
'PYSPARK_ROW_FIELD_SORTING_ENABLED': 'true'
})
# Create a unique temp directory under 'target/' for each run. The TMPDIR variable is
# recognized by the tempfile module to override the default system temp directory.
tmp_dir = os.path.join(target_dir, str(uuid.uuid4()))
while os.path.isdir(tmp_dir):
tmp_dir = os.path.join(target_dir, str(uuid.uuid4()))
os.mkdir(tmp_dir)
env["TMPDIR"] = tmp_dir
# Also override the JVM's temp directory by setting driver and executor options.
java_options = "-Djava.io.tmpdir={0} -Dio.netty.tryReflectionSetAccessible=true".format(tmp_dir)
spark_args = [
"--conf", "spark.driver.extraJavaOptions='{0}'".format(java_options),
"--conf", "spark.executor.extraJavaOptions='{0}'".format(java_options),
"pyspark-shell"
]
env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args)
LOGGER.info("Starting test(%s): %s", pyspark_python, test_name)
start_time = time.time()
try:
per_test_output = tempfile.TemporaryFile()
retcode = subprocess.Popen(
[os.path.join(SPARK_HOME, "bin/pyspark")] + test_name.split(),
stderr=per_test_output, stdout=per_test_output, env=env).wait()
shutil.rmtree(tmp_dir, ignore_errors=True)
except:
LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python)
# Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
# this code is invoked from a thread other than the main thread.
os._exit(1)
duration = time.time() - start_time
# Exit on the first failure.
if retcode != 0:
try:
with FAILURE_REPORTING_LOCK:
with open(LOG_FILE, 'ab') as log_file:
per_test_output.seek(0)
log_file.writelines(per_test_output)
per_test_output.seek(0)
for line in per_test_output:
decoded_line = line.decode("utf-8", "replace")
if not re.match('[0-9]+', decoded_line):
print(decoded_line, end='')
per_test_output.close()
except:
LOGGER.exception("Got an exception while trying to print failed test output")
finally:
print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python))
# Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
# this code is invoked from a thread other than the main thread.
os._exit(-1)
else:
skipped_counts = 0
try:
per_test_output.seek(0)
# Here expects skipped test output from unittest when verbosity level is
# 2 (or --verbose option is enabled).
decoded_lines = map(lambda line: line.decode("utf-8", "replace"), iter(per_test_output))
skipped_tests = list(filter(
lambda line: re.search(r'test_.* \(pyspark\..*\) ... (skip|SKIP)', line),
decoded_lines))
skipped_counts = len(skipped_tests)
if skipped_counts > 0:
key = (pyspark_python, test_name)
SKIPPED_TESTS[key] = skipped_tests
per_test_output.close()
except:
import traceback
print_red("\nGot an exception while trying to store "
"skipped test output:\n%s" % traceback.format_exc())
# Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
# this code is invoked from a thread other than the main thread.
os._exit(-1)
if skipped_counts != 0:
LOGGER.info(
"Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python, test_name,
duration, skipped_counts)
else:
LOGGER.info(
"Finished test(%s): %s (%is)", pyspark_python, test_name, duration)
def get_default_python_executables():
python_execs = [x for x in ["python3.6", "python2.7", "pypy"] if which(x)]
if "python3.6" not in python_execs:
p = which("python3")
if not p:
LOGGER.error("No python3 executable found. Exiting!")
os._exit(1)
else:
python_execs.insert(0, p)
return python_execs
def parse_opts():
parser = ArgumentParser(
prog="run-tests"
)
parser.add_argument(
"--python-executables", type=str, default=','.join(get_default_python_executables()),
help="A comma-separated list of Python executables to test against (default: %(default)s)"
)
parser.add_argument(
"--modules", type=str,
default=",".join(sorted(python_modules.keys())),
help="A comma-separated list of Python modules to test (default: %(default)s)"
)
parser.add_argument(
"-p", "--parallelism", type=int, default=4,
help="The number of suites to test in parallel (default %(default)d)"
)
parser.add_argument(
"--verbose", action="store_true",
help="Enable additional debug logging"
)
group = parser.add_argument_group("Developer Options")
group.add_argument(
"--testnames", type=str,
default=None,
help=(
"A comma-separated list of specific modules, classes and functions of doctest "
"or unittest to test. "
"For example, 'pyspark.sql.foo' to run the module as unittests or doctests, "
"'pyspark.sql.tests FooTests' to run the specific class of unittests, "
"'pyspark.sql.tests FooTests.test_foo' to run the specific unittest in the class. "
"'--modules' option is ignored if they are given.")
)
args, unknown = parser.parse_known_args()
if unknown:
parser.error("Unsupported arguments: %s" % ' '.join(unknown))
if args.parallelism < 1:
parser.error("Parallelism cannot be less than 1")
return args
def _check_coverage(python_exec):
# Make sure if coverage is installed.
try:
subprocess_check_output(
[python_exec, "-c", "import coverage"],
stderr=open(os.devnull, 'w'))
except:
print_red("Coverage is not installed in Python executable '%s' "
"but 'COVERAGE_PROCESS_START' environment variable is set, "
"exiting." % python_exec)
sys.exit(-1)
def main():
opts = parse_opts()
if opts.verbose:
log_level = logging.DEBUG
else:
log_level = logging.INFO
should_test_modules = opts.testnames is None
logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s")
LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE)
if os.path.exists(LOG_FILE):
os.remove(LOG_FILE)
python_execs = opts.python_executables.split(',')
LOGGER.info("Will test against the following Python executables: %s", python_execs)
if should_test_modules:
modules_to_test = []
for module_name in opts.modules.split(','):
if module_name in python_modules:
modules_to_test.append(python_modules[module_name])
else:
print("Error: unrecognized module '%s'. Supported modules: %s" %
(module_name, ", ".join(python_modules)))
sys.exit(-1)
LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test])
else:
testnames_to_test = opts.testnames.split(',')
LOGGER.info("Will test the following Python tests: %s", testnames_to_test)
task_queue = Queue.PriorityQueue()
for python_exec in python_execs:
# Check if the python executable has coverage installed when 'COVERAGE_PROCESS_START'
# environmental variable is set.
if "COVERAGE_PROCESS_START" in os.environ:
_check_coverage(python_exec)
python_implementation = subprocess_check_output(
[python_exec, "-c", "import platform; print(platform.python_implementation())"],
universal_newlines=True).strip()
LOGGER.info("%s python_implementation is %s", python_exec, python_implementation)
LOGGER.info("%s version is: %s", python_exec, subprocess_check_output(
[python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip())
if should_test_modules:
for module in modules_to_test:
if python_implementation not in module.blacklisted_python_implementations:
for test_goal in module.python_test_goals:
heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests',
'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests']
if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)):
priority = 0
else:
priority = 100
task_queue.put((priority, (python_exec, test_goal)))
else:
for test_goal in testnames_to_test:
task_queue.put((0, (python_exec, test_goal)))
# Create the target directory before starting tasks to avoid races.
target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target'))
if not os.path.isdir(target_dir):
os.mkdir(target_dir)
def process_queue(task_queue):
while True:
try:
(priority, (python_exec, test_goal)) = task_queue.get_nowait()
except Queue.Empty:
break
try:
run_individual_python_test(target_dir, test_goal, python_exec)
finally:
task_queue.task_done()
start_time = time.time()
for _ in range(opts.parallelism):
worker = Thread(target=process_queue, args=(task_queue,))
worker.daemon = True
worker.start()
try:
task_queue.join()
except (KeyboardInterrupt, SystemExit):
print_red("Exiting due to interrupt")
sys.exit(-1)
total_duration = time.time() - start_time
LOGGER.info("Tests passed in %i seconds", total_duration)
for key, lines in sorted(SKIPPED_TESTS.items()):
pyspark_python, test_name = key
LOGGER.info("\nSkipped tests in %s with %s:" % (test_name, pyspark_python))
for line in lines:
LOGGER.info(" %s" % line.rstrip())
if __name__ == "__main__":
main()