Skip predicate pushdown if the node's input is a base relation.
diff --git a/query_optimizer/rules/PushDownLowCostDisjunctivePredicate.cpp b/query_optimizer/rules/PushDownLowCostDisjunctivePredicate.cpp
index 69cc299..e73442d 100644
--- a/query_optimizer/rules/PushDownLowCostDisjunctivePredicate.cpp
+++ b/query_optimizer/rules/PushDownLowCostDisjunctivePredicate.cpp
@@ -79,27 +79,30 @@
}
void PushDownLowCostDisjunctivePredicate::collectApplicablePredicates(
- const physical::PhysicalPtr &input) {
+ const physical::PhysicalPtr &node) {
P::TableReferencePtr table_reference;
- if (P::SomeTableReference::MatchesWithConditionalCast(input, &table_reference)) {
- applicable_nodes_.emplace_back(input, &table_reference->attribute_list());
+ if (P::SomeTableReference::MatchesWithConditionalCast(node, &table_reference)) {
+ applicable_nodes_.emplace_back(node, &table_reference->attribute_list());
return;
}
- for (const auto &child : input->children()) {
+ for (const auto &child : node->children()) {
collectApplicablePredicates(child);
}
+ physical::PhysicalPtr input;
E::PredicatePtr filter_predicate = nullptr;
- switch (input->getPhysicalType()) {
+ switch (node->getPhysicalType()) {
case P::PhysicalType::kAggregate: {
- filter_predicate =
- std::static_pointer_cast<const P::Aggregate>(input)->filter_predicate();
+ const P::AggregatePtr aggregate =
+ std::static_pointer_cast<const P::Aggregate>(node);
+ input = aggregate->input();
+ filter_predicate = aggregate->filter_predicate();
break;
}
case P::PhysicalType::kHashJoin: {
const P::HashJoinPtr hash_join =
- std::static_pointer_cast<const P::HashJoin>(input);
+ std::static_pointer_cast<const P::HashJoin>(node);
if (hash_join->join_type() == P::HashJoin::JoinType::kInnerJoin) {
filter_predicate = hash_join->residual_predicate();
}
@@ -107,18 +110,24 @@
}
case P::PhysicalType::kNestedLoopsJoin: {
filter_predicate =
- std::static_pointer_cast<const P::NestedLoopsJoin>(input)->join_predicate();
+ std::static_pointer_cast<const P::NestedLoopsJoin>(node)->join_predicate();
break;
}
case P::PhysicalType::kSelection: {
- filter_predicate =
- std::static_pointer_cast<const P::Selection>(input)->filter_predicate();
+ const P::SelectionPtr selection =
+ std::static_pointer_cast<const P::Selection>(node);
+ input = selection->input();
+ filter_predicate = selection->filter_predicate();
break;
}
default:
break;
}
+ if (input && input->getPhysicalType() == P::PhysicalType::kTableReference) {
+ return;
+ }
+
E::LogicalOrPtr disjunctive_predicate;
if (filter_predicate == nullptr ||
!E::SomeLogicalOr::MatchesWithConditionalCast(filter_predicate, &disjunctive_predicate)) {