| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import logging |
| from dataclasses import dataclass # pylint: disable=wrong-import-order |
| from enum import Enum |
| from typing import List, Optional, Set |
| from urllib import parse |
| |
| import sqlparse |
| from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList |
| from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace |
| from sqlparse.utils import imt |
| |
| RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"} |
| ON_KEYWORD = "ON" |
| PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"} |
| CTE_PREFIX = "CTE__" |
| logger = logging.getLogger(__name__) |
| |
| |
| class CtasMethod(str, Enum): |
| TABLE = "TABLE" |
| VIEW = "VIEW" |
| |
| |
| def _extract_limit_from_query(statement: TokenList) -> Optional[int]: |
| """ |
| Extract limit clause from SQL statement. |
| |
| :param statement: SQL statement |
| :return: Limit extracted from query, None if no limit present in statement |
| """ |
| idx, _ = statement.token_next_by(m=(Keyword, "LIMIT")) |
| if idx is not None: |
| _, token = statement.token_next(idx=idx) |
| if token: |
| if isinstance(token, IdentifierList): |
| # In case of "LIMIT <offset>, <limit>", find comma and extract |
| # first succeeding non-whitespace token |
| idx, _ = token.token_next_by(m=(sqlparse.tokens.Punctuation, ",")) |
| _, token = token.token_next(idx=idx) |
| if token and token.ttype == sqlparse.tokens.Literal.Number.Integer: |
| return int(token.value) |
| return None |
| |
| |
| def strip_comments_from_sql(statement: str) -> str: |
| """ |
| Strips comments from a SQL statement, does a simple test first |
| to avoid always instantiating the expensive ParsedQuery constructor |
| |
| This is useful for engines that don't support comments |
| |
| :param statement: A string with the SQL statement |
| :return: SQL statement without comments |
| """ |
| return ParsedQuery(statement).strip_comments() if "--" in statement else statement |
| |
| |
| @dataclass(eq=True, frozen=True) |
| class Table: # pylint: disable=too-few-public-methods |
| """ |
| A fully qualified SQL table conforming to [[catalog.]schema.]table. |
| """ |
| |
| table: str |
| schema: Optional[str] = None |
| catalog: Optional[str] = None |
| |
| def __str__(self) -> str: |
| """ |
| Return the fully qualified SQL table name. |
| """ |
| |
| return ".".join( |
| parse.quote(part, safe="").replace(".", "%2E") |
| for part in [self.catalog, self.schema, self.table] |
| if part |
| ) |
| |
| |
| class ParsedQuery: |
| def __init__(self, sql_statement: str, strip_comments: bool = False): |
| if strip_comments: |
| sql_statement = sqlparse.format(sql_statement, strip_comments=True) |
| |
| self.sql: str = sql_statement |
| self._tables: Set[Table] = set() |
| self._alias_names: Set[str] = set() |
| self._limit: Optional[int] = None |
| |
| logger.debug("Parsing with sqlparse statement: %s", self.sql) |
| self._parsed = sqlparse.parse(self.stripped()) |
| for statement in self._parsed: |
| self._limit = _extract_limit_from_query(statement) |
| |
| @property |
| def tables(self) -> Set[Table]: |
| if not self._tables: |
| for statement in self._parsed: |
| self._extract_from_token(statement) |
| |
| self._tables = { |
| table for table in self._tables if str(table) not in self._alias_names |
| } |
| return self._tables |
| |
| @property |
| def limit(self) -> Optional[int]: |
| return self._limit |
| |
| def is_select(self) -> bool: |
| return self._parsed[0].get_type() == "SELECT" |
| |
| def is_valid_ctas(self) -> bool: |
| return self._parsed[-1].get_type() == "SELECT" |
| |
| def is_valid_cvas(self) -> bool: |
| return len(self._parsed) == 1 and self._parsed[0].get_type() == "SELECT" |
| |
| def is_explain(self) -> bool: |
| # Remove comments |
| statements_without_comments = sqlparse.format( |
| self.stripped(), strip_comments=True |
| ) |
| |
| # Explain statements will only be the first statement |
| return statements_without_comments.startswith("EXPLAIN") |
| |
| def is_show(self) -> bool: |
| # Remove comments |
| statements_without_comments = sqlparse.format( |
| self.stripped(), strip_comments=True |
| ) |
| # Show statements will only be the first statement |
| return statements_without_comments.upper().startswith("SHOW") |
| |
| def is_set(self) -> bool: |
| # Remove comments |
| statements_without_comments = sqlparse.format( |
| self.stripped(), strip_comments=True |
| ) |
| # Set statements will only be the first statement |
| return statements_without_comments.upper().startswith("SET") |
| |
| def is_unknown(self) -> bool: |
| return self._parsed[0].get_type() == "UNKNOWN" |
| |
| def stripped(self) -> str: |
| return self.sql.strip(" \t\n;") |
| |
| def strip_comments(self) -> str: |
| return sqlparse.format(self.stripped(), strip_comments=True) |
| |
| def get_statements(self) -> List[str]: |
| """Returns a list of SQL statements as strings, stripped""" |
| statements = [] |
| for statement in self._parsed: |
| if statement: |
| sql = str(statement).strip(" \n;\t") |
| if sql: |
| statements.append(sql) |
| return statements |
| |
| @staticmethod |
| def _get_table(tlist: TokenList) -> Optional[Table]: |
| """ |
| Return the table if valid, i.e., conforms to the [[catalog.]schema.]table |
| construct. |
| |
| :param tlist: The SQL tokens |
| :returns: The table if the name conforms |
| """ |
| |
| # Strip the alias if present. |
| idx = len(tlist.tokens) |
| |
| if tlist.has_alias(): |
| ws_idx, _ = tlist.token_next_by(t=Whitespace) |
| |
| if ws_idx != -1: |
| idx = ws_idx |
| |
| tokens = tlist.tokens[:idx] |
| |
| if ( |
| len(tokens) in (1, 3, 5) |
| and all(imt(token, t=[Name, String]) for token in tokens[::2]) |
| and all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2]) |
| ): |
| return Table(*[remove_quotes(token.value) for token in tokens[::-2]]) |
| |
| return None |
| |
| @staticmethod |
| def _is_identifier(token: Token) -> bool: |
| return isinstance(token, (IdentifierList, Identifier)) |
| |
| def _process_tokenlist(self, token_list: TokenList) -> None: |
| """ |
| Add table names to table set |
| |
| :param token_list: TokenList to be processed |
| """ |
| # exclude subselects |
| if "(" not in str(token_list): |
| table = self._get_table(token_list) |
| if table and not table.table.startswith(CTE_PREFIX): |
| self._tables.add(table) |
| return |
| |
| # store aliases |
| if token_list.has_alias(): |
| self._alias_names.add(token_list.get_alias()) |
| |
| # some aliases are not parsed properly |
| if token_list.tokens[0].ttype == Name: |
| self._alias_names.add(token_list.tokens[0].value) |
| self._extract_from_token(token_list) |
| |
| def as_create_table( |
| self, |
| table_name: str, |
| schema_name: Optional[str] = None, |
| overwrite: bool = False, |
| method: CtasMethod = CtasMethod.TABLE, |
| ) -> str: |
| """Reformats the query into the create table as query. |
| |
| Works only for the single select SQL statements, in all other cases |
| the sql query is not modified. |
| :param table_name: table that will contain the results of the query execution |
| :param schema_name: schema name for the target table |
| :param overwrite: table_name will be dropped if true |
| :param method: method for the CTA query, currently view or table creation |
| :return: Create table as query |
| """ |
| exec_sql = "" |
| sql = self.stripped() |
| # TODO(bkyryliuk): quote full_table_name |
| full_table_name = f"{schema_name}.{table_name}" if schema_name else table_name |
| if overwrite: |
| exec_sql = f"DROP {method} IF EXISTS {full_table_name};\n" |
| exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}" |
| return exec_sql |
| |
| def _extract_from_token( # pylint: disable=too-many-branches |
| self, token: Token |
| ) -> None: |
| """ |
| <Identifier> store a list of subtokens and <IdentifierList> store lists of |
| subtoken list. |
| |
| It extracts <IdentifierList> and <Identifier> from :param token: and loops |
| through all subtokens recursively. It finds table_name_preceding_token and |
| passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate |
| self._tables. |
| |
| :param token: instance of Token or child class, e.g. TokenList, to be processed |
| """ |
| if not hasattr(token, "tokens"): |
| return |
| |
| table_name_preceding_token = False |
| |
| for item in token.tokens: |
| if item.is_group and not self._is_identifier(item): |
| self._extract_from_token(item) |
| |
| if item.ttype in Keyword and ( |
| item.normalized in PRECEDES_TABLE_NAME |
| or item.normalized.endswith(" JOIN") |
| ): |
| table_name_preceding_token = True |
| continue |
| |
| if item.ttype in Keyword: |
| table_name_preceding_token = False |
| continue |
| |
| if table_name_preceding_token: |
| if isinstance(item, Identifier): |
| self._process_tokenlist(item) |
| elif isinstance(item, IdentifierList): |
| for token2 in item.get_identifiers(): |
| if isinstance(token2, TokenList): |
| self._process_tokenlist(token2) |
| elif isinstance(item, IdentifierList): |
| if any(not self._is_identifier(token2) for token2 in item.tokens): |
| self._extract_from_token(item) |
| |
| def set_or_update_query_limit(self, new_limit: int) -> str: |
| """Returns the query with the specified limit. |
| |
| Does not change the underlying query if user did not apply the limit, |
| otherwise replaces the limit with the lower value between existing limit |
| in the query and new_limit. |
| |
| :param new_limit: Limit to be incorporated into returned query |
| :return: The original query with new limit |
| """ |
| if not self._limit: |
| return f"{self.stripped()}\nLIMIT {new_limit}" |
| limit_pos = None |
| statement = self._parsed[0] |
| # Add all items to before_str until there is a limit |
| for pos, item in enumerate(statement.tokens): |
| if item.ttype in Keyword and item.value.lower() == "limit": |
| limit_pos = pos |
| break |
| _, limit = statement.token_next(idx=limit_pos) |
| # Override the limit only when it exceeds the configured value. |
| if limit.ttype == sqlparse.tokens.Literal.Number.Integer and new_limit < int( |
| limit.value |
| ): |
| limit.value = new_limit |
| elif limit.is_group: |
| limit.value = f"{next(limit.get_identifiers())}, {new_limit}" |
| |
| str_res = "" |
| for i in statement.tokens: |
| str_res += str(i.value) |
| return str_res |