Added physical rule for partitioned aggregations.
diff --git a/query_optimizer/rules/CMakeLists.txt b/query_optimizer/rules/CMakeLists.txt
index 0cd7212..7abe0d1 100644
--- a/query_optimizer/rules/CMakeLists.txt
+++ b/query_optimizer/rules/CMakeLists.txt
@@ -163,13 +163,19 @@
target_link_libraries(quickstep_queryoptimizer_rules_Partition
glog
gtest
+ quickstep_expressions_aggregation_AggregateFunction
+ quickstep_expressions_aggregation_AggregateFunctionFactory
+ quickstep_expressions_aggregation_AggregationID
quickstep_queryoptimizer_OptimizerContext
quickstep_queryoptimizer_costmodel_StarSchemaSimpleCostModel
+ quickstep_queryoptimizer_expressions_AggregateFunction
quickstep_queryoptimizer_expressions_AttributeReference
+ quickstep_queryoptimizer_expressions_BinaryExpression
quickstep_queryoptimizer_expressions_ExprId
quickstep_queryoptimizer_expressions_ExpressionUtil
quickstep_queryoptimizer_expressions_NamedExpression
quickstep_queryoptimizer_expressions_PatternMatcher
+ quickstep_queryoptimizer_physical_Aggregate
quickstep_queryoptimizer_physical_HashJoin
quickstep_queryoptimizer_physical_PartitionSchemeHeader
quickstep_queryoptimizer_physical_PatternMatcher
@@ -179,6 +185,9 @@
quickstep_queryoptimizer_physical_TableReference
quickstep_queryoptimizer_physical_TopLevelPlan
quickstep_queryoptimizer_rules_BottomUpRule
+ quickstep_types_operations_binaryoperations_BinaryOperation
+ quickstep_types_operations_binaryoperations_BinaryOperationFactory
+ quickstep_types_operations_binaryoperations_BinaryOperationID
quickstep_utility_Cast
quickstep_utility_EqualsAnyConstant
quickstep_utility_Macros
diff --git a/query_optimizer/rules/Partition.cpp b/query_optimizer/rules/Partition.cpp
index 39546c6..41b2f1f 100644
--- a/query_optimizer/rules/Partition.cpp
+++ b/query_optimizer/rules/Partition.cpp
@@ -22,17 +22,24 @@
#include <cstddef>
#include <cstdint>
#include <memory>
+#include <tuple>
#include <unordered_set>
#include <utility>
#include <vector>
+#include "expressions/aggregation/AggregateFunction.hpp"
+#include "expressions/aggregation/AggregateFunctionFactory.hpp"
+#include "expressions/aggregation/AggregationID.hpp"
#include "query_optimizer/OptimizerContext.hpp"
#include "query_optimizer/cost_model/StarSchemaSimpleCostModel.hpp"
+#include "query_optimizer/expressions/AggregateFunction.hpp"
#include "query_optimizer/expressions/AttributeReference.hpp"
+#include "query_optimizer/expressions/BinaryExpression.hpp"
#include "query_optimizer/expressions/ExprId.hpp"
#include "query_optimizer/expressions/ExpressionUtil.hpp"
#include "query_optimizer/expressions/NamedExpression.hpp"
#include "query_optimizer/expressions/PatternMatcher.hpp"
+#include "query_optimizer/physical/Aggregate.hpp"
#include "query_optimizer/physical/HashJoin.hpp"
#include "query_optimizer/physical/PartitionSchemeHeader.hpp"
#include "query_optimizer/physical/PatternMatcher.hpp"
@@ -41,12 +48,16 @@
#include "query_optimizer/physical/Selection.hpp"
#include "query_optimizer/physical/TableReference.hpp"
#include "query_optimizer/physical/TopLevelPlan.hpp"
+#include "types/operations/binary_operations/BinaryOperation.hpp"
+#include "types/operations/binary_operations/BinaryOperationFactory.hpp"
+#include "types/operations/binary_operations/BinaryOperationID.hpp"
#include "utility/Cast.hpp"
#include "utility/EqualsAnyConstant.hpp"
#include "gflags/gflags.h"
#include "glog/logging.h"
+using std::get;
using std::make_unique;
using std::move;
using std::size_t;
@@ -84,6 +95,18 @@
P::PhysicalType::kUnionAll);
}
+P::PhysicalPtr Repartition(const P::PhysicalPtr &node, P::PartitionSchemeHeader *repartition_scheme_header) {
+ if (needsSelection(node->getPhysicalType())) {
+ // Add a Selection node.
+ return P::Selection::Create(node,
+ CastSharedPtrVector<E::NamedExpression>(node->getOutputAttributes()),
+ nullptr /* filter_predicate */, repartition_scheme_header);
+ } else {
+ // Overwrite the output partition scheme header of the node.
+ return node->copyWithNewOutputPartitionSchemeHeader(repartition_scheme_header);
+ }
+}
+
P::PhysicalPtr HashRepartition(const P::PhysicalPtr &node,
const vector<E::AttributeReferencePtr> &repartition_attributes,
const size_t num_repartitions) {
@@ -94,24 +117,229 @@
auto repartition_scheme_header = make_unique<P::PartitionSchemeHeader>(
P::PartitionSchemeHeader::PartitionType::kHash, num_repartitions, move(repartition_expr_ids));
- if (needsSelection(node->getPhysicalType())) {
- // Add a Selection node.
- return P::Selection::Create(node,
- CastSharedPtrVector<E::NamedExpression>(node->getOutputAttributes()),
- nullptr /* filter_predicate */, repartition_scheme_header.release());
- } else {
- // Overwrite the output partition scheme header of the node.
- return node->copyWithNewOutputPartitionSchemeHeader(repartition_scheme_header.release());
+ return Repartition(node, repartition_scheme_header.release());
+}
+
+E::AliasPtr GetReaggregateExpression(const E::AliasPtr &aggr_alias) {
+ E::ExpressionPtr aggr = aggr_alias->expression();
+
+ E::AggregateFunctionPtr aggr_fn;
+ CHECK(E::SomeAggregateFunction::MatchesWithConditionalCast(aggr, &aggr_fn))
+ << aggr->toString();
+
+ AggregationID reaggr_id;
+ switch (aggr_fn->getAggregate().getAggregationID()) {
+ case AggregationID::kCount:
+ case AggregationID::kSum: {
+ reaggr_id = AggregationID::kSum;
+ break;
+ }
+ case AggregationID::kMax: {
+ reaggr_id = AggregationID::kMax;
+ break;
+ }
+ case AggregationID::kMin: {
+ reaggr_id = AggregationID::kMin;
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unsupported aggregation id for re-aggregate";
}
+
+ const E::AggregateFunctionPtr reaggr =
+ E::AggregateFunction::Create(AggregateFunctionFactory::Get(reaggr_id),
+ { E::ToRef(aggr_alias) },
+ aggr_fn->is_vector_aggregate(),
+ aggr_fn->is_distinct());
+
+ return E::Alias::Create(aggr_alias->id(),
+ reaggr,
+ aggr_alias->attribute_name(),
+ aggr_alias->attribute_alias(),
+ aggr_alias->relation_name());
}
} // namespace
P::PhysicalPtr Partition::applyToNode(const P::PhysicalPtr &node) {
- // Will be used for aggregations.
- (void) optimizer_context_;
-
switch (node->getPhysicalType()) {
+ case P::PhysicalType::kAggregate: {
+ const P::AggregatePtr aggregate = static_pointer_cast<const P::Aggregate>(node);
+
+ const P::PhysicalPtr input = aggregate->input();
+ const P::PartitionSchemeHeader *input_partition_scheme_header =
+ input->getOutputPartitionSchemeHeader();
+
+ if (!input_partition_scheme_header) {
+ break;
+ }
+
+ std::unique_ptr<P::PartitionSchemeHeader> output_partition_scheme_header;
+ const vector<E::NamedExpressionPtr> &grouping_expressions = aggregate->grouping_expressions();
+
+ unordered_set<E::ExprId> grouping_expr_ids;
+ for (const E::NamedExpressionPtr &grouping_expression : grouping_expressions) {
+ grouping_expr_ids.insert(grouping_expression->id());
+ }
+
+ if (!grouping_expressions.empty()) {
+ if (input_partition_scheme_header->reusablePartitionScheme(grouping_expr_ids)) {
+ // We do not need to reaggregate iff the list of partition attributes is
+ // a subset of the group by list.
+ P::PartitionSchemeHeader::PartitionExprIds output_partition_expr_ids;
+ for (const E::ExprId grouping_expr_id : grouping_expr_ids) {
+ output_partition_expr_ids.push_back({ grouping_expr_id });
+ }
+
+ output_partition_scheme_header = make_unique<P::PartitionSchemeHeader>(
+ P::PartitionSchemeHeader::PartitionType::kHash,
+ input_partition_scheme_header->num_partitions,
+ move(output_partition_expr_ids));
+
+ return aggregate->copyWithNewOutputPartitionSchemeHeader(output_partition_scheme_header.release());
+ }
+ }
+
+ const vector<E::AliasPtr> &aggregate_expressions = aggregate->aggregate_expressions();
+ const E::PredicatePtr &filter_predicate = aggregate->filter_predicate();
+
+ vector<E::AliasPtr> partial_aggregate_expressions;
+ vector<E::AttributeReferencePtr> non_recompute_aggregate_expressions;
+ // tuple<Avg, Sum, Count>.
+ vector<std::tuple<E::AliasPtr, E::AttributeReferencePtr, E::AttributeReferencePtr>> avg_recompute_expressions;
+ for (const E::AliasPtr &aggregate_expression : aggregate_expressions) {
+ E::AggregateFunctionPtr aggr_func;
+ CHECK(E::SomeAggregateFunction::MatchesWithConditionalCast(aggregate_expression->expression(), &aggr_func));
+
+ bool uses_partial_aggregate = false;
+ if (aggr_func->is_distinct()) {
+ const vector<E::AttributeReferencePtr> distinct_referenced_attributes =
+ aggr_func->getReferencedAttributes();
+ DCHECK_EQ(1u, distinct_referenced_attributes.size());
+
+ if (grouping_expr_ids.find(distinct_referenced_attributes.front()->id()) == grouping_expr_ids.end()) {
+ // Create a new aggregate whose input has no partitions.
+ return P::Aggregate::Create(Repartition(input, nullptr), grouping_expressions, aggregate_expressions,
+ filter_predicate);
+ }
+
+ uses_partial_aggregate = true;
+ } else if (aggr_func->getAggregate().getAggregationID() != AggregationID::kAvg) {
+ uses_partial_aggregate = true;
+ }
+
+ if (uses_partial_aggregate) {
+ partial_aggregate_expressions.push_back(aggregate_expression);
+ non_recompute_aggregate_expressions.push_back(E::ToRef(aggregate_expression));
+
+ continue;
+ }
+
+ DCHECK(aggr_func->getAggregate().getAggregationID() == AggregationID::kAvg);
+ const auto &arguments = aggr_func->getArguments();
+ DCHECK_EQ(1u, arguments.size());
+
+ // Sum
+ const E::AggregateFunctionPtr sum_expr =
+ E::AggregateFunction::Create(AggregateFunctionFactory::Get(AggregationID::kSum),
+ arguments,
+ aggr_func->is_vector_aggregate(),
+ aggr_func->is_distinct());
+ partial_aggregate_expressions.push_back(
+ E::Alias::Create(optimizer_context_->nextExprId(),
+ sum_expr,
+ aggregate_expression->attribute_name(),
+ aggregate_expression->attribute_alias(),
+ aggregate_expression->relation_name()));
+ const E::AttributeReferencePtr sum_attr = E::ToRef(partial_aggregate_expressions.back());
+
+ // Count
+ const E::AggregateFunctionPtr count_expr =
+ E::AggregateFunction::Create(AggregateFunctionFactory::Get(AggregationID::kCount),
+ arguments,
+ aggr_func->is_vector_aggregate(),
+ aggr_func->is_distinct());
+ partial_aggregate_expressions.push_back(
+ E::Alias::Create(optimizer_context_->nextExprId(),
+ count_expr,
+ aggregate_expression->attribute_name(),
+ aggregate_expression->attribute_alias(),
+ aggregate_expression->relation_name()));
+ avg_recompute_expressions.emplace_back(aggregate_expression, sum_attr,
+ E::ToRef(partial_aggregate_expressions.back()));
+ }
+
+ if (!grouping_expressions.empty()) {
+ P::PartitionSchemeHeader::PartitionExprIds output_partition_expr_ids;
+ for (const E::NamedExpressionPtr &grouping_expression : grouping_expressions) {
+ output_partition_expr_ids.push_back({ grouping_expression->id() });
+ }
+ output_partition_scheme_header = make_unique<P::PartitionSchemeHeader>(
+ P::PartitionSchemeHeader::PartitionType::kHash,
+ input_partition_scheme_header->num_partitions,
+ move(output_partition_expr_ids));
+ }
+ const P::PhysicalPtr partial_aggregate =
+ avg_recompute_expressions.empty()
+ ? aggregate->copyWithNewOutputPartitionSchemeHeader(output_partition_scheme_header.release())
+ : P::Aggregate::Create(input, grouping_expressions, partial_aggregate_expressions,
+ filter_predicate, output_partition_scheme_header.release());
+
+ vector<E::AliasPtr> reaggregate_expressions;
+ for (const auto &aggregate_expr : partial_aggregate_expressions) {
+ reaggregate_expressions.push_back(GetReaggregateExpression(aggregate_expr));
+ }
+
+ if (!grouping_expressions.empty()) {
+ P::PartitionSchemeHeader::PartitionExprIds output_partition_expr_ids;
+ for (const E::NamedExpressionPtr &grouping_expression : grouping_expressions) {
+ output_partition_expr_ids.push_back({ grouping_expression->id() });
+ }
+ output_partition_scheme_header = make_unique<P::PartitionSchemeHeader>(
+ P::PartitionSchemeHeader::PartitionType::kHash,
+ input_partition_scheme_header->num_partitions,
+ move(output_partition_expr_ids));
+ }
+ const P::AggregatePtr reaggregate =
+ P::Aggregate::Create(partial_aggregate, grouping_expressions, reaggregate_expressions,
+ nullptr /* filter_predicate */, output_partition_scheme_header.release());
+
+ if (avg_recompute_expressions.empty()) {
+ return reaggregate;
+ }
+
+ vector<E::NamedExpressionPtr> project_expressions;
+ for (const auto &grouping_expr : grouping_expressions) {
+ project_expressions.emplace_back(E::ToRef(grouping_expr));
+ }
+
+ for (const E::AttributeReferencePtr &non_recompute_aggregate_expression : non_recompute_aggregate_expressions) {
+ project_expressions.emplace_back(non_recompute_aggregate_expression);
+ }
+ for (const auto &avg_recompute_expression : avg_recompute_expressions) {
+ const auto &avg_expr = get<0>(avg_recompute_expression);
+ // Obtain AVG by evaluating SUM/COUNT in Selection.
+ const BinaryOperation ÷_op =
+ BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kDivide);
+ const E::BinaryExpressionPtr new_avg_expr =
+ E::BinaryExpression::Create(divide_op,
+ get<1>(avg_recompute_expression),
+ get<2>(avg_recompute_expression));
+ project_expressions.emplace_back(
+ E::Alias::Create(avg_expr->id(),
+ new_avg_expr,
+ avg_expr->attribute_name(),
+ avg_expr->attribute_alias(),
+ avg_expr->relation_name()));
+ }
+
+ if (!grouping_expressions.empty()) {
+ output_partition_scheme_header =
+ make_unique<P::PartitionSchemeHeader>(*reaggregate->getOutputPartitionSchemeHeader());
+ }
+ return P::Selection::Create(reaggregate, project_expressions, nullptr /* filter_predicate */,
+ output_partition_scheme_header.release());
+ }
case P::PhysicalType::kHashJoin: {
const P::HashJoinPtr hash_join = static_pointer_cast<const P::HashJoin>(node);
@@ -191,7 +419,11 @@
const P::PartitionSchemeHeader *input_partition_scheme_header =
selection->input()->getOutputPartitionSchemeHeader();
- if (input_partition_scheme_header && input_partition_scheme_header->isHashPartition()) {
+ if (!input_partition_scheme_header) {
+ break;
+ }
+
+ if (input_partition_scheme_header->isHashPartition()) {
unordered_set<E::ExprId> project_expr_ids;
for (const E::AttributeReferencePtr &project_expression : selection->getOutputAttributes()) {
project_expr_ids.insert(project_expression->id());
diff --git a/query_optimizer/tests/execution_generator/Partition.test b/query_optimizer/tests/execution_generator/Partition.test
index eb3ec98..850a981 100644
--- a/query_optimizer/tests/execution_generator/Partition.test
+++ b/query_optimizer/tests/execution_generator/Partition.test
@@ -127,3 +127,99 @@
| 22| 22 4.690416|
| 24| 24 4.898979|
+-----------+--------------------+
+==
+
+# Partitioned aggregation.
+SELECT COUNT(*)
+FROM dim_4_hash_partitions;
+--
++--------------------+
+|COUNT(*) |
++--------------------+
+| 22|
++--------------------+
+==
+
+# Partitioned aggregation where the partition attributes are the subset of the group-by keys.
+SELECT id, COUNT(*)
+FROM dim_4_hash_partitions
+WHERE id > 0
+GROUP BY id;
+--
++-----------+--------------------+
+|id |COUNT(*) |
++-----------+--------------------+
+| 4| 1|
+| 8| 1|
+| 12| 1|
+| 16| 1|
+| 24| 1|
+| 2| 1|
+| 6| 1|
+| 14| 1|
+| 18| 1|
+| 22| 1|
++-----------+--------------------+
+==
+
+SELECT char_col, COUNT(*)
+FROM dim_4_hash_partitions
+WHERE id < 0
+GROUP BY char_col;
+--
++--------------------+--------------------+
+|char_col |COUNT(*) |
++--------------------+--------------------+
+| -3 1.732051| 1|
+| -11 3.316625| 1|
+| -19 4.358899| 1|
+| -17 4.123106| 1|
+| -15 3.872983| 1|
+| -7 2.645751| 1|
+| -1 1.000000| 1|
+| -23 4.795832| 1|
+| -13 3.605551| 1|
+| -9 3.000000| 1|
+| -21 4.582576| 1|
+| -5 2.236068| 1|
++--------------------+--------------------+
+==
+
+SELECT fact.score, COUNT(*)
+FROM dim_4_hash_partitions JOIN fact ON dim_4_hash_partitions.id = fact.id
+GROUP BY fact.score;
+--
++------------------------+--------------------+
+|score |COUNT(*) |
++------------------------+--------------------+
+| 41.569219381653056| 1|
+| 76.367532368147124| 1|
+| 64| 1|
+| 52.38320341483518| 1|
+| 8| 1|
+| 2.8284271247461903| 1|
+| 14.696938456699067| 1|
+| 22.627416997969522| 1|
+| 117.57550765359254| 1|
+| 103.18914671611546| 1|
++------------------------+--------------------+
+==
+
+SELECT fact.id, AVG(fact.score)
+FROM dim_4_hash_partitions JOIN fact ON dim_4_hash_partitions.id = fact.id
+GROUP BY fact.id;
+--
++-----------+------------------------+
+|id |AVG(fact.score) |
++-----------+------------------------+
+| 4| 8|
+| 8| 22.627416997969522|
+| 12| 41.569219381653056|
+| 16| 64|
+| 24| 117.57550765359254|
+| 2| 2.8284271247461903|
+| 6| 14.696938456699067|
+| 14| 52.38320341483518|
+| 18| 76.367532368147124|
+| 22| 103.18914671611546|
++-----------+------------------------+