blob: af7b59030e98ee7c0d7cc54152381e0066ea22a7 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from __future__ import annotations
import copy
from collections import namedtuple
from datetime import datetime
from typing import Any, Optional
from unittest.mock import MagicMock, Mock, patch
import pandas as pd
import pytest
from flask import g, has_app_context
from pytest_mock import MockerFixture
from requests.exceptions import ConnectionError as RequestsConnectionError
from sqlalchemy import column, sql, text, types
from sqlalchemy.dialects import sqlite
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import NoSuchTableError
from trino.exceptions import TrinoExternalError, TrinoInternalError, TrinoUserError
from trino.sqlalchemy import datatype
from trino.sqlalchemy.dialect import TrinoDialect
import superset.config
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
from superset.db_engine_specs.exceptions import (
SupersetDBAPIConnectionError,
SupersetDBAPIDatabaseError,
SupersetDBAPIOperationalError,
SupersetDBAPIProgrammingError,
)
from superset.sql.parse import Table
from superset.superset_typing import (
OAuth2ClientConfig,
ResultSetColumnType,
SQLAColumnType,
SQLType,
)
from superset.utils import json
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
assert_convert_dttm,
)
from tests.unit_tests.fixtures.common import dttm # noqa: F401
def _assert_columns_equal(actual_cols, expected_cols) -> None:
"""
Assert equality of the given cols, bearing in mind sqlalchemy type
instances can't be compared for equality, so will have to be converted to
strings first.
"""
actual = copy.deepcopy(actual_cols)
expected = copy.deepcopy(expected_cols)
for col in actual:
col["type"] = str(col["type"])
for col in expected:
col["type"] = str(col["type"])
assert actual == expected
@pytest.mark.parametrize(
"extra,expected",
[
({}, {"engine_params": {"connect_args": {"source": "Apache Superset"}}}),
(
{
"first": 1,
"engine_params": {
"second": "two",
"connect_args": {"source": "foobar", "third": "three"},
},
},
{
"first": 1,
"engine_params": {
"second": "two",
"connect_args": {"source": "foobar", "third": "three"},
},
},
),
],
)
def test_get_extra_params(extra: dict[str, Any], expected: dict[str, Any]) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
database.extra = json.dumps(extra)
database.server_cert = None
assert TrinoEngineSpec.get_extra_params(database) == expected
@patch("superset.db_engine_specs.trino.create_ssl_cert_file")
def test_get_extra_params_with_server_cert(mock_create_ssl_cert_file: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
database.extra = json.dumps({})
database.server_cert = "TEST_CERT"
database.db_engine_spec = TrinoEngineSpec
mock_create_ssl_cert_file.return_value = "/path/to/tls.crt"
extra = TrinoEngineSpec.get_extra_params(database)
connect_args = extra.get("engine_params", {}).get("connect_args", {})
assert connect_args.get("http_scheme") == "https"
assert connect_args.get("verify") == "/path/to/tls.crt"
mock_create_ssl_cert_file.assert_called_once_with(database.server_cert)
@patch("trino.auth.BasicAuthentication")
def test_auth_basic(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {"username": "username", "password": "password"}
database.encrypted_extra = json.dumps(
{"auth_method": "basic", "auth_params": auth_params}
)
params: dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
@patch("trino.auth.KerberosAuthentication")
def test_auth_kerberos(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {
"service_name": "superset",
"mutual_authentication": False,
"delegate": True,
}
database.encrypted_extra = json.dumps(
{"auth_method": "kerberos", "auth_params": auth_params}
)
params: dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
@patch("trino.auth.CertificateAuthentication")
def test_auth_certificate(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {"cert": "/path/to/cert.pem", "key": "/path/to/key.pem"}
database.encrypted_extra = json.dumps(
{"auth_method": "certificate", "auth_params": auth_params}
)
params: dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
@patch("trino.auth.JWTAuthentication")
def test_auth_jwt(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {"token": "jwt-token-string"}
database.encrypted_extra = json.dumps(
{"auth_method": "jwt", "auth_params": auth_params}
)
params: dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
def test_auth_custom_auth() -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_class = Mock()
auth_method = "custom_auth"
auth_params = {"params1": "params1", "params2": "params2"}
database.encrypted_extra = json.dumps(
{"auth_method": auth_method, "auth_params": auth_params}
)
with patch.dict(
"superset.config.ALLOWED_EXTRA_AUTHENTICATIONS",
{"trino": {"custom_auth": auth_class}},
clear=True,
):
params: dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
auth_class.assert_called_once_with(**auth_params)
def test_auth_custom_auth_denied() -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_method = "my.module:TrinoAuthClass"
auth_params = {"params1": "params1", "params2": "params2"}
database.encrypted_extra = json.dumps(
{"auth_method": auth_method, "auth_params": auth_params}
)
superset.config.ALLOWED_EXTRA_AUTHENTICATIONS = {}
with pytest.raises(ValueError) as excinfo: # noqa: PT011
TrinoEngineSpec.update_params_from_encrypted_extra(database, {})
assert str(excinfo.value) == (
f"For security reason, custom authentication '{auth_method}' "
f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
)
@pytest.mark.parametrize(
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
("BOOLEAN", types.Boolean, None, GenericDataType.BOOLEAN, False),
("TINYINT", types.Integer, None, GenericDataType.NUMERIC, False),
("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False),
("INTEGER", types.Integer, None, GenericDataType.NUMERIC, False),
("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False),
("REAL", types.FLOAT, None, GenericDataType.NUMERIC, False),
("DOUBLE", types.FLOAT, None, GenericDataType.NUMERIC, False),
("DECIMAL", types.DECIMAL, None, GenericDataType.NUMERIC, False),
("VARCHAR", types.String, None, GenericDataType.STRING, False),
("VARCHAR(20)", types.VARCHAR, {"length": 20}, GenericDataType.STRING, False),
("CHAR", types.String, None, GenericDataType.STRING, False),
("CHAR(2)", types.CHAR, {"length": 2}, GenericDataType.STRING, False),
("JSON", types.JSON, None, GenericDataType.STRING, False),
("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
("TIMESTAMP(3)", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
(
"TIMESTAMP WITH TIME ZONE",
types.TIMESTAMP,
None,
GenericDataType.TEMPORAL,
True,
),
(
"TIMESTAMP(3) WITH TIME ZONE",
types.TIMESTAMP,
None,
GenericDataType.TEMPORAL,
True,
),
("DATE", types.Date, None, GenericDataType.TEMPORAL, True),
],
)
def test_get_column_spec(
native_type: str,
sqla_type: type[types.TypeEngine],
attrs: Optional[dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec as spec # noqa: N813
assert_column_spec(
spec,
native_type,
sqla_type,
attrs,
generic_type,
is_dttm,
)
@pytest.mark.parametrize(
"target_type,expected_result",
[
("TimeStamp", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("TimeStamp(3)", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("TimeStamp With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("TimeStamp(3) With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("Date", "DATE '2019-01-02'"),
("Other", None),
],
)
def test_convert_dttm(
target_type: str,
expected_result: Optional[str],
dttm: datetime, # noqa: F811
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
assert_convert_dttm(TrinoEngineSpec, target_type, expected_result, dttm)
def test_get_extra_table_metadata(mocker: MockerFixture) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
db_mock = mocker.MagicMock()
db_mock.get_indexes = Mock(
return_value=[{"column_names": ["ds", "hour"], "name": "partition"}]
)
db_mock.get_extra = Mock(return_value={})
db_mock.has_view = Mock(return_value=None)
db_mock.get_df = Mock(return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]}))
result = TrinoEngineSpec.get_extra_table_metadata(
db_mock,
Table("test_table", "test_schema"),
)
assert result["partitions"]["cols"] == ["ds", "hour"]
assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}
@patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_success(engine_mock: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is True
@patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_failed(engine_mock: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
query = Query()
cursor_mock = engine_mock.raiseError.side_effect = Exception()
assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is False
@pytest.mark.parametrize(
"initial_extra,final_extra",
[
({}, {QUERY_EARLY_CANCEL_KEY: True}),
({QUERY_CANCEL_KEY: "my_key"}, {QUERY_CANCEL_KEY: "my_key"}),
],
)
def test_prepare_cancel_query(
initial_extra: dict[str, Any],
final_extra: dict[str, Any],
mocker: MockerFixture,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
query = Query(extra_json=json.dumps(initial_extra))
TrinoEngineSpec.prepare_cancel_query(query=query)
assert query.extra == final_extra
@pytest.mark.parametrize("cancel_early", [True, False])
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@patch("sqlalchemy.engine.Engine.connect")
def test_handle_cursor_early_cancel(
engine_mock: Mock,
cancel_query_mock: Mock,
cancel_early: bool,
mocker: MockerFixture,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
query_id = "myQueryId"
cursor_mock = engine_mock.return_value.__enter__.return_value
cursor_mock.query_id = query_id
query = Query()
if cancel_early:
TrinoEngineSpec.prepare_cancel_query(query=query)
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
if cancel_early:
assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id
else:
assert cancel_query_mock.call_args is None
def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
"""Test that `execute_with_cursor` fetches query ID from the cursor"""
from superset.db_engine_specs.trino import TrinoEngineSpec
query_id = "myQueryId"
mock_cursor = mocker.MagicMock()
mock_cursor.query_id = None
mock_query = mocker.MagicMock()
def _mock_execute(*args, **kwargs):
mock_cursor.query_id = query_id
with app.test_request_context("/some/place/"):
mock_cursor.execute.side_effect = _mock_execute
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)
mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)
def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
"""Test that `execute_with_cursor` still contains the current app context"""
from superset.db_engine_specs.trino import TrinoEngineSpec
mock_cursor = mocker.MagicMock()
mock_cursor.query_id = None
mock_query = mocker.MagicMock()
def _mock_execute(*args, **kwargs):
assert has_app_context()
assert g.some_value == "some_value"
with app.test_request_context("/some/place/"):
g.some_value = "some_value"
with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)
def test_get_columns(mocker: MockerFixture):
"""Test that ROW columns are not expanded without expand_rows"""
from superset.db_engine_specs.trino import TrinoEngineSpec
field1_type = datatype.parse_sqltype("row(a varchar, b date)")
field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
field3_type = datatype.parse_sqltype("int")
sqla_columns = [
SQLAColumnType(name="field1", type=field1_type, is_dttm=False),
SQLAColumnType(name="field2", type=field2_type, is_dttm=False),
SQLAColumnType(name="field3", type=field3_type, is_dttm=False),
]
mock_inspector = mocker.MagicMock()
mock_inspector.get_columns.return_value = sqla_columns
actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", "schema"))
expected = [
ResultSetColumnType(
name="field1", column_name="field1", type=field1_type, is_dttm=False
),
ResultSetColumnType(
name="field2", column_name="field2", type=field2_type, is_dttm=False
),
ResultSetColumnType(
name="field3", column_name="field3", type=field3_type, is_dttm=False
),
]
_assert_columns_equal(actual, expected)
def test_get_columns_error(mocker: MockerFixture):
"""
Test that we fallback to a `SHOW COLUMNS FROM ...` query.
"""
from superset.db_engine_specs.trino import TrinoEngineSpec
field1_type = datatype.parse_sqltype("row(a varchar, b date)")
field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
field3_type = datatype.parse_sqltype("int")
mock_inspector = mocker.MagicMock()
mock_inspector.engine.dialect = sqlite.dialect()
mock_inspector.get_columns.side_effect = NoSuchTableError(
"The specified table does not exist."
)
Row = namedtuple("Row", ["Column", "Type"])
mock_inspector.bind.execute().fetchall.return_value = [
Row("field1", "row(a varchar, b date)"),
Row("field2", "row(r1 row(a varchar, b varchar))"),
Row("field3", "int"),
]
actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", "schema"))
expected = [
ResultSetColumnType(
name="field1",
column_name="field1",
type=field1_type,
is_dttm=None,
type_generic=None,
default=None,
nullable=True,
),
ResultSetColumnType(
name="field2",
column_name="field2",
type=field2_type,
is_dttm=None,
type_generic=None,
default=None,
nullable=True,
),
ResultSetColumnType(
name="field3",
column_name="field3",
type=field3_type,
is_dttm=None,
type_generic=None,
default=None,
nullable=True,
),
]
_assert_columns_equal(actual, expected)
mock_inspector.bind.execute.assert_called_with('SHOW COLUMNS FROM schema."table"')
def test_get_columns_expand_rows(mocker: MockerFixture):
"""Test that ROW columns are correctly expanded with expand_rows"""
from superset.db_engine_specs.trino import TrinoEngineSpec
field1_type = datatype.parse_sqltype("row(a varchar, b date)")
field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
field3_type = datatype.parse_sqltype("int")
sqla_columns = [
SQLAColumnType(name="field1", type=field1_type, is_dttm=False),
SQLAColumnType(name="field2", type=field2_type, is_dttm=False),
SQLAColumnType(name="field3", type=field3_type, is_dttm=False),
]
mock_inspector = mocker.MagicMock()
mock_inspector.get_columns.return_value = sqla_columns
actual = TrinoEngineSpec.get_columns(
mock_inspector,
Table("table", "schema"),
{"expand_rows": True},
)
expected = [
ResultSetColumnType(
name="field1", column_name="field1", type=field1_type, is_dttm=False
),
ResultSetColumnType(
name="field1.a",
column_name="field1.a",
type=types.VARCHAR(),
is_dttm=False,
query_as='"field1"."a" AS "field1.a"',
),
ResultSetColumnType(
name="field1.b",
column_name="field1.b",
type=types.DATE(),
is_dttm=True,
query_as='"field1"."b" AS "field1.b"',
),
ResultSetColumnType(
name="field2", column_name="field2", type=field2_type, is_dttm=False
),
ResultSetColumnType(
name="field2.r1",
column_name="field2.r1",
type=datatype.parse_sqltype("row(a varchar, b varchar)"),
is_dttm=False,
query_as='"field2"."r1" AS "field2.r1"',
),
ResultSetColumnType(
name="field2.r1.a",
column_name="field2.r1.a",
type=types.VARCHAR(),
is_dttm=False,
query_as='"field2"."r1"."a" AS "field2.r1.a"',
),
ResultSetColumnType(
name="field2.r1.b",
column_name="field2.r1.b",
type=types.VARCHAR(),
is_dttm=False,
query_as='"field2"."r1"."b" AS "field2.r1.b"',
),
ResultSetColumnType(
name="field3", column_name="field3", type=field3_type, is_dttm=False
),
]
_assert_columns_equal(actual, expected)
def test_get_indexes_no_table():
from superset.db_engine_specs.trino import TrinoEngineSpec
db_mock = Mock()
inspector_mock = Mock()
inspector_mock.get_indexes = Mock(
side_effect=NoSuchTableError("The specified table does not exist.")
)
result = TrinoEngineSpec.get_indexes(
db_mock,
inspector_mock,
Table("test_table", "test_schema"),
)
assert result == []
def test_get_dbapi_exception_mapping():
from superset.db_engine_specs.trino import TrinoEngineSpec
mapping = TrinoEngineSpec.get_dbapi_exception_mapping()
assert mapping.get(TrinoUserError) == SupersetDBAPIProgrammingError
assert mapping.get(TrinoInternalError) == SupersetDBAPIDatabaseError
assert mapping.get(TrinoExternalError) == SupersetDBAPIOperationalError
assert mapping.get(RequestsConnectionError) == SupersetDBAPIConnectionError
assert mapping.get(Exception) is None
def test_adjust_engine_params_fully_qualified() -> None:
"""
Test the ``adjust_engine_params`` method when the URL has catalog and schema.
"""
from superset.db_engine_specs.trino import TrinoEngineSpec
url = make_url("trino://user:pass@localhost:8080/system/default")
uri = TrinoEngineSpec.adjust_engine_params(url, {})[0]
assert str(uri) == "trino://user:pass@localhost:8080/system/default"
uri = TrinoEngineSpec.adjust_engine_params(
url,
{},
schema="new_schema",
)[0]
assert str(uri) == "trino://user:pass@localhost:8080/system/new_schema"
uri = TrinoEngineSpec.adjust_engine_params(
url,
{},
catalog="new_catalog",
)[0]
assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/default"
uri = TrinoEngineSpec.adjust_engine_params(
url,
{},
catalog="new_catalog",
schema="new_schema",
)[0]
assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/new_schema"
def test_adjust_engine_params_catalog_only() -> None:
"""
Test the ``adjust_engine_params`` method when the URL has only the catalog.
"""
from superset.db_engine_specs.trino import TrinoEngineSpec
url = make_url("trino://user:pass@localhost:8080/system")
uri = TrinoEngineSpec.adjust_engine_params(url, {})[0]
assert str(uri) == "trino://user:pass@localhost:8080/system"
uri = TrinoEngineSpec.adjust_engine_params(
url,
{},
schema="new_schema",
)[0]
assert str(uri) == "trino://user:pass@localhost:8080/system/new_schema"
uri = TrinoEngineSpec.adjust_engine_params(
url,
{},
catalog="new_catalog",
)[0]
assert str(uri) == "trino://user:pass@localhost:8080/new_catalog"
uri = TrinoEngineSpec.adjust_engine_params(
url,
{},
catalog="new_catalog",
schema="new_schema",
)[0]
assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/new_schema"
@pytest.mark.parametrize(
"sqlalchemy_uri,result",
[
("trino://user:pass@localhost:8080/system", "system"),
("trino://user:pass@localhost:8080/system/default", "system"),
("trino://trino@localhost:8081", None),
],
)
def test_get_default_catalog(sqlalchemy_uri: str, result: str | None) -> None:
"""
Test the ``get_default_catalog`` method.
"""
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.core import Database
database = Database(
database_name="my_db",
sqlalchemy_uri=sqlalchemy_uri,
)
assert TrinoEngineSpec.get_default_catalog(database) == result
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.latest_partition")
@pytest.mark.parametrize(
["column_type", "column_value", "expected_value"],
[
(types.DATE(), "2023-05-01", "DATE '2023-05-01'"),
(types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"),
(types.VARCHAR(), "2023-05-01", "'2023-05-01'"),
(types.INT(), 1234, "1234"),
],
)
def test_where_latest_partition(
mock_latest_partition,
column_type: SQLType,
column_value: Any,
expected_value: str,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
mock_latest_partition.return_value = (["partition_key"], [column_value])
assert (
str(
TrinoEngineSpec.where_latest_partition( # type: ignore
database=MagicMock(),
table=Table("table"),
query=sql.select(text("* FROM table")),
columns=[
{
"column_name": "partition_key",
"name": "partition_key",
"type": column_type,
"is_dttm": False,
}
],
).compile(
dialect=TrinoDialect(),
compile_kwargs={"literal_binds": True},
)
)
== f"""SELECT * FROM table \nWHERE partition_key = {expected_value}""" # noqa: S608
)
@pytest.fixture
def oauth2_config() -> OAuth2ClientConfig:
"""
Config for Trino OAuth2.
"""
return {
"id": "trino",
"secret": "very-secret",
"scope": "",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://trino.auth.server.example/realms/master/protocol/openid-connect/auth",
"token_request_uri": "https://trino.auth.server.example/master/protocol/openid-connect/token",
"request_content_type": "data",
}
def test_get_oauth2_token(
mocker: MockerFixture,
oauth2_config: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token`.
"""
from superset.db_engine_specs.trino import TrinoEngineSpec
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
assert TrinoEngineSpec.get_oauth2_token(oauth2_config, "code") == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://trino.auth.server.example/master/protocol/openid-connect/token",
data={
"code": "code",
"client_id": "trino",
"client_secret": "very-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)
@pytest.mark.parametrize(
"time_grain,expected_result",
[
("PT1S", "date_trunc('second', CAST(col AS TIMESTAMP))"),
(
"PT5S",
"date_trunc('second', CAST(col AS TIMESTAMP)) - interval '1' second * (second(CAST(col AS TIMESTAMP)) % 5)", # noqa: E501
),
(
"PT30S",
"date_trunc('second', CAST(col AS TIMESTAMP)) - interval '1' second * (second(CAST(col AS TIMESTAMP)) % 30)", # noqa: E501
),
("PT1M", "date_trunc('minute', CAST(col AS TIMESTAMP))"),
(
"PT5M",
"date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 5)", # noqa: E501
),
(
"PT10M",
"date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 10)", # noqa: E501
),
(
"PT15M",
"date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 15)", # noqa: E501
),
(
"PT0.5H",
"date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 30)", # noqa: E501
),
("PT1H", "date_trunc('hour', CAST(col AS TIMESTAMP))"),
(
"PT6H",
"date_trunc('hour', CAST(col AS TIMESTAMP)) - interval '1' hour * (hour(CAST(col AS TIMESTAMP)) % 6)", # noqa: E501
),
("P1D", "date_trunc('day', CAST(col AS TIMESTAMP))"),
("P1W", "date_trunc('week', CAST(col AS TIMESTAMP))"),
("P1M", "date_trunc('month', CAST(col AS TIMESTAMP))"),
("P3M", "date_trunc('quarter', CAST(col AS TIMESTAMP))"),
("P1Y", "date_trunc('year', CAST(col AS TIMESTAMP))"),
(
"1969-12-28T00:00:00Z/P1W",
"date_trunc('week', CAST(col AS TIMESTAMP) + interval '1' day) - interval '1' day", # noqa: E501
),
("1969-12-29T00:00:00Z/P1W", "date_trunc('week', CAST(col AS TIMESTAMP))"),
(
"P1W/1970-01-03T00:00:00Z",
"date_trunc('week', CAST(col AS TIMESTAMP) + interval '1' day) + interval '5' day", # noqa: E501
),
(
"P1W/1970-01-04T00:00:00Z",
"date_trunc('week', CAST(col AS TIMESTAMP)) + interval '6' day",
),
],
)
def test_timegrain_expressions(time_grain: str, expected_result: str) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec as spec # noqa: N813
actual = str(
spec.get_timestamp_expr(col=column("col"), pdf=None, time_grain=time_grain)
)
assert actual == expected_result