| # -*- encoding: utf-8 -*- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| import unittest |
| |
| from pyspark.testing.utils import should_test_connect, connect_requirement_message |
| |
| if should_test_connect: |
| from pyspark.errors.exceptions.connect import ( |
| convert_exception, |
| EXCEPTION_CLASS_MAPPING, |
| SparkConnectGrpcException, |
| PythonException, |
| AnalysisException, |
| ) |
| |
| |
| @unittest.skipIf(not should_test_connect, connect_requirement_message) |
| class ConnectErrorsTest(unittest.TestCase): |
| def test_convert_exception_known_class(self): |
| # Mock ErrorInfo with a known error class |
| from google.rpc.error_details_pb2 import ErrorInfo |
| from grpc import StatusCode |
| |
| info = { |
| "reason": "org.apache.spark.sql.AnalysisException", |
| "metadata": { |
| "classes": '["org.apache.spark.sql.AnalysisException"]', |
| "sqlState": "42000", |
| "errorClass": "ANALYSIS.ERROR", |
| "messageParameters": '{"param1": "value1"}', |
| }, |
| } |
| truncated_message = "Analysis error occurred" |
| exception = convert_exception( |
| info=ErrorInfo(**info), |
| truncated_message=truncated_message, |
| resp=None, |
| grpc_status_code=StatusCode.INTERNAL, |
| ) |
| |
| self.assertIsInstance(exception, AnalysisException) |
| self.assertEqual(exception.getSqlState(), "42000") |
| self.assertEqual(exception._errorClass, "ANALYSIS.ERROR") |
| self.assertEqual(exception._messageParameters, {"param1": "value1"}) |
| self.assertEqual(exception.getGrpcStatusCode(), StatusCode.INTERNAL) |
| |
| def test_convert_exception_python_exception(self): |
| # Mock ErrorInfo for PythonException |
| from google.rpc.error_details_pb2 import ErrorInfo |
| from grpc import StatusCode |
| |
| info = { |
| "reason": "org.apache.spark.api.python.PythonException", |
| "metadata": { |
| "classes": '["org.apache.spark.api.python.PythonException"]', |
| }, |
| } |
| truncated_message = "Python worker error occurred" |
| exception = convert_exception( |
| info=ErrorInfo(**info), |
| truncated_message=truncated_message, |
| resp=None, |
| grpc_status_code=StatusCode.INTERNAL, |
| ) |
| |
| self.assertIsInstance(exception, PythonException) |
| self.assertIn("An exception was thrown from the Python worker", exception.getMessage()) |
| self.assertEqual(exception.getGrpcStatusCode(), StatusCode.INTERNAL) |
| |
| def test_convert_exception_unknown_class(self): |
| # Mock ErrorInfo with an unknown error class |
| from google.rpc.error_details_pb2 import ErrorInfo |
| from grpc import StatusCode |
| |
| info = { |
| "reason": "org.apache.spark.UnknownException", |
| "metadata": {"classes": '["org.apache.spark.UnknownException"]'}, |
| } |
| truncated_message = "Unknown error occurred" |
| exception = convert_exception( |
| info=ErrorInfo(**info), |
| truncated_message=truncated_message, |
| resp=None, |
| grpc_status_code=StatusCode.INTERNAL, |
| ) |
| |
| self.assertIsInstance(exception, SparkConnectGrpcException) |
| self.assertEqual( |
| exception.getMessage(), "(org.apache.spark.UnknownException) Unknown error occurred" |
| ) |
| self.assertEqual(exception.getGrpcStatusCode(), StatusCode.INTERNAL) |
| |
| def test_exception_class_mapping(self): |
| # Ensure that all keys in EXCEPTION_CLASS_MAPPING are valid |
| for error_class_name, exception_class in EXCEPTION_CLASS_MAPPING.items(): |
| self.assertTrue( |
| hasattr(exception_class, "__name__"), |
| f"{exception_class} in EXCEPTION_CLASS_MAPPING is not a valid class", |
| ) |
| |
| def test_convert_exception_with_stacktrace(self): |
| # Mock FetchErrorDetailsResponse with stacktrace |
| from google.rpc.error_details_pb2 import ErrorInfo |
| from pyspark.sql.connect.proto import FetchErrorDetailsResponse as pb2 |
| |
| resp = pb2( |
| root_error_idx=0, |
| errors=[ |
| pb2.Error( |
| message="Root error message", |
| error_type_hierarchy=["org.apache.spark.SparkException"], |
| stack_trace=[ |
| pb2.StackTraceElement( |
| declaring_class="org.apache.spark.Main", |
| method_name="main", |
| file_name="Main.scala", |
| line_number=42, |
| ), |
| ], |
| cause_idx=1, |
| ), |
| pb2.Error( |
| message="Cause error message", |
| error_type_hierarchy=["java.lang.RuntimeException"], |
| stack_trace=[ |
| pb2.StackTraceElement( |
| declaring_class="org.apache.utils.Helper", |
| method_name="help", |
| file_name="Helper.java", |
| line_number=10, |
| ), |
| ], |
| ), |
| ], |
| ) |
| |
| info = { |
| "reason": "org.apache.spark.SparkException", |
| "metadata": { |
| "classes": '["org.apache.spark.SparkException"]', |
| "sqlState": "42000", |
| }, |
| } |
| truncated_message = "Root error message" |
| exception = convert_exception( |
| info=ErrorInfo(**info), |
| truncated_message=truncated_message, |
| resp=resp, |
| display_server_stacktrace=True, |
| ) |
| |
| self.assertIsInstance(exception, SparkConnectGrpcException) |
| self.assertIn("Root error message", exception.getMessage()) |
| self.assertIn("Caused by", exception.getMessage()) |
| |
| def test_convert_exception_fallback(self): |
| # Mock ErrorInfo with missing class information |
| from google.rpc.error_details_pb2 import ErrorInfo |
| from grpc import StatusCode |
| |
| info = { |
| "reason": "org.apache.spark.UnknownReason", |
| "metadata": {}, |
| } |
| truncated_message = "Fallback error occurred" |
| exception = convert_exception( |
| info=ErrorInfo(**info), truncated_message=truncated_message, resp=None |
| ) |
| |
| self.assertIsInstance(exception, SparkConnectGrpcException) |
| self.assertEqual( |
| exception.getMessage(), "(org.apache.spark.UnknownReason) Fallback error occurred" |
| ) |
| self.assertEqual(exception.getGrpcStatusCode(), StatusCode.UNKNOWN) |
| |
| def test_convert_exception_with_breaking_change_info(self): |
| """Test that breaking change info is correctly extracted from protobuf response.""" |
| import pyspark.sql.connect.proto as pb2 |
| from google.rpc.error_details_pb2 import ErrorInfo |
| from grpc import StatusCode |
| |
| # Create mock FetchErrorDetailsResponse with breaking change info |
| resp = pb2.FetchErrorDetailsResponse() |
| resp.root_error_idx = 0 |
| |
| error = resp.errors.add() |
| error.message = "Test error with breaking change" |
| error.error_type_hierarchy.append("org.apache.spark.SparkException") |
| |
| # Add SparkThrowable with breaking change info |
| spark_throwable = error.spark_throwable |
| spark_throwable.error_class = "TEST_BREAKING_CHANGE_ERROR" |
| |
| # Add breaking change info |
| bci = spark_throwable.breaking_change_info |
| bci.migration_message.append("Please update your code to use new API") |
| bci.migration_message.append("See documentation for details") |
| bci.needs_audit = False |
| |
| # Add mitigation config |
| mitigation_config = bci.mitigation_config |
| mitigation_config.key = "spark.sql.legacy.behavior.enabled" |
| mitigation_config.value = "true" |
| |
| info = ErrorInfo() |
| info.reason = "org.apache.spark.SparkException" |
| info.metadata["classes"] = '["org.apache.spark.SparkException"]' |
| |
| exception = convert_exception( |
| info=info, |
| truncated_message="Test error", |
| resp=resp, |
| grpc_status_code=StatusCode.INTERNAL, |
| ) |
| |
| # Verify breaking change info is correctly extracted |
| breaking_change_info = exception.getBreakingChangeInfo() |
| self.assertIsNotNone(breaking_change_info) |
| self.assertEqual( |
| breaking_change_info["migration_message"], |
| ["Please update your code to use new API", "See documentation for details"], |
| ) |
| self.assertEqual(breaking_change_info["needs_audit"], False) |
| self.assertIn("mitigation_config", breaking_change_info) |
| self.assertEqual( |
| breaking_change_info["mitigation_config"]["key"], |
| "spark.sql.legacy.behavior.enabled", |
| ) |
| self.assertEqual(breaking_change_info["mitigation_config"]["value"], "true") |
| |
| def test_convert_exception_without_breaking_change_info(self): |
| """Test that getBreakingChangeInfo returns None when no breaking change info.""" |
| import pyspark.sql.connect.proto as pb2 |
| from google.rpc.error_details_pb2 import ErrorInfo |
| from grpc import StatusCode |
| |
| # Create mock FetchErrorDetailsResponse without breaking change info |
| resp = pb2.FetchErrorDetailsResponse() |
| resp.root_error_idx = 0 |
| |
| error = resp.errors.add() |
| error.message = "Test error without breaking change" |
| error.error_type_hierarchy.append("org.apache.spark.SparkException") |
| |
| # Add SparkThrowable without breaking change info |
| spark_throwable = error.spark_throwable |
| spark_throwable.error_class = "REGULAR_ERROR" |
| |
| info = ErrorInfo() |
| info.reason = "org.apache.spark.SparkException" |
| info.metadata["classes"] = '["org.apache.spark.SparkException"]' |
| |
| exception = convert_exception( |
| info=info, |
| truncated_message="Test error", |
| resp=resp, |
| grpc_status_code=StatusCode.INTERNAL, |
| ) |
| |
| # Verify breaking change info is None |
| breaking_change_info = exception.getBreakingChangeInfo() |
| self.assertIsNone(breaking_change_info) |
| |
| def test_breaking_change_info_storage_in_exception(self): |
| """Test SparkConnectGrpcException correctly stores and retrieves breaking change info.""" |
| from pyspark.errors.exceptions.connect import SparkConnectGrpcException |
| |
| breaking_change_info = { |
| "migration_message": ["Test migration message"], |
| "mitigation_config": {"key": "test.config.key", "value": "test.config.value"}, |
| "needs_audit": True, |
| } |
| |
| exception = SparkConnectGrpcException( |
| message="Test error", errorClass="TEST_ERROR", breaking_change_info=breaking_change_info |
| ) |
| |
| stored_info = exception.getBreakingChangeInfo() |
| self.assertEqual(stored_info, breaking_change_info) |
| |
| def test_breaking_change_info_inheritance(self): |
| """Test that subclasses of SparkConnectGrpcException |
| correctly inherit breaking change info.""" |
| from pyspark.errors.exceptions.connect import AnalysisException, UnknownException |
| |
| breaking_change_info = { |
| "migration_message": ["Inheritance test message"], |
| "needs_audit": False, |
| } |
| |
| # Test AnalysisException |
| analysis_exception = AnalysisException( |
| message="Analysis error with breaking change", |
| errorClass="TEST_ANALYSIS_ERROR", |
| breaking_change_info=breaking_change_info, |
| ) |
| |
| stored_info = analysis_exception.getBreakingChangeInfo() |
| self.assertEqual(stored_info, breaking_change_info) |
| |
| # Test UnknownException |
| unknown_exception = UnknownException( |
| message="Unknown error with breaking change", |
| errorClass="TEST_UNKNOWN_ERROR", |
| breaking_change_info=breaking_change_info, |
| ) |
| |
| stored_info = unknown_exception.getBreakingChangeInfo() |
| self.assertEqual(stored_info, breaking_change_info) |
| |
| def test_breaking_change_info_without_mitigation_config(self): |
| """Test breaking change info that only has migration messages.""" |
| import pyspark.sql.connect.proto as pb2 |
| from google.rpc.error_details_pb2 import ErrorInfo |
| from grpc import StatusCode |
| |
| # Create mock FetchErrorDetailsResponse with breaking change info (no mitigation config) |
| resp = pb2.FetchErrorDetailsResponse() |
| resp.root_error_idx = 0 |
| |
| error = resp.errors.add() |
| error.message = "Test error with breaking change" |
| error.error_type_hierarchy.append("org.apache.spark.SparkException") |
| |
| # Add SparkThrowable with breaking change info |
| spark_throwable = error.spark_throwable |
| spark_throwable.error_class = "TEST_BREAKING_CHANGE_ERROR" |
| |
| # Add breaking change info without mitigation config |
| bci = spark_throwable.breaking_change_info |
| bci.migration_message.append("Migration message only") |
| bci.needs_audit = True |
| |
| info = ErrorInfo() |
| info.reason = "org.apache.spark.SparkException" |
| info.metadata["classes"] = '["org.apache.spark.SparkException"]' |
| |
| exception = convert_exception( |
| info=info, |
| truncated_message="Test error", |
| resp=resp, |
| grpc_status_code=StatusCode.INTERNAL, |
| ) |
| |
| # Verify breaking change info is correctly extracted |
| breaking_change_info = exception.getBreakingChangeInfo() |
| self.assertIsNotNone(breaking_change_info) |
| self.assertEqual(breaking_change_info["migration_message"], ["Migration message only"]) |
| self.assertEqual(breaking_change_info["needs_audit"], True) |
| self.assertNotIn("mitigation_config", breaking_change_info) |
| |
| |
| if __name__ == "__main__": |
| import unittest |
| from pyspark.errors.tests.test_errors import * # noqa: F401 |
| |
| try: |
| import xmlrunner |
| |
| testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) |
| except ImportError: |
| testRunner = None |
| unittest.main(testRunner=testRunner, verbosity=2) |