[CALCITE-6317] Incorrect constant replacement when group keys are NULL
Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
diff --git a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdPredicates.java b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdPredicates.java
index dd10eed..7473293 100644
--- a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdPredicates.java
+++ b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdPredicates.java
@@ -344,6 +344,28 @@
return joinInference.inferPredicates(false);
}
+ /** Check whether the fields specified by the predicateColumns appear in all
+ * the groupSets of the aggregate.
+ *
+ * @param predicateColumns A list of columns used in a pulled predicate.
+ * @param aggregate An aggregation operation.
+ * @return Whether all columns appear in all groupsets.
+ */
+ boolean allGroupSetsOverlap(ImmutableBitSet predicateColumns, Aggregate aggregate) {
+ // Consider this example:
+ // select deptno, sal, count(*)
+ // from emp where deptno = 10
+ // group by rollup(sal, deptno)
+ // Because of the ROLLUP, we cannot assume
+ // that deptno = 10 in the result: deptno may be NULL as well.
+ for (ImmutableBitSet groupSet : aggregate.groupSets) {
+ if (!groupSet.contains(predicateColumns)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
/**
* Infers predicates for an Aggregate.
*
@@ -382,7 +404,8 @@
for (RexNode r : inputInfo.pulledUpPredicates) {
ImmutableBitSet rCols = RelOptUtil.InputFinder.bits(r);
- if (groupKeys.contains(rCols)) {
+
+ if (groupKeys.contains(rCols) && this.allGroupSetsOverlap(rCols, agg)) {
r = r.accept(new RexPermuteInputsShuttle(m, input));
aggPullUpPredicates.add(r);
}
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index d2b28cb..51a13e0 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -320,6 +320,19 @@
/**
* Test case for
+ * <a href="https://issues.apache.org/jira/projects/CALCITE/issues/CALCITE-6317">
+ * [CALCITE-6317] Incorrect constant replacement when group keys are NULL</a>. */
+ @Test void testPredicatePull() {
+ final String sql = "select deptno, sal "
+ + "from emp "
+ + "where deptno = 10 "
+ + "group by rollup(sal, deptno)";
+ sql(sql).withRule(CoreRules.PROJECT_REDUCE_EXPRESSIONS)
+ .check();
+ }
+
+ /**
+ * Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-5971">[CALCITE-5971]
* Add the RelRule to rewrite the bernoulli sample as Filter</a>. */
@Test void testSampleToFilterWithSeed() {
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 7242a8f..bc940b8 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -7436,6 +7436,29 @@
<![CDATA[select * from emp where MGR > 0 and case when MGR > 0 then deptno / MGR else null end > 1]]>
</Resource>
</TestCase>
+ <TestCase name="testPredicatePull">
+ <Resource name="sql">
+ <![CDATA[select deptno, sal from emp where deptno = 10 group by rollup(sal, deptno)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(DEPTNO=[$1], SAL=[$0])
+ LogicalAggregate(group=[{0, 1}], groups=[[{0, 1}, {0}, {}]])
+ LogicalProject(SAL=[$5], DEPTNO=[$7])
+ LogicalFilter(condition=[=($7, 10)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(DEPTNO=[$1], SAL=[$0])
+ LogicalAggregate(group=[{0, 1}], groups=[[{0, 1}, {0}, {}]])
+ LogicalProject(SAL=[$5], DEPTNO=[10])
+ LogicalFilter(condition=[=($7, 10)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
<TestCase name="testProjectAggregateMerge">
<Resource name="sql">
<![CDATA[select deptno + ss