blob: 5b5749415b7d80d09974175af3a7ce818991f351 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=line-too-long
from __future__ import annotations
import dataclasses
import logging
from typing import Any, Dict, Optional
import simplejson as json
from flask import g
from flask_babel import gettext as __, ngettext
from jinja2.exceptions import TemplateError
from jinja2.meta import find_undeclared_variables
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm.session import Session
from superset import app, db, is_feature_enabled, sql_lab
from superset.commands.base import BaseCommand
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
SupersetErrorException,
SupersetErrorsException,
SupersetGenericDBErrorException,
SupersetGenericErrorException,
SupersetSecurityException,
SupersetTemplateParamsErrorException,
SupersetTimeoutException,
)
from superset.jinja_context import BaseTemplateProcessor, get_template_processor
from superset.models.core import Database
from superset.models.sql_lab import LimitingFactor, Query
from superset.queries.dao import QueryDAO
from superset.sqllab.command_status import SqlJsonExecutionStatus
from superset.utils import core as utils
from superset.utils.dates import now_as_float
from superset.utils.sqllab_execution_context import SqlJsonExecutionContext
from superset.views.utils import apply_display_max_row_limit
config = app.config
QueryStatus = utils.QueryStatus
logger = logging.getLogger(__name__)
PARAMETER_MISSING_ERR = (
"Please check your template parameters for syntax errors and make sure "
"they match across your SQL query and Set Parameters. Then, try running "
"your query again."
)
SqlResults = Dict[str, Any]
CommandResult = Dict[str, Any]
class ExecuteSqlCommand(BaseCommand):
execution_context: SqlJsonExecutionContext
log_params: Optional[Dict[str, Any]] = None
session: Session
def __init__(
self,
execution_context: SqlJsonExecutionContext,
log_params: Optional[Dict[str, Any]] = None,
) -> None:
self.execution_context = execution_context
self.log_params = log_params
self.session = db.session()
def validate(self) -> None:
pass
def run( # pylint: disable=too-many-statements,useless-suppression
self,
) -> CommandResult:
"""Runs arbitrary sql and returns data as json"""
query = self._get_existing_query(self.execution_context, self.session)
if self.is_query_handled(query):
self.execution_context.set_query(query) # type: ignore
status = SqlJsonExecutionStatus.QUERY_ALREADY_CREATED
else:
status = self._run_sql_json_exec_from_scratch()
return {
"status": status,
"payload": self._create_payload_from_execution_context(status),
}
@classmethod
def _get_existing_query(
cls, execution_context: SqlJsonExecutionContext, session: Session
) -> Optional[Query]:
query = (
session.query(Query)
.filter_by(
client_id=execution_context.client_id,
user_id=execution_context.user_id,
sql_editor_id=execution_context.sql_editor_id,
)
.one_or_none()
)
return query
@classmethod
def is_query_handled(cls, query: Optional[Query]) -> bool:
return query is not None and query.status in [
QueryStatus.RUNNING,
QueryStatus.PENDING,
QueryStatus.TIMED_OUT,
]
def _run_sql_json_exec_from_scratch(self,) -> SqlJsonExecutionStatus:
self.execution_context.set_database(self._get_the_query_db())
query = self.execution_context.create_query()
try:
self._save_new_query(query)
logger.info("Triggering query_id: %i", query.id)
self._validate_access(query)
self.execution_context.set_query(query)
rendered_query = self._render_query()
self._set_query_limit_if_required(rendered_query)
return self._execute_query(rendered_query)
except Exception as ex:
query.status = QueryStatus.FAILED
self.session.commit()
raise ex
def _get_the_query_db(self) -> Database:
mydb = self.session.query(Database).get(self.execution_context.database_id)
self._validate_query_db(mydb)
return mydb
@classmethod
def _validate_query_db(cls, database: Optional[Database]) -> None:
if not database:
raise SupersetGenericErrorException(
__(
"The database referenced in this query was not found. Please "
"contact an administrator for further assistance or try again."
)
)
def _save_new_query(self, query: Query) -> None:
try:
self.session.add(query)
self.session.flush()
self.session.commit() # shouldn't be necessary
except SQLAlchemyError as ex:
logger.error("Errors saving query details %s", str(ex), exc_info=True)
self.session.rollback()
if not query.id:
raise SupersetGenericErrorException(
__(
"The query record was not created as expected. Please "
"contact an administrator for further assistance or try again."
)
)
def _validate_access(self, query: Query) -> None:
try:
query.raise_for_access()
except SupersetSecurityException as ex:
query.set_extra_json_key("errors", [dataclasses.asdict(ex.error)])
query.status = QueryStatus.FAILED
query.error_message = ex.error.message
self.session.commit()
raise SupersetErrorException(ex.error, status=403) from ex
def _render_query(self,) -> str:
def validate(
rendered_query: str, template_processor: BaseTemplateProcessor
) -> None:
if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"):
# pylint: disable=protected-access
ast = template_processor._env.parse(rendered_query)
undefined_parameters = find_undeclared_variables(ast) # type: ignore
if undefined_parameters:
raise SupersetTemplateParamsErrorException(
message=ngettext(
"The parameter %(parameters)s in your query is undefined.",
"The following parameters in your query are undefined: %(parameters)s.",
len(undefined_parameters),
parameters=utils.format_list(undefined_parameters),
)
+ " "
+ PARAMETER_MISSING_ERR,
error=SupersetErrorType.MISSING_TEMPLATE_PARAMS_ERROR,
extra={
"undefined_parameters": list(undefined_parameters),
"template_parameters": self.execution_context.template_params,
},
)
query = self.execution_context.query
try:
template_processor = get_template_processor(
database=query.database, query=query
)
rendered_query = template_processor.process_template(
query.sql, **self.execution_context.template_params
)
validate(rendered_query, template_processor)
except TemplateError as ex:
raise SupersetTemplateParamsErrorException(
message=__(
'The query contains one or more malformed template parameters. Please check your query and confirm that all template parameters are surround by double braces, for example, "{{ ds }}". Then, try running your query again.'
),
error=SupersetErrorType.INVALID_TEMPLATE_PARAMS_ERROR,
) from ex
return rendered_query
def _set_query_limit_if_required(self, rendered_query: str,) -> None:
if self._is_required_to_set_limit():
self._set_query_limit(rendered_query)
def _is_required_to_set_limit(self,) -> bool:
return not (
config.get("SQLLAB_CTAS_NO_LIMIT") and self.execution_context.select_as_cta
)
def _set_query_limit(self, rendered_query: str) -> None:
db_engine_spec = self.execution_context.database.db_engine_spec # type: ignore
limits = [
db_engine_spec.get_limit_from_sql(rendered_query),
self.execution_context.limit,
]
if limits[0] is None or limits[0] > limits[1]: # type: ignore
self.execution_context.query.limiting_factor = LimitingFactor.DROPDOWN
elif limits[1] > limits[0]: # type: ignore
self.execution_context.query.limiting_factor = LimitingFactor.QUERY
else: # limits[0] == limits[1]
self.execution_context.query.limiting_factor = (
LimitingFactor.QUERY_AND_DROPDOWN
)
self.execution_context.query.limit = min(
lim for lim in limits if lim is not None
)
def _execute_query(self, rendered_query: str,) -> SqlJsonExecutionStatus:
# Flag for whether or not to expand data
# (feature that will expand Presto row objects and arrays)
# Async request.
if self.execution_context.is_run_asynchronous():
return self._sql_json_async(rendered_query)
return self._sql_json_sync(rendered_query)
def _sql_json_async(self, rendered_query: str,) -> SqlJsonExecutionStatus:
"""
Send SQL JSON query to celery workers.
:param rendered_query: the rendered query to perform by workers
:return: A Flask Response
"""
query = self.execution_context.query
logger.info("Query %i: Running query on a Celery worker", query.id)
# Ignore the celery future object and the request may time out.
query_id = query.id
try:
task = sql_lab.get_sql_results.delay(
query.id,
rendered_query,
return_results=False,
store_results=not query.select_as_cta,
user_name=g.user.username
if g.user and hasattr(g.user, "username")
else None,
start_time=now_as_float(),
expand_data=self.execution_context.expand_data,
log_params=self.log_params,
)
# Explicitly forget the task to ensure the task metadata is removed from the
# Celery results backend in a timely manner.
try:
task.forget()
except NotImplementedError:
logger.warning(
"Unable to forget Celery task as backend"
"does not support this operation"
)
except Exception as ex:
logger.exception("Query %i: %s", query.id, str(ex))
message = __("Failed to start remote query on a worker.")
error = SupersetError(
message=message,
error_type=SupersetErrorType.ASYNC_WORKERS_ERROR,
level=ErrorLevel.ERROR,
)
error_payload = dataclasses.asdict(error)
query.set_extra_json_key("errors", [error_payload])
query.status = QueryStatus.FAILED
query.error_message = message
self.session.commit()
raise SupersetErrorException(error) from ex
# Update saved query with execution info from the query execution
QueryDAO.update_saved_query_exec_info(query_id)
self.session.commit()
return SqlJsonExecutionStatus.QUERY_IS_RUNNING
def _sql_json_sync(self, rendered_query: str) -> SqlJsonExecutionStatus:
"""
Execute SQL query (sql json).
:param rendered_query: The rendered query (included templates)
:raises: SupersetTimeoutException
"""
query = self.execution_context.query
try:
timeout = config["SQLLAB_TIMEOUT"]
timeout_msg = f"The query exceeded the {timeout} seconds timeout."
query_id = query.id
data = self._get_sql_results_with_timeout(
timeout, rendered_query, timeout_msg,
)
# Update saved query if needed
QueryDAO.update_saved_query_exec_info(query_id)
self.execution_context.set_execution_result(data)
except SupersetTimeoutException as ex:
# re-raise exception for api exception handler
raise ex
except Exception as ex:
logger.exception("Query %i failed unexpectedly", query.id)
raise SupersetGenericDBErrorException(
utils.error_msg_from_exception(ex)
) from ex
if data is not None and data.get("status") == QueryStatus.FAILED:
# new error payload with rich context
if data["errors"]:
raise SupersetErrorsException(
[SupersetError(**params) for params in data["errors"]]
)
# old string-only error message
raise SupersetGenericDBErrorException(data["error"])
return SqlJsonExecutionStatus.HAS_RESULTS
def _get_sql_results_with_timeout(
self, timeout: int, rendered_query: str, timeout_msg: str,
) -> Optional[SqlResults]:
query = self.execution_context.query
with utils.timeout(seconds=timeout, error_message=timeout_msg):
# pylint: disable=no-value-for-parameter
return sql_lab.get_sql_results(
query.id,
rendered_query,
return_results=True,
store_results=self._is_store_results(query),
user_name=g.user.username
if g.user and hasattr(g.user, "username")
else None,
expand_data=self.execution_context.expand_data,
log_params=self.log_params,
)
@classmethod
def _is_store_results(cls, query: Query) -> bool:
return (
is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE") and not query.select_as_cta
)
def _create_payload_from_execution_context(
# pylint: disable=invalid-name
self,
status: SqlJsonExecutionStatus,
) -> str:
if status == SqlJsonExecutionStatus.HAS_RESULTS:
return self._to_payload_results_based(
self.execution_context.get_execution_result() or {}
)
return self._to_payload_query_based(self.execution_context.query)
def _to_payload_results_based( # pylint: disable=no-self-use
self, execution_result: SqlResults
) -> str:
display_max_row = config["DISPLAY_MAX_ROW"]
return json.dumps(
apply_display_max_row_limit(execution_result, display_max_row),
default=utils.pessimistic_json_iso_dttm_ser,
ignore_nan=True,
encoding=None,
)
def _to_payload_query_based( # pylint: disable=no-self-use
self, query: Query
) -> str:
return json.dumps(
{"query": query.to_dict()},
default=utils.json_int_dttm_ser,
ignore_nan=True,
)