refactor: sql lab: handling command exceptions (#16852)
* chore: support error_type in SupersetException and method to convert the exception to dictionary
* chore: support error_type in SupersetException and method to convert the exception to dictionary
* refactor handling command exceptions fix update query status when query was not created
diff --git a/superset/errors.py b/superset/errors.py
index abfa653..f186819 100644
--- a/superset/errors.py
+++ b/superset/errors.py
@@ -218,3 +218,9 @@
]
}
)
+
+ def to_dict(self) -> Dict[str, Any]:
+ rv = {"message": self.message, "error_type": self.error_type}
+ if self.extra:
+ rv["extra"] = self.extra # type: ignore
+ return rv
diff --git a/superset/exceptions.py b/superset/exceptions.py
index f14506c..b5ae747 100644
--- a/superset/exceptions.py
+++ b/superset/exceptions.py
@@ -28,17 +28,35 @@
message = ""
def __init__(
- self, message: str = "", exception: Optional[Exception] = None,
+ self,
+ message: str = "",
+ exception: Optional[Exception] = None,
+ error_type: Optional[SupersetErrorType] = None,
) -> None:
if message:
self.message = message
self._exception = exception
+ self._error_type = error_type
super().__init__(self.message)
@property
def exception(self) -> Optional[Exception]:
return self._exception
+ @property
+ def error_type(self) -> Optional[SupersetErrorType]:
+ return self._error_type
+
+ def to_dict(self) -> Dict[str, Any]:
+ rv = {}
+ if hasattr(self, "message"):
+ rv["message"] = self.message
+ if self.error_type:
+ rv["error_type"] = self.error_type
+ if self.exception is not None and hasattr(self.exception, "to_dict"):
+ rv = {**rv, **self.exception.to_dict()} # type: ignore
+ return rv
+
class SupersetErrorException(SupersetException):
"""Exceptions with a single SupersetErrorType associated with them"""
@@ -49,6 +67,9 @@
if status is not None:
self.status = status
+ def to_dict(self) -> Dict[str, Any]:
+ return self.error.to_dict()
+
class SupersetGenericErrorException(SupersetErrorException):
"""Exceptions that are too generic to have their own type"""
diff --git a/superset/sqllab/command.py b/superset/sqllab/command.py
index ea4fb45..c9b9df4 100644
--- a/superset/sqllab/command.py
+++ b/superset/sqllab/command.py
@@ -47,6 +47,7 @@
from superset.models.sql_lab import Query
from superset.queries.dao import QueryDAO
from superset.sqllab.command_status import SqlJsonExecutionStatus
+from superset.sqllab.exceptions import SqlLabException
from superset.sqllab.limiting_factor import LimitingFactor
from superset.sqllab.utils import apply_display_max_row_configuration_if_require
from superset.utils import core as utils
@@ -68,18 +69,18 @@
class ExecuteSqlCommand(BaseCommand):
- execution_context: SqlJsonExecutionContext
- log_params: Optional[Dict[str, Any]] = None
- session: Session
+ _execution_context: SqlJsonExecutionContext
+ _log_params: Optional[Dict[str, Any]] = None
+ _session: Session
def __init__(
self,
execution_context: SqlJsonExecutionContext,
log_params: Optional[Dict[str, Any]] = None,
) -> None:
- self.execution_context = execution_context
- self.log_params = log_params
- self.session = db.session()
+ self._execution_context = execution_context
+ self._log_params = log_params
+ self._session = db.session()
def validate(self) -> None:
pass
@@ -88,30 +89,29 @@
self,
) -> CommandResult:
"""Runs arbitrary sql and returns data as json"""
+ try:
+ query = self._get_existing_query()
+ if self.is_query_handled(query):
+ self._execution_context.set_query(query) # type: ignore
+ status = SqlJsonExecutionStatus.QUERY_ALREADY_CREATED
+ else:
+ status = self._run_sql_json_exec_from_scratch()
+ return {
+ "status": status,
+ "payload": self._create_payload_from_execution_context(status),
+ }
+ except (SqlLabException, SupersetErrorsException) as ex:
+ raise ex
+ except Exception as ex:
+ raise SqlLabException(self._execution_context, exception=ex) from ex
- query = self._get_existing_query(self.execution_context, self.session)
-
- if self.is_query_handled(query):
- self.execution_context.set_query(query) # type: ignore
- status = SqlJsonExecutionStatus.QUERY_ALREADY_CREATED
- else:
- status = self._run_sql_json_exec_from_scratch()
-
- return {
- "status": status,
- "payload": self._create_payload_from_execution_context(status),
- }
-
- @classmethod
- def _get_existing_query(
- cls, execution_context: SqlJsonExecutionContext, session: Session
- ) -> Optional[Query]:
+ def _get_existing_query(self) -> Optional[Query]:
query = (
- session.query(Query)
+ self._session.query(Query)
.filter_by(
- client_id=execution_context.client_id,
- user_id=execution_context.user_id,
- sql_editor_id=execution_context.sql_editor_id,
+ client_id=self._execution_context.client_id,
+ user_id=self._execution_context.user_id,
+ sql_editor_id=self._execution_context.sql_editor_id,
)
.one_or_none()
)
@@ -126,25 +126,24 @@
]
def _run_sql_json_exec_from_scratch(self) -> SqlJsonExecutionStatus:
- self.execution_context.set_database(self._get_the_query_db())
- query = self.execution_context.create_query()
+ self._execution_context.set_database(self._get_the_query_db())
+ query = self._execution_context.create_query()
+ self._save_new_query(query)
try:
self._save_new_query(query)
logger.info("Triggering query_id: %i", query.id)
self._validate_access(query)
- self.execution_context.set_query(query)
+ self._execution_context.set_query(query)
rendered_query = self._render_query()
-
self._set_query_limit_if_required(rendered_query)
-
return self._execute_query(rendered_query)
except Exception as ex:
query.status = QueryStatus.FAILED
- self.session.commit()
+ self._session.commit()
raise ex
def _get_the_query_db(self) -> Database:
- mydb = self.session.query(Database).get(self.execution_context.database_id)
+ mydb = self._session.query(Database).get(self._execution_context.database_id)
self._validate_query_db(mydb)
return mydb
@@ -160,12 +159,12 @@
def _save_new_query(self, query: Query) -> None:
try:
- self.session.add(query)
- self.session.flush()
- self.session.commit() # shouldn't be necessary
+ self._session.add(query)
+ self._session.flush()
+ self._session.commit() # shouldn't be necessary
except SQLAlchemyError as ex:
logger.error("Errors saving query details %s", str(ex), exc_info=True)
- self.session.rollback()
+ self._session.rollback()
if not query.id:
raise SupersetGenericErrorException(
__(
@@ -181,7 +180,7 @@
query.set_extra_json_key("errors", [dataclasses.asdict(ex.error)])
query.status = QueryStatus.FAILED
query.error_message = ex.error.message
- self.session.commit()
+ self._session.commit()
raise SupersetErrorException(ex.error, status=403) from ex
def _render_query(self) -> str:
@@ -205,18 +204,18 @@
error=SupersetErrorType.MISSING_TEMPLATE_PARAMS_ERROR,
extra={
"undefined_parameters": list(undefined_parameters),
- "template_parameters": self.execution_context.template_params,
+ "template_parameters": self._execution_context.template_params,
},
)
- query = self.execution_context.query
+ query = self._execution_context.query
try:
template_processor = get_template_processor(
database=query.database, query=query
)
rendered_query = template_processor.process_template(
- query.sql, **self.execution_context.template_params
+ query.sql, **self._execution_context.template_params
)
validate(rendered_query, template_processor)
except TemplateError as ex:
@@ -235,24 +234,24 @@
def _is_required_to_set_limit(self) -> bool:
return not (
- config.get("SQLLAB_CTAS_NO_LIMIT") and self.execution_context.select_as_cta
+ config.get("SQLLAB_CTAS_NO_LIMIT") and self._execution_context.select_as_cta
)
def _set_query_limit(self, rendered_query: str) -> None:
- db_engine_spec = self.execution_context.database.db_engine_spec # type: ignore
+ db_engine_spec = self._execution_context.database.db_engine_spec # type: ignore
limits = [
db_engine_spec.get_limit_from_sql(rendered_query),
- self.execution_context.limit,
+ self._execution_context.limit,
]
if limits[0] is None or limits[0] > limits[1]: # type: ignore
- self.execution_context.query.limiting_factor = LimitingFactor.DROPDOWN
+ self._execution_context.query.limiting_factor = LimitingFactor.DROPDOWN
elif limits[1] > limits[0]: # type: ignore
- self.execution_context.query.limiting_factor = LimitingFactor.QUERY
+ self._execution_context.query.limiting_factor = LimitingFactor.QUERY
else: # limits[0] == limits[1]
- self.execution_context.query.limiting_factor = (
+ self._execution_context.query.limiting_factor = (
LimitingFactor.QUERY_AND_DROPDOWN
)
- self.execution_context.query.limit = min(
+ self._execution_context.query.limit = min(
lim for lim in limits if lim is not None
)
@@ -260,7 +259,7 @@
# Flag for whether or not to expand data
# (feature that will expand Presto row objects and arrays)
# Async request.
- if self.execution_context.is_run_asynchronous():
+ if self._execution_context.is_run_asynchronous():
return self._sql_json_async(rendered_query)
return self._sql_json_sync(rendered_query)
@@ -271,7 +270,7 @@
:param rendered_query: the rendered query to perform by workers
:return: A Flask Response
"""
- query = self.execution_context.query
+ query = self._execution_context.query
logger.info("Query %i: Running query on a Celery worker", query.id)
# Ignore the celery future object and the request may time out.
query_id = query.id
@@ -285,8 +284,8 @@
if g.user and hasattr(g.user, "username")
else None,
start_time=now_as_float(),
- expand_data=self.execution_context.expand_data,
- log_params=self.log_params,
+ expand_data=self._execution_context.expand_data,
+ log_params=self._log_params,
)
# Explicitly forget the task to ensure the task metadata is removed from the
@@ -312,14 +311,14 @@
query.set_extra_json_key("errors", [error_payload])
query.status = QueryStatus.FAILED
query.error_message = message
- self.session.commit()
+ self._session.commit()
raise SupersetErrorException(error) from ex
# Update saved query with execution info from the query execution
QueryDAO.update_saved_query_exec_info(query_id)
- self.session.commit()
+ self._session.commit()
return SqlJsonExecutionStatus.QUERY_IS_RUNNING
def _sql_json_sync(self, rendered_query: str) -> SqlJsonExecutionStatus:
@@ -329,7 +328,7 @@
:param rendered_query: The rendered query (included templates)
:raises: SupersetTimeoutException
"""
- query = self.execution_context.query
+ query = self._execution_context.query
try:
timeout = config["SQLLAB_TIMEOUT"]
timeout_msg = f"The query exceeded the {timeout} seconds timeout."
@@ -339,7 +338,7 @@
)
# Update saved query if needed
QueryDAO.update_saved_query_exec_info(query_id)
- self.execution_context.set_execution_result(data)
+ self._execution_context.set_execution_result(data)
except SupersetTimeoutException as ex:
# re-raise exception for api exception handler
raise ex
@@ -362,7 +361,7 @@
def _get_sql_results_with_timeout(
self, timeout: int, rendered_query: str, timeout_msg: str,
) -> Optional[SqlResults]:
- query = self.execution_context.query
+ query = self._execution_context.query
with utils.timeout(seconds=timeout, error_message=timeout_msg):
# pylint: disable=no-value-for-parameter
return sql_lab.get_sql_results(
@@ -373,8 +372,8 @@
user_name=g.user.username
if g.user and hasattr(g.user, "username")
else None,
- expand_data=self.execution_context.expand_data,
- log_params=self.log_params,
+ expand_data=self._execution_context.expand_data,
+ log_params=self._log_params,
)
@classmethod
@@ -389,9 +388,9 @@
if status == SqlJsonExecutionStatus.HAS_RESULTS:
return self._to_payload_results_based(
- self.execution_context.get_execution_result() or {}
+ self._execution_context.get_execution_result() or {}
)
- return self._to_payload_query_based(self.execution_context.query)
+ return self._to_payload_query_based(self._execution_context.query)
def _to_payload_results_based( # pylint: disable=no-self-use
self, execution_result: SqlResults
diff --git a/superset/sqllab/exceptions.py b/superset/sqllab/exceptions.py
new file mode 100644
index 0000000..6b1736c
--- /dev/null
+++ b/superset/sqllab/exceptions.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import os
+from typing import Optional, TYPE_CHECKING
+
+from superset.errors import SupersetError, SupersetErrorType
+from superset.exceptions import SupersetException
+
+MSG_FORMAT = "Failed to execute {}"
+
+if TYPE_CHECKING:
+ from superset.utils.sqllab_execution_context import SqlJsonExecutionContext
+
+
+class SqlLabException(SupersetException):
+ sql_json_execution_context: SqlJsonExecutionContext
+ failed_reason_msg: str
+ suggestion_help_msg: Optional[str]
+
+ def __init__( # pylint: disable=too-many-arguments
+ self,
+ sql_json_execution_context: SqlJsonExecutionContext,
+ error_type: Optional[SupersetErrorType] = None,
+ reason_message: Optional[str] = None,
+ exception: Optional[Exception] = None,
+ suggestion_help_msg: Optional[str] = None,
+ ) -> None:
+ self.sql_json_execution_context = sql_json_execution_context
+ self.failed_reason_msg = self._get_reason(reason_message, exception)
+ self.suggestion_help_msg = suggestion_help_msg
+ if error_type is None:
+ if exception is not None:
+ if (
+ hasattr(exception, "error_type")
+ and exception.error_type is not None # type: ignore
+ ):
+ error_type = exception.error_type # type: ignore
+ elif hasattr(exception, "error") and isinstance(
+ exception.error, SupersetError # type: ignore
+ ):
+ error_type = exception.error.error_type # type: ignore
+ else:
+ error_type = SupersetErrorType.GENERIC_BACKEND_ERROR
+
+ super().__init__(self._generate_message(), exception, error_type)
+
+ def _generate_message(self) -> str:
+ msg = MSG_FORMAT.format(self.sql_json_execution_context.get_query_details())
+ if self.failed_reason_msg:
+ msg = msg + self.failed_reason_msg
+ if self.suggestion_help_msg is not None:
+ msg = "{} {} {}".format(msg, os.linesep, self.suggestion_help_msg)
+ return msg
+
+ @classmethod
+ def _get_reason(
+ cls, reason_message: Optional[str] = None, exception: Optional[Exception] = None
+ ) -> str:
+ if reason_message is not None:
+ return ": {}".format(reason_message)
+ if exception is not None:
+ if hasattr(exception, "get_message"):
+ return ": {}".format(exception.get_message()) # type: ignore
+ if hasattr(exception, "message"):
+ return ": {}".format(exception.message) # type: ignore
+ return ": {}".format(str(exception))
+ return ""
diff --git a/superset/utils/sqllab_execution_context.py b/superset/utils/sqllab_execution_context.py
index c8cc344..52ba2e7 100644
--- a/superset/utils/sqllab_execution_context.py
+++ b/superset/utils/sqllab_execution_context.py
@@ -22,6 +22,7 @@
from typing import Any, cast, Dict, Optional, TYPE_CHECKING
from flask import g
+from sqlalchemy.orm.exc import DetachedInstanceError
from superset import is_feature_enabled
from superset.models.sql_lab import Query
@@ -177,6 +178,15 @@
client_id=self.client_id_or_short_id,
)
+ def get_query_details(self) -> str:
+ try:
+ if self.query:
+ if self.query.id:
+ return "query '{}' - '{}'".format(self.query.id, self.query.sql)
+ except DetachedInstanceError:
+ pass
+ return "query '{}'".format(self.sql)
+
class CreateTableAsSelect: # pylint: disable=too-few-public-methods
ctas_method: CtasMethod
diff --git a/superset/views/core.py b/superset/views/core.py
index 7bdd77b..dadda30 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -100,6 +100,7 @@
from superset.sql_validators import get_validator_by_name
from superset.sqllab.command import CommandResult, ExecuteSqlCommand
from superset.sqllab.command_status import SqlJsonExecutionStatus
+from superset.sqllab.exceptions import SqlLabException
from superset.sqllab.limiting_factor import LimitingFactor
from superset.sqllab.utils import apply_display_max_row_configuration_if_require
from superset.tasks.async_queries import load_explore_json_into_cache
@@ -2434,13 +2435,17 @@
@event_logger.log_this
@expose("/sql_json/", methods=["POST"])
def sql_json(self) -> FlaskResponse:
- log_params = {
- "user_agent": cast(Optional[str], request.headers.get("USER_AGENT"))
- }
- execution_context = SqlJsonExecutionContext(request.json)
- command = ExecuteSqlCommand(execution_context, log_params)
- command_result: CommandResult = command.run()
- return self._create_response_from_execution_context(command_result)
+ try:
+ log_params = {
+ "user_agent": cast(Optional[str], request.headers.get("USER_AGENT"))
+ }
+ execution_context = SqlJsonExecutionContext(request.json)
+ command = ExecuteSqlCommand(execution_context, log_params)
+ command_result: CommandResult = command.run()
+ return self._create_response_from_execution_context(command_result)
+ except SqlLabException as ex:
+ payload = {"errors": [ex.to_dict()]}
+ return json_error_response(status=ex.status, payload=payload)
def _create_response_from_execution_context( # pylint: disable=invalid-name, no-self-use
self, command_result: CommandResult,