[CALCITE-4726] Support aggregate calls with a FILTER clause in AggregateExpandWithinDistinctRule (Will Noble)

Close apache/calcite#2483
diff --git a/core/src/main/java/org/apache/calcite/jdbc/SimpleCalciteSchema.java b/core/src/main/java/org/apache/calcite/jdbc/SimpleCalciteSchema.java
index dc67c2a..9630bc4 100644
--- a/core/src/main/java/org/apache/calcite/jdbc/SimpleCalciteSchema.java
+++ b/core/src/main/java/org/apache/calcite/jdbc/SimpleCalciteSchema.java
@@ -76,7 +76,7 @@
     return calciteSchema;
   }
 
-  private @Nullable String caseInsensitiveLookup(Set<String> candidates, String name) {
+  private static @Nullable String caseInsensitiveLookup(Set<String> candidates, String name) {
     // Exact string lookup
     if (candidates.contains(name)) {
       return name;
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandWithinDistinctRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandWithinDistinctRule.java
index 91cca97..0e6c621 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandWithinDistinctRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandWithinDistinctRule.java
@@ -31,6 +31,7 @@
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.ImmutableIntList;
 import org.apache.calcite.util.Util;
+import org.apache.calcite.util.mapping.IntPair;
 
 import com.google.common.collect.ArrayListMultimap;
 import com.google.common.collect.ImmutableList;
@@ -113,8 +114,6 @@
         // Wait until AggregateReduceFunctionsRule has dealt with AVG etc.
         && aggregate.getAggCallList().stream()
            .noneMatch(CoreRules.AGGREGATE_REDUCE_FUNCTIONS::canReduce)
-        // Don't know that we can handle FILTER yet
-        && aggregate.getAggCallList().stream().noneMatch(c -> c.filterArg >= 0)
         // Don't think we can handle GROUPING SETS yet
         && aggregate.getGroupType() == Aggregate.Group.SIMPLE;
   }
@@ -132,7 +131,7 @@
     //
     // or in algebra,
     //
-    //   Aggregate($0, SUM($2), SUM($3) WITHIN DISTINCT ($4))
+    //   Aggregate($0, SUM($2), SUM($2) WITHIN DISTINCT ($4))
     //     Scan(emp)
     //
     // We plan to generate the following:
@@ -154,8 +153,6 @@
     //   SUM(sal) WITHIN DISTINCT (sal)
     //
 
-    // TODO: handle "agg(x) filter (b)"
-
     final List<AggregateCall> aggCallList =
         aggregate.getAggCallList()
             .stream()
@@ -179,27 +176,31 @@
           //   sum(x) within distinct (y, z) ... group by y
           // can be simplified to
           //   sum(x) within distinct (z) ... group by y
+          // Note that this assumes a single grouping set for the original agg.
           distinctKeys = distinctKeys.rebuild()
               .removeAll(aggregate.getGroupSet()).build();
         }
       }
       argLists.put(distinctKeys, aggCall);
-      assert aggCall.filterArg < 0;
     }
 
+    // Compute the set of all grouping sets that will be used in the output
+    // query. For each WITHIN DISTINCT aggregate call, we will need a grouping
+    // set that is the union of the aggregate call's unique keys and the input
+    // query's overall grouping. Redundant grouping sets can be reused for
+    // multiple aggregate calls.
     final Set<ImmutableBitSet> groupSetTreeSet =
         new TreeSet<>(ImmutableBitSet.ORDERING);
-    groupSetTreeSet.add(aggregate.getGroupSet());
     for (ImmutableBitSet key : argLists.keySet()) {
-      if (key == notDistinct) {
-        continue;
-      }
       groupSetTreeSet.add(
-          ImmutableBitSet.of(key).union(aggregate.getGroupSet()));
+          (key == notDistinct)
+              ? aggregate.getGroupSet()
+              : ImmutableBitSet.of(key).union(aggregate.getGroupSet()));
     }
 
     final ImmutableList<ImmutableBitSet> groupSets =
         ImmutableList.copyOf(groupSetTreeSet);
+    final boolean hasMultipleGroupSets = groupSets.size() > 1;
     final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
     final Set<Integer> fullGroupOrderedSet = new LinkedHashSet<>();
     fullGroupOrderedSet.addAll(aggregate.getGroupSet().asSet());
@@ -216,40 +217,89 @@
     //
     // or in algebra,
     //
-    //   Aggregate([($0), ($0, $2)], SUM($2), MIN($2), MAX($2), GROUPING($0, $4))
+    //   Aggregate([($0), ($0, $4)], SUM($2), MIN($2), MAX($2), GROUPING($0, $4))
     //     Scan(emp)
 
     final RelBuilder b = call.builder();
     b.push(aggregate.getInput());
     final List<RelBuilder.AggCall> aggCalls = new ArrayList<>();
 
