fix(sqla): labels_expected contains mutated label (#14095)
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 2b61520..4d263fe 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -521,6 +521,7 @@
if db_engine_spec.allows_alias_in_select:
label = db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)
+ sqla_col.key = label_expected
return sqla_col
def __repr__(self) -> str:
@@ -1094,7 +1095,7 @@
)
select_exprs += metrics_exprs
- labels_expected = [c.name for c in select_exprs]
+ labels_expected = [c.key for c in select_exprs]
select_exprs = db_engine_spec.make_select_compatible(
groupby_exprs_with_timestamp.values(), select_exprs
)
diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py
index cdd77c2..a3d7063 100644
--- a/tests/sqla_models_tests.py
+++ b/tests/sqla_models_tests.py
@@ -22,6 +22,7 @@
from superset import db
from superset.connectors.sqla.models import SqlaTable, TableColumn
+from superset.db_engine_specs.bigquery import BigQueryEngineSpec
from superset.db_engine_specs.druid import DruidEngineSpec
from superset.exceptions import QueryObjectValidationError
from superset.models.core import Database
@@ -305,3 +306,36 @@
assert cols["mycase"].expression == ""
assert VIRTUAL_TABLE_STRING_TYPES[backend].match(cols["mycase"].type)
assert cols["expr"].expression == "case when 1 then 1 else 0 end"
+
+ @patch("superset.models.core.Database.db_engine_spec", BigQueryEngineSpec)
+ def test_labels_expected_on_mutated_query(self):
+ query_obj = {
+ "granularity": None,
+ "from_dttm": None,
+ "to_dttm": None,
+ "groupby": ["user"],
+ "metrics": [
+ {
+ "expressionType": "SIMPLE",
+ "column": {"column_name": "user"},
+ "aggregate": "COUNT_DISTINCT",
+ "label": "COUNT_DISTINCT(user)",
+ }
+ ],
+ "is_timeseries": False,
+ "filter": [],
+ "extras": {},
+ }
+
+ database = Database(database_name="testdb", sqlalchemy_uri="sqlite://")
+ table = SqlaTable(table_name="bq_table", database=database)
+ db.session.add(database)
+ db.session.add(table)
+ db.session.commit()
+ sqlaq = table.get_sqla_query(**query_obj)
+ assert sqlaq.labels_expected == ["user", "COUNT_DISTINCT(user)"]
+ sql = table.database.compile_sqla_query(sqlaq.sqla_query)
+ assert "COUNT_DISTINCT_user__00db1" in sql
+ db.session.delete(table)
+ db.session.delete(database)
+ db.session.commit()