/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.plan.rules.logical;

import org.apache.flink.table.api.TableException;
import org.apache.flink.table.errorcode.TableErrors;
import org.apache.flink.util.Preconditions;

import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Aggregate.Group;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

/**
 * This rules is copied from Calcite's {@link org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule}.
 * Modification:
 * - Throws an exception if aggregate both contains approximate distinct aggregate call and distinct
 *   aggregate call.
 * - Exclude non-simple aggregate(e.g. CUBE, ROLLUP).
 * - Fix bug: Some aggregate functions (e.g. COUNT) has a non-null result even without any input.
 * - Fix bug: Add filter argument into rewritten aggregateCall if filter arg is not -1.
 */

/**
 * Planner rule that expands distinct aggregates
 * (such as {@code COUNT(DISTINCT x)}) from a
 * {@link org.apache.calcite.rel.core.Aggregate}.
 *
 * <p>How this is done depends upon the arguments to the function. If all
 * functions have the same argument
 * (e.g. {@code COUNT(DISTINCT x), SUM(DISTINCT x)} both have the argument
 * {@code x}) then one extra {@link org.apache.calcite.rel.core.Aggregate} is
 * sufficient.
 *
 * <p>If there are multiple arguments
 * (e.g. {@code COUNT(DISTINCT x), COUNT(DISTINCT y)})
 * the rule creates separate {@code Aggregate}s and combines using a
 * {@link org.apache.calcite.rel.core.Join}.
 */
public final class FlinkAggregateExpandDistinctAggregatesRule extends RelOptRule {
	//~ Static fields/initializers ---------------------------------------------

	/** The default instance of the rule; operates only on logical expressions. */
	public static final FlinkAggregateExpandDistinctAggregatesRule INSTANCE =
			new FlinkAggregateExpandDistinctAggregatesRule(LogicalAggregate.class, true,
					RelFactories.LOGICAL_BUILDER);

	/** Instance of the rule that operates only on logical expressions and
	 * generates a join. */
	public static final FlinkAggregateExpandDistinctAggregatesRule JOIN =
			new FlinkAggregateExpandDistinctAggregatesRule(LogicalAggregate.class, false,
					RelFactories.LOGICAL_BUILDER);

	public final boolean useGroupingSets;

	//~ Constructors -----------------------------------------------------------

	public FlinkAggregateExpandDistinctAggregatesRule(
			Class<? extends Aggregate> clazz,
			boolean useGroupingSets,
			RelBuilderFactory relBuilderFactory) {
		super(operand(clazz, any()), relBuilderFactory, null);
		this.useGroupingSets = useGroupingSets;
	}

	@Deprecated // to be removed before 2.0
	public FlinkAggregateExpandDistinctAggregatesRule(
			Class<? extends LogicalAggregate> clazz,
			boolean useGroupingSets,
			RelFactories.JoinFactory joinFactory) {
		this(clazz, useGroupingSets, RelBuilder.proto(Contexts.of(joinFactory)));
	}

	@Deprecated // to be removed before 2.0
	public FlinkAggregateExpandDistinctAggregatesRule(
			Class<? extends LogicalAggregate> clazz,
			RelFactories.JoinFactory joinFactory) {
		this(clazz, false, RelBuilder.proto(Contexts.of(joinFactory)));
	}

	//~ Methods ----------------------------------------------------------------

