| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| import glob |
| import os |
| import struct |
| import sys |
| import unittest |
| |
| from pyspark import SparkContext, SparkConf |
| |
| |
| have_scipy = False |
| have_numpy = False |
| try: |
| import scipy.sparse |
| have_scipy = True |
| except: |
| # No SciPy, but that's okay, we'll skip those tests |
| pass |
| try: |
| import numpy as np |
| have_numpy = True |
| except: |
| # No NumPy, but that's okay, we'll skip those tests |
| pass |
| |
| |
| SPARK_HOME = os.environ["SPARK_HOME"] |
| |
| |
| def read_int(b): |
| return struct.unpack("!i", b)[0] |
| |
| |
| def write_int(i): |
| return struct.pack("!i", i) |
| |
| |
| class QuietTest(object): |
| def __init__(self, sc): |
| self.log4j = sc._jvm.org.apache.log4j |
| |
| def __enter__(self): |
| self.old_level = self.log4j.LogManager.getRootLogger().getLevel() |
| self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| self.log4j.LogManager.getRootLogger().setLevel(self.old_level) |
| |
| |
| class PySparkTestCase(unittest.TestCase): |
| |
| def setUp(self): |
| self._old_sys_path = list(sys.path) |
| class_name = self.__class__.__name__ |
| self.sc = SparkContext('local[4]', class_name) |
| |
| def tearDown(self): |
| self.sc.stop() |
| sys.path = self._old_sys_path |
| |
| |
| class ReusedPySparkTestCase(unittest.TestCase): |
| |
| @classmethod |
| def conf(cls): |
| """ |
| Override this in subclasses to supply a more specific conf |
| """ |
| return SparkConf() |
| |
| @classmethod |
| def setUpClass(cls): |
| cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf()) |
| |
| @classmethod |
| def tearDownClass(cls): |
| cls.sc.stop() |
| |
| |
| class ByteArrayOutput(object): |
| def __init__(self): |
| self.buffer = bytearray() |
| |
| def write(self, b): |
| self.buffer += b |
| |
| def close(self): |
| pass |
| |
| |
| def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix): |
| # Note that 'sbt_jar_name_prefix' and 'mvn_jar_name_prefix' are used since the prefix can |
| # vary for SBT or Maven specifically. See also SPARK-26856 |
| project_full_path = os.path.join( |
| os.environ["SPARK_HOME"], project_relative_path) |
| |
| # We should ignore the following jars |
| ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar") |
| |
| # Search jar in the project dir using the jar name_prefix for both sbt build and maven |
| # build because the artifact jars are in different directories. |
| sbt_build = glob.glob(os.path.join( |
| project_full_path, "target/scala-*/%s*.jar" % sbt_jar_name_prefix)) |
| maven_build = glob.glob(os.path.join( |
| project_full_path, "target/%s*.jar" % mvn_jar_name_prefix)) |
| jar_paths = sbt_build + maven_build |
| jars = [jar for jar in jar_paths if not jar.endswith(ignored_jar_suffixes)] |
| |
| if not jars: |
| return None |
| elif len(jars) > 1: |
| raise Exception("Found multiple JARs: %s; please remove all but one" % (", ".join(jars))) |
| else: |
| return jars[0] |