blob: ec9689dd17c9945a3d92827de27647ba4c7f7b7a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel, protected-access
from __future__ import annotations
import json # noqa: TID251
from textwrap import dedent
from typing import Any
import pytest
from pytest_mock import MockerFixture
from sqlalchemy import types
from sqlalchemy.dialects import sqlite
from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.sql import sqltypes
from superset.sql.parse import Table
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import assert_column_spec
def test_get_text_clause_with_colon() -> None:
"""
Make sure text clauses are correctly escaped
"""
from superset.db_engine_specs.base import BaseEngineSpec
text_clause = BaseEngineSpec.get_text_clause(
"SELECT foo FROM tbl WHERE foo = '123:456')"
)
assert text_clause.text == "SELECT foo FROM tbl WHERE foo = '123\\:456')"
def test_validate_db_uri(mocker: MockerFixture) -> None:
"""
Ensures that the `validate_database_uri` method invokes the validator correctly
"""
def mock_validate(sqlalchemy_uri: URL) -> None:
raise ValueError("Invalid URI")
mocker.patch(
"flask.current_app.config",
{"DB_SQLA_URI_VALIDATOR": mock_validate},
)
from superset.db_engine_specs.base import BaseEngineSpec
with pytest.raises(ValueError): # noqa: PT011
BaseEngineSpec.validate_database_uri(URL.create("sqlite"))
@pytest.mark.parametrize(
"original,expected",
[
(
dedent(
"""
with currency as
(
select 'INR' as cur
)
select * from currency
"""
),
None,
),
(
"SELECT 1 as cnt",
None,
),
(
dedent(
"""
select 'INR' as cur
union
select 'AUD' as cur
union
select 'USD' as cur
"""
),
None,
),
],
)
def test_cte_query_parsing(original: types.TypeEngine, expected: str) -> None:
from superset.db_engine_specs.base import BaseEngineSpec
actual = BaseEngineSpec.get_cte_query(original)
assert actual == expected
@pytest.mark.parametrize(
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False),
("INTEGER", types.Integer, None, GenericDataType.NUMERIC, False),
("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False),
("DECIMAL", types.Numeric, None, GenericDataType.NUMERIC, False),
("NUMERIC", types.Numeric, None, GenericDataType.NUMERIC, False),
("REAL", types.REAL, None, GenericDataType.NUMERIC, False),
("DOUBLE PRECISION", types.Float, None, GenericDataType.NUMERIC, False),
("MONEY", types.Numeric, None, GenericDataType.NUMERIC, False),
# String
("CHAR", types.String, None, GenericDataType.STRING, False),
("VARCHAR", types.String, None, GenericDataType.STRING, False),
("TEXT", types.String, None, GenericDataType.STRING, False),
# Temporal
("DATE", types.Date, None, GenericDataType.TEMPORAL, True),
("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
("TIME", types.Time, None, GenericDataType.TEMPORAL, True),
# Boolean
("BOOLEAN", types.Boolean, None, GenericDataType.BOOLEAN, False),
],
)
def test_get_column_spec(
native_type: str,
sqla_type: type[types.TypeEngine],
attrs: dict[str, Any] | None,
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.databricks import (
DatabricksNativeEngineSpec as spec, # noqa: N813
)
assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)
@pytest.mark.parametrize(
"cols, expected_result",
[
(
[SQLAColumnType(name="John", type="integer", is_dttm=False)],
[
ResultSetColumnType(
column_name="John", name="John", type="integer", is_dttm=False
)
],
),
(
[SQLAColumnType(name="hugh", type="integer", is_dttm=False)],
[
ResultSetColumnType(
column_name="hugh", name="hugh", type="integer", is_dttm=False
)
],
),
],
)
def test_convert_inspector_columns(
cols: list[SQLAColumnType], expected_result: list[ResultSetColumnType]
):
from superset.db_engine_specs.base import convert_inspector_columns
assert convert_inspector_columns(cols) == expected_result
def test_select_star(mocker: MockerFixture) -> None:
"""
Test the ``select_star`` method.
"""
from superset.db_engine_specs.base import BaseEngineSpec
cols: list[ResultSetColumnType] = [
{
"column_name": "a",
"name": "a",
"type": sqltypes.String(),
"nullable": True,
"comment": None,
"default": None,
"precision": None,
"scale": None,
"max_length": None,
"is_dttm": False,
},
]
# mock the database so we can compile the query
database = mocker.MagicMock()
database.compile_sqla_query = lambda query, catalog, schema: str(
query.compile(dialect=sqlite.dialect())
)
engine = mocker.MagicMock()
engine.dialect = sqlite.dialect()
sql = BaseEngineSpec.select_star(
database=database,
table=Table("my_table", "my_schema", "my_catalog"),
engine=engine,
limit=100,
show_cols=True,
indent=True,
latest_partition=False,
cols=cols,
)
assert sql == "SELECT\n a\nFROM my_schema.my_table\nLIMIT ?\nOFFSET ?"
def test_extra_table_metadata(mocker: MockerFixture) -> None:
"""
Test the deprecated `extra_table_metadata` method.
"""
from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.core import Database
class ThirdPartyDBEngineSpec(BaseEngineSpec):
@classmethod
def extra_table_metadata(
cls,
database: Database,
table_name: str,
schema_name: str | None,
) -> dict[str, Any]:
return {"table": table_name, "schema": schema_name}
database = mocker.MagicMock()
warnings = mocker.patch("superset.db_engine_specs.base.warnings")
assert ThirdPartyDBEngineSpec.get_extra_table_metadata(
database,
Table("table", "schema"),
) == {"table": "table", "schema": "schema"}
assert (
ThirdPartyDBEngineSpec.get_extra_table_metadata(
database,
Table("table", "schema", "catalog"),
)
== {}
)
warnings.warn.assert_called()
def test_get_default_catalog(mocker: MockerFixture) -> None:
"""
Test the `get_default_catalog` method.
"""
from superset.db_engine_specs.base import BaseEngineSpec
database = mocker.MagicMock()
assert BaseEngineSpec.get_default_catalog(database) is None
def test_quote_table() -> None:
"""
Test the `quote_table` function.
"""
from superset.db_engine_specs.base import BaseEngineSpec
dialect = sqlite.dialect()
assert BaseEngineSpec.quote_table(Table("table"), dialect) == '"table"'
assert (
BaseEngineSpec.quote_table(Table("table", "schema"), dialect)
== 'schema."table"'
)
assert (
BaseEngineSpec.quote_table(Table("table", "schema", "catalog"), dialect)
== 'catalog.schema."table"'
)
assert (
BaseEngineSpec.quote_table(Table("ta ble", "sche.ma", 'cata"log'), dialect)
== '"cata""log"."sche.ma"."ta ble"'
)
def test_mask_encrypted_extra() -> None:
"""
Test that the private key is masked when the database is edited.
"""
from superset.db_engine_specs.base import BaseEngineSpec
config = json.dumps(
{
"foo": "bar",
"service_account_info": {
"project_id": "black-sanctum-314419",
"private_key": "SECRET",
},
}
)
assert BaseEngineSpec.mask_encrypted_extra(config) == json.dumps(
{
"foo": "XXXXXXXXXX",
"service_account_info": "XXXXXXXXXX",
}
)
def test_unmask_encrypted_extra() -> None:
"""
Test that the private key can be reused from the previous `encrypted_extra`.
"""
from superset.db_engine_specs.base import BaseEngineSpec
old = json.dumps(
{
"foo": "bar",
"service_account_info": {
"project_id": "black-sanctum-314419",
"private_key": "SECRET",
},
}
)
new = json.dumps(
{
"foo": "XXXXXXXXXX",
"service_account_info": "XXXXXXXXXX",
}
)
assert BaseEngineSpec.unmask_encrypted_extra(old, new) == json.dumps(
{
"foo": "bar",
"service_account_info": {
"project_id": "black-sanctum-314419",
"private_key": "SECRET",
},
}
)
def test_impersonate_user_backwards_compatible(mocker: MockerFixture) -> None:
"""
Test that the `impersonate_user` method calls the original methods it replaced.
"""
from superset.db_engine_specs.base import BaseEngineSpec
database = mocker.MagicMock()
url = make_url("sqlite://foo.db")
new_url = make_url("sqlite://bar.db")
engine_kwargs = {"connect_args": {"user": "alice"}}
get_url_for_impersonation = mocker.patch.object(
BaseEngineSpec,
"get_url_for_impersonation",
return_value=new_url,
)
update_impersonation_config = mocker.patch.object(
BaseEngineSpec,
"update_impersonation_config",
)
signature = mocker.patch("superset.db_engine_specs.base.signature")
signature().parameters = [
"cls",
"database",
"connect_args",
"uri",
"username",
"access_token",
]
BaseEngineSpec.impersonate_user(database, "alice", "SECRET", url, engine_kwargs)
get_url_for_impersonation.assert_called_once_with(url, True, "alice", "SECRET")
update_impersonation_config.assert_called_once_with(
database,
{"user": "alice"},
new_url,
"alice",
"SECRET",
)
def test_impersonate_user_no_database(mocker: MockerFixture) -> None:
"""
Test `impersonate_user` when `update_impersonation_config` has an old signature.
"""
from superset.db_engine_specs.base import BaseEngineSpec
database = mocker.MagicMock()
url = make_url("sqlite://foo.db")
new_url = make_url("sqlite://bar.db")
engine_kwargs = {"connect_args": {"user": "alice"}}
get_url_for_impersonation = mocker.patch.object(
BaseEngineSpec,
"get_url_for_impersonation",
return_value=new_url,
)
update_impersonation_config = mocker.patch.object(
BaseEngineSpec,
"update_impersonation_config",
)
signature = mocker.patch("superset.db_engine_specs.base.signature")
signature().parameters = [
"cls",
"connect_args",
"uri",
"username",
"access_token",
]
BaseEngineSpec.impersonate_user(database, "alice", "SECRET", url, engine_kwargs)
get_url_for_impersonation.assert_called_once_with(url, True, "alice", "SECRET")
update_impersonation_config.assert_called_once_with(
{"user": "alice"},
new_url,
"alice",
"SECRET",
)
def test_handle_boolean_filter_default_behavior() -> None:
"""
Test that BaseEngineSpec uses IS operators for boolean filters by default.
"""
from sqlalchemy import Boolean, Column
from superset.db_engine_specs.base import BaseEngineSpec
# Create a mock SQLAlchemy column
bool_col = Column("test_col", Boolean)
# Test IS_TRUE filter - should use IS operator by default
result_true = BaseEngineSpec.handle_boolean_filter(bool_col, "IS TRUE", True)
assert hasattr(result_true, "left") # IS comparison has left/right attributes
assert hasattr(result_true, "right")
# Test IS_FALSE filter - should use IS operator by default
result_false = BaseEngineSpec.handle_boolean_filter(bool_col, "IS FALSE", False)
assert hasattr(result_false, "left")
assert hasattr(result_false, "right")
def test_handle_boolean_filter_with_equality() -> None:
"""
Test that BaseEngineSpec can use equality operators when configured.
"""
from sqlalchemy import Boolean, Column
from superset.db_engine_specs.base import BaseEngineSpec
# Create a test engine spec that uses equality
class TestEngineSpec(BaseEngineSpec):
use_equality_for_boolean_filters = True
bool_col = Column("test_col", Boolean)
# Test with equality enabled
result_true = TestEngineSpec.handle_boolean_filter(bool_col, "IS TRUE", True)
# Equality comparison should have different structure than IS comparison
assert str(type(result_true)).endswith("BinaryExpression'>")
result_false = TestEngineSpec.handle_boolean_filter(bool_col, "IS FALSE", False)
assert str(type(result_false)).endswith("BinaryExpression'>")
def test_handle_null_filter() -> None:
"""
Test null/not null filter handling.
"""
from sqlalchemy import Boolean, Column
from superset.db_engine_specs.base import BaseEngineSpec
bool_col = Column("test_col", Boolean)
# Test IS_NULL - use actual FilterOperator values
from superset.utils.core import FilterOperator
result_null = BaseEngineSpec.handle_null_filter(bool_col, FilterOperator.IS_NULL)
assert hasattr(result_null, "left")
assert hasattr(result_null, "right")
# Test IS_NOT_NULL
result_not_null = BaseEngineSpec.handle_null_filter(
bool_col, FilterOperator.IS_NOT_NULL
)
assert hasattr(result_not_null, "left")
assert hasattr(result_not_null, "right")
# Test invalid operator
with pytest.raises(ValueError, match="Invalid null filter operator"):
BaseEngineSpec.handle_null_filter(bool_col, "INVALID") # type: ignore[arg-type]
def test_handle_comparison_filter() -> None:
"""
Test comparison filter handling for all operators.
"""
from sqlalchemy import Column, Integer
from superset.db_engine_specs.base import BaseEngineSpec
int_col = Column("test_col", Integer)
# Test all comparison operators - use actual FilterOperator values
from superset.utils.core import FilterOperator
operators_and_values = [
(FilterOperator.EQUALS, 5),
(FilterOperator.NOT_EQUALS, 5),
(FilterOperator.GREATER_THAN, 5),
(FilterOperator.LESS_THAN, 5),
(FilterOperator.GREATER_THAN_OR_EQUALS, 5),
(FilterOperator.LESS_THAN_OR_EQUALS, 5),
]
for op, value in operators_and_values:
result = BaseEngineSpec.handle_comparison_filter(int_col, op, value)
# All comparison operators should return binary expressions
assert str(type(result)).endswith("BinaryExpression'>")
# Test invalid operator
with pytest.raises(ValueError, match="Invalid comparison filter operator"):
BaseEngineSpec.handle_comparison_filter(int_col, "INVALID", 5) # type: ignore[arg-type]
def test_use_equality_for_boolean_filters_property() -> None:
"""
Test that BaseEngineSpec has the correct default value for boolean filter property.
"""
from superset.db_engine_specs.base import BaseEngineSpec
# Default should be False (use IS operators)
assert BaseEngineSpec.use_equality_for_boolean_filters is False