blob: 7c908cc0e8dce5bdcfb7820d3ca62a779e0773ed [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.interpreter;
import org.apache.calcite.DataContext;
import org.apache.calcite.adapter.enumerable.AggAddContext;
import org.apache.calcite.adapter.enumerable.AggImpState;
import org.apache.calcite.adapter.enumerable.JavaRowFormat;
import org.apache.calcite.adapter.enumerable.PhysType;
import org.apache.calcite.adapter.enumerable.PhysTypeImpl;
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
import org.apache.calcite.adapter.enumerable.impl.AggAddContextImpl;
import org.apache.calcite.adapter.java.JavaTypeFactory;
import org.apache.calcite.interpreter.Row.RowBuilder;
import org.apache.calcite.linq4j.tree.BlockBuilder;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.ParameterExpression;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.runtime.FunctionContexts;
import org.apache.calcite.schema.FunctionContext;
import org.apache.calcite.schema.impl.AggregateFunctionImpl;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlInternalOperators;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.validate.SqlConformance;
import org.apache.calcite.sql.validate.SqlConformanceEnum;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Util;
import com.google.common.collect.ImmutableList;
import org.checkerframework.checker.nullness.qual.Nullable;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import static org.apache.calcite.linq4j.Nullness.castNonNull;
import static java.util.Objects.requireNonNull;
/**
* Interpreter node that implements an
* {@link org.apache.calcite.rel.core.Aggregate}.
*/
public class AggregateNode extends AbstractSingleNode<Aggregate> {
private final List<Grouping> groups = new ArrayList<>();
private final ImmutableBitSet unionGroups;
private final int outputRowLength;
private final ImmutableList<AccumulatorFactory> accumulatorFactories;
private final DataContext dataContext;
public AggregateNode(Compiler compiler, Aggregate rel) {
super(compiler, rel);
this.dataContext = compiler.getDataContext();
ImmutableBitSet union = ImmutableBitSet.of();
for (ImmutableBitSet group : rel.getGroupSets()) {
union = union.union(group);
groups.add(new Grouping(group));
}
this.unionGroups = union;
this.outputRowLength = unionGroups.cardinality()
+ rel.getAggCallList().size();
ImmutableList.Builder<AccumulatorFactory> builder = ImmutableList.builder();
for (AggregateCall aggregateCall : rel.getAggCallList()) {
@SuppressWarnings("method.invocation.invalid")
AccumulatorFactory accumulator =
getAccumulator(compiler, aggregateCall, false);
builder.add(accumulator);
}
accumulatorFactories = builder.build();
}
@Override public void run() throws InterruptedException {
Row r;
while ((r = source.receive()) != null) {
for (Grouping group : groups) {
group.send(r);
}
}
for (Grouping group : groups) {
group.end(sink);
}
}
private AccumulatorFactory getAccumulator(Compiler compiler,
final AggregateCall call, boolean ignoreFilter) {
if (call.filterArg >= 0 && !ignoreFilter) {
final AccumulatorFactory factory = getAccumulator(compiler, call, true);
return () -> {
final Accumulator accumulator = factory.get();
return new FilterAccumulator(accumulator, call.filterArg);
};
}
final SqlAggFunction op = call.getAggregation();
if (op == SqlStdOperatorTable.COUNT) {
return () -> new CountAccumulator(call);
} else if (op == SqlStdOperatorTable.SUM
|| op == SqlStdOperatorTable.SUM0) {
final Class<?> clazz = sumClass(call);
return new UdaAccumulatorFactory(getAggFunction(clazz), call,
op == SqlStdOperatorTable.SUM, dataContext);
} else if (op == SqlStdOperatorTable.MAX
|| op == SqlStdOperatorTable.MIN) {
final Class<?> clazz = maxMinClass(call);
return new UdaAccumulatorFactory(getAggFunction(clazz), call, true,
dataContext);
} else if (op == SqlInternalOperators.LITERAL_AGG) {
final Scalar scalar = compiler.compile(call.rexList, null);
final Object value = scalar.execute(compiler.createContext());
return () -> new LiteralAccumulator(value);
} else {
final JavaTypeFactory typeFactory =
(JavaTypeFactory) rel.getCluster().getTypeFactory();
int stateOffset = 0;
final AggImpState agg = new AggImpState(0, call, false);
int stateSize = requireNonNull(agg.state, "agg.state").size();
final BlockBuilder builder2 = new BlockBuilder();
final PhysType inputPhysType =
PhysTypeImpl.of(typeFactory, rel.getInput().getRowType(),
JavaRowFormat.ARRAY);
final RelDataTypeFactory.Builder builder = typeFactory.builder();
for (Expression expression : agg.state) {
builder.add("a",
typeFactory.createJavaType((Class) expression.getType()));
}
final PhysType accPhysType =
PhysTypeImpl.of(typeFactory, builder.build(), JavaRowFormat.ARRAY);
final ParameterExpression inParameter =
Expressions.parameter(inputPhysType.getJavaRowType(), "in");
final ParameterExpression acc_ =
Expressions.parameter(accPhysType.getJavaRowType(), "acc");
List<Expression> accumulator = new ArrayList<>(stateSize);
for (int j = 0; j < stateSize; j++) {
accumulator.add(accPhysType.fieldReference(acc_, j + stateOffset));
}
agg.state = accumulator;
AggAddContext addContext =
new AggAddContextImpl(builder2, accumulator) {
@Override public List<RexNode> rexArguments() {
List<RexNode> args = new ArrayList<>();
for (int index : agg.call.getArgList()) {
args.add(RexInputRef.of(index, inputPhysType.getRowType()));
}
return args;
}
@Override public @Nullable RexNode rexFilterArgument() {
return agg.call.filterArg < 0
? null
: RexInputRef.of(agg.call.filterArg,
inputPhysType.getRowType());
}
@Override public RexToLixTranslator rowTranslator() {
final SqlConformance conformance =
SqlConformanceEnum.DEFAULT; // TODO: get this from implementor
return RexToLixTranslator.forAggregation(typeFactory,
currentBlock(),
new RexToLixTranslator.InputGetterImpl(inParameter,
inputPhysType),
conformance);
}
};
agg.implementor.implementAdd(requireNonNull(agg.context, "agg.context"), addContext);
final ParameterExpression context_ =
Expressions.parameter(Context.class, "context");
final ParameterExpression outputValues_ =
Expressions.parameter(Object[].class, "outputValues");
final Scalar.Producer addScalarProducer =
JaninoRexCompiler.baz(context_, outputValues_, builder2.toBlock(),
ImmutableList.of());
final Scalar initScalar = castNonNull(null);
final Scalar addScalar = addScalarProducer.apply(dataContext);
final Scalar endScalar = castNonNull(null);
return new ScalarAccumulatorDef(initScalar, addScalar, endScalar,
rel.getInput().getRowType().getFieldCount(), stateSize, dataContext);
}
}
private static Class<?> maxMinClass(AggregateCall call) {
boolean max = call.getAggregation() == SqlStdOperatorTable.MAX;
switch (call.getType().getSqlTypeName()) {
case INTEGER:
return max ? MaxInt.class : MinInt.class;
case REAL:
return max ? MaxFloat.class : MinFloat.class;
case FLOAT:
case DOUBLE:
return max ? MaxDouble.class : MinDouble.class;
case DECIMAL:
return max ? MaxBigDecimal.class : MinBigDecimal.class;
case BOOLEAN:
return max ? MaxBoolean.class : MinBoolean.class;
default:
return max ? MaxLong.class : MinLong.class;
}
}
private static Class<?> sumClass(AggregateCall call) {
switch (call.type.getSqlTypeName()) {
case DOUBLE:
case REAL:
case FLOAT:
return DoubleSum.class;
case DECIMAL:
return BigDecimalSum.class;
case INTEGER:
return IntSum.class;
case BIGINT:
default:
return LongSum.class;
}
}
private static AggregateFunctionImpl getAggFunction(Class<?> clazz) {
return requireNonNull(
AggregateFunctionImpl.create(clazz),
() -> "Unable to create AggregateFunctionImpl for " + clazz);
}
/** Accumulator for calls to the COUNT function. */
private static class CountAccumulator implements Accumulator {
private final AggregateCall call;
long cnt;
CountAccumulator(AggregateCall call) {
this.call = call;
cnt = 0;
}
@Override public void send(Row row) {
boolean notNull = true;
for (Integer i : call.getArgList()) {
if (row.getObject(i) == null) {
notNull = false;
break;
}
}
if (notNull) {
cnt++;
}
}
@Override public Object end() {
return cnt;
}
}
/** Accumulator for calls to the LITERAL_AGG function. */
private static class LiteralAccumulator implements Accumulator {
private final @Nullable Object value;
LiteralAccumulator(@Nullable Object value) {
this.value = value;
}
@Override public void send(Row row) {
}
@Override public @Nullable Object end() {
return value;
}
}
/** Creates an {@link Accumulator}. */
private interface AccumulatorFactory extends Supplier<Accumulator> {
}
/** Accumulator powered by {@link Scalar} code fragments. */
private static class ScalarAccumulatorDef implements AccumulatorFactory {
final Scalar initScalar;
final Scalar addScalar;
final Scalar endScalar;
final Context sendContext;
final Context endContext;
final int rowLength;
final int accumulatorLength;
private ScalarAccumulatorDef(Scalar initScalar, Scalar addScalar,
Scalar endScalar, int rowLength, int accumulatorLength,
DataContext root) {
this.initScalar = initScalar;
this.addScalar = addScalar;
this.endScalar = endScalar;
this.accumulatorLength = accumulatorLength;
this.rowLength = rowLength;
this.sendContext = new Context(root);
this.sendContext.values = new Object[rowLength + accumulatorLength];
this.endContext = new Context(root);
this.endContext.values = new Object[accumulatorLength];
}
@Override public Accumulator get() {
return new ScalarAccumulator(this, new Object[accumulatorLength]);
}
}
/** Accumulator powered by {@link Scalar} code fragments. */
private static class ScalarAccumulator implements Accumulator {
final ScalarAccumulatorDef def;
final Object[] values;
private ScalarAccumulator(ScalarAccumulatorDef def, Object[] values) {
this.def = def;
this.values = values;
}
@Override public void send(Row row) {
@Nullable Object[] sendValues =
requireNonNull(def.sendContext.values, "def.sendContext.values");
System.arraycopy(row.getValues(), 0, sendValues, 0,
def.rowLength);
System.arraycopy(this.values, 0, sendValues, def.rowLength,
this.values.length);
def.addScalar.execute(def.sendContext, this.values);
}
@Override public @Nullable Object end() {
Context endContext = requireNonNull(def.endContext, "def.endContext");
@Nullable Object[] values = requireNonNull(endContext.values, "endContext.values");
System.arraycopy(this.values, 0, values, 0, this.values.length);
return def.endScalar.execute(endContext);
}
}
/**
* Internal class to track groupings.
*/
private class Grouping {
private final ImmutableBitSet grouping;
private final Map<Row, AccumulatorList> accumulators = new HashMap<>();
private Grouping(ImmutableBitSet grouping) {
this.grouping = grouping;
}
public void send(Row row) {
// TODO: fix the size of this row.
RowBuilder builder = Row.newBuilder(grouping.cardinality());
int j = 0;
for (Integer i : grouping) {
builder.set(j++, row.getObject(i));
}
Row key = builder.build();
if (!accumulators.containsKey(key)) {
AccumulatorList list = new AccumulatorList();
for (AccumulatorFactory factory : accumulatorFactories) {
list.add(factory.get());
}
accumulators.put(key, list);
}
accumulators.get(key).send(row);
}
public void end(Sink sink) throws InterruptedException {
for (Map.Entry<Row, AccumulatorList> e : accumulators.entrySet()) {
final Row key = e.getKey();
final AccumulatorList list = e.getValue();
RowBuilder rb = Row.newBuilder(outputRowLength);
int index = 0;
for (Integer groupPos : unionGroups) {
if (grouping.get(groupPos)) {
rb.set(index, key.getObject(index));
}
// need to set false when not part of grouping set.
index++;
}
list.end(rb);
sink.send(rb.build());
}
}
}
/**
* A list of accumulators used during grouping.
*/
private static class AccumulatorList extends ArrayList<Accumulator> {
public void send(Row row) {
for (Accumulator a : this) {
a.send(row);
}
}
public void end(RowBuilder r) {
for (int accIndex = 0, rowIndex = r.size() - size();
rowIndex < r.size(); rowIndex++, accIndex++) {
r.set(rowIndex, get(accIndex).end());
}
}
}
/**
* Defines function implementation for
* things like {@code count()} and {@code sum()}.
*/
private interface Accumulator {
void send(Row row);
@Nullable Object end();
}
/** Implementation of {@code SUM} over INTEGER values as a user-defined
* aggregate. */
public static class IntSum {
public IntSum() {
}
public int init() {
return 0;
}
public int add(int accumulator, int v) {
return accumulator + v;
}
public int merge(int accumulator0, int accumulator1) {
return accumulator0 + accumulator1;
}
public int result(int accumulator) {
return accumulator;
}
}
/** Implementation of {@code SUM} over BIGINT values as a user-defined
* aggregate. */
public static class LongSum {
public LongSum() {
}
public long init() {
return 0L;
}
public long add(long accumulator, long v) {
return accumulator + v;
}
public long merge(long accumulator0, long accumulator1) {
return accumulator0 + accumulator1;
}
public long result(long accumulator) {
return accumulator;
}
}
/** Implementation of {@code SUM} over DOUBLE values as a user-defined
* aggregate. */
public static class DoubleSum {
public DoubleSum() {
}
public double init() {
return 0D;
}
public double add(double accumulator, double v) {
return accumulator + v;
}
public double merge(double accumulator0, double accumulator1) {
return accumulator0 + accumulator1;
}
public double result(double accumulator) {
return accumulator;
}
}
/** Implementation of {@code SUM} over BigDecimal values as a user-defined
* aggregate. */
public static class BigDecimalSum {
public BigDecimalSum(){
}
public BigDecimal init() {
return new BigDecimal("0");
}
public BigDecimal add(BigDecimal accumulator, BigDecimal v) {
return accumulator.add(v);
}
public BigDecimal merge(BigDecimal accumulator0, BigDecimal accumulator01) {
return add(accumulator0, accumulator01);
}
public BigDecimal result(BigDecimal accumulator) {
return accumulator;
}
}
/** Common implementation of comparison aggregate methods over numeric
* values as a user-defined aggregate.
*
* @param <T> The numeric type
*/
public static class NumericComparison<T> {
private final T initialValue;
private final BiFunction<T, T, T> comparisonFunction;
public NumericComparison(T initialValue, BiFunction<T, T, T> comparisonFunction) {
this.initialValue = initialValue;
this.comparisonFunction = comparisonFunction;
}
public T init() {
return this.initialValue;
}
public T add(T accumulator, T value) {
return this.comparisonFunction.apply(accumulator, value);
}
public T merge(T accumulator0, T accumulator1) {
return add(accumulator0, accumulator1);
}
public T result(T accumulator) {
return accumulator;
}
}
/** Implementation of {@code MIN} function to calculate the minimum of
* {@code integer} values as a user-defined aggregate.
*/
public static class MinInt extends NumericComparison<Integer> {
public MinInt() {
super(Integer.MAX_VALUE, Math::min);
}
}
/** Implementation of {@code MIN} function to calculate the minimum of
* {@code long} values as a user-defined aggregate.
*/
public static class MinLong extends NumericComparison<Long> {
public MinLong() {
super(Long.MAX_VALUE, Math::min);
}
}
/** Implementation of {@code MIN} function to calculate the minimum of
* {@code float} values as a user-defined aggregate.
*/
public static class MinFloat extends NumericComparison<Float> {
public MinFloat() {
super(Float.MAX_VALUE, Math::min);
}
}
/** Implementation of {@code MIN} function to calculate the minimum of
* {@code double} and {@code real} values as a user-defined aggregate.
*/
public static class MinDouble extends NumericComparison<Double> {
public MinDouble() {
super(Double.MAX_VALUE, Math::min);
}
}
/** Implementation of {@code MIN} function to calculate the minimum of
* {@code BigDecimal} values as a user-defined aggregate.
*/
public static class MinBigDecimal extends NumericComparison<BigDecimal> {
public MinBigDecimal() {
super(new BigDecimal(Double.MAX_VALUE), MinBigDecimal::min);
}
public static BigDecimal min(BigDecimal a, BigDecimal b) {
return a.min(b);
}
}
/** Implementation of {@code MIN} function to calculate the minimum of
* {@code boolean} values as a user-defined aggregate.
*/
public static class MinBoolean {
public MinBoolean() { }
public Boolean init() {
return Boolean.TRUE;
}
public Boolean add(Boolean accumulator, Boolean value) {
return accumulator && value;
}
public Boolean merge(Boolean accumulator0, Boolean accumulator1) {
return add(accumulator0, accumulator1);
}
public Boolean result(Boolean accumulator) {
return accumulator;
}
}
/** Implementation of {@code MAX} function to calculate the maximum of
* {@code integer} values as a user-defined aggregate.
*/
public static class MaxInt extends NumericComparison<Integer> {
public MaxInt() {
super(Integer.MIN_VALUE, Math::max);
}
}
/** Implementation of {@code MAX} function to calculate the maximum of
* {@code long} values as a user-defined aggregate.
*/
public static class MaxLong extends NumericComparison<Long> {
public MaxLong() {
super(Long.MIN_VALUE, Math::max);
}
}
/** Implementation of {@code MAX} function to calculate the maximum of
* {@code float} values as a user-defined aggregate.
*/
public static class MaxFloat extends NumericComparison<Float> {
public MaxFloat() {
super(Float.MIN_VALUE, Math::max);
}
}
/** Implementation of {@code MAX} function to calculate the maximum of
* {@code double} and {@code real} values as a user-defined aggregate.
*/
public static class MaxDouble extends NumericComparison<Double> {
public MaxDouble() {
super(Double.MIN_VALUE, Math::max);
}
}
/** Implementation of {@code MAX} function to calculate the maximum of
* {@code BigDecimal} values as a user-defined aggregate.
*/
public static class MaxBigDecimal extends NumericComparison<BigDecimal> {
public MaxBigDecimal() {
super(new BigDecimal(Double.MIN_VALUE), MaxBigDecimal::max);
}
public static BigDecimal max(BigDecimal a, BigDecimal b) {
return a.max(b);
}
}
/** Implementation of {@code MAX} function to calculate the maximum of
* {@code boolean} values as a user-defined aggregate.
*/
public static class MaxBoolean {
public MaxBoolean() { }
public Boolean init() {
return Boolean.FALSE;
}
public Boolean add(Boolean accumulator, Boolean value) {
return accumulator || value;
}
public Boolean merge(Boolean accumulator0, Boolean accumulator1) {
return add(accumulator0, accumulator1);
}
public Boolean result(Boolean accumulator) {
return accumulator;
}
}
/** Accumulator factory based on a user-defined aggregate function. */
private static class UdaAccumulatorFactory implements AccumulatorFactory {
final AggregateFunctionImpl aggFunction;
final int argOrdinal;
public final @Nullable Object instance;
public final boolean nullIfEmpty;
UdaAccumulatorFactory(AggregateFunctionImpl aggFunction,
AggregateCall call, boolean nullIfEmpty, DataContext dataContext) {
this.aggFunction = aggFunction;
if (call.getArgList().size() != 1) {
throw new UnsupportedOperationException("in current implementation, "
+ "aggregate must have precisely one argument");
}
argOrdinal = call.getArgList().get(0);
instance = createInstance(aggFunction, dataContext);
this.nullIfEmpty = nullIfEmpty;
}
static @Nullable Object createInstance(AggregateFunctionImpl aggFunction,
DataContext dataContext) {
if (aggFunction.isStatic) {
return null;
}
try {
final Constructor<?> constructor =
aggFunction.declaringClass.getConstructor();
return constructor.newInstance();
} catch (InstantiationException | IllegalAccessException
| NoSuchMethodException | InvocationTargetException e) {
// ignore, and try next constructor
}
try {
final Constructor<?> constructor =
aggFunction.declaringClass.getConstructor(FunctionContext.class);
final Object[] args = new Object[aggFunction.getParameters().size()];
final FunctionContext functionContext =
FunctionContexts.of(dataContext, args);
return constructor.newInstance(functionContext);
} catch (InstantiationException | IllegalAccessException
| NoSuchMethodException | InvocationTargetException e) {
throw Util.toUnchecked(e);
}
}
@Override public Accumulator get() {
return new UdaAccumulator(this);
}
}
/** Accumulator based upon a user-defined aggregate. */
private static class UdaAccumulator implements Accumulator {
private final UdaAccumulatorFactory factory;
private @Nullable Object value;
private boolean empty;
UdaAccumulator(UdaAccumulatorFactory factory) {
this.factory = factory;
try {
this.value = factory.aggFunction.initMethod.invoke(factory.instance);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
this.empty = true;
}
@Override public void send(Row row) {
final @Nullable Object[] args = {value, row.getValues()[factory.argOrdinal]};
for (int i = 1; i < args.length; i++) {
if (args[i] == null) {
return; // one of the arguments is null; don't add to the total
}
}
try {
value = factory.aggFunction.addMethod.invoke(factory.instance, args);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
empty = false;
}
@Override public @Nullable Object end() {
if (factory.nullIfEmpty && empty) {
return null;
}
final @Nullable Object[] args = {value};
try {
AggregateFunctionImpl aggFunction =
requireNonNull(factory.aggFunction, "factory.aggFunction");
return requireNonNull(aggFunction.resultMethod, "aggFunction.resultMethod")
.invoke(factory.instance, args);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
}
}
/** Accumulator that applies a filter to another accumulator.
* The filter is a BOOLEAN field in the input row. */
private static class FilterAccumulator implements Accumulator {
private final Accumulator accumulator;
private final int filterArg;
FilterAccumulator(Accumulator accumulator, int filterArg) {
this.accumulator = accumulator;
this.filterArg = filterArg;
}
@Override public void send(Row row) {
if (row.getValues()[filterArg] == Boolean.TRUE) {
accumulator.send(row);
}
}
@Override public @Nullable Object end() {
return accumulator.end();
}
}
}