ARROW-10760: [Rust] [DataFusion] Fixed error in filter push down over joins

This fixes an error on which a predicate depending on columns from both sides of a join was being pushed down through the join, causing incorrect plans.

This causes all filters to be independently pushed to each side of the join, at the same time keeping any predicate that cannot be pushed (e.g. because it depends on both sides of the join).

Closes #8797 from jorgecarleitao/fix_push

Authored-by: Jorge C. Leitao <jorgecarleitao@gmail.com>
Signed-off-by: Andy Grove <andygrove73@gmail.com>
diff --git a/rust/datafusion/src/optimizer/filter_push_down.rs b/rust/datafusion/src/optimizer/filter_push_down.rs
index 71af928..26d4410 100644
--- a/rust/datafusion/src/optimizer/filter_push_down.rs
+++ b/rust/datafusion/src/optimizer/filter_push_down.rs
@@ -14,6 +14,8 @@
 
 //! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan
 
+use arrow::datatypes::Schema;
+
 use crate::error::Result;
 use crate::logical_plan::Expr;
 use crate::logical_plan::{and, LogicalPlan};
@@ -57,38 +59,104 @@
     filters: Vec<(Expr, HashSet<String>)>,
 }
 
-/// builds a new [LogicalPlan] from `plan` by issuing new [LogicalPlan::Filter] if any of the filters
-/// in `state` depend on the columns `used_columns`.
-fn issue_filters(
-    mut state: State,
-    used_columns: HashSet<String>,
-    plan: &LogicalPlan,
-) -> Result<LogicalPlan> {
-    // pick all filters in the current state that depend on any of `used_columns`
-    let (predicates, predicate_columns): (Vec<_>, Vec<_>) = state
+type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet<String>>);
+
+/// returns all predicates in `state` that depend on any of `used_columns`
+fn get_predicates<'a>(
+    state: &'a State,
+    used_columns: &HashSet<String>,
+) -> Predicates<'a> {
+    state
         .filters
         .iter()
         .filter(|(_, columns)| {
             columns
-                .intersection(&used_columns)
+                .intersection(used_columns)
                 .collect::<HashSet<_>>()
                 .len()
                 > 0
         })
         .map(|&(ref a, ref b)| (a, b))
+        .unzip()
+}
+
+// returns 3 (potentially overlaping) sets of predicates:
+// * pushable to left: its columns are all on the left
+// * pushable to right: its columns is all on the right
+// * keep: the set of columns is not in only either left or right
+// Note that a predicate can be both pushed to the left and to the right.
+fn get_join_predicates<'a>(
+    state: &'a State,
+    left: &Schema,
+    right: &Schema,
+) -> (
+    Vec<&'a HashSet<String>>,
+    Vec<&'a HashSet<String>>,
+    Predicates<'a>,
+) {
+    let left_columns = &left
+        .fields()
+        .iter()
+        .map(|f| f.name().clone())
+        .collect::<HashSet<_>>();
+    let right_columns = &right
+        .fields()
+        .iter()
+        .map(|f| f.name().clone())
+        .collect::<HashSet<_>>();
+
+    let filters = state
+        .filters
+        .iter()
+        .map(|(predicate, columns)| {
+            (
+                (predicate, columns),
+                (
+                    columns,
+                    left_columns.intersection(columns).collect::<HashSet<_>>(),
+                    right_columns.intersection(columns).collect::<HashSet<_>>(),
+                ),
+            )
+        })
+        .collect::<Vec<_>>();
+
+    let pushable_to_left = filters
+        .iter()
+        .filter(|(_, (columns, left, _))| left.len() == columns.len())
+        .map(|((_, b), _)| *b)
+        .collect();
+    let pushable_to_right = filters
+        .iter()
+        .filter(|(_, (columns, _, right))| right.len() == columns.len())
+        .map(|((_, b), _)| *b)
+        .collect();
+    let keep = filters
+        .iter()
+        .filter(|(_, (columns, left, right))| {
+            // predicates whose columns are not in only one side of the join need to remain
+            let all_in_left = left.len() == columns.len();
+            let all_in_right = right.len() == columns.len();
+            !all_in_left && !all_in_right
+        })
+        .map(|((ref a, ref b), _)| (a, b))
         .unzip();
+    (pushable_to_left, pushable_to_right, keep)
+}
 
