refactor: sql_json view endpoint: use execution context instead of query (#16677)
* refactor sql_json view endpoint: use execution context instead of query
* fix failed tests
* fix failed tests
* refactor renaming enum options
diff --git a/superset/utils/sqllab_execution_context.py b/superset/utils/sqllab_execution_context.py
index 0694498..6c5532b 100644
--- a/superset/utils/sqllab_execution_context.py
+++ b/superset/utils/sqllab_execution_context.py
@@ -56,6 +56,8 @@
expand_data: bool
create_table_as_select: Optional[CreateTableAsSelect]
database: Optional[Database]
+ query: Query
+ _sql_result: Optional[SqlResults]
def __init__(self, query_params: Dict[str, Any]):
self.create_table_as_select = None
@@ -64,6 +66,9 @@
self.user_id = self._get_user_id()
self.client_id_or_short_id = cast(str, self.client_id or utils.shortid()[:10])
+ def set_query(self, query: Query) -> None:
+ self.query = query
+
def _init_from_query_params(self, query_params: Dict[str, Any]) -> None:
self.database_id = cast(int, query_params.get("database_id"))
self.schema = cast(str, query_params.get("schema"))
@@ -134,6 +139,12 @@
# TODO validate db.id is equal to self.database_id
pass
+ def get_execution_result(self) -> Optional[SqlResults]:
+ return self._sql_result
+
+ def set_execution_result(self, sql_result: Optional[SqlResults]) -> None:
+ self._sql_result = sql_result
+
def create_query(self) -> Query:
# pylint: disable=line-too-long
start_time = now_as_float()
diff --git a/superset/views/core.py b/superset/views/core.py
index a1bb3e9..34c1f6a 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -22,7 +22,8 @@
import re
from contextlib import closing
from datetime import datetime, timedelta
-from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
+from enum import Enum
+from typing import Any, Callable, cast, Dict, List, Optional, Union
from urllib import parse
import backoff
@@ -176,7 +177,7 @@
"your query again."
)
-SqlResults = Optional[Dict[str, Any]]
+SqlResults = Dict[str, Any]
class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@@ -2423,13 +2424,16 @@
"user_agent": cast(Optional[str], request.headers.get("USER_AGENT"))
}
execution_context = SqlJsonExecutionContext(request.json)
- return json_success(*self.sql_json_exec(execution_context, log_params))
+ status: SqlJsonExecutionStatus = self.sql_json_exec(
+ execution_context, log_params
+ )
+ return self._create_response_from_execution_context(execution_context, status)
def sql_json_exec( # pylint: disable=too-many-statements,useless-suppression
self,
execution_context: SqlJsonExecutionContext,
log_params: Optional[Dict[str, Any]] = None,
- ) -> Tuple[str, int]:
+ ) -> SqlJsonExecutionStatus:
"""Runs arbitrary sql and returns data as json"""
session = db.session()
@@ -2437,7 +2441,8 @@
query = self._get_existing_query(execution_context, session)
if self.is_query_handled(query):
- return self._convert_query_to_payload(cast(Query, query)), 200
+ execution_context.set_query(query) # type: ignore
+ return SqlJsonExecutionStatus.QUERY_ALREADY_CREATED
return self._run_sql_json_exec_from_scratch(
execution_context, session, log_params
@@ -2466,34 +2471,32 @@
QueryStatus.TIMED_OUT,
]
- @staticmethod
- def _convert_query_to_payload(query: Query) -> str:
- return json.dumps(
- {"query": query.to_dict()},
- default=utils.json_int_dttm_ser,
- ignore_nan=True,
- )
-
def _run_sql_json_exec_from_scratch(
self,
execution_context: SqlJsonExecutionContext,
session: Session,
log_params: Optional[Dict[str, Any]] = None,
- ) -> Tuple[str, int]:
+ ) -> SqlJsonExecutionStatus:
execution_context.set_database(
self._get_the_query_db(execution_context, session)
)
query = execution_context.create_query()
- self._save_new_query(query, session)
- logger.info("Triggering query_id: %i", query.id)
- self._validate_access(query, session)
- rendered_query = self._render_query(query, execution_context, session)
+ try:
+ self._save_new_query(query, session)
+ logger.info("Triggering query_id: %i", query.id)
+ self._validate_access(query, session)
+ execution_context.set_query(query)
+ rendered_query = self._render_query(execution_context)
- self._set_query_limit_if_required(execution_context, query, rendered_query)
+ self._set_query_limit_if_required(execution_context, rendered_query)
- return self._execute_query(
- query, execution_context, rendered_query, session, log_params
- )
+ return self._execute_query(
+ execution_context, rendered_query, session, log_params
+ )
+ except Exception as ex:
+ query.status = QueryStatus.FAILED
+ session.commit()
+ raise ex
@classmethod
def _get_the_query_db(
@@ -2544,7 +2547,7 @@
raise SupersetErrorException(ex.error, status=403) from ex
def _render_query( # pylint: disable=no-self-use
- self, query: Query, execution_context: SqlJsonExecutionContext, session: Session
+ self, execution_context: SqlJsonExecutionContext
) -> str:
def validate(
rendered_query: str, template_processor: BaseTemplateProcessor
@@ -2554,8 +2557,6 @@
ast = template_processor._env.parse(rendered_query)
undefined_parameters = find_undeclared_variables(ast) # type: ignore
if undefined_parameters:
- query.status = QueryStatus.FAILED
- session.commit()
raise SupersetTemplateParamsErrorException(
message=ngettext(
"The parameter %(parameters)s in your query is undefined.",
@@ -2572,6 +2573,8 @@
},
)
+ query = execution_context.query
+
try:
template_processor = get_template_processor(
database=query.database, query=query
@@ -2581,8 +2584,6 @@
)
validate(rendered_query, template_processor)
except TemplateError as ex:
- query.status = QueryStatus.FAILED
- session.commit()
raise SupersetTemplateParamsErrorException(
message=__(
'The query contains one or more malformed template parameters. Please check your query and confirm that all template parameters are surround by double braces, for example, "{{ ds }}". Then, try running your query again.'
@@ -2593,13 +2594,10 @@
return rendered_query
def _set_query_limit_if_required(
- self,
- execution_context: SqlJsonExecutionContext,
- query: Query,
- rendered_query: str,
+ self, execution_context: SqlJsonExecutionContext, rendered_query: str,
) -> None:
if self._is_required_to_set_limit(execution_context):
- self._set_query_limit(rendered_query, query, execution_context)
+ self._set_query_limit(rendered_query, execution_context)
def _is_required_to_set_limit( # pylint: disable=no-self-use
self, execution_context: SqlJsonExecutionContext
@@ -2609,10 +2607,7 @@
)
def _set_query_limit( # pylint: disable=no-self-use
- self,
- rendered_query: str,
- query: Query,
- execution_context: SqlJsonExecutionContext,
+ self, rendered_query: str, execution_context: SqlJsonExecutionContext,
) -> None:
db_engine_spec = execution_context.database.db_engine_spec # type: ignore
limits = [
@@ -2620,49 +2615,40 @@
execution_context.limit,
]
if limits[0] is None or limits[0] > limits[1]: # type: ignore
- query.limiting_factor = LimitingFactor.DROPDOWN
+ execution_context.query.limiting_factor = LimitingFactor.DROPDOWN
elif limits[1] > limits[0]: # type: ignore
- query.limiting_factor = LimitingFactor.QUERY
+ execution_context.query.limiting_factor = LimitingFactor.QUERY
else: # limits[0] == limits[1]
- query.limiting_factor = LimitingFactor.QUERY_AND_DROPDOWN
- query.limit = min(lim for lim in limits if lim is not None)
+ execution_context.query.limiting_factor = LimitingFactor.QUERY_AND_DROPDOWN
+ execution_context.query.limit = min(lim for lim in limits if lim is not None)
- def _execute_query( # pylint: disable=too-many-arguments
+ def _execute_query(
self,
- query: Query,
execution_context: SqlJsonExecutionContext,
rendered_query: str,
session: Session,
log_params: Optional[Dict[str, Any]],
- ) -> Tuple[str, int]:
+ ) -> SqlJsonExecutionStatus:
# Flag for whether or not to expand data
# (feature that will expand Presto row objects and arrays)
- expand_data: bool = execution_context.expand_data
# Async request.
if execution_context.is_run_asynchronous():
- return (
- self._sql_json_async(
- query, rendered_query, expand_data, session, log_params
- ),
- 202,
+ return self._sql_json_async(
+ execution_context, rendered_query, session, log_params
)
- # Sync request.
- return (
- self._sql_json_sync(
- query, rendered_query, expand_data, session, log_params
- ),
- 200,
+
+ return self._sql_json_sync(
+ execution_context, rendered_query, session, log_params
)
@classmethod
- def _sql_json_async( # pylint: disable=too-many-arguments
+ def _sql_json_async(
cls,
- query: Query,
+ execution_context: SqlJsonExecutionContext,
rendered_query: str,
- expand_data: bool,
session: Session,
log_params: Optional[Dict[str, Any]],
- ) -> str:
+ ) -> SqlJsonExecutionStatus:
"""
Send SQL JSON query to celery workers.
@@ -2671,6 +2657,7 @@
:param query: The query (SQLAlchemy) object
:return: A Flask Response
"""
+ query = execution_context.query
logger.info("Query %i: Running query on a Celery worker", query.id)
# Ignore the celery future object and the request may time out.
query_id = query.id
@@ -2684,7 +2671,7 @@
if g.user and hasattr(g.user, "username")
else None,
start_time=now_as_float(),
- expand_data=expand_data,
+ expand_data=execution_context.expand_data,
log_params=log_params,
)
@@ -2719,17 +2706,16 @@
QueryDAO.update_saved_query_exec_info(query_id)
session.commit()
- return cls._convert_query_to_payload(query)
+ return SqlJsonExecutionStatus.QUERY_IS_RUNNING
@classmethod
def _sql_json_sync(
cls,
- query: Query,
+ execution_context: SqlJsonExecutionContext,
rendered_query: str,
- expand_data: bool,
_session: Session,
log_params: Optional[Dict[str, Any]],
- ) -> str:
+ ) -> SqlJsonExecutionStatus:
"""
Execute SQL query (sql json).
@@ -2738,18 +2724,22 @@
:return: A Flask Response
:raises: SupersetTimeoutException
"""
+ query = execution_context.query
try:
timeout = config["SQLLAB_TIMEOUT"]
timeout_msg = f"The query exceeded the {timeout} seconds timeout."
query_id = query.id
data = cls._get_sql_results_with_timeout(
- query, timeout, rendered_query, expand_data, timeout_msg, log_params
+ query,
+ timeout,
+ rendered_query,
+ execution_context.expand_data,
+ timeout_msg,
+ log_params,
)
# Update saved query if needed
QueryDAO.update_saved_query_exec_info(query_id)
-
- # TODO: set LimitingFactor to display?
- payload = cls._convert_sql_result_to_payload(data)
+ execution_context.set_execution_result(data)
except SupersetTimeoutException as ex:
# re-raise exception for api exception handler
raise ex
@@ -2767,8 +2757,7 @@
)
# old string-only error message
raise SupersetGenericDBErrorException(data["error"])
-
- return payload
+ return SqlJsonExecutionStatus.HAS_RESULTS
@classmethod
def _get_sql_results_with_timeout( # pylint: disable=too-many-arguments
@@ -2779,7 +2768,7 @@
expand_data: bool,
timeout_msg: str,
log_params: Optional[Dict[str, Any]],
- ) -> SqlResults:
+ ) -> Optional[SqlResults]:
with utils.timeout(seconds=timeout, error_message=timeout_msg):
# pylint: disable=no-value-for-parameter
return sql_lab.get_sql_results(
@@ -2800,14 +2789,38 @@
is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE") and not query.select_as_cta
)
- @classmethod
- def _convert_sql_result_to_payload(cls, sql_results: SqlResults) -> str:
- return json.dumps(
- apply_display_max_row_limit(sql_results), # type: ignore
- default=utils.pessimistic_json_iso_dttm_ser,
- ignore_nan=True,
- encoding=None,
- )
+ def _create_response_from_execution_context(
+ # pylint: disable=invalid-name, no-self-use
+ self,
+ execution_context: SqlJsonExecutionContext,
+ status: SqlJsonExecutionStatus,
+ ) -> FlaskResponse:
+ def _to_payload_results_based(execution_result: SqlResults) -> str:
+ display_max_row = config["DISPLAY_MAX_ROW"]
+ return json.dumps(
+ apply_display_max_row_limit(execution_result, display_max_row),
+ default=utils.pessimistic_json_iso_dttm_ser,
+ ignore_nan=True,
+ encoding=None,
+ )
+
+ def _to_payload_query_based(query: Query) -> str:
+ return json.dumps(
+ {"query": query.to_dict()},
+ default=utils.json_int_dttm_ser,
+ ignore_nan=True,
+ )
+
+ status_code = 200
+ if status == SqlJsonExecutionStatus.HAS_RESULTS:
+ payload = _to_payload_results_based(
+ execution_context.get_execution_result() or {}
+ )
+ else:
+ payload = _to_payload_query_based(execution_context.query)
+ if status.QUERY_IS_RUNNING:
+ status_code = 202
+ return json_success(payload, status_code)
@has_access
@event_logger.log_this
@@ -3177,3 +3190,10 @@
"Failed to fetch schemas allowed for csv upload in this database! "
"Please contact your Superset Admin!"
)
+
+
+class SqlJsonExecutionStatus(Enum):
+ QUERY_ALREADY_CREATED = 1
+ HAS_RESULTS = 2
+ QUERY_IS_RUNNING = 3
+ FAILED = 4