blob: ccf32284e74fc0285fc0299dafbfbeff1ad38b57 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.rel.rules;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Aggregate.Group;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.runtime.PairList;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Optionality;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.NavigableSet;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
/**
* Planner rule that expands distinct aggregates
* (such as {@code COUNT(DISTINCT x)}) from a
* {@link org.apache.calcite.rel.core.Aggregate}.
*
* <p>How this is done depends upon the arguments to the function. If all
* functions have the same argument
* (e.g. {@code COUNT(DISTINCT x), SUM(DISTINCT x)} both have the argument
* {@code x}) then one extra {@link org.apache.calcite.rel.core.Aggregate} is
* sufficient.
*
* <p>If there are multiple arguments
* (e.g. {@code COUNT(DISTINCT x), COUNT(DISTINCT y)})
* the rule creates separate {@code Aggregate}s and combines using a
* {@link org.apache.calcite.rel.core.Join}.
*
* @see CoreRules#AGGREGATE_EXPAND_DISTINCT_AGGREGATES
* @see CoreRules#AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN
*/
@Value.Enclosing
public final class AggregateExpandDistinctAggregatesRule
extends RelRule<AggregateExpandDistinctAggregatesRule.Config>
implements TransformationRule {
/** Creates an AggregateExpandDistinctAggregatesRule. */
AggregateExpandDistinctAggregatesRule(Config config) {
super(config);
}
@Deprecated // to be removed before 2.0
public AggregateExpandDistinctAggregatesRule(
Class<? extends Aggregate> clazz,
boolean useGroupingSets,
RelBuilderFactory relBuilderFactory) {
this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory)
.withOperandSupplier(b ->
b.operand(clazz).anyInputs())
.as(Config.class)
.withUsingGroupingSets(useGroupingSets));
}
@Deprecated // to be removed before 2.0
public AggregateExpandDistinctAggregatesRule(
Class<? extends LogicalAggregate> clazz,
boolean useGroupingSets,
RelFactories.JoinFactory joinFactory) {
this(clazz, useGroupingSets, RelBuilder.proto(Contexts.of(joinFactory)));
}
@Deprecated // to be removed before 2.0
public AggregateExpandDistinctAggregatesRule(
Class<? extends LogicalAggregate> clazz,
RelFactories.JoinFactory joinFactory) {
this(clazz, false, RelBuilder.proto(Contexts.of(joinFactory)));
}
//~ Methods ----------------------------------------------------------------
@Override public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
if (!aggregate.containsDistinctCall()) {
return;
}
if (!config.isUsingGroupingSets()
&& aggregate.groupSets.size() > 1) {
// Grouping sets are not handled correctly
// when generating joins.
return;
}
// Find all of the agg expressions. We use a LinkedHashSet to ensure determinism.
final List<AggregateCall> aggCalls = aggregate.getAggCallList();
// Find all aggregate calls with distinct
final List<AggregateCall> distinctAggCalls = aggCalls.stream()
.filter(AggregateCall::isDistinct).collect(Collectors.toList());
// Find all aggregate calls without distinct
final List<AggregateCall> nonDistinctAggCalls = aggCalls.stream()
.filter(aggCall -> !aggCall.isDistinct()).collect(Collectors.toList());
final long filterCount = aggCalls.stream()
.filter(aggCall -> aggCall.filterArg >= 0).count();
final long unsupportedNonDistinctAggCallCount = nonDistinctAggCalls.stream()
.filter(aggCall -> {
final SqlKind aggCallKind = aggCall.getAggregation().getKind();
// We only support COUNT/SUM/MIN/MAX for the "single" count distinct optimization
switch (aggCallKind) {
case COUNT:
case SUM:
case SUM0:
case MIN:
case MAX:
return false;
default:
return true;
}
}).count();
// Argument list of distinct agg calls.
final Set<Pair<List<Integer>, Integer>> distinctCallArgLists = distinctAggCalls.stream()
.map(aggCall -> Pair.of(aggCall.getArgList(), aggCall.filterArg))
.collect(Collectors.toCollection(LinkedHashSet::new));
checkState(!distinctCallArgLists.isEmpty(), "containsDistinctCall lied");
// If all of the agg expressions are distinct and have the same
// arguments then we can use a more efficient form.
// MAX, MIN, BIT_AND, BIT_OR always ignore distinct attribute,
// when they are mixed in with other distinct agg calls,
// we can still use this promotion.
// Treat the agg expression with Optionality.IGNORED as distinct and
// re-statistic the non-distinct agg call count and the distinct agg
// call arguments.
final List<AggregateCall> nonDistinctAggCallsOfIgnoredOptionality =
nonDistinctAggCalls.stream().filter(aggCall ->
aggCall.getAggregation().getDistinctOptionality() == Optionality.IGNORED)
.collect(Collectors.toList());
// Different with distinctCallArgLists, this list also contains args that come from
// agg call which can ignore the distinct constraint.
final Set<Pair<List<Integer>, Integer>> distinctCallArgLists2 =
Stream.of(distinctAggCalls, nonDistinctAggCallsOfIgnoredOptionality)
.flatMap(Collection::stream)
.map(aggCall -> Pair.of(aggCall.getArgList(), aggCall.filterArg))
.collect(Collectors.toCollection(LinkedHashSet::new));
if ((nonDistinctAggCalls.size() - nonDistinctAggCallsOfIgnoredOptionality.size()) == 0
&& distinctCallArgLists2.size() == 1
&& aggregate.getGroupType() == Group.SIMPLE) {
final Pair<List<Integer>, Integer> pair =
Iterables.getOnlyElement(distinctCallArgLists2);
final RelBuilder relBuilder = call.builder();
convertMonopole(relBuilder, aggregate, pair.left, pair.right);
call.transformTo(relBuilder.build());
return;
}
if (config.isUsingGroupingSets()) {
rewriteUsingGroupingSets(call, aggregate);
return;
}
// If only one distinct aggregate and one or more non-distinct aggregates,
// we can generate multi-phase aggregates
if (distinctAggCalls.size() == 1 // one distinct aggregate
&& filterCount == 0 // no filter
&& unsupportedNonDistinctAggCallCount == 0 // sum/min/max/count in non-distinct aggregate
&& nonDistinctAggCalls.size() > 0) { // one or more non-distinct aggregates
final RelBuilder relBuilder = call.builder();
convertSingletonDistinct(relBuilder, aggregate, distinctCallArgLists);
call.transformTo(relBuilder.build());
return;
}
// Create a list of the expressions which will yield the final result.
// Initially, the expressions point to the input field.
final List<RelDataTypeField> aggFields =
aggregate.getRowType().getFieldList();
final List<@Nullable RexInputRef> refs = new ArrayList<>();
final List<String> fieldNames = aggregate.getRowType().getFieldNames();
final ImmutableBitSet groupSet = aggregate.getGroupSet();
final int groupCount = aggregate.getGroupCount();
for (int i : Util.range(groupCount)) {
refs.add(RexInputRef.of(i, aggFields));
}
// Aggregate the original relation, including any non-distinct aggregates.
final List<AggregateCall> newAggCallList = new ArrayList<>();
int i = -1;
for (AggregateCall aggCall : aggregate.getAggCallList()) {
++i;
if (aggCall.isDistinct()) {
refs.add(null);
continue;
}
refs.add(
new RexInputRef(
groupCount + newAggCallList.size(),
aggFields.get(groupCount + i).getType()));
newAggCallList.add(aggCall);
}
// In the case where there are no non-distinct aggregates (regardless of
// whether there are group bys), there's no need to generate the
// extra aggregate and join.
final RelBuilder relBuilder = call.builder();
relBuilder.push(aggregate.getInput());
int n = 0;
if (!newAggCallList.isEmpty()) {
final RelBuilder.GroupKey groupKey =
relBuilder.groupKey(groupSet, aggregate.getGroupSets());
relBuilder.aggregate(groupKey, newAggCallList);
++n;
}
// For each set of operands, find and rewrite all calls which have that
// set of operands.
for (Pair<List<Integer>, Integer> argList : distinctCallArgLists) {
doRewrite(relBuilder, aggregate, n++, argList.left, argList.right, refs);
}
// It is assumed doRewrite above replaces nulls in refs
@SuppressWarnings("assignment.type.incompatible")
List<RexInputRef> nonNullRefs = refs;
relBuilder.project(nonNullRefs, fieldNames);
call.transformTo(relBuilder.build());
}
/**
* Converts an aggregate with one distinct aggregate and one or more
* non-distinct aggregates to multi-phase aggregates (see reference example
* below).
*
* @param relBuilder Contains the input relational expression
* @param aggregate Original aggregate
* @param argLists Arguments and filters to the distinct aggregate function
*
*/
private static RelBuilder convertSingletonDistinct(RelBuilder relBuilder,
Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
// In this case, we are assuming that there is a single distinct function.
// So make sure that argLists is of size one.
checkArgument(argLists.size() == 1);
// For example,
// SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
// FROM emp
// GROUP BY deptno
//
// becomes
//
// SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
// FROM (
// SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
// FROM EMP
// GROUP BY deptno, sal) // Aggregate B
// GROUP BY deptno // Aggregate A
relBuilder.push(aggregate.getInput());
final List<AggregateCall> originalAggCalls = aggregate.getAggCallList();
final ImmutableBitSet originalGroupSet = aggregate.getGroupSet();
// Add the distinct aggregate column(s) to the group-by columns,
// if not already a part of the group-by
final NavigableSet<Integer> bottomGroups = new TreeSet<>(aggregate.getGroupSet().asList());
for (AggregateCall aggCall : originalAggCalls) {
if (aggCall.isDistinct()) {
bottomGroups.addAll(aggCall.getArgList());
break; // since we only have single distinct call
}
}
final ImmutableBitSet bottomGroupSet = ImmutableBitSet.of(bottomGroups);
// Generate the intermediate aggregate B, the one on the bottom that converts
// a distinct call to group by call.
// Bottom aggregate is the same as the original aggregate, except that
// the bottom aggregate has converted the DISTINCT aggregate to a group by clause.
final List<AggregateCall> bottomAggregateCalls = new ArrayList<>();
for (AggregateCall aggCall : originalAggCalls) {
// Project the column corresponding to the distinct aggregate. Project
// as-is all the non-distinct aggregates
if (!aggCall.isDistinct()) {
final AggregateCall newCall =
AggregateCall.create(aggCall.getAggregation(), false,
aggCall.isApproximate(), aggCall.ignoreNulls(), aggCall.rexList,
aggCall.getArgList(), -1, aggCall.distinctKeys,
aggCall.collation, bottomGroupSet.cardinality(),
relBuilder.peek(), null, aggCall.name);
bottomAggregateCalls.add(newCall);
}
}
// Generate the aggregate B (see the reference example above)
relBuilder.push(
aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
bottomGroupSet, null, bottomAggregateCalls));
// Add aggregate A (see the reference example above), the top aggregate
// to handle the rest of the aggregation that the bottom aggregate hasn't handled
final List<AggregateCall> topAggregateCalls = new ArrayList<>();
// Use the remapped arguments for the (non)distinct aggregate calls
int nonDistinctAggCallProcessedSoFar = 0;
for (AggregateCall aggCall : originalAggCalls) {
final AggregateCall newCall;
if (aggCall.isDistinct()) {
List<Integer> newArgList = new ArrayList<>();
for (int arg : aggCall.getArgList()) {
newArgList.add(bottomGroups.headSet(arg, false).size());
}
newCall =
AggregateCall.create(aggCall.getAggregation(),
false,
aggCall.isApproximate(),
aggCall.ignoreNulls(),
aggCall.rexList,
newArgList,
-1,
aggCall.distinctKeys,
aggCall.collation,
originalGroupSet.cardinality(),
relBuilder.peek(),
aggCall.getType(),
aggCall.name);
} else {
// If aggregate B had a COUNT aggregate call the corresponding aggregate at
// aggregate A must be SUM. For other aggregates, it remains the same.
final int arg = bottomGroups.size() + nonDistinctAggCallProcessedSoFar;
final List<Integer> newArgs = ImmutableList.of(arg);
if (aggCall.getAggregation().getKind() == SqlKind.COUNT) {
newCall =
AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), false,
aggCall.isApproximate(), aggCall.ignoreNulls(),
aggCall.rexList, newArgs, -1, aggCall.distinctKeys,
aggCall.collation, originalGroupSet.cardinality(),
relBuilder.peek(), null, aggCall.getName());
} else {
newCall =
AggregateCall.create(aggCall.getAggregation(), false,
aggCall.isApproximate(), aggCall.ignoreNulls(),
aggCall.rexList, newArgs, -1, aggCall.distinctKeys,
aggCall.collation, originalGroupSet.cardinality(),
relBuilder.peek(), null, aggCall.name);
}
nonDistinctAggCallProcessedSoFar++;
}
topAggregateCalls.add(newCall);
}
// Populate the group-by keys with the remapped arguments for aggregate A
// The top groupset is basically an identity (first X fields of aggregate B's
// output), minus the distinct aggCall's input.
final Set<Integer> topGroupSet = new HashSet<>();
int groupSetToAdd = 0;
for (int bottomGroup : bottomGroups) {
if (originalGroupSet.get(bottomGroup)) {
topGroupSet.add(groupSetToAdd);
}
groupSetToAdd++;
}
relBuilder.push(
aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
ImmutableBitSet.of(topGroupSet), null, topAggregateCalls));
// Add projection node for case: SUM of COUNT(*):
// Type of the SUM may be larger than type of COUNT.
// CAST to original type must be added.
relBuilder.convert(aggregate.getRowType(), true);
return relBuilder;
}
private static void rewriteUsingGroupingSets(RelOptRuleCall call,
Aggregate aggregate) {
final Set<ImmutableBitSet> groupSetTreeSet =
new TreeSet<>(ImmutableBitSet.ORDERING);
// GroupSet to distinct filter arg map,
// filterArg will be -1 for non-distinct agg call.
// Using `Set` here because it's possible that two agg calls
// have different filterArgs but same groupSet.
final Map<ImmutableBitSet, Set<Integer>> distinctFilterArgMap = new HashMap<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
ImmutableBitSet groupSet;
int filterArg;
if (!aggCall.isDistinct()) {
filterArg = -1;
groupSet = aggregate.getGroupSet();
groupSetTreeSet.add(aggregate.getGroupSet());
} else {
filterArg = aggCall.filterArg;
groupSet =
ImmutableBitSet.of(aggCall.getArgList())
.setIf(filterArg, filterArg >= 0)
.union(aggregate.getGroupSet());
groupSetTreeSet.add(groupSet);
}
Set<Integer> filterList = distinctFilterArgMap
.computeIfAbsent(groupSet, g -> new HashSet<>());
filterList.add(filterArg);
}
final ImmutableList<ImmutableBitSet> groupSets =
ImmutableList.copyOf(groupSetTreeSet);
final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
final List<AggregateCall> distinctAggCalls = new ArrayList<>();
for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
if (!aggCall.left.isDistinct()) {
AggregateCall newAggCall =
aggCall.left.adaptTo(aggregate.getInput(),
aggCall.left.getArgList(), aggCall.left.filterArg,
aggregate.getGroupCount(), fullGroupSet.cardinality());
distinctAggCalls.add(newAggCall.withName(aggCall.right));
}
}
final RelBuilder relBuilder = call.builder();
relBuilder.push(aggregate.getInput());
final int groupCount = fullGroupSet.cardinality();
// Get the base ordinal of filter args for different groupSets.
final Map<Pair<ImmutableBitSet, Integer>, Integer> filters = new LinkedHashMap<>();
int z = groupCount + distinctAggCalls.size();
for (ImmutableBitSet groupSet : groupSets) {
Set<Integer> filterArgList = distinctFilterArgMap.get(groupSet);
for (Integer filterArg : requireNonNull(filterArgList, "filterArgList")) {
filters.put(Pair.of(groupSet, filterArg), z);
z += 1;
}
}
distinctAggCalls.add(
AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false, false,
ImmutableList.of(), ImmutableIntList.copyOf(fullGroupSet), -1,
null, RelCollations.EMPTY,
groupSets.size(), relBuilder.peek(), null, "$g"));
relBuilder.aggregate(
relBuilder.groupKey(fullGroupSet, groupSets),
distinctAggCalls);
// GROUPING returns an integer (0 or 1). Add a project to convert those
// values to BOOLEAN.
if (!filters.isEmpty()) {
final List<RexNode> nodes = new ArrayList<>(relBuilder.fields());
final RexNode nodeZ = nodes.remove(nodes.size() - 1);
for (Map.Entry<Pair<ImmutableBitSet, Integer>, Integer> entry : filters.entrySet()) {
final long v = groupValue(fullGroupSet.asList(), entry.getKey().left);
int distinctFilterArg = remap(fullGroupSet, entry.getKey().right);
RexNode expr = relBuilder.equals(nodeZ, relBuilder.literal(v));
if (distinctFilterArg > -1) {
// 'AND' the filter of the distinct aggregate call and the group value.
expr =
relBuilder.and(expr,
relBuilder.call(SqlStdOperatorTable.IS_TRUE,
relBuilder.field(distinctFilterArg)));
}
// "f" means filter.
nodes.add(
relBuilder.alias(expr,
"$g_" + v + (distinctFilterArg < 0 ? "" : "_f_" + distinctFilterArg)));
}
relBuilder.project(nodes);
}
int x = groupCount;
final ImmutableBitSet groupSet = aggregate.getGroupSet();
final List<AggregateCall> newCalls = new ArrayList<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
final int newFilterArg;
final List<Integer> newArgList;
final SqlAggFunction aggregation;
if (!aggCall.isDistinct()) {
aggregation = SqlStdOperatorTable.MIN;
newArgList = ImmutableIntList.of(x++);
newFilterArg =
requireNonNull(filters.get(Pair.of(groupSet, -1)),
"filters.get(Pair.of(groupSet, -1))");
} else {
aggregation = aggCall.getAggregation();
newArgList = remap(fullGroupSet, aggCall.getArgList());
final ImmutableBitSet newGroupSet = ImmutableBitSet.of(aggCall.getArgList())
.setIf(aggCall.filterArg, aggCall.filterArg >= 0)
.union(groupSet);
newFilterArg =
requireNonNull(filters.get(Pair.of(newGroupSet, aggCall.filterArg)),
"filters.get(of(newGroupSet, aggCall.filterArg))");
}
final AggregateCall newCall =
AggregateCall.create(aggregation, false,
aggCall.isApproximate(), aggCall.ignoreNulls(),
aggCall.rexList, newArgList, newFilterArg,
aggCall.distinctKeys, aggCall.collation,
aggregate.getGroupCount(), relBuilder.peek(), null, aggCall.name);
newCalls.add(newCall);
}
relBuilder.aggregate(
relBuilder.groupKey(
remap(fullGroupSet, groupSet),
remap(fullGroupSet, aggregate.getGroupSets())),
newCalls);
relBuilder.convert(aggregate.getRowType(), true);
call.transformTo(relBuilder.build());
}
/** Returns the value that "GROUPING(fullGroupSet)" will return for
* "groupSet".
*
* <p>It is important that {@code fullGroupSet} is not an
* {@link ImmutableBitSet}; the order of the bits matters. */
static long groupValue(Collection<Integer> fullGroupSet,
ImmutableBitSet groupSet) {
long v = 0;
long x = 1L << (fullGroupSet.size() - 1);
assert ImmutableBitSet.of(fullGroupSet).contains(groupSet);
for (int i : fullGroupSet) {
if (!groupSet.get(i)) {
v |= x;
}
x >>= 1;
}
return v;
}
static ImmutableBitSet remap(ImmutableBitSet groupSet,
ImmutableBitSet bitSet) {
final ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
for (Integer bit : bitSet) {
builder.set(remap(groupSet, bit));
}
return builder.build();
}
static ImmutableList<ImmutableBitSet> remap(ImmutableBitSet groupSet,
Iterable<ImmutableBitSet> bitSets) {
final ImmutableList.Builder<ImmutableBitSet> builder =
ImmutableList.builder();
for (ImmutableBitSet bitSet : bitSets) {
builder.add(remap(groupSet, bitSet));
}
return builder.build();
}
private static List<Integer> remap(ImmutableBitSet groupSet,
List<Integer> argList) {
ImmutableIntList list = ImmutableIntList.of();
for (int arg : argList) {
list = list.append(remap(groupSet, arg));
}
return list;
}
private static int remap(ImmutableBitSet groupSet, int arg) {
return arg < 0 ? -1 : groupSet.indexOf(arg);
}
/**
* Converts an aggregate relational expression that contains just one
* distinct aggregate function (or perhaps several over the same arguments)
* and no non-distinct aggregate functions.
*/
private static RelBuilder convertMonopole(RelBuilder relBuilder, Aggregate aggregate,
List<Integer> argList, int filterArg) {
// For example,
// SELECT deptno, COUNT(DISTINCT sal), SUM(DISTINCT sal)
// FROM emp
// GROUP BY deptno
//
// becomes
//
// SELECT deptno, COUNT(distinct_sal), SUM(distinct_sal)
// FROM (
// SELECT DISTINCT deptno, sal AS distinct_sal
// FROM EMP GROUP BY deptno)
// GROUP BY deptno
// Project the columns of the GROUP BY plus the arguments
// to the agg function.
final Map<Integer, Integer> sourceOf = new HashMap<>();
createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);
// Create an aggregate on top, with the new aggregate list.
final List<AggregateCall> newAggCalls =
Lists.newArrayList(aggregate.getAggCallList());
rewriteAggCalls(newAggCalls, argList, sourceOf);
final int cardinality = aggregate.getGroupSet().cardinality();
relBuilder.push(
aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
ImmutableBitSet.range(cardinality), null, newAggCalls));
return relBuilder;
}
/**
* Converts all distinct aggregate calls to a given set of arguments.
*
* <p>This method is called several times, one for each set of arguments.
* Each time it is called, it generates a JOIN to a new SELECT DISTINCT
* relational expression, and modifies the set of top-level calls.
*
* @param aggregate Original aggregate
* @param n Ordinal of this in a join. {@code relBuilder} contains the
* input relational expression (either the original
* aggregate, the output from the previous call to this
* method. {@code n} is 0 if we're converting the
* first distinct aggregate in a query with no non-distinct
* aggregates)
* @param argList Arguments to the distinct aggregate function
* @param filterArg Argument that filters input to aggregate function, or -1
* @param refs Array of expressions which will be the projected by the
* result of this rule. Those relating to this arg list will
* be modified
*/
private static void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n,
List<Integer> argList, int filterArg, List<@Nullable RexInputRef> refs) {
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final List<RelDataTypeField> leftFields;
if (n == 0) {
leftFields = null;
} else {
leftFields = relBuilder.peek().getRowType().getFieldList();
}
// Aggregate(
// child,
// {COUNT(DISTINCT 1), SUM(DISTINCT 1), SUM(2)})
//
// becomes
//
// Aggregate(
// Join(
// child,
// Aggregate(child, < all columns > {}),
// INNER,
// <f2 = f5>))
//
// E.g.
// SELECT deptno, SUM(DISTINCT sal), COUNT(DISTINCT gender), MAX(age)
// FROM Emps
// GROUP BY deptno
//
// becomes
//
// SELECT e.deptno, adsal.sum_sal, adgender.count_gender, e.max_age
// FROM (
// SELECT deptno, MAX(age) as max_age
// FROM Emps GROUP BY deptno) AS e
// JOIN (
// SELECT deptno, COUNT(gender) AS count_gender FROM (
// SELECT DISTINCT deptno, gender FROM Emps) AS dgender
// GROUP BY deptno) AS adgender
// ON e.deptno = adgender.deptno
// JOIN (
// SELECT deptno, SUM(sal) AS sum_sal FROM (
// SELECT DISTINCT deptno, sal FROM Emps) AS dsal
// GROUP BY deptno) AS adsal
// ON e.deptno = adsal.deptno
// GROUP BY e.deptno
//
// Note that if a query contains no non-distinct aggregates, then the
// very first join/group by is omitted. In the example above, if
// MAX(age) is removed, then the sub-select of "e" is not needed, and
// instead the two other group by's are joined to one another.
// Project the columns of the GROUP BY plus the arguments
// to the agg function.
final Map<Integer, Integer> sourceOf = new HashMap<>();
createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);
// Now compute the aggregate functions on top of the distinct dataset.
// Each distinct agg becomes a non-distinct call to the corresponding
// field from the right; for example,
// "COUNT(DISTINCT e.sal)"
// becomes
// "COUNT(distinct_e.sal)".
final List<AggregateCall> aggCallList = new ArrayList<>();
final List<AggregateCall> aggCalls = aggregate.getAggCallList();
final int groupCount = aggregate.getGroupCount();
int i = groupCount - 1;
for (AggregateCall aggCall : aggCalls) {
++i;
// Ignore agg calls which are not distinct or have the wrong set
// arguments. If we're rewriting aggs whose args are {sal}, we will
// rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) but ignore
// COUNT(DISTINCT gender) or SUM(sal).
if (!aggCall.isDistinct()) {
continue;
}
if (!aggCall.getArgList().equals(argList)) {
continue;
}
// Re-map arguments.
final int argCount = aggCall.getArgList().size();
final List<Integer> newArgs = new ArrayList<>(argCount);
for (Integer arg : aggCall.getArgList()) {
newArgs.add(requireNonNull(sourceOf.get(arg), () -> "sourceOf.get(" + arg + ")"));
}
final int newFilterArg =
aggCall.filterArg < 0 ? -1
: requireNonNull(sourceOf.get(aggCall.filterArg),
() -> "sourceOf.get(" + aggCall.filterArg + ")");
final AggregateCall newAggCall =
AggregateCall.create(aggCall.getAggregation(), false,
aggCall.isApproximate(), aggCall.ignoreNulls(), aggCall.rexList,
newArgs, newFilterArg, aggCall.distinctKeys, aggCall.collation,
aggCall.getType(), aggCall.getName());
assert refs.get(i) == null;
if (leftFields == null) {
refs.set(i,
new RexInputRef(groupCount + aggCallList.size(),
newAggCall.getType()));
} else {
refs.set(i,
new RexInputRef(leftFields.size() + groupCount
+ aggCallList.size(), newAggCall.getType()));
}
aggCallList.add(newAggCall);
}
final Map<Integer, Integer> map = new HashMap<>();
for (Integer key : aggregate.getGroupSet()) {
map.put(key, map.size());
}
final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map);
assert newGroupSet
.equals(ImmutableBitSet.range(aggregate.getGroupSet().cardinality()));
ImmutableList<ImmutableBitSet> newGroupingSets = null;
relBuilder.push(
aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
newGroupSet, newGroupingSets, aggCallList));
// If there's no left child yet, no need to create the join
if (leftFields == null) {
return;
}
// Create the join condition. It is of the form
// 'left.f0 = right.f0 and left.f1 = right.f1 and ...'
// where {f0, f1, ...} are the GROUP BY fields.
final List<RelDataTypeField> distinctFields =
relBuilder.peek().getRowType().getFieldList();
final List<RexNode> conditions = new ArrayList<>();
for (i = 0; i < groupCount; ++i) {
// null values form its own group
// use "is not distinct from" so that the join condition
// allows null values to match.
conditions.add(
rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM,
RexInputRef.of(i, leftFields),
new RexInputRef(leftFields.size() + i,
distinctFields.get(i).getType())));
}
// Join in the new 'select distinct' relation.
relBuilder.join(JoinRelType.INNER, conditions);
}
private static void rewriteAggCalls(
List<AggregateCall> newAggCalls,
List<Integer> argList,
Map<Integer, Integer> sourceOf) {
// Rewrite the agg calls. Each distinct agg becomes a non-distinct call
// to the corresponding field from the right; for example,
// "COUNT(DISTINCT e.sal)" becomes "COUNT(distinct_e.sal)".
for (int i = 0; i < newAggCalls.size(); i++) {
final AggregateCall aggCall = newAggCalls.get(i);
// Ignore agg calls which are not distinct or have the wrong set
// arguments. If we're rewriting aggregates whose args are {sal}, we will
// rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) but ignore
// COUNT(DISTINCT gender) or SUM(sal).
if (!aggCall.isDistinct()
&& aggCall.getAggregation().getDistinctOptionality() != Optionality.IGNORED) {
continue;
}
if (!aggCall.getArgList().equals(argList)) {
continue;
}
// Re-map arguments.
final int argCount = aggCall.getArgList().size();
final List<Integer> newArgs = new ArrayList<>(argCount);
for (int j = 0; j < argCount; j++) {
final Integer arg = aggCall.getArgList().get(j);
newArgs.add(
requireNonNull(sourceOf.get(arg),
() -> "sourceOf.get(" + arg + ")"));
}
final AggregateCall newAggCall =
AggregateCall.create(aggCall.getAggregation(), false,
aggCall.isApproximate(), aggCall.ignoreNulls(),
aggCall.rexList, newArgs, -1,
aggCall.distinctKeys, aggCall.collation,
aggCall.getType(), aggCall.getName());
newAggCalls.set(i, newAggCall);
}
}
/**
* Given an {@link org.apache.calcite.rel.core.Aggregate}
* and the ordinals of the arguments to a
* particular call to an aggregate function, creates a 'select distinct'
* relational expression which projects the group columns and those
* arguments but nothing else.
*
* <p>For example, given
*
* <blockquote>
* <pre>select f0, count(distinct f1), count(distinct f2)
* from t group by f0</pre>
* </blockquote>
*
* <p>and the argument list
*
* <blockquote>{2}</blockquote>
*
* <p>returns
*
* <blockquote>
* <pre>select distinct f0, f2 from t</pre>
* </blockquote>
*
* <p>The <code>sourceOf</code> map is populated with the source of each
* column; in this case sourceOf.get(0) = 0, and sourceOf.get(1) = 2.
*
* @param relBuilder Relational expression builder
* @param aggregate Aggregate relational expression
* @param argList Ordinals of columns to make distinct
* @param filterArg Ordinal of column to filter on, or -1
* @param sourceOf Out parameter, is populated with a map of where each
* output field came from
* @return Aggregate relational expression which projects the required
* columns
*/
private static RelBuilder createSelectDistinct(RelBuilder relBuilder,
Aggregate aggregate, List<Integer> argList, int filterArg,
Map<Integer, Integer> sourceOf) {
relBuilder.push(aggregate.getInput());
final PairList<RexNode, String> projects = PairList.of();
final List<RelDataTypeField> childFields =
relBuilder.peek().getRowType().getFieldList();
for (int i : aggregate.getGroupSet()) {
sourceOf.put(i, projects.size());
RexInputRef.add2(projects, i, childFields);
}
for (Integer arg : argList) {
if (filterArg >= 0) {
// Implement
// agg(DISTINCT arg) FILTER $f
// by generating
// SELECT DISTINCT ... CASE WHEN $f THEN arg ELSE NULL END AS arg
// and then applying
// agg(arg)
// as usual.
//
// It works except for (rare) agg functions that need to see null
// values.
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final RexInputRef filterRef = RexInputRef.of(filterArg, childFields);
final Pair<RexNode, String> argRef = RexInputRef.of2(arg, childFields);
RexNode condition =
rexBuilder.makeCall(SqlStdOperatorTable.CASE, filterRef,
argRef.left,
rexBuilder.makeNullLiteral(argRef.left.getType()));
sourceOf.put(arg, projects.size());
projects.add(condition, "i$" + argRef.right);
continue;
}
if (sourceOf.get(arg) != null) {
continue;
}
sourceOf.put(arg, projects.size());
RexInputRef.add2(projects, arg, childFields);
}
relBuilder.project(projects.leftList(), projects.rightList());
// Get the distinct values of the GROUP BY fields and the arguments
// to the agg functions.
relBuilder.push(
aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
ImmutableBitSet.range(projects.size()), null, ImmutableList.of()));
return relBuilder;
}
/** Rule configuration. */
@Value.Immutable
public interface Config extends RelRule.Config {
Config DEFAULT = ImmutableAggregateExpandDistinctAggregatesRule.Config.of()
.withOperandSupplier(b ->
b.operand(LogicalAggregate.class).anyInputs());
Config JOIN = DEFAULT.withUsingGroupingSets(false);
@Override default AggregateExpandDistinctAggregatesRule toRule() {
return new AggregateExpandDistinctAggregatesRule(this);
}
/** Whether to use GROUPING SETS, default true. */
@Value.Default default boolean isUsingGroupingSets() {
return true;
}
/** Sets {@link #isUsingGroupingSets()}. */
Config withUsingGroupingSets(boolean usingGroupingSets);
}
}