| |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import trac.db.util |
| from trac.util import concurrency |
| |
| import sqlparse |
| import sqlparse.tokens as Tokens |
| import sqlparse.sql as Types |
| |
| from multiproduct.cache import lru_cache |
| |
| __all__ = ['BloodhoundIterableCursor', 'BloodhoundConnectionWrapper', 'ProductEnvContextManager'] |
| |
| SKIP_TABLES = ['auth_cookie', |
| 'session', 'session_attribute', |
| 'cache', |
| 'repository', 'revision', 'node_change', |
| 'bloodhound_product', 'bloodhound_productresourcemap', 'bloodhound_productconfig', |
| 'sqlite_master', 'bloodhound_relations' |
| ] |
| TRANSLATE_TABLES = ['system', |
| 'ticket', 'ticket_change', 'ticket_custom', |
| 'attachment', |
| 'enum', 'component', 'milestone', 'version', |
| 'permission', |
| 'wiki', |
| 'report', |
| ] |
| PRODUCT_COLUMN = 'product' |
| GLOBAL_PRODUCT = '' |
| |
| # Singleton used to mark translator as unset |
| class empty_translator(object): |
| pass |
| |
| translator_not_set = empty_translator() |
| |
| @lru_cache(maxsize=1000) |
| def translate_sql(env, sql): |
| translator = None |
| log = None |
| if env is not None: |
| # FIXME: This is the right way to do it but breaks translation |
| # if trac.db.api.DatabaseManager(self.env).debug_sql: |
| if (env.parent or env).config['trac'].get('debug_sql', False): |
| log = env.log |
| product_prefix = env.product.prefix if env.product else GLOBAL_PRODUCT |
| translator = BloodhoundProductSQLTranslate(SKIP_TABLES, |
| TRANSLATE_TABLES, |
| PRODUCT_COLUMN, |
| product_prefix) |
| if log: |
| log.debug('Original SQl: %s', sql) |
| realsql = translator.translate(sql) if (translator is not None) else sql |
| if log: |
| log.debug('SQL: %s', realsql) |
| return realsql |
| |
| class BloodhoundIterableCursor(trac.db.util.IterableCursor): |
| __slots__ = trac.db.util.IterableCursor.__slots__ + ['_translator'] |
| _tls = concurrency.ThreadLocal(env=None) |
| |
| def __init__(self, cursor, log=None): |
| super(BloodhoundIterableCursor, self).__init__(cursor, log=log) |
| |
| def execute(self, sql, args=None): |
| return super(BloodhoundIterableCursor, self).execute(translate_sql(self.env, sql), args=args) |
| |
| def executemany(self, sql, args=None): |
| return super(BloodhoundIterableCursor, self).executemany(translate_sql(self.env, sql), args=args) |
| |
| @property |
| def env(self): |
| return self._tls.env |
| |
| @classmethod |
| def set_env(cls, env): |
| cls._tls.env = env |
| |
| @classmethod |
| def get_env(cls): |
| return cls._tls.env |
| |
| @classmethod |
| def cache_reset(cls): |
| translate_sql.clear() |
| |
| # replace trac.db.util.IterableCursor with BloodhoundIterableCursor |
| trac.db.util.IterableCursor = BloodhoundIterableCursor |
| |
| class BloodhoundConnectionWrapper(object): |
| |
| def __init__(self, connection, env): |
| self.connection = connection |
| self.env = env |
| |
| def __getattr__(self, name): |
| return getattr(self.connection, name) |
| |
| def execute(self, query, params=None): |
| BloodhoundIterableCursor.set_env(self.env) |
| return self.connection.execute(query, params=params) |
| |
| __call__ = execute |
| |
| def executemany(self, query, params=None): |
| BloodhoundIterableCursor.set_env(self.env) |
| return self.connection.executemany(query, params=params) |
| |
| def cursor(self): |
| return BloodhoundCursorWrapper(self.connection.cursor(), self.env) |
| |
| class BloodhoundCursorWrapper(object): |
| |
| def __init__(self, cursor, env): |
| self.cursor = cursor |
| self.env = env |
| |
| def __getattr__(self, name): |
| return getattr(self.cursor, name) |
| |
| def __iter__(self): |
| return self.cursor.__iter__() |
| |
| def execute(self, sql, args=None): |
| BloodhoundIterableCursor.set_env(self.env) |
| return self.cursor.execute(sql, args=args) |
| |
| def executemany(self, sql, args=None): |
| BloodhoundIterableCursor.set_env(self.env) |
| return self.cursor.executemany(sql, args=args) |
| |
| class ProductEnvContextManager(object): |
| """Wrap an underlying database context manager so as to keep track |
| of (nested) product context. |
| """ |
| def __init__(self, context, env=None): |
| """Initialize product database context. |
| |
| :param context: Inner database context (e.g. `QueryContextManager`, |
| `TransactionContextManager` ) |
| :param env: An instance of either `trac.env.Environment` or |
| `multiproduct.env.ProductEnvironment` used to |
| reduce the scope of database queries. If set |
| to `None` then SQL queries will not be translated, |
| which is equivalent to having direct database access. |
| """ |
| self.db_context = context |
| self.env = env |
| |
| def __enter__(self): |
| """Keep track of previous product context and override it with `env`; |
| then enter the inner database context. |
| """ |
| return BloodhoundConnectionWrapper(self.db_context.__enter__(), self.env) |
| |
| def __exit__(self, et, ev, tb): |
| """Uninstall current product context by restoring the last one; |
| then leave the inner database context. |
| """ |
| return self.db_context.__exit__(et, ev, tb) |
| |
| def __call__(self, *args, **kwargs): |
| """Forward attribute access to nested database context on failure. |
| """ |
| BloodhoundIterableCursor.set_env(self.env) |
| return self.db_context(*args, **kwargs) |
| |
| def __getattr__(self, attrnm): |
| """Forward attribute access to nested database context on failure. |
| """ |
| return getattr(self.db_context, attrnm) |
| |
| def execute(self, sql, params=None): |
| BloodhoundIterableCursor.set_env(self.env) |
| return self.db_context.execute(sql, params=params) |
| |
| def executemany(self, sql, params=None): |
| BloodhoundIterableCursor.set_env(self.env) |
| return self.db_context.executemany(sql, params=params) |
| |
| |
| class BloodhoundProductSQLTranslate(object): |
| _join_statements = ['LEFT JOIN', 'LEFT OUTER JOIN', |
| 'RIGHT JOIN', 'RIGHT OUTER JOIN', |
| 'JOIN', 'INNER JOIN'] |
| _from_end_words = ['WHERE', 'GROUP', 'HAVING', 'ORDER', 'UNION', 'LIMIT'] |
| |
| def __init__(self, skip_tables, translate_tables, product_column, product_prefix): |
| self._skip_tables = skip_tables |
| self._translate_tables = translate_tables |
| self._product_column = product_column |
| self._product_prefix = product_prefix |
| |
| def _sqlparse_underline_hack(self, token): |
| underline_token = lambda token: token.ttype == Tokens.Token.Error and token.value == '_' |
| identifier_token = lambda token: isinstance(token, Types.Identifier) or isinstance(token, Types.Token) |
| def prefix_token(token, prefix): |
| if identifier_token(token): |
| if isinstance(token, Types.IdentifierList): |
| token = token.tokens[0] |
| token.value = prefix + token.value |
| token.normalized = token.value.upper() if token.ttype in Tokens.Keyword \ |
| else token.value |
| if hasattr(token, 'tokens'): |
| if len(token.tokens) != 1: |
| raise Exception("Internal error, invalid token list") |
| token.tokens[0].value, token.tokens[0].normalized = token.value, token.normalized |
| return |
| |
| if hasattr(token, 'tokens') and token.tokens and len(token.tokens): |
| current = self._token_first(token) |
| while current: |
| leftover = None |
| if underline_token(current): |
| prefix = '' |
| while underline_token(current): |
| prefix += current.value |
| prev = current |
| current = self._token_next(token, current) |
| self._token_delete(token, prev) |
| # expression ends with _ ... push the token to parent |
| if not current: |
| return prev |
| prefix_token(current, prefix) |
| else: |
| leftover = self._sqlparse_underline_hack(current) |
| if leftover: |
| leftover.parent = token |
| self._token_insert_after(token, current, leftover) |
| current = leftover if leftover else self._token_next(token, current) |
| return None |
| |
| def _select_table_name_alias(self, tokens): |
| return filter(lambda t: t.upper() != 'AS', [t.value for t in tokens if t.value.strip()]) |
| def _column_expression_name_alias(self, tokens): |
| return filter(lambda t: t.upper() != 'AS', [t.value for t in tokens if t.value.strip()]) |
| |
| def _select_alias_sql(self, alias): |
| return ' AS %s' % alias |
| |
| def _translated_table_view_sql(self, name, alias=None): |
| sql = "(SELECT * FROM %s WHERE %s='%s')" % (name, self._product_column, self._product_prefix) |
| if alias: |
| sql += self._select_alias_sql(alias) |
| return sql |
| |
| def _prefixed_table_entity_name(self, tablename): |
| return '"%s_%s"' % (self._product_prefix, tablename) if self._product_prefix else tablename |
| |
| def _prefixed_table_view_sql(self, name, alias): |
| return '(SELECT * FROM %s) AS %s' % (self._prefixed_table_entity_name(name), |
| alias) |
| |
| def _token_first(self, parent): |
| return parent.token_first() |
| def _token_next_match(self, parent, start_token, ttype, token): |
| return parent.token_next_match(self._token_idx(parent, start_token), ttype, token) |
| def _token_next(self, parent, from_token): |
| return parent.token_next(self._token_idx(parent, from_token)) |
| def _token_prev(self, parent, from_token): |
| return parent.token_prev(self._token_idx(parent, from_token)) |
| def _token_next_by_instance(self, parent, start_token, klass): |
| return parent.token_next_by_instance(self._token_idx(parent, start_token), klass) |
| def _token_next_by_type(self, parent, start_token, ttype): |
| return parent.token_next_by_type(self._token_idx(parent, start_token), ttype) |
| def _token_insert_before(self, parent, where, token): |
| return parent.insert_before(where, token) |
| def _token_insert_after(self, parent, where, token): |
| return parent.insert_after(where, token) |
| def _token_idx(self, parent, token): |
| return parent.token_index(token) |
| def _token_delete(self, parent, token): |
| idx = self._token_idx(parent, token) |
| del parent.tokens[idx] |
| return idx |
| def _token_insert(self, parent, idx, token): |
| parent.tokens.insert(idx, token) |
| |
| def _eval_expression_value(self, parent, token): |
| if isinstance(token, Types.Parenthesis): |
| t = self._token_first(token) |
| if t.match(Tokens.Punctuation, '('): |
| t = self._token_next(token, t) |
| if t.match(Tokens.DML, 'SELECT'): |
| self._select(token, t) |
| |
| def _expression_token_unwind_hack(self, parent, token, start_token): |
| # hack to workaround sqlparse bug that wrongly presents list of tokens |
| # as IdentifierList in certain situations |
| if isinstance(token, Types.IdentifierList): |
| idx = self._token_delete(parent, token) |
| for t in token.tokens: |
| self._token_insert(parent, idx, t) |
| idx += 1 |
| token = self._token_next(parent, start_token) |
| return token |
| |
| def _where(self, parent, where_token): |
| if isinstance(where_token, Types.Where): |
| token = self._token_first(where_token) |
| if not token.match(Tokens.Keyword, 'WHERE'): |
| raise Exception("Invalid WHERE statement") |
| while token: |
| self._eval_expression_value(where_token, token) |
| token = self._token_next(where_token, token) |
| return |
| |
| def _select_expression_tokens(self, parent, first_token, end_words): |
| if isinstance(first_token, Types.IdentifierList): |
| return first_token, [list(first_token.flatten())] |
| tokens = list() |
| current_list = list() |
| current_token = first_token |
| while current_token and not current_token.match(Tokens.Keyword, end_words): |
| if current_token.match(Tokens.Punctuation, ','): |
| if current_list: |
| tokens.append(current_list) |
| current_list = list() |
| elif current_token.is_whitespace(): |
| pass |
| else: |
| current_list.append(current_token) |
| current_token = self._token_next(parent, current_token) |
| if current_list: |
| tokens.append(current_list) |
| return current_token, tokens |
| |
| def _select_join(self, parent, start_token, end_words): |
| current_token = self._select_from(parent, start_token, ['ON'], force_alias=True) |
| tokens = list() |
| if current_token: |
| current_token = self._token_next(parent, current_token) |
| while current_token and \ |
| not current_token.match(Tokens.Keyword, end_words) and \ |
| not isinstance(current_token, Types.Where): |
| tokens.append(current_token) |
| current_token = self._token_next(parent, current_token) |
| return current_token |
| |
| def _select_from(self, parent, start_token, end_words, table_name_callback=None, force_alias=False): |
| def inject_table_view(token, name, alias): |
| if name in self._skip_tables: |
| pass |
| elif name in self._translate_tables: |
| if force_alias and not alias: |
| alias = name |
| parent.tokens[self._token_idx(parent, token)] = sqlparse.parse(self._translated_table_view_sql(name, |
| alias=alias))[0] |
| if table_name_callback: |
| table_name_callback(name) |
| else: |
| if not alias: |
| alias = name |
| parent.tokens[self._token_idx(parent, token)] = sqlparse.parse(self._prefixed_table_view_sql(name, |
| alias))[0] |
| if table_name_callback: |
| table_name_callback(name) |
| |
| def inject_table_alias(token, alias): |
| parent.tokens[self._token_idx(parent, token)] = sqlparse.parse(self._select_alias_sql(alias))[0] |
| |
| def process_table_name_tokens(nametokens): |
| if nametokens: |
| l = self._select_table_name_alias(nametokens) |
| if not l: |
| raise Exception("Invalid FROM table name") |
| name, alias = l[0], None |
| alias = l[1] if len(l) > 1 else name |
| if not name in self._skip_tables: |
| token = nametokens[0] |
| for t in nametokens[1:]: |
| self._token_delete(parent, t) |
| inject_table_view(token, name, alias) |
| return list() |
| |
| current_token = self._token_next(parent, start_token) |
| prev_token = start_token |
| table_name_tokens = list() |
| join_tokens = list() |
| while current_token and \ |
| not current_token.match(Tokens.Keyword, end_words) and \ |
| not isinstance(current_token, Types.Where): |
| next_token = self._token_next(parent, current_token) |
| if current_token.is_whitespace(): |
| pass |
| elif isinstance(current_token, Types.IdentifierList): |
| current_token = self._expression_token_unwind_hack(parent, current_token, prev_token) |
| continue |
| elif isinstance(current_token, Types.Identifier): |
| parenthesis = filter(lambda t: isinstance(t, Types.Parenthesis), current_token.tokens) |
| if parenthesis: |
| for p in parenthesis: |
| t = self._token_next(p, self._token_first(p)) |
| if not t.match(Tokens.DML, 'SELECT'): |
| raise Exception("Invalid subselect statement") |
| self._select(p, t) |
| else: |
| tablename = current_token.value.strip() |
| tablealias = current_token.get_name().strip() |
| if tablename == tablealias: |
| table_name_tokens.append(current_token) |
| else: |
| inject_table_view(current_token, tablename, tablealias) |
| elif isinstance(current_token, Types.Parenthesis): |
| t = self._token_next(current_token, self._token_first(current_token)) |
| if t.match(Tokens.DML, 'SELECT'): |
| identifier_token = self._token_next(parent, current_token) |
| as_token = None |
| if identifier_token.match(Tokens.Keyword, 'AS'): |
| as_token = identifier_token |
| identifier_token = self._token_next(parent, identifier_token) |
| if not isinstance(identifier_token, Types.Identifier): |
| raise Exception("Invalid subselect statement") |
| next_token = self._token_next(parent, identifier_token) |
| self._select(current_token, t) |
| if as_token: |
| self._token_delete(parent, as_token) |
| inject_table_alias(identifier_token, identifier_token.value) |
| elif current_token.ttype == Tokens.Punctuation: |
| if table_name_tokens: |
| next_token = self._token_next(parent, current_token) |
| table_name_tokens = process_table_name_tokens(table_name_tokens) |
| elif current_token.match(Tokens.Keyword, ['JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER'] + self._join_statements): |
| join_tokens.append(current_token.value.strip().upper()) |
| join = ' '.join(join_tokens) |
| if join in self._join_statements: |
| join_tokens = list() |
| table_name_tokens = process_table_name_tokens(table_name_tokens) |
| next_token = self._select_join(parent, |
| current_token, |
| ['JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER'] |
| + self._join_statements |
| + self._from_end_words) |
| elif current_token.ttype == Tokens.Keyword or \ |
| current_token.ttype == Tokens.Token.Literal.Number.Integer: |
| table_name_tokens.append(current_token) |
| else: |
| raise Exception("Failed to parse FROM table name") |
| prev_token = current_token |
| current_token = next_token |
| if prev_token: |
| process_table_name_tokens(table_name_tokens) |
| return current_token |
| |
| def _select(self, parent, start_token, insert_table=None): |
| token = self._token_next(parent, start_token) |
| fields_token = self._token_next(parent, token) if token.match(Tokens.Keyword, ['ALL', 'DISTINCT']) else token |
| current_token, field_lists = self._select_expression_tokens(parent, fields_token, ['FROM'] + self._from_end_words) |
| def handle_insert_table(table_name): |
| if insert_table and insert_table in self._translate_tables: |
| if not field_lists or not field_lists[-1]: |
| raise Exception("Invalid SELECT field list") |
| last_token = list(field_lists[-1][-1].flatten())[-1] |
| for keyword in ["'", self._product_prefix, "'", ' ', ',']: |
| self._token_insert_after(last_token.parent, last_token, Types.Token(Tokens.Keyword, keyword)) |
| return |
| table_name_callback = handle_insert_table if insert_table else None |
| from_token = self._token_next_match(parent, start_token, Tokens.Keyword, 'FROM') |
| if not from_token: |
| # FROM not always required, example would be SELECT CURRVAL('"ticket_id_seq"') |
| return current_token |
| current_token = self._select_from(parent, |
| from_token, self._from_end_words, |
| table_name_callback=table_name_callback) |
| if not current_token: |
| return None |
| while current_token: |
| if isinstance(current_token, Types.Where) or \ |
| current_token.match(Tokens.Keyword, ['GROUP', 'HAVING', 'ORDER', 'LIMIT']): |
| if isinstance(current_token, Types.Where): |
| self._where(parent, current_token) |
| start_token = self._token_next(parent, current_token) |
| next_token = self._token_next_match(parent, |
| start_token, |
| Tokens.Keyword, |
| self._from_end_words) if start_token else None |
| elif current_token.match(Tokens.Keyword, ['UNION']): |
| token = self._token_next(parent, current_token) |
| if not token.match(Tokens.DML, 'SELECT'): |
| raise Exception("Invalid SELECT UNION statement") |
| token = self._select(parent, current_token, insert_table=insert_table) |
| next_token = self._token_next(parent, token) if token else None |
| else: |
| raise Exception("Unsupported SQL statement") |
| current_token = next_token |
| return current_token |
| |
| def _replace_table_entity_name(self, parent, token, table_name, entity_name=None): |
| if not entity_name: |
| entity_name = table_name |
| next_token = self._token_next(parent, token) |
| if not table_name in self._skip_tables + self._translate_tables: |
| token_to_replace = parent.tokens[self._token_idx(parent, token)] |
| if isinstance(token_to_replace, Types.Function): |
| t = self._token_first(token_to_replace) |
| if isinstance(t, Types.Identifier): |
| token_to_replace.tokens[self._token_idx(token_to_replace, t)] = Types.Token(Tokens.Keyword, |
| self._prefixed_table_entity_name(entity_name)) |
| elif isinstance(token_to_replace, Types.Identifier) or isinstance(token_to_replace, Types.Token): |
| parent.tokens[self._token_idx(parent, token_to_replace)] = Types.Token(Tokens.Keyword, |
| self._prefixed_table_entity_name(entity_name)) |
| else: |
| raise Exception("Internal error, invalid table entity token type") |
| return next_token |
| |
| def _insert(self, parent, start_token): |
| token = self._token_next(parent, start_token) |
| if not token.match(Tokens.Keyword, 'INTO'): |
| raise Exception("Invalid INSERT statement") |
| def insert_extra_column(tablename, columns_token): |
| if tablename in self._translate_tables and \ |
| isinstance(columns_token, Types.Parenthesis): |
| ptoken = self._token_first(columns_token) |
| if not ptoken.match(Tokens.Punctuation, '('): |
| raise Exception("Invalid INSERT statement, expected parenthesis around columns") |
| ptoken = self._token_next(columns_token, ptoken) |
| last_token = ptoken |
| while ptoken: |
| if isinstance(ptoken, Types.IdentifierList): |
| if any(i.get_name() == 'product' |
| for i in ptoken.get_identifiers() |
| if isinstance(i, Types.Identifier)): |
| return True |
| last_token = ptoken |
| ptoken = self._token_next(columns_token, ptoken) |
| if not last_token or \ |
| not last_token.match(Tokens.Punctuation, ')'): |
| raise Exception("Invalid INSERT statement, unable to find column parenthesis end") |
| for keyword in [',', ' ', self._product_column]: |
| self._token_insert_before(columns_token, last_token, Types.Token(Tokens.Keyword, keyword)) |
| return False |
| def insert_extra_column_value(tablename, ptoken, before_token): |
| if tablename in self._translate_tables: |
| for keyword in [',', "'", self._product_prefix, "'"]: |
| self._token_insert_before(ptoken, before_token, Types.Token(Tokens.Keyword, keyword)) |
| return |
| tablename = None |
| table_name_token = self._token_next(parent, token) |
| has_product_column = False |
| if isinstance(table_name_token, Types.Function): |
| token = self._token_first(table_name_token) |
| if isinstance(token, Types.Identifier): |
| tablename = token.get_name() |
| columns_token = self._replace_table_entity_name(table_name_token, token, tablename) |
| if columns_token.match(Tokens.Keyword, 'VALUES'): |
| token = columns_token |
| else: |
| has_product_column = insert_extra_column(tablename, columns_token) |
| token = self._token_next(parent, table_name_token) |
| else: |
| tablename = table_name_token.value |
| columns_token = self._replace_table_entity_name(parent, table_name_token, tablename) |
| if columns_token.match(Tokens.Keyword, 'VALUES'): |
| token = columns_token |
| else: |
| has_product_column = insert_extra_column(tablename, columns_token) |
| token = self._token_next(parent, columns_token) |
| if has_product_column: |
| pass # INSERT already has product, no translation needed |
| elif token.match(Tokens.Keyword, 'VALUES'): |
| separators = [',', '(', ')'] |
| token = self._token_next(parent, token) |
| while token: |
| if isinstance(token, Types.Parenthesis): |
| ptoken = self._token_first(token) |
| if not ptoken.match(Tokens.Punctuation, '('): |
| raise Exception("Invalid INSERT statement") |
| last_token = ptoken |
| while ptoken: |
| if not ptoken.match(Tokens.Punctuation, separators) and \ |
| not ptoken.match(Tokens.Keyword, separators) and \ |
| not ptoken.is_whitespace(): |
| ptoken = self._expression_token_unwind_hack(token, ptoken, self._token_prev(token, ptoken)) |
| self._eval_expression_value(token, ptoken) |
| last_token = ptoken |
| ptoken = self._token_next(token, ptoken) |
| if not last_token or \ |
| not last_token.match(Tokens.Punctuation, ')'): |
| raise Exception("Invalid INSERT statement, unable to find column value parenthesis end") |
| insert_extra_column_value(tablename, token, last_token) |
| elif not token.match(Tokens.Punctuation, separators) and\ |
| not token.match(Tokens.Keyword, separators) and\ |
| not token.is_whitespace(): |
| raise Exception("Invalid INSERT statement, unable to parse VALUES section") |
| token = self._token_next(parent, token) |
| elif token.match(Tokens.DML, 'SELECT'): |
| self._select(parent, token, insert_table=tablename) |
| else: |
| raise Exception("Invalid INSERT statement") |
| return |
| |
| def _update_delete_where_limit(self, table_name, parent, start_token): |
| if not start_token: |
| return |
| where_token = start_token if isinstance(start_token, Types.Where) \ |
| else self._token_next_by_instance(parent, start_token, Types.Where) |
| if where_token: |
| self._where(parent, where_token) |
| if not table_name in self._translate_tables: |
| return |
| if where_token: |
| keywords = [self._product_column, '=', "'", self._product_prefix, "'", ' ', 'AND', ' '] |
| keywords.reverse() |
| token = self._token_first(where_token) |
| if not token.match(Tokens.Keyword, 'WHERE'): |
| token = self._token_next_match(where_token, token, Tokens.Keyword, 'WHERE') |
| if not token: |
| raise Exception("Invalid UPDATE statement, failed to parse WHERE") |
| for keyword in keywords: |
| self._token_insert_after(where_token, token, Types.Token(Tokens.Keyword, keyword)) |
| else: |
| keywords = ['WHERE', ' ', self._product_column, '=', "'", self._product_prefix, "'"] |
| limit_token = self._token_next_match(parent, start_token, Tokens.Keyword, 'LIMIT') |
| if limit_token: |
| for keyword in keywords: |
| self._token_insert_before(parent, limit_token, Types.Token(Tokens.Keyword, keyword)) |
| self._token_insert_before(parent, limit_token, Types.Token(Tokens.Keyword, ' ')) |
| else: |
| last_token = token = start_token |
| while token: |
| last_token = token |
| token = self._token_next(parent, token) |
| keywords.reverse() |
| for keyword in keywords: |
| self._token_insert_after(parent, last_token, Types.Token(Tokens.Keyword, keyword)) |
| return |
| |
| def _get_entity_name_from_token(self, parent, token): |
| tablename = None |
| if isinstance(token, Types.Identifier): |
| tablename = token.get_name() |
| elif isinstance(token, Types.Function): |
| token = self._token_first(token) |
| if isinstance(token, Types.Identifier): |
| tablename = token.get_name() |
| elif isinstance(token, Types.Token): |
| tablename = token.value |
| return tablename |
| |
| def _update(self, parent, start_token): |
| table_name_token = self._token_next(parent, start_token) |
| tablename = self._get_entity_name_from_token(parent, table_name_token) |
| if not tablename: |
| raise Exception("Invalid UPDATE statement, expected table name") |
| token = self._replace_table_entity_name(parent, table_name_token, tablename) |
| set_token = self._token_next_match(parent, token, Tokens.Keyword, 'SET') |
| if set_token: |
| token = set_token |
| while token and \ |
| not isinstance(token, Types.Where) and \ |
| not token.match(Tokens.Keyword, 'LIMIT'): |
| if not token.match(Tokens.Keyword, 'SET') and \ |
| not token.match(Tokens.Punctuation, ','): |
| raise Exception("Invalid UPDATE statement, failed to match separator") |
| column_token = self._token_next(parent, token) |
| if isinstance(column_token, Types.Comparison): |
| token = self._token_next(parent, column_token) |
| continue |
| equals_token = self._token_next(parent, column_token) |
| if not equals_token.match(Tokens.Token.Operator.Comparison, '='): |
| raise Exception("Invalid UPDATE statement, SET equals token mismatch") |
| expression_token = self._token_next(parent, equals_token) |
| expression_token = self._expression_token_unwind_hack(parent, expression_token, equals_token) |
| self._eval_expression_value(parent, expression_token) |
| token = self._token_next(parent, expression_token) |
| start_token = token |
| self._update_delete_where_limit(tablename, parent, start_token) |
| return |
| |
| def _delete(self, parent, start_token): |
| token = self._token_next(parent, start_token) |
| if not token.match(Tokens.Keyword, 'FROM'): |
| raise Exception("Invalid DELETE statement") |
| table_name_token = self._token_next(parent, token) |
| tablename = self._get_entity_name_from_token(parent, table_name_token) |
| if not tablename: |
| raise Exception("Invalid DELETE statement, expected table name") |
| start_token = self._replace_table_entity_name(parent, table_name_token, tablename) |
| self._update_delete_where_limit(tablename, parent, start_token) |
| return |
| |
| def _create(self, parent, start_token): |
| token = self._token_next(parent, start_token) |
| if token.match(Tokens.Keyword, 'TEMPORARY'): |
| token = self._token_next(parent, token) |
| if token.match(Tokens.Keyword, 'TABLE'): |
| token = self._token_next(parent, token) |
| while token.match(Tokens.Keyword, ['IF', 'NOT', 'EXIST']) or \ |
| token.is_whitespace(): |
| token = self._token_next(parent, token) |
| table_name = self._get_entity_name_from_token(parent, token) |
| if not table_name: |
| raise Exception("Invalid CREATE TABLE statement, expected table name") |
| |
| as_token = self._token_next_match(parent, token, |
| Tokens.Keyword, 'AS') |
| self._replace_table_entity_name(parent, token, table_name) |
| |
| if as_token: |
| select_token = self._token_next_match(parent, as_token, |
| Tokens.DML, 'SELECT') |
| if select_token: |
| return self._select(parent, select_token) |
| elif token.match(Tokens.Keyword, ['UNIQUE', 'INDEX']): |
| if token.match(Tokens.Keyword, 'UNIQUE'): |
| token = self._token_next(parent, token) |
| if token.match(Tokens.Keyword, 'INDEX'): |
| index_token = self._token_next(parent, token) |
| index_name = self._get_entity_name_from_token(parent, index_token) |
| if not index_name: |
| raise Exception("Invalid CREATE INDEX statement, expected index name") |
| on_token = self._token_next_match(parent, index_token, Tokens.Keyword, 'ON') |
| if not on_token: |
| raise Exception("Invalid CREATE INDEX statement, expected ON specifier") |
| table_name_token = self._token_next(parent, on_token) |
| table_name = self._get_entity_name_from_token(parent, table_name_token) |
| if not table_name: |
| raise Exception("Invalid CREATE INDEX statement, expected table name") |
| self._replace_table_entity_name(parent, table_name_token, table_name) |
| self._replace_table_entity_name(parent, index_token, table_name, entity_name=index_name) |
| return |
| |
| def _alter(self, parent, start_token): |
| token = self._token_next(parent, start_token) |
| if token.match(Tokens.Keyword, 'TABLE'): |
| token = self._token_next(parent, token) |
| table_name = self._get_entity_name_from_token(parent, token) |
| if not table_name: |
| raise Exception("Invalid CREATE TABLE statement, expected table name") |
| token = self._replace_table_entity_name(parent, token, table_name) |
| if token.match(Tokens.Keyword.DDL, ['ADD', 'DROP']) or\ |
| token.match(Tokens.Keyword, ['ADD', 'DROP']): |
| token = self._token_next(parent, token) |
| if token.match(Tokens.Keyword, 'CONSTRAINT'): |
| token = self._token_next(parent, token) |
| constraint_name = self._get_entity_name_from_token(parent, token) |
| if not constraint_name: |
| raise Exception("Invalid ALTER TABLE statement, expected constraint name") |
| self._replace_table_entity_name(parent, token, table_name, constraint_name) |
| return |
| |
| def _drop(self, parent, start_token): |
| token = self._token_next(parent, start_token) |
| if token.match(Tokens.Keyword, 'TABLE'): |
| token = self._token_next(parent, token) |
| while token.match(Tokens.Keyword, ['IF', 'EXIST']) or\ |
| token.is_whitespace(): |
| token = self._token_next(parent, token) |
| table_name = self._get_entity_name_from_token(parent, token) |
| if not table_name: |
| raise Exception("Invalid DROP TABLE statement, expected table name") |
| self._replace_table_entity_name(parent, token, table_name) |
| return |
| |
| def translate(self, sql): |
| dml_handlers = {'SELECT': self._select, |
| 'INSERT': self._insert, |
| 'UPDATE': self._update, |
| 'DELETE': self._delete, |
| } |
| ddl_handlers = {'CREATE': self._create, |
| 'ALTER': self._alter, |
| 'DROP': self._drop, |
| } |
| try: |
| formatted_sql = lambda sql: sql.to_unicode() |
| sql_statement = sqlparse.parse(sql)[0] |
| if '_' in sql: |
| self._sqlparse_underline_hack(sql_statement) |
| t = sql_statement.token_first() |
| if t.match(Tokens.DML, dml_handlers.keys()): |
| dml_handlers[t.normalized](sql_statement, t) |
| sql = formatted_sql(sql_statement) |
| elif t.match(Tokens.DDL, ddl_handlers.keys()): |
| ddl_handlers[t.normalized](sql_statement, t) |
| sql = formatted_sql(sql_statement) |
| else: |
| pass |
| except Exception, ex: |
| raise Exception("Failed to translate SQL '%s', exception '%s'" % (sql, ex.message)) |
| return sql |