+    // Helper class for building the inner query.
     // CHECKSTYLE: IGNORE 1
     class Registrar {
       final int g = fullGroupSet.cardinality();
-      final Map<Integer, Integer> args = new HashMap<>();
+      /** Map of input fields (below the original aggregation) and filter args
+       * to inner query aggregate calls. */
+      final Map<IntPair, Integer> args = new HashMap<>();
+      /** Map of aggregate calls from the original aggregation to inner query
+       * aggregate calls. */
       final Map<Integer, Integer> aggs = new HashMap<>();
+      /** Map of aggregate calls from the original aggregation to inner-query
+       * {@code COUNT(*)} calls, which are only needed for filters in the outer
+       * aggregate when the original aggregate call does not ignore null
+       * inputs. */
+      final Map<Integer, Integer> counts = new HashMap<>();
 
-      List<Integer> fields(List<Integer> fields) {
-        return Util.transform(fields, this::field);
+      List<Integer> fields(List<Integer> fields, int filterArg) {
+        return Util.transform(fields, f -> this.field(f, filterArg));
       }
 
-      int field(int field) {
-        return Objects.requireNonNull(args.get(field));
+      int field(int field, int filterArg) {
+        return Objects.requireNonNull(args.get(IntPair.of(field, filterArg)));
       }
 
-      int register(int field) {
-        return args.computeIfAbsent(field, j -> {
+      /** Computes an aggregate call argument's values for a
+       * {@code WITHIN DISTINCT} aggregate call.
+       *
+       * <p>For example, to compute
+       * {@code SUM(x) WITHIN DISTINCT (y) GROUP BY (z)},
+       * the inner aggregate must first group {@code x} by {@code (y, z)}
+       * &mdash; using {@code MIN} to select the (hopefully) unique value of
+       * {@code x} for each {@code (y, z)} group. Actually summing over the
+       * grouped {@code x} values must occur in an outer aggregate.
+       *
+       * @param field Index of an input field that's used in a
+       *         {@code WITHIN DISTINCT} aggregate call
+       * @param filterArg Filter arg used in the original aggregate call, or
+       *         {@code -1} if there is no filter. We use the same filter in
+       *         the inner query.
+       * @return Index of the inner query aggregate call representing the
+       *         grouped field, which can be referenced in the outer query
+       *         aggregate call
+       */
+      int register(int field, int filterArg) {
+        return args.computeIfAbsent(IntPair.of(field, filterArg), j -> {
           final int ordinal = g + aggCalls.size();
+          RelBuilder.AggCall groupedField =
+              b.aggregateCall(SqlStdOperatorTable.MIN, b.field(field));
           aggCalls.add(
-              b.aggregateCall(SqlStdOperatorTable.MIN, b.field(j)));
+              filterArg < 0
+                  ? groupedField
+                  : groupedField.filter(b.field(filterArg)));
           if (config.throwIfNotUnique()) {
+            groupedField =
+                b.aggregateCall(SqlStdOperatorTable.MAX, b.field(field));
             aggCalls.add(
-                b.aggregateCall(SqlStdOperatorTable.MAX, b.field(j)));
+                filterArg < 0
+                    ? groupedField
+                    : groupedField.filter(b.field(filterArg)));
           }
           return ordinal;
         });
       }
 
+      /** Registers an aggregate call that is <em>not</em> a
+       * {@code WITHIN DISTINCT} call.
+       *
+       * <p>Unlike the case handled by {@link #register(int, int)} above,
+       * aggregate calls without any distinct keys do not need a second round
+       * of aggregation in the outer query, so they can be computed "as-is" in
+       * the inner query.
+       *
+       * @param i Index of the aggregate call in the original aggregation
+       * @param aggregateCall Original aggregate call
+       * @return Index of the aggregate call in the computed inner query
+       */
       int registerAgg(int i, RelBuilder.AggCall aggregateCall) {
         final int ordinal = g + aggCalls.size();
         aggs.put(i, ordinal);
@@ -260,6 +310,33 @@
       int getAgg(int i) {
         return Objects.requireNonNull(aggs.get(i));
       }
+
+      /** Registers an extra {@code COUNT} aggregate call when it's needed to
+       * filter out null inputs in the outer aggregate.
+       *
+       * <p>This should only be called for aggregate calls with filters. It's
+       * possible that the filter would eliminate all input rows to the
+       * {@code MIN} call in the inner query, so calls in the outer
+       * aggregate may need to be aware of this. See usage of
+       * {@link AggregateExpandWithinDistinctRule#mustBeCounted(AggregateCall)}.
+       *
+       * @param filterArg The original aggregate call's filter; must be
+       *                 non-negative
+       * @return Index of the {@code COUNT} call in the computed inner query
+       */
+      int registerCount(int filterArg) {
+        assert filterArg >= 0;
+        return counts.computeIfAbsent(filterArg, i -> {
+          final int ordinal = g + aggCalls.size();
+          aggCalls.add(b.aggregateCall(SqlStdOperatorTable.COUNT)
+              .filter(b.field(filterArg)));
+          return ordinal;
+        });
+      }
+
+      int getCount(int filterArg) {
+        return Objects.requireNonNull(counts.get(filterArg));
+      }
     }
 
     final Registrar registrar = new Registrar();
@@ -269,13 +346,25 @@
             b.aggregateCall(c.getAggregation(),
                 b.fields(c.getArgList())));
       } else {
-        c.getArgList().forEach(registrar::register);
+        for (int inputIdx : c.getArgList()) {
+          registrar.register(inputIdx, c.filterArg);
+        }
+        if (mustBeCounted(c)) {
+          registrar.registerCount(c.filterArg);
+        }
       }
     });
