blob: b977fdf51024880f23508bbd1ed3faf8e1997060 [file]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import re
import tempfile
import unittest
import urllib.request
from pyspark.install import (
get_preferred_mirrors,
install_spark,
DEFAULT_HADOOP,
DEFAULT_HIVE,
UNSUPPORTED_COMBINATIONS,
checked_versions,
checked_package_name,
)
class SparkInstallationTestCase(unittest.TestCase):
def get_latest_spark_version(self):
if "PYSPARK_RELEASE_MIRROR" in os.environ:
sites = [os.environ["PYSPARK_RELEASE_MIRROR"]]
else:
sites = get_preferred_mirrors()
# Filter out the archive sites
sites = [site for site in sites if "archive.apache.org" not in site]
for site in sites:
url = site + "/spark/"
try:
with urllib.request.urlopen(url) as response:
html = response.read().decode("utf-8")
versions = re.findall(r"spark-(\d+[\.-]\d+[\.-]\d+)/", html)
versions = [v.replace("-", ".") for v in versions]
return max(versions)
except Exception:
continue
return None
def test_install_spark(self):
# Test only one case. Testing this is expensive because it needs to download
# the Spark distribution. We try to get the latest version, but if we can't,
# we just use a hard-coded version.
spark_version = self.get_latest_spark_version() or "4.1.1"
spark_version, hadoop_version, hive_version = checked_versions(spark_version, "3", "2.3")
with tempfile.TemporaryDirectory(prefix="test_install_spark") as tmp_dir:
install_spark(
dest=tmp_dir,
spark_version=spark_version,
hadoop_version=hadoop_version,
hive_version=hive_version,
)
self.assertTrue(os.path.isdir("%s/jars" % tmp_dir))
self.assertTrue(os.path.exists("%s/bin/spark-submit" % tmp_dir))
self.assertTrue(os.path.exists("%s/RELEASE" % tmp_dir))
def test_package_name(self):
self.assertEqual(
"spark-3.0.0-bin-hadoop3.2", checked_package_name("spark-3.0.0", "hadoop3.2", "hive2.3")
)
spark_version, hadoop_version, hive_version = checked_versions("3.2.0", "3", "2.3")
self.assertEqual(
"spark-3.2.0-bin-hadoop3.2",
checked_package_name(spark_version, hadoop_version, hive_version),
)
spark_version, hadoop_version, hive_version = checked_versions("3.3.0", "3", "2.3")
self.assertEqual(
"spark-3.3.0-bin-hadoop3",
checked_package_name(spark_version, hadoop_version, hive_version),
)
def test_checked_versions(self):
test_version = "3.0.1" # Just pick one version to test.
# Positive test cases
self.assertEqual(
("spark-2.4.1", "without-hadoop", "hive2.3"),
checked_versions("2.4.1", "without", "2.3"),
)
self.assertEqual(
("spark-3.0.1", "without-hadoop", "hive2.3"),
checked_versions("spark-3.0.1", "without-hadoop", "hive2.3"),
)
self.assertEqual(
("spark-3.3.0", "hadoop3", "hive2.3"),
checked_versions("spark-3.3.0", "hadoop3", "hive2.3"),
)
# Negative test cases
for hadoop_version, hive_version in UNSUPPORTED_COMBINATIONS:
with self.assertRaisesRegex(RuntimeError, "Hive.*should.*Hadoop"):
checked_versions(
spark_version=test_version,
hadoop_version=hadoop_version,
hive_version=hive_version,
)
with self.assertRaisesRegex(RuntimeError, "Spark version should start with 'spark-'"):
checked_versions(
spark_version="malformed", hadoop_version=DEFAULT_HADOOP, hive_version=DEFAULT_HIVE
)
with self.assertRaisesRegex(RuntimeError, "Spark distribution.*malformed.*"):
checked_versions(
spark_version=test_version, hadoop_version="malformed", hive_version=DEFAULT_HIVE
)
with self.assertRaisesRegex(RuntimeError, "Spark distribution.*malformed.*"):
checked_versions(
spark_version=test_version, hadoop_version=DEFAULT_HADOOP, hive_version="malformed"
)
with self.assertRaisesRegex(RuntimeError, "Spark distribution of hive1.2 is not supported"):
checked_versions(
spark_version=test_version, hadoop_version="hadoop3", hive_version="hive1.2"
)
if __name__ == "__main__":
from pyspark.testing import main
main()