	public void onMatch(RelOptRuleCall call) {
		final Aggregate aggregate = call.rel(0);
		if (!aggregate.containsAccurateDistinctCall()) {
			return;
		}
		// Check unsupported aggregate which contains both approximate distinct call and
		// accurate distinct call.
		if (aggregate.containsApproximateDistinctCall()) {
			throw new TableException(TableErrors.INST.sqlDistinctConflict());
		}

		// If this aggregate is a non-simple aggregate(e.g. CUBE, ROLLUP)
		// and contains distinct calls, it should be transformed to simple aggregate first
		// by DecomposeGroupingSetsRule. Then this rule expands it's distinct aggregates.
		if (aggregate.getGroupSets().size() > 1) {
			return;
		}

		// Find all of the agg expressions. We use a LinkedHashSet to ensure determinism.
		int nonDistinctAggCallCount = 0;  // find all aggregate calls without distinct
		int filterCount = 0;
		int unsupportedNonDistinctAggCallCount = 0;
		final Set<Pair<List<Integer>, Integer>> argLists = new LinkedHashSet<>();
		for (AggregateCall aggCall : aggregate.getAggCallList()) {
			if (aggCall.filterArg >= 0) {
				++filterCount;
			}
			if (!aggCall.isDistinct()) {
				++nonDistinctAggCallCount;
				final SqlKind aggCallKind = aggCall.getAggregation().getKind();
				// We only support COUNT/SUM/MIN/MAX for the "single" count distinct optimization
				switch (aggCallKind) {
					case COUNT:
					case SUM:
					case SUM0:
					case MIN:
					case MAX:
						break;
					default:
						++unsupportedNonDistinctAggCallCount;
				}
			} else {
				argLists.add(Pair.of(aggCall.getArgList(), aggCall.filterArg));
			}
		}

		final int distinctAggCallCount =
				aggregate.getAggCallList().size() - nonDistinctAggCallCount;
		Preconditions.checkState(argLists.size() > 0, "containsDistinctCall lied");

		// If all of the agg expressions are distinct and have the same
		// arguments then we can use a more efficient form.
		if (nonDistinctAggCallCount == 0
				&& argLists.size() == 1
				&& aggregate.getGroupType() == Group.SIMPLE) {
			final Pair<List<Integer>, Integer> pair =
					com.google.common.collect.Iterables.getOnlyElement(argLists);
			final RelBuilder relBuilder = call.builder();
			convertMonopole(relBuilder, aggregate, pair.left, pair.right);
			call.transformTo(relBuilder.build());
			return;
		}

		if (useGroupingSets) {
			rewriteUsingGroupingSets(call, aggregate);
			return;
		}

		// If only one distinct aggregate and one or more non-distinct aggregates,
		// we can generate multi-phase aggregates
		if (distinctAggCallCount == 1 // one distinct aggregate
				&& filterCount == 0 // no filter
				&& unsupportedNonDistinctAggCallCount == 0 // sum/min/max/count in non-distinct aggregate
				&& nonDistinctAggCallCount > 0) { // one or more non-distinct aggregates
			final RelBuilder relBuilder = call.builder();
			convertSingletonDistinct(relBuilder, aggregate, argLists);
			call.transformTo(relBuilder.build());
			return;
		}

		// Create a list of the expressions which will yield the final result.
		// Initially, the expressions point to the input field.
		final List<RelDataTypeField> aggFields =
				aggregate.getRowType().getFieldList();
		final List<RexInputRef> refs = new ArrayList<>();
		final List<String> fieldNames = aggregate.getRowType().getFieldNames();
		final ImmutableBitSet groupSet = aggregate.getGroupSet();
		final int groupAndIndicatorCount =
				aggregate.getGroupCount() + aggregate.getIndicatorCount();
		for (int i : Util.range(groupAndIndicatorCount)) {
			refs.add(RexInputRef.of(i, aggFields));
		}

		// Aggregate the original relation, including any non-distinct aggregates.
		final List<AggregateCall> newAggCallList = new ArrayList<>();
		int i = -1;
		for (AggregateCall aggCall : aggregate.getAggCallList()) {
			++i;
			if (aggCall.isDistinct()) {
				refs.add(null);
				continue;
			}
			refs.add(
					new RexInputRef(
							groupAndIndicatorCount + newAggCallList.size(),
							aggFields.get(groupAndIndicatorCount + i).getType()));
			newAggCallList.add(aggCall);
		}

		// In the case where there are no non-distinct aggregates (regardless of
		// whether there are group bys), there's no need to generate the
		// extra aggregate and join.
		final RelBuilder relBuilder = call.builder();
		relBuilder.push(aggregate.getInput());
		int n = 0;
		if (!newAggCallList.isEmpty()) {
			final RelBuilder.GroupKey groupKey =
					relBuilder.groupKey(groupSet, aggregate.getGroupSets());
			relBuilder.aggregate(groupKey, newAggCallList);
			++n;
		}

		// For each set of operands, find and rewrite all calls which have that
		// set of operands.
		for (Pair<List<Integer>, Integer> argList : argLists) {
			doRewrite(relBuilder, aggregate, n++, argList.left, argList.right, refs);
		}

		relBuilder.project(refs, fieldNames);
		call.transformTo(relBuilder.build());
	}

