Collapse Selections with predicates.
diff --git a/query_optimizer/rules/CMakeLists.txt b/query_optimizer/rules/CMakeLists.txt
index f8df32b..0cd7212 100644
--- a/query_optimizer/rules/CMakeLists.txt
+++ b/query_optimizer/rules/CMakeLists.txt
@@ -86,7 +86,11 @@
quickstep_queryoptimizer_rules_RuleHelper
quickstep_utility_Macros)
target_link_libraries(quickstep_queryoptimizer_rules_CollapseSelection
+ quickstep_queryoptimizer_expressions_Expression
+ quickstep_queryoptimizer_expressions_LogicalAnd
quickstep_queryoptimizer_expressions_NamedExpression
+ quickstep_queryoptimizer_expressions_PatternMatcher
+ quickstep_queryoptimizer_expressions_Predicate
quickstep_queryoptimizer_physical_PatternMatcher
quickstep_queryoptimizer_physical_Physical
quickstep_queryoptimizer_physical_Selection
diff --git a/query_optimizer/rules/CollapseSelection.cpp b/query_optimizer/rules/CollapseSelection.cpp
index fc45ffd..40f9375 100644
--- a/query_optimizer/rules/CollapseSelection.cpp
+++ b/query_optimizer/rules/CollapseSelection.cpp
@@ -21,7 +21,11 @@
#include <vector>
+#include "query_optimizer/expressions/Expression.hpp"
+#include "query_optimizer/expressions/LogicalAnd.hpp"
#include "query_optimizer/expressions/NamedExpression.hpp"
+#include "query_optimizer/expressions/PatternMatcher.hpp"
+#include "query_optimizer/expressions/Predicate.hpp"
#include "query_optimizer/physical/PatternMatcher.hpp"
#include "query_optimizer/physical/Physical.hpp"
#include "query_optimizer/physical/Selection.hpp"
@@ -37,20 +41,36 @@
P::SelectionPtr selection;
P::SelectionPtr child_selection;
- // TODO(jianqiao): Handle the case where filter predicates are present.
if (P::SomeSelection::MatchesWithConditionalCast(input, &selection) &&
- P::SomeSelection::MatchesWithConditionalCast(selection->input(), &child_selection) &&
- selection->filter_predicate() == nullptr &&
- child_selection->filter_predicate() == nullptr) {
+ P::SomeSelection::MatchesWithConditionalCast(selection->input(), &child_selection)) {
+ E::PredicatePtr filter_predicate = selection->filter_predicate();
+
+ std::vector<E::ExpressionPtr> non_project_expressions;
+ if (filter_predicate) {
+ non_project_expressions.push_back(filter_predicate);
+ }
+
std::vector<E::NamedExpressionPtr> project_expressions =
selection->project_expressions();
PullUpProjectExpressions(child_selection->project_expressions(),
- {} /* non_project_expression_lists */,
- { &project_expressions } /* project_expression_lists */);
+ { &non_project_expressions }, { &project_expressions });
+
+ const E::PredicatePtr &child_filter_predicate = child_selection->filter_predicate();
+ if (filter_predicate) {
+ CHECK(E::SomePredicate::MatchesWithConditionalCast(non_project_expressions[0],
+ &filter_predicate))
+ << non_project_expressions[0]->toString();
+ if (child_filter_predicate) {
+ filter_predicate = E::LogicalAnd::Create({ filter_predicate, child_filter_predicate });
+ }
+ } else {
+ filter_predicate = child_filter_predicate;
+ }
+
return P::Selection::Create(child_selection->input(),
project_expressions,
- selection->filter_predicate(),
- child_selection->input()->cloneOutputPartitionSchemeHeader());
+ filter_predicate,
+ selection->cloneOutputPartitionSchemeHeader());
}
return input;
diff --git a/query_optimizer/tests/physical_generator/Select.test b/query_optimizer/tests/physical_generator/Select.test
index 614347b..dc923ae 100644
--- a/query_optimizer/tests/physical_generator/Select.test
+++ b/query_optimizer/tests/physical_generator/Select.test
@@ -1022,51 +1022,34 @@
[Physical Plan]
TopLevelPlan
+-plan=Selection
-| +-input=Selection
-| | +-input=Aggregate
-| | | +-input=TableReference[relation=Test,alias=test]
-| | | | +-AttributeReference[id=0,name=int_col,relation=test,type=Int NULL]
-| | | | +-AttributeReference[id=1,name=long_col,relation=test,type=Long]
-| | | | +-AttributeReference[id=2,name=float_col,relation=test,type=Float]
-| | | | +-AttributeReference[id=3,name=double_col,relation=test,type=Double NULL]
-| | | | +-AttributeReference[id=4,name=char_col,relation=test,type=Char(20)]
-| | | | +-AttributeReference[id=5,name=vchar_col,relation=test,
-| | | | type=VarChar(20) NULL]
-| | | +-grouping_expressions=
-| | | | +-[]
-| | | +-aggregate_expressions=
-| | | +-Alias[id=6,name=,alias=$aggregate0,relation=$aggregate,type=Long]
-| | | | +-AggregateFunction[function=COUNT]
-| | | | +-[]
-| | | +-Alias[id=7,name=,alias=$aggregate1,relation=$aggregate,type=Long]
-| | | | +-AggregateFunction[function=COUNT]
-| | | | +-AttributeReference[id=0,name=int_col,relation=test,type=Int NULL]
-| | | +-Alias[id=8,name=,alias=$aggregate2,relation=$aggregate,type=Long NULL]
-| | | | +-AggregateFunction[function=SUM]
-| | | | +-AttributeReference[id=1,name=long_col,relation=test,type=Long]
-| | | +-Alias[id=12,name=,alias=$aggregate3,relation=$aggregate,type=Long NULL]
-| | | | +-AggregateFunction[function=SUM]
-| | | | +-AttributeReference[id=0,name=int_col,relation=test,type=Int NULL]
-| | | +-Alias[id=11,name=,alias=$aggregate4,relation=$aggregate,
-| | | type=Double NULL]
-| | | +-AggregateFunction[function=MAX]
-| | | +-AttributeReference[id=3,name=double_col,relation=test,
-| | | type=Double NULL]
-| | +-project_expressions=
-| | +-AttributeReference[id=6,name=,alias=$aggregate0,relation=$aggregate,
-| | | type=Long]
-| | +-AttributeReference[id=7,name=,alias=$aggregate1,relation=$aggregate,
-| | | type=Long]
-| | +-AttributeReference[id=8,name=,alias=$aggregate2,relation=$aggregate,
-| | | type=Long NULL]
-| | +-Alias[id=9,name=,alias=$aggregate3,relation=$aggregate,type=Long NULL]
-| | | +-Divide
-| | | +-AttributeReference[id=12,name=,alias=$aggregate3,
-| | | | relation=$aggregate,type=Long NULL]
-| | | +-AttributeReference[id=7,name=,alias=$aggregate1,relation=$aggregate,
-| | | type=Long]
-| | +-AttributeReference[id=11,name=,alias=$aggregate4,relation=$aggregate,
-| | type=Double NULL]
+| +-input=Aggregate
+| | +-input=TableReference[relation=Test,alias=test]
+| | | +-AttributeReference[id=0,name=int_col,relation=test,type=Int NULL]
+| | | +-AttributeReference[id=1,name=long_col,relation=test,type=Long]
+| | | +-AttributeReference[id=2,name=float_col,relation=test,type=Float]
+| | | +-AttributeReference[id=3,name=double_col,relation=test,type=Double NULL]
+| | | +-AttributeReference[id=4,name=char_col,relation=test,type=Char(20)]
+| | | +-AttributeReference[id=5,name=vchar_col,relation=test,
+| | | type=VarChar(20) NULL]
+| | +-grouping_expressions=
+| | | +-[]
+| | +-aggregate_expressions=
+| | +-Alias[id=6,name=,alias=$aggregate0,relation=$aggregate,type=Long]
+| | | +-AggregateFunction[function=COUNT]
+| | | +-[]
+| | +-Alias[id=7,name=,alias=$aggregate1,relation=$aggregate,type=Long]
+| | | +-AggregateFunction[function=COUNT]
+| | | +-AttributeReference[id=0,name=int_col,relation=test,type=Int NULL]
+| | +-Alias[id=8,name=,alias=$aggregate2,relation=$aggregate,type=Long NULL]
+| | | +-AggregateFunction[function=SUM]
+| | | +-AttributeReference[id=1,name=long_col,relation=test,type=Long]
+| | +-Alias[id=12,name=,alias=$aggregate3,relation=$aggregate,type=Long NULL]
+| | | +-AggregateFunction[function=SUM]
+| | | +-AttributeReference[id=0,name=int_col,relation=test,type=Int NULL]
+| | +-Alias[id=11,name=,alias=$aggregate4,relation=$aggregate,type=Double NULL]
+| | +-AggregateFunction[function=MAX]
+| | +-AttributeReference[id=3,name=double_col,relation=test,
+| | type=Double NULL]
| +-filter_predicate=Greater
| | +-Add
| | | +-AttributeReference[id=11,name=,alias=$aggregate4,relation=$aggregate,
@@ -1081,10 +1064,13 @@
| | | | type=Long]
| | | +-AttributeReference[id=8,name=,alias=$aggregate2,relation=$aggregate,
| | | type=Long NULL]
-| | +-AttributeReference[id=9,name=,alias=$aggregate3,relation=$aggregate,
-| | type=Double NULL]
+| | +-Divide
+| | +-AttributeReference[id=12,name=,alias=$aggregate3,
+| | | relation=$aggregate,type=Long NULL]
+| | +-AttributeReference[id=7,name=,alias=$aggregate1,relation=$aggregate,
+| | type=Long]
| +-project_expressions=
-| +-Alias[id=10,name=col,relation=,type=Double NULL]
+| +-Alias[id=10,name=col,relation=,type=Long NULL]
| +-Add
| +-AttributeReference[id=6,name=,alias=$aggregate0,relation=$aggregate,
| | type=Long]
@@ -1094,10 +1080,13 @@
| | | relation=$aggregate,type=Long]
| | +-AttributeReference[id=8,name=,alias=$aggregate2,
| | relation=$aggregate,type=Long NULL]
-| +-AttributeReference[id=9,name=,alias=$aggregate3,relation=$aggregate,
-| type=Double NULL]
+| +-Divide
+| +-AttributeReference[id=12,name=,alias=$aggregate3,
+| | relation=$aggregate,type=Long NULL]
+| +-AttributeReference[id=7,name=,alias=$aggregate1,
+| relation=$aggregate,type=Long]
+-output_attributes=
- +-AttributeReference[id=10,name=col,relation=,type=Double NULL]
+ +-AttributeReference[id=10,name=col,relation=,type=Long NULL]
==
select long_col as col1, count(*) as col2
@@ -3445,4 +3434,3 @@
| +-[]
+-output_attributes=
+-AttributeReference[id=0,name=int_col,relation=test,type=Int NULL]
-==