+    // Add an additional GROUPING() aggregate call so we can select only the
+    // relevant inner-aggregate rows from the outer aggregate. If there is only
+    // 1 grouping set (i.e. every aggregate call has the same distinct keys),
+    // no GROUPING() call is necessary.
     final int grouping =
-        registrar.registerAgg(-1,
-            b.aggregateCall(SqlStdOperatorTable.GROUPING,
-                b.fields(fullGroupList)));
+        hasMultipleGroupSets
+            ? registrar.registerAgg(-1,
+                b.aggregateCall(
+                    SqlStdOperatorTable.GROUPING,
+                    b.fields(fullGroupList)))
+            : -1;
     b.aggregate(
         b.groupKey(fullGroupSet,
             (Iterable<ImmutableBitSet>) groupSets), aggCalls);
@@ -304,32 +393,56 @@
     aggCalls.clear();
     Ord.forEach(aggCallList, (c, i) -> {
       final List<RexNode> filters = new ArrayList<>();
-      final RexNode groupFilter = b.equals(b.field(grouping),
-          b.literal(
-              groupValue(fullGroupList,
-                  union(aggregate.getGroupSet(), c.distinctKeys))));
-      filters.add(groupFilter);
-      final RelBuilder.AggCall aggCall;
+      RexNode groupFilter = null;
+      if (hasMultipleGroupSets) {
+        groupFilter =
+            b.equals(
+                b.field(grouping),
+                b.literal(
+                    groupValue(fullGroupList, union(aggregate.getGroupSet(), c.distinctKeys))));
+        filters.add(groupFilter);
+      }
+      RelBuilder.AggCall aggCall;
       if (c.distinctKeys == null) {
         aggCall = b.aggregateCall(SqlStdOperatorTable.MIN,
             b.field(registrar.getAgg(i)));
       } else {
-        aggCall = b.aggregateCall(c.getAggregation(),
-            b.fields(registrar.fields(c.getArgList())));
+        // The inputs to this aggregate are outputs from MIN() calls from the
+        // inner agg, and MIN() returns null iff it has no non-null inputs,
+        // which can only happen if an original aggregate's filter causes all
+        // non-null input rows to be discarded for a particular group in the
+        // inner aggregate. In this case, it should be ignored by the outer
+        // aggregate as well. In case the aggregate call does not naturally
+        // ignore null inputs, we add a filter based on a COUNT() in the inner
+        // aggregate.
+        aggCall =
+            b.aggregateCall(
+                c.getAggregation(),
+                b.fields(registrar.fields(c.getArgList(), c.filterArg)));
+
+        if (mustBeCounted(c)) {
+          filters.add(b.greaterThan(b.field(registrar.getCount(c.filterArg)), b.literal(0)));
+        }
 
         if (config.throwIfNotUnique()) {
           for (int j : c.getArgList()) {
+            RexNode isUniqueCondition =
+                b.isNotDistinctFrom(
+                    b.field(registrar.field(j, c.filterArg)),
+                    b.field(registrar.field(j, c.filterArg) + 1));
+            if (groupFilter != null) {
+              isUniqueCondition = b.or(b.not(groupFilter), isUniqueCondition);
+            }
             String message = "more than one distinct value in agg UNIQUE_VALUE";
             filters.add(
-                b.call(SqlInternalOperators.THROW_UNLESS,
-                    b.or(b.not(groupFilter),
-                        b.isNotDistinctFrom(b.field(registrar.field(j)),
-                            b.field(registrar.field(j) + 1))),
-                    b.literal(message)));
+                b.call(SqlInternalOperators.THROW_UNLESS, isUniqueCondition, b.literal(message)));
           }
         }
       }
-      aggCalls.add(aggCall.filter(b.and(filters)));
+      if (filters.size() > 0) {
+        aggCall = aggCall.filter(b.and(filters));
+      }
+      aggCalls.add(aggCall);
     });
 
     b.aggregate(
@@ -342,6 +455,22 @@
     call.transformTo(b.build());
   }
 
+  private static boolean mustBeCounted(AggregateCall aggCall) {
+    // Always count filtered inner aggregates to be safe.
+    //
+    // It's possible that, for some aggregate calls (namely, those that
+    // completely ignore null inputs), we could neglect counting the
+    // grouped-and-filtered rows of the inner aggregate and filtering the empty
+    // ones out from the outer aggregate, since those empty groups would produce
+    // null values as the result of MIN and thus be ignored by the outer
+    // aggregate anyway.
+    //
+    // Note that using "aggCall.ignoreNulls()" is not sufficient to determine
+    // when it's safe to do this, since for COUNT the value of ignoreNulls()
+    // should generally be true even though COUNT(*) will never ignore anything.
+    return aggCall.hasFilter();
+  }
+
   /** Converts a {@code DISTINCT} aggregate call into an equivalent one with
    * {@code WITHIN DISTINCT}.
    *
@@ -362,6 +491,7 @@
           .stream()
           .filter(i ->
               aggregateCall.getAggregation().getKind() != SqlKind.COUNT
+                  || aggregateCall.hasFilter()
                   || isNullable.test(i))
           .collect(Collectors.toList());
       return aggregateCall.withDistinct(false)
diff --git a/core/src/main/java/org/apache/calcite/rex/RexSimplify.java b/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
index 8629b58..96338ab 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
@@ -400,7 +400,7 @@
    * Try to find a literal with the given value in the input list.
    * The type of the literal must be one of the numeric types.
    */
