| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to you under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.calcite.interpreter; |
| |
| import org.apache.calcite.DataContext; |
| import org.apache.calcite.adapter.enumerable.AggAddContext; |
| import org.apache.calcite.adapter.enumerable.AggImpState; |
| import org.apache.calcite.adapter.enumerable.JavaRowFormat; |
| import org.apache.calcite.adapter.enumerable.PhysType; |
| import org.apache.calcite.adapter.enumerable.PhysTypeImpl; |
| import org.apache.calcite.adapter.enumerable.RexToLixTranslator; |
| import org.apache.calcite.adapter.enumerable.impl.AggAddContextImpl; |
| import org.apache.calcite.adapter.java.JavaTypeFactory; |
| import org.apache.calcite.interpreter.Row.RowBuilder; |
| import org.apache.calcite.linq4j.tree.BlockBuilder; |
| import org.apache.calcite.linq4j.tree.Expression; |
| import org.apache.calcite.linq4j.tree.Expressions; |
| import org.apache.calcite.linq4j.tree.ParameterExpression; |
| import org.apache.calcite.rel.core.Aggregate; |
| import org.apache.calcite.rel.core.AggregateCall; |
| import org.apache.calcite.rel.type.RelDataTypeFactory; |
| import org.apache.calcite.rex.RexInputRef; |
| import org.apache.calcite.rex.RexNode; |
| import org.apache.calcite.runtime.FunctionContexts; |
| import org.apache.calcite.schema.FunctionContext; |
| import org.apache.calcite.schema.impl.AggregateFunctionImpl; |
| import org.apache.calcite.sql.SqlAggFunction; |
| import org.apache.calcite.sql.fun.SqlInternalOperators; |
| import org.apache.calcite.sql.fun.SqlStdOperatorTable; |
| import org.apache.calcite.sql.validate.SqlConformance; |
| import org.apache.calcite.sql.validate.SqlConformanceEnum; |
| import org.apache.calcite.util.ImmutableBitSet; |
| import org.apache.calcite.util.Util; |
| |
| import com.google.common.collect.ImmutableList; |
| |
| import org.checkerframework.checker.nullness.qual.Nullable; |
| |
| import java.lang.reflect.Constructor; |
| import java.lang.reflect.InvocationTargetException; |
| import java.math.BigDecimal; |
| import java.util.ArrayList; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.function.BiFunction; |
| import java.util.function.Supplier; |
| |
| import static org.apache.calcite.linq4j.Nullness.castNonNull; |
| |
| import static java.util.Objects.requireNonNull; |
| |
| /** |
| * Interpreter node that implements an |
| * {@link org.apache.calcite.rel.core.Aggregate}. |
| */ |
| public class AggregateNode extends AbstractSingleNode<Aggregate> { |
| private final List<Grouping> groups = new ArrayList<>(); |
| private final ImmutableBitSet unionGroups; |
| private final int outputRowLength; |
| private final ImmutableList<AccumulatorFactory> accumulatorFactories; |
| private final DataContext dataContext; |
| |
| public AggregateNode(Compiler compiler, Aggregate rel) { |
| super(compiler, rel); |
| this.dataContext = compiler.getDataContext(); |
| |
| ImmutableBitSet union = ImmutableBitSet.of(); |
| |
| for (ImmutableBitSet group : rel.getGroupSets()) { |
| union = union.union(group); |
| groups.add(new Grouping(group)); |
| } |
| |
| this.unionGroups = union; |
| this.outputRowLength = unionGroups.cardinality() |
| + rel.getAggCallList().size(); |
| |
| ImmutableList.Builder<AccumulatorFactory> builder = ImmutableList.builder(); |
| for (AggregateCall aggregateCall : rel.getAggCallList()) { |
| @SuppressWarnings("method.invocation.invalid") |
| AccumulatorFactory accumulator = |
| getAccumulator(compiler, aggregateCall, false); |
| builder.add(accumulator); |
| } |
| accumulatorFactories = builder.build(); |
| } |
| |
| @Override public void run() throws InterruptedException { |
| Row r; |
| while ((r = source.receive()) != null) { |
| for (Grouping group : groups) { |
| group.send(r); |
| } |
| } |
| |
| for (Grouping group : groups) { |
| group.end(sink); |
| } |
| } |
| |
| private AccumulatorFactory getAccumulator(Compiler compiler, |
| final AggregateCall call, boolean ignoreFilter) { |
| if (call.filterArg >= 0 && !ignoreFilter) { |
| final AccumulatorFactory factory = getAccumulator(compiler, call, true); |
| return () -> { |
| final Accumulator accumulator = factory.get(); |
| return new FilterAccumulator(accumulator, call.filterArg); |
| }; |
| } |
| final SqlAggFunction op = call.getAggregation(); |
| if (op == SqlStdOperatorTable.COUNT) { |
| return () -> new CountAccumulator(call); |
| } else if (op == SqlStdOperatorTable.SUM |
| || op == SqlStdOperatorTable.SUM0) { |
| final Class<?> clazz = sumClass(call); |
| return new UdaAccumulatorFactory(getAggFunction(clazz), call, |
| op == SqlStdOperatorTable.SUM, dataContext); |
| } else if (op == SqlStdOperatorTable.MAX |
| || op == SqlStdOperatorTable.MIN) { |
| final Class<?> clazz = maxMinClass(call); |
| return new UdaAccumulatorFactory(getAggFunction(clazz), call, true, |
| dataContext); |
| } else if (op == SqlInternalOperators.LITERAL_AGG) { |
| final Scalar scalar = compiler.compile(call.rexList, null); |
| final Object value = scalar.execute(compiler.createContext()); |
| return () -> new LiteralAccumulator(value); |
| } else { |
| final JavaTypeFactory typeFactory = |
| (JavaTypeFactory) rel.getCluster().getTypeFactory(); |
| int stateOffset = 0; |
| final AggImpState agg = new AggImpState(0, call, false); |
| int stateSize = requireNonNull(agg.state, "agg.state").size(); |
| |
| final BlockBuilder builder2 = new BlockBuilder(); |
| final PhysType inputPhysType = |
| PhysTypeImpl.of(typeFactory, rel.getInput().getRowType(), |
| JavaRowFormat.ARRAY); |
| final RelDataTypeFactory.Builder builder = typeFactory.builder(); |
| for (Expression expression : agg.state) { |
| builder.add("a", |
| typeFactory.createJavaType((Class) expression.getType())); |
| } |
| final PhysType accPhysType = |
| PhysTypeImpl.of(typeFactory, builder.build(), JavaRowFormat.ARRAY); |
| final ParameterExpression inParameter = |
| Expressions.parameter(inputPhysType.getJavaRowType(), "in"); |
| final ParameterExpression acc_ = |
| Expressions.parameter(accPhysType.getJavaRowType(), "acc"); |
| |
| List<Expression> accumulator = new ArrayList<>(stateSize); |
| for (int j = 0; j < stateSize; j++) { |
| accumulator.add(accPhysType.fieldReference(acc_, j + stateOffset)); |
| } |
| agg.state = accumulator; |
| |
| AggAddContext addContext = |
| new AggAddContextImpl(builder2, accumulator) { |
| @Override public List<RexNode> rexArguments() { |
| List<RexNode> args = new ArrayList<>(); |
| for (int index : agg.call.getArgList()) { |
| args.add(RexInputRef.of(index, inputPhysType.getRowType())); |
| } |
| return args; |
| } |
| |
| @Override public @Nullable RexNode rexFilterArgument() { |
| return agg.call.filterArg < 0 |
| ? null |
| : RexInputRef.of(agg.call.filterArg, |
| inputPhysType.getRowType()); |
| } |
| |
| @Override public RexToLixTranslator rowTranslator() { |
| final SqlConformance conformance = |
| SqlConformanceEnum.DEFAULT; // TODO: get this from implementor |
| return RexToLixTranslator.forAggregation(typeFactory, |
| currentBlock(), |
| new RexToLixTranslator.InputGetterImpl(inParameter, |
| inputPhysType), |
| conformance); |
| } |
| }; |
| |
| agg.implementor.implementAdd(requireNonNull(agg.context, "agg.context"), addContext); |
| |
| final ParameterExpression context_ = |
| Expressions.parameter(Context.class, "context"); |
| final ParameterExpression outputValues_ = |
| Expressions.parameter(Object[].class, "outputValues"); |
| final Scalar.Producer addScalarProducer = |
| JaninoRexCompiler.baz(context_, outputValues_, builder2.toBlock(), |
| ImmutableList.of()); |
| final Scalar initScalar = castNonNull(null); |
| final Scalar addScalar = addScalarProducer.apply(dataContext); |
| final Scalar endScalar = castNonNull(null); |
| return new ScalarAccumulatorDef(initScalar, addScalar, endScalar, |
| rel.getInput().getRowType().getFieldCount(), stateSize, dataContext); |
| } |
| } |
| |
| private static Class<?> maxMinClass(AggregateCall call) { |
| boolean max = call.getAggregation() == SqlStdOperatorTable.MAX; |
| switch (call.getType().getSqlTypeName()) { |
| case INTEGER: |
| return max ? MaxInt.class : MinInt.class; |
| case REAL: |
| return max ? MaxFloat.class : MinFloat.class; |
| case FLOAT: |
| case DOUBLE: |
| return max ? MaxDouble.class : MinDouble.class; |
| case DECIMAL: |
| return max ? MaxBigDecimal.class : MinBigDecimal.class; |
| case BOOLEAN: |
| return max ? MaxBoolean.class : MinBoolean.class; |
| default: |
| return max ? MaxLong.class : MinLong.class; |
| } |
| } |
| |
| private static Class<?> sumClass(AggregateCall call) { |
| switch (call.type.getSqlTypeName()) { |
| case DOUBLE: |
| case REAL: |
| case FLOAT: |
| return DoubleSum.class; |
| case DECIMAL: |
| return BigDecimalSum.class; |
| case INTEGER: |
| return IntSum.class; |
| case BIGINT: |
| default: |
| return LongSum.class; |
| } |
| } |
| |
| private static AggregateFunctionImpl getAggFunction(Class<?> clazz) { |
| return requireNonNull( |
| AggregateFunctionImpl.create(clazz), |
| () -> "Unable to create AggregateFunctionImpl for " + clazz); |
| } |
| |
| /** Accumulator for calls to the COUNT function. */ |
| private static class CountAccumulator implements Accumulator { |
| private final AggregateCall call; |
| long cnt; |
| |
| CountAccumulator(AggregateCall call) { |
| this.call = call; |
| cnt = 0; |
| } |
| |
| @Override public void send(Row row) { |
| boolean notNull = true; |
| for (Integer i : call.getArgList()) { |
| if (row.getObject(i) == null) { |
| notNull = false; |
| break; |
| } |
| } |
| if (notNull) { |
| cnt++; |
| } |
| } |
| |
| @Override public Object end() { |
| return cnt; |
| } |
| } |
| |
| /** Accumulator for calls to the LITERAL_AGG function. */ |
| private static class LiteralAccumulator implements Accumulator { |
| private final @Nullable Object value; |
| |
| LiteralAccumulator(@Nullable Object value) { |
| this.value = value; |
| } |
| |
| @Override public void send(Row row) { |
| } |
| |
| @Override public @Nullable Object end() { |
| return value; |
| } |
| } |
| |
| /** Creates an {@link Accumulator}. */ |
| private interface AccumulatorFactory extends Supplier<Accumulator> { |
| } |
| |
| /** Accumulator powered by {@link Scalar} code fragments. */ |
| private static class ScalarAccumulatorDef implements AccumulatorFactory { |
| final Scalar initScalar; |
| final Scalar addScalar; |
| final Scalar endScalar; |
| final Context sendContext; |
| final Context endContext; |
| final int rowLength; |
| final int accumulatorLength; |
| |
| private ScalarAccumulatorDef(Scalar initScalar, Scalar addScalar, |
| Scalar endScalar, int rowLength, int accumulatorLength, |
| DataContext root) { |
| this.initScalar = initScalar; |
| this.addScalar = addScalar; |
| this.endScalar = endScalar; |
| this.accumulatorLength = accumulatorLength; |
| this.rowLength = rowLength; |
| this.sendContext = new Context(root); |
| this.sendContext.values = new Object[rowLength + accumulatorLength]; |
| this.endContext = new Context(root); |
| this.endContext.values = new Object[accumulatorLength]; |
| } |
| |
| @Override public Accumulator get() { |
| return new ScalarAccumulator(this, new Object[accumulatorLength]); |
| } |
| } |
| |
| /** Accumulator powered by {@link Scalar} code fragments. */ |
| private static class ScalarAccumulator implements Accumulator { |
| final ScalarAccumulatorDef def; |
| final Object[] values; |
| |
| private ScalarAccumulator(ScalarAccumulatorDef def, Object[] values) { |
| this.def = def; |
| this.values = values; |
| } |
| |
| @Override public void send(Row row) { |
| @Nullable Object[] sendValues = |
| requireNonNull(def.sendContext.values, "def.sendContext.values"); |
| System.arraycopy(row.getValues(), 0, sendValues, 0, |
| def.rowLength); |
| System.arraycopy(this.values, 0, sendValues, def.rowLength, |
| this.values.length); |
| def.addScalar.execute(def.sendContext, this.values); |
| } |
| |
| @Override public @Nullable Object end() { |
| Context endContext = requireNonNull(def.endContext, "def.endContext"); |
| @Nullable Object[] values = requireNonNull(endContext.values, "endContext.values"); |
| System.arraycopy(this.values, 0, values, 0, this.values.length); |
| return def.endScalar.execute(endContext); |
| } |
| } |
| |
| /** |
| * Internal class to track groupings. |
| */ |
| private class Grouping { |
| private final ImmutableBitSet grouping; |
| private final Map<Row, AccumulatorList> accumulators = new HashMap<>(); |
| |
| private Grouping(ImmutableBitSet grouping) { |
| this.grouping = grouping; |
| } |
| |
| public void send(Row row) { |
| // TODO: fix the size of this row. |
| RowBuilder builder = Row.newBuilder(grouping.cardinality()); |
| int j = 0; |
| for (Integer i : grouping) { |
| builder.set(j++, row.getObject(i)); |
| } |
| Row key = builder.build(); |
| |
| if (!accumulators.containsKey(key)) { |
| AccumulatorList list = new AccumulatorList(); |
| for (AccumulatorFactory factory : accumulatorFactories) { |
| list.add(factory.get()); |
| } |
| accumulators.put(key, list); |
| } |
| |
| accumulators.get(key).send(row); |
| } |
| |
| public void end(Sink sink) throws InterruptedException { |
| for (Map.Entry<Row, AccumulatorList> e : accumulators.entrySet()) { |
| final Row key = e.getKey(); |
| final AccumulatorList list = e.getValue(); |
| |
| RowBuilder rb = Row.newBuilder(outputRowLength); |
| int index = 0; |
| for (Integer groupPos : unionGroups) { |
| if (grouping.get(groupPos)) { |
| rb.set(index, key.getObject(index)); |
| } |
| // need to set false when not part of grouping set. |
| |
| index++; |
| } |
| |
| list.end(rb); |
| |
| sink.send(rb.build()); |
| } |
| } |
| } |
| |
| /** |
| * A list of accumulators used during grouping. |
| */ |
| private static class AccumulatorList extends ArrayList<Accumulator> { |
| public void send(Row row) { |
| for (Accumulator a : this) { |
| a.send(row); |
| } |
| } |
| |
| public void end(RowBuilder r) { |
| for (int accIndex = 0, rowIndex = r.size() - size(); |
| rowIndex < r.size(); rowIndex++, accIndex++) { |
| r.set(rowIndex, get(accIndex).end()); |
| } |
| } |
| } |
| |
| /** |
| * Defines function implementation for |
| * things like {@code count()} and {@code sum()}. |
| */ |
| private interface Accumulator { |
| void send(Row row); |
| @Nullable Object end(); |
| } |
| |
| /** Implementation of {@code SUM} over INTEGER values as a user-defined |
| * aggregate. */ |
| public static class IntSum { |
| public IntSum() { |
| } |
| public int init() { |
| return 0; |
| } |
| public int add(int accumulator, int v) { |
| return accumulator + v; |
| } |
| public int merge(int accumulator0, int accumulator1) { |
| return accumulator0 + accumulator1; |
| } |
| public int result(int accumulator) { |
| return accumulator; |
| } |
| } |
| |
| /** Implementation of {@code SUM} over BIGINT values as a user-defined |
| * aggregate. */ |
| public static class LongSum { |
| public LongSum() { |
| } |
| public long init() { |
| return 0L; |
| } |
| public long add(long accumulator, long v) { |
| return accumulator + v; |
| } |
| public long merge(long accumulator0, long accumulator1) { |
| return accumulator0 + accumulator1; |
| } |
| public long result(long accumulator) { |
| return accumulator; |
| } |
| } |
| |
| /** Implementation of {@code SUM} over DOUBLE values as a user-defined |
| * aggregate. */ |
| public static class DoubleSum { |
| public DoubleSum() { |
| } |
| public double init() { |
| return 0D; |
| } |
| public double add(double accumulator, double v) { |
| return accumulator + v; |
| } |
| public double merge(double accumulator0, double accumulator1) { |
| return accumulator0 + accumulator1; |
| } |
| public double result(double accumulator) { |
| return accumulator; |
| } |
| } |
| |
| /** Implementation of {@code SUM} over BigDecimal values as a user-defined |
| * aggregate. */ |
| public static class BigDecimalSum { |
| public BigDecimalSum(){ |
| } |
| |
| public BigDecimal init() { |
| return new BigDecimal("0"); |
| } |
| |
| public BigDecimal add(BigDecimal accumulator, BigDecimal v) { |
| return accumulator.add(v); |
| } |
| |
| public BigDecimal merge(BigDecimal accumulator0, BigDecimal accumulator01) { |
| return add(accumulator0, accumulator01); |
| } |
| |
| public BigDecimal result(BigDecimal accumulator) { |
| return accumulator; |
| } |
| } |
| |
| /** Common implementation of comparison aggregate methods over numeric |
| * values as a user-defined aggregate. |
| * |
| * @param <T> The numeric type |
| */ |
| public static class NumericComparison<T> { |
| private final T initialValue; |
| private final BiFunction<T, T, T> comparisonFunction; |
| |
| public NumericComparison(T initialValue, BiFunction<T, T, T> comparisonFunction) { |
| this.initialValue = initialValue; |
| this.comparisonFunction = comparisonFunction; |
| } |
| |
| public T init() { |
| return this.initialValue; |
| } |
| |
| public T add(T accumulator, T value) { |
| return this.comparisonFunction.apply(accumulator, value); |
| } |
| |
| public T merge(T accumulator0, T accumulator1) { |
| return add(accumulator0, accumulator1); |
| } |
| |
| public T result(T accumulator) { |
| return accumulator; |
| } |
| } |
| |
| /** Implementation of {@code MIN} function to calculate the minimum of |
| * {@code integer} values as a user-defined aggregate. |
| */ |
| public static class MinInt extends NumericComparison<Integer> { |
| public MinInt() { |
| super(Integer.MAX_VALUE, Math::min); |
| } |
| } |
| |
| /** Implementation of {@code MIN} function to calculate the minimum of |
| * {@code long} values as a user-defined aggregate. |
| */ |
| public static class MinLong extends NumericComparison<Long> { |
| public MinLong() { |
| super(Long.MAX_VALUE, Math::min); |
| } |
| } |
| |
| /** Implementation of {@code MIN} function to calculate the minimum of |
| * {@code float} values as a user-defined aggregate. |
| */ |
| public static class MinFloat extends NumericComparison<Float> { |
| public MinFloat() { |
| super(Float.MAX_VALUE, Math::min); |
| } |
| } |
| |
| /** Implementation of {@code MIN} function to calculate the minimum of |
| * {@code double} and {@code real} values as a user-defined aggregate. |
| */ |
| public static class MinDouble extends NumericComparison<Double> { |
| public MinDouble() { |
| super(Double.MAX_VALUE, Math::min); |
| } |
| } |
| |
| /** Implementation of {@code MIN} function to calculate the minimum of |
| * {@code BigDecimal} values as a user-defined aggregate. |
| */ |
| public static class MinBigDecimal extends NumericComparison<BigDecimal> { |
| public MinBigDecimal() { |
| super(new BigDecimal(Double.MAX_VALUE), MinBigDecimal::min); |
| } |
| |
| public static BigDecimal min(BigDecimal a, BigDecimal b) { |
| return a.min(b); |
| } |
| } |
| |
| /** Implementation of {@code MIN} function to calculate the minimum of |
| * {@code boolean} values as a user-defined aggregate. |
| */ |
| public static class MinBoolean { |
| public MinBoolean() { } |
| |
| public Boolean init() { |
| return Boolean.TRUE; |
| } |
| |
| public Boolean add(Boolean accumulator, Boolean value) { |
| return accumulator && value; |
| } |
| |
| public Boolean merge(Boolean accumulator0, Boolean accumulator1) { |
| return add(accumulator0, accumulator1); |
| } |
| |
| public Boolean result(Boolean accumulator) { |
| return accumulator; |
| } |
| } |
| |
| /** Implementation of {@code MAX} function to calculate the maximum of |
| * {@code integer} values as a user-defined aggregate. |
| */ |
| public static class MaxInt extends NumericComparison<Integer> { |
| public MaxInt() { |
| super(Integer.MIN_VALUE, Math::max); |
| } |
| } |
| |
| /** Implementation of {@code MAX} function to calculate the maximum of |
| * {@code long} values as a user-defined aggregate. |
| */ |
| public static class MaxLong extends NumericComparison<Long> { |
| public MaxLong() { |
| super(Long.MIN_VALUE, Math::max); |
| } |
| } |
| |
| /** Implementation of {@code MAX} function to calculate the maximum of |
| * {@code float} values as a user-defined aggregate. |
| */ |
| public static class MaxFloat extends NumericComparison<Float> { |
| public MaxFloat() { |
| super(Float.MIN_VALUE, Math::max); |
| } |
| } |
| |
| /** Implementation of {@code MAX} function to calculate the maximum of |
| * {@code double} and {@code real} values as a user-defined aggregate. |
| */ |
| public static class MaxDouble extends NumericComparison<Double> { |
| public MaxDouble() { |
| super(Double.MIN_VALUE, Math::max); |
| } |
| } |
| |
| /** Implementation of {@code MAX} function to calculate the maximum of |
| * {@code BigDecimal} values as a user-defined aggregate. |
| */ |
| public static class MaxBigDecimal extends NumericComparison<BigDecimal> { |
| public MaxBigDecimal() { |
| super(new BigDecimal(Double.MIN_VALUE), MaxBigDecimal::max); |
| } |
| |
| public static BigDecimal max(BigDecimal a, BigDecimal b) { |
| return a.max(b); |
| } |
| } |
| |
| /** Implementation of {@code MAX} function to calculate the maximum of |
| * {@code boolean} values as a user-defined aggregate. |
| */ |
| public static class MaxBoolean { |
| public MaxBoolean() { } |
| |
| public Boolean init() { |
| return Boolean.FALSE; |
| } |
| |
| public Boolean add(Boolean accumulator, Boolean value) { |
| return accumulator || value; |
| } |
| |
| public Boolean merge(Boolean accumulator0, Boolean accumulator1) { |
| return add(accumulator0, accumulator1); |
| } |
| |
| public Boolean result(Boolean accumulator) { |
| return accumulator; |
| } |
| } |
| |
| /** Accumulator factory based on a user-defined aggregate function. */ |
| private static class UdaAccumulatorFactory implements AccumulatorFactory { |
| final AggregateFunctionImpl aggFunction; |
| final int argOrdinal; |
| public final @Nullable Object instance; |
| public final boolean nullIfEmpty; |
| |
| UdaAccumulatorFactory(AggregateFunctionImpl aggFunction, |
| AggregateCall call, boolean nullIfEmpty, DataContext dataContext) { |
| this.aggFunction = aggFunction; |
| if (call.getArgList().size() != 1) { |
| throw new UnsupportedOperationException("in current implementation, " |
| + "aggregate must have precisely one argument"); |
| } |
| argOrdinal = call.getArgList().get(0); |
| instance = createInstance(aggFunction, dataContext); |
| this.nullIfEmpty = nullIfEmpty; |
| } |
| |
| static @Nullable Object createInstance(AggregateFunctionImpl aggFunction, |
| DataContext dataContext) { |
| if (aggFunction.isStatic) { |
| return null; |
| } |
| try { |
| final Constructor<?> constructor = |
| aggFunction.declaringClass.getConstructor(); |
| return constructor.newInstance(); |
| } catch (InstantiationException | IllegalAccessException |
| | NoSuchMethodException | InvocationTargetException e) { |
| // ignore, and try next constructor |
| } |
| try { |
| final Constructor<?> constructor = |
| aggFunction.declaringClass.getConstructor(FunctionContext.class); |
| final Object[] args = new Object[aggFunction.getParameters().size()]; |
| final FunctionContext functionContext = |
| FunctionContexts.of(dataContext, args); |
| return constructor.newInstance(functionContext); |
| } catch (InstantiationException | IllegalAccessException |
| | NoSuchMethodException | InvocationTargetException e) { |
| throw Util.toUnchecked(e); |
| } |
| } |
| |
| @Override public Accumulator get() { |
| return new UdaAccumulator(this); |
| } |
| } |
| |
| /** Accumulator based upon a user-defined aggregate. */ |
| private static class UdaAccumulator implements Accumulator { |
| private final UdaAccumulatorFactory factory; |
| private @Nullable Object value; |
| private boolean empty; |
| |
| UdaAccumulator(UdaAccumulatorFactory factory) { |
| this.factory = factory; |
| try { |
| this.value = factory.aggFunction.initMethod.invoke(factory.instance); |
| } catch (IllegalAccessException | InvocationTargetException e) { |
| throw new RuntimeException(e); |
| } |
| this.empty = true; |
| } |
| |
| @Override public void send(Row row) { |
| final @Nullable Object[] args = {value, row.getValues()[factory.argOrdinal]}; |
| for (int i = 1; i < args.length; i++) { |
| if (args[i] == null) { |
| return; // one of the arguments is null; don't add to the total |
| } |
| } |
| try { |
| value = factory.aggFunction.addMethod.invoke(factory.instance, args); |
| } catch (IllegalAccessException | InvocationTargetException e) { |
| throw new RuntimeException(e); |
| } |
| empty = false; |
| } |
| |
| @Override public @Nullable Object end() { |
| if (factory.nullIfEmpty && empty) { |
| return null; |
| } |
| final @Nullable Object[] args = {value}; |
| try { |
| AggregateFunctionImpl aggFunction = |
| requireNonNull(factory.aggFunction, "factory.aggFunction"); |
| return requireNonNull(aggFunction.resultMethod, "aggFunction.resultMethod") |
| .invoke(factory.instance, args); |
| } catch (IllegalAccessException | InvocationTargetException e) { |
| throw new RuntimeException(e); |
| } |
| } |
| } |
| |
| /** Accumulator that applies a filter to another accumulator. |
| * The filter is a BOOLEAN field in the input row. */ |
| private static class FilterAccumulator implements Accumulator { |
| private final Accumulator accumulator; |
| private final int filterArg; |
| |
| FilterAccumulator(Accumulator accumulator, int filterArg) { |
| this.accumulator = accumulator; |
| this.filterArg = filterArg; |
| } |
| |
| @Override public void send(Row row) { |
| if (row.getValues()[filterArg] == Boolean.TRUE) { |
| accumulator.send(row); |
| } |
| } |
| |
| @Override public @Nullable Object end() { |
| return accumulator.end(); |
| } |
| } |
| } |