Added physical rule for partitioned aggregations.
diff --git a/query_optimizer/rules/CMakeLists.txt b/query_optimizer/rules/CMakeLists.txt
index 0cd7212..7abe0d1 100644
--- a/query_optimizer/rules/CMakeLists.txt
+++ b/query_optimizer/rules/CMakeLists.txt
@@ -163,13 +163,19 @@
 target_link_libraries(quickstep_queryoptimizer_rules_Partition
                       glog
                       gtest
+                      quickstep_expressions_aggregation_AggregateFunction
+                      quickstep_expressions_aggregation_AggregateFunctionFactory
+                      quickstep_expressions_aggregation_AggregationID
                       quickstep_queryoptimizer_OptimizerContext
                       quickstep_queryoptimizer_costmodel_StarSchemaSimpleCostModel
+                      quickstep_queryoptimizer_expressions_AggregateFunction
                       quickstep_queryoptimizer_expressions_AttributeReference
+                      quickstep_queryoptimizer_expressions_BinaryExpression
                       quickstep_queryoptimizer_expressions_ExprId
                       quickstep_queryoptimizer_expressions_ExpressionUtil
                       quickstep_queryoptimizer_expressions_NamedExpression
                       quickstep_queryoptimizer_expressions_PatternMatcher
+                      quickstep_queryoptimizer_physical_Aggregate
                       quickstep_queryoptimizer_physical_HashJoin
                       quickstep_queryoptimizer_physical_PartitionSchemeHeader
                       quickstep_queryoptimizer_physical_PatternMatcher
@@ -179,6 +185,9 @@
                       quickstep_queryoptimizer_physical_TableReference
                       quickstep_queryoptimizer_physical_TopLevelPlan
                       quickstep_queryoptimizer_rules_BottomUpRule
+                      quickstep_types_operations_binaryoperations_BinaryOperation
+                      quickstep_types_operations_binaryoperations_BinaryOperationFactory
+                      quickstep_types_operations_binaryoperations_BinaryOperationID
                       quickstep_utility_Cast
                       quickstep_utility_EqualsAnyConstant
                       quickstep_utility_Macros
diff --git a/query_optimizer/rules/Partition.cpp b/query_optimizer/rules/Partition.cpp
index 39546c6..41b2f1f 100644
--- a/query_optimizer/rules/Partition.cpp
+++ b/query_optimizer/rules/Partition.cpp
@@ -22,17 +22,24 @@
 #include <cstddef>
 #include <cstdint>
 #include <memory>
+#include <tuple>
 #include <unordered_set>
 #include <utility>
 #include <vector>
 
+#include "expressions/aggregation/AggregateFunction.hpp"
+#include "expressions/aggregation/AggregateFunctionFactory.hpp"
+#include "expressions/aggregation/AggregationID.hpp"
 #include "query_optimizer/OptimizerContext.hpp"
 #include "query_optimizer/cost_model/StarSchemaSimpleCostModel.hpp"
+#include "query_optimizer/expressions/AggregateFunction.hpp"
 #include "query_optimizer/expressions/AttributeReference.hpp"
+#include "query_optimizer/expressions/BinaryExpression.hpp"
 #include "query_optimizer/expressions/ExprId.hpp"
 #include "query_optimizer/expressions/ExpressionUtil.hpp"
 #include "query_optimizer/expressions/NamedExpression.hpp"
 #include "query_optimizer/expressions/PatternMatcher.hpp"
+#include "query_optimizer/physical/Aggregate.hpp"
 #include "query_optimizer/physical/HashJoin.hpp"
 #include "query_optimizer/physical/PartitionSchemeHeader.hpp"
 #include "query_optimizer/physical/PatternMatcher.hpp"
@@ -41,12 +48,16 @@
 #include "query_optimizer/physical/Selection.hpp"
 #include "query_optimizer/physical/TableReference.hpp"
 #include "query_optimizer/physical/TopLevelPlan.hpp"
+#include "types/operations/binary_operations/BinaryOperation.hpp"
+#include "types/operations/binary_operations/BinaryOperationFactory.hpp"
+#include "types/operations/binary_operations/BinaryOperationID.hpp"
 #include "utility/Cast.hpp"
 #include "utility/EqualsAnyConstant.hpp"
 
 #include "gflags/gflags.h"
 #include "glog/logging.h"
 
+using std::get;
 using std::make_unique;
 using std::move;
 using std::size_t;
@@ -84,6 +95,18 @@
                                        P::PhysicalType::kUnionAll);
 }
 