	/**
	 * Converts an aggregate with one distinct aggregate and one or more
	 * non-distinct aggregates to multi-phase aggregates (see reference example
	 * below).
	 *
	 * @param relBuilder Contains the input relational expression
	 * @param aggregate  Original aggregate
	 * @param argLists   Arguments and filters to the distinct aggregate function
	 *
	 */
	private RelBuilder convertSingletonDistinct(RelBuilder relBuilder,
			Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {

		// In this case, we are assuming that there is a single distinct function.
		// So make sure that argLists is of size one.
		Preconditions.checkArgument(argLists.size() == 1);

		// For example,
		//    SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
		//    FROM emp
		//    GROUP BY deptno
		//
		// becomes
		//
		//    SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
		//    FROM (
		//          SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
		//          FROM EMP
		//          GROUP BY deptno, sal)            // Aggregate B
		//    GROUP BY deptno                        // Aggregate A
		relBuilder.push(aggregate.getInput());

		final List<AggregateCall> originalAggCalls = aggregate.getAggCallList();
		final ImmutableBitSet originalGroupSet = aggregate.getGroupSet();

		// Add the distinct aggregate column(s) to the group-by columns,
		// if not already a part of the group-by
		final SortedSet<Integer> bottomGroupSet = new TreeSet<>();
		bottomGroupSet.addAll(aggregate.getGroupSet().asList());
		for (AggregateCall aggCall : originalAggCalls) {
			if (aggCall.isDistinct()) {
				bottomGroupSet.addAll(aggCall.getArgList());
				break;  // since we only have single distinct call
			}
		}

		// Generate the intermediate aggregate B, the one on the bottom that converts
		// a distinct call to group by call.
		// Bottom aggregate is the same as the original aggregate, except that
		// the bottom aggregate has converted the DISTINCT aggregate to a group by clause.
		final List<AggregateCall> bottomAggregateCalls = new ArrayList<>();
		for (AggregateCall aggCall : originalAggCalls) {
			// Project the column corresponding to the distinct aggregate. Project
			// as-is all the non-distinct aggregates
			if (!aggCall.isDistinct()) {
				final AggregateCall newCall =
						AggregateCall.create(aggCall.getAggregation(), false,
								aggCall.isApproximate(), aggCall.getArgList(), -1,
								ImmutableBitSet.of(bottomGroupSet).cardinality(),
								relBuilder.peek(), null, aggCall.name);
				bottomAggregateCalls.add(newCall);
			}
		}
		// Generate the aggregate B (see the reference example above)
		relBuilder.push(
				aggregate.copy(
						aggregate.getTraitSet(), relBuilder.build(),
						false, ImmutableBitSet.of(bottomGroupSet), null, bottomAggregateCalls));

		// Add aggregate A (see the reference example above), the top aggregate
		// to handle the rest of the aggregation that the bottom aggregate hasn't handled
		final List<AggregateCall> topAggregateCalls = com.google.common.collect.Lists.newArrayList();
		// Use the remapped arguments for the (non)distinct aggregate calls
		int nonDistinctAggCallProcessedSoFar = 0;
		for (AggregateCall aggCall : originalAggCalls) {
			final AggregateCall newCall;
			if (aggCall.isDistinct()) {
				List<Integer> newArgList = new ArrayList<>();
				for (int arg : aggCall.getArgList()) {
					newArgList.add(bottomGroupSet.headSet(arg).size());
				}
				newCall =
						AggregateCall.create(aggCall.getAggregation(),
								false,
								aggCall.isApproximate(),
								newArgList,
								-1,
								originalGroupSet.cardinality(),
								relBuilder.peek(),
								aggCall.getType(),
								aggCall.name);
			} else {
				// If aggregate B had a COUNT aggregate call the corresponding aggregate at
				// aggregate A must be SUM. For other aggregates, it remains the same.
				final List<Integer> newArgs =
						com.google.common.collect.Lists.newArrayList(
								bottomGroupSet.size() + nonDistinctAggCallProcessedSoFar);
				if (aggCall.getAggregation().getKind() == SqlKind.COUNT) {
					newCall =
							AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), false,
									aggCall.isApproximate(), newArgs, -1,
									originalGroupSet.cardinality(), relBuilder.peek(),
									aggCall.getType(), aggCall.getName());
				} else {
					newCall =
							AggregateCall.create(aggCall.getAggregation(), false,
									aggCall.isApproximate(), newArgs, -1,
									originalGroupSet.cardinality(),
									relBuilder.peek(), aggCall.getType(), aggCall.name);
				}
				nonDistinctAggCallProcessedSoFar++;
			}

			topAggregateCalls.add(newCall);
		}

