| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from copy import deepcopy |
| from logging import getLogger |
| |
| from tests.comparison.common import ( |
| CollectionColumn, |
| Column, |
| StructColumn, |
| Table) |
| from tests.comparison.db_types import BigInt, Boolean |
| from tests.comparison.funcs import Equals, And |
| from tests.comparison.query import ( |
| FromClause, |
| InlineView, |
| JoinClause, |
| Query, |
| SelectClause, |
| WhereClause) |
| |
| LOG = getLogger(__name__) |
| |
| class QueryFlattener(object): |
| '''Converts a query that contains references to nested types, to an equivalent query for |
| for a flattened dataset. This class depends on the dataset flattener implementation. |
| ''' |
| |
| def __init__(self): |
| self.clear_state() |
| |
| def clear_state(self): |
| self.tmp_alias = 0 |
| # Elements such as join clauses or columns that are not present in the original query. |
| self.for_flattening = set() |
| |
| def contains_jump(self, table_expr): |
| '''If some ancestor CollectionColumn or Table has an alias, but there is a closer |
| ancestor CollectionColumn without an alias, then this function will return True. |
| For example, suppose we have customer t1 and t1.orders.lineitems t2. If we call |
| this function with t2 CollectionColumn as parameter, it will return True. |
| ''' |
| result = False |
| while True: |
| if table_expr.owner.alias: |
| return result |
| if isinstance(table_expr.owner, Table): |
| # We reached the root and did not encounter an ancestor with an alias |
| return False |
| if isinstance(table_expr.owner, CollectionColumn): |
| # We encountered a CollectionColumn with no alias |
| result = True |
| table_expr = table_expr.owner |
| |
| def get_first_aliased_ancestor(self, table_expr): |
| '''Finds the first ancestor that is not a struct. It is returned if it has an alias, |
| otherwise, None is returned. |
| ''' |
| while True: |
| if table_expr.owner.alias: |
| return table_expr.owner |
| elif isinstance(table_expr.owner, StructColumn): |
| table_expr = table_expr.owner |
| else: |
| return None |
| |
| def flatten_join_clause(self, join_clause, query): |
| |
| if join_clause.is_lateral_join: |
| if isinstance(join_clause.table_expr, CollectionColumn): |
| # All laterally joined Collecitons are converted to an inline view. |
| join_clause.table_expr = self.convert_correlated_collection_to_inline_view( |
| join_clause.table_expr) |
| self.flatten(join_clause.table_expr.query, inner=True) |
| join_clause.boolean_expr = join_clause.boolean_expr or Boolean(True) |
| elif join_clause not in self.for_flattening: |
| if isinstance(join_clause.table_expr, CollectionColumn) and \ |
| self.contains_jump(join_clause.table_expr): |
| join_clause.table_expr = self.convert_correlated_collection_to_inline_view( |
| join_clause.table_expr) |
| if isinstance(join_clause.table_expr, CollectionColumn): |
| aliased_ancestor = self.get_first_aliased_ancestor(join_clause.table_expr) |
| if aliased_ancestor: |
| predicate = self.create_join_predicate(aliased_ancestor, join_clause.table_expr) |
| if query.where_clause: |
| query.where_clause.boolean_expr = And.create_from_args( |
| predicate, query.where_clause.boolean_expr) |
| else: |
| query.where_clause = WhereClause(predicate) |
| elif isinstance(join_clause.table_expr, InlineView): |
| self.flatten(join_clause.table_expr.query, inner=True) |
| |
| def flatten_from_clause(self, from_clause, query): |
| |
| if isinstance(from_clause.table_expr, CollectionColumn) and \ |
| self.contains_jump(from_clause.table_expr): |
| from_clause.table_expr = self.convert_correlated_collection_to_inline_view( |
| from_clause.table_expr) |
| if isinstance(from_clause.table_expr, CollectionColumn): |
| aliased_ancestor = self.get_first_aliased_ancestor(from_clause.table_expr) |
| if aliased_ancestor: |
| predicate = self.create_join_predicate(aliased_ancestor, from_clause.table_expr) |
| if query.where_clause: |
| query.where_clause.boolean_expr = And.create_from_args( |
| predicate, query.where_clause.boolean_expr) |
| else: |
| query.where_clause = WhereClause(predicate) |
| elif isinstance(from_clause.table_expr, InlineView): |
| self.flatten(from_clause.table_expr.query, inner=True) |
| for join_clause in from_clause.join_clauses: |
| self.flatten_join_clause(join_clause, query) |
| |
| def flatten(self, query, inner=False): |
| '''This function is idempotent.''' |
| if not inner: |
| self.clear_state() |
| if not getattr(query, 'flattened', False): |
| self.flatten_from_clause(query.from_clause, query) |
| for nested_query in query.nested_queries: |
| self.flatten(nested_query, inner=True) |
| query.flattened = True |
| |
| def convert_correlated_collection_to_inline_view(self, original_collection): |
| '''Converts the given collection to an inline view. The collection must be correlated, |
| ie. it must have an aliased ancestor. This function is able to handle cases where |
| there are unaliased collections between the given collection and the aliased |
| ancestor. |
| For example, |
| |
| SELECT ... |
| FROM customer t1 INNER JOIN t1.orders.lineitem |
| |
| should be converted to: |
| |
| SELECT ... |
| FROM customer t1 INNER JOIN LATERAL ( |
| SELECT |
| t2.* |
| FROM |
| customer.orders tmp_alias_1 |
| INNER JOIN customer.orders.lineitems t2 ON ([tmp_alias_1 is parent of t2]) |
| WHERE |
| [t1 is parent of tmp_alias_1]) t2 ON True |
| |
| This function does not add a where clause to the inline view in order to connect |
| the first table expression in the from clause with it's parent. This is done in |
| flatten_from_clause. |
| ''' |
| |
| def replace_col(container, new_col): |
| for i, col in enumerate(container._cols): |
| if col.name == new_col.name: |
| container._cols[i] = new_col |
| new_col.owner = container |
| return |
| |
| def create_inner_join(parent, child): |
| predicate = self.create_join_predicate(parent, child) |
| join_clause = JoinClause('INNER', child) |
| join_clause.boolean_expr = predicate |
| self.for_flattening.add(join_clause) |
| return join_clause |
| |
| query = Query() |
| query.select_clause = SelectClause(None) |
| query.select_clause.star_prefix = original_collection.alias |
| |
| # Create a list containing original_collection along with all of it's unaliased |
| # ancestors |
| cur = original_collection |
| all_from_elements = [] |
| while True: |
| all_from_elements.append(cur) |
| if cur.owner.alias: |
| break |
| else: |
| cur = cur.owner |
| all_from_elements.reverse() |
| |
| num_collections = sum(1 for e in all_from_elements if isinstance(e, CollectionColumn)) |
| if num_collections == 1: |
| query.from_clause = FromClause(original_collection) |
| elif num_collections > 1: |
| # Add multiple elements to the from clause which are joined together. |
| for i in range(len(all_from_elements)): |
| if isinstance(all_from_elements[i], CollectionColumn): |
| first_collection = deepcopy(all_from_elements[i]) |
| first_collection.alias = self.get_tmp_alias() |
| all_from_elements = all_from_elements[i + 1:] |
| break |
| |
| query.from_clause = FromClause(first_collection) |
| |
| prev = first_collection |
| cur = first_collection |
| |
| for table_expr in all_from_elements[:-1]: |
| cur = cur.get_col_by_name(table_expr.name) |
| if isinstance(cur, CollectionColumn): |
| join_clause = create_inner_join(prev, cur) |
| join_clause.table_expr.alias = self.get_tmp_alias() |
| query.from_clause.join_clauses.append(join_clause) |
| prev = cur |
| # The last original_collection to be the original one. |
| replace_col(cur, original_collection) |
| join_clause = create_inner_join(prev, original_collection) |
| query.from_clause.join_clauses.append(join_clause) |
| else: |
| # num_collections < 1 |
| assert False |
| |
| inline_view = InlineView(query) |
| inline_view.alias = original_collection.alias |
| |
| return inline_view |
| |
| def create_join_predicate(self, parent_table, child_table): |
| for col in parent_table.cols: |
| if col.name == 'id': |
| parent_id_col = col |
| break |
| else: |
| parent_id_col = Column(parent_table, 'id', BigInt) |
| parent_id_col.for_flattening = True |
| parent_table.add_col(parent_id_col) |
| |
| child_col_name = self.flat_collection_name(parent_table) + '_id' |
| child_col = Column(None, child_col_name, BigInt) |
| child_table.add_col(child_col) |
| |
| return Equals.create_from_args(parent_id_col, child_col) |
| |
| def get_tmp_alias(self): |
| self.tmp_alias += 1 |
| return 'tmp_alias_' + str(self.tmp_alias) |
| |
| @classmethod |
| def flat_column_name(cls, col): |
| if isinstance(col, StructColumn): |
| name = 'item' if col.name == 'value' else col.name |
| if isinstance(col.owner, StructColumn): |
| return cls.flat_column_name(col.owner) + '_' + col.name |
| return name |
| elif isinstance(col, Column): |
| name = col.name |
| if col.name == 'item': |
| name = 'value' |
| elif col.name == 'pos': |
| name = 'idx' |
| if isinstance(col.owner, StructColumn) \ |
| and not getattr(col, 'for_flattening', False): |
| # This is a struct field. To get the name in the flattened table, concatenate the |
| # name of the struct with the field struct field name. |
| return cls.flat_column_name(col.owner) + '_' + name |
| return name |
| |
| @classmethod |
| def flat_collection_name(cls, entity): |
| '''Figures out the flat collection name for some descendent CollectionColumn. |
| For example, we have <Table (table1): StructColumn(structcol1): ArrayCol(arrcol)>, |
| and we want to know the name of the flattened arrcol table. This method should |
| return "table1_arrcol". Notice StructColumn name is not included in the table name. |
| This implementation depends on the implementation of the dataset flattener. |
| ''' |
| if isinstance(entity, StructColumn): |
| return cls.flat_collection_name(entity.owner) |
| if isinstance(entity, CollectionColumn): |
| # For example, customer.orders array column is converted to customer_orders table. |
| name = '_values' if entity.name in ('item', 'value') else entity.name |
| return cls.flat_collection_name(entity.owner) + '_' + name |
| if isinstance(entity, Table): |
| return entity.name |