+P::PhysicalPtr Repartition(const P::PhysicalPtr &node, P::PartitionSchemeHeader *repartition_scheme_header) {
+  if (needsSelection(node->getPhysicalType())) {
+    // Add a Selection node.
+    return P::Selection::Create(node,
+                                CastSharedPtrVector<E::NamedExpression>(node->getOutputAttributes()),
+                                nullptr /* filter_predicate */, repartition_scheme_header);
+  } else {
+    // Overwrite the output partition scheme header of the node.
+    return node->copyWithNewOutputPartitionSchemeHeader(repartition_scheme_header);
+  }
+}
+
 P::PhysicalPtr HashRepartition(const P::PhysicalPtr &node,
                                const vector<E::AttributeReferencePtr> &repartition_attributes,
                                const size_t num_repartitions) {
@@ -94,24 +117,229 @@
   auto repartition_scheme_header = make_unique<P::PartitionSchemeHeader>(
       P::PartitionSchemeHeader::PartitionType::kHash, num_repartitions, move(repartition_expr_ids));
 
-  if (needsSelection(node->getPhysicalType())) {
-    // Add a Selection node.
-    return P::Selection::Create(node,
-                                CastSharedPtrVector<E::NamedExpression>(node->getOutputAttributes()),
-                                nullptr /* filter_predicate */, repartition_scheme_header.release());
-  } else {
-    // Overwrite the output partition scheme header of the node.
-    return node->copyWithNewOutputPartitionSchemeHeader(repartition_scheme_header.release());
+  return Repartition(node, repartition_scheme_header.release());
+}
+
+E::AliasPtr GetReaggregateExpression(const E::AliasPtr &aggr_alias) {
+  E::ExpressionPtr aggr = aggr_alias->expression();
+
+  E::AggregateFunctionPtr aggr_fn;
+  CHECK(E::SomeAggregateFunction::MatchesWithConditionalCast(aggr, &aggr_fn))
+      << aggr->toString();
+
+  AggregationID reaggr_id;
+  switch (aggr_fn->getAggregate().getAggregationID()) {
+    case AggregationID::kCount:
+    case AggregationID::kSum: {
+      reaggr_id = AggregationID::kSum;
+      break;
+    }
+    case AggregationID::kMax: {
+      reaggr_id = AggregationID::kMax;
+      break;
+    }
+    case AggregationID::kMin: {
+      reaggr_id = AggregationID::kMin;
+      break;
+    }
+    default:
+      LOG(FATAL) << "Unsupported aggregation id for re-aggregate";
   }
+
+  const E::AggregateFunctionPtr reaggr =
+      E::AggregateFunction::Create(AggregateFunctionFactory::Get(reaggr_id),
+                                   { E::ToRef(aggr_alias) },
+                                   aggr_fn->is_vector_aggregate(),
+                                   aggr_fn->is_distinct());
+
+  return E::Alias::Create(aggr_alias->id(),
+                          reaggr,
+                          aggr_alias->attribute_name(),
+                          aggr_alias->attribute_alias(),
+                          aggr_alias->relation_name());
 }
 
 }  // namespace
 
 P::PhysicalPtr Partition::applyToNode(const P::PhysicalPtr &node) {
-  // Will be used for aggregations.
-  (void) optimizer_context_;
-
   switch (node->getPhysicalType()) {
+    case P::PhysicalType::kAggregate: {
+      const P::AggregatePtr aggregate = static_pointer_cast<const P::Aggregate>(node);
+
+      const P::PhysicalPtr input = aggregate->input();
+      const P::PartitionSchemeHeader *input_partition_scheme_header =
+          input->getOutputPartitionSchemeHeader();
+
+      if (!input_partition_scheme_header) {
+        break;
+      }
+
+      std::unique_ptr<P::PartitionSchemeHeader> output_partition_scheme_header;
+      const vector<E::NamedExpressionPtr> &grouping_expressions = aggregate->grouping_expressions();
+
+      unordered_set<E::ExprId> grouping_expr_ids;
+      for (const E::NamedExpressionPtr &grouping_expression : grouping_expressions) {
+        grouping_expr_ids.insert(grouping_expression->id());
+      }
+
+      if (!grouping_expressions.empty()) {
+        if (input_partition_scheme_header->reusablePartitionScheme(grouping_expr_ids)) {
+          // We do not need to reaggregate iff the list of partition attributes is
+          // a subset of the group by list.
+          P::PartitionSchemeHeader::PartitionExprIds output_partition_expr_ids;
+          for (const E::ExprId grouping_expr_id : grouping_expr_ids) {
+            output_partition_expr_ids.push_back({ grouping_expr_id });
+          }
+
+          output_partition_scheme_header = make_unique<P::PartitionSchemeHeader>(
+              P::PartitionSchemeHeader::PartitionType::kHash,
+              input_partition_scheme_header->num_partitions,
+              move(output_partition_expr_ids));
+
+          return aggregate->copyWithNewOutputPartitionSchemeHeader(output_partition_scheme_header.release());
+        }
+      }
+
+      const vector<E::AliasPtr> &aggregate_expressions = aggregate->aggregate_expressions();
+      const E::PredicatePtr &filter_predicate = aggregate->filter_predicate();
+
+      vector<E::AliasPtr> partial_aggregate_expressions;
+      vector<E::AttributeReferencePtr> non_recompute_aggregate_expressions;
+      // tuple<Avg, Sum, Count>.
+      vector<std::tuple<E::AliasPtr, E::AttributeReferencePtr, E::AttributeReferencePtr>> avg_recompute_expressions;
+      for (const E::AliasPtr &aggregate_expression : aggregate_expressions) {
+        E::AggregateFunctionPtr aggr_func;
+        CHECK(E::SomeAggregateFunction::MatchesWithConditionalCast(aggregate_expression->expression(), &aggr_func));
+
+        bool uses_partial_aggregate = false;
+        if (aggr_func->is_distinct()) {
+          const vector<E::AttributeReferencePtr> distinct_referenced_attributes =
+              aggr_func->getReferencedAttributes();
+          DCHECK_EQ(1u, distinct_referenced_attributes.size());
+
+          if (grouping_expr_ids.find(distinct_referenced_attributes.front()->id()) == grouping_expr_ids.end()) {
+            // Create a new aggregate whose input has no partitions.
+            return P::Aggregate::Create(Repartition(input, nullptr), grouping_expressions, aggregate_expressions,
+                                        filter_predicate);
+          }
+
+          uses_partial_aggregate = true;
+        } else if (aggr_func->getAggregate().getAggregationID() != AggregationID::kAvg) {
+          uses_partial_aggregate = true;
+        }
+
+        if (uses_partial_aggregate) {
+          partial_aggregate_expressions.push_back(aggregate_expression);
+          non_recompute_aggregate_expressions.push_back(E::ToRef(aggregate_expression));
+
+          continue;
+        }
+
+        DCHECK(aggr_func->getAggregate().getAggregationID() == AggregationID::kAvg);
+        const auto &arguments = aggr_func->getArguments();
+        DCHECK_EQ(1u, arguments.size());
+
+        // Sum
+        const E::AggregateFunctionPtr sum_expr =
+            E::AggregateFunction::Create(AggregateFunctionFactory::Get(AggregationID::kSum),
+                                         arguments,
+                                         aggr_func->is_vector_aggregate(),
+                                         aggr_func->is_distinct());
+        partial_aggregate_expressions.push_back(
+            E::Alias::Create(optimizer_context_->nextExprId(),
+                             sum_expr,
+                             aggregate_expression->attribute_name(),
+                             aggregate_expression->attribute_alias(),
+                             aggregate_expression->relation_name()));
+        const E::AttributeReferencePtr sum_attr = E::ToRef(partial_aggregate_expressions.back());
+
+        // Count
+        const E::AggregateFunctionPtr count_expr =
+            E::AggregateFunction::Create(AggregateFunctionFactory::Get(AggregationID::kCount),
+                                         arguments,
+                                         aggr_func->is_vector_aggregate(),
+                                         aggr_func->is_distinct());
+        partial_aggregate_expressions.push_back(
+            E::Alias::Create(optimizer_context_->nextExprId(),
+                             count_expr,
+                             aggregate_expression->attribute_name(),
+                             aggregate_expression->attribute_alias(),
+                             aggregate_expression->relation_name()));
+        avg_recompute_expressions.emplace_back(aggregate_expression, sum_attr,
+                                               E::ToRef(partial_aggregate_expressions.back()));
+      }
+
+      if (!grouping_expressions.empty()) {
+        P::PartitionSchemeHeader::PartitionExprIds output_partition_expr_ids;
+        for (const E::NamedExpressionPtr &grouping_expression : grouping_expressions) {
+          output_partition_expr_ids.push_back({ grouping_expression->id() });
+        }
+        output_partition_scheme_header = make_unique<P::PartitionSchemeHeader>(
+            P::PartitionSchemeHeader::PartitionType::kHash,
+            input_partition_scheme_header->num_partitions,
+            move(output_partition_expr_ids));
+      }
+      const P::PhysicalPtr partial_aggregate =
+          avg_recompute_expressions.empty()
+              ? aggregate->copyWithNewOutputPartitionSchemeHeader(output_partition_scheme_header.release())
+              : P::Aggregate::Create(input, grouping_expressions, partial_aggregate_expressions,
+                                     filter_predicate, output_partition_scheme_header.release());
+
+      vector<E::AliasPtr> reaggregate_expressions;
+      for (const auto &aggregate_expr : partial_aggregate_expressions) {
+        reaggregate_expressions.push_back(GetReaggregateExpression(aggregate_expr));
+      }
+
+      if (!grouping_expressions.empty()) {
+        P::PartitionSchemeHeader::PartitionExprIds output_partition_expr_ids;
+        for (const E::NamedExpressionPtr &grouping_expression : grouping_expressions) {
+          output_partition_expr_ids.push_back({ grouping_expression->id() });
+        }
+        output_partition_scheme_header = make_unique<P::PartitionSchemeHeader>(
+            P::PartitionSchemeHeader::PartitionType::kHash,
+            input_partition_scheme_header->num_partitions,
+            move(output_partition_expr_ids));
+      }
+      const P::AggregatePtr reaggregate =
+          P::Aggregate::Create(partial_aggregate, grouping_expressions, reaggregate_expressions,
+                               nullptr /* filter_predicate */, output_partition_scheme_header.release());
+
+      if (avg_recompute_expressions.empty()) {
+        return reaggregate;
+      }
+
+      vector<E::NamedExpressionPtr> project_expressions;
+      for (const auto &grouping_expr : grouping_expressions) {
+        project_expressions.emplace_back(E::ToRef(grouping_expr));
+      }
+
+      for (const E::AttributeReferencePtr &non_recompute_aggregate_expression : non_recompute_aggregate_expressions) {
+        project_expressions.emplace_back(non_recompute_aggregate_expression);
+      }
+      for (const auto &avg_recompute_expression : avg_recompute_expressions) {
+        const auto &avg_expr = get<0>(avg_recompute_expression);
+        // Obtain AVG by evaluating SUM/COUNT in Selection.
+        const BinaryOperation &divide_op =
+            BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kDivide);
+        const E::BinaryExpressionPtr new_avg_expr =
+            E::BinaryExpression::Create(divide_op,
+                                        get<1>(avg_recompute_expression),
+                                        get<2>(avg_recompute_expression));
+        project_expressions.emplace_back(
+            E::Alias::Create(avg_expr->id(),
+                             new_avg_expr,
+                             avg_expr->attribute_name(),
+                             avg_expr->attribute_alias(),
+                             avg_expr->relation_name()));
+      }
+
+      if (!grouping_expressions.empty()) {
+        output_partition_scheme_header =
+            make_unique<P::PartitionSchemeHeader>(*reaggregate->getOutputPartitionSchemeHeader());
+      }
+      return P::Selection::Create(reaggregate, project_expressions, nullptr /* filter_predicate */,
+                                  output_partition_scheme_header.release());
+    }
     case P::PhysicalType::kHashJoin: {
       const P::HashJoinPtr hash_join = static_pointer_cast<const P::HashJoin>(node);
 
@@ -191,7 +419,11 @@
 
       const P::PartitionSchemeHeader *input_partition_scheme_header =
           selection->input()->getOutputPartitionSchemeHeader();
-      if (input_partition_scheme_header && input_partition_scheme_header->isHashPartition()) {
+      if (!input_partition_scheme_header) {
+        break;
+      }
+
+      if (input_partition_scheme_header->isHashPartition()) {
         unordered_set<E::ExprId> project_expr_ids;
         for (const E::AttributeReferencePtr &project_expression : selection->getOutputAttributes()) {
           project_expr_ids.insert(project_expression->id());
diff --git a/query_optimizer/tests/execution_generator/Partition.test b/query_optimizer/tests/execution_generator/Partition.test
index eb3ec98..850a981 100644
--- a/query_optimizer/tests/execution_generator/Partition.test
+++ b/query_optimizer/tests/execution_generator/Partition.test
@@ -127,3 +127,99 @@
 |         22|         22 4.690416|
 |         24|         24 4.898979|
 +-----------+--------------------+
+==
+
+# Partitioned aggregation.
+SELECT COUNT(*)
+FROM dim_4_hash_partitions;
+--
++--------------------+
+|COUNT(*)            |
++--------------------+
+|                  22|
++--------------------+
+==
+
+# Partitioned aggregation where the partition attributes are the subset of the group-by keys.
+SELECT id, COUNT(*)
+FROM dim_4_hash_partitions
+WHERE id > 0
+GROUP BY id;
+--
++-----------+--------------------+
+|id         |COUNT(*)            |
++-----------+--------------------+
+|          4|                   1|
+|          8|                   1|
+|         12|                   1|
+|         16|                   1|
+|         24|                   1|
+|          2|                   1|
+|          6|                   1|
+|         14|                   1|
+|         18|                   1|
+|         22|                   1|
++-----------+--------------------+
+==
+
+SELECT char_col, COUNT(*)
+FROM dim_4_hash_partitions
+WHERE id < 0
+GROUP BY char_col;
+--
++--------------------+--------------------+
+|char_col            |COUNT(*)            |
++--------------------+--------------------+
+|         -3 1.732051|                   1|
+|        -11 3.316625|                   1|
+|        -19 4.358899|                   1|
+|        -17 4.123106|                   1|
+|        -15 3.872983|                   1|
+|         -7 2.645751|                   1|
+|         -1 1.000000|                   1|
+|        -23 4.795832|                   1|
+|        -13 3.605551|                   1|
+|         -9 3.000000|                   1|
+|        -21 4.582576|                   1|
+|         -5 2.236068|                   1|
++--------------------+--------------------+
+==
+
+SELECT fact.score, COUNT(*)
+FROM dim_4_hash_partitions JOIN fact ON dim_4_hash_partitions.id = fact.id
+GROUP BY fact.score;
+--
++------------------------+--------------------+
+|score                   |COUNT(*)            |
++------------------------+--------------------+
+|      41.569219381653056|                   1|
+|      76.367532368147124|                   1|
+|                      64|                   1|
+|       52.38320341483518|                   1|
+|                       8|                   1|
+|      2.8284271247461903|                   1|
+|      14.696938456699067|                   1|
+|      22.627416997969522|                   1|
+|      117.57550765359254|                   1|
+|      103.18914671611546|                   1|
++------------------------+--------------------+
+==
+
+SELECT fact.id, AVG(fact.score)
+FROM dim_4_hash_partitions JOIN fact ON dim_4_hash_partitions.id = fact.id
+GROUP BY fact.id;
+--
++-----------+------------------------+
+|id         |AVG(fact.score)         |
++-----------+------------------------+
+|          4|                       8|
+|          8|      22.627416997969522|
+|         12|      41.569219381653056|
+|         16|                      64|
+|         24|      117.57550765359254|
+|          2|      2.8284271247461903|
+|          6|      14.696938456699067|
+|         14|       52.38320341483518|
+|         18|      76.367532368147124|
+|         22|      103.18914671611546|
++-----------+------------------------+