blob: 92955ad7f98df5c0dc1a53c7ad1e5e92bfb04d85 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib.util
import linecache
import os
import re
import sys
import tempfile
import traceback
import unittest
import pyspark.sql.functions as sf
from pyspark.errors import PythonException
from pyspark.errors.exceptions.base import AnalysisException
from pyspark.errors.exceptions.tblib import Traceback
from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.session import SparkSession
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)
class TracebackTests(unittest.TestCase):
"""Tests for the Traceback class."""
def make_traceback(self):
try:
raise ValueError("bar")
except ValueError:
_, _, tb = sys.exc_info()
return traceback.format_exc(), "".join(traceback.format_tb(tb))
def make_traceback_with_temp_file(self, code="def foo(): 1 / 0", filename=""):
with tempfile.NamedTemporaryFile("w", suffix=f"{filename}.py", delete=True) as f:
f.write(code)
f.flush()
spec = importlib.util.spec_from_file_location("foo", f.name)
foo = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(foo)
foo.foo()
except Exception:
_, _, tb = sys.exc_info()
return traceback.format_exc(), "".join(traceback.format_tb(tb))
else:
self.fail("Error not raised")
def assert_traceback(self, tb: Traceback, expected: str):
def remove_positions(s):
# For example, remove the 2nd line in this traceback:
"""
result = (x / y / z) * (a / b / c)
~~~~~~^~~
"""
pattern = r"\s*[\^\~]+\s*\n"
return re.sub(pattern, "\n", s)
tb.populate_linecache()
actual = remove_positions("".join(traceback.format_tb(tb.as_traceback())))
expected = remove_positions(expected)
self.assertEqual(actual, expected)
def test_simple(self):
s, expected = self.make_traceback()
self.assert_traceback(Traceback.from_string(s), expected)
def test_missing_source(self):
s, expected = self.make_traceback_with_temp_file()
linecache.clearcache() # remove temp file from cache
self.assert_traceback(Traceback.from_string(s), expected)
def test_recursion(self):
"""
Don't parse [Previous line repeated n times] because it's expensive for large n.
Since the input string is not necessarily a Python traceback, Traceback should keep runtime
linear to the input string length to be safe from malicious inputs.
"""
def foo(depth):
if depth > 0:
return foo(depth - 1)
raise 1 / 0
try:
foo(100)
except ZeroDivisionError:
s = traceback.format_exc()
actual = "".join(traceback.format_tb(Traceback.from_string(s).as_traceback()))
self.assertIn("[Previous line repeated", s)
self.assertNotIn("[Previous line repeated", actual)
@unittest.skipIf(
os.name != "posix",
"These file names may be invalid on non-posix systems",
)
def test_filename(self):
for filename in [
"",
" ",
"\\",
'"',
"'",
'", line 1, in hello',
]:
with self.subTest(filename=filename):
s, expected = self.make_traceback_with_temp_file(filename=filename)
linecache.clearcache()
self.assert_traceback(Traceback.from_string(s), expected)
@unittest.skipIf(
os.name != "posix",
"These file names may be invalid on non-posix systems",
)
def test_filename_failure_newline(self):
# tblib can't handle newline in the filename
s, expected = self.make_traceback_with_temp_file(filename="\n")
linecache.clearcache()
tb = Traceback.from_string(s)
tb.populate_linecache()
actual = "".join(traceback.format_tb(tb.as_traceback()))
self.assertNotEqual(actual, expected)
def test_syntax_error(self):
bad_syntax = "bad syntax"
s, _ = self.make_traceback_with_temp_file(bad_syntax)
tb = Traceback.from_string(s)
tb.populate_linecache()
actual = "".join(traceback.format_tb(tb.as_traceback()))
self.assertIn("bad syntax", actual)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
class BaseTracebackSqlTestsMixin:
"""Tests for recovering the original traceback from JVM exceptions."""
spark: SparkSession
@staticmethod
def raise_exception():
raise ValueError("bar")
def assertInOnce(self, needle: str, haystack: str):
"""Assert that a string appears only once in another string."""
count = haystack.count(needle)
self.assertEqual(count, 1, f"{needle} appears more than once in {haystack}")
def test_udf(self):
for jvm_stack_trace in [False, True]:
with self.subTest(jvm_stack_trace=jvm_stack_trace), self.sql_conf(
{"spark.sql.pyspark.jvmStacktrace.enabled": jvm_stack_trace}
):
@sf.udf()
def foo():
raise ValueError("bar")
df = self.spark.range(1).select(foo())
try:
df.show()
except PythonException:
_, _, tb = sys.exc_info()
else:
self.fail("PythonException not raised")
s = "".join(traceback.format_tb(tb))
self.assertInOnce("""df.show()""", s)
self.assertInOnce("""raise ValueError("bar")""", s)
def test_datasource_analysis(self):
class MyDataSource(DataSource):
def schema(self):
raise ValueError("bar")
self.spark.dataSource.register(MyDataSource)
try:
self.spark.read.format("MyDataSource").load().show()
except AnalysisException:
_, _, tb = sys.exc_info()
else:
self.fail("AnalysisException not raised")
s = "".join(traceback.format_tb(tb))
self.assertInOnce("""self.spark.read.format("MyDataSource").load().show()""", s)
self.assertInOnce("""raise ValueError("bar")""", s)
def test_datasource_execution(self):
class MyDataSource(DataSource):
def schema(self):
return "x int"
def reader(self, schema):
return MyDataSourceReader()
class MyDataSourceReader(DataSourceReader):
def read(self, partitions):
raise ValueError("bar")
self.spark.dataSource.register(MyDataSource)
try:
self.spark.read.format("MyDataSource").load().show()
except PythonException:
_, _, tb = sys.exc_info()
else:
self.fail("PythonException not raised")
s = "".join(traceback.format_tb(tb))
self.assertInOnce("""self.spark.read.format("MyDataSource").load().show()""", s)
self.assertInOnce("""raise ValueError("bar")""", s)
def test_udtf_analysis(self):
@sf.udtf()
class MyUdtf:
@staticmethod
def analyze():
raise ValueError("bar")
def eval(self):
pass
try:
MyUdtf().show()
except AnalysisException:
_, _, tb = sys.exc_info()
else:
self.fail("AnalysisException not raised")
s = "".join(traceback.format_tb(tb))
self.assertInOnce("""MyUdtf().show()""", s)
self.assertInOnce("""raise ValueError("bar")""", s)
def test_udtf_execution(self):
@sf.udtf(returnType="x int")
class MyUdtf:
def eval(self):
raise ValueError("bar")
try:
MyUdtf().show()
except PythonException:
_, _, tb = sys.exc_info()
else:
self.fail("PythonException not raised")
s = "".join(traceback.format_tb(tb))
self.assertInOnce("""MyUdtf().show()""", s)
self.assertInOnce("""raise ValueError("bar")""", s)
class TracebackSqlClassicTests(BaseTracebackSqlTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.errors.tests.test_traceback import * # noqa: F401
try:
import xmlrunner # type: ignore
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)