refactor: sql_json view endpoint: use execution context instead of query (#16677)

* refactor sql_json view endpoint: use execution context instead of query

* fix failed tests

* fix failed tests

* refactor renaming enum options
diff --git a/superset/utils/sqllab_execution_context.py b/superset/utils/sqllab_execution_context.py
index 0694498..6c5532b 100644
--- a/superset/utils/sqllab_execution_context.py
+++ b/superset/utils/sqllab_execution_context.py
@@ -56,6 +56,8 @@
     expand_data: bool
     create_table_as_select: Optional[CreateTableAsSelect]
     database: Optional[Database]
+    query: Query
+    _sql_result: Optional[SqlResults]
 
     def __init__(self, query_params: Dict[str, Any]):
         self.create_table_as_select = None
@@ -64,6 +66,9 @@
         self.user_id = self._get_user_id()
         self.client_id_or_short_id = cast(str, self.client_id or utils.shortid()[:10])
 
+    def set_query(self, query: Query) -> None:
+        self.query = query
+
     def _init_from_query_params(self, query_params: Dict[str, Any]) -> None:
         self.database_id = cast(int, query_params.get("database_id"))
         self.schema = cast(str, query_params.get("schema"))
@@ -134,6 +139,12 @@
         # TODO validate db.id is equal to self.database_id
         pass
 
+    def get_execution_result(self) -> Optional[SqlResults]:
+        return self._sql_result
+
+    def set_execution_result(self, sql_result: Optional[SqlResults]) -> None:
+        self._sql_result = sql_result
+
     def create_query(self) -> Query:
         # pylint: disable=line-too-long
         start_time = now_as_float()
diff --git a/superset/views/core.py b/superset/views/core.py
index a1bb3e9..34c1f6a 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -22,7 +22,8 @@
 import re
 from contextlib import closing
 from datetime import datetime, timedelta
-from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
+from enum import Enum
+from typing import Any, Callable, cast, Dict, List, Optional, Union
 from urllib import parse
 
 import backoff
@@ -176,7 +177,7 @@
     "your query again."
 )
 
-SqlResults = Optional[Dict[str, Any]]
+SqlResults = Dict[str, Any]
 
 
 class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
@@ -2423,13 +2424,16 @@
             "user_agent": cast(Optional[str], request.headers.get("USER_AGENT"))
         }
         execution_context = SqlJsonExecutionContext(request.json)
-        return json_success(*self.sql_json_exec(execution_context, log_params))
+        status: SqlJsonExecutionStatus = self.sql_json_exec(
+            execution_context, log_params
+        )
+        return self._create_response_from_execution_context(execution_context, status)
 
     def sql_json_exec(  # pylint: disable=too-many-statements,useless-suppression
         self,
         execution_context: SqlJsonExecutionContext,
         log_params: Optional[Dict[str, Any]] = None,
-    ) -> Tuple[str, int]:
+    ) -> SqlJsonExecutionStatus:
         """Runs arbitrary sql and returns data as json"""
 
         session = db.session()
@@ -2437,7 +2441,8 @@
         query = self._get_existing_query(execution_context, session)
 
         if self.is_query_handled(query):
-            return self._convert_query_to_payload(cast(Query, query)), 200
+            execution_context.set_query(query)  # type: ignore
+            return SqlJsonExecutionStatus.QUERY_ALREADY_CREATED
 
         return self._run_sql_json_exec_from_scratch(
             execution_context, session, log_params
@@ -2466,34 +2471,32 @@
             QueryStatus.TIMED_OUT,
         ]
 
-    @staticmethod
-    def _convert_query_to_payload(query: Query) -> str:
-        return json.dumps(
-            {"query": query.to_dict()},
-            default=utils.json_int_dttm_ser,
-            ignore_nan=True,
-        )
-
     def _run_sql_json_exec_from_scratch(
         self,
         execution_context: SqlJsonExecutionContext,
         session: Session,
         log_params: Optional[Dict[str, Any]] = None,
-    ) -> Tuple[str, int]:
+    ) -> SqlJsonExecutionStatus:
         execution_context.set_database(
             self._get_the_query_db(execution_context, session)
         )
         query = execution_context.create_query()
-        self._save_new_query(query, session)
-        logger.info("Triggering query_id: %i", query.id)
-        self._validate_access(query, session)
-        rendered_query = self._render_query(query, execution_context, session)
+        try:
+            self._save_new_query(query, session)
+            logger.info("Triggering query_id: %i", query.id)
+            self._validate_access(query, session)
+            execution_context.set_query(query)
+            rendered_query = self._render_query(execution_context)
 
-        self._set_query_limit_if_required(execution_context, query, rendered_query)
+            self._set_query_limit_if_required(execution_context, rendered_query)
 
