| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=invalid-name, redefined-outer-name, too-many-lines |
| |
| |
| import pytest |
| from pytest_mock import MockerFixture |
| from sqlglot import Dialects, exp, parse_one |
| |
| from superset.exceptions import QueryClauseValidationException, SupersetParseError |
| from superset.jinja_context import JinjaTemplateProcessor |
| from superset.sql.parse import ( |
| CTASMethod, |
| extract_tables_from_statement, |
| JinjaSQLResult, |
| KQLTokenType, |
| KustoKQLStatement, |
| LimitMethod, |
| process_jinja_sql, |
| remove_quotes, |
| RLSMethod, |
| sanitize_clause, |
| split_kql, |
| SQLGLOT_DIALECTS, |
| SQLScript, |
| SQLStatement, |
| Table, |
| tokenize_kql, |
| ) |
| from tests.integration_tests.conftest import with_feature_flags |
| |
| |
| def test_table() -> None: |
| """ |
| Test the `Table` class and its string conversion. |
| |
| Special characters in the table, schema, or catalog name should be escaped correctly. |
| """ # noqa: E501 |
| assert str(Table("tbname")) == "tbname" |
| assert str(Table("tbname", "schemaname")) == "schemaname.tbname" |
| assert ( |
| str(Table("tbname", "schemaname", "catalogname")) |
| == "catalogname.schemaname.tbname" |
| ) |
| assert ( |
| str(Table("table.name", "schema/name", "catalog\nname")) |
| == "catalog%0Aname.schema%2Fname.table%2Ename" |
| ) |
| |
| |
| def extract_tables_from_sql(sql: str, engine: str = "postgresql") -> set[Table]: |
| """ |
| Helper function to extract tables from SQL. |
| """ |
| dialect = SQLGLOT_DIALECTS.get(engine) |
| return { |
| table |
| for statement in SQLScript(sql, engine).statements |
| for table in extract_tables_from_statement(statement._parsed, dialect) |
| } |
| |
| |
| def test_extract_tables_from_sql() -> None: |
| """ |
| Test that referenced tables are parsed correctly from the SQL. |
| """ |
| assert extract_tables_from_sql("SELECT * FROM tbname") == {Table("tbname")} |
| assert extract_tables_from_sql("SELECT * FROM tbname foo") == {Table("tbname")} |
| assert extract_tables_from_sql("SELECT * FROM tbname AS foo") == {Table("tbname")} |
| |
| # underscore |
| assert extract_tables_from_sql("SELECT * FROM tb_name") == {Table("tb_name")} |
| |
| # quotes |
| assert extract_tables_from_sql('SELECT * FROM "tbname"') == {Table("tbname")} |
| |
| # unicode |
| assert extract_tables_from_sql('SELECT * FROM "tb_name" WHERE city = "Lübeck"') == { |
| Table("tb_name") |
| } |
| |
| # columns |
| assert extract_tables_from_sql("SELECT field1, field2 FROM tb_name") == { |
| Table("tb_name") |
| } |
| assert extract_tables_from_sql("SELECT t1.f1, t2.f2 FROM t1, t2") == { |
| Table("t1"), |
| Table("t2"), |
| } |
| |
| # named table |
| assert extract_tables_from_sql( |
| "SELECT a.date, a.field FROM left_table a LIMIT 10" |
| ) == {Table("left_table")} |
| |
| assert extract_tables_from_sql( |
| "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;" |
| ) == {Table("forbidden_table")} |
| |
| assert extract_tables_from_sql( |
| "select * from (select * from forbidden_table) forbidden_table" |
| ) == {Table("forbidden_table")} |
| |
| |
| def test_extract_tables_subselect() -> None: |
| """ |
| Test that tables inside subselects are parsed correctly. |
| """ |
| assert extract_tables_from_sql( |
| """ |
| SELECT sub.* |
| FROM ( |
| SELECT * |
| FROM s1.t1 |
| WHERE day_of_week = 'Friday' |
| ) sub, s2.t2 |
| WHERE sub.resolution = 'NONE' |
| """ |
| ) == {Table("t1", "s1"), Table("t2", "s2")} |
| |
| assert extract_tables_from_sql( |
| """ |
| SELECT sub.* |
| FROM ( |
| SELECT * |
| FROM s1.t1 |
| WHERE day_of_week = 'Friday' |
| ) sub |
| WHERE sub.resolution = 'NONE' |
| """ |
| ) == {Table("t1", "s1")} |
| |
| assert extract_tables_from_sql( |
| """ |
| SELECT * FROM t1 |
| WHERE s11 > ANY ( |
| SELECT COUNT(*) /* no hint */ FROM t2 |
| WHERE NOT EXISTS ( |
| SELECT * FROM t3 |
| WHERE ROW(5*t2.s1,77)=( |
| SELECT 50,11*s1 FROM t4 |
| ) |
| ) |
| ) |
| """ |
| ) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")} |
| |
| |
| def test_extract_tables_select_in_expression() -> None: |
| """ |
| Test that parser works with `SELECT`s used as expressions. |
| """ |
| assert extract_tables_from_sql("SELECT f1, (SELECT count(1) FROM t2) FROM t1") == { |
| Table("t1"), |
| Table("t2"), |
| } |
| assert extract_tables_from_sql( |
| "SELECT f1, (SELECT count(1) FROM t2) as f2 FROM t1" |
| ) == { |
| Table("t1"), |
| Table("t2"), |
| } |
| |
| |
| def test_extract_tables_parenthesis() -> None: |
| """ |
| Test that parenthesis are parsed correctly. |
| """ |
| assert extract_tables_from_sql("SELECT f1, (x + y) AS f2 FROM t1") == {Table("t1")} |
| |
| |
| def test_extract_tables_with_schema() -> None: |
| """ |
| Test that schemas are parsed correctly. |
| """ |
| assert extract_tables_from_sql("SELECT * FROM schemaname.tbname") == { |
| Table("tbname", "schemaname") |
| } |
| assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname"') == { |
| Table("tbname", "schemaname") |
| } |
| assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname" foo') == { |
| Table("tbname", "schemaname") |
| } |
| assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname" AS foo') == { |
| Table("tbname", "schemaname") |
| } |
| |
| |
| def test_extract_tables_union() -> None: |
| """ |
| Test that `UNION` queries work as expected. |
| """ |
| assert extract_tables_from_sql("SELECT * FROM t1 UNION SELECT * FROM t2") == { |
| Table("t1"), |
| Table("t2"), |
| } |
| assert extract_tables_from_sql("SELECT * FROM t1 UNION ALL SELECT * FROM t2") == { |
| Table("t1"), |
| Table("t2"), |
| } |
| assert extract_tables_from_sql( |
| "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2" |
| ) == { |
| Table("t1"), |
| Table("t2"), |
| } |
| |
| |
| def test_extract_tables_select_from_values() -> None: |
| """ |
| Test that selecting from values returns no tables. |
| """ |
| assert extract_tables_from_sql("SELECT * FROM VALUES (13, 42)") == set() |
| |
| |
| def test_extract_tables_select_array() -> None: |
| """ |
| Test that queries selecting arrays work as expected. |
| """ |
| assert extract_tables_from_sql( |
| """ |
| SELECT ARRAY[1, 2, 3] AS my_array |
| FROM t1 LIMIT 10 |
| """ |
| ) == {Table("t1")} |
| |
| |
| def test_extract_tables_select_if() -> None: |
| """ |
| Test that queries with an `IF` work as expected. |
| """ |
| assert extract_tables_from_sql( |
| """ |
| SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL) |
| FROM t1 LIMIT 10 |
| """ |
| ) == {Table("t1")} |
| |
| |
| def test_extract_tables_with_catalog() -> None: |
| """ |
| Test that catalogs are parsed correctly. |
| """ |
| assert extract_tables_from_sql("SELECT * FROM catalogname.schemaname.tbname") == { |
| Table("tbname", "schemaname", "catalogname") |
| } |
| |
| |
| def test_extract_tables_illdefined() -> None: |
| """ |
| Test that ill-defined tables return an empty set. |
| """ |
| with pytest.raises(SupersetParseError) as excinfo: |
| extract_tables_from_sql("SELECT * FROM schemaname.") |
| assert str(excinfo.value) == "Error parsing near '.' at line 1:25" |
| |
| with pytest.raises(SupersetParseError) as excinfo: |
| extract_tables_from_sql("SELECT * FROM catalogname.schemaname.") |
| assert str(excinfo.value) == "Error parsing near '.' at line 1:37" |
| |
| with pytest.raises(SupersetParseError) as excinfo: |
| extract_tables_from_sql("SELECT * FROM catalogname..") |
| assert str(excinfo.value) == "Error parsing near '.' at line 1:27" |
| |
| with pytest.raises(SupersetParseError) as excinfo: |
| extract_tables_from_sql('SELECT * FROM "tbname') |
| assert str(excinfo.value) == "Unable to parse script" |
| |
| # odd edge case that works |
| assert extract_tables_from_sql("SELECT * FROM catalogname..tbname") == { |
| Table(table="tbname", schema=None, catalog="catalogname") |
| } |
| |
| |
| def test_extract_tables_show_tables_from() -> None: |
| """ |
| Test `SHOW TABLES FROM`. |
| """ |
| assert ( |
| extract_tables_from_sql("SHOW TABLES FROM s1 like '%order%'", "mysql") == set() |
| ) |
| |
| |
| def test_format_show_tables() -> None: |
| """ |
| Test format when `ast.sql()` raises an exception. |
| """ |
| assert ( |
| SQLScript("SHOW TABLES FROM s1 like '%order%'", "mysql").format() |
| == "SHOW TABLES FROM s1 LIKE '%order%'" |
| ) |
| |
| |
| def test_format_no_dialect() -> None: |
| """ |
| Test format with an engine that has no corresponding dialect. |
| """ |
| assert ( |
| SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "dremio").format() |
| == """ |
| SELECT |
| col |
| FROM t |
| WHERE |
| NOT col IN (1, 2) |
| """.strip() |
| ) |
| |
| |
| def test_split_no_dialect() -> None: |
| """ |
| Test the statement split when the engine has no corresponding dialect. |
| """ |
| sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t; SELECT foo" |
| statements = SQLScript(sql, "dremio").statements |
| assert len(statements) == 3 |
| assert statements[0].format() == "SELECT\n col\nFROM t\nWHERE\n NOT col IN (1, 2)" |
| assert statements[1].format() == "SELECT\n *\nFROM t" |
| assert statements[2].format() == "SELECT\n foo" |
| |
| |
| def test_extract_tables_show_columns_from() -> None: |
| """ |
| Test `SHOW COLUMNS FROM`. |
| """ |
| assert extract_tables_from_sql("SHOW COLUMNS FROM t1") == {Table("t1")} |
| |
| |
| def test_extract_tables_where_subquery() -> None: |
| """ |
| Test that tables in a `WHERE` subquery are parsed correctly. |
| """ |
| assert extract_tables_from_sql( |
| """ |
| SELECT name |
| FROM t1 |
| WHERE regionkey = (SELECT max(regionkey) FROM t2) |
| """ |
| ) == {Table("t1"), Table("t2")} |
| |
| assert extract_tables_from_sql( |
| """ |
| SELECT name |
| FROM t1 |
| WHERE regionkey IN (SELECT regionkey FROM t2) |
| """ |
| ) == {Table("t1"), Table("t2")} |
| |
| assert extract_tables_from_sql( |
| """ |
| SELECT name |
| FROM t1 |
| WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey); |
| """ |
| ) == {Table("t1"), Table("t2")} |
| |
| |
| def test_extract_tables_describe() -> None: |
| """ |
| Test `DESCRIBE`. |
| """ |
| assert extract_tables_from_sql("DESCRIBE t1") == {Table("t1")} |
| |
| |
| def test_extract_tables_show_partitions() -> None: |
| """ |
| Test `SHOW PARTITIONS`. |
| """ |
| assert extract_tables_from_sql( |
| """ |
| SHOW PARTITIONS FROM orders |
| WHERE ds >= '2013-01-01' ORDER BY ds DESC |
| """ |
| ) == {Table("orders")} |
| |
| |
| def test_extract_tables_join() -> None: |
| """ |
| Test joins. |
| """ |
| assert extract_tables_from_sql( |
| "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;" |
| ) == { |
| Table("t1"), |
| Table("t2"), |
| } |
| |
| assert extract_tables_from_sql( |
| """ |
| SELECT a.date, b.name |
| FROM left_table a |
| JOIN ( |
| SELECT |
| CAST((b.year) as VARCHAR) date, |
| name |
| FROM right_table |
| ) b |
| ON a.date = b.date |
| """ |
| ) == {Table("left_table"), Table("right_table")} |
| |
| assert extract_tables_from_sql( |
| """ |
| SELECT a.date, b.name |
| FROM left_table a |
| LEFT INNER JOIN ( |
| SELECT |
| CAST((b.year) as VARCHAR) date, |
| name |
| FROM right_table |
| ) b |
| ON a.date = b.date |
| """ |
| ) == {Table("left_table"), Table("right_table")} |
| |
| assert extract_tables_from_sql( |
| """ |
| SELECT a.date, b.name |
| FROM left_table a |
| RIGHT OUTER JOIN ( |
| SELECT |
| CAST((b.year) as VARCHAR) date, |
| name |
| FROM right_table |
| ) b |
| ON a.date = b.date |
| """ |
| ) == {Table("left_table"), Table("right_table")} |
| |
| assert extract_tables_from_sql( |
| """ |
| SELECT a.date, b.name |
| FROM left_table a |
| FULL OUTER JOIN ( |
| SELECT |
| CAST((b.year) as VARCHAR) date, |
| name |
| FROM right_table |
| ) b |
| ON a.date = b.date |
| """ |
| ) == {Table("left_table"), Table("right_table")} |
| |
| |
| def test_extract_tables_semi_join() -> None: |
| """ |
| Test `LEFT SEMI JOIN`. |
| """ |
| assert extract_tables_from_sql( |
| """ |
| SELECT a.date, b.name |
| FROM left_table a |
| LEFT SEMI JOIN ( |
| SELECT |
| CAST((b.year) as VARCHAR) date, |
| name |
| FROM right_table |
| ) b |
| ON a.data = b.date |
| """ |
| ) == {Table("left_table"), Table("right_table")} |
| |
| |
| def test_extract_tables_combinations() -> None: |
| """ |
| Test a complex case with nested queries. |
| """ |
| assert extract_tables_from_sql( |
| """ |
| SELECT * FROM t1 |
| WHERE s11 > ANY ( |
| SELECT * FROM t1 UNION ALL SELECT * FROM ( |
| SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a |
| ) tmp_join |
| WHERE NOT EXISTS ( |
| SELECT * FROM t3 |
| WHERE ROW(5*t3.s1,77)=( |
| SELECT 50,11*s1 FROM t4 |
| ) |
| ) |
| ) |
| """ |
| ) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")} |
| |
| assert extract_tables_from_sql( |
| """ |
| SELECT * FROM ( |
| SELECT * FROM ( |
| SELECT * FROM ( |
| SELECT * FROM EmployeeS |
| ) AS S1 |
| ) AS S2 |
| ) AS S3 |
| """ |
| ) == {Table("EmployeeS")} |
| |
| |
| def test_extract_tables_with() -> None: |
| """ |
| Test `WITH`. |
| """ |
| assert extract_tables_from_sql( |
| """ |
| WITH |
| x AS (SELECT a FROM t1), |
| y AS (SELECT a AS b FROM t2), |
| z AS (SELECT b AS c FROM t3) |
| SELECT c FROM z |
| """ |
| ) == {Table("t1"), Table("t2"), Table("t3")} |
| |
| assert extract_tables_from_sql( |
| """ |
| WITH |
| x AS (SELECT a FROM t1), |
| y AS (SELECT a AS b FROM x), |
| z AS (SELECT b AS c FROM y) |
| SELECT c FROM z |
| """ |
| ) == {Table("t1")} |
| |
| |
| def test_extract_tables_reusing_aliases() -> None: |
| """ |
| Test that the parser follows aliases. |
| """ |
| assert extract_tables_from_sql( |
| """ |
| with q1 as ( select key from q2 where key = '5'), |
| q2 as ( select key from src where key = '5') |
| select * from (select key from q1) a |
| """ |
| ) == {Table("src")} |
| |
| # weird query with circular dependency |
| assert ( |
| extract_tables_from_sql( |
| """ |
| with src as ( select key from q2 where key = '5'), |
| q2 as ( select key from src where key = '5') |
| select * from (select key from src) a |
| """ |
| ) |
| == set() |
| ) |
| |
| |
| def test_extract_tables_multistatement() -> None: |
| """ |
| Test that the parser works with multiple statements. |
| """ |
| assert extract_tables_from_sql("SELECT * FROM t1; SELECT * FROM t2") == { |
| Table("t1"), |
| Table("t2"), |
| } |
| assert extract_tables_from_sql("SELECT * FROM t1; SELECT * FROM t2;") == { |
| Table("t1"), |
| Table("t2"), |
| } |
| assert extract_tables_from_sql( |
| "ADD JAR file:///hive.jar; SELECT * FROM t1;", |
| engine="hive", |
| ) == {Table("t1")} |
| |
| |
| def test_extract_tables_complex() -> None: |
| """ |
| Test a few complex queries. |
| """ |
| assert extract_tables_from_sql( |
| """ |
| SELECT sum(m_examples) AS "sum__m_example" |
| FROM ( |
| SELECT |
| COUNT(DISTINCT id_userid) AS m_examples, |
| some_more_info |
| FROM my_b_table b |
| JOIN my_t_table t ON b.ds=t.ds |
| JOIN my_l_table l ON b.uid=l.uid |
| WHERE |
| b.rid IN ( |
| SELECT other_col |
| FROM inner_table |
| ) |
| AND l.bla IN ('x', 'y') |
| GROUP BY 2 |
| ORDER BY 2 ASC |
| ) AS "meh" |
| ORDER BY "sum__m_example" DESC |
| LIMIT 10; |
| """ |
| ) == { |
| Table("my_l_table"), |
| Table("my_b_table"), |
| Table("my_t_table"), |
| Table("inner_table"), |
| } |
| |
| assert extract_tables_from_sql( |
| """ |
| SELECT * |
| FROM table_a AS a, table_b AS b, table_c as c |
| WHERE a.id = b.id and b.id = c.id |
| """ |
| ) == {Table("table_a"), Table("table_b"), Table("table_c")} |
| |
| assert extract_tables_from_sql( |
| """ |
| SELECT somecol AS somecol |
| FROM ( |
| WITH bla AS ( |
| SELECT col_a |
| FROM a |
| WHERE |
| 1=1 |
| AND column_of_choice NOT IN ( |
| SELECT interesting_col |
| FROM b |
| ) |
| ), |
| rb AS ( |
| SELECT yet_another_column |
| FROM ( |
| SELECT a |
| FROM c |
| GROUP BY the_other_col |
| ) not_table |
| LEFT JOIN bla foo |
| ON foo.prop = not_table.bad_col0 |
| WHERE 1=1 |
| GROUP BY |
| not_table.bad_col1 , |
| not_table.bad_col2 , |
| ORDER BY not_table.bad_col_3 DESC , |
| not_table.bad_col4 , |
| not_table.bad_col5 |
| ) |
| SELECT random_col |
| FROM d |
| WHERE 1=1 |
| UNION ALL SELECT even_more_cols |
| FROM e |
| WHERE 1=1 |
| UNION ALL SELECT lets_go_deeper |
| FROM f |
| WHERE 1=1 |
| GROUP BY last_col |
| LIMIT 50000 |
| ) |
| """ |
| ) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")} |
| |
| |
| def test_extract_tables_mixed_from_clause() -> None: |
| """ |
| Test that the parser handles a `FROM` clause with table and subselect. |
| """ |
| assert extract_tables_from_sql( |
| """ |
| SELECT * |
| FROM table_a AS a, (select * from table_b) AS b, table_c as c |
| WHERE a.id = b.id and b.id = c.id |
| """ |
| ) == {Table("table_a"), Table("table_b"), Table("table_c")} |
| |
| |
| def test_extract_tables_nested_select() -> None: |
| """ |
| Test that the parser handles selects inside functions. |
| """ |
| assert extract_tables_from_sql( |
| """ |
| select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME) |
| from INFORMATION_SCHEMA.COLUMNS |
| WHERE TABLE_SCHEMA like "%bi%"),0x7e))); |
| """, |
| "mysql", |
| ) == {Table("COLUMNS", "INFORMATION_SCHEMA")} |
| |
| assert extract_tables_from_sql( |
| """ |
| select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME) |
| from INFORMATION_SCHEMA.COLUMNS |
| WHERE TABLE_NAME="bi_achievement_daily"),0x7e))); |
| """, |
| "mysql", |
| ) == {Table("COLUMNS", "INFORMATION_SCHEMA")} |
| |
| |
| def test_extract_tables_complex_cte_with_prefix() -> None: |
| """ |
| Test that the parser handles CTEs with prefixes. |
| """ |
| assert extract_tables_from_sql( |
| """ |
| WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear) |
| AS ( |
| SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear |
| FROM SalesOrderHeader |
| WHERE SalesPersonID IS NOT NULL |
| ) |
| SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear |
| FROM CTE__test |
| GROUP BY SalesYear, SalesPersonID |
| ORDER BY SalesPersonID, SalesYear; |
| """ |
| ) == {Table("SalesOrderHeader")} |
| |
| |
| def test_extract_tables_identifier_list_with_keyword_as_alias() -> None: |
| """ |
| Test that aliases that are keywords are parsed correctly. |
| """ |
| assert extract_tables_from_sql( |
| """ |
| WITH |
| f AS (SELECT * FROM foo), |
| match AS (SELECT * FROM f) |
| SELECT * FROM match |
| """ |
| ) == {Table("foo")} |
| |
| |
| def test_sqlscript() -> None: |
| """ |
| Test the `SQLScript` class. |
| """ |
| script = SQLScript("SELECT 1; SELECT 2;", "sqlite") |
| |
| assert len(script.statements) == 2 |
| assert script.format() == "SELECT\n 1;\nSELECT\n 2" |
| assert script.statements[0].format() == "SELECT\n 1" |
| |
| script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite") |
| assert script.get_settings() == {"a": "2"} |
| |
| query = SQLScript( |
| """set querytrace; |
| Events | take 100""", |
| "kustokql", |
| ) |
| assert query.get_settings() == {"querytrace": True} |
| |
| |
| @pytest.mark.parametrize( |
| "sql, engine, expected", |
| [ |
| ( |
| " SELECT foo FROM tbl ; ", |
| "postgresql", |
| ["SELECT\n foo\nFROM tbl"], |
| ), |
| ( |
| "SELECT foo FROM tbl1; SELECT bar FROM tbl2;", |
| "postgresql", |
| ["SELECT\n foo\nFROM tbl1", "SELECT\n bar\nFROM tbl2"], |
| ), |
| ( |
| "let foo = 1; tbl | where bar == foo", |
| "kustokql", |
| ["let foo = 1", "tbl | where bar == foo"], |
| ), |
| ( |
| "SELECT 1; -- extraneous comment", |
| "postgresql", |
| ["SELECT\n 1 /* extraneous comment */"], |
| ), |
| ( |
| "SHOW TABLES FROM s1 like '%order%';", |
| "mysql", |
| ["SHOW TABLES FROM s1 LIKE '%order%'"], |
| ), |
| ( |
| "SELECT 1; SELECT 2; SELECT 3;", |
| "unknown-engine", |
| [ |
| "SELECT\n 1", |
| "SELECT\n 2", |
| "SELECT\n 3", |
| ], |
| ), |
| ], |
| ) |
| def test_sqlscript_split(sql: str, engine: str, expected: list[str]) -> None: |
| """ |
| Test the `SQLScript` class with a script that has a single statement. |
| """ |
| script = SQLScript(sql, engine) |
| assert [statement.format() for statement in script.statements] == expected |
| |
| |
| def test_sqlstatement() -> None: |
| """ |
| Test the `SQLStatement` class. |
| """ |
| statement = SQLStatement( |
| "SELECT * FROM table1 UNION ALL SELECT * FROM table2", |
| "sqlite", |
| ) |
| |
| assert ( |
| statement.format() |
| == "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2" |
| ) |
| assert str(statement) == statement.format() |
| |
| assert statement.tables == { |
| Table(table="table1", schema=None, catalog=None), |
| Table(table="table2", schema=None, catalog=None), |
| } |
| |
| assert statement.parse_predicate("a > 1") == exp.GT( |
| this=exp.Column(this=exp.Identifier(this="a", quoted=False)), |
| expression=exp.Literal(this="1", is_string=False), |
| ) |
| |
| statement = SQLStatement("SET a=1", "sqlite") |
| assert statement.get_settings() == {"a": "1"} |
| |
| with pytest.raises( |
| ValueError, |
| match="Either statement or ast must be provided", |
| ): |
| SQLStatement() |
| |
| |
| def test_kustokqlstatement() -> None: |
| """ |
| Test the `KustoKQLStatement` class. |
| """ |
| statement = KustoKQLStatement("foo | take 100", "kustokql") |
| |
| assert statement.format() == "foo | take 100" |
| assert str(statement) == statement.format() |
| |
| # doesn't support table extraction |
| assert statement.tables == set() |
| |
| # optimize is a no-op |
| assert statement.optimize().format() == "foo | take 100" |
| |
| # predicate parsing is also no-op |
| assert statement.parse_predicate("a > 1") == "a > 1" |
| |
| with pytest.raises(SupersetParseError, match="Invalid engine: invalid-engine"): |
| KustoKQLStatement("foo | take 100", "invalid-engine") |
| |
| with pytest.raises( |
| SupersetParseError, |
| match="KustoKQLStatement should have exactly one statement", |
| ): |
| KustoKQLStatement("foo | take 1; bar | take 2", "kustokql") |
| |
| |
| def test_kustokqlstatement_split_script() -> None: |
| """ |
| Test the `KustoKQLStatement` split method. |
| """ |
| statements = KustoKQLStatement.split_script( |
| """ |
| let totalPagesPerDay = PageViews |
| | summarize by Page, Day = startofday(Timestamp) |
| | summarize count() by Day; |
| let materializedScope = PageViews |
| | summarize by Page, Day = startofday(Timestamp); |
| let cachedResult = materialize(materializedScope); |
| cachedResult |
| | project Page, Day1 = Day |
| | join kind = inner |
| ( |
| cachedResult |
| | project Page, Day2 = Day |
| ) |
| on Page |
| | where Day2 > Day1 |
| | summarize count() by Day1, Day2 |
| | join kind = inner |
| totalPagesPerDay |
| on $left.Day1 == $right.Day |
| | project Day1, Day2, Percentage = count_*100.0/count_1 |
| """, |
| "kustokql", |
| ) |
| assert len(statements) == 4 |
| |
| |
| def test_kustokqlstatement_with_program() -> None: |
| """ |
| Test the `KustoKQLStatement` split method when the KQL has a program. |
| """ |
| statements = KustoKQLStatement.split_script( |
| """ |
| print program = ``` |
| public class Program { |
| public static void Main() { |
| System.Console.WriteLine("Hello!"); |
| } |
| }``` |
| """, |
| "kustokql", |
| ) |
| assert len(statements) == 1 |
| |
| |
| def test_kustokqlstatement_with_set() -> None: |
| """ |
| Test the `KustoKQLStatement` split method when the KQL has a set command. |
| """ |
| statements = KustoKQLStatement.split_script( |
| """ |
| set querytrace; |
| Events | take 100 |
| """, |
| "kustokql", |
| ) |
| assert len(statements) == 2 |
| assert statements[0].format() == "set querytrace" |
| assert statements[1].format() == "Events | take 100" |
| |
| |
| @pytest.mark.parametrize( |
| "kql,statements", |
| [ |
| ('print banner=strcat("Hello", ", ", "World!")', 1), |
| (r"print 'O\'Malley\'s'", 1), |
| (r"print 'O\'Mal;ley\'s'", 1), |
| ("print ```foo;\nbar;\nbaz;```\n", 1), |
| ], |
| ) |
| def test_kustokql_statement_split_special(kql: str, statements: int) -> None: |
| assert len(KustoKQLStatement.split_script(kql, "kustokql")) == statements |
| |
| |
| @pytest.mark.parametrize( |
| "kql, expected", |
| [ |
| (";Table | take 5", ["Table | take 5"]), |
| (";Table | take 5;", ["Table | take 5"]), |
| ( |
| """ |
| let totalPagesPerDay = PageViews |
| | summarize by Page, Day = startofday(Timestamp) |
| | summarize count() by Day; |
| let materializedScope = PageViews |
| | summarize by Page, Day = startofday(Timestamp); |
| let cachedResult = materialize(materializedScope); |
| cachedResult |
| | project Page, Day1 = Day |
| | join kind = inner |
| ( |
| cachedResult |
| | project Page, Day2 = Day |
| ) |
| on Page |
| | where Day2 > Day1 |
| | summarize count() by Day1, Day2 |
| | join kind = inner |
| totalPagesPerDay |
| on $left.Day1 == $right.Day |
| | project Day1, Day2, Percentage = count_*100.0/count_1 |
| """, |
| [ |
| """ |
| let totalPagesPerDay = PageViews |
| | summarize by Page, Day = startofday(Timestamp) |
| | summarize count() by Day""", |
| """ |
| let materializedScope = PageViews |
| | summarize by Page, Day = startofday(Timestamp)""", |
| """ |
| let cachedResult = materialize(materializedScope)""", |
| """ |
| cachedResult |
| | project Page, Day1 = Day |
| | join kind = inner |
| ( |
| cachedResult |
| | project Page, Day2 = Day |
| ) |
| on Page |
| | where Day2 > Day1 |
| | summarize count() by Day1, Day2 |
| | join kind = inner |
| totalPagesPerDay |
| on $left.Day1 == $right.Day |
| | project Day1, Day2, Percentage = count_*100.0/count_1 |
| """, |
| ], |
| ), |
| ], |
| ) |
| def test_split_kql(kql: str, expected: list[str]) -> None: |
| """ |
| Test the `split_kql` function. |
| """ |
| assert split_kql(kql) == expected |
| |
| |
| @pytest.mark.parametrize( |
| ("engine", "sql", "expected"), |
| [ |
| ("sqlite", "SELECT 1", False), |
| ("sqlite", "INSERT INTO foo VALUES (1)", True), |
| ("sqlite", "UPDATE foo SET bar = 2 WHERE id = 1", True), |
| ("sqlite", "DELETE FROM foo WHERE id = 1", True), |
| ("sqlite", "CREATE TABLE foo (id INT, bar TEXT)", True), |
| ("sqlite", "DROP TABLE foo", True), |
| ("sqlite", "EXPLAIN SELECT * FROM foo", False), |
| ("sqlite", "PRAGMA table_info(foo)", False), |
| ("postgresql", "SELECT 1", False), |
| ("postgresql", "INSERT INTO foo (id, bar) VALUES (1, 'test')", True), |
| ("postgresql", "UPDATE foo SET bar = 'new' WHERE id = 1", True), |
| ("postgresql", "DELETE FROM foo WHERE id = 1", True), |
| ("postgresql", "CREATE TABLE foo (id SERIAL PRIMARY KEY, bar TEXT)", True), |
| ("postgresql", "DROP TABLE foo", True), |
| ("postgresql", "EXPLAIN ANALYZE SELECT * FROM foo", False), |
| ("postgresql", "EXPLAIN ANALYZE DELETE FROM foo", True), |
| ("postgresql", "SHOW search_path", False), |
| ("postgresql", "SET search_path TO public", False), |
| ( |
| "postgres", |
| """ |
| with source as ( |
| select 1 as one |
| ) |
| select * from source |
| """, |
| False, |
| ), |
| ("trino", "SELECT 1", False), |
| ("trino", "INSERT INTO foo VALUES (1, 'bar')", True), |
| ("trino", "UPDATE foo SET bar = 'baz' WHERE id = 1", True), |
| ("trino", "DELETE FROM foo WHERE id = 1", True), |
| ("trino", "CREATE TABLE foo (id INT, bar VARCHAR)", True), |
| ("trino", "DROP TABLE foo", True), |
| ("trino", "EXPLAIN SELECT * FROM foo", False), |
| ("trino", "SHOW SCHEMAS", False), |
| ("trino", "SET SESSION optimization_level = '3'", False), |
| ("kustokql", "tbl | limit 100", False), |
| ("kustokql", "let foo = 1; tbl | where bar == foo", False), |
| ("kustokql", ".show tables", False), |
| ("kustokql", "print 1", False), |
| ("kustokql", "set querytrace; Events | take 100", False), |
| ("kustokql", ".drop table foo", True), |
| ("kustokql", ".set-or-append table foo <| bar", True), |
| ("base", "SHOW LOCKS test EXTENDED", False), |
| ("base", "SET hivevar:desc='Legislators'", False), |
| ("base", "UPDATE t1 SET col1 = NULL", True), |
| ("base", "EXPLAIN SELECT 1", False), |
| ("base", "SELECT 1", False), |
| ("base", "WITH bla AS (SELECT 1) SELECT * FROM bla", False), |
| ("base", "SHOW CATALOGS", False), |
| ("base", "SHOW TABLES", False), |
| ("hive", "UPDATE t1 SET col1 = NULL", True), |
| ("hive", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True), |
| ("hive", "SHOW LOCKS test EXTENDED", False), |
| ("hive", "SET hivevar:desc='Legislators'", False), |
| ("hive", "EXPLAIN SELECT 1", False), |
| ("hive", "SELECT 1", False), |
| ("hive", "WITH bla AS (SELECT 1) SELECT * FROM bla", False), |
| ("presto", "SET hivevar:desc='Legislators'", False), |
| ("presto", "UPDATE t1 SET col1 = NULL", True), |
| ("presto", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True), |
| ("presto", "SHOW LOCKS test EXTENDED", False), |
| ("presto", "EXPLAIN SELECT 1", False), |
| ("presto", "SELECT 1", False), |
| ("presto", "WITH bla AS (SELECT 1) SELECT * FROM bla", False), |
| ], |
| ) |
| def test_has_mutation(engine: str, sql: str, expected: bool) -> None: |
| """ |
| Test the `has_mutation` method. |
| """ |
| assert SQLScript(sql, engine).has_mutation() == expected |
| |
| |
| def test_get_settings() -> None: |
| """ |
| Test `get_settings` in some edge cases. |
| """ |
| sql = """ |
| set |
| -- this is a tricky comment |
| search_path -- another one |
| = bar; |
| SELECT * FROM some_table; |
| """ |
| assert SQLScript(sql, "postgresql").get_settings() == {"search_path": "bar"} |
| |
| |
| @pytest.mark.parametrize( |
| "app", |
| [{"SQLGLOT_DIALECTS_EXTENSIONS": {"custom": Dialects.MYSQL}}], |
| indirect=True, |
| ) |
| def test_custom_dialect(app: None) -> None: |
| """ |
| Test that custom dialects are loaded correctly. |
| """ |
| assert SQLGLOT_DIALECTS.get("custom") == Dialects.MYSQL |
| |
| |
| @pytest.mark.parametrize( |
| "engine", |
| [ |
| "ascend", |
| "awsathena", |
| "base", |
| "bigquery", |
| "clickhouse", |
| "clickhousedb", |
| "cockroachdb", |
| "couchbase", |
| "crate", |
| "databend", |
| "databricks", |
| "db2", |
| "denodo", |
| "dremio", |
| "drill", |
| "druid", |
| "duckdb", |
| "dynamodb", |
| "elasticsearch", |
| "exa", |
| "firebird", |
| "firebolt", |
| "gsheets", |
| "hana", |
| "hive", |
| "ibmi", |
| "impala", |
| "kustokql", |
| "kustosql", |
| "kylin", |
| "mariadb", |
| "motherduck", |
| "mssql", |
| "mysql", |
| "netezza", |
| "oceanbase", |
| "ocient", |
| "odelasticsearch", |
| "oracle", |
| "pinot", |
| "postgresql", |
| "presto", |
| "pydoris", |
| "redshift", |
| "risingwave", |
| "shillelagh", |
| "snowflake", |
| "solr", |
| "sqlite", |
| "starrocks", |
| "superset", |
| "teradatasql", |
| "trino", |
| "vertica", |
| ], |
| ) |
| @pytest.mark.parametrize( |
| "sql, expected", |
| [ |
| ("SELECT 1", False), |
| ("with source as ( select 1 as one ) select * from source", False), |
| ("ALTER TABLE foo ADD COLUMN bar INT", True), |
| ], |
| ) |
| def test_is_mutating(sql: str, engine: str, expected: bool) -> None: |
| """ |
| Global tests for `is_mutating`, covering all supported engines. |
| """ |
| assert SQLStatement(sql, engine).is_mutating() == expected |
| |
| |
| @pytest.mark.parametrize( |
| "sql, expected", |
| [ |
| ( |
| """ |
| DO $$ |
| BEGIN |
| INSERT INTO public.users (name, real_name) |
| VALUES ('SQLLab bypass DML', 'SQLLab bypass DML'); |
| END; |
| $$; |
| """, |
| True, |
| ), |
| ( |
| """ |
| DO $$ |
| BEGIN |
| IF (SELECT COUNT(*) FROM orders WHERE status = 'pending') > 100 THEN |
| RAISE NOTICE 'High pending order volume detected'; |
| END IF; |
| END; |
| $$; |
| """, |
| True, |
| ), |
| ], |
| ) |
| def test_is_mutating_anonymous_block(sql: str, expected: bool) -> None: |
| """ |
| Test for `is_mutating` with a Postgres anonymous block. |
| |
| Since we can't parse the PL/pgSQL inside the block we always assume it is mutating. |
| """ |
| assert SQLStatement(sql, "postgresql").is_mutating() == expected |
| |
| |
| def test_optimize() -> None: |
| """ |
| Test that the `optimize` method works as expected. |
| |
| The SQL optimization only works with engines that have a corresponding dialect. |
| """ |
| sql = """ |
| SELECT anon_1.a, anon_1.b |
| FROM (SELECT some_table.a AS a, some_table.b AS b, some_table.c AS c |
| FROM some_table) AS anon_1 |
| WHERE anon_1.a > 1 AND anon_1.b = 2 |
| """ |
| |
| optimized = """ |
| SELECT |
| anon_1.a, |
| anon_1.b |
| FROM ( |
| SELECT |
| some_table.a AS a, |
| some_table.b AS b, |
| some_table.c AS c |
| FROM some_table |
| WHERE |
| some_table.a > 1 AND some_table.b = 2 |
| ) AS anon_1 |
| WHERE |
| TRUE AND TRUE |
| """.strip() |
| |
| not_optimized = """ |
| SELECT |
| anon_1.a, |
| anon_1.b |
| FROM ( |
| SELECT |
| some_table.a AS a, |
| some_table.b AS b, |
| some_table.c AS c |
| FROM some_table |
| ) AS anon_1 |
| WHERE |
| anon_1.a > 1 AND anon_1.b = 2 |
| """.strip() |
| |
| assert SQLStatement(sql, "sqlite").optimize().format() == optimized |
| assert SQLStatement(sql, "crate").optimize().format() == not_optimized |
| |
| # also works for scripts |
| assert SQLScript(sql, "sqlite").optimize().format() == optimized |
| |
| |
| def test_firebolt() -> None: |
| """ |
| Test that Firebolt 3rd party dialect is registered correctly. |
| |
| We need a custom dialect for Firebolt because it parses `NOT col IN (1, 2)` as |
| `(NOT col) IN (1, 2)` instead of `NOT (col IN (1, 2))`, which will fail when `col` |
| is not a boolean. |
| |
| Note that `NOT col = 1` works as expected in Firebolt, parsing as `NOT (col = 1)`. |
| """ |
| sql = "SELECT col NOT IN (1, 2) FROM tbl" |
| assert ( |
| SQLStatement(sql, "firebolt").format() |
| == """ |
| SELECT |
| NOT ( |
| col IN (1, 2) |
| ) |
| FROM tbl |
| """.strip() |
| ) |
| |
| sql = "SELECT NOT col = 1 FROM tbl" |
| assert ( |
| SQLStatement(sql, "firebolt").format() |
| == """ |
| SELECT |
| NOT col = 1 |
| FROM tbl |
| """.strip() |
| ) |
| |
| |
| def test_firebolt_old() -> None: |
| """ |
| Test the dialect for the old Firebolt syntax. |
| """ |
| from superset.sql.dialects import FireboltOld |
| from superset.sql.parse import SQLGLOT_DIALECTS |
| |
| SQLGLOT_DIALECTS["firebolt"] = FireboltOld |
| |
| sql = "SELECT * FROM t1 UNNEST(col1 AS foo)" |
| assert ( |
| SQLStatement(sql, "firebolt").format() |
| == """ |
| SELECT |
| * |
| FROM t1 UNNEST(col1 AS foo) |
| """.strip() |
| ) |
| |
| |
| def test_firebolt_old_escape_string() -> None: |
| """ |
| Test the dialect for the old Firebolt syntax. |
| """ |
| from superset.sql.dialects import FireboltOld |
| from superset.sql.parse import SQLGLOT_DIALECTS |
| |
| SQLGLOT_DIALECTS["firebolt"] = FireboltOld |
| |
| # both '' and \' are valid escape sequences |
| sql = r"SELECT 'foo''bar', 'foo\'bar'" |
| |
| # but they normalize to '' |
| assert ( |
| SQLStatement(sql, "firebolt").format() |
| == """ |
| SELECT |
| 'foo''bar', |
| 'foo''bar' |
| """.strip() |
| ) |
| |
| |
| @pytest.mark.parametrize( |
| "sql, engine, expected", |
| [ |
| ("SELECT * FROM users LIMIT 10", "postgresql", 10), |
| ( |
| """ |
| WITH cte_example AS ( |
| SELECT * FROM my_table |
| LIMIT 100 |
| ) |
| SELECT * FROM cte_example |
| LIMIT 10; |
| """, |
| "postgresql", |
| 10, |
| ), |
| ("SELECT * FROM users ORDER BY id DESC LIMIT 25", "postgresql", 25), |
| ("SELECT * FROM users", "postgresql", None), |
| ("SELECT TOP 5 name FROM employees", "teradatasql", 5), |
| ("SELECT TOP (42) * FROM table_name", "teradatasql", 42), |
| ("select * from table", "postgresql", None), |
| ("select * from mytable limit 10", "postgresql", 10), |
| ( |
| "select * from (select * from my_subquery limit 10) where col=1 limit 20", |
| "postgresql", |
| 20, |
| ), |
| ("select * from (select * from my_subquery limit 10);", "postgresql", None), |
| ( |
| "select * from (select * from my_subquery limit 10) where col=1 limit 20;", |
| "postgresql", |
| 20, |
| ), |
| ("select * from mytable limit 20, 10", "postgresql", 10), |
| ("select * from mytable limit 10 offset 20", "postgresql", 10), |
| ( |
| """ |
| SELECT id, value, i |
| FROM (SELECT * FROM my_table LIMIT 10), |
| LATERAL generate_series(1, value) AS i; |
| """, |
| "postgresql", |
| None, |
| ), |
| # not really valid SQL, but let's roll with it |
| ("SELECT * FROM my_table LIMIT invalid", "postgresql", None), |
| ], |
| ) |
| def test_get_limit_value(sql: str, engine: str, expected: str) -> None: |
| assert SQLStatement(sql, engine).get_limit_value() == expected |
| |
| |
| @pytest.mark.parametrize( |
| "kql, expected", |
| [ |
| ("StormEvents | take 10", 10), |
| ("StormEvents | limit 20", 20), |
| ("StormEvents | where State == 'FL' | summarize count()", None), |
| ("StormEvents | where name has 'limit 10'", None), |
| ("AnotherTable | take 5", 5), |
| ("datatable(x:int) [1, 2, 3] | take 100", 100), |
| ( |
| """ |
| Table1 | where msg contains 'abc;xyz' |
| | limit 5 |
| """, |
| 5, |
| ), |
| ("table | take five", None), |
| ], |
| ) |
| def test_get_kql_limit_value(kql: str, expected: str) -> None: |
| assert KustoKQLStatement(kql, "kustokql").get_limit_value() == expected |
| |
| |
| @pytest.mark.parametrize( |
| "sql, engine, limit, method, expected", |
| [ |
| ( |
| "SELECT * FROM t", |
| "postgresql", |
| 10, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\n *\nFROM t\nLIMIT 10", |
| ), |
| ( |
| "SELECT * FROM t LIMIT 1000", |
| "postgresql", |
| 10, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\n *\nFROM t\nLIMIT 10", |
| ), |
| ( |
| "SELECT * FROM t", |
| "mssql", |
| 10, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\nTOP 10\n *\nFROM t", |
| ), |
| ( |
| "SELECT * FROM t", |
| "teradatasql", |
| 10, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\nTOP 10\n *\nFROM t", |
| ), |
| ( |
| "SELECT * FROM t", |
| "oracle", |
| 10, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\n *\nFROM t\nFETCH FIRST 10 ROWS ONLY", |
| ), |
| ( |
| "SELECT * FROM t", |
| "db2", |
| 10, |
| LimitMethod.WRAP_SQL, |
| "SELECT\n *\nFROM (\n SELECT\n *\n FROM t\n)\nLIMIT 10", |
| ), |
| ( |
| "SEL TOP 1000 * FROM My_table", |
| "teradatasql", |
| 100, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\nTOP 100\n *\nFROM My_table", |
| ), |
| ( |
| "SEL TOP 1000 * FROM My_table;", |
| "teradatasql", |
| 100, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\nTOP 100\n *\nFROM My_table", |
| ), |
| ( |
| "SEL TOP 1000 * FROM My_table;", |
| "teradatasql", |
| 1000, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\nTOP 1000\n *\nFROM My_table", |
| ), |
| ( |
| "SELECT TOP 1000 * FROM My_table;", |
| "teradatasql", |
| 100, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\nTOP 100\n *\nFROM My_table", |
| ), |
| ( |
| "SELECT TOP 1000 * FROM My_table;", |
| "teradatasql", |
| 10000, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\nTOP 10000\n *\nFROM My_table", |
| ), |
| ( |
| "SELECT TOP 1000 * FROM My_table", |
| "mssql", |
| 100, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\nTOP 100\n *\nFROM My_table", |
| ), |
| ( |
| "SELECT TOP 1000 * FROM My_table;", |
| "mssql", |
| 100, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\nTOP 100\n *\nFROM My_table", |
| ), |
| ( |
| "SELECT TOP 1000 * FROM My_table;", |
| "mssql", |
| 10000, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\nTOP 10000\n *\nFROM My_table", |
| ), |
| ( |
| "SELECT TOP 1000 * FROM My_table;", |
| "mssql", |
| 1000, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\nTOP 1000\n *\nFROM My_table", |
| ), |
| ( |
| """ |
| with abc as (select * from test union select * from test1) |
| select TOP 100 * from currency |
| """, |
| "mssql", |
| 1000, |
| LimitMethod.FORCE_LIMIT, |
| """ |
| WITH abc AS ( |
| SELECT |
| * |
| FROM test |
| UNION |
| SELECT |
| * |
| FROM test1 |
| ) |
| SELECT |
| TOP 1000 |
| * |
| FROM currency |
| """.strip(), |
| ), |
| ( |
| "SELECT DISTINCT x from tbl", |
| "mssql", |
| 100, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT DISTINCT\nTOP 100\n x\nFROM tbl", |
| ), |
| ( |
| "SELECT 1 as cnt", |
| "mssql", |
| 10, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\nTOP 10\n 1 AS cnt", |
| ), |
| ( |
| "select TOP 1000 * from abc where id=1", |
| "mssql", |
| 10, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\nTOP 10\n *\nFROM abc\nWHERE\n id = 1", |
| ), |
| ( |
| "SELECT * FROM birth_names -- SOME COMMENT", |
| "postgresql", |
| 1000, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\n *\nFROM birth_names /* SOME COMMENT */\nLIMIT 1000", |
| ), |
| ( |
| "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555", |
| "postgresql", |
| 1000, |
| LimitMethod.FORCE_LIMIT, |
| """ |
| SELECT |
| * |
| FROM birth_names /* SOME COMMENT WITH LIMIT 555 */ |
| LIMIT 1000 |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM birth_names LIMIT 555", |
| "postgresql", |
| 1000, |
| LimitMethod.FORCE_LIMIT, |
| "SELECT\n *\nFROM birth_names\nLIMIT 1000", |
| ), |
| ( |
| "SELECT * FROM birth_names LIMIT 555", |
| "postgresql", |
| 1000, |
| LimitMethod.FETCH_MANY, |
| "SELECT\n *\nFROM birth_names\nLIMIT 555", |
| ), |
| ], |
| ) |
| def test_set_limit_value( |
| sql: str, |
| engine: str, |
| limit: int, |
| method: LimitMethod, |
| expected: str, |
| ) -> None: |
| statement = SQLStatement(sql, engine) |
| statement.set_limit_value(limit, method) |
| assert statement.format() == expected |
| |
| |
| @pytest.mark.parametrize( |
| "kql, limit, expected", |
| [ |
| ("StormEvents | take 10", 100, "StormEvents | take 100"), |
| ("StormEvents | limit 20", 10, "StormEvents | limit 10"), |
| ( |
| "StormEvents | where State == 'FL' | summarize count()", |
| 10, |
| "StormEvents | where State == 'FL' | summarize count() | take 10", |
| ), |
| ( |
| "StormEvents | where name has 'limit 10'", |
| 10, |
| "StormEvents | where name has 'limit 10' | take 10", |
| ), |
| ("AnotherTable | take 5", 50, "AnotherTable | take 50"), |
| ( |
| "datatable(x:int) [1, 2, 3] | take 100", |
| 10, |
| "datatable(x:int) [1, 2, 3] | take 10", |
| ), |
| ( |
| """ |
| Table1 | where msg contains 'abc;xyz' |
| | limit 5 |
| """, |
| 10, |
| """Table1 | where msg contains 'abc;xyz' |
| | limit 10""", |
| ), |
| ], |
| ) |
| def test_set_kql_limit_value(kql: str, limit: int, expected: str) -> None: |
| """ |
| Test the `set_limit_value` method for KustoKQLStatement. |
| """ |
| statement = KustoKQLStatement(kql, "kustokql") |
| statement.set_limit_value(limit) |
| assert statement.format() == expected |
| |
| |
| @pytest.mark.parametrize("method", [LimitMethod.WRAP_SQL, LimitMethod.FETCH_MANY]) |
| def test_set_kql_limit_value_invalid_method(method: LimitMethod) -> None: |
| """ |
| Test that setting a limit value with an invalid method raises an error. |
| """ |
| statement = KustoKQLStatement("foo", "kustokql") |
| |
| with pytest.raises( |
| SupersetParseError, |
| match="Kusto KQL only supports the FORCE_LIMIT method.", |
| ): |
| statement.set_limit_value(10, method) |
| |
| |
| @pytest.mark.parametrize( |
| "sql, engine, expected", |
| [ |
| ("SELECT 1", "postgresql", False), |
| ("SELECT 1 AS cnt", "postgresql", False), |
| ( |
| """ |
| SELECT 'INR' AS cur |
| UNION |
| SELECT 'USD' AS cur |
| UNION |
| SELECT 'EUR' AS cur |
| """, |
| "postgresql", |
| False, |
| ), |
| ("WITH cte AS (SELECT 1) SELECT * FROM cte", "postgresql", True), |
| ( |
| """ |
| WITH |
| x AS (SELECT a FROM t1), |
| y AS (SELECT a AS b FROM t2), |
| z AS (SELECT b AS c FROM t3) |
| SELECT c FROM z |
| """, |
| "postgresql", |
| True, |
| ), |
| ( |
| """ |
| WITH |
| x AS (SELECT a FROM t1), |
| y AS (SELECT a AS b FROM x), |
| z AS (SELECT b AS c FROM y) |
| SELECT c FROM z |
| """, |
| "postgresql", |
| True, |
| ), |
| ( |
| """ |
| WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear) |
| AS ( |
| SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear |
| FROM SalesOrderHeader |
| WHERE SalesPersonID IS NOT NULL |
| ) |
| SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear |
| FROM CTE__test |
| GROUP BY SalesYear, SalesPersonID |
| ORDER BY SalesPersonID, SalesYear; |
| """, |
| "postgresql", |
| True, |
| ), |
| ], |
| ) |
| def test_has_cte(sql: str, engine: str, expected: bool) -> None: |
| """ |
| Test that the parser detects CTEs correctly. |
| """ |
| assert SQLStatement(sql, engine).has_cte() == expected |
| |
| |
| @pytest.mark.parametrize( |
| "sql, engine, expected", |
| [ |
| ( |
| "SELECT 1", |
| "postgresql", |
| "WITH __cte AS (\n SELECT\n 1\n)", |
| ), |
| ( |
| """ |
| WITH currency AS (SELECT 'INR' AS cur), |
| currency_2 AS (SELECT 'USD' AS cur) |
| SELECT * FROM currency |
| UNION ALL |
| SELECT * FROM currency_2 |
| """, |
| "postgresql", |
| """ |
| WITH currency AS ( |
| SELECT |
| 'INR' AS cur |
| ), currency_2 AS ( |
| SELECT |
| 'USD' AS cur |
| ), __cte AS ( |
| SELECT |
| * |
| FROM currency |
| UNION ALL |
| SELECT |
| * |
| FROM currency_2 |
| ) |
| """.strip(), |
| ), |
| ], |
| ) |
| def test_as_cte(sql: str, engine: str, expected: str) -> None: |
| """ |
| Test that we can covert select to CTE. |
| """ |
| assert SQLStatement(sql, engine).as_cte().format() == expected |
| |
| |
| @pytest.mark.parametrize( |
| "sql, rules, expected", |
| [ |
| ( |
| "SELECT t.foo FROM some_table AS t", |
| {Table("some_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| t.foo |
| FROM ( |
| SELECT |
| * |
| FROM some_table |
| WHERE |
| id = 42 |
| ) AS t |
| """.strip(), |
| ), |
| ( |
| "SELECT t.foo FROM some_table AS t", |
| {}, |
| """ |
| SELECT |
| t.foo |
| FROM some_table AS t |
| """.strip(), |
| ), |
| ( |
| "SELECT t.foo FROM some_table AS t WHERE bar = 'baz'", |
| {Table("some_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| t.foo |
| FROM ( |
| SELECT |
| * |
| FROM some_table |
| WHERE |
| id = 42 |
| ) AS t |
| WHERE |
| bar = 'baz' |
| """.strip(), |
| ), |
| ( |
| "SELECT t.foo FROM schema1.some_table AS t", |
| {Table("some_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| t.foo |
| FROM ( |
| SELECT |
| * |
| FROM schema1.some_table |
| WHERE |
| id = 42 |
| ) AS t |
| """.strip(), |
| ), |
| ( |
| "SELECT t.foo FROM schema1.some_table AS t", |
| {Table("some_table", "schema2"): "id = 42"}, |
| "SELECT\n t.foo\nFROM schema1.some_table AS t", |
| ), |
| ( |
| "SELECT t.foo FROM catalog1.schema1.some_table AS t", |
| {Table("some_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| t.foo |
| FROM ( |
| SELECT |
| * |
| FROM catalog1.schema1.some_table |
| WHERE |
| id = 42 |
| ) AS t |
| """.strip(), |
| ), |
| ( |
| "SELECT t.foo FROM catalog1.schema1.some_table AS t", |
| {Table("some_table", "schema1", "catalog2"): "id = 42"}, |
| "SELECT\n t.foo\nFROM catalog1.schema1.some_table AS t", |
| ), |
| ( |
| "SELECT * FROM some_table WHERE 1=1", |
| {Table("some_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM ( |
| SELECT |
| * |
| FROM some_table |
| WHERE |
| id = 42 |
| ) AS "some_table" |
| WHERE |
| 1 = 1 |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM table WHERE 1=1", |
| {Table("table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM ( |
| SELECT |
| * |
| FROM table |
| WHERE |
| id = 42 |
| ) AS "table" |
| WHERE |
| 1 = 1 |
| """.strip(), |
| ), |
| ( |
| 'SELECT * FROM "table" WHERE 1=1', |
| {Table("table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM ( |
| SELECT |
| * |
| FROM "table" |
| WHERE |
| id = 42 |
| ) AS "table" |
| WHERE |
| 1 = 1 |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM table WHERE 1=1", |
| {Table("other_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM table |
| WHERE |
| 1 = 1 |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM other_table WHERE 1=1", |
| {Table("table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM other_table |
| WHERE |
| 1 = 1 |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM table JOIN other_table ON table.id = other_table.id", |
| {Table("other_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM table |
| JOIN ( |
| SELECT |
| * |
| FROM other_table |
| WHERE |
| id = 42 |
| ) AS "other_table" |
| ON table.id = other_table.id |
| """.strip(), |
| ), |
| ( |
| 'SELECT * FROM "table" JOIN other_table ON "table".id = other_table.id', |
| {Table("table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM ( |
| SELECT |
| * |
| FROM "table" |
| WHERE |
| id = 42 |
| ) AS "table" |
| JOIN other_table |
| ON "table".id = other_table.id |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM (SELECT * FROM some_table)", |
| {Table("some_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM ( |
| SELECT |
| * |
| FROM ( |
| SELECT |
| * |
| FROM some_table |
| WHERE |
| id = 42 |
| ) AS "some_table" |
| ) |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM table UNION ALL SELECT * FROM other_table", |
| {Table("table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM ( |
| SELECT |
| * |
| FROM table |
| WHERE |
| id = 42 |
| ) AS "table" |
| UNION ALL |
| SELECT |
| * |
| FROM other_table |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM table UNION ALL SELECT * FROM other_table", |
| {Table("other_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM table |
| UNION ALL |
| SELECT |
| * |
| FROM ( |
| SELECT |
| * |
| FROM other_table |
| WHERE |
| id = 42 |
| ) AS "other_table" |
| """.strip(), |
| ), |
| ( |
| "SELECT a.*, b.* FROM tbl_a AS a INNER JOIN tbl_b AS b ON a.col = b.col", |
| {Table("tbl_a", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| a.*, |
| b.* |
| FROM ( |
| SELECT |
| * |
| FROM tbl_a |
| WHERE |
| id = 42 |
| ) AS a |
| INNER JOIN tbl_b AS b |
| ON a.col = b.col |
| """.strip(), |
| ), |
| ( |
| "SELECT a.*, b.* FROM tbl_a a INNER JOIN tbl_b b ON a.col = b.col", |
| {Table("tbl_a", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| a.*, |
| b.* |
| FROM ( |
| SELECT |
| * |
| FROM tbl_a |
| WHERE |
| id = 42 |
| ) AS a |
| INNER JOIN tbl_b AS b |
| ON a.col = b.col |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM public.flights LIMIT 100", |
| {Table("flights", "public", "catalog1"): "\"AIRLINE\" like 'A%'"}, |
| """ |
| SELECT |
| * |
| FROM ( |
| SELECT |
| * |
| FROM public.flights |
| WHERE |
| "AIRLINE" LIKE 'A%' |
| ) AS "public.flights" |
| LIMIT 100 |
| """.strip(), |
| ), |
| ], |
| ) |
| def test_rls_subquery_transformer( |
| sql: str, |
| rules: dict[Table, str], |
| expected: str, |
| ) -> None: |
| """ |
| Test `RLSAsSubqueryTransformer`. |
| """ |
| statement = SQLStatement(sql) |
| statement.apply_rls( |
| "catalog1", |
| "schema1", |
| {k: [parse_one(v)] for k, v in rules.items()}, |
| RLSMethod.AS_SUBQUERY, |
| ) |
| assert statement.format() == expected |
| |
| |
| def test_rls_invalid_method(mocker: MockerFixture) -> None: |
| """ |
| Test that an invalid RLS method raises an error. |
| """ |
| statement = SQLStatement("SELECT 1", "postgresql") |
| predicates = mocker.MagicMock() |
| |
| with pytest.raises(ValueError, match="Invalid RLS method: invalid"): |
| statement.apply_rls("catalog1", "schema1", predicates, "invalid") # type: ignore |
| |
| |
| @pytest.mark.parametrize( |
| "sql, rules, expected", |
| [ |
| ( |
| "SELECT t.foo FROM some_table AS t", |
| {Table("some_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| t.foo |
| FROM some_table AS t |
| WHERE |
| t.id = 42 |
| """.strip(), |
| ), |
| ( |
| "SELECT t.foo FROM schema2.some_table AS t", |
| {Table("some_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| t.foo |
| FROM schema2.some_table AS t |
| """.strip(), |
| ), |
| ( |
| "SELECT t.foo FROM catalog2.schema1.some_table AS t", |
| {Table("some_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| t.foo |
| FROM catalog2.schema1.some_table AS t |
| """.strip(), |
| ), |
| ( |
| "SELECT t.foo FROM some_table AS t WHERE bar = 'baz'", |
| {Table("some_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| t.foo |
| FROM some_table AS t |
| WHERE |
| t.id = 42 AND ( |
| bar = 'baz' |
| ) |
| """.strip(), |
| ), |
| ( |
| "SELECT t.foo FROM some_table AS t WHERE bar = 'baz' OR foo = 'qux'", |
| {Table("some_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| t.foo |
| FROM some_table AS t |
| WHERE |
| t.id = 42 AND ( |
| bar = 'baz' OR foo = 'qux' |
| ) |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM some_table WHERE 1=1", |
| {Table("some_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM some_table |
| WHERE |
| some_table.id = 42 AND ( |
| 1 = 1 |
| ) |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM some_table WHERE TRUE OR FALSE", |
| {Table("some_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM some_table |
| WHERE |
| some_table.id = 42 AND ( |
| TRUE OR FALSE |
| ) |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM table WHERE 1=1", |
| {Table("table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM table |
| WHERE |
| table.id = 42 AND ( |
| 1 = 1 |
| ) |
| """.strip(), |
| ), |
| ( |
| 'SELECT * FROM "table" WHERE 1=1', |
| {Table("table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM "table" |
| WHERE |
| "table".id = 42 AND ( |
| 1 = 1 |
| ) |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM table WHERE 1=1", |
| {Table("other_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM table |
| WHERE |
| 1 = 1 |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM other_table WHERE 1=1", |
| {Table("table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM other_table |
| WHERE |
| 1 = 1 |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM table", |
| {Table("table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM table |
| WHERE |
| table.id = 42 |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM some_table", |
| {Table("some_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM some_table |
| WHERE |
| some_table.id = 42 |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM table ORDER BY id", |
| {Table("table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM table |
| WHERE |
| table.id = 42 |
| ORDER BY |
| id |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM table WHERE 1=1 AND table.id=42", |
| {Table("table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM table |
| WHERE |
| table.id = 42 AND ( |
| 1 = 1 AND table.id = 42 |
| ) |
| """.strip(), |
| ), |
| ( |
| """ |
| SELECT * FROM table |
| JOIN other_table |
| ON table.id = other_table.id |
| AND other_table.id=42 |
| """, |
| {Table("other_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM table |
| JOIN other_table |
| ON other_table.id = 42 AND ( |
| table.id = other_table.id AND other_table.id = 42 |
| ) |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM table WHERE 1=1 AND id=42", |
| {Table("table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM table |
| WHERE |
| table.id = 42 AND ( |
| 1 = 1 AND id = 42 |
| ) |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM table JOIN other_table ON table.id = other_table.id", |
| {Table("other_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM table |
| JOIN other_table |
| ON other_table.id = 42 AND ( |
| table.id = other_table.id |
| ) |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM table JOIN other_table", |
| {Table("other_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM table |
| JOIN other_table |
| ON other_table.id = 42 |
| """.strip(), |
| ), |
| ( |
| """ |
| SELECT * |
| FROM table |
| JOIN other_table |
| ON table.id = other_table.id |
| WHERE 1=1 |
| """, |
| {Table("other_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM table |
| JOIN other_table |
| ON other_table.id = 42 AND ( |
| table.id = other_table.id |
| ) |
| WHERE |
| 1 = 1 |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM (SELECT * FROM other_table)", |
| {Table("other_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM ( |
| SELECT |
| * |
| FROM other_table |
| WHERE |
| other_table.id = 42 |
| ) |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM table UNION ALL SELECT * FROM other_table", |
| {Table("table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM table |
| WHERE |
| table.id = 42 |
| UNION ALL |
| SELECT |
| * |
| FROM other_table |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM table UNION ALL SELECT * FROM other_table", |
| {Table("other_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| SELECT |
| * |
| FROM table |
| UNION ALL |
| SELECT |
| * |
| FROM other_table |
| WHERE |
| other_table.id = 42 |
| """.strip(), |
| ), |
| ( |
| "INSERT INTO some_table (col1, col2) VALUES (1, 2)", |
| {Table("some_table", "schema1", "catalog1"): "id = 42"}, |
| """ |
| INSERT INTO some_table ( |
| col1, |
| col2 |
| ) |
| VALUES |
| (1, 2) |
| """.strip(), |
| ), |
| ], |
| ) |
| def test_rls_predicate_transformer( |
| sql: str, |
| rules: dict[Table, str], |
| expected: str, |
| ) -> None: |
| """ |
| Test `RLSPredicateTransformer`. |
| """ |
| statement = SQLStatement(sql) |
| statement.apply_rls( |
| "catalog1", |
| "schema1", |
| {k: [parse_one(v)] for k, v in rules.items()}, |
| RLSMethod.AS_PREDICATE, |
| ) |
| assert statement.format() == expected |
| |
| |
| @pytest.mark.parametrize( |
| "sql, table, expected", |
| [ |
| ( |
| "SELECT * FROM some_table", |
| Table("some_table"), |
| """ |
| CREATE TABLE some_table AS |
| SELECT |
| * |
| FROM some_table |
| """.strip(), |
| ), |
| ( |
| "SELECT * FROM some_table", |
| Table("some_table", "schema1", "catalog1"), |
| """ |
| CREATE TABLE catalog1.schema1.some_table AS |
| SELECT |
| * |
| FROM some_table |
| """.strip(), |
| ), |
| ], |
| ) |
| def test_as_create_table(sql: str, table: Table, expected: str) -> None: |
| """ |
| Test the `as_create_table` method. |
| """ |
| statement = SQLStatement(sql) |
| create_table = statement.as_create_table(table, CTASMethod.TABLE) |
| assert create_table.format() == expected |
| |
| |
| @pytest.mark.parametrize( |
| "sql, engine, expected", |
| [ |
| ("SELECT * FROM table", "postgresql", True), |
| ( |
| """ |
| -- comment |
| SELECT * FROM table |
| -- comment 2 |
| """, |
| "mysql", |
| True, |
| ), |
| ( |
| """ |
| -- comment |
| SET @value = 42; |
| SELECT @value as foo; |
| -- comment 2 |
| """, |
| "mysql", |
| True, |
| ), |
| ( |
| """ |
| -- comment |
| EXPLAIN SELECT * FROM table |
| -- comment 2 |
| """, |
| "mysql", |
| False, |
| ), |
| ( |
| """ |
| SELECT * FROM table; |
| INSERT INTO TABLE (foo) VALUES (42); |
| """, |
| "mysql", |
| False, |
| ), |
| ], |
| ) |
| def test_is_valid_ctas(sql: str, engine: str, expected: bool) -> None: |
| """ |
| Test the `is_valid_ctas` method. |
| """ |
| assert SQLScript(sql, engine).is_valid_ctas() == expected |
| |
| |
| @pytest.mark.parametrize( |
| "sql, engine, expected", |
| [ |
| ("SELECT * FROM table", "postgresql", True), |
| ( |
| """ |
| -- comment |
| SELECT * FROM table |
| -- comment 2 |
| """, |
| "mysql", |
| True, |
| ), |
| ( |
| """ |
| -- comment |
| SET @value = 42; |
| SELECT @value as foo; |
| -- comment 2 |
| """, |
| "mysql", |
| False, |
| ), |
| ( |
| """ |
| -- comment |
| SELECT value as foo; |
| -- comment 2 |
| """, |
| "mysql", |
| True, |
| ), |
| ( |
| """ |
| SELECT * FROM table; |
| INSERT INTO TABLE (foo) VALUES (42); |
| """, |
| "mysql", |
| False, |
| ), |
| ], |
| ) |
| def test_is_valid_cvas(sql: str, engine: str, expected: bool) -> None: |
| """ |
| Test the `is_valid_cvas` method. |
| """ |
| assert SQLScript(sql, engine).is_valid_cvas() == expected |
| |
| |
| @pytest.mark.parametrize( |
| "sql, expected, engine", |
| [ |
| ("col = 1", "col = 1", "base"), |
| ("1=\t\n1", "1 = 1", "base"), |
| ("(col = 1)", "(\n col = 1\n)", "base"), |
| ("(col1 = 1) AND (col2 = 2)", "(\n col1 = 1\n) AND (\n col2 = 2\n)", "base"), |
| ("col = 'abc' -- comment", "col = 'abc' /* comment */", "base"), |
| ("col = 'col1 = 1) AND (col2 = 2'", "col = 'col1 = 1) AND (col2 = 2'", "base"), |
| ("col = 'select 1; select 2'", "col = 'select 1; select 2'", "base"), |
| ("col = 'abc -- comment'", "col = 'abc -- comment'", "base"), |
| ("col1 = 1) AND (col2 = 2)", QueryClauseValidationException, "base"), |
| ("(col1 = 1) AND (col2 = 2", QueryClauseValidationException, "base"), |
| ("col1 = 1) AND (col2 = 2", QueryClauseValidationException, "base"), |
| ("(col1 = 1)) AND ((col2 = 2)", QueryClauseValidationException, "base"), |
| ("TRUE; SELECT 1", QueryClauseValidationException, "base"), |
| ], |
| ) |
| def test_sanitize_clause(sql: str, expected: str | Exception, engine: str) -> None: |
| """ |
| Test the `sanitize_clause` function. |
| """ |
| if isinstance(expected, str): |
| assert sanitize_clause(sql, engine) == expected |
| else: |
| with pytest.raises(expected): |
| sanitize_clause(sql, engine) |
| |
| |
| @pytest.mark.parametrize( |
| "engine", |
| [ |
| "hive", |
| "presto", |
| "trino", |
| ], |
| ) |
| @pytest.mark.parametrize( |
| "macro, expected", |
| [ |
| ( |
| "latest_partition('foo.bar')", |
| {Table(table="bar", schema="foo")}, |
| ), |
| ( |
| "latest_partition(' foo.bar ')", # Non-atypical user error which works |
| {Table(table="bar", schema="foo")}, |
| ), |
| ( |
| "latest_partition('foo.%s'|format('bar'))", |
| {Table(table="bar", schema="foo")}, |
| ), |
| ( |
| "latest_sub_partition('foo.bar', baz='qux')", |
| {Table(table="bar", schema="foo")}, |
| ), |
| ( |
| "latest_partition('foo.%s'|format(str('bar')))", |
| set(), |
| ), |
| ( |
| "latest_partition('foo.{}'.format('bar'))", |
| set(), |
| ), |
| ], |
| ) |
| def test_extract_tables_from_jinja_sql( |
| mocker: MockerFixture, |
| engine: str, |
| macro: str, |
| expected: set[Table], |
| ) -> None: |
| assert ( |
| process_jinja_sql( |
| sql=f"'{{{{ {engine}.{macro} }}}}'", |
| database=mocker.MagicMock(backend=engine), |
| ).tables |
| == expected |
| ) |
| |
| |
| @with_feature_flags(ENABLE_TEMPLATE_PROCESSING=False) |
| def test_extract_tables_from_jinja_sql_disabled(mocker: MockerFixture) -> None: |
| """ |
| Test the function when the feature flag is disabled. |
| """ |
| database = mocker.MagicMock() |
| database.db_engine_spec.engine = "mssql" |
| |
| assert process_jinja_sql( |
| sql="SELECT 1 FROM t", |
| database=database, |
| ).tables == {Table("t")} |
| |
| |
| def test_extract_tables_from_jinja_sql_invalid_function(mocker: MockerFixture) -> None: |
| """ |
| Test the function with an invalid function. |
| """ |
| database = mocker.MagicMock(backend="postgresql") |
| |
| processor = JinjaTemplateProcessor(database) |
| processor.env.globals["my_table"] = lambda: "t" |
| mocker.patch( |
| "superset.jinja_context.get_template_processor", |
| return_value=processor, |
| ) |
| |
| assert process_jinja_sql( |
| sql="SELECT * FROM {{ my_table() }}", |
| database=database, |
| ).tables == {Table("t")} |
| |
| |
| def test_process_jinja_sql_result_object_structure(mocker: MockerFixture) -> None: |
| """ |
| Test that process_jinja_sql returns a proper JinjaSQLResult object |
| with correct script and tables properties. |
| """ |
| database = mocker.MagicMock() |
| database.db_engine_spec.engine = "postgresql" |
| |
| result = process_jinja_sql( |
| sql="SELECT id FROM users WHERE active = true", |
| database=database, |
| ) |
| |
| # Test that result is the correct type |
| assert isinstance(result, JinjaSQLResult) |
| |
| # Test that script property returns a SQLScript |
| assert hasattr(result, "script") |
| assert isinstance(result.script, SQLScript) |
| |
| # Test that tables property returns a set of Tables |
| assert hasattr(result, "tables") |
| assert isinstance(result.tables, set) |
| assert result.tables == {Table("users")} |
| |
| # Test that the script contains the expected SQL |
| formatted_sql = result.script.format() |
| assert "users" in formatted_sql |
| assert "active = TRUE" in formatted_sql |
| |
| |
| def test_process_jinja_sql_template_params_parameter(mocker: MockerFixture) -> None: |
| """ |
| Test that the template_params parameter is properly handled. |
| """ |
| database = mocker.MagicMock() |
| database.db_engine_spec.engine = "postgresql" |
| |
| processor = JinjaTemplateProcessor(database) |
| mocker.patch( |
| "superset.jinja_context.get_template_processor", |
| return_value=processor, |
| ) |
| |
| # Test that template_params parameter is accepted and passed through |
| result = process_jinja_sql( |
| sql="SELECT * FROM table_name", |
| database=database, |
| template_params={"param1": "value1"}, |
| ) |
| |
| # Verify the function accepts the parameter without error |
| assert isinstance(result, JinjaSQLResult) |
| assert result.tables == {Table("table_name")} |
| |
| |
| @pytest.mark.parametrize( |
| "sql, engine, expected", |
| [ |
| ("SELECT * FROM users", "postgresql", True), |
| ("WITH cte AS (SELECT * FROM users) SELECT * FROM cte", "postgresql", True), |
| ("CREATE TABLE users AS SELECT * FROM users", "postgresql", False), |
| ("ALTER TABLE users ADD COLUMN age INT", "postgresql", False), |
| ("SET @value = 42", "postgresql", False), |
| ], |
| ) |
| def test_sqlstatement_is_select(sql: str, engine: str, expected: bool) -> None: |
| """ |
| Test the `SQLStatement.is_select()` method. |
| """ |
| assert SQLStatement(sql, engine).is_select() == expected |
| |
| |
| @pytest.mark.parametrize( |
| "kql, expected", |
| [ |
| ("StormEvents | take 10", True), |
| ("StormEvents | limit 20", True), |
| ("StormEvents | where State == 'FL' | summarize count()", True), |
| ("StormEvents | where name has 'limit 10'", True), |
| ("AnotherTable | take 5", True), |
| ("datatable(x:int) [1, 2, 3] | take 100", True), |
| (".create table StormEvents (x:int)", False), |
| (".ingest inline into table StormEvents <| StormEvents | take 10", False), |
| ], |
| ) |
| def test_kqlstatement_is_select(kql: str, expected: bool) -> None: |
| """ |
| Test the `KustoKQLStatement.is_select()` method. |
| """ |
| assert KustoKQLStatement(kql, "kustokql").is_select() == expected |
| |
| |
| def test_remove_quotes() -> None: |
| """ |
| Test the `remove_quotes` helper function. |
| """ |
| assert remove_quotes(None) is None |
| assert remove_quotes('"foo"') == "foo" |
| assert remove_quotes("'foo'") == "foo" |
| assert remove_quotes("`foo`") == "foo" |
| assert remove_quotes("'foo`") == "'foo`" |
| |
| |
| @pytest.mark.parametrize( |
| "sql, engine, expected", |
| [ |
| ("SELECT * FROM table", "postgresql", False), |
| ("SELECT VERSION()", "postgresql", True), |
| ("SELECT query_to_xml()", "postgresql", True), |
| ("WITH cte AS (SELECT * FROM table) SELECT * FROM cte", "postgresql", False), |
| ( |
| """ |
| SELECT * |
| FROM query_to_xml('SELECT * from some_table WHERE id = 42') |
| """, |
| "postgresql", |
| True, |
| ), |
| ("Table | limit 10", "kustokql", False), |
| ], |
| ) |
| def test_check_functions_present(sql: str, engine: str, expected: bool) -> None: |
| """ |
| Check the `check_functions_present` method. |
| """ |
| functions = {"version", "query_to_xml"} |
| assert SQLScript(sql, engine).check_functions_present(functions) == expected |
| |
| |
| @pytest.mark.parametrize( |
| "kql, expected", |
| [ |
| ( |
| "StormEvents | take 10", |
| [ |
| (KQLTokenType.WORD, "StormEvents"), |
| (KQLTokenType.WHITESPACE, " "), |
| (KQLTokenType.OTHER, "|"), |
| (KQLTokenType.WHITESPACE, " "), |
| (KQLTokenType.WORD, "take"), |
| (KQLTokenType.WHITESPACE, " "), |
| (KQLTokenType.NUMBER, "10"), |
| ], |
| ), |
| ("'test'", [(KQLTokenType.STRING, "'test'")]), |
| ("```test```", [(KQLTokenType.STRING, "```test```")]), |
| ], |
| ) |
| def test_tokenize_kql(kql: str, expected: list[tuple[KQLTokenType, str]]) -> None: |
| """ |
| Test the `tokenize_kql` function. |
| """ |
| assert tokenize_kql(kql) == expected |
| |
| |
| @pytest.mark.parametrize( |
| "sql, engine, expected", |
| [ |
| ("a = 1", "postgresql", False), |
| ("(SELECT * FROM table)", "postgresql", True), |
| ("SELECT * FROM table", "postgresql", False), |
| ("SELECT * FROM (SELECT 1)", "postgresql", True), |
| ("SELECT * FROM (SELECT 1) AS subquery", "postgresql", True), |
| ("WITH cte AS (SELECT 1) SELECT * FROM cte", "postgresql", True), |
| ("SELECT * FROM table WHERE EXISTS (SELECT 1)", "postgresql", True), |
| ("SELECT * FROM table WHERE NOT EXISTS (SELECT 1)", "postgresql", True), |
| ( |
| "SELECT * FROM table WHERE id IN (SELECT id FROM other_table)", |
| "postgresql", |
| True, |
| ), |
| ], |
| ) |
| def test_has_subquery(sql: str, engine: str, expected: bool) -> None: |
| """ |
| Test the `has_subquery` method. |
| """ |
| assert SQLStatement(sql, engine).has_subquery() == expected |