-  private int findLiteralIndex(List<RexNode> operands, BigDecimal value) {
+  private static int findLiteralIndex(List<RexNode> operands, BigDecimal value) {
     for (int i = 0; i < operands.size(); i++) {
       if (operands.get(i).isA(SqlKind.LITERAL)) {
         Comparable comparable = ((RexLiteral) operands.get(i)).getValue();
diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index 40268b4..6fbdde0 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -731,6 +731,11 @@
     return call(SqlStdOperatorTable.EQUALS, operand0, operand1);
   }
 
+  /** Creates a {@code >}. */
+  public RexNode greaterThan(RexNode operand0, RexNode operand1) {
+    return call(SqlStdOperatorTable.GREATER_THAN, operand0, operand1);
+  }
+
   /** Creates a {@code <>}. */
   public RexNode notEquals(RexNode operand0, RexNode operand1) {
     return call(SqlStdOperatorTable.NOT_EQUALS, operand0, operand1);
@@ -3543,13 +3548,12 @@
       if (distinct) {
         b.append("DISTINCT ");
       }
-      final int iMax = operands.size() - 1;
-      for (int i = 0; ; i++) {
-        b.append(operands.get(i));
-        if (i == iMax) {
-          break;
+      if (operands.size() > 0) {
+        b.append(operands.get(0));
+        for (int i = 1; i < operands.size(); i++) {
+          b.append(", ");
+          b.append(operands.get(i));
         }
-        b.append(", ");
       }
       b.append(')');
       if (filter != null) {
diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
index 9656adb..4cbc11e 100644
--- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
@@ -330,6 +330,22 @@
     assertThat(root, hasTree(expected));
   }
 
+  @Test void testScanFilterGreaterThan() {
+    // Equivalent SQL:
+    //   SELECT *
+    //   FROM emp
+    //   WHERE deptno > 20
+    final RelBuilder builder = RelBuilder.create(config().build());
+    RelNode root =
+        builder.scan("EMP")
+            .filter(
+                builder.greaterThan(builder.field("DEPTNO"), builder.literal(20)))
+            .build();
+    final String expected = "LogicalFilter(condition=[>($7, 20)])\n"
+        + "  LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(root, hasTree(expected));
+  }
+
   @Test void testSnapshotTemporalTable() {
     // Equivalent SQL:
     //   SELECT *
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 809289b..93e2893 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -1461,14 +1461,49 @@
     sql(sql).with(program).check();
   }
 
+  /** Tests {@link AggregateExpandWithinDistinctRule}. If all aggregate calls
+   * have the same distinct keys, there is no need for multiple grouping
+   * sets. */
+  @Test void testWithinDistinctUniformDistinctKeys() {
+    final String sql = "SELECT deptno,\n"
+        + " SUM(sal) WITHIN DISTINCT (job),\n"
+        + " AVG(comm) WITHIN DISTINCT (job)\n"
+        + "FROM emp\n"
+        + "GROUP BY deptno";
+    HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(CoreRules.AGGREGATE_REDUCE_FUNCTIONS)
+        .addRuleInstance(CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT)
+        .build();
+    sql(sql).with(program).check();
+  }
+
+  /** Tests {@link AggregateExpandWithinDistinctRule}. If all aggregate calls
+   * have the same distinct keys, and we're not checking for true uniqueness,
+   * there is no need for filtering in the outer aggregate. */
+  @Test void testWithinDistinctUniformDistinctKeysNoThrow() {
+    final String sql = "SELECT deptno,\n"
+        + " SUM(sal) WITHIN DISTINCT (job),\n"
+        + " AVG(comm) WITHIN DISTINCT (job)\n"
+        + "FROM emp\n"
+        + "GROUP BY deptno";
+    HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(CoreRules.AGGREGATE_REDUCE_FUNCTIONS)
+        .addRuleInstance(
+            CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT.config
+                .withThrowIfNotUnique(false).toRule())
+        .build();
+    sql(sql).with(program).check();
+  }
+
   /** Tests that {@link AggregateExpandWithinDistinctRule} treats
    * "COUNT(DISTINCT x)" as if it were "COUNT(x) WITHIN DISTINCT (x)". */
   @Test void testWithinDistinctCountDistinct() {
     final String sql = "SELECT deptno,\n"
-        + "  SUM(sal) WITHIN DISTINCT (job) AS ss_j,\n"
+        + "  SUM(sal) WITHIN DISTINCT (comm) AS ss_c,\n"
         + "  COUNT(DISTINCT job) cdj,\n"
         + "  COUNT(job) WITHIN DISTINCT (job) AS cj_j,\n"
-        + "  COUNT(DISTINCT job) WITHIN DISTINCT (job) AS cdj_j\n"
+        + "  COUNT(DISTINCT job) WITHIN DISTINCT (job) AS cdj_j,\n"
+        + "  COUNT(DISTINCT job) FILTER (WHERE sal > 1000) AS cdj_filtered\n"
         + "FROM emp\n"
         + "GROUP BY deptno";
     HepProgram program = new HepProgramBuilder()
@@ -1479,6 +1514,78 @@
     sql(sql).with(program).check();
   }
 
+  /** Test case for
+   * <a href="https://issues.apache.org/jira/browse/CALCITE-4726">[CALCITE-4726]
+   * Support aggregate calls with a FILTER clause in
+   * AggregateExpandWithinDistinctRule</a>.
+   *
+   * <p>Tests {@link AggregateExpandWithinDistinctRule} with different
+   * distinct keys and different filters for each aggregate call. */
+  @Test void testWithinDistinctFilteredAggs() {
+    final String sql = "SELECT deptno,\n"
+        + " SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE comm > 10),\n"
+        + " AVG(comm) WITHIN DISTINCT (sal) FILTER (WHERE ename LIKE '%ok%')\n"
+        + "FROM emp\n"
+        + "GROUP BY deptno";
+    HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(CoreRules.AGGREGATE_REDUCE_FUNCTIONS)
+        .addRuleInstance(CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT)
+        .build();
+    sql(sql).with(program).check();
+  }
+
+  /** Tests {@link AggregateExpandWithinDistinctRule}. Includes multiple
+   * different filters for the aggregate calls, and all aggregate calls have the
+   * same distinct keys, so there is no need to filter based on
+   * {@code GROUPING()}. */
+  @Test void testWithinDistinctFilteredAggsUniformDistinctKeys() {
+    final String sql = "SELECT deptno,\n"
+        + " SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE comm > 10),\n"
+        + " AVG(comm) WITHIN DISTINCT (job) FILTER (WHERE ename LIKE '%ok%')\n"
+        + "FROM emp\n"
+        + "GROUP BY deptno";
+    HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(CoreRules.AGGREGATE_REDUCE_FUNCTIONS)
+        .addRuleInstance(CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT)
+        .build();
+    sql(sql).with(program).check();
+  }
+
+  /** Tests {@link AggregateExpandWithinDistinctRule}. Includes multiple
+   * different filters for the aggregate calls, and all aggregate calls have the
+   * same distinct keys, so there is no need to filter based on
+   * {@code GROUPING()}. Does <em>not</em> throw if not unique. */
+  @Test void testWithinDistinctFilteredAggsUniformDistinctKeysNoThrow() {
+    final String sql = "SELECT deptno,\n"
+        + " SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE comm > 10),\n"
+        + " AVG(comm) WITHIN DISTINCT (job) FILTER (WHERE ename LIKE '%ok%')\n"
+        + "FROM emp\n"
+        + "GROUP BY deptno";
+    HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(CoreRules.AGGREGATE_REDUCE_FUNCTIONS)
+        .addRuleInstance(
+            CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT.config
+                .withThrowIfNotUnique(false).toRule())
+        .build();
+    sql(sql).with(program).check();
+  }
+
+  /** Tests {@link AggregateExpandWithinDistinctRule}. Includes multiple
+   * identical filters for the aggregate calls. The filters should be
+   * re-used. */
+  @Test void testWithinDistinctFilteredAggsSameFilter() {
+    final String sql = "SELECT deptno,\n"
+        + " SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE ename LIKE '%ok%'),\n"
+        + " AVG(comm) WITHIN DISTINCT (sal) FILTER (WHERE ename LIKE '%ok%')\n"
+        + "FROM emp\n"
+        + "GROUP BY deptno";
+    HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(CoreRules.AGGREGATE_REDUCE_FUNCTIONS)
+        .addRuleInstance(CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT)
+        .build();
+    sql(sql).with(program).check();
+  }
+
   @Test void testPushProjectPastFilter() {
     final String sql = "select empno + deptno from emp where sal = 10 * comm\n"
         + "and upper(ename) = 'FOO'";
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 3ed9020..a99fa46 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -13500,31 +13500,140 @@
   <TestCase name="testWithinDistinctCountDistinct">
     <Resource name="sql">
       <![CDATA[SELECT deptno,
-  SUM(sal) WITHIN DISTINCT (job) AS ss_j,
+  SUM(sal) WITHIN DISTINCT (comm) AS ss_c,
   COUNT(DISTINCT job) cdj,
   COUNT(job) WITHIN DISTINCT (job) AS cj_j,
-  COUNT(DISTINCT job) WITHIN DISTINCT (job) AS cdj_j
+  COUNT(DISTINCT job) WITHIN DISTINCT (job) AS cdj_j,
+  COUNT(DISTINCT job) FILTER (WHERE sal > 1000) AS cdj_filtered
 FROM emp
 GROUP BY deptno]]>
     </Resource>
     <Resource name="planBefore">
       <![CDATA[
-LogicalAggregate(group=[{0}], SS_J=[SUM($1) WITHIN DISTINCT ($2)], CDJ=[COUNT(DISTINCT $2)], CJ_J=[COUNT() WITHIN DISTINCT ($2)], CDJ_J=[COUNT(DISTINCT $2) WITHIN DISTINCT ($2)])
-  LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])
+LogicalAggregate(group=[{0}], SS_C=[SUM($1) WITHIN DISTINCT ($2)], CDJ=[COUNT(DISTINCT $3)], CJ_J=[COUNT() WITHIN DISTINCT ($3)], CDJ_J=[COUNT(DISTINCT $3) WITHIN DISTINCT ($3)], CDJ_FILTERED=[COUNT(DISTINCT $3) FILTER $4])
+  LogicalProject(DEPTNO=[$7], SAL=[$5], COMM=[$6], JOB=[$2], $f4=[>($5, 1000)])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
     </Resource>
     <Resource name="planAfter">
       <![CDATA[
-LogicalProject(DEPTNO=[$0], $f1=[$1], $f2=[$2], $f20=[$2], $f21=[$2])
-  LogicalAggregate(group=[{0}], agg#0=[$SUM0($1) FILTER $2], agg#1=[COUNT() FILTER $2])
-    LogicalProject(DEPTNO=[$0], $f2=[$2], $f4=[=($3, 0)])
-      LogicalAggregate(group=[{0, 2}], groups=[[{0, 2}, {0}]], agg#0=[MIN($1)], agg#1=[GROUPING($0, $2)])
-        LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])
+LogicalProject(DEPTNO=[$0], $f1=[$1], $f2=[$2], $f20=[$2], $f21=[$2], $f3=[$3])
+  LogicalAggregate(group=[{0}], agg#0=[$SUM0($1) FILTER $3], agg#1=[COUNT() FILTER $4], agg#2=[COUNT($2) FILTER $5])
+    LogicalProject(DEPTNO=[$0], $f3=[$3], $f4=[$4], $f7=[=($6, 1)], $f8=[=($6, 2)], $f9=[AND(=($6, 2), >($5, 0))])
+      LogicalAggregate(group=[{0, 2, 3}], groups=[[{0, 2}, {0, 3}]], agg#0=[MIN($1)], agg#1=[MIN($3) FILTER $4], agg#2=[COUNT() FILTER $4], agg#3=[GROUPING($0, $2, $3)])
+        LogicalProject(DEPTNO=[$7], SAL=[$5], COMM=[$6], JOB=[$2], $f4=[>($5, 1000)])
           LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
     </Resource>
   </TestCase>
+  <TestCase name="testWithinDistinctFilteredAggs">
+    <Resource name="sql">
+      <![CDATA[SELECT deptno,
+ SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE comm > 10),
+ AVG(comm) WITHIN DISTINCT (sal) FILTER (WHERE ename LIKE '%ok%')
+FROM emp
+GROUP BY deptno]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1) WITHIN DISTINCT ($3) FILTER $2], EXPR$2=[AVG($4) WITHIN DISTINCT ($1) FILTER $5])
+  LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($6, 10)], JOB=[$2], COMM=[$6], $f5=[LIKE($1, '%ok%')])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[CASE(=($2, 0), null:INTEGER, $1)], EXPR$2=[CAST(/($3, $4)):INTEGER])
+  LogicalProject(DEPTNO=[$0], EXPR$1=[$1], $f2=[$2], $f3=[CASE(=($4, 0), null:INTEGER, $3)], $f4=[$4])
+    LogicalAggregate(group=[{0}], agg#0=[$SUM0($1) FILTER $3], agg#1=[COUNT() FILTER $4], agg#2=[$SUM0($2) FILTER $5], agg#3=[COUNT() FILTER $6])
+      LogicalProject(DEPTNO=[$0], $f3=[$3], $f6=[$6], $f10=[AND(=($9, 2), >($5, 0), $THROW_UNLESS(OR(<>($9, 2), AND(IS NULL($3), IS NULL($4)), IS TRUE(=($3, $4))), 'more than one distinct value in agg UNIQUE_VALUE'))], $f11=[AND(=($9, 2), >($5, 0))], $f12=[AND(=($9, 1), >($8, 0), $THROW_UNLESS(OR(<>($9, 1), AND(IS NULL($6), IS NULL($7)), IS TRUE(=($6, $7))), 'more than one distinct value in agg UNIQUE_VALUE'))], $f13=[AND(=($9, 1), >($8, 0))])
+        LogicalAggregate(group=[{0, 1, 3}], groups=[[{0, 1}, {0, 3}]], agg#0=[MIN($1) FILTER $2], agg#1=[MAX($1) FILTER $2], agg#2=[COUNT() FILTER $2], agg#3=[MIN($4) FILTER $5], agg#4=[MAX($4) FILTER $5], agg#5=[COUNT() FILTER $5], agg#6=[GROUPING($0, $1, $3)])
+          LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($6, 10)], JOB=[$2], COMM=[$6], $f5=[LIKE($1, '%ok%')])
+            LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testWithinDistinctFilteredAggsSameFilter">
+    <Resource name="sql">
+      <![CDATA[SELECT deptno,
+ SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE ename LIKE '%ok%'),
+ AVG(comm) WITHIN DISTINCT (sal) FILTER (WHERE ename LIKE '%ok%')
+FROM emp
+GROUP BY deptno]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1) WITHIN DISTINCT ($3) FILTER $2], EXPR$2=[AVG($4) WITHIN DISTINCT ($1) FILTER $2])
+  LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[LIKE($1, '%ok%')], JOB=[$2], COMM=[$6])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[CASE(=($2, 0), null:INTEGER, $1)], EXPR$2=[CAST(/($3, $4)):INTEGER])
+  LogicalProject(DEPTNO=[$0], EXPR$1=[$1], $f2=[$2], $f3=[CASE(=($4, 0), null:INTEGER, $3)], $f4=[$4])
+    LogicalAggregate(group=[{0}], agg#0=[$SUM0($1) FILTER $3], agg#1=[COUNT() FILTER $4], agg#2=[$SUM0($2) FILTER $5], agg#3=[COUNT() FILTER $6])
+      LogicalProject(DEPTNO=[$0], $f3=[$3], $f6=[$6], $f9=[AND(=($8, 2), >($5, 0), $THROW_UNLESS(OR(<>($8, 2), AND(IS NULL($3), IS NULL($4)), IS TRUE(=($3, $4))), 'more than one distinct value in agg UNIQUE_VALUE'))], $f10=[AND(=($8, 2), >($5, 0))], $f11=[AND(=($8, 1), >($5, 0), $THROW_UNLESS(OR(<>($8, 1), AND(IS NULL($6), IS NULL($7)), IS TRUE(=($6, $7))), 'more than one distinct value in agg UNIQUE_VALUE'))], $f12=[AND(=($8, 1), >($5, 0))])
+        LogicalAggregate(group=[{0, 1, 3}], groups=[[{0, 1}, {0, 3}]], agg#0=[MIN($1) FILTER $2], agg#1=[MAX($1) FILTER $2], agg#2=[COUNT() FILTER $2], agg#3=[MIN($4) FILTER $2], agg#4=[MAX($4) FILTER $2], agg#5=[GROUPING($0, $1, $3)])
+          LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[LIKE($1, '%ok%')], JOB=[$2], COMM=[$6])
+            LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testWithinDistinctFilteredAggsUniformDistinctKeys">
+    <Resource name="sql">
+      <![CDATA[SELECT deptno,
+ SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE comm > 10),
+ AVG(comm) WITHIN DISTINCT (job) FILTER (WHERE ename LIKE '%ok%')
+FROM emp
+GROUP BY deptno]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1) WITHIN DISTINCT ($3) FILTER $2], EXPR$2=[AVG($4) WITHIN DISTINCT ($3) FILTER $5])
+  LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($6, 10)], JOB=[$2], COMM=[$6], $f5=[LIKE($1, '%ok%')])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[CASE(=($2, 0), null:INTEGER, $1)], EXPR$2=[CAST(/($3, $4)):INTEGER])
+  LogicalProject(DEPTNO=[$0], EXPR$1=[$1], $f2=[$2], $f3=[CASE(=($4, 0), null:INTEGER, $3)], $f4=[$4])
+    LogicalAggregate(group=[{0}], agg#0=[$SUM0($1) FILTER $3], agg#1=[COUNT() FILTER $4], agg#2=[$SUM0($2) FILTER $5], agg#3=[COUNT() FILTER $6])
+      LogicalProject(DEPTNO=[$0], $f2=[$2], $f5=[$5], $f8=[AND(>($4, 0), $THROW_UNLESS(OR(AND(IS NULL($2), IS NULL($3)), IS TRUE(=($2, $3))), 'more than one distinct value in agg UNIQUE_VALUE'))], $f9=[>($4, 0)], $f10=[AND(>($7, 0), $THROW_UNLESS(OR(AND(IS NULL($5), IS NULL($6)), IS TRUE(=($5, $6))), 'more than one distinct value in agg UNIQUE_VALUE'))], $f11=[>($7, 0)])
+        LogicalAggregate(group=[{0, 3}], agg#0=[MIN($1) FILTER $2], agg#1=[MAX($1) FILTER $2], agg#2=[COUNT() FILTER $2], agg#3=[MIN($4) FILTER $5], agg#4=[MAX($4) FILTER $5], agg#5=[COUNT() FILTER $5])
+          LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($6, 10)], JOB=[$2], COMM=[$6], $f5=[LIKE($1, '%ok%')])
+            LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testWithinDistinctFilteredAggsUniformDistinctKeysNoThrow">
+    <Resource name="sql">
+      <![CDATA[SELECT deptno,
+ SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE comm > 10),
+ AVG(comm) WITHIN DISTINCT (job) FILTER (WHERE ename LIKE '%ok%')
+FROM emp
+GROUP BY deptno]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1) WITHIN DISTINCT ($3) FILTER $2], EXPR$2=[AVG($4) WITHIN DISTINCT ($3) FILTER $5])
+  LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($6, 10)], JOB=[$2], COMM=[$6], $f5=[LIKE($1, '%ok%')])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[CASE(=($2, 0), null:INTEGER, $1)], EXPR$2=[CAST(/($3, $4)):INTEGER])
+  LogicalProject(DEPTNO=[$0], EXPR$1=[$1], $f2=[$2], $f3=[CASE(=($4, 0), null:INTEGER, $3)], $f4=[$4])
+    LogicalAggregate(group=[{0}], agg#0=[$SUM0($1) FILTER $3], agg#1=[COUNT() FILTER $3], agg#2=[$SUM0($2) FILTER $4], agg#3=[COUNT() FILTER $4])
+      LogicalProject(DEPTNO=[$0], $f2=[$2], $f4=[$4], $f6=[>($3, 0)], $f7=[>($5, 0)])
+        LogicalAggregate(group=[{0, 3}], agg#0=[MIN($1) FILTER $2], agg#1=[COUNT() FILTER $2], agg#2=[MIN($4) FILTER $5], agg#3=[COUNT() FILTER $5])
+          LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($6, 10)], JOB=[$2], COMM=[$6], $f5=[LIKE($1, '%ok%')])
+            LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
   <TestCase name="testWithinDistinctNoThrow">
     <Resource name="sql">
       <![CDATA[SELECT deptno, SUM(sal), SUM(sal) WITHIN DISTINCT (job)
@@ -13549,4 +13658,55 @@
 ]]>
     </Resource>
   </TestCase>
