[CALCITE-4726] Support aggregate calls with a FILTER clause in AggregateExpandWithinDistinctRule (Will Noble)
Close apache/calcite#2483
diff --git a/core/src/main/java/org/apache/calcite/jdbc/SimpleCalciteSchema.java b/core/src/main/java/org/apache/calcite/jdbc/SimpleCalciteSchema.java
index dc67c2a..9630bc4 100644
--- a/core/src/main/java/org/apache/calcite/jdbc/SimpleCalciteSchema.java
+++ b/core/src/main/java/org/apache/calcite/jdbc/SimpleCalciteSchema.java
@@ -76,7 +76,7 @@
return calciteSchema;
}
- private @Nullable String caseInsensitiveLookup(Set<String> candidates, String name) {
+ private static @Nullable String caseInsensitiveLookup(Set<String> candidates, String name) {
// Exact string lookup
if (candidates.contains(name)) {
return name;
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandWithinDistinctRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandWithinDistinctRule.java
index 91cca97..0e6c621 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandWithinDistinctRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandWithinDistinctRule.java
@@ -31,6 +31,7 @@
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Util;
+import org.apache.calcite.util.mapping.IntPair;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
@@ -113,8 +114,6 @@
// Wait until AggregateReduceFunctionsRule has dealt with AVG etc.
&& aggregate.getAggCallList().stream()
.noneMatch(CoreRules.AGGREGATE_REDUCE_FUNCTIONS::canReduce)
- // Don't know that we can handle FILTER yet
- && aggregate.getAggCallList().stream().noneMatch(c -> c.filterArg >= 0)
// Don't think we can handle GROUPING SETS yet
&& aggregate.getGroupType() == Aggregate.Group.SIMPLE;
}
@@ -132,7 +131,7 @@
//
// or in algebra,
//
- // Aggregate($0, SUM($2), SUM($3) WITHIN DISTINCT ($4))
+ // Aggregate($0, SUM($2), SUM($2) WITHIN DISTINCT ($4))
// Scan(emp)
//
// We plan to generate the following:
@@ -154,8 +153,6 @@
// SUM(sal) WITHIN DISTINCT (sal)
//
- // TODO: handle "agg(x) filter (b)"
-
final List<AggregateCall> aggCallList =
aggregate.getAggCallList()
.stream()
@@ -179,27 +176,31 @@
// sum(x) within distinct (y, z) ... group by y
// can be simplified to
// sum(x) within distinct (z) ... group by y
+ // Note that this assumes a single grouping set for the original agg.
distinctKeys = distinctKeys.rebuild()
.removeAll(aggregate.getGroupSet()).build();
}
}
argLists.put(distinctKeys, aggCall);
- assert aggCall.filterArg < 0;
}
+ // Compute the set of all grouping sets that will be used in the output
+ // query. For each WITHIN DISTINCT aggregate call, we will need a grouping
+ // set that is the union of the aggregate call's unique keys and the input
+ // query's overall grouping. Redundant grouping sets can be reused for
+ // multiple aggregate calls.
final Set<ImmutableBitSet> groupSetTreeSet =
new TreeSet<>(ImmutableBitSet.ORDERING);
- groupSetTreeSet.add(aggregate.getGroupSet());
for (ImmutableBitSet key : argLists.keySet()) {
- if (key == notDistinct) {
- continue;
- }
groupSetTreeSet.add(
- ImmutableBitSet.of(key).union(aggregate.getGroupSet()));
+ (key == notDistinct)
+ ? aggregate.getGroupSet()
+ : ImmutableBitSet.of(key).union(aggregate.getGroupSet()));
}
final ImmutableList<ImmutableBitSet> groupSets =
ImmutableList.copyOf(groupSetTreeSet);
+ final boolean hasMultipleGroupSets = groupSets.size() > 1;
final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
final Set<Integer> fullGroupOrderedSet = new LinkedHashSet<>();
fullGroupOrderedSet.addAll(aggregate.getGroupSet().asSet());
@@ -216,40 +217,89 @@
//
// or in algebra,
//
- // Aggregate([($0), ($0, $2)], SUM($2), MIN($2), MAX($2), GROUPING($0, $4))
+ // Aggregate([($0), ($0, $4)], SUM($2), MIN($2), MAX($2), GROUPING($0, $4))
// Scan(emp)
final RelBuilder b = call.builder();
b.push(aggregate.getInput());
final List<RelBuilder.AggCall> aggCalls = new ArrayList<>();
+ // Helper class for building the inner query.
// CHECKSTYLE: IGNORE 1
class Registrar {
final int g = fullGroupSet.cardinality();
- final Map<Integer, Integer> args = new HashMap<>();
+ /** Map of input fields (below the original aggregation) and filter args
+ * to inner query aggregate calls. */
+ final Map<IntPair, Integer> args = new HashMap<>();
+ /** Map of aggregate calls from the original aggregation to inner query
+ * aggregate calls. */
final Map<Integer, Integer> aggs = new HashMap<>();
+ /** Map of aggregate calls from the original aggregation to inner-query
+ * {@code COUNT(*)} calls, which are only needed for filters in the outer
+ * aggregate when the original aggregate call does not ignore null
+ * inputs. */
+ final Map<Integer, Integer> counts = new HashMap<>();
- List<Integer> fields(List<Integer> fields) {
- return Util.transform(fields, this::field);
+ List<Integer> fields(List<Integer> fields, int filterArg) {
+ return Util.transform(fields, f -> this.field(f, filterArg));
}
- int field(int field) {
- return Objects.requireNonNull(args.get(field));
+ int field(int field, int filterArg) {
+ return Objects.requireNonNull(args.get(IntPair.of(field, filterArg)));
}
- int register(int field) {
- return args.computeIfAbsent(field, j -> {
+ /** Computes an aggregate call argument's values for a
+ * {@code WITHIN DISTINCT} aggregate call.
+ *
+ * <p>For example, to compute
+ * {@code SUM(x) WITHIN DISTINCT (y) GROUP BY (z)},
+ * the inner aggregate must first group {@code x} by {@code (y, z)}
+ * — using {@code MIN} to select the (hopefully) unique value of
+ * {@code x} for each {@code (y, z)} group. Actually summing over the
+ * grouped {@code x} values must occur in an outer aggregate.
+ *
+ * @param field Index of an input field that's used in a
+ * {@code WITHIN DISTINCT} aggregate call
+ * @param filterArg Filter arg used in the original aggregate call, or
+ * {@code -1} if there is no filter. We use the same filter in
+ * the inner query.
+ * @return Index of the inner query aggregate call representing the
+ * grouped field, which can be referenced in the outer query
+ * aggregate call
+ */
+ int register(int field, int filterArg) {
+ return args.computeIfAbsent(IntPair.of(field, filterArg), j -> {
final int ordinal = g + aggCalls.size();
+ RelBuilder.AggCall groupedField =
+ b.aggregateCall(SqlStdOperatorTable.MIN, b.field(field));
aggCalls.add(
- b.aggregateCall(SqlStdOperatorTable.MIN, b.field(j)));
+ filterArg < 0
+ ? groupedField
+ : groupedField.filter(b.field(filterArg)));
if (config.throwIfNotUnique()) {
+ groupedField =
+ b.aggregateCall(SqlStdOperatorTable.MAX, b.field(field));
aggCalls.add(
- b.aggregateCall(SqlStdOperatorTable.MAX, b.field(j)));
+ filterArg < 0
+ ? groupedField
+ : groupedField.filter(b.field(filterArg)));
}
return ordinal;
});
}
+ /** Registers an aggregate call that is <em>not</em> a
+ * {@code WITHIN DISTINCT} call.
+ *
+ * <p>Unlike the case handled by {@link #register(int, int)} above,
+ * aggregate calls without any distinct keys do not need a second round
+ * of aggregation in the outer query, so they can be computed "as-is" in
+ * the inner query.
+ *
+ * @param i Index of the aggregate call in the original aggregation
+ * @param aggregateCall Original aggregate call
+ * @return Index of the aggregate call in the computed inner query
+ */
int registerAgg(int i, RelBuilder.AggCall aggregateCall) {
final int ordinal = g + aggCalls.size();
aggs.put(i, ordinal);
@@ -260,6 +310,33 @@
int getAgg(int i) {
return Objects.requireNonNull(aggs.get(i));
}
+
+ /** Registers an extra {@code COUNT} aggregate call when it's needed to
+ * filter out null inputs in the outer aggregate.
+ *
+ * <p>This should only be called for aggregate calls with filters. It's
+ * possible that the filter would eliminate all input rows to the
+ * {@code MIN} call in the inner query, so calls in the outer
+ * aggregate may need to be aware of this. See usage of
+ * {@link AggregateExpandWithinDistinctRule#mustBeCounted(AggregateCall)}.
+ *
+ * @param filterArg The original aggregate call's filter; must be
+ * non-negative
+ * @return Index of the {@code COUNT} call in the computed inner query
+ */
+ int registerCount(int filterArg) {
+ assert filterArg >= 0;
+ return counts.computeIfAbsent(filterArg, i -> {
+ final int ordinal = g + aggCalls.size();
+ aggCalls.add(b.aggregateCall(SqlStdOperatorTable.COUNT)
+ .filter(b.field(filterArg)));
+ return ordinal;
+ });
+ }
+
+ int getCount(int filterArg) {
+ return Objects.requireNonNull(counts.get(filterArg));
+ }
}
final Registrar registrar = new Registrar();
@@ -269,13 +346,25 @@
b.aggregateCall(c.getAggregation(),
b.fields(c.getArgList())));
} else {
- c.getArgList().forEach(registrar::register);
+ for (int inputIdx : c.getArgList()) {
+ registrar.register(inputIdx, c.filterArg);
+ }
+ if (mustBeCounted(c)) {
+ registrar.registerCount(c.filterArg);
+ }
}
});
+ // Add an additional GROUPING() aggregate call so we can select only the
+ // relevant inner-aggregate rows from the outer aggregate. If there is only
+ // 1 grouping set (i.e. every aggregate call has the same distinct keys),
+ // no GROUPING() call is necessary.
final int grouping =
- registrar.registerAgg(-1,
- b.aggregateCall(SqlStdOperatorTable.GROUPING,
- b.fields(fullGroupList)));
+ hasMultipleGroupSets
+ ? registrar.registerAgg(-1,
+ b.aggregateCall(
+ SqlStdOperatorTable.GROUPING,
+ b.fields(fullGroupList)))
+ : -1;
b.aggregate(
b.groupKey(fullGroupSet,
(Iterable<ImmutableBitSet>) groupSets), aggCalls);
@@ -304,32 +393,56 @@
aggCalls.clear();
Ord.forEach(aggCallList, (c, i) -> {
final List<RexNode> filters = new ArrayList<>();
- final RexNode groupFilter = b.equals(b.field(grouping),
- b.literal(
- groupValue(fullGroupList,
- union(aggregate.getGroupSet(), c.distinctKeys))));
- filters.add(groupFilter);
- final RelBuilder.AggCall aggCall;
+ RexNode groupFilter = null;
+ if (hasMultipleGroupSets) {
+ groupFilter =
+ b.equals(
+ b.field(grouping),
+ b.literal(
+ groupValue(fullGroupList, union(aggregate.getGroupSet(), c.distinctKeys))));
+ filters.add(groupFilter);
+ }
+ RelBuilder.AggCall aggCall;
if (c.distinctKeys == null) {
aggCall = b.aggregateCall(SqlStdOperatorTable.MIN,
b.field(registrar.getAgg(i)));
} else {
- aggCall = b.aggregateCall(c.getAggregation(),
- b.fields(registrar.fields(c.getArgList())));
+ // The inputs to this aggregate are outputs from MIN() calls from the
+ // inner agg, and MIN() returns null iff it has no non-null inputs,
+ // which can only happen if an original aggregate's filter causes all
+ // non-null input rows to be discarded for a particular group in the
+ // inner aggregate. In this case, it should be ignored by the outer
+ // aggregate as well. In case the aggregate call does not naturally
+ // ignore null inputs, we add a filter based on a COUNT() in the inner
+ // aggregate.
+ aggCall =
+ b.aggregateCall(
+ c.getAggregation(),
+ b.fields(registrar.fields(c.getArgList(), c.filterArg)));
+
+ if (mustBeCounted(c)) {
+ filters.add(b.greaterThan(b.field(registrar.getCount(c.filterArg)), b.literal(0)));
+ }
if (config.throwIfNotUnique()) {
for (int j : c.getArgList()) {
+ RexNode isUniqueCondition =
+ b.isNotDistinctFrom(
+ b.field(registrar.field(j, c.filterArg)),
+ b.field(registrar.field(j, c.filterArg) + 1));
+ if (groupFilter != null) {
+ isUniqueCondition = b.or(b.not(groupFilter), isUniqueCondition);
+ }
String message = "more than one distinct value in agg UNIQUE_VALUE";
filters.add(
- b.call(SqlInternalOperators.THROW_UNLESS,
- b.or(b.not(groupFilter),
- b.isNotDistinctFrom(b.field(registrar.field(j)),
- b.field(registrar.field(j) + 1))),
- b.literal(message)));
+ b.call(SqlInternalOperators.THROW_UNLESS, isUniqueCondition, b.literal(message)));
}
}
}
- aggCalls.add(aggCall.filter(b.and(filters)));
+ if (filters.size() > 0) {
+ aggCall = aggCall.filter(b.and(filters));
+ }
+ aggCalls.add(aggCall);
});
b.aggregate(
@@ -342,6 +455,22 @@
call.transformTo(b.build());
}
+ private static boolean mustBeCounted(AggregateCall aggCall) {
+ // Always count filtered inner aggregates to be safe.
+ //
+ // It's possible that, for some aggregate calls (namely, those that
+ // completely ignore null inputs), we could neglect counting the
+ // grouped-and-filtered rows of the inner aggregate and filtering the empty
+ // ones out from the outer aggregate, since those empty groups would produce
+ // null values as the result of MIN and thus be ignored by the outer
+ // aggregate anyway.
+ //
+ // Note that using "aggCall.ignoreNulls()" is not sufficient to determine
+ // when it's safe to do this, since for COUNT the value of ignoreNulls()
+ // should generally be true even though COUNT(*) will never ignore anything.
+ return aggCall.hasFilter();
+ }
+
/** Converts a {@code DISTINCT} aggregate call into an equivalent one with
* {@code WITHIN DISTINCT}.
*
@@ -362,6 +491,7 @@
.stream()
.filter(i ->
aggregateCall.getAggregation().getKind() != SqlKind.COUNT
+ || aggregateCall.hasFilter()
|| isNullable.test(i))
.collect(Collectors.toList());
return aggregateCall.withDistinct(false)
diff --git a/core/src/main/java/org/apache/calcite/rex/RexSimplify.java b/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
index 8629b58..96338ab 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexSimplify.java
@@ -400,7 +400,7 @@
* Try to find a literal with the given value in the input list.
* The type of the literal must be one of the numeric types.
*/
- private int findLiteralIndex(List<RexNode> operands, BigDecimal value) {
+ private static int findLiteralIndex(List<RexNode> operands, BigDecimal value) {
for (int i = 0; i < operands.size(); i++) {
if (operands.get(i).isA(SqlKind.LITERAL)) {
Comparable comparable = ((RexLiteral) operands.get(i)).getValue();
diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index 40268b4..6fbdde0 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -731,6 +731,11 @@
return call(SqlStdOperatorTable.EQUALS, operand0, operand1);
}
+ /** Creates a {@code >}. */
+ public RexNode greaterThan(RexNode operand0, RexNode operand1) {
+ return call(SqlStdOperatorTable.GREATER_THAN, operand0, operand1);
+ }
+
/** Creates a {@code <>}. */
public RexNode notEquals(RexNode operand0, RexNode operand1) {
return call(SqlStdOperatorTable.NOT_EQUALS, operand0, operand1);
@@ -3543,13 +3548,12 @@
if (distinct) {
b.append("DISTINCT ");
}
- final int iMax = operands.size() - 1;
- for (int i = 0; ; i++) {
- b.append(operands.get(i));
- if (i == iMax) {
- break;
+ if (operands.size() > 0) {
+ b.append(operands.get(0));
+ for (int i = 1; i < operands.size(); i++) {
+ b.append(", ");
+ b.append(operands.get(i));
}
- b.append(", ");
}
b.append(')');
if (filter != null) {
diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
index 9656adb..4cbc11e 100644
--- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
@@ -330,6 +330,22 @@
assertThat(root, hasTree(expected));
}
+ @Test void testScanFilterGreaterThan() {
+ // Equivalent SQL:
+ // SELECT *
+ // FROM emp
+ // WHERE deptno > 20
+ final RelBuilder builder = RelBuilder.create(config().build());
+ RelNode root =
+ builder.scan("EMP")
+ .filter(
+ builder.greaterThan(builder.field("DEPTNO"), builder.literal(20)))
+ .build();
+ final String expected = "LogicalFilter(condition=[>($7, 20)])\n"
+ + " LogicalTableScan(table=[[scott, EMP]])\n";
+ assertThat(root, hasTree(expected));
+ }
+
@Test void testSnapshotTemporalTable() {
// Equivalent SQL:
// SELECT *
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 809289b..93e2893 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -1461,14 +1461,49 @@
sql(sql).with(program).check();
}
+ /** Tests {@link AggregateExpandWithinDistinctRule}. If all aggregate calls
+ * have the same distinct keys, there is no need for multiple grouping
+ * sets. */
+ @Test void testWithinDistinctUniformDistinctKeys() {
+ final String sql = "SELECT deptno,\n"
+ + " SUM(sal) WITHIN DISTINCT (job),\n"
+ + " AVG(comm) WITHIN DISTINCT (job)\n"
+ + "FROM emp\n"
+ + "GROUP BY deptno";
+ HepProgram program = new HepProgramBuilder()
+ .addRuleInstance(CoreRules.AGGREGATE_REDUCE_FUNCTIONS)
+ .addRuleInstance(CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT)
+ .build();
+ sql(sql).with(program).check();
+ }
+
+ /** Tests {@link AggregateExpandWithinDistinctRule}. If all aggregate calls
+ * have the same distinct keys, and we're not checking for true uniqueness,
+ * there is no need for filtering in the outer aggregate. */
+ @Test void testWithinDistinctUniformDistinctKeysNoThrow() {
+ final String sql = "SELECT deptno,\n"
+ + " SUM(sal) WITHIN DISTINCT (job),\n"
+ + " AVG(comm) WITHIN DISTINCT (job)\n"
+ + "FROM emp\n"
+ + "GROUP BY deptno";
+ HepProgram program = new HepProgramBuilder()
+ .addRuleInstance(CoreRules.AGGREGATE_REDUCE_FUNCTIONS)
+ .addRuleInstance(
+ CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT.config
+ .withThrowIfNotUnique(false).toRule())
+ .build();
+ sql(sql).with(program).check();
+ }
+
/** Tests that {@link AggregateExpandWithinDistinctRule} treats
* "COUNT(DISTINCT x)" as if it were "COUNT(x) WITHIN DISTINCT (x)". */
@Test void testWithinDistinctCountDistinct() {
final String sql = "SELECT deptno,\n"
- + " SUM(sal) WITHIN DISTINCT (job) AS ss_j,\n"
+ + " SUM(sal) WITHIN DISTINCT (comm) AS ss_c,\n"
+ " COUNT(DISTINCT job) cdj,\n"
+ " COUNT(job) WITHIN DISTINCT (job) AS cj_j,\n"
- + " COUNT(DISTINCT job) WITHIN DISTINCT (job) AS cdj_j\n"
+ + " COUNT(DISTINCT job) WITHIN DISTINCT (job) AS cdj_j,\n"
+ + " COUNT(DISTINCT job) FILTER (WHERE sal > 1000) AS cdj_filtered\n"
+ "FROM emp\n"
+ "GROUP BY deptno";
HepProgram program = new HepProgramBuilder()
@@ -1479,6 +1514,78 @@
sql(sql).with(program).check();
}
+ /** Test case for
+ * <a href="https://issues.apache.org/jira/browse/CALCITE-4726">[CALCITE-4726]
+ * Support aggregate calls with a FILTER clause in
+ * AggregateExpandWithinDistinctRule</a>.
+ *
+ * <p>Tests {@link AggregateExpandWithinDistinctRule} with different
+ * distinct keys and different filters for each aggregate call. */
+ @Test void testWithinDistinctFilteredAggs() {
+ final String sql = "SELECT deptno,\n"
+ + " SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE comm > 10),\n"
+ + " AVG(comm) WITHIN DISTINCT (sal) FILTER (WHERE ename LIKE '%ok%')\n"
+ + "FROM emp\n"
+ + "GROUP BY deptno";
+ HepProgram program = new HepProgramBuilder()
+ .addRuleInstance(CoreRules.AGGREGATE_REDUCE_FUNCTIONS)
+ .addRuleInstance(CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT)
+ .build();
+ sql(sql).with(program).check();
+ }
+
+ /** Tests {@link AggregateExpandWithinDistinctRule}. Includes multiple
+ * different filters for the aggregate calls, and all aggregate calls have the
+ * same distinct keys, so there is no need to filter based on
+ * {@code GROUPING()}. */
+ @Test void testWithinDistinctFilteredAggsUniformDistinctKeys() {
+ final String sql = "SELECT deptno,\n"
+ + " SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE comm > 10),\n"
+ + " AVG(comm) WITHIN DISTINCT (job) FILTER (WHERE ename LIKE '%ok%')\n"
+ + "FROM emp\n"
+ + "GROUP BY deptno";
+ HepProgram program = new HepProgramBuilder()
+ .addRuleInstance(CoreRules.AGGREGATE_REDUCE_FUNCTIONS)
+ .addRuleInstance(CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT)
+ .build();
+ sql(sql).with(program).check();
+ }
+
+ /** Tests {@link AggregateExpandWithinDistinctRule}. Includes multiple
+ * different filters for the aggregate calls, and all aggregate calls have the
+ * same distinct keys, so there is no need to filter based on
+ * {@code GROUPING()}. Does <em>not</em> throw if not unique. */
+ @Test void testWithinDistinctFilteredAggsUniformDistinctKeysNoThrow() {
+ final String sql = "SELECT deptno,\n"
+ + " SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE comm > 10),\n"
+ + " AVG(comm) WITHIN DISTINCT (job) FILTER (WHERE ename LIKE '%ok%')\n"
+ + "FROM emp\n"
+ + "GROUP BY deptno";
+ HepProgram program = new HepProgramBuilder()
+ .addRuleInstance(CoreRules.AGGREGATE_REDUCE_FUNCTIONS)
+ .addRuleInstance(
+ CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT.config
+ .withThrowIfNotUnique(false).toRule())
+ .build();
+ sql(sql).with(program).check();
+ }
+
+ /** Tests {@link AggregateExpandWithinDistinctRule}. Includes multiple
+ * identical filters for the aggregate calls. The filters should be
+ * re-used. */
+ @Test void testWithinDistinctFilteredAggsSameFilter() {
+ final String sql = "SELECT deptno,\n"
+ + " SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE ename LIKE '%ok%'),\n"
+ + " AVG(comm) WITHIN DISTINCT (sal) FILTER (WHERE ename LIKE '%ok%')\n"
+ + "FROM emp\n"
+ + "GROUP BY deptno";
+ HepProgram program = new HepProgramBuilder()
+ .addRuleInstance(CoreRules.AGGREGATE_REDUCE_FUNCTIONS)
+ .addRuleInstance(CoreRules.AGGREGATE_EXPAND_WITHIN_DISTINCT)
+ .build();
+ sql(sql).with(program).check();
+ }
+
@Test void testPushProjectPastFilter() {
final String sql = "select empno + deptno from emp where sal = 10 * comm\n"
+ "and upper(ename) = 'FOO'";
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 3ed9020..a99fa46 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -13500,31 +13500,140 @@
<TestCase name="testWithinDistinctCountDistinct">
<Resource name="sql">
<![CDATA[SELECT deptno,
- SUM(sal) WITHIN DISTINCT (job) AS ss_j,
+ SUM(sal) WITHIN DISTINCT (comm) AS ss_c,
COUNT(DISTINCT job) cdj,
COUNT(job) WITHIN DISTINCT (job) AS cj_j,
- COUNT(DISTINCT job) WITHIN DISTINCT (job) AS cdj_j
+ COUNT(DISTINCT job) WITHIN DISTINCT (job) AS cdj_j,
+ COUNT(DISTINCT job) FILTER (WHERE sal > 1000) AS cdj_filtered
FROM emp
GROUP BY deptno]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
-LogicalAggregate(group=[{0}], SS_J=[SUM($1) WITHIN DISTINCT ($2)], CDJ=[COUNT(DISTINCT $2)], CJ_J=[COUNT() WITHIN DISTINCT ($2)], CDJ_J=[COUNT(DISTINCT $2) WITHIN DISTINCT ($2)])
- LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])
+LogicalAggregate(group=[{0}], SS_C=[SUM($1) WITHIN DISTINCT ($2)], CDJ=[COUNT(DISTINCT $3)], CJ_J=[COUNT() WITHIN DISTINCT ($3)], CDJ_J=[COUNT(DISTINCT $3) WITHIN DISTINCT ($3)], CDJ_FILTERED=[COUNT(DISTINCT $3) FILTER $4])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], COMM=[$6], JOB=[$2], $f4=[>($5, 1000)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
-LogicalProject(DEPTNO=[$0], $f1=[$1], $f2=[$2], $f20=[$2], $f21=[$2])
- LogicalAggregate(group=[{0}], agg#0=[$SUM0($1) FILTER $2], agg#1=[COUNT() FILTER $2])
- LogicalProject(DEPTNO=[$0], $f2=[$2], $f4=[=($3, 0)])
- LogicalAggregate(group=[{0, 2}], groups=[[{0, 2}, {0}]], agg#0=[MIN($1)], agg#1=[GROUPING($0, $2)])
- LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2])
+LogicalProject(DEPTNO=[$0], $f1=[$1], $f2=[$2], $f20=[$2], $f21=[$2], $f3=[$3])
+ LogicalAggregate(group=[{0}], agg#0=[$SUM0($1) FILTER $3], agg#1=[COUNT() FILTER $4], agg#2=[COUNT($2) FILTER $5])
+ LogicalProject(DEPTNO=[$0], $f3=[$3], $f4=[$4], $f7=[=($6, 1)], $f8=[=($6, 2)], $f9=[AND(=($6, 2), >($5, 0))])
+ LogicalAggregate(group=[{0, 2, 3}], groups=[[{0, 2}, {0, 3}]], agg#0=[MIN($1)], agg#1=[MIN($3) FILTER $4], agg#2=[COUNT() FILTER $4], agg#3=[GROUPING($0, $2, $3)])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], COMM=[$6], JOB=[$2], $f4=[>($5, 1000)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
+ <TestCase name="testWithinDistinctFilteredAggs">
+ <Resource name="sql">
+ <![CDATA[SELECT deptno,
+ SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE comm > 10),
+ AVG(comm) WITHIN DISTINCT (sal) FILTER (WHERE ename LIKE '%ok%')
+FROM emp
+GROUP BY deptno]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1) WITHIN DISTINCT ($3) FILTER $2], EXPR$2=[AVG($4) WITHIN DISTINCT ($1) FILTER $5])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($6, 10)], JOB=[$2], COMM=[$6], $f5=[LIKE($1, '%ok%')])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[CASE(=($2, 0), null:INTEGER, $1)], EXPR$2=[CAST(/($3, $4)):INTEGER])
+ LogicalProject(DEPTNO=[$0], EXPR$1=[$1], $f2=[$2], $f3=[CASE(=($4, 0), null:INTEGER, $3)], $f4=[$4])
+ LogicalAggregate(group=[{0}], agg#0=[$SUM0($1) FILTER $3], agg#1=[COUNT() FILTER $4], agg#2=[$SUM0($2) FILTER $5], agg#3=[COUNT() FILTER $6])
+ LogicalProject(DEPTNO=[$0], $f3=[$3], $f6=[$6], $f10=[AND(=($9, 2), >($5, 0), $THROW_UNLESS(OR(<>($9, 2), AND(IS NULL($3), IS NULL($4)), IS TRUE(=($3, $4))), 'more than one distinct value in agg UNIQUE_VALUE'))], $f11=[AND(=($9, 2), >($5, 0))], $f12=[AND(=($9, 1), >($8, 0), $THROW_UNLESS(OR(<>($9, 1), AND(IS NULL($6), IS NULL($7)), IS TRUE(=($6, $7))), 'more than one distinct value in agg UNIQUE_VALUE'))], $f13=[AND(=($9, 1), >($8, 0))])
+ LogicalAggregate(group=[{0, 1, 3}], groups=[[{0, 1}, {0, 3}]], agg#0=[MIN($1) FILTER $2], agg#1=[MAX($1) FILTER $2], agg#2=[COUNT() FILTER $2], agg#3=[MIN($4) FILTER $5], agg#4=[MAX($4) FILTER $5], agg#5=[COUNT() FILTER $5], agg#6=[GROUPING($0, $1, $3)])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($6, 10)], JOB=[$2], COMM=[$6], $f5=[LIKE($1, '%ok%')])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testWithinDistinctFilteredAggsSameFilter">
+ <Resource name="sql">
+ <![CDATA[SELECT deptno,
+ SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE ename LIKE '%ok%'),
+ AVG(comm) WITHIN DISTINCT (sal) FILTER (WHERE ename LIKE '%ok%')
+FROM emp
+GROUP BY deptno]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1) WITHIN DISTINCT ($3) FILTER $2], EXPR$2=[AVG($4) WITHIN DISTINCT ($1) FILTER $2])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[LIKE($1, '%ok%')], JOB=[$2], COMM=[$6])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[CASE(=($2, 0), null:INTEGER, $1)], EXPR$2=[CAST(/($3, $4)):INTEGER])
+ LogicalProject(DEPTNO=[$0], EXPR$1=[$1], $f2=[$2], $f3=[CASE(=($4, 0), null:INTEGER, $3)], $f4=[$4])
+ LogicalAggregate(group=[{0}], agg#0=[$SUM0($1) FILTER $3], agg#1=[COUNT() FILTER $4], agg#2=[$SUM0($2) FILTER $5], agg#3=[COUNT() FILTER $6])
+ LogicalProject(DEPTNO=[$0], $f3=[$3], $f6=[$6], $f9=[AND(=($8, 2), >($5, 0), $THROW_UNLESS(OR(<>($8, 2), AND(IS NULL($3), IS NULL($4)), IS TRUE(=($3, $4))), 'more than one distinct value in agg UNIQUE_VALUE'))], $f10=[AND(=($8, 2), >($5, 0))], $f11=[AND(=($8, 1), >($5, 0), $THROW_UNLESS(OR(<>($8, 1), AND(IS NULL($6), IS NULL($7)), IS TRUE(=($6, $7))), 'more than one distinct value in agg UNIQUE_VALUE'))], $f12=[AND(=($8, 1), >($5, 0))])
+ LogicalAggregate(group=[{0, 1, 3}], groups=[[{0, 1}, {0, 3}]], agg#0=[MIN($1) FILTER $2], agg#1=[MAX($1) FILTER $2], agg#2=[COUNT() FILTER $2], agg#3=[MIN($4) FILTER $2], agg#4=[MAX($4) FILTER $2], agg#5=[GROUPING($0, $1, $3)])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[LIKE($1, '%ok%')], JOB=[$2], COMM=[$6])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testWithinDistinctFilteredAggsUniformDistinctKeys">
+ <Resource name="sql">
+ <![CDATA[SELECT deptno,
+ SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE comm > 10),
+ AVG(comm) WITHIN DISTINCT (job) FILTER (WHERE ename LIKE '%ok%')
+FROM emp
+GROUP BY deptno]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1) WITHIN DISTINCT ($3) FILTER $2], EXPR$2=[AVG($4) WITHIN DISTINCT ($3) FILTER $5])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($6, 10)], JOB=[$2], COMM=[$6], $f5=[LIKE($1, '%ok%')])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[CASE(=($2, 0), null:INTEGER, $1)], EXPR$2=[CAST(/($3, $4)):INTEGER])
+ LogicalProject(DEPTNO=[$0], EXPR$1=[$1], $f2=[$2], $f3=[CASE(=($4, 0), null:INTEGER, $3)], $f4=[$4])
+ LogicalAggregate(group=[{0}], agg#0=[$SUM0($1) FILTER $3], agg#1=[COUNT() FILTER $4], agg#2=[$SUM0($2) FILTER $5], agg#3=[COUNT() FILTER $6])
+ LogicalProject(DEPTNO=[$0], $f2=[$2], $f5=[$5], $f8=[AND(>($4, 0), $THROW_UNLESS(OR(AND(IS NULL($2), IS NULL($3)), IS TRUE(=($2, $3))), 'more than one distinct value in agg UNIQUE_VALUE'))], $f9=[>($4, 0)], $f10=[AND(>($7, 0), $THROW_UNLESS(OR(AND(IS NULL($5), IS NULL($6)), IS TRUE(=($5, $6))), 'more than one distinct value in agg UNIQUE_VALUE'))], $f11=[>($7, 0)])
+ LogicalAggregate(group=[{0, 3}], agg#0=[MIN($1) FILTER $2], agg#1=[MAX($1) FILTER $2], agg#2=[COUNT() FILTER $2], agg#3=[MIN($4) FILTER $5], agg#4=[MAX($4) FILTER $5], agg#5=[COUNT() FILTER $5])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($6, 10)], JOB=[$2], COMM=[$6], $f5=[LIKE($1, '%ok%')])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testWithinDistinctFilteredAggsUniformDistinctKeysNoThrow">
+ <Resource name="sql">
+ <![CDATA[SELECT deptno,
+ SUM(sal) WITHIN DISTINCT (job) FILTER (WHERE comm > 10),
+ AVG(comm) WITHIN DISTINCT (job) FILTER (WHERE ename LIKE '%ok%')
+FROM emp
+GROUP BY deptno]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1) WITHIN DISTINCT ($3) FILTER $2], EXPR$2=[AVG($4) WITHIN DISTINCT ($3) FILTER $5])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($6, 10)], JOB=[$2], COMM=[$6], $f5=[LIKE($1, '%ok%')])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[CASE(=($2, 0), null:INTEGER, $1)], EXPR$2=[CAST(/($3, $4)):INTEGER])
+ LogicalProject(DEPTNO=[$0], EXPR$1=[$1], $f2=[$2], $f3=[CASE(=($4, 0), null:INTEGER, $3)], $f4=[$4])
+ LogicalAggregate(group=[{0}], agg#0=[$SUM0($1) FILTER $3], agg#1=[COUNT() FILTER $3], agg#2=[$SUM0($2) FILTER $4], agg#3=[COUNT() FILTER $4])
+ LogicalProject(DEPTNO=[$0], $f2=[$2], $f4=[$4], $f6=[>($3, 0)], $f7=[>($5, 0)])
+ LogicalAggregate(group=[{0, 3}], agg#0=[MIN($1) FILTER $2], agg#1=[COUNT() FILTER $2], agg#2=[MIN($4) FILTER $5], agg#3=[COUNT() FILTER $5])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], $f2=[>($6, 10)], JOB=[$2], COMM=[$6], $f5=[LIKE($1, '%ok%')])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
<TestCase name="testWithinDistinctNoThrow">
<Resource name="sql">
<![CDATA[SELECT deptno, SUM(sal), SUM(sal) WITHIN DISTINCT (job)
@@ -13549,4 +13658,55 @@
]]>
</Resource>
</TestCase>
+ <TestCase name="testWithinDistinctUniformDistinctKeys">
+ <Resource name="sql">
+ <![CDATA[SELECT deptno,
+ SUM(sal) WITHIN DISTINCT (job),
+ AVG(comm) WITHIN DISTINCT (job)
+FROM emp
+GROUP BY deptno]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1) WITHIN DISTINCT ($2)], EXPR$2=[AVG($3) WITHIN DISTINCT ($2)])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], COMM=[$6])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[$1], EXPR$2=[CAST(/($2, $3)):INTEGER NOT NULL])
+ LogicalAggregate(group=[{0}], agg#0=[$SUM0($1) FILTER $3], agg#1=[$SUM0($2) FILTER $4], agg#2=[COUNT()])
+ LogicalProject(DEPTNO=[$0], $f2=[$2], $f4=[$4], $f6=[$THROW_UNLESS(=($2, $3), 'more than one distinct value in agg UNIQUE_VALUE')], $f7=[$THROW_UNLESS(=($4, $5), 'more than one distinct value in agg UNIQUE_VALUE')])
+ LogicalAggregate(group=[{0, 2}], agg#0=[MIN($1)], agg#1=[MAX($1)], agg#2=[MIN($3)], agg#3=[MAX($3)])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], COMM=[$6])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testWithinDistinctUniformDistinctKeysNoThrow">
+ <Resource name="sql">
+ <![CDATA[SELECT deptno,
+ SUM(sal) WITHIN DISTINCT (job),
+ AVG(comm) WITHIN DISTINCT (job)
+FROM emp
+GROUP BY deptno]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1) WITHIN DISTINCT ($2)], EXPR$2=[AVG($3) WITHIN DISTINCT ($2)])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], COMM=[$6])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(DEPTNO=[$0], EXPR$1=[$1], EXPR$2=[CAST(/($2, $3)):INTEGER NOT NULL])
+ LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[$SUM0($3)], agg#2=[COUNT()])
+ LogicalAggregate(group=[{0, 2}], agg#0=[MIN($1)], agg#1=[MIN($3)])
+ LogicalProject(DEPTNO=[$7], SAL=[$5], JOB=[$2], COMM=[$6])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
</Root>
diff --git a/core/src/test/resources/sql/within-distinct.iq b/core/src/test/resources/sql/within-distinct.iq
index 2697208..ea8c5ec 100644
--- a/core/src/test/resources/sql/within-distinct.iq
+++ b/core/src/test/resources/sql/within-distinct.iq
@@ -891,4 +891,51 @@
more than one distinct value in agg UNIQUE_VALUE
!error
+# Since all of the people from WY are filtered out, make sure both "COUNT(*)"
+# and "AVG(age)" ignore that entire group. Also, filters can be used to
+# manufacture uniqueness within a distinct key set. Without filters on these
+# aggregate calls, the query would throw due to non-unique ages in each state.
+WITH FriendStates
+AS (SELECT * FROM (VALUES
+ ('Alice', 789, 'UT'),
+ ('Bob', 25, 'UT'),
+ ('Carlos', 25, 'UT'),
+ ('Dan', 12, 'UT'),
+ ('Erin', 567, 'WY'),
+ ('Frank', 456, 'WY')) AS FriendStates (name, age, state))
+SELECT AVG(age) WITHIN DISTINCT (state) FILTER (WHERE age < 100 AND age > 18) AS aa_s,
+ COUNT(*) WITHIN DISTINCT (state) FILTER (WHERE age < 100 AND age > 18) AS c_s
+FROM FriendStates;
++------+-----+
+| AA_S | C_S |
++------+-----+
+| 25 | 1 |
++------+-----+
+(1 row)
+
+!ok
+
+# Unlike the previous example with FriendStates, this one should count the null
+# age of 'Forest' in WY, however it should also be left out of the average
+# because it's null.
+WITH FriendStates
+AS (SELECT * FROM (VALUES
+ ('Alice', 789, 'UT'),
+ ('Bob', 25, 'UT'),
+ ('Carlos', 25, 'UT'),
+ ('Dan', 678, 'UT'),
+ ('Erin', 567, 'WY'),
+ ('Forest', NULL, 'WY')) AS FriendStates (name, age, state))
+SELECT AVG(age) WITHIN DISTINCT (state) FILTER (WHERE name LIKE '%o%') AS aa_s,
+ COUNT(*) WITHIN DISTINCT (state) FILTER (WHERE name LIKE '%o%') AS c_s
+FROM FriendStates;
++------+-----+
+| AA_S | C_S |
++------+-----+
+| 25 | 2 |
++------+-----+
+(1 row)
+
+!ok
+
# End within-distinct.iq