feat: add global max row limit (#16683)
* feat: add global max limit
* fix lint and tests
* leave SAMPLES_ROW_LIMIT unchanged
* fix sample rowcount test
* replace max global limit with existing sql max row limit
* fix test
* make max_limit optional in util
* improve comments
(cherry picked from commit 4e3d4f6daf01749b8f28e0770e138db5ed8fae91)
diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py
index 349d62d..659d79b 100644
--- a/superset/common/query_actions.py
+++ b/superset/common/query_actions.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import copy
-import math
from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING
from flask_babel import _
@@ -131,15 +130,12 @@
query_context: "QueryContext", query_obj: "QueryObject", force_cached: bool = False
) -> Dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
- row_limit = query_obj.row_limit or math.inf
query_obj = copy.copy(query_obj)
query_obj.is_timeseries = False
query_obj.orderby = []
query_obj.groupby = []
query_obj.metrics = []
query_obj.post_processing = []
- query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"])
- query_obj.row_offset = 0
query_obj.columns = [o.column_name for o in datasource.columns]
return _get_full(query_context, query_obj, force_cached)
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 566b01c..83070f7 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -100,11 +100,11 @@
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
- self.queries = [QueryObject(**query_obj) for query_obj in queries]
self.force = force
self.custom_cache_timeout = custom_cache_timeout
self.result_type = result_type or ChartDataResultType.FULL
self.result_format = result_format or ChartDataResultFormat.JSON
+ self.queries = [QueryObject(self, **query_obj) for query_obj in queries]
self.cache_values = {
"datasource": datasource,
"queries": queries,
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index bdf5f89..074de61 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime, timedelta
-from typing import Any, Dict, List, NamedTuple, Optional
+from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING
from flask_babel import gettext as _
from pandas import DataFrame
@@ -28,6 +28,7 @@
from superset.typing import Metric, OrderBy
from superset.utils import pandas_postprocessing
from superset.utils.core import (
+ apply_max_row_limit,
ChartDataResultType,
DatasourceDict,
DTTM_ALIAS,
@@ -41,6 +42,10 @@
from superset.utils.hashing import md5_sha_from_dict
from superset.views.utils import get_time_range_endpoints
+if TYPE_CHECKING:
+ from superset.common.query_context import QueryContext # pragma: no cover
+
+
config = app.config
logger = logging.getLogger(__name__)
@@ -100,6 +105,7 @@
def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
+ query_context: "QueryContext",
datasource: Optional[DatasourceDict] = None,
result_type: Optional[ChartDataResultType] = None,
annotation_layers: Optional[List[Dict[str, Any]]] = None,
@@ -138,7 +144,7 @@
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
- self.result_type = result_type
+ self.result_type = result_type or query_context.result_type
self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
self.annotation_layers = [
layer
@@ -180,7 +186,12 @@
for x in metrics
]
- self.row_limit = config["ROW_LIMIT"] if row_limit is None else row_limit
+ default_row_limit = (
+ config["SAMPLES_ROW_LIMIT"]
+ if self.result_type == ChartDataResultType.SAMPLES
+ else config["ROW_LIMIT"]
+ )
+ self.row_limit = apply_max_row_limit(row_limit or default_row_limit)
self.row_offset = row_offset or 0
self.filter = filters or []
self.timeseries_limit = timeseries_limit
diff --git a/superset/config.py b/superset/config.py
index 25027a3..21671a7 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -115,9 +115,9 @@
# default viz used in chart explorer
DEFAULT_VIZ_TYPE = "table"
+# default row limit when requesting chart data
ROW_LIMIT = 50000
-VIZ_ROW_LIMIT = 10000
-# max rows retreieved when requesting samples from datasource in explore view
+# default row limit when requesting samples from datasource in explore view
SAMPLES_ROW_LIMIT = 1000
# max rows retrieved by filter select auto complete
FILTER_SELECT_ROW_LIMIT = 10000
@@ -665,9 +665,7 @@
# Set this API key to enable Mapbox visualizations
MAPBOX_API_KEY = os.environ.get("MAPBOX_API_KEY", "")
-# Maximum number of rows returned from a database
-# in async mode, no more than SQL_MAX_ROW will be returned and stored
-# in the results backend. This also becomes the limit when exporting CSVs
+# Maximum number of rows returned for any analytical database query
SQL_MAX_ROW = 100000
# Maximum number of rows displayed in SQL Lab UI
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 646b04c..ac4c8b8a 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -1761,3 +1761,25 @@
return bool(strtobool(bool_str.lower()))
except ValueError:
return False
+
+
+def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
+ """
+ Override row limit if max global limit is defined
+
+ :param limit: requested row limit
+ :param max_limit: Maximum allowed row limit
+ :return: Capped row limit
+
+ >>> apply_max_row_limit(100000, 10)
+ 10
+ >>> apply_max_row_limit(10, 100000)
+ 10
+ >>> apply_max_row_limit(0, 10000)
+ 10000
+ """
+ if max_limit is None:
+ max_limit = current_app.config["SQL_MAX_ROW"]
+ if limit != 0:
+ return min(max_limit, limit)
+ return max_limit
diff --git a/superset/utils/sqllab_execution_context.py b/superset/utils/sqllab_execution_context.py
index 0694498..c87ac3d 100644
--- a/superset/utils/sqllab_execution_context.py
+++ b/superset/utils/sqllab_execution_context.py
@@ -23,10 +23,11 @@
from flask import g
-from superset import app, is_feature_enabled
+from superset import is_feature_enabled
from superset.models.sql_lab import Query
from superset.sql_parse import CtasMethod
from superset.utils import core as utils
+from superset.utils.core import apply_max_row_limit
from superset.utils.dates import now_as_float
from superset.views.utils import get_cta_schema_name
@@ -97,7 +98,7 @@
@staticmethod
def _get_limit_param(query_params: Dict[str, Any]) -> int:
- limit: int = query_params.get("queryLimit") or app.config["SQL_MAX_ROW"]
+ limit = apply_max_row_limit(query_params.get("queryLimit") or 0)
if limit < 0:
logger.warning(
"Invalid limit of %i specified. Defaulting to max limit.", limit
diff --git a/superset/views/core.py b/superset/views/core.py
index 13c9417..b77f0c1 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -107,7 +107,7 @@
from superset.utils import core as utils, csv
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.cache import etag_cache
-from superset.utils.core import ReservedUrlParameters
+from superset.utils.core import apply_max_row_limit, ReservedUrlParameters
from superset.utils.dates import now_as_float
from superset.utils.decorators import check_dashboard_access
from superset.utils.sqllab_execution_context import SqlJsonExecutionContext
@@ -898,8 +898,9 @@
return json_error_response(DATASOURCE_MISSING_ERR)
datasource.raise_for_access()
+ row_limit = apply_max_row_limit(config["FILTER_SELECT_ROW_LIMIT"])
payload = json.dumps(
- datasource.values_for_column(column, config["FILTER_SELECT_ROW_LIMIT"]),
+ datasource.values_for_column(column, row_limit),
default=utils.json_int_dttm_ser,
ignore_nan=True,
)
diff --git a/superset/viz.py b/superset/viz.py
index 357b6c8..9359e88 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -21,7 +21,7 @@
Superset can render.
"""
import copy
-import inspect
+import dataclasses
import logging
import math
import re
@@ -70,6 +70,7 @@
from superset.utils import core as utils, csv
from superset.utils.cache import set_and_log_cache
from superset.utils.core import (
+ apply_max_row_limit,
DTTM_ALIAS,
ExtraFiltersReasonType,
JS_MAX_INTEGER,
@@ -81,9 +82,6 @@
from superset.utils.dates import datetime_to_epoch
from superset.utils.hashing import md5_sha_from_str
-import dataclasses # isort:skip
-
-
if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
@@ -110,7 +108,7 @@
FILTER_VALUES_REGEX = re.compile(r"filter_values\(['\"](\w+)['\"]\,")
-class BaseViz:
+class BaseViz: # pylint: disable=too-many-public-methods
"""All visualizations derive this base class"""
@@ -332,6 +330,7 @@
limit = int(form_data.get("limit") or 0)
timeseries_limit_metric = form_data.get("timeseries_limit_metric")
row_limit = int(form_data.get("row_limit") or config["ROW_LIMIT"])
+ row_limit = apply_max_row_limit(row_limit)
# default order direction
order_desc = form_data.get("order_desc", True)
@@ -556,7 +555,7 @@
)
self.errors.append(error)
self.status = utils.QueryStatus.FAILED
- except Exception as ex:
+ except Exception as ex: # pylint: disable=broad-except
logger.exception(ex)
error = dataclasses.asdict(
@@ -625,7 +624,7 @@
include_index = not isinstance(df.index, pd.RangeIndex)
return csv.df_to_escaped_csv(df, index=include_index, **config["CSV_EXPORT"])
- def get_data(self, df: pd.DataFrame) -> VizData:
+ def get_data(self, df: pd.DataFrame) -> VizData: # pylint: disable=no-self-use
return df.to_dict(orient="records")
@property
@@ -1242,7 +1241,7 @@
d["orderby"] = [(sort_by, is_asc)]
return d
- def to_series(
+ def to_series( # pylint: disable=too-many-branches
self, df: pd.DataFrame, classed: str = "", title_suffix: str = ""
) -> List[Dict[str, Any]]:
cols = []
@@ -1446,6 +1445,7 @@
return {}
def get_data(self, df: pd.DataFrame) -> VizData:
+ # pylint: disable=import-outside-toplevel,too-many-locals
multiline_fd = self.form_data
# Late import to avoid circular import issues
from superset.charts.dao import ChartDAO
@@ -1669,19 +1669,20 @@
def query_obj(self) -> QueryObjectDict:
"""Returns the query object for this visualization"""
- d = super().query_obj()
- d["row_limit"] = self.form_data.get("row_limit", int(config["VIZ_ROW_LIMIT"]))
+ query_obj = super().query_obj()
numeric_columns = self.form_data.get("all_columns_x")
if numeric_columns is None:
raise QueryObjectValidationError(
_("Must have at least one numeric column specified")
)
- self.columns = numeric_columns
- d["columns"] = numeric_columns + self.groupby
+ self.columns = ( # pylint: disable=attribute-defined-outside-init
+ numeric_columns
+ )
+ query_obj["columns"] = numeric_columns + self.groupby
# override groupby entry to avoid aggregation
- d["groupby"] = None
- d["metrics"] = None
- return d
+ query_obj["groupby"] = None
+ query_obj["metrics"] = None
+ return query_obj
def labelify(self, keys: Union[List[str], str], column: str) -> str:
if isinstance(keys, str):
@@ -1751,7 +1752,7 @@
return d
- def get_data(self, df: pd.DataFrame) -> VizData:
+ def get_data(self, df: pd.DataFrame) -> VizData: # pylint: disable=too-many-locals
if df.empty:
return None
@@ -2061,6 +2062,7 @@
return {}
def run_extra_queries(self) -> None:
+ # pylint: disable=import-outside-toplevel
from superset.common.query_context import QueryContext
qry = super().query_obj()
@@ -2373,6 +2375,7 @@
def get_data(self, df: pd.DataFrame) -> VizData:
fd = self.form_data
# Late imports to avoid circular import issues
+ # pylint: disable=import-outside-toplevel
from superset import db
from superset.models.slice import Slice
@@ -2393,6 +2396,7 @@
spatial_control_keys: List[str] = []
def get_metrics(self) -> List[str]:
+ # pylint: disable=attribute-defined-outside-init
self.metric = self.form_data.get("size")
return [self.metric] if self.metric else []
@@ -2557,15 +2561,18 @@
is_timeseries = True
def query_obj(self) -> QueryObjectDict:
- fd = self.form_data
- self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity"))
- self.point_radius_fixed = fd.get("point_radius_fixed") or {
+ # pylint: disable=attribute-defined-outside-init
+ self.is_timeseries = bool(
+ self.form_data.get("time_grain_sqla") or self.form_data.get("granularity")
+ )
+ self.point_radius_fixed = self.form_data.get("point_radius_fixed") or {
"type": "fix",
"value": 500,
}
return super().query_obj()
def get_metrics(self) -> List[str]:
+ # pylint: disable=attribute-defined-outside-init
self.metric = None
if self.point_radius_fixed.get("type") == "metric":
self.metric = self.point_radius_fixed["value"]
diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py
index 7499e9e..8517e97 100644
--- a/tests/integration_tests/charts/api_tests.py
+++ b/tests/integration_tests/charts/api_tests.py
@@ -18,7 +18,7 @@
"""Unit tests for Superset"""
import json
import unittest
-from datetime import datetime, timedelta
+from datetime import datetime
from io import BytesIO
from typing import Optional
from unittest import mock
@@ -1203,6 +1203,7 @@
self.login(username="admin")
request_payload = get_query_context("birth_names")
del request_payload["queries"][0]["row_limit"]
+
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
@@ -1210,11 +1211,46 @@
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
- "superset.common.query_actions.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
+ "superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 10},
)
- def test_chart_data_default_sample_limit(self):
+ def test_chart_data_sql_max_row_limit(self):
"""
- Chart data API: Ensure sample response row count doesn't exceed default limit
+ Chart data API: Ensure row count doesn't exceed max global row limit
+ """
+ self.login(username="admin")
+ request_payload = get_query_context("birth_names")
+ request_payload["queries"][0]["row_limit"] = 10000000
+ rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
+ response_payload = json.loads(rv.data.decode("utf-8"))
+ result = response_payload["result"][0]
+ self.assertEqual(result["rowcount"], 10)
+
+ @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+ @mock.patch(
+ "superset.common.query_object.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
+ )
+ def test_chart_data_sample_default_limit(self):
+ """
+ Chart data API: Ensure sample response row count defaults to config defaults
+ """
+ self.login(username="admin")
+ request_payload = get_query_context("birth_names")
+ request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
+ del request_payload["queries"][0]["row_limit"]
+ rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
+ response_payload = json.loads(rv.data.decode("utf-8"))
+ result = response_payload["result"][0]
+ self.assertEqual(result["rowcount"], 5)
+
+ @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+ @mock.patch(
+ "superset.common.query_actions.config",
+ {**app.config, "SAMPLES_ROW_LIMIT": 5, "SQL_MAX_ROW": 15},
+ )
+ def test_chart_data_sample_custom_limit(self):
+ """
+ Chart data API: Ensure requested sample response row count is between
+ default and SQL max row limit
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
@@ -1223,6 +1259,24 @@
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
+ self.assertEqual(result["rowcount"], 10)
+
+ @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+ @mock.patch(
+ "superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 5},
+ )
+ def test_chart_data_sql_max_row_sample_limit(self):
+ """
+ Chart data API: Ensure requested sample response row count doesn't
+ exceed SQL max row limit
+ """
+ self.login(username="admin")
+ request_payload = get_query_context("birth_names")
+ request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
+ request_payload["queries"][0]["row_limit"] = 10000000
+ rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
+ response_payload = json.loads(rv.data.decode("utf-8"))
+ result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)
def test_chart_data_incorrect_result_type(self):
diff --git a/tests/integration_tests/charts/schema_tests.py b/tests/integration_tests/charts/schema_tests.py
index e34b7d7..977cf72 100644
--- a/tests/integration_tests/charts/schema_tests.py
+++ b/tests/integration_tests/charts/schema_tests.py
@@ -16,17 +16,25 @@
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
-from typing import Any, Dict, Tuple
+from unittest import mock
+
+import pytest
from marshmallow import ValidationError
from tests.integration_tests.test_app import app
from superset.charts.schemas import ChartDataQueryContextSchema
-from superset.common.query_context import QueryContext
from tests.integration_tests.base_tests import SupersetTestCase
+from tests.integration_tests.fixtures.birth_names_dashboard import (
+ load_birth_names_dashboard_with_slices,
+)
from tests.integration_tests.fixtures.query_context import get_query_context
class TestSchema(SupersetTestCase):
+ @mock.patch(
+ "superset.common.query_object.config", {**app.config, "ROW_LIMIT": 5000},
+ )
+ @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_limit_and_offset(self):
self.login(username="admin")
payload = get_query_context("birth_names")
@@ -36,7 +44,7 @@
payload["queries"][0].pop("row_offset", None)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
- self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"])
+ self.assertEqual(query_object.row_limit, 5000)
self.assertEqual(query_object.row_offset, 0)
# Valid limit and offset
@@ -55,12 +63,14 @@
self.assertIn("row_limit", context.exception.messages["queries"][0])
self.assertIn("row_offset", context.exception.messages["queries"][0])
+ @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_null_timegrain(self):
self.login(username="admin")
payload = get_query_context("birth_names")
payload["queries"][0]["extras"]["time_grain_sqla"] = None
_ = ChartDataQueryContextSchema().load(payload)
+ @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_series_limit(self):
self.login(username="admin")
payload = get_query_context("birth_names")
@@ -82,6 +92,7 @@
}
_ = ChartDataQueryContextSchema().load(payload)
+ @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_null_post_processing_op(self):
self.login(username="admin")
payload = get_query_context("birth_names")
diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py
index 87af4c9..e895e7e 100644
--- a/tests/integration_tests/query_context_tests.py
+++ b/tests/integration_tests/query_context_tests.py
@@ -90,6 +90,7 @@
self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
self.assertEqual(post_proc["options"], payload_post_proc["options"])
+ @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_cache(self):
table_name = "birth_names"
table = self.get_table(name=table_name)