| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to you under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.flink.table.plan.rules.logical; |
| |
| import org.apache.flink.table.api.TableException; |
| import org.apache.flink.table.errorcode.TableErrors; |
| import org.apache.flink.util.Preconditions; |
| |
| import org.apache.calcite.linq4j.Ord; |
| import org.apache.calcite.plan.Contexts; |
| import org.apache.calcite.plan.RelOptRule; |
| import org.apache.calcite.plan.RelOptRuleCall; |
| import org.apache.calcite.rel.RelNode; |
| import org.apache.calcite.rel.core.Aggregate; |
| import org.apache.calcite.rel.core.Aggregate.Group; |
| import org.apache.calcite.rel.core.AggregateCall; |
| import org.apache.calcite.rel.core.JoinRelType; |
| import org.apache.calcite.rel.core.RelFactories; |
| import org.apache.calcite.rel.logical.LogicalAggregate; |
| import org.apache.calcite.rel.type.RelDataTypeField; |
| import org.apache.calcite.rex.RexBuilder; |
| import org.apache.calcite.rex.RexInputRef; |
| import org.apache.calcite.rex.RexNode; |
| import org.apache.calcite.sql.SqlAggFunction; |
| import org.apache.calcite.sql.SqlKind; |
| import org.apache.calcite.sql.fun.SqlStdOperatorTable; |
| import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction; |
| import org.apache.calcite.tools.RelBuilder; |
| import org.apache.calcite.tools.RelBuilderFactory; |
| import org.apache.calcite.util.ImmutableBitSet; |
| import org.apache.calcite.util.ImmutableIntList; |
| import org.apache.calcite.util.Pair; |
| import org.apache.calcite.util.Util; |
| |
| import java.math.BigDecimal; |
| import java.util.ArrayList; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.LinkedHashMap; |
| import java.util.LinkedHashSet; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Set; |
| import java.util.SortedSet; |
| import java.util.TreeSet; |
| |
| /** |
| * This rules is copied from Calcite's {@link org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule}. |
| * Modification: |
| * - Throws an exception if aggregate both contains approximate distinct aggregate call and distinct |
| * aggregate call. |
| * - Exclude non-simple aggregate(e.g. CUBE, ROLLUP). |
| * - Fix bug: Some aggregate functions (e.g. COUNT) has a non-null result even without any input. |
| * - Fix bug: Add filter argument into rewritten aggregateCall if filter arg is not -1. |
| */ |
| |
| /** |
| * Planner rule that expands distinct aggregates |
| * (such as {@code COUNT(DISTINCT x)}) from a |
| * {@link org.apache.calcite.rel.core.Aggregate}. |
| * |
| * <p>How this is done depends upon the arguments to the function. If all |
| * functions have the same argument |
| * (e.g. {@code COUNT(DISTINCT x), SUM(DISTINCT x)} both have the argument |
| * {@code x}) then one extra {@link org.apache.calcite.rel.core.Aggregate} is |
| * sufficient. |
| * |
| * <p>If there are multiple arguments |
| * (e.g. {@code COUNT(DISTINCT x), COUNT(DISTINCT y)}) |
| * the rule creates separate {@code Aggregate}s and combines using a |
| * {@link org.apache.calcite.rel.core.Join}. |
| */ |
| public final class FlinkAggregateExpandDistinctAggregatesRule extends RelOptRule { |
| //~ Static fields/initializers --------------------------------------------- |
| |
| /** The default instance of the rule; operates only on logical expressions. */ |
| public static final FlinkAggregateExpandDistinctAggregatesRule INSTANCE = |
| new FlinkAggregateExpandDistinctAggregatesRule(LogicalAggregate.class, true, |
| RelFactories.LOGICAL_BUILDER); |
| |
| /** Instance of the rule that operates only on logical expressions and |
| * generates a join. */ |
| public static final FlinkAggregateExpandDistinctAggregatesRule JOIN = |
| new FlinkAggregateExpandDistinctAggregatesRule(LogicalAggregate.class, false, |
| RelFactories.LOGICAL_BUILDER); |
| |
| public final boolean useGroupingSets; |
| |
| //~ Constructors ----------------------------------------------------------- |
| |
| public FlinkAggregateExpandDistinctAggregatesRule( |
| Class<? extends Aggregate> clazz, |
| boolean useGroupingSets, |
| RelBuilderFactory relBuilderFactory) { |
| super(operand(clazz, any()), relBuilderFactory, null); |
| this.useGroupingSets = useGroupingSets; |
| } |
| |
| @Deprecated // to be removed before 2.0 |
| public FlinkAggregateExpandDistinctAggregatesRule( |
| Class<? extends LogicalAggregate> clazz, |
| boolean useGroupingSets, |
| RelFactories.JoinFactory joinFactory) { |
| this(clazz, useGroupingSets, RelBuilder.proto(Contexts.of(joinFactory))); |
| } |
| |
| @Deprecated // to be removed before 2.0 |
| public FlinkAggregateExpandDistinctAggregatesRule( |
| Class<? extends LogicalAggregate> clazz, |
| RelFactories.JoinFactory joinFactory) { |
| this(clazz, false, RelBuilder.proto(Contexts.of(joinFactory))); |
| } |
| |
| //~ Methods ---------------------------------------------------------------- |
| |
| public void onMatch(RelOptRuleCall call) { |
| final Aggregate aggregate = call.rel(0); |
| if (!aggregate.containsAccurateDistinctCall()) { |
| return; |
| } |
| // Check unsupported aggregate which contains both approximate distinct call and |
| // accurate distinct call. |
| if (aggregate.containsApproximateDistinctCall()) { |
| throw new TableException(TableErrors.INST.sqlDistinctConflict()); |
| } |
| |
| // If this aggregate is a non-simple aggregate(e.g. CUBE, ROLLUP) |
| // and contains distinct calls, it should be transformed to simple aggregate first |
| // by DecomposeGroupingSetsRule. Then this rule expands it's distinct aggregates. |
| if (aggregate.getGroupSets().size() > 1) { |
| return; |
| } |
| |
| // Find all of the agg expressions. We use a LinkedHashSet to ensure determinism. |
| int nonDistinctAggCallCount = 0; // find all aggregate calls without distinct |
| int filterCount = 0; |
| int unsupportedNonDistinctAggCallCount = 0; |
| final Set<Pair<List<Integer>, Integer>> argLists = new LinkedHashSet<>(); |
| for (AggregateCall aggCall : aggregate.getAggCallList()) { |
| if (aggCall.filterArg >= 0) { |
| ++filterCount; |
| } |
| if (!aggCall.isDistinct()) { |
| ++nonDistinctAggCallCount; |
| final SqlKind aggCallKind = aggCall.getAggregation().getKind(); |
| // We only support COUNT/SUM/MIN/MAX for the "single" count distinct optimization |
| switch (aggCallKind) { |
| case COUNT: |
| case SUM: |
| case SUM0: |
| case MIN: |
| case MAX: |
| break; |
| default: |
| ++unsupportedNonDistinctAggCallCount; |
| } |
| } else { |
| argLists.add(Pair.of(aggCall.getArgList(), aggCall.filterArg)); |
| } |
| } |
| |
| final int distinctAggCallCount = |
| aggregate.getAggCallList().size() - nonDistinctAggCallCount; |
| Preconditions.checkState(argLists.size() > 0, "containsDistinctCall lied"); |
| |
| // If all of the agg expressions are distinct and have the same |
| // arguments then we can use a more efficient form. |
| if (nonDistinctAggCallCount == 0 |
| && argLists.size() == 1 |
| && aggregate.getGroupType() == Group.SIMPLE) { |
| final Pair<List<Integer>, Integer> pair = |
| com.google.common.collect.Iterables.getOnlyElement(argLists); |
| final RelBuilder relBuilder = call.builder(); |
| convertMonopole(relBuilder, aggregate, pair.left, pair.right); |
| call.transformTo(relBuilder.build()); |
| return; |
| } |
| |
| if (useGroupingSets) { |
| rewriteUsingGroupingSets(call, aggregate); |
| return; |
| } |
| |
| // If only one distinct aggregate and one or more non-distinct aggregates, |
| // we can generate multi-phase aggregates |
| if (distinctAggCallCount == 1 // one distinct aggregate |
| && filterCount == 0 // no filter |
| && unsupportedNonDistinctAggCallCount == 0 // sum/min/max/count in non-distinct aggregate |
| && nonDistinctAggCallCount > 0) { // one or more non-distinct aggregates |
| final RelBuilder relBuilder = call.builder(); |
| convertSingletonDistinct(relBuilder, aggregate, argLists); |
| call.transformTo(relBuilder.build()); |
| return; |
| } |
| |
| // Create a list of the expressions which will yield the final result. |
| // Initially, the expressions point to the input field. |
| final List<RelDataTypeField> aggFields = |
| aggregate.getRowType().getFieldList(); |
| final List<RexInputRef> refs = new ArrayList<>(); |
| final List<String> fieldNames = aggregate.getRowType().getFieldNames(); |
| final ImmutableBitSet groupSet = aggregate.getGroupSet(); |
| final int groupAndIndicatorCount = |
| aggregate.getGroupCount() + aggregate.getIndicatorCount(); |
| for (int i : Util.range(groupAndIndicatorCount)) { |
| refs.add(RexInputRef.of(i, aggFields)); |
| } |
| |
| // Aggregate the original relation, including any non-distinct aggregates. |
| final List<AggregateCall> newAggCallList = new ArrayList<>(); |
| int i = -1; |
| for (AggregateCall aggCall : aggregate.getAggCallList()) { |
| ++i; |
| if (aggCall.isDistinct()) { |
| refs.add(null); |
| continue; |
| } |
| refs.add( |
| new RexInputRef( |
| groupAndIndicatorCount + newAggCallList.size(), |
| aggFields.get(groupAndIndicatorCount + i).getType())); |
| newAggCallList.add(aggCall); |
| } |
| |
| // In the case where there are no non-distinct aggregates (regardless of |
| // whether there are group bys), there's no need to generate the |
| // extra aggregate and join. |
| final RelBuilder relBuilder = call.builder(); |
| relBuilder.push(aggregate.getInput()); |
| int n = 0; |
| if (!newAggCallList.isEmpty()) { |
| final RelBuilder.GroupKey groupKey = |
| relBuilder.groupKey(groupSet, aggregate.getGroupSets()); |
| relBuilder.aggregate(groupKey, newAggCallList); |
| ++n; |
| } |
| |
| // For each set of operands, find and rewrite all calls which have that |
| // set of operands. |
| for (Pair<List<Integer>, Integer> argList : argLists) { |
| doRewrite(relBuilder, aggregate, n++, argList.left, argList.right, refs); |
| } |
| |
| relBuilder.project(refs, fieldNames); |
| call.transformTo(relBuilder.build()); |
| } |
| |
| /** |
| * Converts an aggregate with one distinct aggregate and one or more |
| * non-distinct aggregates to multi-phase aggregates (see reference example |
| * below). |
| * |
| * @param relBuilder Contains the input relational expression |
| * @param aggregate Original aggregate |
| * @param argLists Arguments and filters to the distinct aggregate function |
| * |
| */ |
| private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, |
| Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) { |
| |
| // In this case, we are assuming that there is a single distinct function. |
| // So make sure that argLists is of size one. |
| Preconditions.checkArgument(argLists.size() == 1); |
| |
| // For example, |
| // SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal) |
| // FROM emp |
| // GROUP BY deptno |
| // |
| // becomes |
| // |
| // SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal) |
| // FROM ( |
| // SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal |
| // FROM EMP |
| // GROUP BY deptno, sal) // Aggregate B |
| // GROUP BY deptno // Aggregate A |
| relBuilder.push(aggregate.getInput()); |
| |
| final List<AggregateCall> originalAggCalls = aggregate.getAggCallList(); |
| final ImmutableBitSet originalGroupSet = aggregate.getGroupSet(); |
| |
| // Add the distinct aggregate column(s) to the group-by columns, |
| // if not already a part of the group-by |
| final SortedSet<Integer> bottomGroupSet = new TreeSet<>(); |
| bottomGroupSet.addAll(aggregate.getGroupSet().asList()); |
| for (AggregateCall aggCall : originalAggCalls) { |
| if (aggCall.isDistinct()) { |
| bottomGroupSet.addAll(aggCall.getArgList()); |
| break; // since we only have single distinct call |
| } |
| } |
| |
| // Generate the intermediate aggregate B, the one on the bottom that converts |
| // a distinct call to group by call. |
| // Bottom aggregate is the same as the original aggregate, except that |
| // the bottom aggregate has converted the DISTINCT aggregate to a group by clause. |
| final List<AggregateCall> bottomAggregateCalls = new ArrayList<>(); |
| for (AggregateCall aggCall : originalAggCalls) { |
| // Project the column corresponding to the distinct aggregate. Project |
| // as-is all the non-distinct aggregates |
| if (!aggCall.isDistinct()) { |
| final AggregateCall newCall = |
| AggregateCall.create(aggCall.getAggregation(), false, |
| aggCall.isApproximate(), aggCall.getArgList(), -1, |
| ImmutableBitSet.of(bottomGroupSet).cardinality(), |
| relBuilder.peek(), null, aggCall.name); |
| bottomAggregateCalls.add(newCall); |
| } |
| } |
| // Generate the aggregate B (see the reference example above) |
| relBuilder.push( |
| aggregate.copy( |
| aggregate.getTraitSet(), relBuilder.build(), |
| false, ImmutableBitSet.of(bottomGroupSet), null, bottomAggregateCalls)); |
| |
| // Add aggregate A (see the reference example above), the top aggregate |
| // to handle the rest of the aggregation that the bottom aggregate hasn't handled |
| final List<AggregateCall> topAggregateCalls = com.google.common.collect.Lists.newArrayList(); |
| // Use the remapped arguments for the (non)distinct aggregate calls |
| int nonDistinctAggCallProcessedSoFar = 0; |
| for (AggregateCall aggCall : originalAggCalls) { |
| final AggregateCall newCall; |
| if (aggCall.isDistinct()) { |
| List<Integer> newArgList = new ArrayList<>(); |
| for (int arg : aggCall.getArgList()) { |
| newArgList.add(bottomGroupSet.headSet(arg).size()); |
| } |
| newCall = |
| AggregateCall.create(aggCall.getAggregation(), |
| false, |
| aggCall.isApproximate(), |
| newArgList, |
| -1, |
| originalGroupSet.cardinality(), |
| relBuilder.peek(), |
| aggCall.getType(), |
| aggCall.name); |
| } else { |
| // If aggregate B had a COUNT aggregate call the corresponding aggregate at |
| // aggregate A must be SUM. For other aggregates, it remains the same. |
| final List<Integer> newArgs = |
| com.google.common.collect.Lists.newArrayList( |
| bottomGroupSet.size() + nonDistinctAggCallProcessedSoFar); |
| if (aggCall.getAggregation().getKind() == SqlKind.COUNT) { |
| newCall = |
| AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), false, |
| aggCall.isApproximate(), newArgs, -1, |
| originalGroupSet.cardinality(), relBuilder.peek(), |
| aggCall.getType(), aggCall.getName()); |
| } else { |
| newCall = |
| AggregateCall.create(aggCall.getAggregation(), false, |
| aggCall.isApproximate(), newArgs, -1, |
| originalGroupSet.cardinality(), |
| relBuilder.peek(), aggCall.getType(), aggCall.name); |
| } |
| nonDistinctAggCallProcessedSoFar++; |
| } |
| |
| topAggregateCalls.add(newCall); |
| } |
| |
| // Populate the group-by keys with the remapped arguments for aggregate A |
| // The top groupset is basically an identity (first X fields of aggregate B's |
| // output), minus the distinct aggCall's input. |
| final Set<Integer> topGroupSet = new HashSet<>(); |
| int groupSetToAdd = 0; |
| for (int bottomGroup : bottomGroupSet) { |
| if (originalGroupSet.get(bottomGroup)) { |
| topGroupSet.add(groupSetToAdd); |
| } |
| groupSetToAdd++; |
| } |
| relBuilder.push( |
| aggregate.copy(aggregate.getTraitSet(), |
| relBuilder.build(), aggregate.indicator, |
| ImmutableBitSet.of(topGroupSet), null, topAggregateCalls)); |
| return relBuilder; |
| } |
| |
| private void rewriteUsingGroupingSets(RelOptRuleCall call, |
| Aggregate aggregate) { |
| final Set<ImmutableBitSet> groupSetTreeSet = |
| new TreeSet<>(ImmutableBitSet.ORDERING); |
| for (AggregateCall aggCall : aggregate.getAggCallList()) { |
| if (!aggCall.isDistinct()) { |
| groupSetTreeSet.add(aggregate.getGroupSet()); |
| } else { |
| groupSetTreeSet.add( |
| ImmutableBitSet.of(aggCall.getArgList()) |
| .setIf(aggCall.filterArg, aggCall.filterArg >= 0) |
| .union(aggregate.getGroupSet())); |
| } |
| } |
| |
| final com.google.common.collect.ImmutableList<ImmutableBitSet> groupSets = |
| com.google.common.collect.ImmutableList.copyOf(groupSetTreeSet); |
| final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets); |
| |
| final List<AggregateCall> distinctAggCalls = new ArrayList<>(); |
| for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) { |
| if (!aggCall.left.isDistinct()) { |
| AggregateCall newAggCall = aggCall.left.adaptTo( |
| aggregate.getInput(), |
| aggCall.left.getArgList(), |
| aggCall.left.filterArg, |
| aggregate.getGroupCount(), |
| fullGroupSet.cardinality()); |
| distinctAggCalls.add(newAggCall.rename(aggCall.right)); |
| } |
| } |
| |
| final RelBuilder relBuilder = call.builder(); |
| relBuilder.push(aggregate.getInput()); |
| final int groupCount = fullGroupSet.cardinality(); |
| |
| final Map<ImmutableBitSet, Integer> filters = new LinkedHashMap<>(); |
| final int z = groupCount + distinctAggCalls.size(); |
| distinctAggCalls.add( |
| AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false, |
| ImmutableIntList.copyOf(fullGroupSet), -1, groupSets.size(), |
| relBuilder.peek(), null, "$g")); |
| for (Ord<ImmutableBitSet> groupSet : Ord.zip(groupSets)) { |
| filters.put(groupSet.e, z + groupSet.i); |
| } |
| |
| relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets), |
| distinctAggCalls); |
| final RelNode distinct = relBuilder.peek(); |
| |
| // GROUPING returns an integer (0 or 1). Add a project to convert those |
| // values to BOOLEAN. |
| if (!filters.isEmpty()) { |
| final List<RexNode> nodes = new ArrayList<>(relBuilder.fields()); |
| final RexNode nodeZ = nodes.remove(nodes.size() - 1); |
| for (Map.Entry<ImmutableBitSet, Integer> entry : filters.entrySet()) { |
| final long v = groupValue(fullGroupSet, entry.getKey()); |
| nodes.add( |
| relBuilder.alias( |
| relBuilder.equals(nodeZ, relBuilder.literal(v)), |
| "$g_" + v)); |
| } |
| relBuilder.project(nodes); |
| } |
| |
| int aggCallIdx = 0; |
| int x = groupCount; |
| final List<AggregateCall> newCalls = new ArrayList<>(); |
| // TODO supports more aggCalls (currently only supports COUNT) |
| // Some aggregate functions (e.g. COUNT) have the special property that they can return a |
| // non-null result without any input. We need to make sure we return a result in this case. |
| final List<Integer> needDefaultValueAggCalls = new ArrayList<>(); |
| for (AggregateCall aggCall : aggregate.getAggCallList()) { |
| final int newFilterArg; |
| final List<Integer> newArgList; |
| final SqlAggFunction aggregation; |
| if (!aggCall.isDistinct()) { |
| aggregation = SqlStdOperatorTable.MIN; |
| newArgList = ImmutableIntList.of(x++); |
| newFilterArg = filters.get(aggregate.getGroupSet()); |
| switch (aggCall.getAggregation().getKind()) { |
| case COUNT: |
| needDefaultValueAggCalls.add(aggCallIdx); |
| break; |
| default: |
| } |
| } else { |
| aggregation = aggCall.getAggregation(); |
| newArgList = remap(fullGroupSet, aggCall.getArgList()); |
| newFilterArg = |
| filters.get( |
| ImmutableBitSet.of(aggCall.getArgList()) |
| .setIf(aggCall.filterArg, aggCall.filterArg >= 0) |
| .union(aggregate.getGroupSet())); |
| } |
| final AggregateCall newCall = |
| AggregateCall.create(aggregation, false, aggCall.isApproximate(), |
| newArgList, newFilterArg, aggregate.getGroupCount(), distinct, |
| null, aggCall.name); |
| newCalls.add(newCall); |
| aggCallIdx++; |
| } |
| |
| relBuilder.aggregate( |
| relBuilder.groupKey( |
| remap(fullGroupSet, aggregate.getGroupSet()), |
| remap(fullGroupSet, aggregate.getGroupSets())), |
| newCalls); |
| if (!needDefaultValueAggCalls.isEmpty() && aggregate.getGroupCount() == 0) { |
| final Aggregate newAgg = (Aggregate) relBuilder.peek(); |
| final List<RexNode> nodes = new ArrayList<>(); |
| for (int i = 0; i < newAgg.getGroupCount(); ++i) { |
| nodes.add(RexInputRef.of(i, newAgg.getRowType())); |
| } |
| for (int i = 0; i < newAgg.getAggCallList().size(); ++i) { |
| final RexNode inputRef = RexInputRef.of(newAgg.getGroupCount() + i, newAgg.getRowType()); |
| RexNode newNode = inputRef; |
| if (needDefaultValueAggCalls.contains(i)) { |
| SqlKind originalFunKind = aggregate.getAggCallList().get(i).getAggregation().getKind(); |
| switch (originalFunKind) { |
| case COUNT: |
| newNode = relBuilder.call( |
| SqlStdOperatorTable.CASE, |
| relBuilder.isNotNull(inputRef), |
| inputRef, |
| relBuilder.literal(BigDecimal.ZERO)); |
| break; |
| default: |
| } |
| } |
| nodes.add(newNode); |
| } |
| relBuilder.project(nodes); |
| } |
| |
| relBuilder.convert(aggregate.getRowType(), true); |
| call.transformTo(relBuilder.build()); |
| } |
| |
| private static long groupValue(ImmutableBitSet fullGroupSet, |
| ImmutableBitSet groupSet) { |
| long v = 0; |
| long x = 1L << (fullGroupSet.cardinality() - 1); |
| assert fullGroupSet.contains(groupSet); |
| for (int i : fullGroupSet) { |
| if (!groupSet.get(i)) { |
| v |= x; |
| } |
| x >>= 1; |
| } |
| return v; |
| } |
| |
| private static ImmutableBitSet remap(ImmutableBitSet groupSet, |
| ImmutableBitSet bitSet) { |
| final ImmutableBitSet.Builder builder = ImmutableBitSet.builder(); |
| for (Integer bit : bitSet) { |
| builder.set(remap(groupSet, bit)); |
| } |
| return builder.build(); |
| } |
| |
| private static com.google.common.collect.ImmutableList<ImmutableBitSet> remap( |
| ImmutableBitSet groupSet, |
| Iterable<ImmutableBitSet> bitSets) { |
| final com.google.common.collect.ImmutableList.Builder<ImmutableBitSet> builder = |
| com.google.common.collect.ImmutableList.builder(); |
| for (ImmutableBitSet bitSet : bitSets) { |
| builder.add(remap(groupSet, bitSet)); |
| } |
| return builder.build(); |
| } |
| |
| private static List<Integer> remap(ImmutableBitSet groupSet, |
| List<Integer> argList) { |
| ImmutableIntList list = ImmutableIntList.of(); |
| for (int arg : argList) { |
| list = list.append(remap(groupSet, arg)); |
| } |
| return list; |
| } |
| |
| private static int remap(ImmutableBitSet groupSet, int arg) { |
| return arg < 0 ? -1 : groupSet.indexOf(arg); |
| } |
| |
| /** |
| * Converts an aggregate relational expression that contains just one |
| * distinct aggregate function (or perhaps several over the same arguments) |
| * and no non-distinct aggregate functions. |
| */ |
| private RelBuilder convertMonopole(RelBuilder relBuilder, Aggregate aggregate, |
| List<Integer> argList, int filterArg) { |
| // For example, |
| // SELECT deptno, COUNT(DISTINCT sal), SUM(DISTINCT sal) |
| // FROM emp |
| // GROUP BY deptno |
| // |
| // becomes |
| // |
| // SELECT deptno, COUNT(distinct_sal), SUM(distinct_sal) |
| // FROM ( |
| // SELECT DISTINCT deptno, sal AS distinct_sal |
| // FROM EMP GROUP BY deptno) |
| // GROUP BY deptno |
| |
| // Project the columns of the GROUP BY plus the arguments |
| // to the agg function. |
| final Map<Integer, Integer> sourceOf = new HashMap<>(); |
| createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf); |
| |
| // Create an aggregate on top, with the new aggregate list. |
| final List<AggregateCall> newAggCalls = |
| com.google.common.collect.Lists.newArrayList(aggregate.getAggCallList()); |
| rewriteAggCalls(newAggCalls, argList, sourceOf); |
| final int cardinality = aggregate.getGroupSet().cardinality(); |
| relBuilder.push( |
| aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), |
| aggregate.indicator, ImmutableBitSet.range(cardinality), null, |
| newAggCalls)); |
| return relBuilder; |
| } |
| |
| /** |
| * Converts all distinct aggregate calls to a given set of arguments. |
| * |
| * <p>This method is called several times, one for each set of arguments. |
| * Each time it is called, it generates a JOIN to a new SELECT DISTINCT |
| * relational expression, and modifies the set of top-level calls. |
| * |
| * @param aggregate Original aggregate |
| * @param n Ordinal of this in a join. {@code relBuilder} contains the |
| * input relational expression (either the original |
| * aggregate, the output from the previous call to this |
| * method. {@code n} is 0 if we're converting the |
| * first distinct aggregate in a query with no non-distinct |
| * aggregates) |
| * @param argList Arguments to the distinct aggregate function |
| * @param filterArg Argument that filters input to aggregate function, or -1 |
| * @param refs Array of expressions which will be the projected by the |
| * result of this rule. Those relating to this arg list will |
| * be modified @return Relational expression |
| */ |
| private void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n, |
| List<Integer> argList, int filterArg, List<RexInputRef> refs) { |
| final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); |
| final List<RelDataTypeField> leftFields; |
| if (n == 0) { |
| leftFields = null; |
| } else { |
| leftFields = relBuilder.peek().getRowType().getFieldList(); |
| } |
| |
| // Aggregate( |
| // child, |
| // {COUNT(DISTINCT 1), SUM(DISTINCT 1), SUM(2)}) |
| // |
| // becomes |
| // |
| // Aggregate( |
| // Join( |
| // child, |
| // Aggregate(child, < all columns > {}), |
| // INNER, |
| // <f2 = f5>)) |
| // |
| // E.g. |
| // SELECT deptno, SUM(DISTINCT sal), COUNT(DISTINCT gender), MAX(age) |
| // FROM Emps |
| // GROUP BY deptno |
| // |
| // becomes |
| // |
| // SELECT e.deptno, adsal.sum_sal, adgender.count_gender, e.max_age |
| // FROM ( |
| // SELECT deptno, MAX(age) as max_age |
| // FROM Emps GROUP BY deptno) AS e |
| // JOIN ( |
| // SELECT deptno, COUNT(gender) AS count_gender FROM ( |
| // SELECT DISTINCT deptno, gender FROM Emps) AS dgender |
| // GROUP BY deptno) AS adgender |
| // ON e.deptno = adgender.deptno |
| // JOIN ( |
| // SELECT deptno, SUM(sal) AS sum_sal FROM ( |
| // SELECT DISTINCT deptno, sal FROM Emps) AS dsal |
| // GROUP BY deptno) AS adsal |
| // ON e.deptno = adsal.deptno |
| // GROUP BY e.deptno |
| // |
| // Note that if a query contains no non-distinct aggregates, then the |
| // very first join/group by is omitted. In the example above, if |
| // MAX(age) is removed, then the sub-select of "e" is not needed, and |
| // instead the two other group by's are joined to one another. |
| |
| // Project the columns of the GROUP BY plus the arguments |
| // to the agg function. |
| final Map<Integer, Integer> sourceOf = new HashMap<>(); |
| createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf); |
| |
| // Now compute the aggregate functions on top of the distinct dataset. |
| // Each distinct agg becomes a non-distinct call to the corresponding |
| // field from the right; for example, |
| // "COUNT(DISTINCT e.sal)" |
| // becomes |
| // "COUNT(distinct_e.sal)". |
| final List<AggregateCall> aggCallList = new ArrayList<>(); |
| final List<AggregateCall> aggCalls = aggregate.getAggCallList(); |
| |
| final int groupAndIndicatorCount = |
| aggregate.getGroupCount() + aggregate.getIndicatorCount(); |
| int i = groupAndIndicatorCount - 1; |
| for (AggregateCall aggCall : aggCalls) { |
| ++i; |
| |
| // Ignore agg calls which are not distinct or have the wrong set |
| // arguments. If we're rewriting aggs whose args are {sal}, we will |
| // rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) but ignore |
| // COUNT(DISTINCT gender) or SUM(sal). |
| if (!aggCall.isDistinct()) { |
| continue; |
| } |
| if (!aggCall.getArgList().equals(argList)) { |
| continue; |
| } |
| |
| // Re-map arguments. |
| final int argCount = aggCall.getArgList().size(); |
| final List<Integer> newArgs = new ArrayList<>(argCount); |
| for (int j = 0; j < argCount; j++) { |
| final Integer arg = aggCall.getArgList().get(j); |
| newArgs.add(sourceOf.get(arg)); |
| } |
| final int newFilterArg = |
| aggCall.filterArg >= 0 ? sourceOf.get(aggCall.filterArg) : -1; |
| final AggregateCall newAggCall = |
| AggregateCall.create(aggCall.getAggregation(), false, |
| aggCall.isApproximate(), newArgs, |
| newFilterArg, aggCall.getType(), aggCall.getName()); |
| assert refs.get(i) == null; |
| if (n == 0) { |
| refs.set(i, |
| new RexInputRef(groupAndIndicatorCount + aggCallList.size(), |
| newAggCall.getType())); |
| } else { |
| refs.set(i, |
| new RexInputRef(leftFields.size() + groupAndIndicatorCount |
| + aggCallList.size(), newAggCall.getType())); |
| } |
| aggCallList.add(newAggCall); |
| } |
| |
| final Map<Integer, Integer> map = new HashMap<>(); |
| for (Integer key : aggregate.getGroupSet()) { |
| map.put(key, map.size()); |
| } |
| final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map); |
| assert newGroupSet |
| .equals(ImmutableBitSet.range(aggregate.getGroupSet().cardinality())); |
| com.google.common.collect.ImmutableList<ImmutableBitSet> newGroupingSets = null; |
| if (aggregate.indicator) { |
| newGroupingSets = |
| ImmutableBitSet.ORDERING.immutableSortedCopy( |
| ImmutableBitSet.permute(aggregate.getGroupSets(), map)); |
| } |
| |
| relBuilder.push( |
| aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), |
| aggregate.indicator, newGroupSet, newGroupingSets, aggCallList)); |
| |
| // If there's no left child yet, no need to create the join |
| if (n == 0) { |
| return; |
| } |
| |
| // Create the join condition. It is of the form |
| // 'left.f0 = right.f0 and left.f1 = right.f1 and ...' |
| // where {f0, f1, ...} are the GROUP BY fields. |
| final List<RelDataTypeField> distinctFields = |
| relBuilder.peek().getRowType().getFieldList(); |
| final List<RexNode> conditions = com.google.common.collect.Lists.newArrayList(); |
| for (i = 0; i < groupAndIndicatorCount; ++i) { |
| // null values form its own group |
| // use "is not distinct from" so that the join condition |
| // allows null values to match. |
| conditions.add( |
| rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, |
| RexInputRef.of(i, leftFields), |
| new RexInputRef(leftFields.size() + i, |
| distinctFields.get(i).getType()))); |
| } |
| |
| // Join in the new 'select distinct' relation. |
| relBuilder.join(JoinRelType.INNER, conditions); |
| } |
| |
| private static void rewriteAggCalls( |
| List<AggregateCall> newAggCalls, |
| List<Integer> argList, |
| Map<Integer, Integer> sourceOf) { |
| // Rewrite the agg calls. Each distinct agg becomes a non-distinct call |
| // to the corresponding field from the right; for example, |
| // "COUNT(DISTINCT e.sal)" becomes "COUNT(distinct_e.sal)". |
| for (int i = 0; i < newAggCalls.size(); i++) { |
| final AggregateCall aggCall = newAggCalls.get(i); |
| |
| // Ignore agg calls which are not distinct or have the wrong set |
| // arguments. If we're rewriting aggregates whose args are {sal}, we will |
| // rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) but ignore |
| // COUNT(DISTINCT gender) or SUM(sal). |
| if (!aggCall.isDistinct()) { |
| continue; |
| } |
| if (!aggCall.getArgList().equals(argList)) { |
| continue; |
| } |
| |
| // Re-map arguments. |
| final int argCount = aggCall.getArgList().size(); |
| final List<Integer> newArgs = new ArrayList<>(argCount); |
| for (int j = 0; j < argCount; j++) { |
| final Integer arg = aggCall.getArgList().get(j); |
| newArgs.add(sourceOf.get(arg)); |
| } |
| final AggregateCall newAggCall = |
| AggregateCall.create(aggCall.getAggregation(), false, |
| aggCall.isApproximate(), newArgs, -1, |
| aggCall.getType(), aggCall.getName()); |
| newAggCalls.set(i, newAggCall); |
| } |
| } |
| |
| /** |
| * Given an {@link org.apache.calcite.rel.core.Aggregate} |
| * and the ordinals of the arguments to a |
| * particular call to an aggregate function, creates a 'select distinct' |
| * relational expression which projects the group columns and those |
| * arguments but nothing else. |
| * |
| * <p>For example, given |
| * |
| * <blockquote> |
| * <pre>select f0, count(distinct f1), count(distinct f2) |
| * from t group by f0</pre> |
| * </blockquote> |
| * |
| * <p>and the argument list |
| * |
| * <blockquote>{2}</blockquote> |
| * |
| * <p>returns |
| * |
| * <blockquote> |
| * <pre>select distinct f0, f2 from t</pre> |
| * </blockquote> |
| * |
| * <p>The <code>sourceOf</code> map is populated with the source of each |
| * column; in this case sourceOf.get(0) = 0, and sourceOf.get(1) = 2. |
| * |
| * @param relBuilder Relational expression builder |
| * @param aggregate Aggregate relational expression |
| * @param argList Ordinals of columns to make distinct |
| * @param filterArg Ordinal of column to filter on, or -1 |
| * @param sourceOf Out parameter, is populated with a map of where each |
| * output field came from |
| * @return Aggregate relational expression which projects the required |
| * columns |
| */ |
| private RelBuilder createSelectDistinct(RelBuilder relBuilder, |
| Aggregate aggregate, List<Integer> argList, int filterArg, |
| Map<Integer, Integer> sourceOf) { |
| relBuilder.push(aggregate.getInput()); |
| final List<Pair<RexNode, String>> projects = new ArrayList<>(); |
| final List<RelDataTypeField> childFields = |
| relBuilder.peek().getRowType().getFieldList(); |
| for (int i : aggregate.getGroupSet()) { |
| sourceOf.put(i, projects.size()); |
| projects.add(RexInputRef.of2(i, childFields)); |
| } |
| if (filterArg >= 0) { |
| sourceOf.put(filterArg, projects.size()); |
| projects.add(RexInputRef.of2(filterArg, childFields)); |
| } |
| for (Integer arg : argList) { |
| if (filterArg >= 0) { |
| // Implement |
| // agg(DISTINCT arg) FILTER $f |
| // by generating |
| // SELECT DISTINCT ... CASE WHEN $f THEN arg ELSE NULL END AS arg |
| // and then applying |
| // agg(arg) |
| // as usual. |
| // |
| // It works except for (rare) agg functions that need to see null |
| // values. |
| final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); |
| final RexInputRef filterRef = RexInputRef.of(filterArg, childFields); |
| final Pair<RexNode, String> argRef = RexInputRef.of2(arg, childFields); |
| RexNode condition = |
| rexBuilder.makeCall(SqlStdOperatorTable.CASE, filterRef, |
| argRef.left, |
| rexBuilder.ensureType(argRef.left.getType(), |
| rexBuilder.makeCast(argRef.left.getType(), |
| rexBuilder.constantNull()), |
| true)); |
| sourceOf.put(arg, projects.size()); |
| projects.add(Pair.of(condition, "i$" + argRef.right)); |
| continue; |
| } |
| if (sourceOf.get(arg) != null) { |
| continue; |
| } |
| sourceOf.put(arg, projects.size()); |
| projects.add(RexInputRef.of2(arg, childFields)); |
| } |
| relBuilder.project(Pair.left(projects), Pair.right(projects)); |
| |
| // Get the distinct values of the GROUP BY fields and the arguments |
| // to the agg functions. |
| relBuilder.push( |
| aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), false, |
| ImmutableBitSet.range(projects.size()), |
| null, com.google.common.collect.ImmutableList.<AggregateCall>of())); |
| return relBuilder; |
| } |
| } |
| |
| // End FlinkAggregateExpandDistinctAggregatesRule.java |