[CALCITE-6317] Incorrect constant replacement when group keys are NULL

Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
diff --git a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdPredicates.java b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdPredicates.java
index dd10eed..7473293 100644
--- a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdPredicates.java
+++ b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdPredicates.java
@@ -344,6 +344,28 @@
     return joinInference.inferPredicates(false);
   }
 
+  /** Check whether the fields specified by the predicateColumns appear in all
+   * the groupSets of the aggregate.
+   *
+   * @param predicateColumns  A list of columns used in a pulled predicate.
+   * @param aggregate         An aggregation operation.
+   * @return                  Whether all columns appear in all groupsets.
+   */
+  boolean allGroupSetsOverlap(ImmutableBitSet predicateColumns, Aggregate aggregate) {
+    // Consider this example:
+    // select deptno, sal, count(*)
+    // from emp where deptno = 10
+    // group by rollup(sal, deptno)
+    // Because of the ROLLUP, we cannot assume
+    // that deptno = 10 in the result: deptno may be NULL as well.
+    for (ImmutableBitSet groupSet : aggregate.groupSets) {
+      if (!groupSet.contains(predicateColumns)) {
+        return false;
+      }
+    }
+    return true;
+  }
+
   /**
    * Infers predicates for an Aggregate.
    *
@@ -382,7 +404,8 @@
 
     for (RexNode r : inputInfo.pulledUpPredicates) {
       ImmutableBitSet rCols = RelOptUtil.InputFinder.bits(r);
-      if (groupKeys.contains(rCols)) {
+
+      if (groupKeys.contains(rCols) && this.allGroupSetsOverlap(rCols, agg)) {
         r = r.accept(new RexPermuteInputsShuttle(m, input));
         aggPullUpPredicates.add(r);
       }
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index d2b28cb..51a13e0 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -320,6 +320,19 @@
 
   /**
    * Test case for
+   * <a href="https://issues.apache.org/jira/projects/CALCITE/issues/CALCITE-6317">
+   * [CALCITE-6317] Incorrect constant replacement when group keys are NULL</a>. */
+  @Test void testPredicatePull() {
+    final String sql = "select deptno, sal "
+                 + "from emp "
+                + "where deptno = 10 "
+                + "group by rollup(sal, deptno)";
+    sql(sql).withRule(CoreRules.PROJECT_REDUCE_EXPRESSIONS)
+        .check();
+  }
+
+  /**
+   * Test case for
    * <a href="https://issues.apache.org/jira/browse/CALCITE-5971">[CALCITE-5971]
    * Add the RelRule to rewrite the bernoulli sample as Filter</a>. */
   @Test void testSampleToFilterWithSeed() {
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 7242a8f..bc940b8 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -7436,6 +7436,29 @@
       <![CDATA[select * from emp where MGR > 0 and case when MGR > 0 then deptno / MGR else null end > 1]]>
     </Resource>
   </TestCase>
+  <TestCase name="testPredicatePull">
+    <Resource name="sql">
+      <![CDATA[select deptno, sal from emp where deptno = 10 group by rollup(sal, deptno)]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalProject(DEPTNO=[$1], SAL=[$0])
+  LogicalAggregate(group=[{0, 1}], groups=[[{0, 1}, {0}, {}]])
+    LogicalProject(SAL=[$5], DEPTNO=[$7])
+      LogicalFilter(condition=[=($7, 10)])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(DEPTNO=[$1], SAL=[$0])
+  LogicalAggregate(group=[{0, 1}], groups=[[{0, 1}, {0}, {}]])
+    LogicalProject(SAL=[$5], DEPTNO=[10])
+      LogicalFilter(condition=[=($7, 10)])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
   <TestCase name="testProjectAggregateMerge">
     <Resource name="sql">
       <![CDATA[select deptno + ss