blob: 2a119131e02c6e902dabd85b75ffd19a509d9942 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from copy import deepcopy
from logging import getLogger
from tests.comparison.common import (
CollectionColumn,
Column,
StructColumn,
Table)
from tests.comparison.db_types import BigInt, Boolean
from tests.comparison.funcs import Equals, And
from tests.comparison.query import (
FromClause,
InlineView,
JoinClause,
Query,
SelectClause,
WhereClause)
LOG = getLogger(__name__)
class QueryFlattener(object):
'''Converts a query that contains references to nested types, to an equivalent query for
for a flattened dataset. This class depends on the dataset flattener implementation.
'''
def __init__(self):
self.clear_state()
def clear_state(self):
self.tmp_alias = 0
# Elements such as join clauses or columns that are not present in the original query.
self.for_flattening = set()
def contains_jump(self, table_expr):
'''If some ancestor CollectionColumn or Table has an alias, but there is a closer
ancestor CollectionColumn without an alias, then this function will return True.
For example, suppose we have customer t1 and t1.orders.lineitems t2. If we call
this function with t2 CollectionColumn as parameter, it will return True.
'''
result = False
while True:
if table_expr.owner.alias:
return result
if isinstance(table_expr.owner, Table):
# We reached the root and did not encounter an ancestor with an alias
return False
if isinstance(table_expr.owner, CollectionColumn):
# We encountered a CollectionColumn with no alias
result = True
table_expr = table_expr.owner
def get_first_aliased_ancestor(self, table_expr):
'''Finds the first ancestor that is not a struct. It is returned if it has an alias,
otherwise, None is returned.
'''
while True:
if table_expr.owner.alias:
return table_expr.owner
elif isinstance(table_expr.owner, StructColumn):
table_expr = table_expr.owner
else:
return None
def flatten_join_clause(self, join_clause, query):
if join_clause.is_lateral_join:
if isinstance(join_clause.table_expr, CollectionColumn):
# All laterally joined Collecitons are converted to an inline view.
join_clause.table_expr = self.convert_correlated_collection_to_inline_view(
join_clause.table_expr)
self.flatten(join_clause.table_expr.query, inner=True)
join_clause.boolean_expr = join_clause.boolean_expr or Boolean(True)
elif join_clause not in self.for_flattening:
if isinstance(join_clause.table_expr, CollectionColumn) and \
self.contains_jump(join_clause.table_expr):
join_clause.table_expr = self.convert_correlated_collection_to_inline_view(
join_clause.table_expr)
if isinstance(join_clause.table_expr, CollectionColumn):
aliased_ancestor = self.get_first_aliased_ancestor(join_clause.table_expr)
if aliased_ancestor:
predicate = self.create_join_predicate(aliased_ancestor, join_clause.table_expr)
if query.where_clause:
query.where_clause.boolean_expr = And.create_from_args(
predicate, query.where_clause.boolean_expr)
else:
query.where_clause = WhereClause(predicate)
elif isinstance(join_clause.table_expr, InlineView):
self.flatten(join_clause.table_expr.query, inner=True)
def flatten_from_clause(self, from_clause, query):
if isinstance(from_clause.table_expr, CollectionColumn) and \
self.contains_jump(from_clause.table_expr):
from_clause.table_expr = self.convert_correlated_collection_to_inline_view(
from_clause.table_expr)
if isinstance(from_clause.table_expr, CollectionColumn):
aliased_ancestor = self.get_first_aliased_ancestor(from_clause.table_expr)
if aliased_ancestor:
predicate = self.create_join_predicate(aliased_ancestor, from_clause.table_expr)
if query.where_clause:
query.where_clause.boolean_expr = And.create_from_args(
predicate, query.where_clause.boolean_expr)
else:
query.where_clause = WhereClause(predicate)
elif isinstance(from_clause.table_expr, InlineView):
self.flatten(from_clause.table_expr.query, inner=True)
for join_clause in from_clause.join_clauses:
self.flatten_join_clause(join_clause, query)
def flatten(self, query, inner=False):
'''This function is idempotent.'''
if not inner:
self.clear_state()
if not getattr(query, 'flattened', False):
self.flatten_from_clause(query.from_clause, query)
for nested_query in query.nested_queries:
self.flatten(nested_query, inner=True)
query.flattened = True
def convert_correlated_collection_to_inline_view(self, original_collection):
'''Converts the given collection to an inline view. The collection must be correlated,
ie. it must have an aliased ancestor. This function is able to handle cases where
there are unaliased collections between the given collection and the aliased
ancestor.
For example,
SELECT ...
FROM customer t1 INNER JOIN t1.orders.lineitem
should be converted to:
SELECT ...
FROM customer t1 INNER JOIN LATERAL (
SELECT
t2.*
FROM
customer.orders tmp_alias_1
INNER JOIN customer.orders.lineitems t2 ON ([tmp_alias_1 is parent of t2])
WHERE
[t1 is parent of tmp_alias_1]) t2 ON True
This function does not add a where clause to the inline view in order to connect
the first table expression in the from clause with it's parent. This is done in
flatten_from_clause.
'''
def replace_col(container, new_col):
for i, col in enumerate(container._cols):
if col.name == new_col.name:
container._cols[i] = new_col
new_col.owner = container
return
def create_inner_join(parent, child):
predicate = self.create_join_predicate(parent, child)
join_clause = JoinClause('INNER', child)
join_clause.boolean_expr = predicate
self.for_flattening.add(join_clause)
return join_clause
query = Query()
query.select_clause = SelectClause(None)
query.select_clause.star_prefix = original_collection.alias
# Create a list containing original_collection along with all of it's unaliased
# ancestors
cur = original_collection
all_from_elements = []
while True:
all_from_elements.append(cur)
if cur.owner.alias:
break
else:
cur = cur.owner
all_from_elements.reverse()
num_collections = sum(1 for e in all_from_elements if isinstance(e, CollectionColumn))
if num_collections == 1:
query.from_clause = FromClause(original_collection)
elif num_collections > 1:
# Add multiple elements to the from clause which are joined together.
for i in range(len(all_from_elements)):
if isinstance(all_from_elements[i], CollectionColumn):
first_collection = deepcopy(all_from_elements[i])
first_collection.alias = self.get_tmp_alias()
all_from_elements = all_from_elements[i + 1:]
break
query.from_clause = FromClause(first_collection)
prev = first_collection
cur = first_collection
for table_expr in all_from_elements[:-1]:
cur = cur.get_col_by_name(table_expr.name)
if isinstance(cur, CollectionColumn):
join_clause = create_inner_join(prev, cur)
join_clause.table_expr.alias = self.get_tmp_alias()
query.from_clause.join_clauses.append(join_clause)
prev = cur
# The last original_collection to be the original one.
replace_col(cur, original_collection)
join_clause = create_inner_join(prev, original_collection)
query.from_clause.join_clauses.append(join_clause)
else:
# num_collections < 1
assert False
inline_view = InlineView(query)
inline_view.alias = original_collection.alias
return inline_view
def create_join_predicate(self, parent_table, child_table):
for col in parent_table.cols:
if col.name == 'id':
parent_id_col = col
break
else:
parent_id_col = Column(parent_table, 'id', BigInt)
parent_id_col.for_flattening = True
parent_table.add_col(parent_id_col)
child_col_name = self.flat_collection_name(parent_table) + '_id'
child_col = Column(None, child_col_name, BigInt)
child_table.add_col(child_col)
return Equals.create_from_args(parent_id_col, child_col)
def get_tmp_alias(self):
self.tmp_alias += 1
return 'tmp_alias_' + str(self.tmp_alias)
@classmethod
def flat_column_name(cls, col):
if isinstance(col, StructColumn):
name = 'item' if col.name == 'value' else col.name
if isinstance(col.owner, StructColumn):
return cls.flat_column_name(col.owner) + '_' + col.name
return name
elif isinstance(col, Column):
name = col.name
if col.name == 'item':
name = 'value'
elif col.name == 'pos':
name = 'idx'
if isinstance(col.owner, StructColumn) \
and not getattr(col, 'for_flattening', False):
# This is a struct field. To get the name in the flattened table, concatenate the
# name of the struct with the field struct field name.
return cls.flat_column_name(col.owner) + '_' + name
return name
@classmethod
def flat_collection_name(cls, entity):
'''Figures out the flat collection name for some descendent CollectionColumn.
For example, we have <Table (table1): StructColumn(structcol1): ArrayCol(arrcol)>,
and we want to know the name of the flattened arrcol table. This method should
return "table1_arrcol". Notice StructColumn name is not included in the table name.
This implementation depends on the implementation of the dataset flattener.
'''
if isinstance(entity, StructColumn):
return cls.flat_collection_name(entity.owner)
if isinstance(entity, CollectionColumn):
# For example, customer.orders array column is converted to customer_orders table.
name = '_values' if entity.name in ('item', 'value') else entity.name
return cls.flat_collection_name(entity.owner) + '_' + name
if isinstance(entity, Table):
return entity.name