+  <TestCase name="testWithinDistinctUniformDistinctKeys">
+    <Resource name="sql">
+      <![CDATA[SELECT deptno,
+ SUM(sal) WITHIN DISTINCT (job),
+ AVG(comm) WITHIN DISTINCT (job)
+FROM emp
+GROUP BY deptno]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1) WITHIN DISTINCT ($2)], EXPR$2=[AVG($3) WITHIN DISTINCT ($2)])
+  LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], COMM=[$6])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[$1], EXPR$2=[CAST(/($2, $3)):INTEGER NOT NULL])
+  LogicalAggregate(group=[{0}], agg#0=[$SUM0($1) FILTER $3], agg#1=[$SUM0($2) FILTER $4], agg#2=[COUNT()])
+    LogicalProject(DEPTNO=[$0], $f2=[$2], $f4=[$4], $f6=[$THROW_UNLESS(=($2, $3), 'more than one distinct value in agg UNIQUE_VALUE')], $f7=[$THROW_UNLESS(=($4, $5), 'more than one distinct value in agg UNIQUE_VALUE')])
+      LogicalAggregate(group=[{0, 2}], agg#0=[MIN($1)], agg#1=[MAX($1)], agg#2=[MIN($3)], agg#3=[MAX($3)])
+        LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], COMM=[$6])
+          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testWithinDistinctUniformDistinctKeysNoThrow">
+    <Resource name="sql">
+      <![CDATA[SELECT deptno,
+ SUM(sal) WITHIN DISTINCT (job),
+ AVG(comm) WITHIN DISTINCT (job)
+FROM emp
+GROUP BY deptno]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1) WITHIN DISTINCT ($2)], EXPR$2=[AVG($3) WITHIN DISTINCT ($2)])
+  LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], COMM=[$6])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[$1], EXPR$2=[CAST(/($2, $3)):INTEGER NOT NULL])
+  LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[$SUM0($3)], agg#2=[COUNT()])
+    LogicalAggregate(group=[{0, 2}], agg#0=[MIN($1)], agg#1=[MIN($3)])
+      LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], COMM=[$6])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
 </Root>
