blob: 4911d4c0e6c14d0943bfd1092f5c75784de273aa [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, redefined-outer-name, too-many-lines
import pytest
from sqlglot import Dialects
from superset.exceptions import SupersetParseError
from superset.sql.parse import (
extract_tables_from_statement,
KustoKQLStatement,
split_kql,
SQLGLOT_DIALECTS,
SQLScript,
SQLStatement,
Table,
)
def test_table() -> None:
"""
Test the `Table` class and its string conversion.
Special characters in the table, schema, or catalog name should be escaped correctly.
"""
assert str(Table("tbname")) == "tbname"
assert str(Table("tbname", "schemaname")) == "schemaname.tbname"
assert (
str(Table("tbname", "schemaname", "catalogname"))
== "catalogname.schemaname.tbname"
)
assert (
str(Table("table.name", "schema/name", "catalog\nname"))
== "catalog%0Aname.schema%2Fname.table%2Ename"
)
def extract_tables_from_sql(sql: str, engine: str = "postgresql") -> set[Table]:
"""
Helper function to extract tables from SQL.
"""
dialect = SQLGLOT_DIALECTS.get(engine)
return {
table
for statement in SQLScript(sql, engine).statements
for table in extract_tables_from_statement(statement._parsed, dialect)
}
def test_extract_tables_from_sql() -> None:
"""
Test that referenced tables are parsed correctly from the SQL.
"""
assert extract_tables_from_sql("SELECT * FROM tbname") == {Table("tbname")}
assert extract_tables_from_sql("SELECT * FROM tbname foo") == {Table("tbname")}
assert extract_tables_from_sql("SELECT * FROM tbname AS foo") == {Table("tbname")}
# underscore
assert extract_tables_from_sql("SELECT * FROM tb_name") == {Table("tb_name")}
# quotes
assert extract_tables_from_sql('SELECT * FROM "tbname"') == {Table("tbname")}
# unicode
assert extract_tables_from_sql('SELECT * FROM "tb_name" WHERE city = "Lübeck"') == {
Table("tb_name")
}
# columns
assert extract_tables_from_sql("SELECT field1, field2 FROM tb_name") == {
Table("tb_name")
}
assert extract_tables_from_sql("SELECT t1.f1, t2.f2 FROM t1, t2") == {
Table("t1"),
Table("t2"),
}
# named table
assert extract_tables_from_sql(
"SELECT a.date, a.field FROM left_table a LIMIT 10"
) == {Table("left_table")}
assert extract_tables_from_sql(
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;"
) == {Table("forbidden_table")}
assert extract_tables_from_sql(
"select * from (select * from forbidden_table) forbidden_table"
) == {Table("forbidden_table")}
def test_extract_tables_subselect() -> None:
"""
Test that tables inside subselects are parsed correctly.
"""
assert extract_tables_from_sql(
"""
SELECT sub.*
FROM (
SELECT *
FROM s1.t1
WHERE day_of_week = 'Friday'
) sub, s2.t2
WHERE sub.resolution = 'NONE'
"""
) == {Table("t1", "s1"), Table("t2", "s2")}
assert extract_tables_from_sql(
"""
SELECT sub.*
FROM (
SELECT *
FROM s1.t1
WHERE day_of_week = 'Friday'
) sub
WHERE sub.resolution = 'NONE'
"""
) == {Table("t1", "s1")}
assert extract_tables_from_sql(
"""
SELECT * FROM t1
WHERE s11 > ANY (
SELECT COUNT(*) /* no hint */ FROM t2
WHERE NOT EXISTS (
SELECT * FROM t3
WHERE ROW(5*t2.s1,77)=(
SELECT 50,11*s1 FROM t4
)
)
)
"""
) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
def test_extract_tables_select_in_expression() -> None:
"""
Test that parser works with `SELECT`s used as expressions.
"""
assert extract_tables_from_sql("SELECT f1, (SELECT count(1) FROM t2) FROM t1") == {
Table("t1"),
Table("t2"),
}
assert extract_tables_from_sql(
"SELECT f1, (SELECT count(1) FROM t2) as f2 FROM t1"
) == {
Table("t1"),
Table("t2"),
}
def test_extract_tables_parenthesis() -> None:
"""
Test that parenthesis are parsed correctly.
"""
assert extract_tables_from_sql("SELECT f1, (x + y) AS f2 FROM t1") == {Table("t1")}
def test_extract_tables_with_schema() -> None:
"""
Test that schemas are parsed correctly.
"""
assert extract_tables_from_sql("SELECT * FROM schemaname.tbname") == {
Table("tbname", "schemaname")
}
assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname"') == {
Table("tbname", "schemaname")
}
assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname" foo') == {
Table("tbname", "schemaname")
}
assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname" AS foo') == {
Table("tbname", "schemaname")
}
def test_extract_tables_union() -> None:
"""
Test that `UNION` queries work as expected.
"""
assert extract_tables_from_sql("SELECT * FROM t1 UNION SELECT * FROM t2") == {
Table("t1"),
Table("t2"),
}
assert extract_tables_from_sql("SELECT * FROM t1 UNION ALL SELECT * FROM t2") == {
Table("t1"),
Table("t2"),
}
assert extract_tables_from_sql(
"SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
) == {
Table("t1"),
Table("t2"),
}
def test_extract_tables_select_from_values() -> None:
"""
Test that selecting from values returns no tables.
"""
assert extract_tables_from_sql("SELECT * FROM VALUES (13, 42)") == set()
def test_extract_tables_select_array() -> None:
"""
Test that queries selecting arrays work as expected.
"""
assert extract_tables_from_sql(
"""
SELECT ARRAY[1, 2, 3] AS my_array
FROM t1 LIMIT 10
"""
) == {Table("t1")}
def test_extract_tables_select_if() -> None:
"""
Test that queries with an `IF` work as expected.
"""
assert extract_tables_from_sql(
"""
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
FROM t1 LIMIT 10
"""
) == {Table("t1")}
def test_extract_tables_with_catalog() -> None:
"""
Test that catalogs are parsed correctly.
"""
assert extract_tables_from_sql("SELECT * FROM catalogname.schemaname.tbname") == {
Table("tbname", "schemaname", "catalogname")
}
def test_extract_tables_illdefined() -> None:
"""
Test that ill-defined tables return an empty set.
"""
with pytest.raises(SupersetParseError) as excinfo:
extract_tables_from_sql("SELECT * FROM schemaname.")
assert str(excinfo.value) == "Error parsing near '.' at line 1:25"
with pytest.raises(SupersetParseError) as excinfo:
extract_tables_from_sql("SELECT * FROM catalogname.schemaname.")
assert str(excinfo.value) == "Error parsing near '.' at line 1:37"
with pytest.raises(SupersetParseError) as excinfo:
extract_tables_from_sql("SELECT * FROM catalogname..")
assert str(excinfo.value) == "Error parsing near '.' at line 1:27"
with pytest.raises(SupersetParseError) as excinfo:
extract_tables_from_sql('SELECT * FROM "tbname')
assert str(excinfo.value) == "Unable to parse script"
# odd edge case that works
assert extract_tables_from_sql("SELECT * FROM catalogname..tbname") == {
Table(table="tbname", schema=None, catalog="catalogname")
}
def test_extract_tables_show_tables_from() -> None:
"""
Test `SHOW TABLES FROM`.
"""
assert (
extract_tables_from_sql("SHOW TABLES FROM s1 like '%order%'", "mysql") == set()
)
def test_format_show_tables() -> None:
"""
Test format when `ast.sql()` raises an exception.
In that case sqlparse should be used instead.
"""
assert (
SQLScript("SHOW TABLES FROM s1 like '%order%'", "mysql").format()
== "SHOW TABLES FROM s1 LIKE '%order%'"
)
def test_format_no_dialect() -> None:
"""
Test format with an engine that has no corresponding dialect.
"""
assert (
SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "firebolt").format()
== "SELECT col\nFROM t\nWHERE col NOT IN (1,\n 2)"
)
def test_split_no_dialect() -> None:
"""
Test the statement split when the engine has no corresponding dialect.
"""
sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t; SELECT foo"
statements = SQLScript(sql, "firebolt").statements
assert len(statements) == 3
assert statements[0]._sql == "SELECT col FROM t WHERE col NOT IN (1, 2)"
assert statements[1]._sql == "SELECT * FROM t"
assert statements[2]._sql == "SELECT foo"
def test_extract_tables_show_columns_from() -> None:
"""
Test `SHOW COLUMNS FROM`.
"""
assert extract_tables_from_sql("SHOW COLUMNS FROM t1") == {Table("t1")}
def test_extract_tables_where_subquery() -> None:
"""
Test that tables in a `WHERE` subquery are parsed correctly.
"""
assert extract_tables_from_sql(
"""
SELECT name
FROM t1
WHERE regionkey = (SELECT max(regionkey) FROM t2)
"""
) == {Table("t1"), Table("t2")}
assert extract_tables_from_sql(
"""
SELECT name
FROM t1
WHERE regionkey IN (SELECT regionkey FROM t2)
"""
) == {Table("t1"), Table("t2")}
assert extract_tables_from_sql(
"""
SELECT name
FROM t1
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
"""
) == {Table("t1"), Table("t2")}
def test_extract_tables_describe() -> None:
"""
Test `DESCRIBE`.
"""
assert extract_tables_from_sql("DESCRIBE t1") == {Table("t1")}
def test_extract_tables_show_partitions() -> None:
"""
Test `SHOW PARTITIONS`.
"""
assert extract_tables_from_sql(
"""
SHOW PARTITIONS FROM orders
WHERE ds >= '2013-01-01' ORDER BY ds DESC
"""
) == {Table("orders")}
def test_extract_tables_join() -> None:
"""
Test joins.
"""
assert extract_tables_from_sql(
"SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
) == {
Table("t1"),
Table("t2"),
}
assert extract_tables_from_sql(
"""
SELECT a.date, b.name
FROM left_table a
JOIN (
SELECT
CAST((b.year) as VARCHAR) date,
name
FROM right_table
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
assert extract_tables_from_sql(
"""
SELECT a.date, b.name
FROM left_table a
LEFT INNER JOIN (
SELECT
CAST((b.year) as VARCHAR) date,
name
FROM right_table
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
assert extract_tables_from_sql(
"""
SELECT a.date, b.name
FROM left_table a
RIGHT OUTER JOIN (
SELECT
CAST((b.year) as VARCHAR) date,
name
FROM right_table
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
assert extract_tables_from_sql(
"""
SELECT a.date, b.name
FROM left_table a
FULL OUTER JOIN (
SELECT
CAST((b.year) as VARCHAR) date,
name
FROM right_table
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
def test_extract_tables_semi_join() -> None:
"""
Test `LEFT SEMI JOIN`.
"""
assert extract_tables_from_sql(
"""
SELECT a.date, b.name
FROM left_table a
LEFT SEMI JOIN (
SELECT
CAST((b.year) as VARCHAR) date,
name
FROM right_table
) b
ON a.data = b.date
"""
) == {Table("left_table"), Table("right_table")}
def test_extract_tables_combinations() -> None:
"""
Test a complex case with nested queries.
"""
assert extract_tables_from_sql(
"""
SELECT * FROM t1
WHERE s11 > ANY (
SELECT * FROM t1 UNION ALL SELECT * FROM (
SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a
) tmp_join
WHERE NOT EXISTS (
SELECT * FROM t3
WHERE ROW(5*t3.s1,77)=(
SELECT 50,11*s1 FROM t4
)
)
)
"""
) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
assert extract_tables_from_sql(
"""
SELECT * FROM (
SELECT * FROM (
SELECT * FROM (
SELECT * FROM EmployeeS
) AS S1
) AS S2
) AS S3
"""
) == {Table("EmployeeS")}
def test_extract_tables_with() -> None:
"""
Test `WITH`.
"""
assert extract_tables_from_sql(
"""
WITH
x AS (SELECT a FROM t1),
y AS (SELECT a AS b FROM t2),
z AS (SELECT b AS c FROM t3)
SELECT c FROM z
"""
) == {Table("t1"), Table("t2"), Table("t3")}
assert extract_tables_from_sql(
"""
WITH
x AS (SELECT a FROM t1),
y AS (SELECT a AS b FROM x),
z AS (SELECT b AS c FROM y)
SELECT c FROM z
"""
) == {Table("t1")}
def test_extract_tables_reusing_aliases() -> None:
"""
Test that the parser follows aliases.
"""
assert extract_tables_from_sql(
"""
with q1 as ( select key from q2 where key = '5'),
q2 as ( select key from src where key = '5')
select * from (select key from q1) a
"""
) == {Table("src")}
# weird query with circular dependency
assert (
extract_tables_from_sql(
"""
with src as ( select key from q2 where key = '5'),
q2 as ( select key from src where key = '5')
select * from (select key from src) a
"""
)
== set()
)
def test_extract_tables_multistatement() -> None:
"""
Test that the parser works with multiple statements.
"""
assert extract_tables_from_sql("SELECT * FROM t1; SELECT * FROM t2") == {
Table("t1"),
Table("t2"),
}
assert extract_tables_from_sql("SELECT * FROM t1; SELECT * FROM t2;") == {
Table("t1"),
Table("t2"),
}
assert extract_tables_from_sql(
"ADD JAR file:///hive.jar; SELECT * FROM t1;",
engine="hive",
) == {Table("t1")}
def test_extract_tables_complex() -> None:
"""
Test a few complex queries.
"""
assert extract_tables_from_sql(
"""
SELECT sum(m_examples) AS "sum__m_example"
FROM (
SELECT
COUNT(DISTINCT id_userid) AS m_examples,
some_more_info
FROM my_b_table b
JOIN my_t_table t ON b.ds=t.ds
JOIN my_l_table l ON b.uid=l.uid
WHERE
b.rid IN (
SELECT other_col
FROM inner_table
)
AND l.bla IN ('x', 'y')
GROUP BY 2
ORDER BY 2 ASC
) AS "meh"
ORDER BY "sum__m_example" DESC
LIMIT 10;
"""
) == {
Table("my_l_table"),
Table("my_b_table"),
Table("my_t_table"),
Table("inner_table"),
}
assert extract_tables_from_sql(
"""
SELECT *
FROM table_a AS a, table_b AS b, table_c as c
WHERE a.id = b.id and b.id = c.id
"""
) == {Table("table_a"), Table("table_b"), Table("table_c")}
assert extract_tables_from_sql(
"""
SELECT somecol AS somecol
FROM (
WITH bla AS (
SELECT col_a
FROM a
WHERE
1=1
AND column_of_choice NOT IN (
SELECT interesting_col
FROM b
)
),
rb AS (
SELECT yet_another_column
FROM (
SELECT a
FROM c
GROUP BY the_other_col
) not_table
LEFT JOIN bla foo
ON foo.prop = not_table.bad_col0
WHERE 1=1
GROUP BY
not_table.bad_col1 ,
not_table.bad_col2 ,
ORDER BY not_table.bad_col_3 DESC ,
not_table.bad_col4 ,
not_table.bad_col5
)
SELECT random_col
FROM d
WHERE 1=1
UNION ALL SELECT even_more_cols
FROM e
WHERE 1=1
UNION ALL SELECT lets_go_deeper
FROM f
WHERE 1=1
WHERE 2=2
GROUP BY last_col
LIMIT 50000
)
"""
) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
def test_extract_tables_mixed_from_clause() -> None:
"""
Test that the parser handles a `FROM` clause with table and subselect.
"""
assert extract_tables_from_sql(
"""
SELECT *
FROM table_a AS a, (select * from table_b) AS b, table_c as c
WHERE a.id = b.id and b.id = c.id
"""
) == {Table("table_a"), Table("table_b"), Table("table_c")}
def test_extract_tables_nested_select() -> None:
"""
Test that the parser handles selects inside functions.
"""
assert extract_tables_from_sql(
"""
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
""",
"mysql",
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
assert extract_tables_from_sql(
"""
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
""",
"mysql",
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
def test_extract_tables_complex_cte_with_prefix() -> None:
"""
Test that the parser handles CTEs with prefixes.
"""
assert extract_tables_from_sql(
"""
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
AS (
SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
FROM SalesOrderHeader
WHERE SalesPersonID IS NOT NULL
)
SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear
FROM CTE__test
GROUP BY SalesYear, SalesPersonID
ORDER BY SalesPersonID, SalesYear;
"""
) == {Table("SalesOrderHeader")}
def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
"""
Test that aliases that are keywords are parsed correctly.
"""
assert extract_tables_from_sql(
"""
WITH
f AS (SELECT * FROM foo),
match AS (SELECT * FROM f)
SELECT * FROM match
"""
) == {Table("foo")}
def test_sqlscript() -> None:
"""
Test the `SQLScript` class.
"""
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
assert len(script.statements) == 2
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
assert script.statements[0].format() == "SELECT\n 1"
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
assert script.get_settings() == {"a": "2"}
query = SQLScript(
"""set querytrace;
Events | take 100""",
"kustokql",
)
assert query.get_settings() == {"querytrace": True}
def test_sqlstatement() -> None:
"""
Test the `SQLStatement` class.
"""
statement = SQLStatement(
"SELECT * FROM table1 UNION ALL SELECT * FROM table2",
"sqlite",
)
assert statement.tables == {
Table(table="table1", schema=None, catalog=None),
Table(table="table2", schema=None, catalog=None),
}
assert (
statement.format()
== "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2"
)
statement = SQLStatement("SET a=1", "sqlite")
assert statement.get_settings() == {"a": "1"}
def test_kustokqlstatement_split_script() -> None:
"""
Test the `KustoKQLStatement` split method.
"""
statements = KustoKQLStatement.split_script(
"""
let totalPagesPerDay = PageViews
| summarize by Page, Day = startofday(Timestamp)
| summarize count() by Day;
let materializedScope = PageViews
| summarize by Page, Day = startofday(Timestamp);
let cachedResult = materialize(materializedScope);
cachedResult
| project Page, Day1 = Day
| join kind = inner
(
cachedResult
| project Page, Day2 = Day
)
on Page
| where Day2 > Day1
| summarize count() by Day1, Day2
| join kind = inner
totalPagesPerDay
on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
""",
"kustokql",
)
assert len(statements) == 4
def test_kustokqlstatement_with_program() -> None:
"""
Test the `KustoKQLStatement` split method when the KQL has a program.
"""
statements = KustoKQLStatement.split_script(
"""
print program = ```
public class Program {
public static void Main() {
System.Console.WriteLine("Hello!");
}
}```
""",
"kustokql",
)
assert len(statements) == 1
def test_kustokqlstatement_with_set() -> None:
"""
Test the `KustoKQLStatement` split method when the KQL has a set command.
"""
statements = KustoKQLStatement.split_script(
"""
set querytrace;
Events | take 100
""",
"kustokql",
)
assert len(statements) == 2
assert statements[0].format() == "set querytrace"
assert statements[1].format() == "Events | take 100"
@pytest.mark.parametrize(
"kql,statements",
[
('print banner=strcat("Hello", ", ", "World!")', 1),
(r"print 'O\'Malley\'s'", 1),
(r"print 'O\'Mal;ley\'s'", 1),
("print ```foo;\nbar;\nbaz;```\n", 1),
],
)
def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
assert len(KustoKQLStatement.split_script(kql, "kustokql")) == statements
def test_split_kql() -> None:
"""
Test the `split_kql` function.
"""
kql = """
let totalPagesPerDay = PageViews
| summarize by Page, Day = startofday(Timestamp)
| summarize count() by Day;
let materializedScope = PageViews
| summarize by Page, Day = startofday(Timestamp);
let cachedResult = materialize(materializedScope);
cachedResult
| project Page, Day1 = Day
| join kind = inner
(
cachedResult
| project Page, Day2 = Day
)
on Page
| where Day2 > Day1
| summarize count() by Day1, Day2
| join kind = inner
totalPagesPerDay
on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
"""
assert split_kql(kql) == [
"""
let totalPagesPerDay = PageViews
| summarize by Page, Day = startofday(Timestamp)
| summarize count() by Day""",
"""
let materializedScope = PageViews
| summarize by Page, Day = startofday(Timestamp)""",
"""
let cachedResult = materialize(materializedScope)""",
"""
cachedResult
| project Page, Day1 = Day
| join kind = inner
(
cachedResult
| project Page, Day2 = Day
)
on Page
| where Day2 > Day1
| summarize count() by Day1, Day2
| join kind = inner
totalPagesPerDay
on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
""",
]
@pytest.mark.parametrize(
("engine", "sql", "expected"),
[
("sqlite", "SELECT 1", False),
("sqlite", "INSERT INTO foo VALUES (1)", True),
("sqlite", "UPDATE foo SET bar = 2 WHERE id = 1", True),
("sqlite", "DELETE FROM foo WHERE id = 1", True),
("sqlite", "CREATE TABLE foo (id INT, bar TEXT)", True),
("sqlite", "DROP TABLE foo", True),
("sqlite", "EXPLAIN SELECT * FROM foo", False),
("sqlite", "PRAGMA table_info(foo)", False),
("postgresql", "SELECT 1", False),
("postgresql", "INSERT INTO foo (id, bar) VALUES (1, 'test')", True),
("postgresql", "UPDATE foo SET bar = 'new' WHERE id = 1", True),
("postgresql", "DELETE FROM foo WHERE id = 1", True),
("postgresql", "CREATE TABLE foo (id SERIAL PRIMARY KEY, bar TEXT)", True),
("postgresql", "DROP TABLE foo", True),
("postgresql", "EXPLAIN ANALYZE SELECT * FROM foo", False),
("postgresql", "EXPLAIN ANALYZE DELETE FROM foo", True),
("postgresql", "SHOW search_path", False),
("postgresql", "SET search_path TO public", False),
(
"postgres",
"""
with source as (
select 1 as one
)
select * from source
""",
False,
),
("trino", "SELECT 1", False),
("trino", "INSERT INTO foo VALUES (1, 'bar')", True),
("trino", "UPDATE foo SET bar = 'baz' WHERE id = 1", True),
("trino", "DELETE FROM foo WHERE id = 1", True),
("trino", "CREATE TABLE foo (id INT, bar VARCHAR)", True),
("trino", "DROP TABLE foo", True),
("trino", "EXPLAIN SELECT * FROM foo", False),
("trino", "SHOW SCHEMAS", False),
("trino", "SET SESSION optimization_level = '3'", False),
("kustokql", "tbl | limit 100", False),
("kustokql", "let foo = 1; tbl | where bar == foo", False),
("kustokql", ".show tables", False),
("kustokql", "print 1", False),
("kustokql", "set querytrace; Events | take 100", False),
("kustokql", ".drop table foo", True),
("kustokql", ".set-or-append table foo <| bar", True),
("base", "SHOW LOCKS test EXTENDED", False),
("base", "SET hivevar:desc='Legislators'", False),
("base", "UPDATE t1 SET col1 = NULL", True),
("base", "EXPLAIN SELECT 1", False),
("base", "SELECT 1", False),
("base", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
("base", "SHOW CATALOGS", False),
("base", "SHOW TABLES", False),
("hive", "UPDATE t1 SET col1 = NULL", True),
("hive", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True),
("hive", "SHOW LOCKS test EXTENDED", False),
("hive", "SET hivevar:desc='Legislators'", False),
("hive", "EXPLAIN SELECT 1", False),
("hive", "SELECT 1", False),
("hive", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
("presto", "SET hivevar:desc='Legislators'", False),
("presto", "UPDATE t1 SET col1 = NULL", True),
("presto", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True),
("presto", "SHOW LOCKS test EXTENDED", False),
("presto", "EXPLAIN SELECT 1", False),
("presto", "SELECT 1", False),
("presto", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
],
)
def test_has_mutation(engine: str, sql: str, expected: bool) -> None:
"""
Test the `has_mutation` method.
"""
assert SQLScript(sql, engine).has_mutation() == expected
def test_get_settings() -> None:
"""
Test `get_settings` in some edge cases.
"""
sql = """
set
-- this is a tricky comment
search_path -- another one
= bar;
SELECT * FROM some_table;
"""
assert SQLScript(sql, "postgresql").get_settings() == {"search_path": "bar"}
@pytest.mark.parametrize(
"app",
[{"SQLGLOT_DIALECTS_EXTENSIONS": {"custom": Dialects.MYSQL}}],
indirect=True,
)
def test_custom_dialect(app: None) -> None:
"""
Test that custom dialects are loaded correctly.
"""
assert SQLGLOT_DIALECTS.get("custom") == Dialects.MYSQL
@pytest.mark.parametrize(
"engine",
[
"ascend",
"awsathena",
"base",
"bigquery",
"clickhouse",
"clickhousedb",
"cockroachdb",
"couchbase",
"crate",
"databend",
"databricks",
"db2",
"denodo",
"dremio",
"drill",
"druid",
"duckdb",
"dynamodb",
"elasticsearch",
"exa",
"firebird",
"firebolt",
"gsheets",
"hana",
"hive",
"ibmi",
"impala",
"kustokql",
"kustosql",
"kylin",
"mariadb",
"motherduck",
"mssql",
"mysql",
"netezza",
"oceanbase",
"ocient",
"odelasticsearch",
"oracle",
"pinot",
"postgresql",
"presto",
"pydoris",
"redshift",
"risingwave",
"rockset",
"shillelagh",
"snowflake",
"solr",
"sqlite",
"starrocks",
"superset",
"teradatasql",
"trino",
"vertica",
],
)
def test_is_mutating(engine: str) -> None:
"""
Global tests for `is_mutating`, covering all supported engines.
"""
assert not SQLStatement(
"with source as ( select 1 as one ) select * from source",
engine=engine,
).is_mutating()