blob: 960a8d1c46a49f3ce7d19846483fa912bf9bc083 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
from unittest import mock
import pytest
from superset.db_engine_specs import engines
from superset.db_engine_specs.base import (
BaseEngineSpec,
builtin_time_grains,
LimitMethod,
)
from superset.db_engine_specs.sqlite import SqliteEngineSpec
from superset.sql_parse import ParsedQuery
from superset.utils.core import get_example_database
from tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.test_app import app
from ..fixtures.energy_dashboard import load_energy_table_with_slice
from ..fixtures.pyodbcRow import Row
class TestDbEngineSpecs(TestDbEngineSpec):
def test_extract_limit_from_query(self, engine_spec_class=BaseEngineSpec):
q0 = "select * from table"
q1 = "select * from mytable limit 10"
q2 = "select * from (select * from my_subquery limit 10) where col=1 limit 20"
q3 = "select * from (select * from my_subquery limit 10);"
q4 = "select * from (select * from my_subquery limit 10) where col=1 limit 20;"
q5 = "select * from mytable limit 20, 10"
q6 = "select * from mytable limit 10 offset 20"
q7 = "select * from mytable limit"
q8 = "select * from mytable limit 10.0"
q9 = "select * from mytable limit x"
q10 = "select * from mytable limit 20, x"
q11 = "select * from mytable limit x offset 20"
self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)
def test_wrapped_semi_tabs(self):
self.sql_limit_regex(
"SELECT * FROM a \t \n ; \t \n ", "SELECT * FROM a\nLIMIT 1000"
)
def test_simple_limit_query(self):
self.sql_limit_regex("SELECT * FROM a", "SELECT * FROM a\nLIMIT 1000")
def test_modify_limit_query(self):
self.sql_limit_regex("SELECT * FROM a LIMIT 9999", "SELECT * FROM a LIMIT 1000")
def test_limit_query_with_limit_subquery(self): # pylint: disable=invalid-name
self.sql_limit_regex(
"SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999",
"SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000",
)
def test_limit_with_expr(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990""",
"""SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 1000""",
)
def test_limit_expr_and_semicolon(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990 ;""",
"""SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 1000""",
)
def test_get_datatype(self):
self.assertEqual("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR"))
def test_limit_with_implicit_offset(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990, 999999""",
"""SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990, 1000""",
)
def test_limit_with_explicit_offset(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990
OFFSET 999999""",
"""SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 1000
OFFSET 999999""",
)
def test_limit_with_non_token_limit(self):
self.sql_limit_regex(
"""SELECT 'LIMIT 777'""", """SELECT 'LIMIT 777'\nLIMIT 1000"""
)
def test_limit_with_fetch_many(self):
class DummyEngineSpec(BaseEngineSpec):
limit_method = LimitMethod.FETCH_MANY
self.sql_limit_regex(
"SELECT * FROM table", "SELECT * FROM table", DummyEngineSpec
)
def test_time_grain_denylist(self):
with app.app_context():
app.config["TIME_GRAIN_DENYLIST"] = ["PT1M"]
time_grain_functions = SqliteEngineSpec.get_time_grain_expressions()
self.assertNotIn("PT1M", time_grain_functions)
def test_time_grain_addons(self):
with app.app_context():
app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
"sqlite": {"PTXM": "ABC({col})"}
}
time_grains = SqliteEngineSpec.get_time_grains()
time_grain_addon = time_grains[-1]
self.assertEqual("PTXM", time_grain_addon.duration)
self.assertEqual("x seconds", time_grain_addon.label)
app.config["TIME_GRAIN_ADDONS"] = {}
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {}
def test_engine_time_grain_validity(self):
time_grains = set(builtin_time_grains.keys())
# loop over all subclasses of BaseEngineSpec
for engine in engines.values():
if engine is not BaseEngineSpec:
# make sure time grain functions have been defined
self.assertGreater(len(engine.get_time_grain_expressions()), 0)
# make sure all defined time grains are supported
defined_grains = {grain.duration for grain in engine.get_time_grains()}
intersection = time_grains.intersection(defined_grains)
self.assertSetEqual(defined_grains, intersection, engine)
def test_get_table_names(self):
inspector = mock.Mock()
inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])
""" Make sure base engine spec removes schema name from table name
ie. when try_remove_schema_from_table_name == True. """
base_result_expected = ["table", "table_2"]
base_result = BaseEngineSpec.get_table_names(
database=mock.ANY, schema="schema", inspector=inspector
)
self.assertListEqual(base_result_expected, base_result)
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_column_datatype_to_string(self):
example_db = get_example_database()
sqla_table = example_db.get_table("energy_usage")
dialect = example_db.get_dialect()
# TODO: fix column type conversion for presto.
if example_db.backend == "presto":
return
col_names = [
example_db.db_engine_spec.column_datatype_to_string(c.type, dialect)
for c in sqla_table.columns
]
if example_db.backend == "postgresql":
expected = ["VARCHAR(255)", "VARCHAR(255)", "DOUBLE PRECISION"]
elif example_db.backend == "hive":
expected = ["STRING", "STRING", "FLOAT"]
else:
expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"]
self.assertEqual(col_names, expected)
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertIsNone(BaseEngineSpec.convert_dttm("", dttm))
def test_pyodbc_rows_to_tuples(self):
# Test for case when pyodbc.Row is returned (odbc driver)
data = [
Row((1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000))),
Row((2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000))),
]
expected = [
(1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
]
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
self.assertListEqual(result, expected)
def test_pyodbc_rows_to_tuples_passthrough(self):
# Test for case when tuples are returned
data = [
(1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
]
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
self.assertListEqual(result, data)
def test_is_readonly():
def is_readonly(sql: str) -> bool:
return BaseEngineSpec.is_readonly_query(ParsedQuery(sql))
assert is_readonly("SHOW LOCKS test EXTENDED")
assert not is_readonly("SET hivevar:desc='Legislators'")
assert not is_readonly("UPDATE t1 SET col1 = NULL")
assert is_readonly("EXPLAIN SELECT 1")
assert is_readonly("SELECT 1")
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
assert is_readonly("SHOW CATALOGS")
assert is_readonly("SHOW TABLES")