| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| import grpc |
| import json |
| from grpc import StatusCode |
| from typing import Any, Dict, List, Optional, TYPE_CHECKING |
| |
| from pyspark.errors.exceptions.base import ( |
| AnalysisException as BaseAnalysisException, |
| IllegalArgumentException as BaseIllegalArgumentException, |
| ArithmeticException as BaseArithmeticException, |
| UnsupportedOperationException as BaseUnsupportedOperationException, |
| ArrayIndexOutOfBoundsException as BaseArrayIndexOutOfBoundsException, |
| DateTimeException as BaseDateTimeException, |
| NumberFormatException as BaseNumberFormatException, |
| ParseException as BaseParseException, |
| PySparkException, |
| PythonException as BasePythonException, |
| StreamingQueryException as BaseStreamingQueryException, |
| QueryExecutionException as BaseQueryExecutionException, |
| SparkRuntimeException as BaseSparkRuntimeException, |
| SparkNoSuchElementException as BaseNoSuchElementException, |
| SparkUpgradeException as BaseSparkUpgradeException, |
| QueryContext as BaseQueryContext, |
| QueryContextType, |
| StreamingPythonRunnerInitializationException as BaseStreamingPythonRunnerInitException, |
| PickleException as BasePickleException, |
| UnknownException as BaseUnknownException, |
| recover_python_exception, |
| ) |
| |
| if TYPE_CHECKING: |
| import pyspark.sql.connect.proto as pb2 |
| from google.rpc.error_details_pb2 import ErrorInfo |
| |
| |
| class SparkConnectException(PySparkException): |
| """ |
| Exception thrown from Spark Connect. |
| """ |
| |
| |
| def convert_exception( |
| info: "ErrorInfo", |
| truncated_message: str, |
| resp: Optional["pb2.FetchErrorDetailsResponse"], |
| display_server_stacktrace: bool = False, |
| grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN, |
| ) -> SparkConnectException: |
| converted = _convert_exception( |
| info, truncated_message, resp, display_server_stacktrace, grpc_status_code |
| ) |
| return recover_python_exception(converted) |
| |
| |
| def _convert_exception( |
| info: "ErrorInfo", |
| truncated_message: str, |
| resp: Optional["pb2.FetchErrorDetailsResponse"], |
| display_server_stacktrace: bool = False, |
| grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN, |
| ) -> SparkConnectException: |
| import pyspark.sql.connect.proto as pb2 |
| |
| raw_classes = info.metadata.get("classes") |
| classes: List[str] = json.loads(raw_classes) if raw_classes else [] |
| sql_state = info.metadata.get("sqlState") |
| error_class = info.metadata.get("errorClass") |
| raw_message_parameters = info.metadata.get("messageParameters") |
| message_parameters: Dict[str, str] = ( |
| json.loads(raw_message_parameters) if raw_message_parameters else {} |
| ) |
| stacktrace: Optional[str] = None |
| |
| if resp is not None and resp.HasField("root_error_idx"): |
| message = resp.errors[resp.root_error_idx].message |
| stacktrace = _extract_jvm_stacktrace(resp) |
| else: |
| message = truncated_message |
| stacktrace = info.metadata.get("stackTrace") |
| display_server_stacktrace = display_server_stacktrace if stacktrace else False |
| |
| contexts = None |
| breaking_change_info = None |
| if resp and resp.HasField("root_error_idx"): |
| root_error = resp.errors[resp.root_error_idx] |
| if hasattr(root_error, "spark_throwable"): |
| message_parameters = dict(root_error.spark_throwable.message_parameters) |
| contexts = [ |
| SQLQueryContext(c) |
| if c.context_type == pb2.FetchErrorDetailsResponse.QueryContext.SQL |
| else DataFrameQueryContext(c) |
| for c in root_error.spark_throwable.query_contexts |
| ] |
| # Extract breaking change info if present |
| if hasattr( |
| root_error.spark_throwable, "breaking_change_info" |
| ) and root_error.spark_throwable.HasField("breaking_change_info"): |
| bci = root_error.spark_throwable.breaking_change_info |
| breaking_change_info = { |
| "migration_message": list(bci.migration_message), |
| "needs_audit": bci.needs_audit if bci.HasField("needs_audit") else True, |
| } |
| if bci.HasField("mitigation_config"): |
| breaking_change_info["mitigation_config"] = { |
| "key": bci.mitigation_config.key, |
| "value": bci.mitigation_config.value, |
| } |
| |
| if "org.apache.spark.api.python.PythonException" in classes: |
| return PythonException( |
| message="\n An exception was thrown from the Python worker. " |
| "Please see the stack trace below.\n%s" % message, |
| grpc_status_code=grpc_status_code, |
| ) |
| |
| # Return exception based on class mapping |
| for error_class_name in classes: |
| ExceptionClass = EXCEPTION_CLASS_MAPPING.get(error_class_name) |
| if ExceptionClass is SparkException: |
| for third_party_exception_class in THIRD_PARTY_EXCEPTION_CLASS_MAPPING: |
| ExceptionClass = ( |
| THIRD_PARTY_EXCEPTION_CLASS_MAPPING.get(third_party_exception_class) |
| if third_party_exception_class in message |
| else SparkException |
| ) |
| |
| if ExceptionClass: |
| return ExceptionClass( |
| message, |
| errorClass=error_class, |
| messageParameters=message_parameters, |
| sql_state=sql_state, |
| server_stacktrace=stacktrace, |
| display_server_stacktrace=display_server_stacktrace, |
| contexts=contexts, |
| grpc_status_code=grpc_status_code, |
| breaking_change_info=breaking_change_info, |
| ) |
| |
| # Return UnknownException if there is no matched exception class |
| return UnknownException( |
| message, |
| reason=info.reason, |
| messageParameters=message_parameters, |
| errorClass=error_class, |
| sql_state=sql_state, |
| server_stacktrace=stacktrace, |
| display_server_stacktrace=display_server_stacktrace, |
| contexts=contexts, |
| grpc_status_code=grpc_status_code, |
| breaking_change_info=breaking_change_info, |
| ) |
| |
| |
| def _extract_jvm_stacktrace(resp: "pb2.FetchErrorDetailsResponse") -> str: |
| if len(resp.errors[resp.root_error_idx].stack_trace) == 0: |
| return "" |
| |
| lines: List[str] = [] |
| |
| def format_stacktrace(error: "pb2.FetchErrorDetailsResponse.Error") -> None: |
| message = f"{error.error_type_hierarchy[0]}: {error.message}" |
| if len(lines) == 0: |
| lines.append(error.error_type_hierarchy[0]) |
| else: |
| lines.append(f"Caused by: {message}") |
| for elem in error.stack_trace: |
| lines.append( |
| f"\tat {elem.declaring_class}.{elem.method_name}" |
| f"({elem.file_name}:{elem.line_number})" |
| ) |
| |
| # If this error has a cause, format that recursively |
| if error.HasField("cause_idx"): |
| format_stacktrace(resp.errors[error.cause_idx]) |
| |
| format_stacktrace(resp.errors[resp.root_error_idx]) |
| |
| return "\n".join(lines) |
| |
| |
| class SparkConnectGrpcException(SparkConnectException): |
| """ |
| Base class to handle the errors from GRPC. |
| """ |
| |
| def __init__( |
| self, |
| message: Optional[str] = None, |
| errorClass: Optional[str] = None, |
| messageParameters: Optional[Dict[str, str]] = None, |
| reason: Optional[str] = None, |
| sql_state: Optional[str] = None, |
| server_stacktrace: Optional[str] = None, |
| display_server_stacktrace: bool = False, |
| contexts: Optional[List[BaseQueryContext]] = None, |
| grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN, |
| breaking_change_info: Optional[Dict[str, Any]] = None, |
| ) -> None: |
| if contexts is None: |
| contexts = [] |
| self._message = message # type: ignore[assignment] |
| if reason is not None: |
| self._message = f"({reason}) {self._message}" |
| |
| # PySparkException has the assumption that errorClass and messageParameters are |
| # only occurring together. If only one is set, we assume the message to be fully |
| # parsed. |
| tmp_error_class = errorClass |
| tmp_message_parameters = messageParameters |
| if errorClass is not None and messageParameters is None: |
| tmp_error_class = None |
| elif errorClass is None and messageParameters is not None: |
| tmp_message_parameters = None |
| |
| super().__init__( |
| message=self._message, |
| errorClass=tmp_error_class, |
| messageParameters=tmp_message_parameters, |
| ) |
| self._errorClass = errorClass |
| self._sql_state: Optional[str] = sql_state |
| self._stacktrace: Optional[str] = server_stacktrace |
| self._display_stacktrace: bool = display_server_stacktrace |
| self._contexts: List[BaseQueryContext] = contexts |
| self._grpc_status_code = grpc_status_code |
| self._breaking_change_info: Optional[Dict[str, Any]] = breaking_change_info |
| self._log_exception() |
| |
| def getSqlState(self) -> Optional[str]: |
| if self._sql_state is not None: |
| return self._sql_state |
| else: |
| return super().getSqlState() |
| |
| def getStackTrace(self) -> Optional[str]: |
| return self._stacktrace |
| |
| def getMessage(self) -> str: |
| desc = self._message |
| if self._display_stacktrace: |
| desc += "\n\nJVM stacktrace:\n%s" % self._stacktrace |
| return desc |
| |
| def getGrpcStatusCode(self) -> grpc.StatusCode: |
| return self._grpc_status_code |
| |
| def getBreakingChangeInfo(self) -> Optional[Dict[str, Any]]: |
| """ |
| Returns the breaking change info for an error, or None. |
| |
| For Spark Connect exceptions, this returns the breaking change info |
| received from the server, rather than looking it up from local error files. |
| """ |
| return self._breaking_change_info |
| |
| def __str__(self) -> str: |
| return self.getMessage() |
| |
| |
| class UnknownException(SparkConnectGrpcException, BaseUnknownException): |
| """ |
| Exception for unmapped errors in Spark Connect. |
| This class is functionally identical to SparkConnectGrpcException but has a different name |
| for consistency. |
| """ |
| |
| def __init__( |
| self, |
| message: Optional[str] = None, |
| errorClass: Optional[str] = None, |
| messageParameters: Optional[Dict[str, str]] = None, |
| reason: Optional[str] = None, |
| sql_state: Optional[str] = None, |
| server_stacktrace: Optional[str] = None, |
| display_server_stacktrace: bool = False, |
| contexts: Optional[List[BaseQueryContext]] = None, |
| grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN, |
| breaking_change_info: Optional[Dict[str, Any]] = None, |
| ) -> None: |
| super().__init__( |
| message=message, |
| errorClass=errorClass, |
| messageParameters=messageParameters, |
| reason=reason, |
| sql_state=sql_state, |
| server_stacktrace=server_stacktrace, |
| display_server_stacktrace=display_server_stacktrace, |
| contexts=contexts, |
| grpc_status_code=grpc_status_code, |
| breaking_change_info=breaking_change_info, |
| ) |
| |
| |
| class AnalysisException(SparkConnectGrpcException, BaseAnalysisException): |
| """ |
| Failed to analyze a SQL query plan, thrown from Spark Connect. |
| """ |
| |
| |
| class ParseException(AnalysisException, BaseParseException): |
| """ |
| Failed to parse a SQL command, thrown from Spark Connect. |
| """ |
| |
| |
| class IllegalArgumentException(SparkConnectGrpcException, BaseIllegalArgumentException): |
| """ |
| Passed an illegal or inappropriate argument, thrown from Spark Connect. |
| """ |
| |
| |
| class StreamingQueryException(SparkConnectGrpcException, BaseStreamingQueryException): |
| """ |
| Exception that stopped a :class:`StreamingQuery` thrown from Spark Connect. |
| """ |
| |
| |
| class QueryExecutionException(SparkConnectGrpcException, BaseQueryExecutionException): |
| """ |
| Failed to execute a query, thrown from Spark Connect. |
| """ |
| |
| |
| class PythonException(SparkConnectGrpcException, BasePythonException): |
| """ |
| Exceptions thrown from Spark Connect. |
| """ |
| |
| |
| class ArithmeticException(SparkConnectGrpcException, BaseArithmeticException): |
| """ |
| Arithmetic exception thrown from Spark Connect. |
| """ |
| |
| |
| class UnsupportedOperationException(SparkConnectGrpcException, BaseUnsupportedOperationException): |
| """ |
| Unsupported operation exception thrown from Spark Connect. |
| """ |
| |
| |
| class ArrayIndexOutOfBoundsException(SparkConnectGrpcException, BaseArrayIndexOutOfBoundsException): |
| """ |
| Array index out of bounds exception thrown from Spark Connect. |
| """ |
| |
| |
| class DateTimeException(SparkConnectGrpcException, BaseDateTimeException): |
| """ |
| Datetime exception thrown from Spark Connect. |
| """ |
| |
| |
| class NumberFormatException(IllegalArgumentException, BaseNumberFormatException): |
| """ |
| Number format exception thrown from Spark Connect. |
| """ |
| |
| |
| class SparkRuntimeException(SparkConnectGrpcException, BaseSparkRuntimeException): |
| """ |
| Runtime exception thrown from Spark Connect. |
| """ |
| |
| |
| class SparkUpgradeException(SparkConnectGrpcException, BaseSparkUpgradeException): |
| """ |
| Exception thrown because of Spark upgrade from Spark Connect. |
| """ |
| |
| |
| class SparkException(SparkConnectGrpcException): |
| """ """ |
| |
| |
| class SparkNoSuchElementException(SparkConnectGrpcException, BaseNoSuchElementException): |
| """ |
| No such element exception. |
| """ |
| |
| |
| class InvalidPlanInput(SparkConnectGrpcException): |
| """ |
| Error thrown when a connect plan is not valid. |
| """ |
| |
| |
| class StreamingPythonRunnerInitializationException( |
| SparkConnectGrpcException, BaseStreamingPythonRunnerInitException |
| ): |
| """ |
| Failed to initialize a streaming Python runner. |
| """ |
| |
| |
| class PickleException(SparkConnectGrpcException, BasePickleException): |
| """ |
| Represents an exception which is failed while pickling from server side |
| such as `net.razorvine.pickle.PickleException`. This is different from `PySparkPicklingError` |
| which represents an exception failed from Python built-in `pickle.PicklingError`. |
| """ |
| |
| |
| # Update EXCEPTION_CLASS_MAPPING here when adding a new exception |
| EXCEPTION_CLASS_MAPPING = { |
| "org.apache.spark.sql.catalyst.parser.ParseException": ParseException, |
| "org.apache.spark.sql.AnalysisException": AnalysisException, |
| "org.apache.spark.sql.streaming.StreamingQueryException": StreamingQueryException, |
| "org.apache.spark.sql.execution.QueryExecutionException": QueryExecutionException, |
| "java.lang.NumberFormatException": NumberFormatException, |
| "java.lang.IllegalArgumentException": IllegalArgumentException, |
| "java.lang.ArithmeticException": ArithmeticException, |
| "java.lang.UnsupportedOperationException": UnsupportedOperationException, |
| "java.lang.ArrayIndexOutOfBoundsException": ArrayIndexOutOfBoundsException, |
| "java.time.DateTimeException": DateTimeException, |
| "org.apache.spark.SparkRuntimeException": SparkRuntimeException, |
| "org.apache.spark.SparkUpgradeException": SparkUpgradeException, |
| "org.apache.spark.api.python.PythonException": PythonException, |
| "org.apache.spark.SparkNoSuchElementException": SparkNoSuchElementException, |
| "org.apache.spark.SparkException": SparkException, |
| "org.apache.spark.sql.connect.common.InvalidPlanInput": InvalidPlanInput, |
| "org.apache.spark.api.python.StreamingPythonRunner" |
| "$StreamingPythonRunnerInitializationException": StreamingPythonRunnerInitializationException, |
| } |
| |
| THIRD_PARTY_EXCEPTION_CLASS_MAPPING = { |
| "net.razorvine.pickle.PickleException": PickleException, |
| } |
| |
| |
| class SQLQueryContext(BaseQueryContext): |
| def __init__(self, q: "pb2.FetchErrorDetailsResponse.QueryContext"): |
| self._q = q |
| |
| def contextType(self) -> QueryContextType: |
| return QueryContextType.SQL |
| |
| def objectType(self) -> str: |
| return str(self._q.object_type) |
| |
| def objectName(self) -> str: |
| return str(self._q.object_name) |
| |
| def startIndex(self) -> int: |
| return int(self._q.start_index) |
| |
| def stopIndex(self) -> int: |
| return int(self._q.stop_index) |
| |
| def fragment(self) -> str: |
| return str(self._q.fragment) |
| |
| def callSite(self) -> str: |
| raise UnsupportedOperationException( |
| "", |
| errorClass="UNSUPPORTED_CALL.WITHOUT_SUGGESTION", |
| messageParameters={"className": "SQLQueryContext", "methodName": "callSite"}, |
| sql_state="0A000", |
| server_stacktrace=None, |
| display_server_stacktrace=False, |
| contexts=[], |
| ) |
| |
| def summary(self) -> str: |
| return str(self._q.summary) |
| |
| |
| class DataFrameQueryContext(BaseQueryContext): |
| def __init__(self, q: "pb2.FetchErrorDetailsResponse.QueryContext"): |
| self._q = q |
| |
| def contextType(self) -> QueryContextType: |
| return QueryContextType.DataFrame |
| |
| def objectType(self) -> str: |
| raise UnsupportedOperationException( |
| "", |
| errorClass="UNSUPPORTED_CALL.WITHOUT_SUGGESTION", |
| messageParameters={"className": "DataFrameQueryContext", "methodName": "objectType"}, |
| sql_state="0A000", |
| server_stacktrace=None, |
| display_server_stacktrace=False, |
| contexts=[], |
| ) |
| |
| def objectName(self) -> str: |
| raise UnsupportedOperationException( |
| "", |
| errorClass="UNSUPPORTED_CALL.WITHOUT_SUGGESTION", |
| messageParameters={"className": "DataFrameQueryContext", "methodName": "objectName"}, |
| sql_state="0A000", |
| server_stacktrace=None, |
| display_server_stacktrace=False, |
| contexts=[], |
| ) |
| |
| def startIndex(self) -> int: |
| raise UnsupportedOperationException( |
| "", |
| errorClass="UNSUPPORTED_CALL.WITHOUT_SUGGESTION", |
| messageParameters={"className": "DataFrameQueryContext", "methodName": "startIndex"}, |
| sql_state="0A000", |
| server_stacktrace=None, |
| display_server_stacktrace=False, |
| contexts=[], |
| ) |
| |
| def stopIndex(self) -> int: |
| raise UnsupportedOperationException( |
| "", |
| errorClass="UNSUPPORTED_CALL.WITHOUT_SUGGESTION", |
| messageParameters={"className": "DataFrameQueryContext", "methodName": "stopIndex"}, |
| sql_state="0A000", |
| server_stacktrace=None, |
| display_server_stacktrace=False, |
| contexts=[], |
| ) |
| |
| def fragment(self) -> str: |
| return str(self._q.fragment) |
| |
| def callSite(self) -> str: |
| return str(self._q.call_site) |
| |
| def summary(self) -> str: |
| return str(self._q.summary) |