blob: 4c4d33ecefc1202a6dd2968f3af02cf243317234 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
**/
#include "query_optimizer/rules/ExtractCommonSubexpression.hpp"
#include <cstddef>
#include <memory>
#include <unordered_set>
#include <vector>
#include "query_optimizer/OptimizerContext.hpp"
#include "query_optimizer/expressions/AggregateFunction.hpp"
#include "query_optimizer/expressions/Alias.hpp"
#include "query_optimizer/expressions/CommonSubexpression.hpp"
#include "query_optimizer/expressions/ExpressionType.hpp"
#include "query_optimizer/expressions/NamedExpression.hpp"
#include "query_optimizer/expressions/PatternMatcher.hpp"
#include "query_optimizer/expressions/Scalar.hpp"
#include "query_optimizer/physical/Aggregate.hpp"
#include "query_optimizer/physical/HashJoin.hpp"
#include "query_optimizer/physical/NestedLoopsJoin.hpp"
#include "query_optimizer/physical/Physical.hpp"
#include "query_optimizer/physical/PhysicalType.hpp"
#include "query_optimizer/physical/Selection.hpp"
#include "utility/HashError.hpp"
#include "glog/logging.h"
namespace quickstep {
namespace optimizer {
namespace E = ::quickstep::optimizer::expressions;
namespace P = ::quickstep::optimizer::physical;
ExtractCommonSubexpression::ExtractCommonSubexpression(
OptimizerContext *optimizer_context)
: optimizer_context_(optimizer_context) {
const std::vector<E::ExpressionType> homogeneous_expr_types = {
E::ExpressionType::kAlias,
E::ExpressionType::kAttributeReference,
E::ExpressionType::kBinaryExpression,
E::ExpressionType::kCast,
E::ExpressionType::kCommonSubexpression,
E::ExpressionType::kScalarLiteral,
E::ExpressionType::kUnaryExpression
};
for (const auto &expr_type : homogeneous_expr_types) {
homogeneous_expression_types_.emplace(expr_type);
}
}
P::PhysicalPtr ExtractCommonSubexpression::applyToNode(
const P::PhysicalPtr &input) {
switch (input->getPhysicalType()) {
case P::PhysicalType::kAggregate: {
const P::AggregatePtr aggregate =
std::static_pointer_cast<const P::Aggregate>(input);
std::vector<E::ExpressionPtr> expressions;
// Gather grouping expressions and aggregate functions' argument expressions.
for (const auto &expr : aggregate->grouping_expressions()) {
expressions.emplace_back(expr);
}
for (const auto &expr : aggregate->aggregate_expressions()) {
const E::AggregateFunctionPtr &func =
std::static_pointer_cast<const E::AggregateFunction>(expr->expression());
for (const auto &arg : func->getArguments()) {
expressions.emplace_back(arg);
}
}
// Transform the expressions so that common subexpressions are labelled.
const std::vector<E::ExpressionPtr> new_expressions =
transformExpressions(expressions);
if (new_expressions != expressions) {
std::vector<E::AliasPtr> new_aggregate_expressions;
std::vector<E::NamedExpressionPtr> new_grouping_expressions;
// Reconstruct grouping expressions.
const std::size_t num_grouping_expressions =
aggregate->grouping_expressions().size();
for (std::size_t i = 0; i < num_grouping_expressions; ++i) {
DCHECK(E::SomeNamedExpression::Matches(new_expressions[i]));
new_grouping_expressions.emplace_back(
std::static_pointer_cast<const E::NamedExpression>(new_expressions[i]));
}
// Reconstruct aggregate expressions.
auto it = new_expressions.begin() + num_grouping_expressions;
for (const auto &expr : aggregate->aggregate_expressions()) {
const E::AggregateFunctionPtr &func =
std::static_pointer_cast<const E::AggregateFunction>(
expr->expression());
std::vector<E::ScalarPtr> new_arguments;
for (std::size_t i = 0; i < func->getArguments().size(); ++i, ++it) {
DCHECK(E::SomeScalar::Matches(*it));
new_arguments.emplace_back(std::static_pointer_cast<const E::Scalar>(*it));
}
if (new_arguments == func->getArguments()) {
new_aggregate_expressions.emplace_back(expr);
} else {
const E::AggregateFunctionPtr new_func =
E::AggregateFunction::Create(func->getAggregate(),
new_arguments,
func->is_vector_aggregate(),
func->is_distinct());
new_aggregate_expressions.emplace_back(
E::Alias::Create(expr->id(),
new_func,
expr->attribute_name(),
expr->attribute_alias(),
expr->relation_name()));
}
}
return P::Aggregate::Create(aggregate->input(),
new_grouping_expressions,
new_aggregate_expressions,
aggregate->filter_predicate());
}
break;
}
case P::PhysicalType::kSelection: {
const P::SelectionPtr selection =
std::static_pointer_cast<const P::Selection>(input);
// Transform Selection's project expressions.
const std::vector<E::NamedExpressionPtr> new_expressions =
DownCast<E::NamedExpression>(
transformExpressions(UpCast(selection->project_expressions())));
if (new_expressions != selection->project_expressions()) {
return P::Selection::Create(selection->input(),
new_expressions,
selection->filter_predicate(),
selection->input()->cloneOutputPartitionSchemeHeader());
}
break;
}
case P::PhysicalType::kHashJoin: {
const P::HashJoinPtr hash_join =
std::static_pointer_cast<const P::HashJoin>(input);
// Transform HashJoin's project expressions.
const std::vector<E::NamedExpressionPtr> new_expressions =
DownCast<E::NamedExpression>(
transformExpressions(UpCast(hash_join->project_expressions())));
if (new_expressions != hash_join->project_expressions()) {
return P::HashJoin::Create(hash_join->left(),
hash_join->right(),
hash_join->left_join_attributes(),
hash_join->right_join_attributes(),
hash_join->residual_predicate(),
new_expressions,
hash_join->join_type());
}
break;
}
case P::PhysicalType::kNestedLoopsJoin: {
const P::NestedLoopsJoinPtr nested_loops_join =
std::static_pointer_cast<const P::NestedLoopsJoin>(input);
// Transform NestedLoopsJoin's project expressions.
const std::vector<E::NamedExpressionPtr> new_expressions =
DownCast<E::NamedExpression>(
transformExpressions(UpCast(nested_loops_join->project_expressions())));
if (new_expressions != nested_loops_join->project_expressions()) {
return P::NestedLoopsJoin::Create(nested_loops_join->left(),
nested_loops_join->right(),
nested_loops_join->join_predicate(),
new_expressions);
}
break;
}
default:
break;
}
return input;
}
std::vector<E::ExpressionPtr> ExtractCommonSubexpression::transformExpressions(
const std::vector<E::ExpressionPtr> &expressions) {
// Step 1. For each subexpression, count the number of its occurrences.
ScalarCounter counter;
ScalarHashable hashable;
for (const auto &expr : expressions) {
visitAndCount(expr, &counter, &hashable);
}
// Note that any subexpression with count > 1 is a common subexpression.
// However, it is not necessary to create a CommonSubexpression node for every
// such subexpression. E.g. consider the case
// --
// SELECT (x+1)*(x+2), (x+1)*(x+2)*3 FROM s;
// --
// We only need to create one *dominant* CommonSubexpression (x+1)*(x+2) and
// do not need to create the child (x+1) or (x+2) nodes.
//
// To address this issue. We define that a subtree S *dominates* its descendent
// subtree (or leaf node) T if and only if counter[S] >= counter[T].
//
// Then we create CommonSubexpression nodes for every subexpression that is
// not dominated by any ancestor.
ScalarMap substitution_map;
std::vector<E::ExpressionPtr> new_expressions;
for (const auto &expr : expressions) {
new_expressions.emplace_back(
visitAndTransform(expr, 1, counter, hashable, &substitution_map));
}
return new_expressions;
}
E::ExpressionPtr ExtractCommonSubexpression::transformExpression(
const E::ExpressionPtr &expression) {
return transformExpressions({expression}).front();
}
bool ExtractCommonSubexpression::visitAndCount(
const E::ExpressionPtr &expression,
ScalarCounter *counter,
ScalarHashable *hashable) const {
// This bool flag is for avoiding some unnecessary hash() computation.
bool children_hashable = true;
const auto homogeneous_expression_types_it =
homogeneous_expression_types_.find(expression->getExpressionType());
if (homogeneous_expression_types_it != homogeneous_expression_types_.end()) {
for (const auto &child : expression->children()) {
children_hashable &= visitAndCount(child, counter, hashable);
}
}
E::ScalarPtr scalar;
if (children_hashable &&
E::SomeScalar::MatchesWithConditionalCast(expression, &scalar)) {
// A scalar expression may or may not support the hash() computation.
// In the later case, a HashNotSupported exception will be thrown and we
// simply ignore this expression (and all its ancestor expressions).
try {
++(*counter)[scalar];
} catch (const HashNotSupported &e) {
return false;
}
hashable->emplace(scalar);
return true;
}
return false;
}
E::ExpressionPtr ExtractCommonSubexpression::visitAndTransform(
const E::ExpressionPtr &expression,
const std::size_t max_reference_count,
const ScalarCounter &counter,
const ScalarHashable &hashable,
ScalarMap *substitution_map) {
// TODO(jianqiao): Currently it is hardly beneficial to make AttributeReference
// a common subexpression due to the inefficiency of ScalarAttribute's
// size-not-known-at-compile-time std::memcpy() calls, compared to copy-elision
// at selection level. Even in the case of compressed column store, it requires
// an attribute to occur at least 4 times for the common subexpression version
// to outperform the direct decoding version. We may look into ScalarAttribute
// and modify the heuristic here later.
if (expression->getExpressionType() == E::ExpressionType::kScalarLiteral ||
expression->getExpressionType() == E::ExpressionType::kAttributeReference) {
return expression;
}
E::ScalarPtr scalar;
const bool is_hashable =
E::SomeScalar::MatchesWithConditionalCast(expression, &scalar)
&& hashable.find(scalar) != hashable.end();
std::size_t new_max_reference_count;
if (is_hashable) {
// CommonSubexpression node already generated somewhere. Just refer to that
// and return.
const auto substitution_map_it = substitution_map->find(scalar);
if (substitution_map_it != substitution_map->end()) {
return substitution_map_it->second;
}
// Otherwise, update the dominance count.
const auto counter_it = counter.find(scalar);
DCHECK(counter_it != counter.end());
DCHECK_LE(max_reference_count, counter_it->second);
new_max_reference_count = counter_it->second;
} else {
new_max_reference_count = max_reference_count;
}
// Process children.
std::vector<E::ExpressionPtr> new_children;
const auto homogeneous_expression_types_it =
homogeneous_expression_types_.find(expression->getExpressionType());
if (homogeneous_expression_types_it == homogeneous_expression_types_.end()) {
// If child subexpressions cannot be hoisted through the current expression,
// treat child expressions as isolated sub-optimizations.
for (const auto &child : expression->children()) {
new_children.emplace_back(transformExpression(child));
}
} else {
for (const auto &child : expression->children()) {
new_children.emplace_back(
visitAndTransform(child,
new_max_reference_count,
counter,
hashable,
substitution_map));
}
}
E::ExpressionPtr output;
if (new_children == expression->children()) {
output = expression;
} else {
output = std::static_pointer_cast<const E::Scalar>(
expression->copyWithNewChildren(new_children));
}
// Wrap it with a new CommonSubexpression node if the current expression is a
// dominant subexpression.
if (is_hashable && new_max_reference_count > max_reference_count) {
DCHECK(E::SomeScalar::Matches(output));
const E::CommonSubexpressionPtr common_subexpression =
E::CommonSubexpression::Create(
optimizer_context_->nextExprId(),
std::static_pointer_cast<const E::Scalar>(output));
substitution_map->emplace(scalar, common_subexpression);
output = common_subexpression;
}
return output;
}
} // namespace optimizer
} // namespace quickstep