| # -*- coding: utf-8 -*- |
| |
| import re |
| |
| from os.path import abspath, join |
| |
| from sqlparse import sql, tokens as T |
| from sqlparse.engine import FilterStack |
| from sqlparse.lexer import tokenize |
| from sqlparse.pipeline import Pipeline |
| from sqlparse.tokens import (Comment, Comparison, Keyword, Name, Punctuation, |
| String, Whitespace) |
| from sqlparse.utils import memoize_generator |
| from sqlparse.utils import split_unquoted_newlines |
| |
| |
| # -------------------------- |
| # token process |
| |
| class _CaseFilter: |
| |
| ttype = None |
| |
| def __init__(self, case=None): |
| if case is None: |
| case = 'upper' |
| assert case in ['lower', 'upper', 'capitalize'] |
| self.convert = getattr(unicode, case) |
| |
| def process(self, stack, stream): |
| for ttype, value in stream: |
| if ttype in self.ttype: |
| value = self.convert(value) |
| yield ttype, value |
| |
| |
| class KeywordCaseFilter(_CaseFilter): |
| ttype = T.Keyword |
| |
| |
| class IdentifierCaseFilter(_CaseFilter): |
| ttype = (T.Name, T.String.Symbol) |
| |
| def process(self, stack, stream): |
| for ttype, value in stream: |
| if ttype in self.ttype and not value.strip()[0] == '"': |
| value = self.convert(value) |
| yield ttype, value |
| |
| |
| class TruncateStringFilter: |
| |
| def __init__(self, width, char): |
| self.width = max(width, 1) |
| self.char = unicode(char) |
| |
| def process(self, stack, stream): |
| for ttype, value in stream: |
| if ttype is T.Literal.String.Single: |
| if value[:2] == '\'\'': |
| inner = value[2:-2] |
| quote = u'\'\'' |
| else: |
| inner = value[1:-1] |
| quote = u'\'' |
| if len(inner) > self.width: |
| value = u''.join((quote, inner[:self.width], self.char, |
| quote)) |
| yield ttype, value |
| |
| |
| class GetComments: |
| """Get the comments from a stack""" |
| def process(self, stack, stream): |
| for token_type, value in stream: |
| if token_type in Comment: |
| yield token_type, value |
| |
| |
| class StripComments: |
| """Strip the comments from a stack""" |
| def process(self, stack, stream): |
| for token_type, value in stream: |
| if token_type not in Comment: |
| yield token_type, value |
| |
| |
| def StripWhitespace(stream): |
| "Strip the useless whitespaces from a stream leaving only the minimal ones" |
| last_type = None |
| has_space = False |
| ignore_group = frozenset((Comparison, Punctuation)) |
| |
| for token_type, value in stream: |
| # We got a previous token (not empty first ones) |
| if last_type: |
| if token_type in Whitespace: |
| has_space = True |
| continue |
| |
| # Ignore first empty spaces and dot-commas |
| elif token_type in (Whitespace, Whitespace.Newline, ignore_group): |
| continue |
| |
| # Yield a whitespace if it can't be ignored |
| if has_space: |
| if not ignore_group.intersection((last_type, token_type)): |
| yield Whitespace, ' ' |
| has_space = False |
| |
| # Yield the token and set its type for checking with the next one |
| yield token_type, value |
| last_type = token_type |
| |
| |
| class IncludeStatement: |
| """Filter that enable a INCLUDE statement""" |
| |
| def __init__(self, dirpath=".", maxrecursive=10, raiseexceptions=False): |
| if maxrecursive <= 0: |
| raise ValueError('Max recursion limit reached') |
| |
| self.dirpath = abspath(dirpath) |
| self.maxRecursive = maxrecursive |
| self.raiseexceptions = raiseexceptions |
| |
| self.detected = False |
| |
| @memoize_generator |
| def process(self, stack, stream): |
| # Run over all tokens in the stream |
| for token_type, value in stream: |
| # INCLUDE statement found, set detected mode |
| if token_type in Name and value.upper() == 'INCLUDE': |
| self.detected = True |
| continue |
| |
| # INCLUDE statement was found, parse it |
| elif self.detected: |
| # Omit whitespaces |
| if token_type in Whitespace: |
| continue |
| |
| # Found file path to include |
| if token_type in String.Symbol: |
| # if token_type in tokens.String.Symbol: |
| |
| # Get path of file to include |
| path = join(self.dirpath, value[1:-1]) |
| |
| try: |
| f = open(path) |
| raw_sql = f.read() |
| f.close() |
| |
| # There was a problem loading the include file |
| except IOError, err: |
| # Raise the exception to the interpreter |
| if self.raiseexceptions: |
| raise |
| |
| # Put the exception as a comment on the SQL code |
| yield Comment, u'-- IOError: %s\n' % err |
| |
| else: |
| # Create new FilterStack to parse readed file |
| # and add all its tokens to the main stack recursively |
| try: |
| filtr = IncludeStatement(self.dirpath, |
| self.maxRecursive - 1, |
| self.raiseexceptions) |
| |
| # Max recursion limit reached |
| except ValueError, err: |
| # Raise the exception to the interpreter |
| if self.raiseexceptions: |
| raise |
| |
| # Put the exception as a comment on the SQL code |
| yield Comment, u'-- ValueError: %s\n' % err |
| |
| stack = FilterStack() |
| stack.preprocess.append(filtr) |
| |
| for tv in stack.run(raw_sql): |
| yield tv |
| |
| # Set normal mode |
| self.detected = False |
| |
| # Don't include any token while in detected mode |
| continue |
| |
| # Normal token |
| yield token_type, value |
| |
| |
| # ---------------------- |
| # statement process |
| |
| class StripCommentsFilter: |
| |
| def _get_next_comment(self, tlist): |
| # TODO(andi) Comment types should be unified, see related issue38 |
| token = tlist.token_next_by_instance(0, sql.Comment) |
| if token is None: |
| token = tlist.token_next_by_type(0, T.Comment) |
| return token |
| |
| def _process(self, tlist): |
| token = self._get_next_comment(tlist) |
| while token: |
| tidx = tlist.token_index(token) |
| prev = tlist.token_prev(tidx, False) |
| next_ = tlist.token_next(tidx, False) |
| # Replace by whitespace if prev and next exist and if they're not |
| # whitespaces. This doesn't apply if prev or next is a paranthesis. |
| if (prev is not None and next_ is not None |
| and not prev.is_whitespace() and not next_.is_whitespace() |
| and not (prev.match(T.Punctuation, '(') |
| or next_.match(T.Punctuation, ')'))): |
| tlist.tokens[tidx] = sql.Token(T.Whitespace, ' ') |
| else: |
| tlist.tokens.pop(tidx) |
| token = self._get_next_comment(tlist) |
| |
| def process(self, stack, stmt): |
| [self.process(stack, sgroup) for sgroup in stmt.get_sublists()] |
| self._process(stmt) |
| |
| |
| class StripWhitespaceFilter: |
| |
| def _stripws(self, tlist): |
| func_name = '_stripws_%s' % tlist.__class__.__name__.lower() |
| func = getattr(self, func_name, self._stripws_default) |
| func(tlist) |
| |
| def _stripws_default(self, tlist): |
| last_was_ws = False |
| for token in tlist.tokens: |
| if token.is_whitespace(): |
| if last_was_ws: |
| token.value = '' |
| else: |
| token.value = ' ' |
| last_was_ws = token.is_whitespace() |
| |
| def _stripws_identifierlist(self, tlist): |
| # Removes newlines before commas, see issue140 |
| last_nl = None |
| for token in tlist.tokens[:]: |
| if (token.ttype is T.Punctuation |
| and token.value == ',' |
| and last_nl is not None): |
| tlist.tokens.remove(last_nl) |
| if token.is_whitespace(): |
| last_nl = token |
| else: |
| last_nl = None |
| return self._stripws_default(tlist) |
| |
| def _stripws_parenthesis(self, tlist): |
| if tlist.tokens[1].is_whitespace(): |
| tlist.tokens.pop(1) |
| if tlist.tokens[-2].is_whitespace(): |
| tlist.tokens.pop(-2) |
| self._stripws_default(tlist) |
| |
| def process(self, stack, stmt, depth=0): |
| [self.process(stack, sgroup, depth + 1) |
| for sgroup in stmt.get_sublists()] |
| self._stripws(stmt) |
| if ( |
| depth == 0 |
| and stmt.tokens |
| and stmt.tokens[-1].is_whitespace() |
| ): |
| stmt.tokens.pop(-1) |
| |
| |
| class ReindentFilter: |
| |
| def __init__(self, width=2, char=' ', line_width=None): |
| self.width = width |
| self.char = char |
| self.indent = 0 |
| self.offset = 0 |
| self.line_width = line_width |
| self._curr_stmt = None |
| self._last_stmt = None |
| |
| def _flatten_up_to_token(self, token): |
| """Yields all tokens up to token plus the next one.""" |
| # helper for _get_offset |
| iterator = self._curr_stmt.flatten() |
| for t in iterator: |
| yield t |
| if t == token: |
| raise StopIteration |
| |
| def _get_offset(self, token): |
| raw = ''.join(map(unicode, self._flatten_up_to_token(token))) |
| line = raw.splitlines()[-1] |
| # Now take current offset into account and return relative offset. |
| full_offset = len(line) - len(self.char * (self.width * self.indent)) |
| return full_offset - self.offset |
| |
| def nl(self): |
| # TODO: newline character should be configurable |
| space = (self.char * ((self.indent * self.width) + self.offset)) |
| # Detect runaway indenting due to parsing errors |
| if len(space) > 200: |
| # something seems to be wrong, flip back |
| self.indent = self.offset = 0 |
| space = (self.char * ((self.indent * self.width) + self.offset)) |
| ws = '\n' + space |
| return sql.Token(T.Whitespace, ws) |
| |
| def _split_kwds(self, tlist): |
| split_words = ('FROM', 'STRAIGHT_JOIN$', 'JOIN$', 'AND', 'OR', |
| 'GROUP', 'ORDER', 'UNION', 'VALUES', |
| 'SET', 'BETWEEN', 'EXCEPT', 'HAVING') |
| |
| def _next_token(i): |
| t = tlist.token_next_match(i, T.Keyword, split_words, |
| regex=True) |
| if t and t.value.upper() == 'BETWEEN': |
| t = _next_token(tlist.token_index(t) + 1) |
| if t and t.value.upper() == 'AND': |
| t = _next_token(tlist.token_index(t) + 1) |
| return t |
| |
| idx = 0 |
| token = _next_token(idx) |
| added = set() |
| while token: |
| prev = tlist.token_prev(tlist.token_index(token), False) |
| offset = 1 |
| if prev and prev.is_whitespace() and prev not in added: |
| tlist.tokens.pop(tlist.token_index(prev)) |
| offset += 1 |
| uprev = unicode(prev) |
| if (prev and (uprev.endswith('\n') or uprev.endswith('\r'))): |
| nl = tlist.token_next(token) |
| else: |
| nl = self.nl() |
| added.add(nl) |
| tlist.insert_before(token, nl) |
| offset += 1 |
| token = _next_token(tlist.token_index(nl) + offset) |
| |
| def _split_statements(self, tlist): |
| idx = 0 |
| token = tlist.token_next_by_type(idx, (T.Keyword.DDL, T.Keyword.DML)) |
| while token: |
| prev = tlist.token_prev(tlist.token_index(token), False) |
| if prev and prev.is_whitespace(): |
| tlist.tokens.pop(tlist.token_index(prev)) |
| # only break if it's not the first token |
| if prev: |
| nl = self.nl() |
| tlist.insert_before(token, nl) |
| token = tlist.token_next_by_type(tlist.token_index(token) + 1, |
| (T.Keyword.DDL, T.Keyword.DML)) |
| |
| def _process(self, tlist): |
| func_name = '_process_%s' % tlist.__class__.__name__.lower() |
| func = getattr(self, func_name, self._process_default) |
| func(tlist) |
| |
| def _process_where(self, tlist): |
| token = tlist.token_next_match(0, T.Keyword, 'WHERE') |
| try: |
| tlist.insert_before(token, self.nl()) |
| except ValueError: # issue121, errors in statement |
| pass |
| self.indent += 1 |
| self._process_default(tlist) |
| self.indent -= 1 |
| |
| def _process_having(self, tlist): |
| token = tlist.token_next_match(0, T.Keyword, 'HAVING') |
| try: |
| tlist.insert_before(token, self.nl()) |
| except ValueError: # issue121, errors in statement |
| pass |
| self.indent += 1 |
| self._process_default(tlist) |
| self.indent -= 1 |
| |
| def _process_parenthesis(self, tlist): |
| first = tlist.token_next(0) |
| indented = False |
| if first and first.ttype in (T.Keyword.DML, T.Keyword.DDL): |
| self.indent += 1 |
| tlist.tokens.insert(0, self.nl()) |
| indented = True |
| num_offset = self._get_offset( |
| tlist.token_next_match(0, T.Punctuation, '(')) |
| self.offset += num_offset |
| self._process_default(tlist, stmts=not indented) |
| if indented: |
| self.indent -= 1 |
| self.offset -= num_offset |
| |
| def _process_identifierlist(self, tlist): |
| identifiers = list(tlist.get_identifiers()) |
| if len(identifiers) > 1 and not tlist.within(sql.Function): |
| first = list(identifiers[0].flatten())[0] |
| if self.char == '\t': |
| # when using tabs we don't count the actual word length |
| # in spaces. |
| num_offset = 1 |
| else: |
| num_offset = self._get_offset(first) - len(first.value) |
| self.offset += num_offset |
| for token in identifiers[1:]: |
| tlist.insert_before(token, self.nl()) |
| self.offset -= num_offset |
| self._process_default(tlist) |
| |
| def _process_case(self, tlist): |
| is_first = True |
| num_offset = None |
| case = tlist.tokens[0] |
| outer_offset = self._get_offset(case) - len(case.value) |
| self.offset += outer_offset |
| for cond, value in tlist.get_cases(): |
| if is_first: |
| tcond = list(cond[0].flatten())[0] |
| is_first = False |
| num_offset = self._get_offset(tcond) - len(tcond.value) |
| self.offset += num_offset |
| continue |
| if cond is None: |
| token = value[0] |
| else: |
| token = cond[0] |
| tlist.insert_before(token, self.nl()) |
| # Line breaks on group level are done. Now let's add an offset of |
| # 5 (=length of "when", "then", "else") and process subgroups. |
| self.offset += 5 |
| self._process_default(tlist) |
| self.offset -= 5 |
| if num_offset is not None: |
| self.offset -= num_offset |
| end = tlist.token_next_match(0, T.Keyword, 'END') |
| tlist.insert_before(end, self.nl()) |
| self.offset -= outer_offset |
| |
| def _process_default(self, tlist, stmts=True, kwds=True): |
| if stmts: |
| self._split_statements(tlist) |
| if kwds: |
| self._split_kwds(tlist) |
| [self._process(sgroup) for sgroup in tlist.get_sublists()] |
| |
| def process(self, stack, stmt): |
| if isinstance(stmt, sql.Statement): |
| self._curr_stmt = stmt |
| self._process(stmt) |
| if isinstance(stmt, sql.Statement): |
| if self._last_stmt is not None: |
| if unicode(self._last_stmt).endswith('\n'): |
| nl = '\n' |
| else: |
| nl = '\n\n' |
| stmt.tokens.insert( |
| 0, sql.Token(T.Whitespace, nl)) |
| if self._last_stmt != stmt: |
| self._last_stmt = stmt |
| |
| |
| # FIXME: Doesn't work ;) |
| class RightMarginFilter: |
| |
| keep_together = ( |
| # sql.TypeCast, sql.Identifier, sql.Alias, |
| ) |
| |
| def __init__(self, width=79): |
| self.width = width |
| self.line = '' |
| |
| def _process(self, stack, group, stream): |
| for token in stream: |
| if token.is_whitespace() and '\n' in token.value: |
| if token.value.endswith('\n'): |
| self.line = '' |
| else: |
| self.line = token.value.splitlines()[-1] |
| elif (token.is_group() |
| and not token.__class__ in self.keep_together): |
| token.tokens = self._process(stack, token, token.tokens) |
| else: |
| val = unicode(token) |
| if len(self.line) + len(val) > self.width: |
| match = re.search('^ +', self.line) |
| if match is not None: |
| indent = match.group() |
| else: |
| indent = '' |
| yield sql.Token(T.Whitespace, '\n%s' % indent) |
| self.line = indent |
| self.line += val |
| yield token |
| |
| def process(self, stack, group): |
| return |
| group.tokens = self._process(stack, group, group.tokens) |
| |
| |
| class ColumnsSelect: |
| """Get the columns names of a SELECT query""" |
| def process(self, stack, stream): |
| mode = 0 |
| oldValue = "" |
| parenthesis = 0 |
| |
| for token_type, value in stream: |
| # Ignore comments |
| if token_type in Comment: |
| continue |
| |
| # We have not detected a SELECT statement |
| if mode == 0: |
| if token_type in Keyword and value == 'SELECT': |
| mode = 1 |
| |
| # We have detected a SELECT statement |
| elif mode == 1: |
| if value == 'FROM': |
| if oldValue: |
| yield oldValue |
| |
| mode = 3 # Columns have been checked |
| |
| elif value == 'AS': |
| oldValue = "" |
| mode = 2 |
| |
| elif (token_type == Punctuation |
| and value == ',' and not parenthesis): |
| if oldValue: |
| yield oldValue |
| oldValue = "" |
| |
| elif token_type not in Whitespace: |
| if value == '(': |
| parenthesis += 1 |
| elif value == ')': |
| parenthesis -= 1 |
| |
| oldValue += value |
| |
| # We are processing an AS keyword |
| elif mode == 2: |
| # We check also for Keywords because a bug in SQLParse |
| if token_type == Name or token_type == Keyword: |
| yield value |
| mode = 1 |
| |
| |
| # --------------------------- |
| # postprocess |
| |
| class SerializerUnicode: |
| |
| def process(self, stack, stmt): |
| raw = unicode(stmt) |
| lines = split_unquoted_newlines(raw) |
| res = '\n'.join(line.rstrip() for line in lines) |
| return res |
| |
| |
| def Tokens2Unicode(stream): |
| result = "" |
| |
| for _, value in stream: |
| result += unicode(value) |
| |
| return result |
| |
| |
| class OutputFilter: |
| varname_prefix = '' |
| |
| def __init__(self, varname='sql'): |
| self.varname = self.varname_prefix + varname |
| self.count = 0 |
| |
| def _process(self, stream, varname, has_nl): |
| raise NotImplementedError |
| |
| def process(self, stack, stmt): |
| self.count += 1 |
| if self.count > 1: |
| varname = '%s%d' % (self.varname, self.count) |
| else: |
| varname = self.varname |
| |
| has_nl = len(unicode(stmt).strip().splitlines()) > 1 |
| stmt.tokens = self._process(stmt.tokens, varname, has_nl) |
| return stmt |
| |
| |
| class OutputPythonFilter(OutputFilter): |
| def _process(self, stream, varname, has_nl): |
| # SQL query asignation to varname |
| if self.count > 1: |
| yield sql.Token(T.Whitespace, '\n') |
| yield sql.Token(T.Name, varname) |
| yield sql.Token(T.Whitespace, ' ') |
| yield sql.Token(T.Operator, '=') |
| yield sql.Token(T.Whitespace, ' ') |
| if has_nl: |
| yield sql.Token(T.Operator, '(') |
| yield sql.Token(T.Text, "'") |
| |
| # Print the tokens on the quote |
| for token in stream: |
| # Token is a new line separator |
| if token.is_whitespace() and '\n' in token.value: |
| # Close quote and add a new line |
| yield sql.Token(T.Text, " '") |
| yield sql.Token(T.Whitespace, '\n') |
| |
| # Quote header on secondary lines |
| yield sql.Token(T.Whitespace, ' ' * (len(varname) + 4)) |
| yield sql.Token(T.Text, "'") |
| |
| # Indentation |
| after_lb = token.value.split('\n', 1)[1] |
| if after_lb: |
| yield sql.Token(T.Whitespace, after_lb) |
| continue |
| |
| # Token has escape chars |
| elif "'" in token.value: |
| token.value = token.value.replace("'", "\\'") |
| |
| # Put the token |
| yield sql.Token(T.Text, token.value) |
| |
| # Close quote |
| yield sql.Token(T.Text, "'") |
| if has_nl: |
| yield sql.Token(T.Operator, ')') |
| |
| |
| class OutputPHPFilter(OutputFilter): |
| varname_prefix = '$' |
| |
| def _process(self, stream, varname, has_nl): |
| # SQL query asignation to varname (quote header) |
| if self.count > 1: |
| yield sql.Token(T.Whitespace, '\n') |
| yield sql.Token(T.Name, varname) |
| yield sql.Token(T.Whitespace, ' ') |
| if has_nl: |
| yield sql.Token(T.Whitespace, ' ') |
| yield sql.Token(T.Operator, '=') |
| yield sql.Token(T.Whitespace, ' ') |
| yield sql.Token(T.Text, '"') |
| |
| # Print the tokens on the quote |
| for token in stream: |
| # Token is a new line separator |
| if token.is_whitespace() and '\n' in token.value: |
| # Close quote and add a new line |
| yield sql.Token(T.Text, ' ";') |
| yield sql.Token(T.Whitespace, '\n') |
| |
| # Quote header on secondary lines |
| yield sql.Token(T.Name, varname) |
| yield sql.Token(T.Whitespace, ' ') |
| yield sql.Token(T.Operator, '.=') |
| yield sql.Token(T.Whitespace, ' ') |
| yield sql.Token(T.Text, '"') |
| |
| # Indentation |
| after_lb = token.value.split('\n', 1)[1] |
| if after_lb: |
| yield sql.Token(T.Whitespace, after_lb) |
| continue |
| |
| # Token has escape chars |
| elif '"' in token.value: |
| token.value = token.value.replace('"', '\\"') |
| |
| # Put the token |
| yield sql.Token(T.Text, token.value) |
| |
| # Close quote |
| yield sql.Token(T.Text, '"') |
| yield sql.Token(T.Punctuation, ';') |
| |
| |
| class Limit: |
| """Get the LIMIT of a query. |
| |
| If not defined, return -1 (SQL specification for no LIMIT query) |
| """ |
| def process(self, stack, stream): |
| index = 7 |
| stream = list(stream) |
| stream.reverse() |
| |
| # Run over all tokens in the stream from the end |
| for token_type, value in stream: |
| index -= 1 |
| |
| # if index and token_type in Keyword: |
| if index and token_type in Keyword and value == 'LIMIT': |
| return stream[4 - index][1] |
| |
| return -1 |
| |
| |
| def compact(stream): |
| """Function that return a compacted version of the stream""" |
| pipe = Pipeline() |
| |
| pipe.append(StripComments()) |
| pipe.append(StripWhitespace) |
| |
| return pipe(stream) |