| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| import os |
| import tempfile |
| import unittest |
| |
| from pyspark.install import ( |
| install_spark, |
| DEFAULT_HADOOP, |
| DEFAULT_HIVE, |
| UNSUPPORTED_COMBINATIONS, |
| checked_versions, |
| checked_package_name, |
| ) |
| |
| |
| class SparkInstallationTestCase(unittest.TestCase): |
| def test_install_spark(self): |
| # Test only one case. Testing this is expensive because it needs to download |
| # the Spark distribution, ensure it is available at https://dlcdn.apache.org/spark/ |
| spark_version, hadoop_version, hive_version = checked_versions("3.5.6", "3", "2.3") |
| |
| with tempfile.TemporaryDirectory(prefix="test_install_spark") as tmp_dir: |
| install_spark( |
| dest=tmp_dir, |
| spark_version=spark_version, |
| hadoop_version=hadoop_version, |
| hive_version=hive_version, |
| ) |
| |
| self.assertTrue(os.path.isdir("%s/jars" % tmp_dir)) |
| self.assertTrue(os.path.exists("%s/bin/spark-submit" % tmp_dir)) |
| self.assertTrue(os.path.exists("%s/RELEASE" % tmp_dir)) |
| |
| def test_package_name(self): |
| self.assertEqual( |
| "spark-3.0.0-bin-hadoop3.2", checked_package_name("spark-3.0.0", "hadoop3.2", "hive2.3") |
| ) |
| |
| spark_version, hadoop_version, hive_version = checked_versions("3.2.0", "3", "2.3") |
| self.assertEqual( |
| "spark-3.2.0-bin-hadoop3.2", |
| checked_package_name(spark_version, hadoop_version, hive_version), |
| ) |
| |
| spark_version, hadoop_version, hive_version = checked_versions("3.3.0", "3", "2.3") |
| self.assertEqual( |
| "spark-3.3.0-bin-hadoop3", |
| checked_package_name(spark_version, hadoop_version, hive_version), |
| ) |
| |
| def test_checked_versions(self): |
| test_version = "3.0.1" # Just pick one version to test. |
| |
| # Positive test cases |
| self.assertEqual( |
| ("spark-2.4.1", "without-hadoop", "hive2.3"), |
| checked_versions("2.4.1", "without", "2.3"), |
| ) |
| |
| self.assertEqual( |
| ("spark-3.0.1", "without-hadoop", "hive2.3"), |
| checked_versions("spark-3.0.1", "without-hadoop", "hive2.3"), |
| ) |
| |
| self.assertEqual( |
| ("spark-3.3.0", "hadoop3", "hive2.3"), |
| checked_versions("spark-3.3.0", "hadoop3", "hive2.3"), |
| ) |
| |
| # Negative test cases |
| for hadoop_version, hive_version in UNSUPPORTED_COMBINATIONS: |
| with self.assertRaisesRegex(RuntimeError, "Hive.*should.*Hadoop"): |
| checked_versions( |
| spark_version=test_version, |
| hadoop_version=hadoop_version, |
| hive_version=hive_version, |
| ) |
| |
| with self.assertRaisesRegex(RuntimeError, "Spark version should start with 'spark-'"): |
| checked_versions( |
| spark_version="malformed", hadoop_version=DEFAULT_HADOOP, hive_version=DEFAULT_HIVE |
| ) |
| |
| with self.assertRaisesRegex(RuntimeError, "Spark distribution.*malformed.*"): |
| checked_versions( |
| spark_version=test_version, hadoop_version="malformed", hive_version=DEFAULT_HIVE |
| ) |
| |
| with self.assertRaisesRegex(RuntimeError, "Spark distribution.*malformed.*"): |
| checked_versions( |
| spark_version=test_version, hadoop_version=DEFAULT_HADOOP, hive_version="malformed" |
| ) |
| |
| with self.assertRaisesRegex(RuntimeError, "Spark distribution of hive1.2 is not supported"): |
| checked_versions( |
| spark_version=test_version, hadoop_version="hadoop3", hive_version="hive1.2" |
| ) |
| |
| |
| if __name__ == "__main__": |
| from pyspark.tests.test_install_spark import * # noqa: F401 |
| |
| try: |
| import xmlrunner |
| |
| testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) |
| except ImportError: |
| testRunner = None |
| unittest.main(testRunner=testRunner, verbosity=2) |