| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| # pylint: disable=import-outside-toplevel |
| |
| from __future__ import annotations |
| |
| from contextlib import contextmanager |
| from typing import TYPE_CHECKING |
| from unittest.mock import patch |
| |
| import pytest |
| from pytest_mock import MockerFixture |
| from sqlalchemy import create_engine |
| from sqlalchemy.orm.session import Session |
| from sqlalchemy.pool import StaticPool |
| from sqlalchemy.sql.elements import ColumnElement |
| |
| from superset.superset_typing import AdhocColumn |
| |
| if TYPE_CHECKING: |
| from superset.models.core import Database |
| |
| |
| @pytest.fixture |
| def database(mocker: MockerFixture, session: Session) -> Database: |
| from superset.connectors.sqla.models import SqlaTable |
| from superset.models.core import Database |
| |
| SqlaTable.metadata.create_all(session.get_bind()) |
| |
| engine = create_engine( |
| "sqlite://", |
| connect_args={"check_same_thread": False}, |
| poolclass=StaticPool, |
| ) |
| database = Database(database_name="db", sqlalchemy_uri="sqlite://") |
| |
| connection = engine.raw_connection() |
| connection.execute("CREATE TABLE t (a INTEGER, b TEXT)") |
| connection.execute("INSERT INTO t VALUES (1, 'Alice')") |
| connection.execute("INSERT INTO t VALUES (NULL, 'Bob')") |
| connection.commit() |
| |
| # since we're using an in-memory SQLite database, make sure we always |
| # return the same engine where the table was created |
| @contextmanager |
| def mock_get_sqla_engine(catalog=None, schema=None, **kwargs): |
| yield engine |
| |
| mocker.patch.object( |
| database, |
| "get_sqla_engine", |
| new=mock_get_sqla_engine, |
| ) |
| |
| return database |
| |
| |
| def test_values_for_column(database: Database) -> None: |
| """ |
| Test the `values_for_column` method. |
| |
| NULL values should be returned as `None`, not `np.nan`, since NaN cannot be |
| serialized to JSON. |
| """ |
| import numpy as np |
| import pandas as pd |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| columns=[TableColumn(column_name="a")], |
| ) |
| |
| # Mock pd.read_sql_query to return a dataframe with the expected values |
| with patch( |
| "pandas.read_sql_query", |
| return_value=pd.DataFrame({"column_values": [1, np.nan]}), |
| ): |
| assert table.values_for_column("a") == [1, None] |
| |
| |
| def test_values_for_column_with_rls(database: Database) -> None: |
| """ |
| Test the `values_for_column` method with RLS enabled. |
| """ |
| import pandas as pd |
| from sqlalchemy.sql.elements import TextClause |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| columns=[ |
| TableColumn(column_name="a"), |
| ], |
| ) |
| |
| # Mock RLS filters and pd.read_sql_query |
| with ( |
| patch.object( |
| table, |
| "get_sqla_row_level_filters", |
| return_value=[ |
| TextClause("a = 1"), |
| ], |
| ), |
| patch( |
| "pandas.read_sql_query", |
| return_value=pd.DataFrame({"column_values": [1]}), |
| ), |
| ): |
| assert table.values_for_column("a") == [1] |
| |
| |
| def test_values_for_column_with_rls_no_values(database: Database) -> None: |
| """ |
| Test the `values_for_column` method with RLS enabled and no values. |
| """ |
| import pandas as pd |
| from sqlalchemy.sql.elements import TextClause |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| columns=[ |
| TableColumn(column_name="a"), |
| ], |
| ) |
| |
| # Mock RLS filters and pd.read_sql_query to return empty dataframe |
| with ( |
| patch.object( |
| table, |
| "get_sqla_row_level_filters", |
| return_value=[ |
| TextClause("a = 2"), |
| ], |
| ), |
| patch( |
| "pandas.read_sql_query", |
| return_value=pd.DataFrame({"column_values": []}), |
| ), |
| ): |
| assert table.values_for_column("a") == [] |
| |
| |
| def test_values_for_column_calculated( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test that calculated columns work. |
| """ |
| import pandas as pd |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| columns=[ |
| TableColumn( |
| column_name="starts_with_A", |
| expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END", |
| ) |
| ], |
| ) |
| |
| # Mock pd.read_sql_query to return expected values for calculated column |
| with patch( |
| "pandas.read_sql_query", |
| return_value=pd.DataFrame({"column_values": ["yes", "nope"]}), |
| ): |
| assert table.values_for_column("starts_with_A") == ["yes", "nope"] |
| |
| |
| def test_values_for_column_double_percents( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test the behavior of `double_percents`. |
| """ |
| import pandas as pd |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| with database.get_sqla_engine() as engine: |
| engine.dialect.identifier_preparer._double_percents = "pyformat" |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| columns=[ |
| TableColumn( |
| column_name="starts_with_A", |
| expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END", |
| ) |
| ], |
| ) |
| |
| # Mock pd.read_sql_query to capture the SQL and return expected values |
| read_sql_mock = mocker.patch( |
| "pandas.read_sql_query", |
| return_value=pd.DataFrame({"column_values": ["yes", "nope"]}), |
| ) |
| |
| result = table.values_for_column("starts_with_A") |
| |
| # Verify the result |
| assert result == ["yes", "nope"] |
| |
| # Verify read_sql_query was called |
| read_sql_mock.assert_called_once() |
| |
| # Get the SQL that was passed to read_sql_query |
| called_sql = str(read_sql_mock.call_args[1]["sql"]) |
| |
| # The SQL should have single percents (after replacement) |
| assert "LIKE 'A%'" in called_sql |
| assert "LIKE 'A%%'" not in called_sql |
| |
| |
| def test_apply_series_others_grouping(database: Database) -> None: |
| """ |
| Test the `_apply_series_others_grouping` method. |
| |
| This method should replace series columns with CASE expressions that |
| group remaining series into an "Others" category based on a condition. |
| """ |
| from unittest.mock import Mock |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| # Create a mock table for testing |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="test_table", |
| columns=[ |
| TableColumn(column_name="category", type="TEXT"), |
| TableColumn(column_name="metric_col", type="INTEGER"), |
| TableColumn(column_name="other_col", type="TEXT"), |
| ], |
| ) |
| |
| # Mock SELECT expressions |
| category_expr = Mock() |
| category_expr.name = "category" |
| metric_expr = Mock() |
| metric_expr.name = "metric_col" |
| other_expr = Mock() |
| other_expr.name = "other_col" |
| |
| select_exprs = [category_expr, metric_expr, other_expr] |
| |
| # Mock GROUP BY columns |
| groupby_all_columns = { |
| "category": category_expr, |
| "other_col": other_expr, |
| } |
| |
| # Define series columns (only category should be modified) |
| groupby_series_columns = {"category": category_expr} |
| |
| # Create a condition factory that always returns True |
| def always_true_condition(col_name: str, expr) -> bool: |
| return True |
| |
| # Mock the make_sqla_column_compatible method |
| def mock_make_compatible(expr, name=None): |
| mock_result = Mock() |
| mock_result.name = name |
| return mock_result |
| |
| with patch.object( |
| table, "make_sqla_column_compatible", side_effect=mock_make_compatible |
| ): |
| # Call the method |
| result_select_exprs, result_groupby_columns = ( |
| table._apply_series_others_grouping( |
| select_exprs, |
| groupby_all_columns, |
| groupby_series_columns, |
| always_true_condition, |
| ) |
| ) |
| |
| # Verify SELECT expressions |
| assert len(result_select_exprs) == 3 |
| |
| # Category (series column) should be replaced with CASE expression |
| category_result = result_select_exprs[0] |
| assert category_result.name == "category" # Should be made compatible |
| |
| # Metric (non-series column) should remain unchanged |
| assert result_select_exprs[1] == metric_expr |
| |
| # Other (non-series column) should remain unchanged |
| assert result_select_exprs[2] == other_expr |
| |
| # Verify GROUP BY columns |
| assert len(result_groupby_columns) == 2 |
| |
| # Category (series column) should be replaced with CASE expression |
| assert "category" in result_groupby_columns |
| category_groupby_result = result_groupby_columns["category"] |
| # After our fix, GROUP BY expressions are NOT wrapped with |
| # make_sqla_column_compatible, so it should be a raw CASE expression, |
| # not a Mock with .name attribute. Verify it's different from the original |
| assert category_groupby_result != category_expr |
| |
| # Other (non-series column) should remain unchanged |
| assert result_groupby_columns["other_col"] == other_expr |
| |
| |
| def test_apply_series_others_grouping_with_false_condition(database: Database) -> None: |
| """ |
| Test the `_apply_series_others_grouping` method with a condition that returns False. |
| |
| This should result in CASE expressions that always use "Others". |
| """ |
| from unittest.mock import Mock |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| # Create a mock table for testing |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="test_table", |
| columns=[TableColumn(column_name="category", type="TEXT")], |
| ) |
| |
| # Mock SELECT expressions |
| category_expr = Mock() |
| category_expr.name = "category" |
| select_exprs = [category_expr] |
| |
| # Mock GROUP BY columns |
| groupby_all_columns = {"category": category_expr} |
| groupby_series_columns = {"category": category_expr} |
| |
| # Create a condition factory that always returns False |
| def always_false_condition(col_name: str, expr) -> bool: |
| return False |
| |
| # Mock the make_sqla_column_compatible method |
| def mock_make_compatible(expr, name=None): |
| mock_result = Mock() |
| mock_result.name = name |
| return mock_result |
| |
| with patch.object( |
| table, "make_sqla_column_compatible", side_effect=mock_make_compatible |
| ): |
| # Call the method |
| result_select_exprs, result_groupby_columns = ( |
| table._apply_series_others_grouping( |
| select_exprs, |
| groupby_all_columns, |
| groupby_series_columns, |
| always_false_condition, |
| ) |
| ) |
| |
| # Verify that the expressions were replaced (we can't test SQL generation |
| # in a unit test, but we can verify the structure changed) |
| assert len(result_select_exprs) == 1 |
| assert result_select_exprs[0].name == "category" |
| |
| assert len(result_groupby_columns) == 1 |
| assert "category" in result_groupby_columns |
| # GROUP BY expression should be a CASE expression, not the original |
| assert result_groupby_columns["category"] != category_expr |
| |
| |
| def test_apply_series_others_grouping_sql_compilation(database: Database) -> None: |
| """ |
| Test that the `_apply_series_others_grouping` method properly quotes |
| the 'Others' literal in both SELECT and GROUP BY clauses. |
| |
| This test verifies the fix for the bug where 'Others' was not quoted |
| in the GROUP BY clause, causing SQL syntax errors. |
| """ |
| import sqlalchemy as sa |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| # Create a real table instance |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="test_table", |
| columns=[ |
| TableColumn(column_name="name", type="TEXT"), |
| TableColumn(column_name="value", type="INTEGER"), |
| ], |
| ) |
| |
| # Create real SQLAlchemy expressions |
| name_col = sa.column("name") |
| value_col = sa.column("value") |
| |
| select_exprs = [name_col, value_col] |
| groupby_all_columns = {"name": name_col} |
| groupby_series_columns = {"name": name_col} |
| |
| # Condition factory that checks if a subquery column is not null |
| def condition_factory(col_name: str, expr): |
| return sa.column("series_limit.name__").is_not(None) |
| |
| # Call the method |
| result_select_exprs, result_groupby_columns = table._apply_series_others_grouping( |
| select_exprs, |
| groupby_all_columns, |
| groupby_series_columns, |
| condition_factory, |
| ) |
| |
| # Get the database dialect from the actual database |
| with database.get_sqla_engine() as engine: |
| dialect = engine.dialect |
| |
| # Test SELECT expression compilation |
| select_case_expr = result_select_exprs[0] |
| select_sql = str( |
| select_case_expr.compile( |
| dialect=dialect, compile_kwargs={"literal_binds": True} |
| ) |
| ) |
| |
| # Test GROUP BY expression compilation |
| groupby_case_expr = result_groupby_columns["name"] |
| groupby_sql = str( |
| groupby_case_expr.compile( |
| dialect=dialect, compile_kwargs={"literal_binds": True} |
| ) |
| ) |
| |
| # Different databases may use different quote characters |
| # PostgreSQL/MySQL use single quotes, some might use double quotes |
| # The key is that Others should be quoted, not bare |
| |
| # Check that 'Others' appears with some form of quotes |
| # and not as a bare identifier |
| assert " Others " not in select_sql, "Found unquoted 'Others' in SELECT" |
| assert " Others " not in groupby_sql, "Found unquoted 'Others' in GROUP BY" |
| |
| # Check for common quoting patterns |
| has_single_quotes = "'Others'" in select_sql and "'Others'" in groupby_sql |
| has_double_quotes = '"Others"' in select_sql and '"Others"' in groupby_sql |
| |
| assert has_single_quotes or has_double_quotes, ( |
| "Others literal should be quoted with either single or double quotes" |
| ) |
| |
| # Verify the structure of the generated SQL |
| assert "CASE WHEN" in select_sql |
| assert "CASE WHEN" in groupby_sql |
| |
| # Check that ELSE is followed by a quoted value |
| assert "ELSE " in select_sql |
| assert "ELSE " in groupby_sql |
| |
| # The key test is that GROUP BY expression doesn't have a label |
| # while SELECT might or might not have one depending on the database |
| # What matters is that GROUP BY should NOT have label |
| assert " AS " not in groupby_sql # GROUP BY should NOT have label |
| |
| # Also verify that if SELECT has a label, it's different from GROUP BY |
| if " AS " in select_sql: |
| # If labeled, SELECT and GROUP BY should be different |
| assert select_sql != groupby_sql |
| |
| |
| def test_apply_series_others_grouping_no_label_in_groupby(database: Database) -> None: |
| """ |
| Test that GROUP BY expressions don't get wrapped with make_sqla_column_compatible. |
| |
| This is a specific test for the bug fix where make_sqla_column_compatible |
| was causing issues with literal quoting in GROUP BY clauses. |
| """ |
| from unittest.mock import ANY, call, Mock, patch |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| # Create a table instance |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="test_table", |
| columns=[TableColumn(column_name="category", type="TEXT")], |
| ) |
| |
| # Mock expressions |
| category_expr = Mock() |
| category_expr.name = "category" |
| |
| select_exprs = [category_expr] |
| groupby_all_columns = {"category": category_expr} |
| groupby_series_columns = {"category": category_expr} |
| |
| def condition_factory(col_name: str, expr): |
| return True |
| |
| # Track calls to make_sqla_column_compatible |
| with patch.object( |
| table, "make_sqla_column_compatible", side_effect=lambda expr, name: expr |
| ) as mock_make_compatible: |
| result_select_exprs, result_groupby_columns = ( |
| table._apply_series_others_grouping( |
| select_exprs, |
| groupby_all_columns, |
| groupby_series_columns, |
| condition_factory, |
| ) |
| ) |
| |
| # Verify make_sqla_column_compatible was called for SELECT expressions |
| # but NOT for GROUP BY expressions |
| calls = mock_make_compatible.call_args_list |
| |
| # Should have exactly one call (for the SELECT expression) |
| assert len(calls) == 1 |
| |
| # The call should be for the SELECT expression with the column name |
| # Using unittest.mock.ANY to match any CASE expression |
| assert calls[0] == call(ANY, "category") |
| |
| # Verify the GROUP BY expression was NOT passed through |
| # make_sqla_column_compatible - it should be the raw CASE expression |
| assert "category" in result_groupby_columns |
| # The GROUP BY expression should be different from the SELECT expression |
| # because only SELECT gets make_sqla_column_compatible applied |
| |
| |
| def test_process_orderby_expression_basic( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test basic ORDER BY expression processing. |
| """ |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| # Mock _process_sql_expression to return a processed SELECT statement |
| mocker.patch.object( |
| table, |
| "_process_sql_expression", |
| return_value="SELECT 1 ORDER BY column_name DESC", |
| ) |
| |
| result = table._process_orderby_expression( |
| expression="column_name DESC", |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=None, |
| ) |
| |
| assert result == "column_name DESC" |
| |
| |
| def test_process_orderby_expression_with_case_insensitive_order_by( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test ORDER BY expression processing with case-insensitive matching. |
| """ |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| # Mock with lowercase "order by" |
| mocker.patch.object( |
| table, |
| "_process_sql_expression", |
| return_value="SELECT 1 order by column_name ASC", |
| ) |
| |
| result = table._process_orderby_expression( |
| expression="column_name ASC", |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=None, |
| ) |
| |
| assert result == "column_name ASC" |
| |
| |
| def test_process_orderby_expression_complex( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test ORDER BY expression with complex expressions. |
| """ |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| complex_orderby = "CASE WHEN status = 'active' THEN 1 ELSE 2 END, name DESC" |
| mocker.patch.object( |
| table, |
| "_process_sql_expression", |
| return_value=f"SELECT 1 ORDER BY {complex_orderby}", |
| ) |
| |
| result = table._process_orderby_expression( |
| expression=complex_orderby, |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=None, |
| ) |
| |
| assert result == complex_orderby |
| |
| |
| def test_process_orderby_expression_none( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test ORDER BY expression processing with None expression. |
| """ |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| # Mock should return None when input is None |
| mocker.patch.object( |
| table, |
| "_process_sql_expression", |
| return_value=None, |
| ) |
| |
| result = table._process_orderby_expression( |
| expression=None, |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=None, |
| ) |
| |
| assert result is None |
| |
| |
| def test_process_orderby_expression_empty_string( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test ORDER BY expression processing with empty string. |
| """ |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| # Mock should return None for empty string |
| mocker.patch.object( |
| table, |
| "_process_sql_expression", |
| return_value=None, |
| ) |
| |
| result = table._process_orderby_expression( |
| expression="", |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=None, |
| ) |
| |
| assert result is None |
| |
| |
| def test_process_orderby_expression_strips_whitespace( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test that ORDER BY expression processing strips leading/trailing whitespace. |
| """ |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| # Mock with extra whitespace after ORDER BY |
| mocker.patch.object( |
| table, |
| "_process_sql_expression", |
| return_value="SELECT 1 ORDER BY column_name DESC ", |
| ) |
| |
| result = table._process_orderby_expression( |
| expression="column_name DESC", |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=None, |
| ) |
| |
| assert result == "column_name DESC" |
| |
| |
| def test_process_orderby_expression_with_template_processor( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test ORDER BY expression with template processor. |
| """ |
| from unittest.mock import Mock |
| |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| # Create a mock template processor |
| template_processor = Mock() |
| |
| # Mock the _process_sql_expression to verify it receives the prefixed expression |
| mock_process = mocker.patch.object( |
| table, |
| "_process_sql_expression", |
| return_value="SELECT 1 ORDER BY processed_column DESC", |
| ) |
| |
| result = table._process_orderby_expression( |
| expression="column_name DESC", |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=template_processor, |
| ) |
| |
| # Verify _process_sql_expression was called with SELECT prefix |
| mock_process.assert_called_once() |
| call_args = mock_process.call_args[1] |
| assert call_args["expression"] == "SELECT 1 ORDER BY column_name DESC" |
| assert call_args["template_processor"] is template_processor |
| |
| assert result == "processed_column DESC" |
| |
| |
| def test_process_select_expression_basic( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test basic SELECT expression processing. |
| """ |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| # Mock _process_sql_expression to return a processed SELECT statement |
| mocker.patch.object( |
| table, |
| "_process_sql_expression", |
| return_value="SELECT COUNT(*)", |
| ) |
| |
| result = table._process_select_expression( |
| expression="COUNT(*)", |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=None, |
| ) |
| |
| assert result == "COUNT(*)" |
| |
| |
| def test_process_select_expression_with_case_insensitive_select( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test SELECT expression processing with case-insensitive matching. |
| """ |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| # Mock with lowercase "select" |
| mocker.patch.object( |
| table, |
| "_process_sql_expression", |
| return_value="select column_name", |
| ) |
| |
| result = table._process_select_expression( |
| expression="column_name", |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=None, |
| ) |
| |
| assert result == "column_name" |
| |
| |
| def test_process_select_expression_complex( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test SELECT expression with complex expressions. |
| """ |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| complex_select = "CASE WHEN status = 'active' THEN 1 ELSE 0 END" |
| mocker.patch.object( |
| table, |
| "_process_sql_expression", |
| return_value=f"SELECT {complex_select}", |
| ) |
| |
| result = table._process_select_expression( |
| expression=complex_select, |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=None, |
| ) |
| |
| assert result == complex_select |
| |
| |
| def test_process_select_expression_none( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test SELECT expression processing with None expression. |
| """ |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| # Mock should return None when input is None |
| mocker.patch.object( |
| table, |
| "_process_sql_expression", |
| return_value=None, |
| ) |
| |
| result = table._process_select_expression( |
| expression=None, |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=None, |
| ) |
| |
| assert result is None |
| |
| |
| def test_process_select_expression_empty_string( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test SELECT expression processing with empty string. |
| """ |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| # Mock should return None for empty string |
| mocker.patch.object( |
| table, |
| "_process_sql_expression", |
| return_value=None, |
| ) |
| |
| result = table._process_select_expression( |
| expression="", |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=None, |
| ) |
| |
| assert result is None |
| |
| |
| def test_process_select_expression_strips_whitespace( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test that SELECT expression processing strips leading/trailing whitespace. |
| """ |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| # Mock with extra whitespace after SELECT |
| mocker.patch.object( |
| table, |
| "_process_sql_expression", |
| return_value="SELECT column_name ", |
| ) |
| |
| result = table._process_select_expression( |
| expression="column_name", |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=None, |
| ) |
| |
| assert result == "column_name" |
| |
| |
| def test_process_select_expression_with_template_processor( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test SELECT expression with template processor. |
| """ |
| from unittest.mock import Mock |
| |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| # Create a mock template processor |
| template_processor = Mock() |
| |
| # Mock the _process_sql_expression to verify it receives the prefixed expression |
| mock_process = mocker.patch.object( |
| table, |
| "_process_sql_expression", |
| return_value="SELECT processed_expression", |
| ) |
| |
| result = table._process_select_expression( |
| expression="some_expression", |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=template_processor, |
| ) |
| |
| # Verify _process_sql_expression was called with SELECT prefix |
| mock_process.assert_called_once() |
| call_args = mock_process.call_args[1] |
| assert call_args["expression"] == "SELECT some_expression" |
| assert call_args["template_processor"] is template_processor |
| |
| assert result == "processed_expression" |
| |
| |
| def test_process_select_expression_distinct_column( |
| mocker: MockerFixture, |
| database: Database, |
| ) -> None: |
| """ |
| Test SELECT expression with DISTINCT keyword (e.g., "distinct owners"). |
| |
| This test ensures that expressions like "distinct owners" used in adhoc |
| metrics or columns are properly parsed and validated. |
| """ |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| # Mock _process_sql_expression to return a processed SELECT with DISTINCT |
| mocker.patch.object( |
| table, |
| "_process_sql_expression", |
| return_value="SELECT DISTINCT owners", |
| ) |
| |
| result = table._process_select_expression( |
| expression="distinct owners", |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=None, |
| ) |
| |
| assert result == "DISTINCT owners" |
| |
| |
| def test_process_select_expression_end_to_end(database: Database) -> None: |
| """ |
| End-to-end test that verifies the regex split works with real sqlglot processing. |
| |
| This test does NOT mock _process_sql_expression, allowing the full flow |
| through sqlglot parsing and validation to ensure the regex extraction works. |
| """ |
| from superset.connectors.sqla.models import SqlaTable |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="t", |
| ) |
| |
| # Test various real-world expressions |
| test_cases = [ |
| # (input, expected_output) |
| ("COUNT(*)", "COUNT(*)"), |
| ("DISTINCT owners", "DISTINCT owners"), |
| ("column_name", "column_name"), |
| ( |
| "CASE WHEN status = 'active' THEN 1 ELSE 0 END", |
| "CASE WHEN status = 'active' THEN 1 ELSE 0 END", |
| ), |
| ("SUM(amount) / COUNT(*)", "SUM(amount) / COUNT(*)"), |
| ("UPPER(name)", "UPPER(name)"), |
| ] |
| |
| for expression, expected in test_cases: |
| result = table._process_select_expression( |
| expression=expression, |
| database_id=database.id, |
| engine="sqlite", |
| schema="", |
| template_processor=None, |
| ) |
| # sqlglot may normalize the SQL slightly, so we check the result exists |
| # and doesn't contain the SELECT prefix |
| assert result is not None, f"Failed to process: {expression}" |
| assert not result.upper().startswith("SELECT"), ( |
| f"Result still has SELECT prefix: {result}" |
| ) |
| # The result should contain the core expression (case-insensitive check) |
| assert expected.replace(" ", "").lower() in result.replace(" ", "").lower(), ( |
| f"Expected '{expected}' to be in result '{result}' for input '{expression}'" |
| ) |
| |
| |
| def test_reapply_query_filters_with_granularity(database: Database) -> None: |
| """ |
| Test that _reapply_query_filters correctly applies filters with granularity. |
| |
| When granularity is provided, both time_filters and where_clause_and should |
| be combined in the WHERE clause. |
| """ |
| import sqlalchemy as sa |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="test_table", |
| columns=[TableColumn(column_name="value", type="INTEGER")], |
| ) |
| |
| # Create a simple query |
| qry = sa.select(sa.column("value")) |
| |
| # Create mock filter conditions |
| time_filter = sa.column("time_col") >= "2025-01-01" |
| where_filter = sa.column("value") > 10 |
| |
| time_filters = [time_filter] |
| where_clause_and = [where_filter] |
| having_clause_and: list[ColumnElement] = [] |
| |
| # Call the method |
| result_qry = table._reapply_query_filters( |
| qry=qry, |
| apply_fetch_values_predicate=False, |
| template_processor=None, |
| granularity="time_col", |
| time_filters=time_filters, |
| where_clause_and=where_clause_and, |
| having_clause_and=having_clause_and, |
| ) |
| |
| # Compile the query to SQL |
| with database.get_sqla_engine() as engine: |
| sql = str( |
| result_qry.compile( |
| dialect=engine.dialect, compile_kwargs={"literal_binds": True} |
| ) |
| ) |
| |
| # Verify WHERE clause is present |
| assert "WHERE" in sql |
| # Both filters should be in the query |
| assert "time_col" in sql |
| assert "value" in sql |
| |
| |
| def test_reapply_query_filters_without_granularity(database: Database) -> None: |
| """ |
| Test that _reapply_query_filters works correctly without granularity. |
| |
| This test verifies the bug fix where time_filters was not initialized |
| when granularity is None. The method should handle empty time_filters |
| gracefully and only apply where_clause_and. |
| """ |
| import sqlalchemy as sa |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="test_table", |
| columns=[TableColumn(column_name="value", type="INTEGER")], |
| ) |
| |
| # Create a simple query |
| qry = sa.select(sa.column("value")) |
| |
| # Empty time_filters (as would happen without granularity) |
| time_filters: list[ColumnElement] = [] |
| where_filter = sa.column("value") > 10 |
| where_clause_and = [where_filter] |
| having_clause_and: list[ColumnElement] = [] |
| |
| # Call the method with granularity=None |
| result_qry = table._reapply_query_filters( |
| qry=qry, |
| apply_fetch_values_predicate=False, |
| template_processor=None, |
| granularity=None, |
| time_filters=time_filters, |
| where_clause_and=where_clause_and, |
| having_clause_and=having_clause_and, |
| ) |
| |
| # Compile the query to SQL |
| with database.get_sqla_engine() as engine: |
| sql = str( |
| result_qry.compile( |
| dialect=engine.dialect, compile_kwargs={"literal_binds": True} |
| ) |
| ) |
| |
| # Verify WHERE clause is present with the where_filter |
| assert "WHERE" in sql |
| assert "value" in sql |
| |
| |
| def test_reapply_query_filters_with_having_clause(database: Database) -> None: |
| """ |
| Test that _reapply_query_filters correctly applies HAVING clause. |
| |
| HAVING clauses are used for filtering on aggregated metrics. |
| """ |
| import sqlalchemy as sa |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="test_table", |
| columns=[TableColumn(column_name="value", type="INTEGER")], |
| ) |
| |
| # Create a query with GROUP BY |
| qry = sa.select(sa.column("category"), sa.func.sum(sa.column("value"))).group_by( |
| sa.column("category") |
| ) |
| |
| # Create HAVING condition |
| having_filter = sa.func.sum(sa.column("value")) > 100 |
| having_clause_and = [having_filter] |
| |
| # Call the method |
| result_qry = table._reapply_query_filters( |
| qry=qry, |
| apply_fetch_values_predicate=False, |
| template_processor=None, |
| granularity=None, |
| time_filters=[], |
| where_clause_and=[], |
| having_clause_and=having_clause_and, |
| ) |
| |
| # Compile the query to SQL |
| with database.get_sqla_engine() as engine: |
| sql = str( |
| result_qry.compile( |
| dialect=engine.dialect, compile_kwargs={"literal_binds": True} |
| ) |
| ) |
| |
| # Verify HAVING clause is present |
| assert "HAVING" in sql |
| assert "sum" in sql.lower() |
| |
| |
| def test_reapply_query_filters_with_fetch_values_predicate(database: Database) -> None: |
| """ |
| Test that _reapply_query_filters applies fetch_values_predicate when enabled. |
| |
| Fetch values predicate is used for filtering specific column values. |
| """ |
| from unittest.mock import Mock |
| |
| import sqlalchemy as sa |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="test_table", |
| columns=[TableColumn(column_name="value", type="INTEGER")], |
| ) |
| |
| # Mock fetch_values_predicate |
| fetch_predicate = sa.column("value").in_([1, 2, 3]) |
| table.fetch_values_predicate = True |
| |
| # Mock get_fetch_values_predicate method |
| mock_template_processor = Mock() |
| with patch.object( |
| table, "get_fetch_values_predicate", return_value=fetch_predicate |
| ): |
| # Create a simple query |
| qry = sa.select(sa.column("value")) |
| |
| # Call the method with apply_fetch_values_predicate=True |
| result_qry = table._reapply_query_filters( |
| qry=qry, |
| apply_fetch_values_predicate=True, |
| template_processor=mock_template_processor, |
| granularity=None, |
| time_filters=[], |
| where_clause_and=[], |
| having_clause_and=[], |
| ) |
| |
| # Compile the query to SQL |
| with database.get_sqla_engine() as engine: |
| sql = str( |
| result_qry.compile( |
| dialect=engine.dialect, compile_kwargs={"literal_binds": True} |
| ) |
| ) |
| |
| # Verify WHERE clause with IN condition is present |
| assert "WHERE" in sql |
| assert "IN" in sql |
| |
| |
| def test_reapply_query_filters_with_empty_filters(database: Database) -> None: |
| """ |
| Test that _reapply_query_filters handles empty filter lists gracefully. |
| |
| This is an edge case test to ensure the method doesn't fail when |
| all filter lists are empty. |
| """ |
| import sqlalchemy as sa |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| table = SqlaTable( |
| database=database, |
| schema=None, |
| table_name="test_table", |
| columns=[TableColumn(column_name="value", type="INTEGER")], |
| ) |
| |
| # Create a simple query |
| qry = sa.select(sa.column("value")) |
| |
| # All empty filter lists |
| time_filters: list[ColumnElement] = [] |
| where_clause_and: list[ColumnElement] = [] |
| having_clause_and: list[ColumnElement] = [] |
| |
| # Call the method with empty filters |
| result_qry = table._reapply_query_filters( |
| qry=qry, |
| apply_fetch_values_predicate=False, |
| template_processor=None, |
| granularity=None, |
| time_filters=time_filters, |
| where_clause_and=where_clause_and, |
| having_clause_and=having_clause_and, |
| ) |
| |
| # Should not raise an error |
| # Compile the query to verify it's valid |
| with database.get_sqla_engine() as engine: |
| sql = str( |
| result_qry.compile( |
| dialect=engine.dialect, compile_kwargs={"literal_binds": True} |
| ) |
| ) |
| |
| # Query should be valid without WHERE or HAVING |
| assert "SELECT" in sql |
| assert "value" in sql |
| |
| |
| def test_adhoc_column_to_sqla_with_column_reference(database: Database) -> None: |
| """ |
| Test that adhoc_column_to_sqla properly handles column references |
| by looking up the column in metadata instead of quoting and processing through |
| SQLGlot. |
| |
| This tests the fix for column names with spaces being properly handled |
| without going through SQLGlot which could misinterpret "column AS alias" patterns. |
| """ |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| table = SqlaTable( |
| table_name="test_table", |
| database=database, |
| columns=[ |
| TableColumn(column_name="Customer Name", type="TEXT"), |
| ], |
| ) |
| |
| # Test: Column reference with spaces should be found in metadata |
| col_with_spaces: AdhocColumn = { |
| "sqlExpression": "Customer Name", |
| "label": "Customer Name", |
| "isColumnReference": True, |
| } |
| |
| result = table.adhoc_column_to_sqla(col_with_spaces) |
| |
| # Should return a valid SQLAlchemy column |
| assert result is not None |
| result_str = str(result) |
| |
| # The column name should be present (may or may not be quoted depending on dialect) |
| assert "Customer Name" in result_str or '"Customer Name"' in result_str |
| |
| |
| def test_adhoc_column_to_sqla_preserves_column_type_for_time_grain( |
| database: Database, |
| ) -> None: |
| """ |
| Test that adhoc_column_to_sqla preserves column type info in column references. |
| |
| This tests the fix where column references now look up metadata first, preserving |
| type information needed for time grain operations. Previously, quoting the column |
| name before metadata lookup would cause the column to not be found, resulting in |
| NULL type and failing to apply time grain transformations properly. |
| |
| The test verifies that: |
| 1. Column metadata is found by looking up the unquoted column name |
| 2. The column type (DATE) is preserved when creating the SQLAlchemy column |
| 3. The get_timestamp_expr method is properly called with the column type info |
| """ |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| # Create a table with a temporal column |
| table = SqlaTable( |
| table_name="test_table", |
| database=database, |
| columns=[ |
| TableColumn( |
| column_name="local_date", |
| type="DATE", |
| is_dttm=True, |
| ) |
| ], |
| ) |
| |
| # Test with a DATE column reference with time grain |
| date_col: AdhocColumn = { |
| "sqlExpression": "local_date", |
| "label": "local_date", |
| "isColumnReference": True, |
| "timeGrain": "P1D", # Daily time grain |
| "columnType": "BASE_AXIS", |
| } |
| |
| # Should not raise ColumnNotFoundException |
| result = table.adhoc_column_to_sqla(date_col) |
| |
| assert result is not None |
| result_str = str(result) |
| |
| # Verify the column name is present (may be quoted depending on dialect) |
| assert "local_date" in result_str |
| |
| |
| def test_adhoc_column_to_sqla_with_temporal_column_types(database: Database) -> None: |
| """ |
| Test that adhoc_column_to_sqla correctly handles different temporal column types. |
| |
| This verifies that for different temporal types (DATE, DATETIME, TIMESTAMP), |
| the column metadata is properly found and the column type is preserved, |
| allowing time grain operations to work correctly. |
| """ |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| # Test different temporal types |
| temporal_types = ["DATE", "DATETIME", "TIMESTAMP"] |
| |
| for type_name in temporal_types: |
| table = SqlaTable( |
| table_name="test_table", |
| database=database, |
| columns=[ |
| TableColumn( |
| column_name="time_col", |
| type=type_name, |
| is_dttm=True, |
| ) |
| ], |
| ) |
| |
| time_col: AdhocColumn = { |
| "sqlExpression": "time_col", |
| "label": "time_col", |
| "isColumnReference": True, |
| "timeGrain": "P1D", |
| "columnType": "BASE_AXIS", |
| } |
| |
| result = table.adhoc_column_to_sqla(time_col) |
| |
| assert result is not None |
| result_str = str(result) |
| |
| # Verify the column name is present |
| assert "time_col" in result_str |
| |
| |
| def test_adhoc_column_with_spaces_generates_quoted_sql(database: Database) -> None: |
| """ |
| Test that column names with spaces are properly quoted in the generated SQL. |
| |
| This verifies that even though we look up columns using unquoted names, |
| the final SQL still properly quotes column names that need quoting (like those with |
| spaces). |
| """ |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| table = SqlaTable( |
| table_name="test_table", |
| database=database, |
| columns=[ |
| TableColumn(column_name="Customer Name", type="TEXT"), |
| TableColumn(column_name="Order Total", type="NUMERIC"), |
| ], |
| ) |
| |
| # Test column reference with spaces |
| col_with_spaces: AdhocColumn = { |
| "sqlExpression": "Customer Name", |
| "label": "Customer Name", |
| "isColumnReference": True, |
| } |
| |
| result = table.adhoc_column_to_sqla(col_with_spaces) |
| |
| # Compile the column to SQL to see how it's rendered |
| with database.get_sqla_engine() as engine: |
| sql = str( |
| result.compile( |
| dialect=engine.dialect, compile_kwargs={"literal_binds": True} |
| ) |
| ) |
| |
| # The SQL should quote the column name (SQLite uses double quotes) |
| # Column names with spaces MUST be quoted in SQL |
| assert '"Customer Name"' in sql, f"Expected quoted column name in SQL: {sql}" |
| |
| # Also test that it works in a query context |
| col_numeric: AdhocColumn = { |
| "sqlExpression": "Order Total", |
| "label": "Order Total", |
| "isColumnReference": True, |
| } |
| |
| result_numeric = table.adhoc_column_to_sqla(col_numeric) |
| |
| with database.get_sqla_engine() as engine: |
| sql_numeric = str( |
| result_numeric.compile( |
| dialect=engine.dialect, compile_kwargs={"literal_binds": True} |
| ) |
| ) |
| |
| assert '"Order Total"' in sql_numeric, ( |
| f"Expected quoted column name in SQL: {sql_numeric}" |
| ) |
| |
| |
| def test_adhoc_column_with_spaces_in_full_query(database: Database) -> None: |
| """ |
| Test that column names with spaces work correctly in a full SELECT query. |
| |
| This demonstrates that the fix properly handles column names with spaces |
| throughout the entire query generation process, with proper quoting in the final |
| SQL. |
| """ |
| import sqlalchemy as sa |
| |
| from superset.connectors.sqla.models import SqlaTable, TableColumn |
| |
| table = SqlaTable( |
| table_name="test_table", |
| database=database, |
| columns=[ |
| TableColumn(column_name="Customer Name", type="TEXT"), |
| TableColumn(column_name="Order Total", type="NUMERIC"), |
| ], |
| ) |
| |
| # Create adhoc columns for both columns with spaces |
| customer_col: AdhocColumn = { |
| "sqlExpression": "Customer Name", |
| "label": "Customer Name", |
| "isColumnReference": True, |
| } |
| |
| order_col: AdhocColumn = { |
| "sqlExpression": "Order Total", |
| "label": "Order Total", |
| "isColumnReference": True, |
| } |
| |
| # Get SQLAlchemy columns |
| customer_sqla = table.adhoc_column_to_sqla(customer_col) |
| order_sqla = table.adhoc_column_to_sqla(order_col) |
| |
| # Build a full query |
| tbl = table.get_sqla_table() |
| query = sa.select(customer_sqla, order_sqla).select_from(tbl) |
| |
| # Compile to SQL |
| with database.get_sqla_engine() as engine: |
| sql = str( |
| query.compile( |
| dialect=engine.dialect, compile_kwargs={"literal_binds": True} |
| ) |
| ) |
| |
| # Verify both column names are quoted in the final SQL |
| assert '"Customer Name"' in sql, f"Customer Name not properly quoted in SQL: {sql}" |
| assert '"Order Total"' in sql, f"Order Total not properly quoted in SQL: {sql}" |
| |
| # Verify SELECT and FROM clauses are present |
| assert "SELECT" in sql |
| assert "FROM" in sql |