| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| import os |
| import sys |
| import tempfile |
| import unittest |
| from io import StringIO |
| |
| from pyspark import SparkConf, SparkContext, BasicProfiler |
| from pyspark.profiler import has_memory_profiler |
| from pyspark.sql import SparkSession |
| from pyspark.sql.functions import udf |
| from pyspark.errors import PythonException, PySparkRuntimeError |
| from pyspark.testing.utils import PySparkTestCase, PySparkErrorTestUtils |
| |
| |
| class ProfilerTests(PySparkTestCase): |
| def setUp(self): |
| self._old_sys_path = list(sys.path) |
| class_name = self.__class__.__name__ |
| conf = SparkConf().set("spark.python.profile", "true") |
| self.sc = SparkContext("local[4]", class_name, conf=conf) |
| |
| def test_profiler(self): |
| self.do_computation() |
| |
| profilers = self.sc.profiler_collector.profilers |
| self.assertEqual(1, len(profilers)) |
| id, profiler, _ = profilers[0] |
| stats = profiler.stats() |
| self.assertTrue(stats is not None) |
| width, stat_list = stats.get_print_list([]) |
| func_names = [func_name for fname, n, func_name in stat_list] |
| self.assertTrue("heavy_foo" in func_names) |
| |
| old_stdout = sys.stdout |
| sys.stdout = io = StringIO() |
| self.sc.show_profiles() |
| self.assertTrue("heavy_foo" in io.getvalue()) |
| sys.stdout = old_stdout |
| |
| with tempfile.TemporaryDirectory(prefix="test_profiler") as d: |
| self.sc.dump_profiles(d) |
| self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) |
| |
| def test_custom_profiler(self): |
| class TestCustomProfiler(BasicProfiler): |
| def show(self, id): |
| self.result = "Custom formatting" |
| |
| self.sc.profiler_collector.profiler_cls = TestCustomProfiler |
| |
| self.do_computation() |
| |
| profilers = self.sc.profiler_collector.profilers |
| self.assertEqual(1, len(profilers)) |
| _, profiler, _ = profilers[0] |
| self.assertTrue(isinstance(profiler, TestCustomProfiler)) |
| |
| self.sc.show_profiles() |
| self.assertEqual("Custom formatting", profiler.result) |
| |
| def do_computation(self): |
| def heavy_foo(x): |
| for i in range(1 << 18): |
| x = 1 # noqa: F841 |
| |
| rdd = self.sc.parallelize(range(100)) |
| rdd.foreach(heavy_foo) |
| |
| |
| class ProfilerTests2(unittest.TestCase, PySparkErrorTestUtils): |
| def test_profiler_disabled(self): |
| sc = SparkContext( |
| conf=SparkConf() |
| .set("spark.python.profile", "false") |
| .set("spark.python.profile.memory", "false") |
| ) |
| try: |
| with self.assertRaises(PySparkRuntimeError) as pe: |
| sc.show_profiles() |
| self.check_error( |
| exception=pe.exception, |
| errorClass="INCORRECT_CONF_FOR_PROFILE", |
| messageParameters={}, |
| ) |
| |
| with self.assertRaises(PySparkRuntimeError) as pe: |
| sc.dump_profiles("/tmp/abc") |
| self.check_error( |
| exception=pe.exception, |
| errorClass="INCORRECT_CONF_FOR_PROFILE", |
| messageParameters={}, |
| ) |
| finally: |
| sc.stop() |
| |
| def test_profiler_all_enabled(self): |
| sc = SparkContext( |
| conf=SparkConf() |
| .set("spark.python.profile", "true") |
| .set("spark.python.profile.memory", "true") |
| ) |
| spark = SparkSession(sparkContext=sc) |
| |
| @udf("int") |
| def plus_one(v): |
| return v + 1 |
| |
| try: |
| with self.assertRaises(PySparkRuntimeError) as pe: |
| spark.range(10).select(plus_one("id")).collect() |
| |
| self.check_error( |
| exception=pe.exception, |
| errorClass="CANNOT_SET_TOGETHER", |
| messageParameters={ |
| "arg_list": "'spark.python.profile' and " |
| "'spark.python.profile.memory' configuration" |
| }, |
| ) |
| finally: |
| sc.stop() |
| |
| @unittest.skipIf(has_memory_profiler, "Test when memory-profiler is not installed.") |
| def test_no_memory_profile_installed(self): |
| sc = SparkContext( |
| conf=SparkConf() |
| .set("spark.python.profile", "false") |
| .set("spark.python.profile.memory", "true") |
| ) |
| spark = SparkSession(sparkContext=sc) |
| |
| @udf("int") |
| def plus_one(v): |
| return v + 1 |
| |
| try: |
| self.assertRaisesRegex( |
| PythonException, |
| "Install the 'memory_profiler' library in the cluster to enable memory " |
| "profiling", |
| lambda: spark.range(10).select(plus_one("id")).collect(), |
| ) |
| finally: |
| sc.stop() |
| |
| |
| if __name__ == "__main__": |
| from pyspark.tests.test_profiler import * # noqa: F401 |
| |
| try: |
| import xmlrunner |
| |
| testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) |
| except ImportError: |
| testRunner = None |
| unittest.main(testRunner=testRunner, verbosity=2) |