| # -*- encoding: utf-8 -*- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| import os |
| import signal |
| import sys |
| import tempfile |
| import threading |
| import time |
| import unittest |
| |
| has_resource_module = True |
| try: |
| import resource |
| except ImportError: |
| has_resource_module = False |
| |
| from py4j.protocol import Py4JJavaError |
| |
| from pyspark import SparkConf, SparkContext |
| from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, eventually |
| |
| |
| class WorkerTests(ReusedPySparkTestCase): |
| def test_cancel_task(self): |
| temp = tempfile.NamedTemporaryFile(delete=True) |
| temp.close() |
| path = temp.name |
| |
| def sleep(x): |
| import os |
| import time |
| |
| with open(path, "w") as f: |
| f.write("%d %d" % (os.getppid(), os.getpid())) |
| time.sleep(100) |
| |
| # start job in background thread |
| def run(): |
| try: |
| self.sc.parallelize(range(1), 1).foreach(sleep) |
| except Exception: |
| pass |
| |
| import threading |
| |
| t = threading.Thread(target=run) |
| t.daemon = True |
| t.start() |
| |
| daemon_pid, worker_pid = 0, 0 |
| cnt = 0 |
| while True: |
| if os.path.exists(path): |
| with open(path) as f: |
| data = f.read().split(" ") |
| try: |
| daemon_pid, worker_pid = map(int, data) |
| except ValueError: |
| # In case the value is not written yet. |
| cnt += 1 |
| if cnt == 10: |
| raise |
| else: |
| break |
| time.sleep(1) |
| |
| # cancel jobs |
| self.sc.cancelAllJobs() |
| t.join() |
| |
| for i in range(50): |
| try: |
| os.kill(worker_pid, 0) |
| time.sleep(0.1) |
| except OSError: |
| break # worker was killed |
| else: |
| self.fail("worker has not been killed after 5 seconds") |
| |
| try: |
| os.kill(daemon_pid, 0) |
| except OSError: |
| self.fail("daemon had been killed") |
| |
| # run a normal job |
| rdd = self.sc.parallelize(range(100), 1) |
| self.assertEqual(100, rdd.map(str).count()) |
| |
| def test_after_exception(self): |
| def raise_exception(_): |
| raise RuntimeError() |
| |
| rdd = self.sc.parallelize(range(100), 1) |
| with QuietTest(self.sc): |
| self.assertRaises(Py4JJavaError, lambda: rdd.foreach(raise_exception)) |
| self.assertEqual(100, rdd.map(str).count()) |
| |
| def test_after_non_exception_error(self): |
| # SPARK-33339: Pyspark application will hang due to non Exception |
| def raise_system_exit(_): |
| raise SystemExit() |
| |
| rdd = self.sc.parallelize(range(100), 1) |
| with QuietTest(self.sc): |
| self.assertRaises(Exception, lambda: rdd.foreach(raise_system_exit)) |
| self.assertEqual(100, rdd.map(str).count()) |
| |
| def test_after_jvm_exception(self): |
| tempFile = tempfile.NamedTemporaryFile(delete=False) |
| tempFile.write(b"Hello World!") |
| tempFile.close() |
| data = self.sc.textFile(tempFile.name, 1) |
| filtered_data = data.filter(lambda x: True) |
| self.assertEqual(1, filtered_data.count()) |
| os.unlink(tempFile.name) |
| with QuietTest(self.sc): |
| self.assertRaises(Exception, lambda: filtered_data.count()) |
| |
| rdd = self.sc.parallelize(range(100), 1) |
| self.assertEqual(100, rdd.map(str).count()) |
| |
| def test_accumulator_when_reuse_worker(self): |
| from pyspark.accumulators import INT_ACCUMULATOR_PARAM |
| |
| acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) |
| self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x)) |
| self.assertEqual(sum(range(100)), acc1.value) |
| |
| acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) |
| self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x)) |
| self.assertEqual(sum(range(100)), acc2.value) |
| self.assertEqual(sum(range(100)), acc1.value) |
| |
| def test_reuse_worker_after_take(self): |
| rdd = self.sc.parallelize(range(100000), 1) |
| self.assertEqual(0, rdd.first()) |
| |
| def count(): |
| try: |
| rdd.count() |
| except Exception: |
| pass |
| |
| t = threading.Thread(target=count) |
| t.daemon = True |
| t.start() |
| t.join(5) |
| self.assertTrue(not t.is_alive()) |
| self.assertEqual(100000, rdd.count()) |
| |
| def test_with_different_versions_of_python(self): |
| rdd = self.sc.parallelize(range(10)) |
| rdd.count() |
| version = self.sc.pythonVer |
| self.sc.pythonVer = "2.0" |
| try: |
| with QuietTest(self.sc): |
| self.assertRaises(Py4JJavaError, lambda: rdd.count()) |
| finally: |
| self.sc.pythonVer = version |
| |
| def test_python_exception_non_hanging(self): |
| # SPARK-21045: exceptions with no ascii encoding shall not hanging PySpark. |
| try: |
| |
| def f(): |
| raise RuntimeError("exception with 中 and \xd6\xd0") |
| |
| self.sc.parallelize([1]).map(lambda x: f()).count() |
| except Py4JJavaError as e: |
| self.assertRegex(str(e), "exception with 中") |
| |
| |
| class WorkerReuseTest(PySparkTestCase): |
| @eventually(catch_assertions=True) |
| def test_reuse_worker_of_parallelize_range(self): |
| rdd = self.sc.parallelize(range(20), 8) |
| previous_pids = rdd.map(lambda x: os.getpid()).collect() |
| current_pids = rdd.map(lambda x: os.getpid()).collect() |
| for pid in current_pids: |
| self.assertTrue(pid in previous_pids) |
| |
| |
| @unittest.skipIf( |
| not has_resource_module or sys.platform != "linux", |
| "Memory limit feature in Python worker is dependent on " |
| "Python's 'resource' module on Linux; however, not found or not on Linux.", |
| ) |
| class WorkerMemoryTest(unittest.TestCase): |
| def setUp(self): |
| class_name = self.__class__.__name__ |
| conf = SparkConf().set("spark.executor.pyspark.memory", "2g") |
| self.sc = SparkContext("local[4]", class_name, conf=conf) |
| |
| def test_memory_limit(self): |
| rdd = self.sc.parallelize(range(1), 1) |
| |
| def getrlimit(): |
| return resource.getrlimit(resource.RLIMIT_AS) |
| |
| actual = rdd.map(lambda _: getrlimit()).collect() |
| self.assertTrue(len(actual) == 1) |
| self.assertTrue(len(actual[0]) == 2) |
| [(soft_limit, hard_limit)] = actual |
| self.assertEqual(soft_limit, 2 * 1024 * 1024 * 1024) |
| self.assertEqual(hard_limit, 2 * 1024 * 1024 * 1024) |
| |
| def tearDown(self): |
| self.sc.stop() |
| |
| |
| class WorkerSegfaultTest(ReusedPySparkTestCase): |
| @classmethod |
| def conf(cls): |
| _conf = super(WorkerSegfaultTest, cls).conf() |
| _conf.set("spark.python.worker.faulthandler.enabled", "true") |
| return _conf |
| |
| @unittest.skipIf(sys.version_info > (3, 12), "SPARK-46130: Flaky with Python 3.12") |
| def test_python_segfault(self): |
| try: |
| |
| def f(): |
| import ctypes |
| |
| ctypes.string_at(0) |
| |
| self.sc.parallelize([1]).map(lambda x: f()).count() |
| except Py4JJavaError as e: |
| self.assertRegex(str(e), "Segmentation fault") |
| |
| |
| @unittest.skipIf( |
| "COVERAGE_PROCESS_START" in os.environ, |
| "Flaky with coverage enabled, skipping for now.", |
| ) |
| class WorkerSegfaultNonDaemonTest(WorkerSegfaultTest): |
| @classmethod |
| def conf(cls): |
| _conf = super(WorkerSegfaultNonDaemonTest, cls).conf() |
| _conf.set("spark.python.use.daemon", "false") |
| return _conf |
| |
| |
| class WorkerPoolCrashTest(PySparkTestCase): |
| def test_worker_crash(self): |
| # SPARK-47565: Kill a worker that is currently idling |
| rdd = self.sc.parallelize(range(20), 4) |
| # first ensure that workers are reused |
| worker_pids1 = set(rdd.map(lambda x: os.getpid()).collect()) |
| worker_pids2 = set(rdd.map(lambda x: os.getpid()).collect()) |
| self.assertEqual(worker_pids1, worker_pids2) |
| for pid in list(worker_pids1)[1:]: # kill all workers except for one |
| os.kill(pid, signal.SIGTERM) |
| # give things a moment to settle |
| time.sleep(5) |
| rdd.map(lambda x: os.getpid()).collect() |
| |
| |
| if __name__ == "__main__": |
| import unittest |
| from pyspark.tests.test_worker import * # noqa: F401 |
| |
| try: |
| import xmlrunner |
| |
| testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) |
| except ImportError: |
| testRunner = None |
| unittest.main(testRunner=testRunner, verbosity=2) |