/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.rel.rules;

import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Aggregate.Group;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.runtime.PairList;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Optionality;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.NavigableSet;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;

import static java.util.Objects.requireNonNull;

/**
 * Planner rule that expands distinct aggregates
 * (such as {@code COUNT(DISTINCT x)}) from a
 * {@link org.apache.calcite.rel.core.Aggregate}.
 *
 * <p>How this is done depends upon the arguments to the function. If all
 * functions have the same argument
 * (e.g. {@code COUNT(DISTINCT x), SUM(DISTINCT x)} both have the argument
 * {@code x}) then one extra {@link org.apache.calcite.rel.core.Aggregate} is
 * sufficient.
 *
 * <p>If there are multiple arguments
 * (e.g. {@code COUNT(DISTINCT x), COUNT(DISTINCT y)})
 * the rule creates separate {@code Aggregate}s and combines using a
 * {@link org.apache.calcite.rel.core.Join}.
 *
 * @see CoreRules#AGGREGATE_EXPAND_DISTINCT_AGGREGATES
 * @see CoreRules#AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN
 */
@Value.Enclosing
public final class AggregateExpandDistinctAggregatesRule
    extends RelRule<AggregateExpandDistinctAggregatesRule.Config>
    implements TransformationRule {

  /** Creates an AggregateExpandDistinctAggregatesRule. */
  AggregateExpandDistinctAggregatesRule(Config config) {
    super(config);
  }

  @Deprecated // to be removed before 2.0
  public AggregateExpandDistinctAggregatesRule(
      Class<? extends Aggregate> clazz,
      boolean useGroupingSets,
      RelBuilderFactory relBuilderFactory) {
    this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory)
        .withOperandSupplier(b ->
            b.operand(clazz).anyInputs())
        .as(Config.class)
        .withUsingGroupingSets(useGroupingSets));
  }

  @Deprecated // to be removed before 2.0
  public AggregateExpandDistinctAggregatesRule(
      Class<? extends LogicalAggregate> clazz,
      boolean useGroupingSets,
      RelFactories.JoinFactory joinFactory) {
    this(clazz, useGroupingSets, RelBuilder.proto(Contexts.of(joinFactory)));
  }

  @Deprecated // to be removed before 2.0
  public AggregateExpandDistinctAggregatesRule(
      Class<? extends LogicalAggregate> clazz,
      RelFactories.JoinFactory joinFactory) {
    this(clazz, false, RelBuilder.proto(Contexts.of(joinFactory)));
  }

  //~ Methods ----------------------------------------------------------------

  @Override public void onMatch(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    if (!aggregate.containsDistinctCall()) {
      return;
    }

    if (!config.isUsingGroupingSets()
        && aggregate.groupSets.size() > 1) {
      // Grouping sets are not handled correctly
      // when generating joins.
      return;
    }

    // Find all of the agg expressions. We use a LinkedHashSet to ensure determinism.
    final List<AggregateCall> aggCalls = aggregate.getAggCallList();
    // Find all aggregate calls with distinct
    final List<AggregateCall> distinctAggCalls = aggCalls.stream()
        .filter(AggregateCall::isDistinct).collect(Collectors.toList());
    // Find all aggregate calls without distinct
    final List<AggregateCall> nonDistinctAggCalls = aggCalls.stream()
        .filter(aggCall -> !aggCall.isDistinct()).collect(Collectors.toList());
    final long filterCount = aggCalls.stream()
        .filter(aggCall -> aggCall.filterArg >= 0).count();
    final long unsupportedNonDistinctAggCallCount = nonDistinctAggCalls.stream()
        .filter(aggCall -> {
          final SqlKind aggCallKind = aggCall.getAggregation().getKind();
          // We only support COUNT/SUM/MIN/MAX for the "single" count distinct optimization
          switch (aggCallKind) {
          case COUNT:
          case SUM:
          case SUM0:
          case MIN:
          case MAX:
            return false;
          default:
            return true;
          }
        }).count();
    // Argument list of distinct agg calls.
    final Set<Pair<List<Integer>, Integer>> distinctCallArgLists = distinctAggCalls.stream()
        .map(aggCall -> Pair.of(aggCall.getArgList(), aggCall.filterArg))
        .collect(Collectors.toCollection(LinkedHashSet::new));

    checkState(!distinctCallArgLists.isEmpty(), "containsDistinctCall lied");

    // If all of the agg expressions are distinct and have the same
    // arguments then we can use a more efficient form.

    // MAX, MIN, BIT_AND, BIT_OR always ignore distinct attribute,
    // when they are mixed in with other distinct agg calls,
    // we can still use this promotion.

    // Treat the agg expression with Optionality.IGNORED as distinct and
    // re-statistic the non-distinct agg call count and the distinct agg
    // call arguments.
    final List<AggregateCall> nonDistinctAggCallsOfIgnoredOptionality =
        nonDistinctAggCalls.stream().filter(aggCall ->
            aggCall.getAggregation().getDistinctOptionality() == Optionality.IGNORED)
            .collect(Collectors.toList());
    // Different with distinctCallArgLists, this list also contains args that come from
    // agg call which can ignore the distinct constraint.
    final Set<Pair<List<Integer>, Integer>> distinctCallArgLists2 =
        Stream.of(distinctAggCalls, nonDistinctAggCallsOfIgnoredOptionality)
            .flatMap(Collection::stream)
            .map(aggCall -> Pair.of(aggCall.getArgList(), aggCall.filterArg))
            .collect(Collectors.toCollection(LinkedHashSet::new));

    if ((nonDistinctAggCalls.size() - nonDistinctAggCallsOfIgnoredOptionality.size()) == 0
        && distinctCallArgLists2.size() == 1
        && aggregate.getGroupType() == Group.SIMPLE) {
      final Pair<List<Integer>, Integer> pair =
          Iterables.getOnlyElement(distinctCallArgLists2);
      final RelBuilder relBuilder = call.builder();
      convertMonopole(relBuilder, aggregate, pair.left, pair.right);
      call.transformTo(relBuilder.build());
      return;
    }

    if (config.isUsingGroupingSets()) {
      rewriteUsingGroupingSets(call, aggregate);
      return;
    }

    // If only one distinct aggregate and one or more non-distinct aggregates,
    // we can generate multi-phase aggregates
    if (distinctAggCalls.size() == 1 // one distinct aggregate
        && filterCount == 0 // no filter
        && unsupportedNonDistinctAggCallCount == 0 // sum/min/max/count in non-distinct aggregate
        && nonDistinctAggCalls.size() > 0) { // one or more non-distinct aggregates
      final RelBuilder relBuilder = call.builder();
      convertSingletonDistinct(relBuilder, aggregate, distinctCallArgLists);
      call.transformTo(relBuilder.build());
      return;
    }

    // Create a list of the expressions which will yield the final result.
    // Initially, the expressions point to the input field.
    final List<RelDataTypeField> aggFields =
        aggregate.getRowType().getFieldList();
    final List<@Nullable RexInputRef> refs = new ArrayList<>();
    final List<String> fieldNames = aggregate.getRowType().getFieldNames();
    final ImmutableBitSet groupSet = aggregate.getGroupSet();
    final int groupCount = aggregate.getGroupCount();
    for (int i : Util.range(groupCount)) {
      refs.add(RexInputRef.of(i, aggFields));
    }

    // Aggregate the original relation, including any non-distinct aggregates.
    final List<AggregateCall> newAggCallList = new ArrayList<>();
    int i = -1;
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
      ++i;
      if (aggCall.isDistinct()) {
        refs.add(null);
        continue;
      }
      refs.add(
          new RexInputRef(
              groupCount + newAggCallList.size(),
              aggFields.get(groupCount + i).getType()));
      newAggCallList.add(aggCall);
    }

    // In the case where there are no non-distinct aggregates (regardless of
    // whether there are group bys), there's no need to generate the
    // extra aggregate and join.
    final RelBuilder relBuilder = call.builder();
    relBuilder.push(aggregate.getInput());
    int n = 0;
    if (!newAggCallList.isEmpty()) {
      final RelBuilder.GroupKey groupKey =
          relBuilder.groupKey(groupSet, aggregate.getGroupSets());
      relBuilder.aggregate(groupKey, newAggCallList);
      ++n;
    }

    // For each set of operands, find and rewrite all calls which have that
    // set of operands.
    for (Pair<List<Integer>, Integer> argList : distinctCallArgLists) {
      doRewrite(relBuilder, aggregate, n++, argList.left, argList.right, refs);
    }
    // It is assumed doRewrite above replaces nulls in refs
    @SuppressWarnings("assignment.type.incompatible")
    List<RexInputRef> nonNullRefs = refs;
    relBuilder.project(nonNullRefs, fieldNames);
    call.transformTo(relBuilder.build());
  }

  /**
   * Converts an aggregate with one distinct aggregate and one or more
   * non-distinct aggregates to multi-phase aggregates (see reference example
   * below).
   *
   * @param relBuilder Contains the input relational expression
   * @param aggregate  Original aggregate
   * @param argLists   Arguments and filters to the distinct aggregate function
   *
   */
  private static RelBuilder convertSingletonDistinct(RelBuilder relBuilder,
      Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {

    // In this case, we are assuming that there is a single distinct function.
    // So make sure that argLists is of size one.
    checkArgument(argLists.size() == 1);

    // For example,
    //    SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
    //    FROM emp
    //    GROUP BY deptno
    //
    // becomes
    //
    //    SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
    //    FROM (
    //          SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
    //          FROM EMP
    //          GROUP BY deptno, sal)            // Aggregate B
    //    GROUP BY deptno                        // Aggregate A
    relBuilder.push(aggregate.getInput());

    final List<AggregateCall> originalAggCalls = aggregate.getAggCallList();
    final ImmutableBitSet originalGroupSet = aggregate.getGroupSet();

    // Add the distinct aggregate column(s) to the group-by columns,
    // if not already a part of the group-by
    final NavigableSet<Integer> bottomGroups = new TreeSet<>(aggregate.getGroupSet().asList());
    for (AggregateCall aggCall : originalAggCalls) {
      if (aggCall.isDistinct()) {
        bottomGroups.addAll(aggCall.getArgList());
        break;  // since we only have single distinct call
      }
    }
    final ImmutableBitSet bottomGroupSet = ImmutableBitSet.of(bottomGroups);

    // Generate the intermediate aggregate B, the one on the bottom that converts
    // a distinct call to group by call.
    // Bottom aggregate is the same as the original aggregate, except that
    // the bottom aggregate has converted the DISTINCT aggregate to a group by clause.
    final List<AggregateCall> bottomAggregateCalls = new ArrayList<>();
    for (AggregateCall aggCall : originalAggCalls) {
      // Project the column corresponding to the distinct aggregate. Project
      // as-is all the non-distinct aggregates
      if (!aggCall.isDistinct()) {
        final AggregateCall newCall =
            AggregateCall.create(aggCall.getAggregation(), false,
                aggCall.isApproximate(), aggCall.ignoreNulls(), aggCall.rexList,
                aggCall.getArgList(), -1, aggCall.distinctKeys,
                aggCall.collation, bottomGroupSet.cardinality(),
                relBuilder.peek(), null, aggCall.name);
        bottomAggregateCalls.add(newCall);
      }
    }
    // Generate the aggregate B (see the reference example above)
    relBuilder.push(
        aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
            bottomGroupSet, null, bottomAggregateCalls));

    // Add aggregate A (see the reference example above), the top aggregate
    // to handle the rest of the aggregation that the bottom aggregate hasn't handled
    final List<AggregateCall> topAggregateCalls = new ArrayList<>();
    // Use the remapped arguments for the (non)distinct aggregate calls
    int nonDistinctAggCallProcessedSoFar = 0;
    for (AggregateCall aggCall : originalAggCalls) {
      final AggregateCall newCall;
      if (aggCall.isDistinct()) {
        List<Integer> newArgList = new ArrayList<>();
        for (int arg : aggCall.getArgList()) {
          newArgList.add(bottomGroups.headSet(arg, false).size());
        }
        newCall =
            AggregateCall.create(aggCall.getAggregation(),
                false,
                aggCall.isApproximate(),
                aggCall.ignoreNulls(),
                aggCall.rexList,
                newArgList,
                -1,
                aggCall.distinctKeys,
                aggCall.collation,
                originalGroupSet.cardinality(),
                relBuilder.peek(),
                aggCall.getType(),
                aggCall.name);
      } else {
        // If aggregate B had a COUNT aggregate call the corresponding aggregate at
        // aggregate A must be SUM. For other aggregates, it remains the same.
        final int arg = bottomGroups.size() + nonDistinctAggCallProcessedSoFar;
        final List<Integer> newArgs = ImmutableList.of(arg);
        if (aggCall.getAggregation().getKind() == SqlKind.COUNT) {
          newCall =
              AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), false,
                  aggCall.isApproximate(), aggCall.ignoreNulls(),
                  aggCall.rexList, newArgs, -1, aggCall.distinctKeys,
                  aggCall.collation, originalGroupSet.cardinality(),
                  relBuilder.peek(), null, aggCall.getName());
        } else {
          newCall =
              AggregateCall.create(aggCall.getAggregation(), false,
                  aggCall.isApproximate(), aggCall.ignoreNulls(),
                  aggCall.rexList, newArgs, -1, aggCall.distinctKeys,
                  aggCall.collation, originalGroupSet.cardinality(),
                  relBuilder.peek(), null, aggCall.name);
        }
        nonDistinctAggCallProcessedSoFar++;
      }

      topAggregateCalls.add(newCall);
    }

    // Populate the group-by keys with the remapped arguments for aggregate A
    // The top groupset is basically an identity (first X fields of aggregate B's
    // output), minus the distinct aggCall's input.
    final Set<Integer> topGroupSet = new HashSet<>();
    int groupSetToAdd = 0;
    for (int bottomGroup : bottomGroups) {
      if (originalGroupSet.get(bottomGroup)) {
        topGroupSet.add(groupSetToAdd);
      }
      groupSetToAdd++;
    }
    relBuilder.push(
        aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
            ImmutableBitSet.of(topGroupSet), null, topAggregateCalls));

    // Add projection node for case: SUM of COUNT(*):
    // Type of the SUM may be larger than type of COUNT.
    // CAST to original type must be added.
    relBuilder.convert(aggregate.getRowType(), true);

    return relBuilder;
  }

  private static void rewriteUsingGroupingSets(RelOptRuleCall call,
      Aggregate aggregate) {
    final Set<ImmutableBitSet> groupSetTreeSet =
        new TreeSet<>(ImmutableBitSet.ORDERING);
    // GroupSet to distinct filter arg map,
    // filterArg will be -1 for non-distinct agg call.

    // Using `Set` here because it's possible that two agg calls
    // have different filterArgs but same groupSet.
    final Map<ImmutableBitSet, Set<Integer>> distinctFilterArgMap = new HashMap<>();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
      ImmutableBitSet groupSet;
      int filterArg;
      if (!aggCall.isDistinct()) {
        filterArg = -1;
        groupSet = aggregate.getGroupSet();
        groupSetTreeSet.add(aggregate.getGroupSet());
      } else {
        filterArg = aggCall.filterArg;
        groupSet =
            ImmutableBitSet.of(aggCall.getArgList())
                .setIf(filterArg, filterArg >= 0)
                .union(aggregate.getGroupSet());
        groupSetTreeSet.add(groupSet);
      }
      Set<Integer> filterList = distinctFilterArgMap
          .computeIfAbsent(groupSet, g -> new HashSet<>());
      filterList.add(filterArg);
    }

    final ImmutableList<ImmutableBitSet> groupSets =
        ImmutableList.copyOf(groupSetTreeSet);
    final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);

    final List<AggregateCall> distinctAggCalls = new ArrayList<>();
    for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
      if (!aggCall.left.isDistinct()) {
        AggregateCall newAggCall =
            aggCall.left.adaptTo(aggregate.getInput(),
                aggCall.left.getArgList(), aggCall.left.filterArg,
                aggregate.getGroupCount(), fullGroupSet.cardinality());
        distinctAggCalls.add(newAggCall.withName(aggCall.right));
      }
    }

    final RelBuilder relBuilder = call.builder();
    relBuilder.push(aggregate.getInput());
    final int groupCount = fullGroupSet.cardinality();

    // Get the base ordinal of filter args for different groupSets.
    final Map<Pair<ImmutableBitSet, Integer>, Integer> filters = new LinkedHashMap<>();
    int z = groupCount + distinctAggCalls.size();
    for (ImmutableBitSet groupSet : groupSets) {
      Set<Integer> filterArgList = distinctFilterArgMap.get(groupSet);
      for (Integer filterArg : requireNonNull(filterArgList, "filterArgList")) {
        filters.put(Pair.of(groupSet, filterArg), z);
        z += 1;
      }
    }

    distinctAggCalls.add(
        AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false, false,
            ImmutableList.of(), ImmutableIntList.copyOf(fullGroupSet), -1,
            null, RelCollations.EMPTY,
            groupSets.size(), relBuilder.peek(), null, "$g"));

    relBuilder.aggregate(
        relBuilder.groupKey(fullGroupSet, groupSets),
        distinctAggCalls);

    // GROUPING returns an integer (0 or 1). Add a project to convert those
    // values to BOOLEAN.
    if (!filters.isEmpty()) {
      final List<RexNode> nodes = new ArrayList<>(relBuilder.fields());
      final RexNode nodeZ = nodes.remove(nodes.size() - 1);
      for (Map.Entry<Pair<ImmutableBitSet, Integer>, Integer> entry : filters.entrySet()) {
        final long v = groupValue(fullGroupSet.asList(), entry.getKey().left);
        int distinctFilterArg = remap(fullGroupSet, entry.getKey().right);
        RexNode expr = relBuilder.equals(nodeZ, relBuilder.literal(v));
        if (distinctFilterArg > -1) {
          // 'AND' the filter of the distinct aggregate call and the group value.
          expr =
              relBuilder.and(expr,
                  relBuilder.call(SqlStdOperatorTable.IS_TRUE,
                      relBuilder.field(distinctFilterArg)));
        }
        // "f" means filter.
        nodes.add(
            relBuilder.alias(expr,
            "$g_" + v + (distinctFilterArg < 0 ? "" : "_f_" + distinctFilterArg)));
      }
      relBuilder.project(nodes);
    }

    int x = groupCount;
    final ImmutableBitSet groupSet = aggregate.getGroupSet();
    final List<AggregateCall> newCalls = new ArrayList<>();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
      final int newFilterArg;
      final List<Integer> newArgList;
      final SqlAggFunction aggregation;
      if (!aggCall.isDistinct()) {
        aggregation = SqlStdOperatorTable.MIN;
        newArgList = ImmutableIntList.of(x++);
        newFilterArg =
            requireNonNull(filters.get(Pair.of(groupSet, -1)),
                "filters.get(Pair.of(groupSet, -1))");
      } else {
        aggregation = aggCall.getAggregation();
        newArgList = remap(fullGroupSet, aggCall.getArgList());
        final ImmutableBitSet newGroupSet = ImmutableBitSet.of(aggCall.getArgList())
            .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
            .union(groupSet);
        newFilterArg =
            requireNonNull(filters.get(Pair.of(newGroupSet, aggCall.filterArg)),
                "filters.get(of(newGroupSet, aggCall.filterArg))");
      }
      final AggregateCall newCall =
          AggregateCall.create(aggregation, false,
              aggCall.isApproximate(), aggCall.ignoreNulls(),
              aggCall.rexList, newArgList, newFilterArg,
              aggCall.distinctKeys, aggCall.collation,
              aggregate.getGroupCount(), relBuilder.peek(), null, aggCall.name);
      newCalls.add(newCall);
    }

    relBuilder.aggregate(
        relBuilder.groupKey(
            remap(fullGroupSet, groupSet),
            remap(fullGroupSet, aggregate.getGroupSets())),
        newCalls);
    relBuilder.convert(aggregate.getRowType(), true);
    call.transformTo(relBuilder.build());
  }

  /** Returns the value that "GROUPING(fullGroupSet)" will return for
   * "groupSet".
   *
   * <p>It is important that {@code fullGroupSet} is not an
   * {@link ImmutableBitSet}; the order of the bits matters. */
  static long groupValue(Collection<Integer> fullGroupSet,
      ImmutableBitSet groupSet) {
    long v = 0;
    long x = 1L << (fullGroupSet.size() - 1);
    assert ImmutableBitSet.of(fullGroupSet).contains(groupSet);
    for (int i : fullGroupSet) {
      if (!groupSet.get(i)) {
        v |= x;
      }
      x >>= 1;
    }
    return v;
  }

  static ImmutableBitSet remap(ImmutableBitSet groupSet,
      ImmutableBitSet bitSet) {
    final ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
    for (Integer bit : bitSet) {
      builder.set(remap(groupSet, bit));
    }
    return builder.build();
  }

  static ImmutableList<ImmutableBitSet> remap(ImmutableBitSet groupSet,
      Iterable<ImmutableBitSet> bitSets) {
    final ImmutableList.Builder<ImmutableBitSet> builder =
        ImmutableList.builder();
    for (ImmutableBitSet bitSet : bitSets) {
      builder.add(remap(groupSet, bitSet));
    }
    return builder.build();
  }

  private static List<Integer> remap(ImmutableBitSet groupSet,
      List<Integer> argList) {
    ImmutableIntList list = ImmutableIntList.of();
    for (int arg : argList) {
      list = list.append(remap(groupSet, arg));
    }
    return list;
  }

  private static int remap(ImmutableBitSet groupSet, int arg) {
    return arg < 0 ? -1 : groupSet.indexOf(arg);
  }

  /**
   * Converts an aggregate relational expression that contains just one
   * distinct aggregate function (or perhaps several over the same arguments)
   * and no non-distinct aggregate functions.
   */
  private static RelBuilder convertMonopole(RelBuilder relBuilder, Aggregate aggregate,
      List<Integer> argList, int filterArg) {
    // For example,
    //    SELECT deptno, COUNT(DISTINCT sal), SUM(DISTINCT sal)
    //    FROM emp
    //    GROUP BY deptno
    //
    // becomes
    //
    //    SELECT deptno, COUNT(distinct_sal), SUM(distinct_sal)
    //    FROM (
    //      SELECT DISTINCT deptno, sal AS distinct_sal
    //      FROM EMP GROUP BY deptno)
    //    GROUP BY deptno

    // Project the columns of the GROUP BY plus the arguments
    // to the agg function.
    final Map<Integer, Integer> sourceOf = new HashMap<>();
    createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);

    // Create an aggregate on top, with the new aggregate list.
    final List<AggregateCall> newAggCalls =
        Lists.newArrayList(aggregate.getAggCallList());
    rewriteAggCalls(newAggCalls, argList, sourceOf);
    final int cardinality = aggregate.getGroupSet().cardinality();
    relBuilder.push(
        aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
            ImmutableBitSet.range(cardinality), null, newAggCalls));
    return relBuilder;
  }

  /**
   * Converts all distinct aggregate calls to a given set of arguments.
   *
   * <p>This method is called several times, one for each set of arguments.
   * Each time it is called, it generates a JOIN to a new SELECT DISTINCT
   * relational expression, and modifies the set of top-level calls.
   *
   * @param aggregate Original aggregate
   * @param n         Ordinal of this in a join. {@code relBuilder} contains the
   *                  input relational expression (either the original
   *                  aggregate, the output from the previous call to this
   *                  method. {@code n} is 0 if we're converting the
   *                  first distinct aggregate in a query with no non-distinct
   *                  aggregates)
   * @param argList   Arguments to the distinct aggregate function
   * @param filterArg Argument that filters input to aggregate function, or -1
   * @param refs      Array of expressions which will be the projected by the
   *                  result of this rule. Those relating to this arg list will
   *                  be modified
   */
  private static void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n,
      List<Integer> argList, int filterArg, List<@Nullable RexInputRef> refs) {
    final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
    final List<RelDataTypeField> leftFields;
    if (n == 0) {
      leftFields = null;
    } else {
      leftFields = relBuilder.peek().getRowType().getFieldList();
    }

    // Aggregate(
    //     child,
    //     {COUNT(DISTINCT 1), SUM(DISTINCT 1), SUM(2)})
    //
    // becomes
    //
    // Aggregate(
    //     Join(
    //         child,
    //         Aggregate(child, < all columns > {}),
    //         INNER,
    //         <f2 = f5>))
    //
    // E.g.
    //   SELECT deptno, SUM(DISTINCT sal), COUNT(DISTINCT gender), MAX(age)
    //   FROM Emps
    //   GROUP BY deptno
    //
    // becomes
    //
    //   SELECT e.deptno, adsal.sum_sal, adgender.count_gender, e.max_age
    //   FROM (
    //     SELECT deptno, MAX(age) as max_age
    //     FROM Emps GROUP BY deptno) AS e
    //   JOIN (
    //     SELECT deptno, COUNT(gender) AS count_gender FROM (
    //       SELECT DISTINCT deptno, gender FROM Emps) AS dgender
    //     GROUP BY deptno) AS adgender
    //     ON e.deptno = adgender.deptno
    //   JOIN (
    //     SELECT deptno, SUM(sal) AS sum_sal FROM (
    //       SELECT DISTINCT deptno, sal FROM Emps) AS dsal
    //     GROUP BY deptno) AS adsal
    //   ON e.deptno = adsal.deptno
    //   GROUP BY e.deptno
    //
    // Note that if a query contains no non-distinct aggregates, then the
    // very first join/group by is omitted.  In the example above, if
    // MAX(age) is removed, then the sub-select of "e" is not needed, and
    // instead the two other group by's are joined to one another.

    // Project the columns of the GROUP BY plus the arguments
    // to the agg function.
    final Map<Integer, Integer> sourceOf = new HashMap<>();
    createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);

    // Now compute the aggregate functions on top of the distinct dataset.
    // Each distinct agg becomes a non-distinct call to the corresponding
    // field from the right; for example,
    //   "COUNT(DISTINCT e.sal)"
    // becomes
    //   "COUNT(distinct_e.sal)".
    final List<AggregateCall> aggCallList = new ArrayList<>();
    final List<AggregateCall> aggCalls = aggregate.getAggCallList();

    final int groupCount = aggregate.getGroupCount();
    int i = groupCount - 1;
    for (AggregateCall aggCall : aggCalls) {
      ++i;

      // Ignore agg calls which are not distinct or have the wrong set
      // arguments. If we're rewriting aggs whose args are {sal}, we will
      // rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) but ignore
      // COUNT(DISTINCT gender) or SUM(sal).
      if (!aggCall.isDistinct()) {
        continue;
      }
      if (!aggCall.getArgList().equals(argList)) {
        continue;
      }

      // Re-map arguments.
      final int argCount = aggCall.getArgList().size();
      final List<Integer> newArgs = new ArrayList<>(argCount);
      for (Integer arg : aggCall.getArgList()) {
        newArgs.add(requireNonNull(sourceOf.get(arg), () -> "sourceOf.get(" + arg + ")"));
      }
      final int newFilterArg =
          aggCall.filterArg < 0 ? -1
              : requireNonNull(sourceOf.get(aggCall.filterArg),
                  () -> "sourceOf.get(" + aggCall.filterArg + ")");
      final AggregateCall newAggCall =
          AggregateCall.create(aggCall.getAggregation(), false,
              aggCall.isApproximate(), aggCall.ignoreNulls(), aggCall.rexList,
              newArgs, newFilterArg, aggCall.distinctKeys, aggCall.collation,
              aggCall.getType(), aggCall.getName());
      assert refs.get(i) == null;
      if (leftFields == null) {
        refs.set(i,
            new RexInputRef(groupCount + aggCallList.size(),
                newAggCall.getType()));
      } else {
        refs.set(i,
            new RexInputRef(leftFields.size() + groupCount
                + aggCallList.size(), newAggCall.getType()));
      }
      aggCallList.add(newAggCall);
    }

    final Map<Integer, Integer> map = new HashMap<>();
    for (Integer key : aggregate.getGroupSet()) {
      map.put(key, map.size());
    }
    final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map);
    assert newGroupSet
        .equals(ImmutableBitSet.range(aggregate.getGroupSet().cardinality()));
    ImmutableList<ImmutableBitSet> newGroupingSets = null;

    relBuilder.push(
        aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
            newGroupSet, newGroupingSets, aggCallList));

    // If there's no left child yet, no need to create the join
    if (leftFields == null) {
      return;
    }

    // Create the join condition. It is of the form
    //  'left.f0 = right.f0 and left.f1 = right.f1 and ...'
    // where {f0, f1, ...} are the GROUP BY fields.
    final List<RelDataTypeField> distinctFields =
        relBuilder.peek().getRowType().getFieldList();
    final List<RexNode> conditions = new ArrayList<>();
    for (i = 0; i < groupCount; ++i) {
      // null values form its own group
      // use "is not distinct from" so that the join condition
      // allows null values to match.
      conditions.add(
          rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM,
              RexInputRef.of(i, leftFields),
              new RexInputRef(leftFields.size() + i,
                  distinctFields.get(i).getType())));
    }

    // Join in the new 'select distinct' relation.
    relBuilder.join(JoinRelType.INNER, conditions);
  }

  private static void rewriteAggCalls(
      List<AggregateCall> newAggCalls,
      List<Integer> argList,
      Map<Integer, Integer> sourceOf) {
    // Rewrite the agg calls. Each distinct agg becomes a non-distinct call
    // to the corresponding field from the right; for example,
    // "COUNT(DISTINCT e.sal)" becomes   "COUNT(distinct_e.sal)".
    for (int i = 0; i < newAggCalls.size(); i++) {
      final AggregateCall aggCall = newAggCalls.get(i);

      // Ignore agg calls which are not distinct or have the wrong set
      // arguments. If we're rewriting aggregates whose args are {sal}, we will
      // rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) but ignore
      // COUNT(DISTINCT gender) or SUM(sal).
      if (!aggCall.isDistinct()
          && aggCall.getAggregation().getDistinctOptionality() != Optionality.IGNORED) {
        continue;
      }
      if (!aggCall.getArgList().equals(argList)) {
        continue;
      }

      // Re-map arguments.
      final int argCount = aggCall.getArgList().size();
      final List<Integer> newArgs = new ArrayList<>(argCount);
      for (int j = 0; j < argCount; j++) {
        final Integer arg = aggCall.getArgList().get(j);
        newArgs.add(
            requireNonNull(sourceOf.get(arg),
                () -> "sourceOf.get(" + arg + ")"));
      }
      final AggregateCall newAggCall =
          AggregateCall.create(aggCall.getAggregation(), false,
              aggCall.isApproximate(), aggCall.ignoreNulls(),
              aggCall.rexList, newArgs, -1,
              aggCall.distinctKeys, aggCall.collation,
              aggCall.getType(), aggCall.getName());
      newAggCalls.set(i, newAggCall);
    }
  }

  /**
   * Given an {@link org.apache.calcite.rel.core.Aggregate}
   * and the ordinals of the arguments to a
   * particular call to an aggregate function, creates a 'select distinct'
   * relational expression which projects the group columns and those
   * arguments but nothing else.
   *
   * <p>For example, given
   *
   * <blockquote>
   * <pre>select f0, count(distinct f1), count(distinct f2)
   * from t group by f0</pre>
   * </blockquote>
   *
   * <p>and the argument list
   *
   * <blockquote>{2}</blockquote>
   *
   * <p>returns
   *
   * <blockquote>
   * <pre>select distinct f0, f2 from t</pre>
   * </blockquote>
   *
   * <p>The <code>sourceOf</code> map is populated with the source of each
   * column; in this case sourceOf.get(0) = 0, and sourceOf.get(1) = 2.
   *
   * @param relBuilder Relational expression builder
   * @param aggregate Aggregate relational expression
   * @param argList   Ordinals of columns to make distinct
   * @param filterArg Ordinal of column to filter on, or -1
   * @param sourceOf  Out parameter, is populated with a map of where each
   *                  output field came from
   * @return Aggregate relational expression which projects the required
   * columns
   */
  private static RelBuilder createSelectDistinct(RelBuilder relBuilder,
      Aggregate aggregate, List<Integer> argList, int filterArg,
      Map<Integer, Integer> sourceOf) {
    relBuilder.push(aggregate.getInput());
    final PairList<RexNode, String> projects = PairList.of();
    final List<RelDataTypeField> childFields =
        relBuilder.peek().getRowType().getFieldList();
    for (int i : aggregate.getGroupSet()) {
      sourceOf.put(i, projects.size());
      RexInputRef.add2(projects, i, childFields);
    }
    for (Integer arg : argList) {
      if (filterArg >= 0) {
        // Implement
        //   agg(DISTINCT arg) FILTER $f
        // by generating
        //   SELECT DISTINCT ... CASE WHEN $f THEN arg ELSE NULL END AS arg
        // and then applying
        //   agg(arg)
        // as usual.
        //
        // It works except for (rare) agg functions that need to see null
        // values.
        final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        final RexInputRef filterRef = RexInputRef.of(filterArg, childFields);
        final Pair<RexNode, String> argRef = RexInputRef.of2(arg, childFields);
        RexNode condition =
            rexBuilder.makeCall(SqlStdOperatorTable.CASE, filterRef,
                argRef.left,
                rexBuilder.makeNullLiteral(argRef.left.getType()));
        sourceOf.put(arg, projects.size());
        projects.add(condition, "i$" + argRef.right);
        continue;
      }
      if (sourceOf.get(arg) != null) {
        continue;
      }
      sourceOf.put(arg, projects.size());
      RexInputRef.add2(projects, arg, childFields);
    }
    relBuilder.project(projects.leftList(), projects.rightList());

    // Get the distinct values of the GROUP BY fields and the arguments
    // to the agg functions.
    relBuilder.push(
        aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
            ImmutableBitSet.range(projects.size()), null, ImmutableList.of()));
    return relBuilder;
  }

  /** Rule configuration. */
  @Value.Immutable
  public interface Config extends RelRule.Config {
    Config DEFAULT = ImmutableAggregateExpandDistinctAggregatesRule.Config.of()
        .withOperandSupplier(b ->
            b.operand(LogicalAggregate.class).anyInputs());

    Config JOIN = DEFAULT.withUsingGroupingSets(false);

    @Override default AggregateExpandDistinctAggregatesRule toRule() {
      return new AggregateExpandDistinctAggregatesRule(this);
    }

    /** Whether to use GROUPING SETS, default true. */
    @Value.Default default boolean isUsingGroupingSets() {
      return true;
    }

    /** Sets {@link #isUsingGroupingSets()}. */
    Config withUsingGroupingSets(boolean usingGroupingSets);
  }
}