-        return self._execute_query(
-            query, execution_context, rendered_query, session, log_params
-        )
+            return self._execute_query(
+                execution_context, rendered_query, session, log_params
+            )
+        except Exception as ex:
+            query.status = QueryStatus.FAILED
+            session.commit()
+            raise ex
 
     @classmethod
     def _get_the_query_db(
@@ -2544,7 +2547,7 @@
             raise SupersetErrorException(ex.error, status=403) from ex
 
     def _render_query(  # pylint: disable=no-self-use
-        self, query: Query, execution_context: SqlJsonExecutionContext, session: Session
+        self, execution_context: SqlJsonExecutionContext
     ) -> str:
         def validate(
             rendered_query: str, template_processor: BaseTemplateProcessor
@@ -2554,8 +2557,6 @@
                 ast = template_processor._env.parse(rendered_query)
                 undefined_parameters = find_undeclared_variables(ast)  # type: ignore
                 if undefined_parameters:
-                    query.status = QueryStatus.FAILED
-                    session.commit()
                     raise SupersetTemplateParamsErrorException(
                         message=ngettext(
                             "The parameter %(parameters)s in your query is undefined.",
@@ -2572,6 +2573,8 @@
                         },
                     )
 
+        query = execution_context.query
+
         try:
             template_processor = get_template_processor(
                 database=query.database, query=query
@@ -2581,8 +2584,6 @@
             )
             validate(rendered_query, template_processor)
         except TemplateError as ex:
-            query.status = QueryStatus.FAILED
-            session.commit()
             raise SupersetTemplateParamsErrorException(
                 message=__(
                     'The query contains one or more malformed template parameters. Please check your query and confirm that all template parameters are surround by double braces, for example, "{{ ds }}". Then, try running your query again.'
@@ -2593,13 +2594,10 @@
         return rendered_query
 
     def _set_query_limit_if_required(
-        self,
-        execution_context: SqlJsonExecutionContext,
-        query: Query,
-        rendered_query: str,
+        self, execution_context: SqlJsonExecutionContext, rendered_query: str,
     ) -> None:
         if self._is_required_to_set_limit(execution_context):
-            self._set_query_limit(rendered_query, query, execution_context)
+            self._set_query_limit(rendered_query, execution_context)
 
     def _is_required_to_set_limit(  # pylint: disable=no-self-use
         self, execution_context: SqlJsonExecutionContext
@@ -2609,10 +2607,7 @@
         )
 
     def _set_query_limit(  # pylint: disable=no-self-use
-        self,
-        rendered_query: str,
-        query: Query,
-        execution_context: SqlJsonExecutionContext,
+        self, rendered_query: str, execution_context: SqlJsonExecutionContext,
     ) -> None:
         db_engine_spec = execution_context.database.db_engine_spec  # type: ignore
         limits = [
@@ -2620,49 +2615,40 @@
             execution_context.limit,
         ]
         if limits[0] is None or limits[0] > limits[1]:  # type: ignore
-            query.limiting_factor = LimitingFactor.DROPDOWN
+            execution_context.query.limiting_factor = LimitingFactor.DROPDOWN
         elif limits[1] > limits[0]:  # type: ignore
-            query.limiting_factor = LimitingFactor.QUERY
+            execution_context.query.limiting_factor = LimitingFactor.QUERY
         else:  # limits[0] == limits[1]
-            query.limiting_factor = LimitingFactor.QUERY_AND_DROPDOWN
-        query.limit = min(lim for lim in limits if lim is not None)
+            execution_context.query.limiting_factor = LimitingFactor.QUERY_AND_DROPDOWN
+        execution_context.query.limit = min(lim for lim in limits if lim is not None)
 
-    def _execute_query(  # pylint: disable=too-many-arguments
+    def _execute_query(
         self,
-        query: Query,
         execution_context: SqlJsonExecutionContext,
         rendered_query: str,
         session: Session,
         log_params: Optional[Dict[str, Any]],
-    ) -> Tuple[str, int]:
+    ) -> SqlJsonExecutionStatus:
         # Flag for whether or not to expand data
         # (feature that will expand Presto row objects and arrays)
-        expand_data: bool = execution_context.expand_data
         # Async request.
         if execution_context.is_run_asynchronous():
-            return (
-                self._sql_json_async(
-                    query, rendered_query, expand_data, session, log_params
-                ),
-                202,
+            return self._sql_json_async(
+                execution_context, rendered_query, session, log_params
             )
-        # Sync request.
-        return (
-            self._sql_json_sync(
-                query, rendered_query, expand_data, session, log_params
-            ),
-            200,
+
+        return self._sql_json_sync(
+            execution_context, rendered_query, session, log_params
         )
 
     @classmethod
-    def _sql_json_async(  # pylint: disable=too-many-arguments
+    def _sql_json_async(
         cls,
-        query: Query,
+        execution_context: SqlJsonExecutionContext,
         rendered_query: str,
-        expand_data: bool,
         session: Session,
         log_params: Optional[Dict[str, Any]],
-    ) -> str:
+    ) -> SqlJsonExecutionStatus:
         """
         Send SQL JSON query to celery workers.
 
@@ -2671,6 +2657,7 @@
         :param query: The query (SQLAlchemy) object
         :return: A Flask Response
         """
+        query = execution_context.query
         logger.info("Query %i: Running query on a Celery worker", query.id)
         # Ignore the celery future object and the request may time out.
         query_id = query.id
@@ -2684,7 +2671,7 @@
                 if g.user and hasattr(g.user, "username")
                 else None,
                 start_time=now_as_float(),
-                expand_data=expand_data,
+                expand_data=execution_context.expand_data,
                 log_params=log_params,
             )
 
@@ -2719,17 +2706,16 @@
         QueryDAO.update_saved_query_exec_info(query_id)
 
         session.commit()
-        return cls._convert_query_to_payload(query)
+        return SqlJsonExecutionStatus.QUERY_IS_RUNNING
 
     @classmethod
     def _sql_json_sync(
         cls,
-        query: Query,
+        execution_context: SqlJsonExecutionContext,
         rendered_query: str,
-        expand_data: bool,
         _session: Session,
         log_params: Optional[Dict[str, Any]],
-    ) -> str:
+    ) -> SqlJsonExecutionStatus:
         """
         Execute SQL query (sql json).
 
@@ -2738,18 +2724,22 @@
         :return: A Flask Response
         :raises: SupersetTimeoutException
         """
+        query = execution_context.query
         try:
             timeout = config["SQLLAB_TIMEOUT"]
             timeout_msg = f"The query exceeded the {timeout} seconds timeout."
             query_id = query.id
             data = cls._get_sql_results_with_timeout(
-                query, timeout, rendered_query, expand_data, timeout_msg, log_params
+                query,
+                timeout,
+                rendered_query,
+                execution_context.expand_data,
+                timeout_msg,
+                log_params,
             )
             # Update saved query if needed
             QueryDAO.update_saved_query_exec_info(query_id)
-
-            # TODO: set LimitingFactor to display?
-            payload = cls._convert_sql_result_to_payload(data)
+            execution_context.set_execution_result(data)
         except SupersetTimeoutException as ex:
             # re-raise exception for api exception handler
             raise ex
@@ -2767,8 +2757,7 @@
                 )
             # old string-only error message
             raise SupersetGenericDBErrorException(data["error"])
-
-        return payload
+        return SqlJsonExecutionStatus.HAS_RESULTS
 
     @classmethod
     def _get_sql_results_with_timeout(  # pylint: disable=too-many-arguments
@@ -2779,7 +2768,7 @@
         expand_data: bool,
         timeout_msg: str,
         log_params: Optional[Dict[str, Any]],
-    ) -> SqlResults:
+    ) -> Optional[SqlResults]:
         with utils.timeout(seconds=timeout, error_message=timeout_msg):
             # pylint: disable=no-value-for-parameter
             return sql_lab.get_sql_results(
@@ -2800,14 +2789,38 @@
             is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE") and not query.select_as_cta
         )
 
-    @classmethod
-    def _convert_sql_result_to_payload(cls, sql_results: SqlResults) -> str:
-        return json.dumps(
-            apply_display_max_row_limit(sql_results),  # type: ignore
-            default=utils.pessimistic_json_iso_dttm_ser,
-            ignore_nan=True,
-            encoding=None,
-        )
+    def _create_response_from_execution_context(
+        # pylint: disable=invalid-name, no-self-use
+        self,
+        execution_context: SqlJsonExecutionContext,
+        status: SqlJsonExecutionStatus,
+    ) -> FlaskResponse:
+        def _to_payload_results_based(execution_result: SqlResults) -> str:
+            display_max_row = config["DISPLAY_MAX_ROW"]
+            return json.dumps(
+                apply_display_max_row_limit(execution_result, display_max_row),
+                default=utils.pessimistic_json_iso_dttm_ser,
+                ignore_nan=True,
+                encoding=None,
+            )
+
+        def _to_payload_query_based(query: Query) -> str:
+            return json.dumps(
+                {"query": query.to_dict()},
+                default=utils.json_int_dttm_ser,
+                ignore_nan=True,
+            )
+
+        status_code = 200
+        if status == SqlJsonExecutionStatus.HAS_RESULTS:
+            payload = _to_payload_results_based(
+                execution_context.get_execution_result() or {}
+            )
+        else:
+            payload = _to_payload_query_based(execution_context.query)
+        if status.QUERY_IS_RUNNING:
+            status_code = 202
+        return json_success(payload, status_code)
 
     @has_access
     @event_logger.log_this
@@ -3177,3 +3190,10 @@
                 "Failed to fetch schemas allowed for csv upload in this database! "
                 "Please contact your Superset Admin!"
             )
+
+
+class SqlJsonExecutionStatus(Enum):
+    QUERY_ALREADY_CREATED = 1
+    HAS_RESULTS = 2
+    QUERY_IS_RUNNING = 3
+    FAILED = 4