| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from inspect import getmro |
| from logging import getLogger |
| from re import sub |
| from sqlparse import format |
| |
| from tests.comparison.common import StructColumn, CollectionColumn |
| from tests.comparison.db_types import ( |
| Char, |
| Decimal, |
| Double, |
| Float, |
| Int, |
| String, |
| Timestamp, |
| TinyInt, |
| VarChar) |
| from tests.comparison.query import InsertClause, Query |
| from tests.comparison.query_flattener import QueryFlattener |
| |
| LOG = getLogger(__name__) |
| |
| |
| class SqlWriter(object): |
| '''Subclasses of SQLWriter will take a Query and provide the SQL representation for a |
| specific database such as Impala or MySQL. The SqlWriter.create([dialect=]) |
| factory method may be used instead of specifying the concrete class. |
| |
| ''' |
| |
| @staticmethod |
| def create(dialect='impala', nulls_order_asc='DEFAULT'): |
| '''Create and return a new SqlWriter appropriate for the given sql dialect. "dialect" |
| refers to database specific deviations of sql, and the val should be one of |
| "IMPALA", "MYSQL", or "POSTGRESQL". |
| ''' |
| dialect = dialect.upper() |
| if dialect == 'IMPALA': |
| return ImpalaSqlWriter(nulls_order_asc) |
| if dialect == 'MYSQL': |
| return MySQLSqlWriter(nulls_order_asc) |
| if dialect == 'ORACLE': |
| return OracleSqlWriter(nulls_order_asc) |
| if dialect == 'POSTGRESQL': |
| return PostgresqlSqlWriter(nulls_order_asc) |
| if dialect == 'HIVE': |
| return HiveSqlWriter(nulls_order_asc) |
| raise Exception('Unknown dialect: %s' % dialect) |
| |
| def __init__(self, nulls_order_asc): |
| if nulls_order_asc not in ('BEFORE', 'AFTER', 'DEFAULT'): |
| raise Exception('Unknown nulls order: %s' % nulls_order_asc) |
| self.nulls_order_asc = nulls_order_asc |
| |
| # Functions that don't follow the usual call syntax of foo(bar, baz) can be listed |
| # here. Parenthesis were added everywhere to avoid problems with operator precedence. |
| # TODO: Account for operator precedence... |
| self.operator_funcs = { |
| 'And': '({0}) AND ({1})', |
| 'Or': '({0}) OR ({1})', |
| 'Plus': '({0}) + ({1})', |
| 'Minus': '({0}) - ({1})', |
| 'Multiply': '({0}) * ({1})', |
| 'Divide': '({0}) / ({1})', |
| 'Equals': '({0}) = ({1})', |
| 'NotEquals': '({0}) != ({1})', |
| 'IsNotDistinctFrom': '({0}) IS NOT DISTINCT FROM ({1})', |
| # If a database supports the operator version of 'IS NOT DISTINCT FROM', it |
| # should overwrite the value of 'IsNotDistinctFromOp'. |
| 'IsNotDistinctFromOp': '({0}) IS NOT DISTINCT FROM ({1})', |
| 'IsDistinctFrom': '({0}) IS DISTINCT FROM ({1})', |
| 'LessThan': '({0}) < ({1})', |
| 'GreaterThan': '({0}) > ({1})', |
| 'LessThanOrEquals': '({0}) <= ({1})', |
| 'GreaterThanOrEquals': '({0}) >= ({1})', |
| 'IsNull': '({0}) IS NULL', |
| 'IsNotNull': '({0}) IS NOT NULL'} |
| |
| def write_query(self, statement, pretty=False): |
| """ |
| Return SQL as a string for the given query. |
| |
| If "pretty" is True, the SQL will be formatted (though not very well) with new |
| lines and indentation. |
| """ |
| sql = self._write(statement) |
| if pretty: |
| sql = self.make_pretty_sql(sql) |
| return sql |
| |
| def _write_query(self, query): |
| """ |
| Taking in a Query object with some attributes set, return a string |
| representation of the query in the correct dialect. |
| |
| This is just another dispatch destination of self._write(). When |
| self._write(Query) is called, that's dispatched to self._write_query(Query) |
| """ |
| sql = list() |
| # Write out each section in the proper order |
| for clause in ( |
| query.with_clause, |
| query.select_clause, |
| query.from_clause, |
| query.where_clause, |
| query.group_by_clause, |
| query.having_clause, |
| query.union_clause, |
| query.order_by_clause, |
| query.limit_clause |
| ): |
| if clause: |
| sql.append(self._write(clause)) |
| sql = '\n'.join(sql) |
| return sql |
| |
| def _write_insert_statement(self, insert_statement): |
| """ |
| Taking in a InsertStatement object with some attributes set, return a string |
| representation of the query in the correct dialect. |
| """ |
| sql = list() |
| |
| if insert_statement.with_clause: |
| sql.append(self._write(insert_statement.with_clause)) |
| |
| if insert_statement.insert_clause: |
| sql.append(self._write(insert_statement.insert_clause)) |
| else: |
| raise Exception('InsertStatement is missing insert_clause attribute') |
| |
| if insert_statement.select_query and not insert_statement.values_clause: |
| sql.append(self._write(insert_statement.select_query)) |
| elif not insert_statement.select_query and insert_statement.values_clause: |
| sql.append(self._write(insert_statement.values_clause)) |
| else: |
| raise Exception('InsertStatement must have a select_query xor a values clause') |
| |
| sql = '\n'.join(sql) |
| return sql |
| |
| def make_pretty_sql(self, sql): |
| try: |
| sql = format(sql, reindent=True) |
| except Exception as e: |
| LOG.warn('Unable to format sql: %s', e) |
| return sql |
| |
| def write_create_table_as(self, query, name, pretty=False): |
| return 'CREATE TABLE %s AS %s' % (name, self.write_query(query, pretty=pretty)) |
| |
| def write_create_view(self, query, name, pretty=False): |
| return 'CREATE VIEW %s AS %s' % (name, self.write_query(query, pretty=pretty)) |
| |
| def _write_with_clause(self, with_clause): |
| return 'WITH ' + ',\n'.join('%s AS (%s)' % (view.identifier, self._write(view.query)) |
| for view in with_clause.with_clause_inline_views) |
| |
| def _write_select_clause(self, select_clause): |
| if hasattr(select_clause, 'star_prefix'): |
| # This is a little hack to get query flattening to work (look at query_flattener.py |
| # TODO: Add proper support for SELECT *, and SELECT table_name.* |
| return 'SELECT %s.*' % select_clause.star_prefix |
| sql = 'SELECT' |
| if select_clause.distinct: |
| sql += ' DISTINCT' |
| sql += '\n' + ',\n'.join(self._write(item) for item in select_clause.items) |
| return sql |
| |
| def _write_select_item(self, select_item): |
| if select_item.alias: |
| return '{0} AS {1}'.format(self._write(select_item.val_expr), select_item.alias) |
| else: |
| return self._write(select_item.val_expr) |
| |
| def _write_struct_column(self, struct_col): |
| if isinstance(struct_col.owner, StructColumn) or \ |
| (isinstance(struct_col.owner, CollectionColumn) and not struct_col.owner.alias): |
| return '%s.%s' % (self._write(struct_col.owner), struct_col.name) |
| else: |
| return '%s.%s' % (struct_col.owner.identifier, struct_col.name) |
| |
| def _write_collection_column(self, collection_col): |
| if isinstance(collection_col.owner, (StructColumn, CollectionColumn)) and \ |
| not collection_col.owner.alias: |
| if collection_col.alias: |
| return '%s.%s %s' % ( |
| self._write(collection_col.owner), |
| collection_col.name, |
| collection_col.alias) |
| else: |
| return '%s.%s' % (self._write(collection_col.owner), collection_col.name) |
| else: |
| if collection_col.alias: |
| return '%s.%s %s' % ( |
| collection_col.owner.identifier, |
| collection_col.name, |
| collection_col.alias) |
| else: |
| return '%s.%s' % (collection_col.owner.identifier, collection_col.name) |
| |
| def _write_column(self, col): |
| if isinstance(col.owner, StructColumn): |
| return '%s.%s' % (self._write(col.owner), col.name) |
| return '%s.%s' % (col.owner.identifier, col.name) |
| |
| def _write_from_clause(self, from_clause): |
| sql = 'FROM %s' % self._write(from_clause.table_expr) |
| if from_clause.join_clauses: |
| sql += '\n' + '\n'.join(self._write(join) for join in from_clause.join_clauses) |
| return sql |
| |
| def _write_table(self, table): |
| if table.alias: |
| return '%s %s' % (table.name, table.identifier) |
| return table.name |
| |
| def _write_inline_view(self, inline_view): |
| if not inline_view.identifier: |
| raise Exception('An inline view requires an identifier') |
| return '(\n%s\n) %s' % (self._write(inline_view.query), inline_view.identifier) |
| |
| def _write_with_clause_inline_view(self, with_clause_inline_view): |
| if not with_clause_inline_view.with_clause_alias: |
| raise Exception('An with clause entry requires an identifier') |
| sql = with_clause_inline_view.with_clause_alias |
| if with_clause_inline_view.alias: |
| sql += ' ' + with_clause_inline_view.alias |
| return sql |
| |
| def _write_join_clause(self, join_clause): |
| sql = '%s JOIN %s' % (join_clause.join_type, self._write(join_clause.table_expr)) |
| if join_clause.boolean_expr: |
| sql += ' ON ' + self._write(join_clause.boolean_expr) |
| return sql |
| |
| def _write_where_clause(self, where_clause): |
| return 'WHERE\n' + self._write(where_clause.boolean_expr) |
| |
| def _write_group_by_clause(self, group_by_clause): |
| return 'GROUP BY\n' + ',\n'.join(self._write(item.val_expr) |
| for item in group_by_clause.group_by_items) |
| |
| def _write_having_clause(self, having_clause): |
| return 'HAVING\n' + self._write(having_clause.boolean_expr) |
| |
| def _write_union_clause(self, union_clause): |
| sql = 'UNION' |
| if union_clause.all: |
| sql += ' ALL' |
| sql += '\n' + self._write(union_clause.query) |
| return sql |
| |
| def _write_data_type(self, data_type): |
| '''Write a literal value.''' |
| if data_type.val is None: |
| return 'NULL' |
| if data_type.returns_char: |
| return "'{0}'".format(data_type.val) |
| if data_type.returns_timestamp: |
| return "CAST('{0}' AS {1})".format( |
| data_type.val, self._write_data_type_metaclass(Timestamp)) |
| return str(data_type.val) |
| |
| def _write_func(self, func): |
| if func.name() in self.operator_funcs: |
| sql = self.operator_funcs[func.name()].format( |
| *[self._write(arg) for arg in func.args]) |
| else: |
| sql = '%s(%s)' % \ |
| (self._to_sql_name(func.name()), self._write_as_comma_list(func.args)) |
| return sql |
| |
| def _write_exists(self, func): |
| return 'EXISTS ' + self._write(func.args[0]) |
| |
| def _write_not_exists(self, func): |
| return 'NOT EXISTS ' + self._write(func.args[0]) |
| |
| def _write_in(self, func, use_not=False): |
| sql = '(%s) ' % self._write(func.args[0]) |
| if use_not: |
| sql += 'NOT ' |
| sql += 'IN ' |
| if func.signature.args[1].is_subquery: |
| sql += self._write(func.args[1]) |
| else: |
| sql += '(' + self._write_as_comma_list(func.args[1:]) + ')' |
| return sql |
| |
| def _write_not_in(self, func): |
| return self._write_in(func, use_not=True) |
| |
| def _write_as_comma_list(self, items): |
| return ', '.join([self._write(item) for item in items]) |
| |
| def _write_cast_as_char(self, func): |
| return 'CAST(%s AS %s)' % (self._write(func.args[0]), self._write(String)) |
| |
| def _write_cast(self, arg, type): |
| return 'CAST(%s AS %s)' % (self._write(arg), type) |
| |
| def _write_cast_func(self, func): |
| val_expr = func.args[0] |
| type_ = func.args[1] |
| return 'CAST({val_expr} AS {type_})'.format( |
| val_expr=self._write(val_expr), type_=self._write(type_)) |
| |
| def _write_date_add_year(self, func): |
| return "%s + INTERVAL %s YEAR" \ |
| % (self._write(func.args[0]), self._write(func.args[1])) |
| |
| def _write_date_add_month(self, func): |
| return "%s + INTERVAL %s MONTH" \ |
| % (self._write(func.args[0]), self._write(func.args[1])) |
| |
| def _write_date_add_day(self, func): |
| return "%s + INTERVAL %s DAY" \ |
| % (self._write(func.args[0]), self._write(func.args[1])) |
| |
| def _write_date_add_hour(self, func): |
| return "%s + INTERVAL %s HOUR" \ |
| % (self._write(func.args[0]), self._write(func.args[1])) |
| |
| def _write_date_add_minute(self, func): |
| return "%s + INTERVAL %s MINUTE" \ |
| % (self._write(func.args[0]), self._write(func.args[1])) |
| |
| def _write_date_add_second(self, func): |
| return "%s + INTERVAL %s SECOND" \ |
| % (self._write(func.args[0]), self._write(func.args[1])) |
| |
| def _write_extract_year(self, func): |
| return 'EXTRACT(YEAR FROM %s)' % self._write(func.args[0]) |
| |
| def _write_extract_month(self, func): |
| return 'EXTRACT(MONTH FROM %s)' % self._write(func.args[0]) |
| |
| def _write_extract_day(self, func): |
| return 'EXTRACT(DAY FROM %s)' % self._write(func.args[0]) |
| |
| def _write_extract_hour(self, func): |
| return 'EXTRACT(HOUR FROM %s)' % self._write(func.args[0]) |
| |
| def _write_extract_minute(self, func): |
| return 'EXTRACT(MINUTE FROM %s)' % self._write(func.args[0]) |
| |
| def _write_extract_second(self, func): |
| return 'EXTRACT(SECOND FROM %s)' % self._write(func.args[0]) |
| |
| def _write_implicit_cast(self, func): |
| return self._write(func.args[0]) |
| |
| def _write_analytic_func(self, func): |
| sql = self._to_sql_name(func.name()) \ |
| + '(' + self._write_as_comma_list(func.args) \ |
| + ') OVER (' |
| options = [] |
| if func.partition_by_clause: |
| options.append(self._write(func.partition_by_clause)) |
| if func.order_by_clause: |
| options.append(self._write(func.order_by_clause)) |
| if func.window_clause: |
| options.append(self._write(func.window_clause)) |
| return sql + ' '.join(options) + ')' |
| |
| def _write_partition_by_clause(self, partition_by_clause): |
| return 'PARTITION BY ' + \ |
| ', '.join(self._write(expr) for expr in partition_by_clause.val_exprs) |
| |
| def _write_window_clause(self, window_clause): |
| sql = window_clause.range_or_rows.upper() + ' ' |
| if window_clause.end_boundary: |
| sql += 'BETWEEN ' |
| if window_clause.start_boundary.val_expr: |
| sql += self._write(window_clause.start_boundary.val_expr) + ' ' |
| sql += window_clause.start_boundary.boundary_type.upper() + ' ' |
| if window_clause.end_boundary: |
| sql += 'AND ' |
| if window_clause.end_boundary.val_expr: |
| sql += self._write(window_clause.end_boundary.val_expr) + ' ' |
| sql += window_clause.end_boundary.boundary_type.upper() |
| return sql |
| |
| def _write_agg_func(self, func): |
| sql = self._to_sql_name(func.name()) + '(' |
| if func.distinct: |
| sql += 'DISTINCT ' |
| # All agg funcs only have a single arg |
| sql += self._write(func.args[0]) + ')' |
| return sql |
| |
| def _write_data_type_metaclass(self, data_type_class): |
| '''Write a data type class such as Int, Boolean, or Decimal(4, 2).''' |
| if data_type_class == Char: |
| return 'CHAR({0})'.format(data_type_class.MAX) |
| elif data_type_class == VarChar: |
| return 'VARCHAR({0})'.format(data_type_class.MAX) |
| elif data_type_class == Decimal: |
| return 'DECIMAL({scale},{precision})'.format( |
| scale=data_type_class.MAX_DIGITS, |
| precision=data_type_class.MAX_FRACTIONAL_DIGITS) |
| else: |
| return data_type_class.__name__.upper() |
| |
| def _write_subquery(self, subquery): |
| return '({0})'.format(self._write(subquery.query)) |
| |
| def _write_order_by_clause(self, order_by_clause): |
| sql = 'ORDER BY ' |
| for idx, (expr, order) in enumerate(order_by_clause.exprs_to_order): |
| if idx > 0: |
| sql += ', ' |
| sql += self._write(expr) |
| if order: |
| sql += ' ' + order |
| nulls_order = self.get_nulls_order(order) |
| if nulls_order is not None: |
| sql += ' ' + nulls_order |
| return sql |
| |
| def _write_limit_clause(self, limit_clause): |
| return 'LIMIT {0}'.format(limit_clause.limit) |
| |
| def _write_insert_clause(self, insert_clause): |
| """ |
| Given an InsertClause, return a string representing that portion of the query. The |
| InsertClause object may have the column_list attribute set, which is a |
| sequence of columns. |
| """ |
| if insert_clause.column_list is None: |
| column_list = '' |
| else: |
| column_list = ' ({column_list})'.format( |
| column_list=', '.join([col.name for col in insert_clause.column_list])) |
| return 'INSERT INTO {table_name}{column_list}'.format( |
| table_name=insert_clause.table.name, column_list=column_list) |
| |
| def _write_values_row(self, values_row): |
| """ |
| Return a string representing 1 row of a VALUES clause. |
| """ |
| return '({values_row})'.format( |
| values_row=', '.join([self._write(item) for item in values_row.items])) |
| |
| def _write_values_clause(self, values_clause): |
| """ |
| Return a string representing the VALUES clause of an INSERT query. |
| """ |
| return 'VALUES\n{values_rows}'.format( |
| values_rows=',\n'.join([self._write(values_row) |
| for values_row in values_clause.values_rows])) |
| |
| def _write(self, object_): |
| '''Return a sql string representation of the given object.''' |
| # What's below is effectively a giant switch statement. It works based on a func |
| # naming and signature convention. It should match the incoming object with the |
| # corresponding func defined, then call the func and return the result. |
| # |
| # Ex: |
| # a = model.And(...) |
| # _write(a) should call _write_func(a) because "And" is a subclass of "Func" and no |
| # other _writer_<class name> methods have been defined higher up the method |
| # resolution order (MRO). If _write_and(...) were to be defined, it would be called |
| # instead. |
| for type_ in getmro(type(object_)): |
| writer_func_name = '_write_' + self._to_py_name(type_.__name__) |
| writer_func = getattr(self, writer_func_name, None) |
| if writer_func: |
| return writer_func(object_) |
| |
| # Handle any remaining cases |
| if isinstance(object_, Query): |
| return self.write_query(object_) |
| |
| raise Exception('Unsupported object: %s<%s>' % (type(object_).__name__, object_)) |
| |
| def get_nulls_order(self, order): |
| if self.nulls_order_asc is None: |
| return None |
| nulls_order_asc = self.nulls_order_asc |
| if order == 'ASC': |
| if nulls_order_asc == 'BEFORE': |
| return 'NULLS FIRST' |
| if nulls_order_asc == 'AFTER': |
| return 'NULLS LAST' |
| if order == 'DESC': |
| if nulls_order_asc == 'BEFORE': |
| return 'NULLS LAST' |
| if nulls_order_asc == 'AFTER': |
| return 'NULLS FIRST' |
| |
| def _to_py_name(self, name): |
| return sub('([A-Z])', r'_\1', name).lower().lstrip('_') |
| |
| def _to_sql_name(self, name): |
| return self._to_py_name(name).upper() |
| |
| |
| class ImpalaSqlWriter(SqlWriter): |
| |
| DIALECT = 'IMPALA' |
| |
| def __init__(self, *args, **kwargs): |
| super(ImpalaSqlWriter, self).__init__(*args, **kwargs) |
| self.operator_funcs['IsNotDistinctFromOp'] = '({0}) <=> ({1})' |
| |
| def _write_column(self, col): |
| result = super(ImpalaSqlWriter, self)._write_column(col) |
| if col.exact_type == Char: |
| # TRIM is a temporary workaround for IMPALA-1652 |
| result = 'TRIM(%s)' % result |
| return result |
| |
| def _write_insert_clause(self, insert_clause): |
| sql = super(ImpalaSqlWriter, self)._write_insert_clause(insert_clause) |
| if insert_clause.conflict_action == InsertClause.CONFLICT_ACTION_UPDATE: |
| # The value of sql at this point would be something like: |
| # |
| # INSERT INTO <table name> [(column list)] |
| # |
| # If it happens that the table name or column list contains the text INSERT in an |
| # identifier, we want to ensure that the replace() call below does not alter their |
| # names but instead only modifiers the INSERT keyword to UPSERT. |
| return sql.replace('INSERT', 'UPSERT', 1) |
| else: |
| return sql |
| |
| |
| class OracleSqlWriter(SqlWriter): |
| |
| DIALECT = 'ORACLE' |
| |
| |
| class HiveSqlWriter(SqlWriter): |
| |
| DIALECT = 'HIVE' |
| |
| def __init__(self, *args, **kwargs): |
| super(HiveSqlWriter, self).__init__(*args, **kwargs) |
| |
| self.operator_funcs.update({ |
| 'IsNotDistinctFrom': '({0}) <=> ({1})', |
| 'IsNotDistinctFromOp': '({0}) <=> ({1})', |
| 'IsDistinctFrom': 'NOT(({0}) <=> ({1}))' |
| }) |
| |
| # Hive greatest UDF is strict on type equality |
| # Hive Profile already restricts to signatures with the same types, |
| # but sometimes expression with UDF's like 'count' |
| # return an unpredictable type like 'bigint' unlike |
| # the query model, so cast is still necessary. |
| def _write_greatest(self, func): |
| args = func.args |
| if args[0].type in (Int, Decimal, Float): |
| argtype = args[0].type.__name__.lower() |
| sql = '%s(%s)' % ( |
| self._to_sql_name(func.name()), |
| self._write_cast(args[0], argtype) + ", " + self._write_cast(args[1], argtype)) |
| else: |
| sql = self._write_func(func) |
| return sql |
| |
| # Hive least UDF is strict on type equality |
| # Hive Profile already restricts to signatures with the same types, |
| # but sometimes expression with UDF's like 'count' |
| # return an unpredictable type like 'bigint' unlike |
| # the query model, so cast is still necessary. |
| def _write_least(self, func): |
| args = func.args |
| if args[0].type in (Int, Decimal, Float): |
| argtype = args[0].type.__name__.lower() |
| sql = '%s(%s)' % ( |
| self._to_sql_name(func.name()), |
| self._write_cast(args[0], argtype) + ", " + self._write_cast(args[1], argtype)) |
| else: |
| sql = self._write_func(func) |
| return sql |
| |
| # Workaround for tinyint casting issues when run against RefDb |
| # that might only have larger integers. |
| # This does all the arithmetic operations in terms of bigints. |
| def _write_plus(self, func): |
| return self.arithmetic_cast(func, '+') |
| |
| def _write_minus(self, func): |
| return self.arithmetic_cast(func, '-') |
| |
| def _write_multiply(self, func): |
| return self.arithmetic_cast(func, '*') |
| |
| def arithmetic_cast(self, func, symbol): |
| args = func.args |
| if args[0].type is Int and args[1].type is Int: |
| return 'CAST (%s AS BIGINT) %s CAST (%s AS BIGINT)' % ( |
| self._write(args[0]), symbol, self._write(args[1])) |
| else: |
| return self._write_func(func) |
| |
| # Hive partition by clause throws exception if sorted by more than one key, unless |
| # 'rows unbounded preceding' added. |
| def _write_analytic_func(self, func): |
| sql = self._to_sql_name(func.name()) \ |
| + '(' + self._write_as_comma_list(func.args) \ |
| + ') OVER (' |
| options = [] |
| if func.partition_by_clause: |
| options.append(self._write(func.partition_by_clause)) |
| if func.order_by_clause: |
| options.append(self._write(func.order_by_clause)) |
| if func.window_clause: |
| options.append(self._write(func.window_clause)) |
| if func.partition_by_clause and func.order_by_clause: |
| if len(func.order_by_clause.exprs_to_order) > 1: |
| if func.SUPPORTS_WINDOWING and func.window_clause is None: |
| options.append('rows unbounded preceding') |
| return sql + ' '.join(options) + ')' |
| |
| def _write_extract_year(self, func): |
| return 'YEAR(%s)' % self._write(func.args[0]) |
| |
| def _write_extract_month(self, func): |
| return 'MONTH(%s)' % self._write(func.args[0]) |
| |
| def _write_extract_day(self, func): |
| return 'DAY(%s)' % self._write(func.args[0]) |
| |
| def _write_extract_hour(self, func): |
| return 'HOUR(%s)' % self._write(func.args[0]) |
| |
| def _write_extract_minute(self, func): |
| return 'MINUTE(%s)' % self._write(func.args[0]) |
| |
| def _write_extract_second(self, func): |
| return 'SECOND(%s)' % self._write(func.args[0]) |
| |
| |
| class PostgresqlSqlWriter(SqlWriter): |
| |
| DIALECT = 'POSTGRESQL' |
| |
| def _write_insert_statement(self, insert_statement): |
| sql = SqlWriter._write_insert_statement(self, insert_statement) |
| if insert_statement.conflict_action == InsertClause.CONFLICT_ACTION_DEFAULT: |
| pass |
| elif insert_statement.conflict_action == InsertClause.CONFLICT_ACTION_IGNORE: |
| sql += '\nON CONFLICT DO NOTHING' |
| elif insert_statement.conflict_action == InsertClause.CONFLICT_ACTION_UPDATE: |
| if insert_statement.updatable_column_names: |
| primary_keys = insert_statement.primary_key_string |
| columns = ',\n'.join('{name} = EXCLUDED.{name}'.format(name=name) for name in |
| insert_statement.updatable_column_names) |
| sql += '\nON CONFLICT {primary_keys}\nDO UPDATE SET\n{columns}'.format( |
| primary_keys=primary_keys, columns=columns) |
| else: |
| sql += '\nON CONFLICT DO NOTHING' |
| else: |
| raise Exception('InsertStatement has unsupported conflict_action: {0}'.format( |
| insert_statement.conflict_action)) |
| return sql |
| |
| def _write_date_add_year(self, func): |
| return "%s + (%s) * INTERVAL '1' YEAR" \ |
| % (self._write(func.args[0]), self._write(func.args[1])) |
| |
| def _write_date_add_month(self, func): |
| return "%s + (%s) * INTERVAL '1' MONTH" \ |
| % (self._write(func.args[0]), self._write(func.args[1])) |
| |
| def _write_date_add_day(self, func): |
| return "%s + (%s) * INTERVAL '1' DAY" \ |
| % (self._write(func.args[0]), self._write(func.args[1])) |
| |
| def _write_date_add_hour(self, func): |
| return "%s + (%s) * INTERVAL '1' HOUR" \ |
| % (self._write(func.args[0]), self._write(func.args[1])) |
| |
| def _write_join_clause(self, join_clause): |
| sql = '%s JOIN %s %s' % ( |
| join_clause.join_type, |
| 'LATERAL' if join_clause.is_lateral_join else '', |
| self._write(join_clause.table_expr)) |
| if join_clause.boolean_expr: |
| sql += ' ON ' + self._write(join_clause.boolean_expr) |
| return sql |
| |
| def _write_date_add_minute(self, func): |
| return "%s + (%s) * INTERVAL '1' MINUTE" \ |
| % (self._write(func.args[0]), self._write(func.args[1])) |
| |
| def _write_date_add_second(self, func): |
| return "%s + (%s) * INTERVAL '1' SECOND" \ |
| % (self._write(func.args[0]), self._write(func.args[1])) |
| |
| def _write_column(self, col): |
| def first_non_struct_ancestor(col): |
| col = col.owner |
| while isinstance(col, StructColumn): |
| col = col.owner |
| return col |
| return '%s.%s' % (first_non_struct_ancestor(col).identifier, |
| QueryFlattener.flat_column_name(col)) |
| |
| def _write_collection_column(self, collection_col): |
| return '%s %s' % (QueryFlattener.flat_collection_name(collection_col), |
| collection_col.identifier) |
| |
| def _write_extract_second(self, func): |
| # For some reason Postgresql decided that extracting second should return a FLOAT... |
| return 'FLOOR(EXTRACT(SECOND FROM %s))' % self._write(func.args[0]) |
| |
| def _write_if(self, func): |
| return 'CASE WHEN {0} THEN {1} ELSE {2} END' \ |
| .format(*[self._write(arg) for arg in func.args]) |
| |
| def _write_data_type_metaclass(self, data_type_class): |
| '''Write a data type class such as Int or Boolean.''' |
| if data_type_class == Double: |
| return 'DOUBLE PRECISION' |
| elif data_type_class == Float: |
| return 'REAL' |
| elif data_type_class == String: |
| return 'VARCHAR(%s)' % data_type_class.MAX |
| elif data_type_class == Timestamp: |
| return 'TIMESTAMP WITHOUT TIME ZONE' |
| elif data_type_class == TinyInt: |
| return 'SMALLINT' |
| else: |
| return super(PostgresqlSqlWriter, self)._write_data_type_metaclass(data_type_class) |
| |
| def _write_order_by_clause(self, order_by_clause): |
| sql = 'ORDER BY ' |
| for idx, (expr, order) in enumerate(order_by_clause.exprs_to_order): |
| if idx > 0: |
| sql += ', ' |
| if expr.returns_char: |
| sql += 'CAST({0} AS BYTEA)'.format(self._write(expr)) |
| else: |
| sql += self._write(expr) |
| if order: |
| sql += ' ' + order |
| nulls_order = self.get_nulls_order(order) |
| if (nulls_order is not None): |
| sql += ' ' + nulls_order |
| return sql |
| |
| def _write_data_type(self, data_type): |
| '''Write a literal value.''' |
| if data_type.val is None: |
| return 'NULL' |
| if data_type.returns_char: |
| # Literals sometimes produce an error 'could not determine polymorphic type |
| # because input has type "unknown"', adding an "|| ''" avoids the problem. |
| return "'%s' || ''" % data_type.val |
| return SqlWriter._write_data_type(self, data_type) |
| |
| def _write_concat(self, func): |
| # PostgreSQL CONCAT() doesn't behave like Impala CONCAT(). PostgreSQL || does. |
| return '({concat_list})'.format( |
| concat_list=' || '.join(['({written_item})'.format(written_item=self._write(item)) |
| for item in func.args])) |
| |
| |
| class MySQLSqlWriter(SqlWriter): |
| |
| DIALECT = 'MYSQL' |
| |
| def write_query(self, query, pretty=False): |
| # MySQL doesn't support WITH clauses so they need to be converted into inline views. |
| # We are going to cheat by making use of the fact that the query generator creates |
| # with clause entries with unique aliases even considering nested queries. |
| sql = list() |
| for clause in ( |
| query.select_clause, |
| query.from_clause, |
| query.where_clause, |
| query.group_by_clause, |
| query.having_clause, |
| query.union_clause, |
| query.order_by_clause, |
| query.limit_clause |
| ): |
| if clause: |
| sql.append(self._write(clause)) |
| sql = '\n'.join(sql) |
| if query.with_clause: |
| # Just replace the named referenes with inline views. Go in reverse order because |
| # entries at the bottom of the WITH clause definition may reference entries above. |
| for with_clause_inline_view in reversed(query.with_clause.with_clause_inline_views): |
| replacement_sql = '(' + self.write_query(with_clause_inline_view.query) + ')' |
| sql = sql.replace(with_clause_inline_view.identifier, replacement_sql) |
| if pretty: |
| sql = self.make_pretty_sql(sql) |
| return sql |
| |
| def _write_data_type_metaclass(self, data_type_class): |
| '''Write a data type class such as Int or Boolean.''' |
| if issubclass(data_type_class, Int): |
| return 'INTEGER' |
| if issubclass(data_type_class, Float): |
| return 'DECIMAL(65, 15)' |
| if issubclass(data_type_class, VarChar): |
| return 'CHAR' |
| if hasattr(data_type_class, 'MYSQL'): |
| return data_type_class.MYSQL[0] |
| return data_type_class.__name__.upper() |
| |
| def _write_data_type(self, data_type): |
| '''Write a literal value.''' |
| if data_type.returns_timestamp: |
| return "CAST('{0}' AS DATETIME)".format(data_type.val) |
| if data_type.returns_boolean: |
| # MySQL will error if a data_type "FALSE" is used as a GROUP BY field |
| return '(0 = 0)' if data_type.val else '(1 = 0)' |
| return SqlWriter._write_data_type(self, data_type) |
| |
| def _write_order_by_clause(self, order_by_clause): |
| sql = 'ORDER BY ' |
| for idx, (expr, order) in enumerate(order_by_clause.exprs_to_order): |
| if idx > 0: |
| sql += ', ' |
| sql += 'ISNULL({0}), {0}'.format(self._write(expr)) |
| if order: |
| sql += ' ' + order |
| nulls_order = self.get_nulls_order(order) |
| if (nulls_order is not None): |
| sql += ' ' + nulls_order |
| return sql |