-    if predicates.is_empty() {
-        // all filters can be pushed down => optimize inputs and return new plan
-        let new_inputs = utils::inputs(&plan)
-            .iter()
-            .map(|input| optimize(input, state.clone()))
-            .collect::<Result<Vec<_>>>()?;
+/// Optimizes the plan
+fn push_down(state: &State, plan: &LogicalPlan) -> Result<LogicalPlan> {
+    let new_inputs = utils::inputs(&plan)
+        .iter()
+        .map(|input| optimize(input, state.clone()))
+        .collect::<Result<Vec<_>>>()?;
 
-        let expr = utils::expressions(&plan);
-        return utils::from_plan(&plan, &expr, &new_inputs);
-    }
+    let expr = utils::expressions(&plan);
+    utils::from_plan(&plan, &expr, &new_inputs)
+}
 
+/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with
+/// its predicate be all `predicates` ANDed.
+fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan {
     // reduce filters to a single filter with an AND
     let predicate = predicates
         .iter()
@@ -97,28 +165,56 @@
             and(acc, (*predicate).to_owned())
         });
 
-    // add a new filter node with the predicates
-    let plan = LogicalPlan::Filter {
+    LogicalPlan::Filter {
         predicate,
-        input: Arc::new(plan.clone()),
-    };
+        input: Arc::new(plan),
+    }
+}
 
-    // remove all filters from the state that cannot be pushed further down
-    state.filters = state
-        .filters
+// remove all filters from `filters` that are in `predicate_columns`
+fn remove_filters(
+    filters: &[(Expr, HashSet<String>)],
+    predicate_columns: &[&HashSet<String>],
+) -> Vec<(Expr, HashSet<String>)> {
+    filters
         .iter()
         .filter(|(_, columns)| !predicate_columns.contains(&columns))
         .cloned()
-        .collect::<Vec<_>>();
+        .collect::<Vec<_>>()
+}
+
+// keeps all filters from `filters` that are in `predicate_columns`
+fn keep_filters(
+    filters: &[(Expr, HashSet<String>)],
+    predicate_columns: &[&HashSet<String>],
+) -> Vec<(Expr, HashSet<String>)> {
+    filters
+        .iter()
+        .filter(|(_, columns)| predicate_columns.contains(&columns))
+        .cloned()
+        .collect::<Vec<_>>()
+}
+
+/// builds a new [LogicalPlan] from `plan` by issuing new [LogicalPlan::Filter] if any of the filters
+/// in `state` depend on the columns `used_columns`.
+fn issue_filters(
+    mut state: State,
+    used_columns: HashSet<String>,
+    plan: &LogicalPlan,
+) -> Result<LogicalPlan> {
+    let (predicates, predicate_columns) = get_predicates(&state, &used_columns);
+
+    if predicates.is_empty() {
+        // all filters can be pushed down => optimize inputs and return new plan
+        return push_down(&state, plan);
+    }
+
+    let plan = add_filter(plan.clone(), &predicates);
+
+    state.filters = remove_filters(&state.filters, &predicate_columns);
 
     // continue optimization over all input nodes by cloning the current state (i.e. each node is independent)
-    let new_inputs = utils::inputs(&plan)
-        .iter()
-        .map(|input| optimize(input, state.clone()))
-        .collect::<Result<Vec<_>>>()?;
-
-    let expr = utils::expressions(&plan);
-    utils::from_plan(&plan, &expr, &new_inputs)
+    push_down(&state, &plan)
 }
 
 fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
@@ -183,7 +279,7 @@
         }
         LogicalPlan::Sort { .. } => {
             // sort is filter-commutable
-            issue_filters(state, HashSet::new(), plan)
+            push_down(&state, plan)
         }
         LogicalPlan::Limit { input, .. } => {
             // limit is _not_ filter-commutable => collect all columns from its input
@@ -195,9 +291,31 @@
                 .collect::<HashSet<_>>();
             issue_filters(state, used_columns, plan)
         }
