WIP
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index c38a008..05be7d3 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -172,7 +172,9 @@
PHYSICAL = "physical"
-class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods
+class BaseDatasource(
+ AuditMixinNullable, ImportExportMixin
+): # pylint: disable=too-many-public-methods
"""A common interface to objects that are queryable
(tables and datasources)"""
diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py
index d5cc86c..3521dee 100644
--- a/superset/db_engine_specs/mssql.py
+++ b/superset/db_engine_specs/mssql.py
@@ -49,7 +49,7 @@
class MssqlEngineSpec(BaseEngineSpec):
engine = "mssql"
engine_name = "Microsoft SQL Server"
- limit_method = LimitMethod.WRAP_SQL
+ limit_method = LimitMethod.FORCE_LIMIT
max_column_name_length = 128
allows_cte_in_subquery = False
allow_limit_clause = False
diff --git a/superset/db_engine_specs/teradata.py b/superset/db_engine_specs/teradata.py
index 887add2..910ac94 100644
--- a/superset/db_engine_specs/teradata.py
+++ b/superset/db_engine_specs/teradata.py
@@ -23,7 +23,7 @@
engine = "teradatasql"
engine_name = "Teradata"
- limit_method = LimitMethod.WRAP_SQL
+ limit_method = LimitMethod.FORCE_LIMIT
max_column_name_length = 30 # since 14.10 this is 128
allow_limit_clause = False
select_keywords = {"SELECT", "SEL"}
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index b841426..cd8b8d8 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -17,6 +17,8 @@
# pylint: disable=too-many-lines
"""a collection of model-related helper classes and functions"""
+from __future__ import annotations
+
import builtins
import dataclasses
import logging
@@ -806,7 +808,7 @@
def get_sqla_row_level_filters(
self,
- template_processor: Optional[BaseTemplateProcessor] = None,
+ template_processor: BaseTemplateProcessor | None = None,
) -> list[TextClause]:
"""
Return the appropriate row level security filters for this table and the
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index f20bff3..f032ce1 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -53,9 +53,8 @@
from superset.result_set import SupersetResultSet
from superset.sql_parse import (
CtasMethod,
- insert_rls_as_subquery,
- insert_rls_in_predicate,
ParsedQuery,
+ SQLStatement,
Table,
)
from superset.sqllab.limiting_factor import LimitingFactor
@@ -205,67 +204,49 @@
database: Database = query.database
db_engine_spec = database.db_engine_spec
- parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
+ parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine)
+
if is_feature_enabled("RLS_IN_SQLLAB"):
- # There are two ways to insert RLS: either replacing the table with a subquery
- # that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
- # safer, but not supported in all databases.
- insert_rls = (
- insert_rls_as_subquery
- if database.db_engine_spec.allows_subqueries
- and database.db_engine_spec.allows_alias_in_select
- else insert_rls_in_predicate
- )
+ default_schema = database.get_default_schema_for_query(query)
+ parsed_statement = parsed_statement.apply_rls(query.catalog, default_schema)
- # Insert any applicable RLS predicates
- parsed_query = ParsedQuery(
- str(
- insert_rls(
- parsed_query._parsed[0], # pylint: disable=protected-access
- database.id,
- query.schema,
- )
- ),
- engine=db_engine_spec.engine,
- )
-
- sql = parsed_query.stripped()
-
- # This is a test to see if the query is being
- # limited by either the dropdown or the sql.
- # We are testing to see if more rows exist than the limit.
- increased_limit = None if query.limit is None else query.limit + 1
-
- if not db_engine_spec.is_readonly_query(parsed_query) and not database.allow_dml:
+ if parsed_statement.is_dml() and not database.allow_dml:
raise SupersetErrorException(
SupersetError(
- message=__("Only SELECT statements are allowed against this database."),
+ message=__("DML statements are not allowed in this database."),
error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
level=ErrorLevel.ERROR,
)
)
+
if apply_ctas:
if not query.tmp_table_name:
start_dttm = datetime.fromtimestamp(query.start_time)
query.tmp_table_name = (
f'tmp_{query.user_id}_table_{start_dttm.strftime("%Y_%m_%d_%H_%M_%S")}'
)
- sql = parsed_query.as_create_table(
- query.tmp_table_name,
- schema_name=query.tmp_schema_name,
+ parsed_statement = parsed_statement.as_create_table(
+ Table(query.tmp_table_name, query.tmp_schema_name, query.catalog),
method=query.ctas_method,
)
query.select_as_cta_used = True
+ increased_limit = None if query.limit is None else query.limit + 1
+
# Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
- if db_engine_spec.is_select_query(parsed_query) and not (
+ if parsed_statement.is_select() and not (
query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT
):
if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
query.limit = SQL_MAX_ROW
- sql = apply_limit_if_exists(database, increased_limit, query, sql)
+
+ if query.limit:
+ # Increase limit by one so we can test if there are more rows when the
+ # database returns exactly the number of rows requested by the user.
+ parsed_statement = parsed_statement.apply_limit(increased_limit)
# Hook to allow environment-specific mutation (usually comments) to the SQL
+ sql = parsed_statement.format(strip=True)
sql = database.mutate_sql_based_on_config(sql)
try:
query.executed_sql = sql
@@ -333,19 +314,6 @@
return SupersetResultSet(data, cursor_description, db_engine_spec)
-def apply_limit_if_exists(
- database: Database, increased_limit: Optional[int], query: Query, sql: str
-) -> str:
- if query.limit and increased_limit:
- # We are fetching one more than the requested limit in order
- # to test whether there are more rows than the limit. According to the DB
- # Engine support it will choose top or limit parse
- # Later, the extra row will be dropped before sending
- # the results back to the user.
- sql = database.apply_limit_to_sql(sql, increased_limit, force=True)
- return sql
-
-
def _serialize_payload(
payload: dict[Any, Any], use_msgpack: Optional[bool] = False
) -> Union[bytes, str]:
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 8046e8f..4a041f5 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -32,7 +32,7 @@
from flask_babel import gettext as __
from jinja2 import nodes
from sqlalchemy import and_
-from sqlglot import exp, parse, parse_one
+from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError, SqlglotError
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
@@ -308,7 +308,7 @@
return set()
try:
- pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
+ pseudo_query = sqlglot.parse_one(f"SELECT {literal.this}", dialect=dialect)
except ParseError:
return set()
sources = pseudo_query.find_all(exp.Table)
@@ -433,7 +433,7 @@
"""
raise NotImplementedError()
- def format(self, comments: bool = True) -> str:
+ def format(self, comments: bool = True, strip: bool = False) -> str:
"""
Format the statement, optionally ommitting comments.
"""
@@ -451,10 +451,93 @@
"""
raise NotImplementedError()
+ def apply_rls(
+ self,
+ catalog: str | None,
+ schema: str | None,
+ ) -> InternalRepresentation:
+ """
+ Apply Row Level Security to the SQL.
+
+ :param database: The database where the SQL will run
+ :param catalog: The default catalog for non-qualified table names
+ :param schema: The default schema for non-qualified table names
+ :return: The SQL with RLS applied
+ """
+ raise NotImplementedError()
+
def __str__(self) -> str:
return self.format()
+class RLSAsPredicate:
+ """
+ Apply Row Level Security role as a predicate.
+
+ This transformer will apply any RLS predicates to the relevant tables. For example,
+ given the RLS rule:
+
+ table: some_table
+ clause: id = 42
+
+ If a user subject to the rule runs the following query:
+
+ SELECT foo FROM some_table WHERE bar = 'baz'
+
+ The query will be modified to:
+
+ SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42
+
+ This approach is probably less secure than using subqueries, so it's only used for
+ databases without support for subqueries.
+ """
+
+ def __init__(self, rules: dict[Table, str]) -> None:
+ self.rules = rules
+
+ def __call__(self, node: exp.Expression) -> exp.Expression:
+ if not isinstance(node, exp.Select):
+ return node
+
+ table_node = node.find(exp.Table)
+ if not table_node:
+ return node
+
+ table = Table(
+ str(table_node.this),
+ str(table_node.db) if table_node.db else None,
+ str(table_node.catalog) if table_node.catalog else None,
+ )
+ if predicate := self.rules.get(table):
+ if where := node.args.get("where"):
+ predicate = exp.And(this=predicate, expression=where.this)
+
+ node.set("where", exp.Where(this=predicate))
+
+ return node
+
+
+class RLSAsSubquery:
+ def __init__(self, rules: dict[Table, str]) -> None:
+ self.rules = rules
+
+ def __call__(self, node: exp.Expression) -> exp.Expression:
+ if not isinstance(node, exp.Table):
+ return node
+
+ table = Table(
+ str(node.this),
+ str(node.db) if node.db else None,
+ str(node.catalog) if node.catalog else None,
+ )
+ if predicate := self.rules.get(table):
+ alias = node.alias
+ node.set("alias", None)
+ return f"(SELECT * FROM {node} WHERE {predicate}) AS {alias}"
+
+ return node
+
+
class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
A SQL statement.
@@ -521,12 +604,19 @@
dialect = SQLGLOT_DIALECTS.get(engine)
return extract_tables_from_statement(parsed, dialect)
- def format(self, comments: bool = True) -> str:
+ def format(self, comments: bool = True, strip: bool = False) -> str:
"""
Pretty-format the SQL statement.
"""
write = Dialect.get_or_raise(self._dialect)
- return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
+ output = write.generate(
+ self._parsed,
+ copy=False,
+ comments=comments,
+ pretty=True,
+ )
+
+ return output.strip(" \t\r\n;") if strip else output
def get_settings(self) -> dict[str, str | bool]:
"""
@@ -543,6 +633,186 @@
for eq in set_item.find_all(exp.EQ)
}
+ def apply_rls(
+ self,
+ catalog: str | None,
+ schema: str | None,
+ ) -> SQLStatement:
+ """
+ Apply Row Level Security to the SQL.
+
+ :param catalog: The default catalog for non-qualified table names
+ :param schema: The default schema for non-qualified table names
+ :return: The SQL with RLS applied
+ """
+ from superset.db_engine_specs import load_engine_specs
+
+ statement = self._parsed.copy()
+
+ # collect all relevant RLS rules
+ rules = {}
+ for table in self.tables:
+ if rls := self._get_rls_for_table(table, catalog, schema):
+ rules[table] = rls
+
+ if not rules:
+ return statement
+
+ use_subquery = all(
+ engine_spec.allows_subqueries
+ for engine_spec in load_engine_specs()
+ if engine_spec.engine == self.engine
+ )
+ transformer = RLSAsSubquery(rules) if use_subquery else RLSAsPredicate(rules)
+
+ return SQLStatement(statement.transform(transformer), self.engine)
+
+ def _get_rls_for_table(
+ self,
+ database: Database,
+ table: Table,
+ catalog: str | None,
+ schema: str | None,
+ ) -> exp.Expression | None:
+ """
+ Get the RLS for a table.
+
+ :param table: The table to get the RLS for
+ :param catalog: The default catalog for non-qualified table names
+ :param schema: The default schema for non-qualified table names
+ :return: The RLS for the table
+ """
+ # pylint: disable=import-outside-toplevel
+ from superset import db
+ from superset.connectors.sqla.models import SqlaTable
+
+ dataset = db.session.query(SqlaTable).filter(
+ and_(
+ SqlaTable.database_id == database.id,
+ SqlaTable.catalog == table.catalog or catalog,
+ SqlaTable.schema == table.schema or schema,
+ SqlaTable.table_name == table.table,
+ ).one_or_none()
+ )
+ if not dataset:
+ return None
+
+ filters = dataset.get_sqla_row_level_filters()
+ if not filters:
+ return None
+
+ rls = and_(*filters).compile(
+ dialect=database.get_dialect(),
+ compile_kwargs={"literal_binds": True},
+ )
+
+ return sqlglot.parse_one(str(rls), dialect=self._dialect)
+
+ def is_dml(self) -> bool:
+ """
+ Check if the statement is DML.
+
+ :return: True if the statement is DML
+ """
+ for node in self._parsed.walk():
+ if isinstance(
+ node,
+ (
+ exp.Insert,
+ exp.Update,
+ exp.Delete,
+ exp.Merge,
+ exp.Create,
+ exp.Alter,
+ exp.Drop,
+ exp.TruncateTable,
+ ),
+ ):
+ return True
+
+ return False
+
+ def as_create_table(self, table: Table, method: CtasMethod) -> SQLStatement:
+ """
+ Convert the statement to a CREATE TABLE statement.
+ """
+ create_table = exp.Create(
+ this=sqlglot.parse_one(table, into=exp.Table),
+ kind=method.value,
+ expression=self._parsed.copy(),
+ )
+
+ return SQLStatement(create_table, self.engine)
+
+ def is_select(self) -> bool:
+ """
+ Check if the statement is a SELECT statement.
+
+ :return: True if the statement is a SELECT statement
+ """
+ return isinstance(self._parsed, exp.Select)
+
+ def apply_limit(self, limit: int, force: bool = False) -> SQLStatement:
+ """
+ Apply a limit to the SQL.
+
+ There are 3 strategies to limit queries, defined in the DB engine spec:
+
+ 1. `FORCE_LIMIT`: a limit is added to the query, or the existing one is
+ replaced. This is the most efficient, since the database will produce at
+ most the number of rows that Superset will display.
+ 2. `WRAP_SQL`: the query is wrapped in a subquery, and the limit is applied
+ to the outer query. This might be inneficient, since the database
+ optimizer might not be able to push the limit down to the inner query.
+ 3. `FETCH_MANY`: no limit is applied, but only `LIMIT` rows are fetched from
+ the database. This is the least efficient, unless the database computes
+ rows as they are read by the cursor, which is unlikely.
+
+ :param limit: The limit to apply
+ :param force: Apply limit even when a lower one is present
+ :return: The SQL with the limit applied
+ """
+ from superset.db_engine_specs import load_engine_specs
+ from superset.db_engine_specs.base import LimitMethod
+
+ methods = {
+ engine_spec.limit_method
+ for engine_spec in load_engine_specs()
+ if engine_spec.engine == self.engine
+ }
+ if not methods:
+ methods = {LimitMethod.FETCH_MANY}
+
+ # When multiple methods are supported, we prefer the more generic one --
+ # usually less efficient.
+ preference = [
+ LimitMethod.FETCH_MANY,
+ LimitMethod.WRAP_SQL,
+ LimitMethod.FORCE_LIMIT,
+ ]
+ method = sorted(methods, key=preference.index)[0]
+
+ if not self.is_select() or method == LimitMethod.FETCH_MANY:
+ return SQLStatement(self._parsed.copy(), self.engine)
+
+ if method == LimitMethod.WRAP_SQL:
+ limited = exp.Select(
+ expressions=[exp.Star()],
+ from_=exp.Subquery(subquery=self._parsed.copy(), alias="inner_qry"),
+ limit=exp.Literal.number(limit),
+ )
+ return SQLStatement(limited, self.engine)
+
+ current_limit: int | None = None
+ for node in self._parsed.find_all(exp.Limit):
+ current_limit = int(node.expression.this)
+ break
+
+ if force or current_limit is None or limit < current_limit:
+ return SQLStatement(self._parsed.limit(limit), self.engine)
+
+ return SQLStatement(self._parsed.copy(), self.engine)
+
class KQLSplitState(enum.Enum):
"""
@@ -666,11 +936,11 @@
)
return set()
- def format(self, comments: bool = True) -> str:
+ def format(self, comments: bool = True, strip: bool = False) -> str:
"""
Pretty-format the SQL statement.
"""
- return self._parsed
+ return self._parsed.strip(" \t\r\n;") if strip else self._parsed
def get_settings(self) -> dict[str, str | bool]:
"""
@@ -712,7 +982,10 @@
"""
Pretty-format the SQL query.
"""
- return ";\n".join(statement.format(comments) for statement in self.statements)
+ return (
+ ";\n".join(statement.format(comments) for statement in self.statements)
+ + ";"
+ )
def get_settings(self) -> dict[str, str | bool]:
"""
@@ -792,7 +1065,7 @@
Note: this uses sqlglot, since it's better at catching more edge cases.
"""
try:
- statements = parse(self.stripped(), dialect=self._dialect)
+ statements = sqlglot.parse(self.stripped(), dialect=self._dialect)
except SqlglotError as ex:
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
@@ -846,7 +1119,7 @@
return set()
try:
- pseudo_query = parse_one(
+ pseudo_query = sqlglot.parse_one(
f"SELECT {literal.this}",
dialect=self._dialect,
)
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index 6259d62..204455c 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -20,6 +20,7 @@
from unittest.mock import Mock
import pytest
+import sqlglot
import sqlparse
from pytest_mock import MockerFixture
from sqlalchemy import text
@@ -41,6 +42,8 @@
insert_rls_in_predicate,
KustoKQLStatement,
ParsedQuery,
+ RLSAsPredicate,
+ RLSAsSubquery,
sanitize_clause,
split_kql,
SQLScript,
@@ -119,8 +122,9 @@
"""
Test that tables inside subselects are parsed correctly.
"""
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT sub.*
FROM (
SELECT *
@@ -129,10 +133,13 @@
) sub, s2.t2
WHERE sub.resolution = 'NONE'
"""
- ) == {Table("t1", "s1"), Table("t2", "s2")}
+ )
+ == {Table("t1", "s1"), Table("t2", "s2")}
+ )
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT sub.*
FROM (
SELECT *
@@ -141,10 +148,13 @@
) sub
WHERE sub.resolution = 'NONE'
"""
- ) == {Table("t1", "s1")}
+ )
+ == {Table("t1", "s1")}
+ )
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT * FROM t1
WHERE s11 > ANY (
SELECT COUNT(*) /* no hint */ FROM t2
@@ -156,7 +166,9 @@
)
)
"""
- ) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
+ )
+ == {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
+ )
def test_extract_tables_select_in_expression() -> None:
@@ -227,24 +239,30 @@
"""
Test that queries selecting arrays work as expected.
"""
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT ARRAY[1, 2, 3] AS my_array
FROM t1 LIMIT 10
"""
- ) == {Table("t1")}
+ )
+ == {Table("t1")}
+ )
def test_extract_tables_select_if() -> None:
"""
Test that queries with an ``IF`` work as expected.
"""
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
FROM t1 LIMIT 10
"""
- ) == {Table("t1")}
+ )
+ == {Table("t1")}
+ )
def test_extract_tables_with_catalog() -> None:
@@ -312,29 +330,38 @@
"""
Test that tables in a ``WHERE`` subquery are parsed correctly.
"""
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT name
FROM t1
WHERE regionkey = (SELECT max(regionkey) FROM t2)
"""
- ) == {Table("t1"), Table("t2")}
+ )
+ == {Table("t1"), Table("t2")}
+ )
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT name
FROM t1
WHERE regionkey IN (SELECT regionkey FROM t2)
"""
- ) == {Table("t1"), Table("t2")}
+ )
+ == {Table("t1"), Table("t2")}
+ )
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT name
FROM t1
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
"""
- ) == {Table("t1"), Table("t2")}
+ )
+ == {Table("t1"), Table("t2")}
+ )
def test_extract_tables_describe() -> None:
@@ -348,12 +375,15 @@
"""
Test ``SHOW PARTITIONS``.
"""
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SHOW PARTITIONS FROM orders
WHERE ds >= '2013-01-01' ORDER BY ds DESC
"""
- ) == {Table("orders")}
+ )
+ == {Table("orders")}
+ )
def test_extract_tables_join() -> None:
@@ -365,8 +395,9 @@
Table("t2"),
}
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT a.date, b.name
FROM left_table a
JOIN (
@@ -377,10 +408,13 @@
) b
ON a.date = b.date
"""
- ) == {Table("left_table"), Table("right_table")}
+ )
+ == {Table("left_table"), Table("right_table")}
+ )
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT a.date, b.name
FROM left_table a
LEFT INNER JOIN (
@@ -391,10 +425,13 @@
) b
ON a.date = b.date
"""
- ) == {Table("left_table"), Table("right_table")}
+ )
+ == {Table("left_table"), Table("right_table")}
+ )
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT a.date, b.name
FROM left_table a
RIGHT OUTER JOIN (
@@ -405,10 +442,13 @@
) b
ON a.date = b.date
"""
- ) == {Table("left_table"), Table("right_table")}
+ )
+ == {Table("left_table"), Table("right_table")}
+ )
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT a.date, b.name
FROM left_table a
FULL OUTER JOIN (
@@ -419,15 +459,18 @@
) b
ON a.date = b.date
"""
- ) == {Table("left_table"), Table("right_table")}
+ )
+ == {Table("left_table"), Table("right_table")}
+ )
def test_extract_tables_semi_join() -> None:
"""
Test ``LEFT SEMI JOIN``.
"""
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT a.date, b.name
FROM left_table a
LEFT SEMI JOIN (
@@ -438,15 +481,18 @@
) b
ON a.data = b.date
"""
- ) == {Table("left_table"), Table("right_table")}
+ )
+ == {Table("left_table"), Table("right_table")}
+ )
def test_extract_tables_combinations() -> None:
"""
Test a complex case with nested queries.
"""
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT * FROM t1
WHERE s11 > ANY (
SELECT * FROM t1 UNION ALL SELECT * FROM (
@@ -460,10 +506,13 @@
)
)
"""
- ) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
+ )
+ == {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
+ )
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT * FROM (
SELECT * FROM (
SELECT * FROM (
@@ -472,45 +521,56 @@
) AS S2
) AS S3
"""
- ) == {Table("EmployeeS")}
+ )
+ == {Table("EmployeeS")}
+ )
def test_extract_tables_with() -> None:
"""
Test ``WITH``.
"""
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
WITH
x AS (SELECT a FROM t1),
y AS (SELECT a AS b FROM t2),
z AS (SELECT b AS c FROM t3)
SELECT c FROM z
"""
- ) == {Table("t1"), Table("t2"), Table("t3")}
+ )
+ == {Table("t1"), Table("t2"), Table("t3")}
+ )
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
WITH
x AS (SELECT a FROM t1),
y AS (SELECT a AS b FROM x),
z AS (SELECT b AS c FROM y)
SELECT c FROM z
"""
- ) == {Table("t1")}
+ )
+ == {Table("t1")}
+ )
def test_extract_tables_reusing_aliases() -> None:
"""
Test that the parser follows aliases.
"""
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
with q1 as ( select key from q2 where key = '5'),
q2 as ( select key from src where key = '5')
select * from (select key from q1) a
"""
- ) == {Table("src")}
+ )
+ == {Table("src")}
+ )
# weird query with circular dependency
assert (
@@ -547,8 +607,9 @@
"""
Test a few complex queries.
"""
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT sum(m_examples) AS "sum__m_example"
FROM (
SELECT
@@ -569,23 +630,29 @@
ORDER BY "sum__m_example" DESC
LIMIT 10;
"""
- ) == {
- Table("my_l_table"),
- Table("my_b_table"),
- Table("my_t_table"),
- Table("inner_table"),
- }
+ )
+ == {
+ Table("my_l_table"),
+ Table("my_b_table"),
+ Table("my_t_table"),
+ Table("inner_table"),
+ }
+ )
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT *
FROM table_a AS a, table_b AS b, table_c as c
WHERE a.id = b.id and b.id = c.id
"""
- ) == {Table("table_a"), Table("table_b"), Table("table_c")}
+ )
+ == {Table("table_a"), Table("table_b"), Table("table_c")}
+ )
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT somecol AS somecol
FROM (
WITH bla AS (
@@ -629,51 +696,63 @@
LIMIT 50000
)
"""
- ) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
+ )
+ == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
+ )
def test_extract_tables_mixed_from_clause() -> None:
"""
Test that the parser handles a ``FROM`` clause with table and subselect.
"""
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
SELECT *
FROM table_a AS a, (select * from table_b) AS b, table_c as c
WHERE a.id = b.id and b.id = c.id
"""
- ) == {Table("table_a"), Table("table_b"), Table("table_c")}
+ )
+ == {Table("table_a"), Table("table_b"), Table("table_c")}
+ )
def test_extract_tables_nested_select() -> None:
"""
Test that the parser handles selects inside functions.
"""
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
""",
- "mysql",
- ) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
+ "mysql",
+ )
+ == {Table("COLUMNS", "INFORMATION_SCHEMA")}
+ )
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
""",
- "mysql",
- ) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
+ "mysql",
+ )
+ == {Table("COLUMNS", "INFORMATION_SCHEMA")}
+ )
def test_extract_tables_complex_cte_with_prefix() -> None:
"""
Test that the parser handles CTEs with prefixes.
"""
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
AS (
SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
@@ -685,21 +764,26 @@
GROUP BY SalesYear, SalesPersonID
ORDER BY SalesPersonID, SalesYear;
"""
- ) == {Table("SalesOrderHeader")}
+ )
+ == {Table("SalesOrderHeader")}
+ )
def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
"""
Test that aliases that are keywords are parsed correctly.
"""
- assert extract_tables(
- """
+ assert (
+ extract_tables(
+ """
WITH
f AS (SELECT * FROM foo),
match AS (SELECT * FROM f)
SELECT * FROM match
"""
- ) == {Table("foo")}
+ )
+ == {Table("foo")}
+ )
def test_update() -> None:
@@ -1841,7 +1925,7 @@
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
assert len(script.statements) == 2
- assert script.format() == "SELECT\n 1;\nSELECT\n 2"
+ assert script.format() == "SELECT\n 1;\nSELECT\n 2;"
assert script.statements[0].format() == "SELECT\n 1"
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
@@ -2058,3 +2142,120 @@
| project Day1, Day2, Percentage = count_*100.0/count_1
""",
]
+
+
+@pytest.mark.parametrize(
+ "sql,rules,expected",
+ [
+ (
+ "SELECT t.foo FROM some_table AS t",
+ {Table("some_table"): "id = 42"},
+ "SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t",
+ ),
+ (
+ "SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
+ {Table("some_table"): "id = 42"},
+ (
+ "SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t "
+ "WHERE bar = 'baz'"
+ ),
+ ),
+ (
+ "SELECT t.foo FROM schema1.some_table AS t",
+ {Table("some_table", "schema1"): "id = 42"},
+ "SELECT t.foo FROM (SELECT * FROM schema1.some_table WHERE id = 42) AS t",
+ ),
+ (
+ "SELECT t.foo FROM schema1.some_table AS t",
+ {Table("some_table", "schema2"): "id = 42"},
+ "SELECT t.foo FROM schema1.some_table AS t",
+ ),
+ (
+ "SELECT t.foo FROM catalog1.schema1.some_table AS t",
+ {Table("some_table", "schema1", "catalog1"): "id = 42"},
+ "SELECT t.foo FROM (SELECT * FROM catalog1.schema1.some_table WHERE id = 42) AS t",
+ ),
+ (
+ "SELECT t.foo FROM catalog1.schema1.some_table AS t",
+ {Table("some_table", "schema1", "catalog2"): "id = 42"},
+ "SELECT t.foo FROM catalog1.schema1.some_table AS t",
+ ),
+ ],
+)
+def test_RLSAsSubquery(sql: str, rules: dict[Table, str], expected: str) -> None:
+ """
+ Test the `RLSAsSubquery` transformer.
+ """
+ statement = sqlglot.parse_one(sql)
+ transformer = RLSAsSubquery(rules)
+ assert str(statement.transform(transformer)) == expected
+
+
+@pytest.mark.parametrize(
+ "sql,rules,expected",
+ [
+ (
+ "SELECT t.foo FROM some_table AS t",
+ {Table("some_table"): "id = 42"},
+ "SELECT t.foo FROM some_table AS t WHERE id = 42",
+ ),
+ (
+ "SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
+ {Table("some_table"): "id = 42"},
+ "SELECT t.foo FROM some_table AS t WHERE id = 42 AND bar = 'baz'",
+ ),
+ ],
+)
+def test_RLSAsPredicate(sql: str, rules: dict[Table, str], expected: str) -> None:
+ """
+ Test the `RLSAsPredicate` transformer.
+ """
+ statement = sqlglot.parse_one(sql)
+ transformer = RLSAsPredicate(rules)
+ assert str(statement.transform(transformer)) == expected
+
+
+@pytest.mark.parametrize(
+ "sql,engine,limit,force,expected",
+ [
+ (
+ "SELECT TOP 10 * FROM Customers",
+ "teradatasql",
+ 5,
+ False,
+ "SELECT\nTOP 5\n *\nFROM Customers",
+ ),
+ (
+ "SELECT TOP 10 * FROM Customers",
+ "teradatasql",
+ 15,
+ False,
+ "SELECT\nTOP 10\n *\nFROM Customers",
+ ),
+ (
+ "SELECT TOP 10 * FROM Customers",
+ "teradatasql",
+ 15,
+ True,
+ "SELECT\nTOP 15\n *\nFROM Customers",
+ ),
+ (
+ "SELECT TOP 10 * FROM Customers",
+ "mssql",
+ 15,
+ True,
+ "SELECT\nTOP 15\n *\nFROM Customers",
+ ),
+ ],
+)
+def test_apply_limit(
+ sql: str,
+ engine: str,
+ limit: int,
+ force: bool,
+ expected: str,
+) -> None:
+ """
+ Test the `apply_limit` function.
+ """
+ assert SQLStatement(sql, engine).apply_limit(limit, force).format() == expected