blob: bce73c84028a19c74302893472efd4a070684f1f [file] [log] [blame]
#!/usr/bin/env python3
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
sys.path.insert(0, "python")
import os
def find_py_files(path, exclude_paths):
"""
Find all .py files in a directory, excluding files in specified subdirectories.
Parameters
----------
path : str
The base directory to search for .py files.
exclude_paths : list of str
A list of subdirectories to exclude from the search.
Returns
-------
list of str
A list of paths to .py files.
"""
py_files = []
for root, dirs, files in os.walk(path):
if any(exclude_path in root for exclude_path in exclude_paths):
continue
for file in files:
if file.endswith(".py"):
py_files.append(os.path.join(root, file))
return py_files
def check_errors_in_file(file_path, pyspark_error_list):
"""
Check if a file uses PySpark-specific errors correctly.
Parameters
----------
file_path : str
Path to the file to check.
pyspark_error_list : list of str
List of PySpark-specific error names.
Returns
-------
list of str
A list of strings describing the errors found in the file, with line numbers.
"""
errors_found = []
with open(file_path, "r") as file:
for line_num, line in enumerate(file, start=1):
if line.strip().startswith("raise"):
parts = line.split()
# Check for 'raise' statement and ensure the error raised is a capitalized word.
if len(parts) > 1 and parts[1][0].isupper():
if not any(pyspark_error in line for pyspark_error in pyspark_error_list):
errors_found.append(f"{file_path}:{line_num}: {line.strip()}")
return errors_found
def check_pyspark_custom_errors(target_paths, exclude_paths):
"""
Check PySpark-specific errors in multiple paths.
Parameters
----------
target_paths : list of str
List of paths to check for PySpark-specific errors.
exclude_paths : list of str
List of paths to exclude from the check.
Returns
-------
list of str
A list of strings describing the errors found, with file paths and line numbers.
"""
all_errors = []
for path in target_paths:
for py_file in find_py_files(path, exclude_paths):
file_errors = check_errors_in_file(py_file, pyspark_error_list)
all_errors.extend(file_errors)
return all_errors
if __name__ == "__main__":
# PySpark-specific errors
pyspark_error_list = [
"AnalysisException",
"ArithmeticException",
"ArrayIndexOutOfBoundsException",
"DateTimeException",
"IllegalArgumentException",
"NumberFormatException",
"ParseException",
"PySparkAssertionError",
"PySparkAttributeError",
"PySparkException",
"PySparkImportError",
"PySparkIndexError",
"PySparkKeyError",
"PySparkNotImplementedError",
"PySparkPicklingError",
"PySparkRuntimeError",
"PySparkTypeError",
"PySparkValueError",
"PythonException",
"QueryExecutionException",
"RetriesExceeded",
"SessionNotSameException",
"SparkNoSuchElementException",
"SparkRuntimeException",
"SparkUpgradeException",
"StreamingQueryException",
"TempTableAlreadyExistsException",
"UnknownException",
"UnsupportedOperationException",
]
connect_error_list = [
"AnalysisException",
"ArithmeticException",
"ArrayIndexOutOfBoundsException",
"BaseAnalysisException",
"BaseArithmeticException",
"BaseArrayIndexOutOfBoundsException",
"BaseDateTimeException",
"BaseIllegalArgumentException",
"BaseNoSuchElementException",
"BaseNumberFormatException",
"BaseParseException",
"BasePythonException",
"BaseQueryExecutionException",
"BaseSparkRuntimeException",
"BaseSparkUpgradeException",
"BaseStreamingQueryException",
"BaseUnsupportedOperationException",
"DateTimeException",
"IllegalArgumentException",
"NumberFormatException",
"ParseException",
"PySparkException",
"PythonException",
"QueryExecutionException",
"SparkConnectException",
"SparkConnectGrpcException",
"SparkException",
"SparkNoSuchElementException",
"SparkRuntimeException",
"SparkUpgradeException",
"StreamingQueryException",
"UnsupportedOperationException",
]
internal_error_list = ["RetryException", "StopIteration"]
pyspark_error_list += connect_error_list
pyspark_error_list += internal_error_list
# Target paths and exclude paths
TARGET_PATHS = ["python/pyspark/sql"]
EXCLUDE_PATHS = [
"python/pyspark/sql/tests",
"python/pyspark/sql/connect/resource",
"python/pyspark/sql/connect/proto",
]
# Check errors
errors_found = check_pyspark_custom_errors(TARGET_PATHS, EXCLUDE_PATHS)
if errors_found:
print("\nPySpark custom errors check found issues in the following files:", file=sys.stderr)
for error in errors_found:
print(error, file=sys.stderr)
print("\nUse existing or create a new custom error from pyspark.errors.", file=sys.stderr)
sys.exit(1)