blob: c1cd361b7f9993e08ba2c37771f129b1d914fe71 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import shutil
import stat
import tempfile
import threading
import time
import unittest
from collections import namedtuple
from pyspark import SparkConf, SparkFiles, SparkContext
from pyspark.testing.sqlutils import SPARK_HOME
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest
class CheckpointTests(ReusedPySparkTestCase):
def setUp(self):
self.checkpointDir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.checkpointDir.name)
self.sc.setCheckpointDir(self.checkpointDir.name)
def tearDown(self):
shutil.rmtree(self.checkpointDir.name)
def test_basic_checkpointing(self):
parCollection = self.sc.parallelize([1, 2, 3, 4])
flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
self.assertFalse(flatMappedRDD.isCheckpointed())
self.assertTrue(flatMappedRDD.getCheckpointFile() is None)
self.assertFalse(self.sc.getCheckpointDir() is None)
flatMappedRDD.checkpoint()
result = flatMappedRDD.collect()
time.sleep(1) # 1 second
self.assertTrue(flatMappedRDD.isCheckpointed())
self.assertEqual(flatMappedRDD.collect(), result)
self.assertEqual(
"file:" + self.checkpointDir.name,
os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile())),
)
self.assertEqual(
self.sc.getCheckpointDir(), os.path.dirname(flatMappedRDD.getCheckpointFile())
)
def test_checkpoint_and_restore(self):
parCollection = self.sc.parallelize([1, 2, 3, 4])
flatMappedRDD = parCollection.flatMap(lambda x: [x])
self.assertFalse(flatMappedRDD.isCheckpointed())
self.assertTrue(flatMappedRDD.getCheckpointFile() is None)
flatMappedRDD.checkpoint()
flatMappedRDD.count() # forces a checkpoint to be computed
time.sleep(1) # 1 second
self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
recovered = self.sc._checkpointFile(
flatMappedRDD.getCheckpointFile(), flatMappedRDD._jrdd_deserializer
)
self.assertEqual([1, 2, 3, 4], recovered.collect())
class LocalCheckpointTests(ReusedPySparkTestCase):
def test_basic_localcheckpointing(self):
parCollection = self.sc.parallelize([1, 2, 3, 4])
flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
self.assertFalse(flatMappedRDD.isCheckpointed())
self.assertFalse(flatMappedRDD.isLocallyCheckpointed())
flatMappedRDD.localCheckpoint()
result = flatMappedRDD.collect()
time.sleep(1) # 1 second
self.assertTrue(flatMappedRDD.isCheckpointed())
self.assertTrue(flatMappedRDD.isLocallyCheckpointed())
self.assertEqual(flatMappedRDD.collect(), result)
class AddFileTests(PySparkTestCase):
def test_add_py_file(self):
# To ensure that we're actually testing addPyFile's effects, check that
# this job fails due to `userlibrary` not being on the Python path:
# disable logging in log4j temporarily
def func(x):
from userlibrary import UserClass
return UserClass().hello()
with QuietTest(self.sc):
self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first)
# Add the file, so the job should now succeed:
path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
self.sc.addPyFile(path)
res = self.sc.parallelize(range(2)).map(func).first()
self.assertEqual("Hello World!", res)
def test_add_file_locally(self):
path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
self.sc.addFile(path)
download_path = SparkFiles.get("hello.txt")
self.assertNotEqual(path, download_path)
with open(download_path) as test_file:
self.assertEqual("Hello World!\n", test_file.readline())
def test_add_file_recursively_locally(self):
path = os.path.join(SPARK_HOME, "python/test_support/hello")
self.sc.addFile(path, True)
download_path = SparkFiles.get("hello")
self.assertNotEqual(path, download_path)
with open(download_path + "/hello.txt") as test_file:
self.assertEqual("Hello World!\n", test_file.readline())
with open(download_path + "/sub_hello/sub_hello.txt") as test_file:
self.assertEqual("Sub Hello World!\n", test_file.readline())
def test_add_py_file_locally(self):
# To ensure that we're actually testing addPyFile's effects, check that
# this fails due to `userlibrary` not being on the Python path:
def func():
from userlibrary import UserClass # noqa: F401
self.assertRaises(ImportError, func)
path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
self.sc.addPyFile(path)
from userlibrary import UserClass
self.assertEqual("Hello World!", UserClass().hello())
def test_add_egg_file_locally(self):
# To ensure that we're actually testing addPyFile's effects, check that
# this fails due to `userlibrary` not being on the Python path:
def func():
from userlib import UserClass
UserClass()
self.assertRaises(ImportError, func)
path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip")
self.sc.addPyFile(path)
from userlib import UserClass
self.assertEqual("Hello World from inside a package!", UserClass().hello())
def test_overwrite_system_module(self):
self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py"))
import SimpleHTTPServer
self.assertEqual("My Server", SimpleHTTPServer.__name__)
def func(x):
import SimpleHTTPServer
return SimpleHTTPServer.__name__
self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())
class ContextTests(unittest.TestCase):
def test_failed_sparkcontext_creation(self):
# Regression test for SPARK-1550
self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
def test_get_or_create(self):
with SparkContext.getOrCreate() as sc:
self.assertTrue(SparkContext.getOrCreate() is sc)
def test_parallelize_eager_cleanup(self):
with SparkContext() as sc:
temp_files = os.listdir(sc._temp_dir)
sc.parallelize([0, 1, 2])
post_parallelize_temp_files = os.listdir(sc._temp_dir)
self.assertEqual(temp_files, post_parallelize_temp_files)
def test_set_conf(self):
# This is for an internal use case. When there is an existing SparkContext,
# SparkSession's builder needs to set configs into SparkContext's conf.
sc = SparkContext()
sc._conf.set("spark.test.SPARK16224", "SPARK16224")
self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224")
sc.stop()
def test_stop(self):
sc = SparkContext()
self.assertNotEqual(SparkContext._active_spark_context, None)
sc.stop()
self.assertEqual(SparkContext._active_spark_context, None)
def test_with(self):
with SparkContext():
self.assertNotEqual(SparkContext._active_spark_context, None)
self.assertEqual(SparkContext._active_spark_context, None)
def test_with_exception(self):
try:
with SparkContext():
self.assertNotEqual(SparkContext._active_spark_context, None)
raise RuntimeError()
except BaseException:
pass
self.assertEqual(SparkContext._active_spark_context, None)
def test_with_stop(self):
with SparkContext() as sc:
self.assertNotEqual(SparkContext._active_spark_context, None)
sc.stop()
self.assertEqual(SparkContext._active_spark_context, None)
def test_progress_api(self):
with SparkContext() as sc:
sc.setJobGroup("test_progress_api", "", True)
rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100))
def run():
# When thread is pinned, job group should be set for each thread for now.
# Local properties seem not being inherited like Scala side does.
if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
sc.setJobGroup("test_progress_api", "", True)
try:
rdd.count()
except Exception:
pass
t = threading.Thread(target=run)
t.daemon = True
t.start()
# wait for scheduler to start
time.sleep(3)
tracker = sc.statusTracker()
jobIds = tracker.getJobIdsForGroup("test_progress_api")
self.assertEqual(1, len(jobIds))
job = tracker.getJobInfo(jobIds[0])
self.assertEqual(1, len(job.stageIds))
stage = tracker.getStageInfo(job.stageIds[0])
self.assertEqual(rdd.getNumPartitions(), stage.numTasks)
sc.cancelAllJobs()
t.join()
# wait for event listener to update the status
time.sleep(1)
job = tracker.getJobInfo(jobIds[0])
self.assertEqual("FAILED", job.status)
self.assertEqual([], tracker.getActiveJobsIds())
self.assertEqual([], tracker.getActiveStageIds())
sc.stop()
def test_startTime(self):
with SparkContext() as sc:
self.assertGreater(sc.startTime, 0)
def test_forbid_insecure_gateway(self):
# Fail immediately if you try to create a SparkContext
# with an insecure gateway
parameters = namedtuple("MockGatewayParameters", "auth_token")(None)
mock_insecure_gateway = namedtuple("MockJavaGateway", "gateway_parameters")(parameters)
with self.assertRaises(ValueError) as context:
SparkContext(gateway=mock_insecure_gateway)
self.assertIn("insecure Py4j gateway", str(context.exception))
def test_resources(self):
"""Test the resources are empty by default."""
with SparkContext() as sc:
resources = sc.resources
self.assertEqual(len(resources), 0)
def test_disallow_to_create_spark_context_in_executors(self):
# SPARK-32160: SparkContext should not be created in executors.
with SparkContext("local-cluster[3, 1, 1024]") as sc:
with self.assertRaises(Exception) as context:
sc.range(2).foreach(lambda _: SparkContext())
self.assertIn(
"CONTEXT_ONLY_VALID_ON_DRIVER",
str(context.exception),
)
def test_allow_to_create_spark_context_in_executors(self):
# SPARK-32160: SparkContext can be created in executors if the config is set.
def create_spark_context():
conf = SparkConf().set("spark.executor.allowSparkContext", "true")
with SparkContext(conf=conf):
pass
with SparkContext("local-cluster[3, 1, 1024]") as sc:
sc.range(2).foreach(lambda _: create_spark_context())
class ContextTestsWithResources(unittest.TestCase):
def setUp(self):
class_name = self.__class__.__name__
self.tempFile = tempfile.NamedTemporaryFile(delete=False)
self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}')
self.tempFile.close()
# create temporary directory for Worker resources coordination
self.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.tempdir.name)
os.chmod(
self.tempFile.name,
stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
)
conf = SparkConf().set("spark.test.home", SPARK_HOME)
conf = conf.set("spark.driver.resource.gpu.amount", "1")
conf = conf.set("spark.driver.resource.gpu.discoveryScript", self.tempFile.name)
self.sc = SparkContext("local-cluster[2,1,1024]", class_name, conf=conf)
def test_resources(self):
"""Test the resources are available."""
resources = self.sc.resources
self.assertEqual(len(resources), 1)
self.assertTrue("gpu" in resources)
self.assertEqual(resources["gpu"].name, "gpu")
self.assertEqual(resources["gpu"].addresses, ["0"])
def tearDown(self):
os.unlink(self.tempFile.name)
self.sc.stop()
if __name__ == "__main__":
from pyspark.tests.test_context import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)