		// Populate the group-by keys with the remapped arguments for aggregate A
		// The top groupset is basically an identity (first X fields of aggregate B's
		// output), minus the distinct aggCall's input.
		final Set<Integer> topGroupSet = new HashSet<>();
		int groupSetToAdd = 0;
		for (int bottomGroup : bottomGroupSet) {
			if (originalGroupSet.get(bottomGroup)) {
				topGroupSet.add(groupSetToAdd);
			}
			groupSetToAdd++;
		}
		relBuilder.push(
				aggregate.copy(aggregate.getTraitSet(),
						relBuilder.build(), aggregate.indicator,
						ImmutableBitSet.of(topGroupSet), null, topAggregateCalls));
		return relBuilder;
	}

	private void rewriteUsingGroupingSets(RelOptRuleCall call,
			Aggregate aggregate) {
		final Set<ImmutableBitSet> groupSetTreeSet =
				new TreeSet<>(ImmutableBitSet.ORDERING);
		for (AggregateCall aggCall : aggregate.getAggCallList()) {
			if (!aggCall.isDistinct()) {
				groupSetTreeSet.add(aggregate.getGroupSet());
			} else {
				groupSetTreeSet.add(
						ImmutableBitSet.of(aggCall.getArgList())
								.setIf(aggCall.filterArg, aggCall.filterArg >= 0)
								.union(aggregate.getGroupSet()));
			}
		}

		final com.google.common.collect.ImmutableList<ImmutableBitSet> groupSets =
				com.google.common.collect.ImmutableList.copyOf(groupSetTreeSet);
		final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);

		final List<AggregateCall> distinctAggCalls = new ArrayList<>();
		for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
			if (!aggCall.left.isDistinct()) {
				AggregateCall newAggCall = aggCall.left.adaptTo(
						aggregate.getInput(),
						aggCall.left.getArgList(),
						aggCall.left.filterArg,
						aggregate.getGroupCount(),
						fullGroupSet.cardinality());
				distinctAggCalls.add(newAggCall.rename(aggCall.right));
			}
		}

		final RelBuilder relBuilder = call.builder();
		relBuilder.push(aggregate.getInput());
		final int groupCount = fullGroupSet.cardinality();

		final Map<ImmutableBitSet, Integer> filters = new LinkedHashMap<>();
		final int z = groupCount + distinctAggCalls.size();
		distinctAggCalls.add(
				AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false,
						ImmutableIntList.copyOf(fullGroupSet), -1, groupSets.size(),
						relBuilder.peek(), null, "$g"));
		for (Ord<ImmutableBitSet> groupSet : Ord.zip(groupSets)) {
			filters.put(groupSet.e, z + groupSet.i);
		}

		relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets),
				distinctAggCalls);
		final RelNode distinct = relBuilder.peek();

		// GROUPING returns an integer (0 or 1). Add a project to convert those
		// values to BOOLEAN.
		if (!filters.isEmpty()) {
			final List<RexNode> nodes = new ArrayList<>(relBuilder.fields());
			final RexNode nodeZ = nodes.remove(nodes.size() - 1);
			for (Map.Entry<ImmutableBitSet, Integer> entry : filters.entrySet()) {
				final long v = groupValue(fullGroupSet, entry.getKey());
				nodes.add(
						relBuilder.alias(
								relBuilder.equals(nodeZ, relBuilder.literal(v)),
								"$g_" + v));
			}
			relBuilder.project(nodes);
		}

		int aggCallIdx = 0;
		int x = groupCount;
		final List<AggregateCall> newCalls = new ArrayList<>();
		// TODO supports more aggCalls (currently only supports COUNT)
		// Some aggregate functions (e.g. COUNT) have the special property that they can return a
		// non-null result without any input. We need to make sure we return a result in this case.
		final List<Integer> needDefaultValueAggCalls = new ArrayList<>();
		for (AggregateCall aggCall : aggregate.getAggCallList()) {
			final int newFilterArg;
			final List<Integer> newArgList;
			final SqlAggFunction aggregation;
			if (!aggCall.isDistinct()) {
				aggregation = SqlStdOperatorTable.MIN;
				newArgList = ImmutableIntList.of(x++);
				newFilterArg = filters.get(aggregate.getGroupSet());
				switch (aggCall.getAggregation().getKind()) {
					case COUNT:
						needDefaultValueAggCalls.add(aggCallIdx);
						break;
					default:
				}
			} else {
				aggregation = aggCall.getAggregation();
				newArgList = remap(fullGroupSet, aggCall.getArgList());
				newFilterArg =
						filters.get(
								ImmutableBitSet.of(aggCall.getArgList())
										.setIf(aggCall.filterArg, aggCall.filterArg >= 0)
										.union(aggregate.getGroupSet()));
			}
			final AggregateCall newCall =
					AggregateCall.create(aggregation, false, aggCall.isApproximate(),
							newArgList, newFilterArg, aggregate.getGroupCount(), distinct,
							null, aggCall.name);
			newCalls.add(newCall);
			aggCallIdx++;
		}

		relBuilder.aggregate(
				relBuilder.groupKey(
						remap(fullGroupSet, aggregate.getGroupSet()),
						remap(fullGroupSet, aggregate.getGroupSets())),
				newCalls);
		if (!needDefaultValueAggCalls.isEmpty() && aggregate.getGroupCount() == 0) {
			final Aggregate newAgg = (Aggregate) relBuilder.peek();
			final List<RexNode> nodes = new ArrayList<>();
			for (int i = 0; i < newAgg.getGroupCount(); ++i) {
				nodes.add(RexInputRef.of(i, newAgg.getRowType()));
			}
			for (int i = 0; i < newAgg.getAggCallList().size(); ++i) {
				final RexNode inputRef = RexInputRef.of(newAgg.getGroupCount() + i, newAgg.getRowType());
				RexNode newNode = inputRef;
				if (needDefaultValueAggCalls.contains(i)) {
					SqlKind originalFunKind = aggregate.getAggCallList().get(i).getAggregation().getKind();
					switch (originalFunKind) {
						case COUNT:
							newNode = relBuilder.call(
									SqlStdOperatorTable.CASE,
									relBuilder.isNotNull(inputRef),
									inputRef,
									relBuilder.literal(BigDecimal.ZERO));
							break;
						default:
					}
				}
				nodes.add(newNode);
			}
			relBuilder.project(nodes);
		}

		relBuilder.convert(aggregate.getRowType(), true);
		call.transformTo(relBuilder.build());
	}

	private static long groupValue(ImmutableBitSet fullGroupSet,
			ImmutableBitSet groupSet) {
		long v = 0;
		long x = 1L << (fullGroupSet.cardinality() - 1);
		assert fullGroupSet.contains(groupSet);
		for (int i : fullGroupSet) {
			if (!groupSet.get(i)) {
				v |= x;
			}
			x >>= 1;
		}
		return v;
	}

	private static ImmutableBitSet remap(ImmutableBitSet groupSet,
			ImmutableBitSet bitSet) {
		final ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
		for (Integer bit : bitSet) {
			builder.set(remap(groupSet, bit));
		}
		return builder.build();
	}

	private static com.google.common.collect.ImmutableList<ImmutableBitSet> remap(
			ImmutableBitSet groupSet,
			Iterable<ImmutableBitSet> bitSets) {
		final com.google.common.collect.ImmutableList.Builder<ImmutableBitSet> builder =
				com.google.common.collect.ImmutableList.builder();
		for (ImmutableBitSet bitSet : bitSets) {
			builder.add(remap(groupSet, bitSet));
		}
		return builder.build();
	}

	private static List<Integer> remap(ImmutableBitSet groupSet,
			List<Integer> argList) {
		ImmutableIntList list = ImmutableIntList.of();
		for (int arg : argList) {
			list = list.append(remap(groupSet, arg));
		}
		return list;
	}

	private static int remap(ImmutableBitSet groupSet, int arg) {
		return arg < 0 ? -1 : groupSet.indexOf(arg);
	}

	/**
	 * Converts an aggregate relational expression that contains just one
	 * distinct aggregate function (or perhaps several over the same arguments)
	 * and no non-distinct aggregate functions.
	 */
	private RelBuilder convertMonopole(RelBuilder relBuilder, Aggregate aggregate,
			List<Integer> argList, int filterArg) {
		// For example,
		//    SELECT deptno, COUNT(DISTINCT sal), SUM(DISTINCT sal)
		//    FROM emp
		//    GROUP BY deptno
		//
		// becomes
		//
		//    SELECT deptno, COUNT(distinct_sal), SUM(distinct_sal)
		//    FROM (
		//      SELECT DISTINCT deptno, sal AS distinct_sal
		//      FROM EMP GROUP BY deptno)
		//    GROUP BY deptno

		// Project the columns of the GROUP BY plus the arguments
		// to the agg function.
		final Map<Integer, Integer> sourceOf = new HashMap<>();
		createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);

		// Create an aggregate on top, with the new aggregate list.
		final List<AggregateCall> newAggCalls =
				com.google.common.collect.Lists.newArrayList(aggregate.getAggCallList());
		rewriteAggCalls(newAggCalls, argList, sourceOf);
		final int cardinality = aggregate.getGroupSet().cardinality();
		relBuilder.push(
				aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
						aggregate.indicator, ImmutableBitSet.range(cardinality), null,
						newAggCalls));
		return relBuilder;
	}

	/**
	 * Converts all distinct aggregate calls to a given set of arguments.
	 *
	 * <p>This method is called several times, one for each set of arguments.
	 * Each time it is called, it generates a JOIN to a new SELECT DISTINCT
	 * relational expression, and modifies the set of top-level calls.
	 *
	 * @param aggregate Original aggregate
	 * @param n         Ordinal of this in a join. {@code relBuilder} contains the
	 *                  input relational expression (either the original
	 *                  aggregate, the output from the previous call to this
	 *                  method. {@code n} is 0 if we're converting the
	 *                  first distinct aggregate in a query with no non-distinct
	 *                  aggregates)
	 * @param argList   Arguments to the distinct aggregate function
	 * @param filterArg Argument that filters input to aggregate function, or -1
	 * @param refs      Array of expressions which will be the projected by the
	 *                  result of this rule. Those relating to this arg list will
	 *                  be modified  @return Relational expression
	 */
	private void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n,
			List<Integer> argList, int filterArg, List<RexInputRef> refs) {
		final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
		final List<RelDataTypeField> leftFields;
		if (n == 0) {
			leftFields = null;
		} else {
			leftFields = relBuilder.peek().getRowType().getFieldList();
		}

		// Aggregate(
		//     child,
		//     {COUNT(DISTINCT 1), SUM(DISTINCT 1), SUM(2)})
		//
		// becomes
		//
		// Aggregate(
		//     Join(
		//         child,
		//         Aggregate(child, < all columns > {}),
		//         INNER,
		//         <f2 = f5>))
		//
		// E.g.
		//   SELECT deptno, SUM(DISTINCT sal), COUNT(DISTINCT gender), MAX(age)
		//   FROM Emps
		//   GROUP BY deptno
		//
		// becomes
		//
		//   SELECT e.deptno, adsal.sum_sal, adgender.count_gender, e.max_age
		//   FROM (
		//     SELECT deptno, MAX(age) as max_age
		//     FROM Emps GROUP BY deptno) AS e
		//   JOIN (
		//     SELECT deptno, COUNT(gender) AS count_gender FROM (
		//       SELECT DISTINCT deptno, gender FROM Emps) AS dgender
		//     GROUP BY deptno) AS adgender
		//     ON e.deptno = adgender.deptno
		//   JOIN (
		//     SELECT deptno, SUM(sal) AS sum_sal FROM (
		//       SELECT DISTINCT deptno, sal FROM Emps) AS dsal
		//     GROUP BY deptno) AS adsal
		//   ON e.deptno = adsal.deptno
		//   GROUP BY e.deptno
		//
		// Note that if a query contains no non-distinct aggregates, then the
		// very first join/group by is omitted.  In the example above, if
		// MAX(age) is removed, then the sub-select of "e" is not needed, and
		// instead the two other group by's are joined to one another.

		// Project the columns of the GROUP BY plus the arguments
		// to the agg function.
		final Map<Integer, Integer> sourceOf = new HashMap<>();
		createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);

		// Now compute the aggregate functions on top of the distinct dataset.
		// Each distinct agg becomes a non-distinct call to the corresponding
		// field from the right; for example,
		//   "COUNT(DISTINCT e.sal)"
		// becomes
		//   "COUNT(distinct_e.sal)".
		final List<AggregateCall> aggCallList = new ArrayList<>();
		final List<AggregateCall> aggCalls = aggregate.getAggCallList();

		final int groupAndIndicatorCount =
				aggregate.getGroupCount() + aggregate.getIndicatorCount();
		int i = groupAndIndicatorCount - 1;
		for (AggregateCall aggCall : aggCalls) {
			++i;

			// Ignore agg calls which are not distinct or have the wrong set
			// arguments. If we're rewriting aggs whose args are {sal}, we will
			// rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) but ignore
			// COUNT(DISTINCT gender) or SUM(sal).
			if (!aggCall.isDistinct()) {
				continue;
			}
			if (!aggCall.getArgList().equals(argList)) {
				continue;
			}

			// Re-map arguments.
			final int argCount = aggCall.getArgList().size();
			final List<Integer> newArgs = new ArrayList<>(argCount);
			for (int j = 0; j < argCount; j++) {
				final Integer arg = aggCall.getArgList().get(j);
				newArgs.add(sourceOf.get(arg));
			}
			final int newFilterArg =
					aggCall.filterArg >= 0 ? sourceOf.get(aggCall.filterArg) : -1;
			final AggregateCall newAggCall =
					AggregateCall.create(aggCall.getAggregation(), false,
							aggCall.isApproximate(), newArgs,
							newFilterArg, aggCall.getType(), aggCall.getName());
			assert refs.get(i) == null;
			if (n == 0) {
				refs.set(i,
						new RexInputRef(groupAndIndicatorCount + aggCallList.size(),
								newAggCall.getType()));
			} else {
				refs.set(i,
						new RexInputRef(leftFields.size() + groupAndIndicatorCount
								+ aggCallList.size(), newAggCall.getType()));
			}
			aggCallList.add(newAggCall);
		}

		final Map<Integer, Integer> map = new HashMap<>();
		for (Integer key : aggregate.getGroupSet()) {
			map.put(key, map.size());
		}
		final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map);
		assert newGroupSet
				.equals(ImmutableBitSet.range(aggregate.getGroupSet().cardinality()));
		com.google.common.collect.ImmutableList<ImmutableBitSet> newGroupingSets = null;
		if (aggregate.indicator) {
			newGroupingSets =
					ImmutableBitSet.ORDERING.immutableSortedCopy(
							ImmutableBitSet.permute(aggregate.getGroupSets(), map));
		}

		relBuilder.push(
				aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
						aggregate.indicator, newGroupSet, newGroupingSets, aggCallList));

		// If there's no left child yet, no need to create the join
		if (n == 0) {
			return;
		}

		// Create the join condition. It is of the form
		//  'left.f0 = right.f0 and left.f1 = right.f1 and ...'
		// where {f0, f1, ...} are the GROUP BY fields.
		final List<RelDataTypeField> distinctFields =
				relBuilder.peek().getRowType().getFieldList();
		final List<RexNode> conditions = com.google.common.collect.Lists.newArrayList();
		for (i = 0; i < groupAndIndicatorCount; ++i) {
			// null values form its own group
			// use "is not distinct from" so that the join condition
			// allows null values to match.
			conditions.add(
					rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM,
							RexInputRef.of(i, leftFields),
							new RexInputRef(leftFields.size() + i,
									distinctFields.get(i).getType())));
		}

		// Join in the new 'select distinct' relation.
		relBuilder.join(JoinRelType.INNER, conditions);
	}

	private static void rewriteAggCalls(
			List<AggregateCall> newAggCalls,
			List<Integer> argList,
			Map<Integer, Integer> sourceOf) {
		// Rewrite the agg calls. Each distinct agg becomes a non-distinct call
		// to the corresponding field from the right; for example,
		// "COUNT(DISTINCT e.sal)" becomes   "COUNT(distinct_e.sal)".
		for (int i = 0; i < newAggCalls.size(); i++) {
			final AggregateCall aggCall = newAggCalls.get(i);

			// Ignore agg calls which are not distinct or have the wrong set
			// arguments. If we're rewriting aggregates whose args are {sal}, we will
			// rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) but ignore
			// COUNT(DISTINCT gender) or SUM(sal).
			if (!aggCall.isDistinct()) {
				continue;
			}
			if (!aggCall.getArgList().equals(argList)) {
				continue;
			}

			// Re-map arguments.
			final int argCount = aggCall.getArgList().size();
			final List<Integer> newArgs = new ArrayList<>(argCount);
			for (int j = 0; j < argCount; j++) {
				final Integer arg = aggCall.getArgList().get(j);
				newArgs.add(sourceOf.get(arg));
			}
			final AggregateCall newAggCall =
					AggregateCall.create(aggCall.getAggregation(), false,
							aggCall.isApproximate(), newArgs, -1,
							aggCall.getType(), aggCall.getName());
			newAggCalls.set(i, newAggCall);
		}
	}

	/**
	 * Given an {@link org.apache.calcite.rel.core.Aggregate}
	 * and the ordinals of the arguments to a
	 * particular call to an aggregate function, creates a 'select distinct'
	 * relational expression which projects the group columns and those
	 * arguments but nothing else.
	 *
	 * <p>For example, given
	 *
	 * <blockquote>
	 * <pre>select f0, count(distinct f1), count(distinct f2)
	 * from t group by f0</pre>
	 * </blockquote>
	 *
	 * <p>and the argument list
	 *
	 * <blockquote>{2}</blockquote>
	 *
	 * <p>returns
	 *
	 * <blockquote>
	 * <pre>select distinct f0, f2 from t</pre>
	 * </blockquote>
	 *
	 * <p>The <code>sourceOf</code> map is populated with the source of each
	 * column; in this case sourceOf.get(0) = 0, and sourceOf.get(1) = 2.
	 *
	 * @param relBuilder Relational expression builder
	 * @param aggregate Aggregate relational expression
	 * @param argList   Ordinals of columns to make distinct
	 * @param filterArg Ordinal of column to filter on, or -1
	 * @param sourceOf  Out parameter, is populated with a map of where each
	 *                  output field came from
	 * @return Aggregate relational expression which projects the required
	 * columns
	 */
	private RelBuilder createSelectDistinct(RelBuilder relBuilder,
			Aggregate aggregate, List<Integer> argList, int filterArg,
			Map<Integer, Integer> sourceOf) {
		relBuilder.push(aggregate.getInput());
		final List<Pair<RexNode, String>> projects = new ArrayList<>();
		final List<RelDataTypeField> childFields =
				relBuilder.peek().getRowType().getFieldList();
		for (int i : aggregate.getGroupSet()) {
			sourceOf.put(i, projects.size());
			projects.add(RexInputRef.of2(i, childFields));
		}
		if (filterArg >= 0) {
			sourceOf.put(filterArg, projects.size());
			projects.add(RexInputRef.of2(filterArg, childFields));
		}
		for (Integer arg : argList) {
			if (filterArg >= 0) {
				// Implement
				//   agg(DISTINCT arg) FILTER $f
				// by generating
				//   SELECT DISTINCT ... CASE WHEN $f THEN arg ELSE NULL END AS arg
				// and then applying
				//   agg(arg)
				// as usual.
				//
				// It works except for (rare) agg functions that need to see null
				// values.
				final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
				final RexInputRef filterRef = RexInputRef.of(filterArg, childFields);
				final Pair<RexNode, String> argRef = RexInputRef.of2(arg, childFields);
				RexNode condition =
						rexBuilder.makeCall(SqlStdOperatorTable.CASE, filterRef,
								argRef.left,
								rexBuilder.ensureType(argRef.left.getType(),
										rexBuilder.makeCast(argRef.left.getType(),
												rexBuilder.constantNull()),
										true));
				sourceOf.put(arg, projects.size());
				projects.add(Pair.of(condition, "i$" + argRef.right));
				continue;
			}
			if (sourceOf.get(arg) != null) {
				continue;
			}
			sourceOf.put(arg, projects.size());
			projects.add(RexInputRef.of2(arg, childFields));
		}
		relBuilder.project(Pair.left(projects), Pair.right(projects));

		// Get the distinct values of the GROUP BY fields and the arguments
		// to the agg functions.
		relBuilder.push(
				aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), false,
						ImmutableBitSet.range(projects.size()),
						null, com.google.common.collect.ImmutableList.<AggregateCall>of()));
		return relBuilder;
	}
}

// End FlinkAggregateExpandDistinctAggregatesRule.java