diff --git a/core/src/test/resources/sql/within-distinct.iq b/core/src/test/resources/sql/within-distinct.iq
index 2697208..ea8c5ec 100644
--- a/core/src/test/resources/sql/within-distinct.iq
+++ b/core/src/test/resources/sql/within-distinct.iq
@@ -891,4 +891,51 @@
 more than one distinct value in agg UNIQUE_VALUE
 !error
 
+# Since all of the people from WY are filtered out, make sure both "COUNT(*)"
+# and "AVG(age)" ignore that entire group. Also, filters can be used to
+# manufacture uniqueness within a distinct key set. Without filters on these
+# aggregate calls, the query would throw due to non-unique ages in each state.
+WITH FriendStates
+AS (SELECT * FROM (VALUES
+  ('Alice', 789, 'UT'),
+  ('Bob', 25, 'UT'),
+  ('Carlos', 25, 'UT'),
+  ('Dan', 12, 'UT'),
+  ('Erin', 567, 'WY'),
+  ('Frank', 456, 'WY')) AS FriendStates (name, age, state))
+SELECT AVG(age) WITHIN DISTINCT (state) FILTER (WHERE age < 100 AND age > 18) AS aa_s,
+  COUNT(*) WITHIN DISTINCT (state) FILTER (WHERE age < 100 AND age > 18) AS c_s
+FROM FriendStates;
++------+-----+
+| AA_S | C_S |
++------+-----+
+|   25 |   1 |
++------+-----+
+(1 row)
+
+!ok
+
+# Unlike the previous example with FriendStates, this one should count the null
+# age of 'Forest' in WY, however it should also be left out of the average
+# because it's null.
+WITH FriendStates
+AS (SELECT * FROM (VALUES
+  ('Alice', 789, 'UT'),
+  ('Bob', 25, 'UT'),
+  ('Carlos', 25, 'UT'),
+  ('Dan', 678, 'UT'),
+  ('Erin', 567, 'WY'),
+  ('Forest', NULL, 'WY')) AS FriendStates (name, age, state))
+SELECT AVG(age) WITHIN DISTINCT (state) FILTER (WHERE name LIKE '%o%') AS aa_s,
+  COUNT(*) WITHIN DISTINCT (state) FILTER (WHERE name LIKE '%o%') AS c_s
+FROM FriendStates;
++------+-----+
+| AA_S | C_S |
++------+-----+
+|   25 |   2 |
++------+-----+
+(1 row)
+
+!ok
+
 # End within-distinct.iq