blob: 80430ab0881845d12c9f587dc75a3e00065240eb [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING
from unittest.mock import patch
import pytest
from pytest_mock import MockerFixture
from sqlalchemy import create_engine
from sqlalchemy.orm.session import Session
from sqlalchemy.pool import StaticPool
if TYPE_CHECKING:
from superset.models.core import Database
@pytest.fixture
def database(mocker: MockerFixture, session: Session) -> Database:
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
SqlaTable.metadata.create_all(session.get_bind())
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
database = Database(database_name="db", sqlalchemy_uri="sqlite://")
connection = engine.raw_connection()
connection.execute("CREATE TABLE t (a INTEGER, b TEXT)")
connection.execute("INSERT INTO t VALUES (1, 'Alice')")
connection.execute("INSERT INTO t VALUES (NULL, 'Bob')")
connection.commit()
# since we're using an in-memory SQLite database, make sure we always
# return the same engine where the table was created
@contextmanager
def mock_get_sqla_engine(catalog=None, schema=None, **kwargs):
yield engine
mocker.patch.object(
database,
"get_sqla_engine",
new=mock_get_sqla_engine,
)
return database
def test_values_for_column(database: Database) -> None:
"""
Test the `values_for_column` method.
NULL values should be returned as `None`, not `np.nan`, since NaN cannot be
serialized to JSON.
"""
import numpy as np
import pandas as pd
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[TableColumn(column_name="a")],
)
# Mock pd.read_sql_query to return a dataframe with the expected values
with patch(
"pandas.read_sql_query",
return_value=pd.DataFrame({"column_values": [1, np.nan]}),
):
assert table.values_for_column("a") == [1, None]
def test_values_for_column_with_rls(database: Database) -> None:
"""
Test the `values_for_column` method with RLS enabled.
"""
import pandas as pd
from sqlalchemy.sql.elements import TextClause
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a"),
],
)
# Mock RLS filters and pd.read_sql_query
with (
patch.object(
table,
"get_sqla_row_level_filters",
return_value=[
TextClause("a = 1"),
],
),
patch(
"pandas.read_sql_query",
return_value=pd.DataFrame({"column_values": [1]}),
),
):
assert table.values_for_column("a") == [1]
def test_values_for_column_with_rls_no_values(database: Database) -> None:
"""
Test the `values_for_column` method with RLS enabled and no values.
"""
import pandas as pd
from sqlalchemy.sql.elements import TextClause
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a"),
],
)
# Mock RLS filters and pd.read_sql_query to return empty dataframe
with (
patch.object(
table,
"get_sqla_row_level_filters",
return_value=[
TextClause("a = 2"),
],
),
patch(
"pandas.read_sql_query",
return_value=pd.DataFrame({"column_values": []}),
),
):
assert table.values_for_column("a") == []
def test_values_for_column_calculated(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test that calculated columns work.
"""
import pandas as pd
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(
column_name="starts_with_A",
expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END",
)
],
)
# Mock pd.read_sql_query to return expected values for calculated column
with patch(
"pandas.read_sql_query",
return_value=pd.DataFrame({"column_values": ["yes", "nope"]}),
):
assert table.values_for_column("starts_with_A") == ["yes", "nope"]
def test_values_for_column_double_percents(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test the behavior of `double_percents`.
"""
import pandas as pd
from superset.connectors.sqla.models import SqlaTable, TableColumn
with database.get_sqla_engine() as engine:
engine.dialect.identifier_preparer._double_percents = "pyformat"
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(
column_name="starts_with_A",
expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END",
)
],
)
# Mock pd.read_sql_query to capture the SQL and return expected values
read_sql_mock = mocker.patch(
"pandas.read_sql_query",
return_value=pd.DataFrame({"column_values": ["yes", "nope"]}),
)
result = table.values_for_column("starts_with_A")
# Verify the result
assert result == ["yes", "nope"]
# Verify read_sql_query was called
read_sql_mock.assert_called_once()
# Get the SQL that was passed to read_sql_query
called_sql = str(read_sql_mock.call_args[1]["sql"])
# The SQL should have single percents (after replacement)
assert "LIKE 'A%'" in called_sql
assert "LIKE 'A%%'" not in called_sql
def test_apply_series_others_grouping(database: Database) -> None:
"""
Test the `_apply_series_others_grouping` method.
This method should replace series columns with CASE expressions that
group remaining series into an "Others" category based on a condition.
"""
from unittest.mock import Mock
from superset.connectors.sqla.models import SqlaTable, TableColumn
# Create a mock table for testing
table = SqlaTable(
database=database,
schema=None,
table_name="test_table",
columns=[
TableColumn(column_name="category", type="TEXT"),
TableColumn(column_name="metric_col", type="INTEGER"),
TableColumn(column_name="other_col", type="TEXT"),
],
)
# Mock SELECT expressions
category_expr = Mock()
category_expr.name = "category"
metric_expr = Mock()
metric_expr.name = "metric_col"
other_expr = Mock()
other_expr.name = "other_col"
select_exprs = [category_expr, metric_expr, other_expr]
# Mock GROUP BY columns
groupby_all_columns = {
"category": category_expr,
"other_col": other_expr,
}
# Define series columns (only category should be modified)
groupby_series_columns = {"category": category_expr}
# Create a condition factory that always returns True
def always_true_condition(col_name: str, expr) -> bool:
return True
# Mock the make_sqla_column_compatible method
def mock_make_compatible(expr, name=None):
mock_result = Mock()
mock_result.name = name
return mock_result
with patch.object(
table, "make_sqla_column_compatible", side_effect=mock_make_compatible
):
# Call the method
result_select_exprs, result_groupby_columns = (
table._apply_series_others_grouping(
select_exprs,
groupby_all_columns,
groupby_series_columns,
always_true_condition,
)
)
# Verify SELECT expressions
assert len(result_select_exprs) == 3
# Category (series column) should be replaced with CASE expression
category_result = result_select_exprs[0]
assert category_result.name == "category" # Should be made compatible
# Metric (non-series column) should remain unchanged
assert result_select_exprs[1] == metric_expr
# Other (non-series column) should remain unchanged
assert result_select_exprs[2] == other_expr
# Verify GROUP BY columns
assert len(result_groupby_columns) == 2
# Category (series column) should be replaced with CASE expression
assert "category" in result_groupby_columns
category_groupby_result = result_groupby_columns["category"]
# After our fix, GROUP BY expressions are NOT wrapped with
# make_sqla_column_compatible, so it should be a raw CASE expression,
# not a Mock with .name attribute. Verify it's different from the original
assert category_groupby_result != category_expr
# Other (non-series column) should remain unchanged
assert result_groupby_columns["other_col"] == other_expr
def test_apply_series_others_grouping_with_false_condition(database: Database) -> None:
"""
Test the `_apply_series_others_grouping` method with a condition that returns False.
This should result in CASE expressions that always use "Others".
"""
from unittest.mock import Mock
from superset.connectors.sqla.models import SqlaTable, TableColumn
# Create a mock table for testing
table = SqlaTable(
database=database,
schema=None,
table_name="test_table",
columns=[TableColumn(column_name="category", type="TEXT")],
)
# Mock SELECT expressions
category_expr = Mock()
category_expr.name = "category"
select_exprs = [category_expr]
# Mock GROUP BY columns
groupby_all_columns = {"category": category_expr}
groupby_series_columns = {"category": category_expr}
# Create a condition factory that always returns False
def always_false_condition(col_name: str, expr) -> bool:
return False
# Mock the make_sqla_column_compatible method
def mock_make_compatible(expr, name=None):
mock_result = Mock()
mock_result.name = name
return mock_result
with patch.object(
table, "make_sqla_column_compatible", side_effect=mock_make_compatible
):
# Call the method
result_select_exprs, result_groupby_columns = (
table._apply_series_others_grouping(
select_exprs,
groupby_all_columns,
groupby_series_columns,
always_false_condition,
)
)
# Verify that the expressions were replaced (we can't test SQL generation
# in a unit test, but we can verify the structure changed)
assert len(result_select_exprs) == 1
assert result_select_exprs[0].name == "category"
assert len(result_groupby_columns) == 1
assert "category" in result_groupby_columns
# GROUP BY expression should be a CASE expression, not the original
assert result_groupby_columns["category"] != category_expr
def test_apply_series_others_grouping_sql_compilation(database: Database) -> None:
"""
Test that the `_apply_series_others_grouping` method properly quotes
the 'Others' literal in both SELECT and GROUP BY clauses.
This test verifies the fix for the bug where 'Others' was not quoted
in the GROUP BY clause, causing SQL syntax errors.
"""
import sqlalchemy as sa
from superset.connectors.sqla.models import SqlaTable, TableColumn
# Create a real table instance
table = SqlaTable(
database=database,
schema=None,
table_name="test_table",
columns=[
TableColumn(column_name="name", type="TEXT"),
TableColumn(column_name="value", type="INTEGER"),
],
)
# Create real SQLAlchemy expressions
name_col = sa.column("name")
value_col = sa.column("value")
select_exprs = [name_col, value_col]
groupby_all_columns = {"name": name_col}
groupby_series_columns = {"name": name_col}
# Condition factory that checks if a subquery column is not null
def condition_factory(col_name: str, expr):
return sa.column("series_limit.name__").is_not(None)
# Call the method
result_select_exprs, result_groupby_columns = table._apply_series_others_grouping(
select_exprs,
groupby_all_columns,
groupby_series_columns,
condition_factory,
)
# Get the database dialect from the actual database
with database.get_sqla_engine() as engine:
dialect = engine.dialect
# Test SELECT expression compilation
select_case_expr = result_select_exprs[0]
select_sql = str(
select_case_expr.compile(
dialect=dialect, compile_kwargs={"literal_binds": True}
)
)
# Test GROUP BY expression compilation
groupby_case_expr = result_groupby_columns["name"]
groupby_sql = str(
groupby_case_expr.compile(
dialect=dialect, compile_kwargs={"literal_binds": True}
)
)
# Different databases may use different quote characters
# PostgreSQL/MySQL use single quotes, some might use double quotes
# The key is that Others should be quoted, not bare
# Check that 'Others' appears with some form of quotes
# and not as a bare identifier
assert " Others " not in select_sql, "Found unquoted 'Others' in SELECT"
assert " Others " not in groupby_sql, "Found unquoted 'Others' in GROUP BY"
# Check for common quoting patterns
has_single_quotes = "'Others'" in select_sql and "'Others'" in groupby_sql
has_double_quotes = '"Others"' in select_sql and '"Others"' in groupby_sql
assert has_single_quotes or has_double_quotes, (
"Others literal should be quoted with either single or double quotes"
)
# Verify the structure of the generated SQL
assert "CASE WHEN" in select_sql
assert "CASE WHEN" in groupby_sql
# Check that ELSE is followed by a quoted value
assert "ELSE " in select_sql
assert "ELSE " in groupby_sql
# The key test is that GROUP BY expression doesn't have a label
# while SELECT might or might not have one depending on the database
# What matters is that GROUP BY should NOT have label
assert " AS " not in groupby_sql # GROUP BY should NOT have label
# Also verify that if SELECT has a label, it's different from GROUP BY
if " AS " in select_sql:
# If labeled, SELECT and GROUP BY should be different
assert select_sql != groupby_sql
def test_apply_series_others_grouping_no_label_in_groupby(database: Database) -> None:
"""
Test that GROUP BY expressions don't get wrapped with make_sqla_column_compatible.
This is a specific test for the bug fix where make_sqla_column_compatible
was causing issues with literal quoting in GROUP BY clauses.
"""
from unittest.mock import ANY, call, Mock, patch
from superset.connectors.sqla.models import SqlaTable, TableColumn
# Create a table instance
table = SqlaTable(
database=database,
schema=None,
table_name="test_table",
columns=[TableColumn(column_name="category", type="TEXT")],
)
# Mock expressions
category_expr = Mock()
category_expr.name = "category"
select_exprs = [category_expr]
groupby_all_columns = {"category": category_expr}
groupby_series_columns = {"category": category_expr}
def condition_factory(col_name: str, expr):
return True
# Track calls to make_sqla_column_compatible
with patch.object(
table, "make_sqla_column_compatible", side_effect=lambda expr, name: expr
) as mock_make_compatible:
result_select_exprs, result_groupby_columns = (
table._apply_series_others_grouping(
select_exprs,
groupby_all_columns,
groupby_series_columns,
condition_factory,
)
)
# Verify make_sqla_column_compatible was called for SELECT expressions
# but NOT for GROUP BY expressions
calls = mock_make_compatible.call_args_list
# Should have exactly one call (for the SELECT expression)
assert len(calls) == 1
# The call should be for the SELECT expression with the column name
# Using unittest.mock.ANY to match any CASE expression
assert calls[0] == call(ANY, "category")
# Verify the GROUP BY expression was NOT passed through
# make_sqla_column_compatible - it should be the raw CASE expression
assert "category" in result_groupby_columns
# The GROUP BY expression should be different from the SELECT expression
# because only SELECT gets make_sqla_column_compatible applied