perf: Simplify CASE for any WHEN TRUE (#17602)
* Extend case simplify expr
* Add tests
* cargo fmt
* Remove copying vector based on PR feedback
* Remove unnecessary if conditional (pr feedback)
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 97dfc09..3c96f95 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -1400,13 +1400,35 @@
//
// CASE WHEN true THEN A ... END --> A
+ // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
Expr::Case(Case {
expr: None,
mut when_then_expr,
else_expr: _,
- }) if !when_then_expr.is_empty() && is_true(when_then_expr[0].0.as_ref()) => {
- let (_, then_) = when_then_expr.swap_remove(0);
- Transformed::yes(*then_)
+ // if let guard is not stabilized so we can't use it yet: https://github.com/rust-lang/rust/issues/51114
+ // Once it's supported we can avoid searching through when_then_expr twice in the below .any() and .position() calls
+ // }) if let Some(i) = when_then_expr.iter().position(|(when, _)| is_true(when.as_ref())) => {
+ }) if when_then_expr
+ .iter()
+ .any(|(when, _)| is_true(when.as_ref())) =>
+ {
+ let i = when_then_expr
+ .iter()
+ .position(|(when, _)| is_true(when.as_ref()))
+ .unwrap();
+ let (_, then_) = when_then_expr.swap_remove(i);
+ // CASE WHEN true THEN A ... END --> A
+ if i == 0 {
+ return Ok(Transformed::yes(*then_));
+ }
+
+ // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
+ when_then_expr.truncate(i);
+ Transformed::yes(Expr::Case(Case {
+ expr: None,
+ when_then_expr,
+ else_expr: Some(then_),
+ }))
}
// CASE
@@ -3563,7 +3585,7 @@
}
#[test]
- fn simplify_expr_case_when_true() {
+ fn simplify_expr_case_when_first_true() {
// CASE WHEN true THEN 1 ELSE x END --> 1
assert_eq!(
simplify(Expr::Case(Case::new(
@@ -3632,6 +3654,82 @@
assert_eq!(simplify(expr.clone()), expr);
}
+ #[test]
+ fn simplify_expr_case_when_any_true() {
+ // CASE WHEN x > 0 THEN a WHEN true THEN b ELSE c END --> CASE WHEN x > 0 THEN a ELSE b END
+ assert_eq!(
+ simplify(Expr::Case(Case::new(
+ None,
+ vec![
+ (Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
+ (Box::new(lit(true)), Box::new(col("b"))),
+ ],
+ Some(Box::new(col("c"))),
+ ))),
+ Expr::Case(Case::new(
+ None,
+ vec![(Box::new(col("x").gt(lit(0))), Box::new(col("a")))],
+ Some(Box::new(col("b"))),
+ ))
+ );
+
+ // CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c WHEN z = 0 THEN d ELSE e END
+ // --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END
+ assert_eq!(
+ simplify(Expr::Case(Case::new(
+ None,
+ vec![
+ (Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
+ (Box::new(col("y").lt(lit(0))), Box::new(col("b"))),
+ (Box::new(lit(true)), Box::new(col("c"))),
+ (Box::new(col("z").eq(lit(0))), Box::new(col("d"))),
+ ],
+ Some(Box::new(col("e"))),
+ ))),
+ Expr::Case(Case::new(
+ None,
+ vec![
+ (Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
+ (Box::new(col("y").lt(lit(0))), Box::new(col("b"))),
+ ],
+ Some(Box::new(col("c"))),
+ ))
+ );
+
+ // CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c END (no else)
+ // --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END
+ assert_eq!(
+ simplify(Expr::Case(Case::new(
+ None,
+ vec![
+ (Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
+ (Box::new(col("y").lt(lit(0))), Box::new(col("b"))),
+ (Box::new(lit(true)), Box::new(col("c"))),
+ ],
+ None,
+ ))),
+ Expr::Case(Case::new(
+ None,
+ vec![
+ (Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
+ (Box::new(col("y").lt(lit(0))), Box::new(col("b"))),
+ ],
+ Some(Box::new(col("c"))),
+ ))
+ );
+
+ // Negative test: CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END should not be simplified
+ let expr = Expr::Case(Case::new(
+ None,
+ vec![
+ (Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
+ (Box::new(col("y").lt(lit(0))), Box::new(col("b"))),
+ ],
+ Some(Box::new(col("c"))),
+ ));
+ assert_eq!(simplify(expr.clone()), expr);
+ }
+
fn distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
Expr::BinaryExpr(BinaryExpr {
left: Box::new(left.into()),