-        LogicalPlan::Join { .. } => {
-            // join is filter-commutable
-            issue_filters(state, HashSet::new(), plan)
+        LogicalPlan::Join { left, right, .. } => {
+            let (pushable_to_left, pushable_to_right, keep) =
+                get_join_predicates(&state, &left.schema(), &right.schema());
+
+            let mut left_state = state.clone();
+            left_state.filters = keep_filters(&left_state.filters, &pushable_to_left);
+            let left = optimize(left, left_state)?;
+
+            let mut right_state = state.clone();
+            right_state.filters = keep_filters(&right_state.filters, &pushable_to_right);
+            let right = optimize(right, right_state)?;
+
+            // create a new Join with the new `left` and `right`
+            let expr = utils::expressions(&plan);
+            let plan = utils::from_plan(&plan, &expr, &vec![left, right])?;
+
+            if keep.0.is_empty() {
+                Ok(plan)
+            } else {
+                // wrap the join on the filter whose predicates must be kept
+                let plan = add_filter(plan, &keep.0);
+                state.filters = remove_filters(&state.filters, &keep.1);
+
+                Ok(plan)
+            }
         }
         _ => {
             // all other plans are _not_ filter-commutable
@@ -594,8 +712,9 @@
         Ok(())
     }
 
+    /// post-join predicates on a column common to both sides is pushed to both sides
     #[test]
-    fn filters_join() -> Result<()> {
+    fn filter_join_on_common_independent() -> Result<()> {
         let table_scan = test_table_scan()?;
         let left = LogicalPlanBuilder::from(&table_scan).build()?;
         let right = LogicalPlanBuilder::from(&table_scan)
@@ -628,4 +747,77 @@
         assert_optimized_plan_eq(&plan, expected);
         Ok(())
     }
+
+    /// post-join predicates with columns from both sides are not pushed
+    #[test]
+    fn filter_join_on_common_dependent() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let left = LogicalPlanBuilder::from(&table_scan)
+            .project(vec![col("a"), col("c")])?
+            .build()?;
+        let right = LogicalPlanBuilder::from(&table_scan)
+            .project(vec![col("a"), col("b")])?
+            .build()?;
+        let plan = LogicalPlanBuilder::from(&left)
+            .join(&right, JoinType::Inner, &["a"], &["a"])?
+            // "b" and "c" are not shared by either side: they are only available together after the join
+            .filter(col("c").lt_eq(col("b")))?
+            .build()?;
+
+        // not part of the test, just good to know:
+        assert_eq!(
+            format!("{:?}", plan),
+            "\
+            Filter: #c LtEq #b\
+            \n  Join: a = a\
+            \n    Projection: #a, #c\
+            \n      TableScan: test projection=None\
+            \n    Projection: #a, #b\
+            \n      TableScan: test projection=None"
+        );
+
+        // expected is equal: no push-down
+        let expected = &format!("{:?}", plan);
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    /// post-join predicates with columns from one side of a join are pushed only to that side
+    #[test]
+    fn filter_join_on_one_side() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let left = LogicalPlanBuilder::from(&table_scan)
+            .project(vec![col("a"), col("b")])?
+            .build()?;
+        let right = LogicalPlanBuilder::from(&table_scan)
+            .project(vec![col("a"), col("c")])?
+            .build()?;
+        let plan = LogicalPlanBuilder::from(&left)
+            .join(&right, JoinType::Inner, &["a"], &["a"])?
+            .filter(col("b").lt_eq(lit(1i64)))?
+            .build()?;
+
+        // not part of the test, just good to know:
+        assert_eq!(
+            format!("{:?}", plan),
+            "\
+            Filter: #b LtEq Int64(1)\
+            \n  Join: a = a\
+            \n    Projection: #a, #b\
+            \n      TableScan: test projection=None\
+            \n    Projection: #a, #c\
+            \n      TableScan: test projection=None"
+        );
+
+        let expected = "\
+        Join: a = a\
+        \n  Projection: #a, #b\
+        \n    Filter: #b LtEq Int64(1)\
+        \n      TableScan: test projection=None\
+        \n  Projection: #a, #c\
+        \n    TableScan: test projection=None";
+
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
 }