blob: 3fc299bd5a336422325b50dcc08589be891e0b9c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.impl.transform;
import java.io.Serializable;
import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.Map;
import java.util.function.Function;
import org.apache.beam.sdk.coders.BigDecimalCoder;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.NullableCoder;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.CountIf;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.CovarianceFn;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.VarianceFn;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.Schema.TypeName;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Max;
import org.apache.beam.sdk.transforms.Min;
import org.apache.beam.sdk.transforms.Sample;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Predicates;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.checkerframework.checker.nullness.qual.Nullable;
/** Built-in aggregations functions for COUNT/MAX/MIN/SUM/AVG/VAR_POP/VAR_SAMP. */
@SuppressWarnings({
"rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
public class BeamBuiltinAggregations {
public static final Map<String, Function<Schema.FieldType, CombineFn<?, ?, ?>>>
BUILTIN_AGGREGATOR_FACTORIES =
ImmutableMap.<String, Function<Schema.FieldType, CombineFn<?, ?, ?>>>builder()
.put("ANY_VALUE", typeName -> Sample.anyValueCombineFn())
// Drop null elements for these aggregations.
.put("COUNT", typeName -> new DropNullFnWithDefault(Count.combineFn()))
.put("MAX", typeName -> new DropNullFn(BeamBuiltinAggregations.createMax(typeName)))
.put("MIN", typeName -> new DropNullFn(BeamBuiltinAggregations.createMin(typeName)))
.put("SUM", typeName -> new DropNullFn(BeamBuiltinAggregations.createSum(typeName)))
.put(
"$SUM0",
typeName ->
new DropNullFnWithDefault(BeamBuiltinAggregations.createSum0(typeName)))
.put("AVG", typeName -> new DropNullFn(BeamBuiltinAggregations.createAvg(typeName)))
.put(
"BIT_OR",
typeName -> new DropNullFn(BeamBuiltinAggregations.createBitOr(typeName)))
.put(
"BIT_XOR",
typeName -> new DropNullFn(BeamBuiltinAggregations.createBitXOr(typeName)))
// JIRA link:https://issues.apache.org/jira/browse/BEAM-10379
.put(
"BIT_AND",
typeName -> new DropNullFn(BeamBuiltinAggregations.createBitAnd(typeName)))
.put("VAR_POP", t -> VarianceFn.newPopulation(t.getTypeName()))
.put("VAR_SAMP", t -> VarianceFn.newSample(t.getTypeName()))
.put("COVAR_POP", t -> CovarianceFn.newPopulation(t.getTypeName()))
.put("COVAR_SAMP", t -> CovarianceFn.newSample(t.getTypeName()))
.put("COUNTIF", typeName -> CountIf.combineFn())
.build();
private static MathContext mc = new MathContext(10, RoundingMode.HALF_UP);
public static CombineFn<?, ?, ?> create(String functionName, Schema.FieldType fieldType) {
Function<Schema.FieldType, CombineFn<?, ?, ?>> aggregatorFactory =
BUILTIN_AGGREGATOR_FACTORIES.get(functionName);
if (aggregatorFactory != null) {
return aggregatorFactory.apply(fieldType);
}
throw new UnsupportedOperationException(
String.format("Aggregator [%s] is not supported", functionName));
}
/** {@link CombineFn} for MAX based on {@link Max} and {@link Combine.BinaryCombineFn}. */
static CombineFn createMax(FieldType fieldType) {
if (CalciteUtils.isDateTimeType(fieldType)) {
return new CustMax<>();
}
switch (fieldType.getTypeName()) {
case BOOLEAN:
case INT16:
case BYTE:
case FLOAT:
case DATETIME:
case DECIMAL:
case STRING:
return new CustMax<>();
case INT32:
return Max.ofIntegers();
case INT64:
return Max.ofLongs();
case DOUBLE:
return Max.ofDoubles();
default:
throw new UnsupportedOperationException(
String.format("[%s] is not supported in MAX", fieldType));
}
}
/** {@link CombineFn} for MIN based on {@link Min} and {@link Combine.BinaryCombineFn}. */
static CombineFn createMin(Schema.FieldType fieldType) {
if (CalciteUtils.isDateTimeType(fieldType)) {
return new CustMin();
}
switch (fieldType.getTypeName()) {
case BOOLEAN:
case BYTE:
case INT16:
case FLOAT:
case DATETIME:
case DECIMAL:
case STRING:
return new CustMin();
case INT32:
return Min.ofIntegers();
case INT64:
return Min.ofLongs();
case DOUBLE:
return Min.ofDoubles();
default:
throw new UnsupportedOperationException(
String.format("[%s] is not supported in MIN", fieldType));
}
}
/** {@link CombineFn} for Sum based on {@link Sum} and {@link Combine.BinaryCombineFn}. */
static CombineFn createSum(Schema.FieldType fieldType) {
switch (fieldType.getTypeName()) {
case INT32:
return Sum.ofIntegers();
case INT16:
return new ShortSum();
case BYTE:
return new ByteSum();
case INT64:
return new LongSum();
case FLOAT:
return new FloatSum();
case DOUBLE:
return Sum.ofDoubles();
case DECIMAL:
return new BigDecimalSum();
default:
throw new UnsupportedOperationException(
String.format("[%s] is not supported in SUM", fieldType));
}
}
/**
* {@link CombineFn} for Sum0 where sum of null returns 0 based on {@link Sum} and {@link
* Combine.BinaryCombineFn}.
*/
static CombineFn createSum0(Schema.FieldType fieldType) {
switch (fieldType.getTypeName()) {
case INT32:
return new IntegerSum0();
case INT16:
return new ShortSum0();
case BYTE:
return new ByteSum0();
case INT64:
return new LongSum0();
case FLOAT:
return new FloatSum0();
case DOUBLE:
return new DoubleSum0();
case DECIMAL:
return new BigDecimalSum0();
default:
throw new UnsupportedOperationException(
String.format("[%s] is not supported in SUM0", fieldType));
}
}
/** {@link CombineFn} for AVG. */
static CombineFn createAvg(Schema.FieldType fieldType) {
switch (fieldType.getTypeName()) {
case INT32:
return new IntegerAvg();
case INT16:
return new ShortAvg();
case BYTE:
return new ByteAvg();
case INT64:
return new LongAvg();
case FLOAT:
return new FloatAvg();
case DOUBLE:
return new DoubleAvg();
case DECIMAL:
return new BigDecimalAvg();
default:
throw new UnsupportedOperationException(
String.format("[%s] is not supported in AVG", fieldType));
}
}
static CombineFn createBitOr(Schema.FieldType fieldType) {
if (fieldType.getTypeName() == TypeName.INT64) {
return new BitOr();
}
throw new UnsupportedOperationException(
String.format("[%s] is not supported in BIT_OR", fieldType));
}
static CombineFn createBitAnd(Schema.FieldType fieldType) {
if (fieldType.getTypeName() == TypeName.INT64) {
return new BitAnd();
}
throw new UnsupportedOperationException(
String.format("[%s] is not supported in BIT_AND", fieldType));
}
public static CombineFn createBitXOr(Schema.FieldType fieldType) {
if (fieldType.getTypeName() == TypeName.INT64) {
return new BitXOr();
}
throw new UnsupportedOperationException(
String.format("[%s] is not supported in BIT_XOR", fieldType));
}
static class CustMax<T extends Comparable<T>> extends Combine.BinaryCombineFn<T> {
@Override
public T apply(T left, T right) {
return (right == null || right.compareTo(left) < 0) ? left : right;
}
}
static class CustMin<T extends Comparable<T>> extends Combine.BinaryCombineFn<T> {
@Override
public T apply(T left, T right) {
return (left == null || left.compareTo(right) < 0) ? left : right;
}
}
static class IntegerSum extends Combine.BinaryCombineFn<Integer> {
@Override
public Integer apply(Integer left, Integer right) {
return (int) (left + right);
}
}
static class ShortSum extends Combine.BinaryCombineFn<Short> {
@Override
public Short apply(Short left, Short right) {
return (short) (left + right);
}
}
static class ByteSum extends Combine.BinaryCombineFn<Byte> {
@Override
public Byte apply(Byte left, Byte right) {
return (byte) (left + right);
}
}
static class FloatSum extends Combine.BinaryCombineFn<Float> {
@Override
public Float apply(Float left, Float right) {
return left + right;
}
}
static class DoubleSum extends Combine.BinaryCombineFn<Double> {
@Override
public Double apply(Double left, Double right) {
return (double) left + right;
}
}
static class LongSum extends Combine.BinaryCombineFn<Long> {
@Override
public Long apply(Long left, Long right) {
return Math.addExact(left, right);
}
}
static class BigDecimalSum extends Combine.BinaryCombineFn<BigDecimal> {
@Override
public BigDecimal apply(BigDecimal left, BigDecimal right) {
return left.add(right);
}
}
static class IntegerSum0 extends IntegerSum {
@Override
public @Nullable Integer identity() {
return 0;
}
}
static class ShortSum0 extends ShortSum {
@Override
public @Nullable Short identity() {
return 0;
}
}
static class ByteSum0 extends ByteSum {
@Override
public @Nullable Byte identity() {
return 0;
}
}
static class FloatSum0 extends FloatSum {
@Override
public @Nullable Float identity() {
return 0F;
}
}
static class DoubleSum0 extends DoubleSum {
@Override
public @Nullable Double identity() {
return 0D;
}
}
static class LongSum0 extends LongSum {
@Override
public @Nullable Long identity() {
return 0L;
}
}
static class BigDecimalSum0 extends BigDecimalSum {
@Override
public @Nullable BigDecimal identity() {
return BigDecimal.ZERO;
}
}
private static class DropNullFn<InputT, AccumT, OutputT>
extends CombineFn<InputT, AccumT, OutputT> {
protected final CombineFn<InputT, AccumT, OutputT> combineFn;
DropNullFn(CombineFn<InputT, AccumT, OutputT> combineFn) {
this.combineFn = combineFn;
}
@Override
public AccumT createAccumulator() {
return null;
}
@Override
public AccumT addInput(AccumT accumulator, InputT input) {
if (input == null) {
return accumulator;
}
if (accumulator == null) {
accumulator = combineFn.createAccumulator();
}
return combineFn.addInput(accumulator, input);
}
@Override
public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
// filter out nulls
accumulators = Iterables.filter(accumulators, Predicates.notNull());
// handle only nulls
if (!accumulators.iterator().hasNext()) {
return null;
}
return combineFn.mergeAccumulators(accumulators);
}
@Override
public OutputT extractOutput(AccumT accumulator) {
if (accumulator == null) {
return null;
}
return combineFn.extractOutput(accumulator);
}
@Override
public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<InputT> inputCoder)
throws CannotProvideCoderException {
Coder<AccumT> coder = combineFn.getAccumulatorCoder(registry, inputCoder);
if (coder instanceof NullableCoder) {
return coder;
}
return NullableCoder.of(coder);
}
}
private static class DropNullFnWithDefault<InputT, AccumT, OutputT>
extends DropNullFn<InputT, AccumT, OutputT> {
DropNullFnWithDefault(CombineFn<InputT, AccumT, OutputT> combineFn) {
super(combineFn);
}
@Override
public AccumT createAccumulator() {
return combineFn.createAccumulator();
}
}
/** {@link CombineFn} for <em>AVG</em> on {@link Number} types. */
abstract static class Avg<T extends Number> extends CombineFn<T, KV<Integer, BigDecimal>, T> {
@Override
public KV<Integer, BigDecimal> createAccumulator() {
return KV.of(0, BigDecimal.ZERO);
}
@Override
public KV<Integer, BigDecimal> addInput(KV<Integer, BigDecimal> accumulator, T input) {
return KV.of(accumulator.getKey() + 1, accumulator.getValue().add(toBigDecimal(input)));
}
@Override
public KV<Integer, BigDecimal> mergeAccumulators(
Iterable<KV<Integer, BigDecimal>> accumulators) {
int size = 0;
BigDecimal acc = BigDecimal.ZERO;
for (KV<Integer, BigDecimal> ele : accumulators) {
size += ele.getKey();
acc = acc.add(ele.getValue());
}
return KV.of(size, acc);
}
@Override
public Coder<KV<Integer, BigDecimal>> getAccumulatorCoder(
CoderRegistry registry, Coder<T> inputCoder) {
return KvCoder.of(BigEndianIntegerCoder.of(), BigDecimalCoder.of());
}
protected BigDecimal prepareOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getValue().divide(new BigDecimal(accumulator.getKey()), mc);
}
@Override
public abstract T extractOutput(KV<Integer, BigDecimal> accumulator);
public abstract BigDecimal toBigDecimal(T record);
}
static class IntegerAvg extends Avg<Integer> {
@Override
public @Nullable Integer extractOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).intValue();
}
@Override
public BigDecimal toBigDecimal(Integer record) {
return new BigDecimal(record);
}
}
static class LongAvg extends Avg<Long> {
@Override
public @Nullable Long extractOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).longValue();
}
@Override
public BigDecimal toBigDecimal(Long record) {
return new BigDecimal(record);
}
}
static class ShortAvg extends Avg<Short> {
@Override
public @Nullable Short extractOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).shortValue();
}
@Override
public BigDecimal toBigDecimal(Short record) {
return new BigDecimal(record);
}
}
static class ByteAvg extends Avg<Byte> {
@Override
public @Nullable Byte extractOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).byteValue();
}
@Override
public BigDecimal toBigDecimal(Byte record) {
return new BigDecimal(record);
}
}
static class FloatAvg extends Avg<Float> {
@Override
public @Nullable Float extractOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).floatValue();
}
@Override
public BigDecimal toBigDecimal(Float record) {
return new BigDecimal(record);
}
}
static class DoubleAvg extends Avg<Double> {
@Override
public @Nullable Double extractOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).doubleValue();
}
@Override
public BigDecimal toBigDecimal(Double record) {
return new BigDecimal(record);
}
}
static class BigDecimalAvg extends Avg<BigDecimal> {
@Override
public @Nullable BigDecimal extractOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getKey() == 0 ? null : prepareOutput(accumulator);
}
@Override
public BigDecimal toBigDecimal(BigDecimal record) {
return record;
}
}
static class BitOr<T extends Number> extends CombineFn<T, BitOr.Accum, Long> {
static class Accum implements Serializable {
/** True if no inputs have been seen yet. */
boolean isEmpty = true;
/** The bitwise-or of the inputs seen so far. */
long bitOr = 0L;
}
@Override
public BitOr.Accum createAccumulator() {
return new BitOr.Accum();
}
@Override
public BitOr.Accum addInput(BitOr.Accum accum, T input) {
accum.isEmpty = false;
accum.bitOr |= input.longValue();
return accum;
}
@Override
public BitOr.Accum mergeAccumulators(Iterable<BitOr.Accum> accums) {
BitOr.Accum merged = createAccumulator();
for (BitOr.Accum accum : accums) {
if (accum.isEmpty) {
continue;
}
merged.isEmpty = false;
merged.bitOr |= accum.bitOr;
}
return merged;
}
@Override
public Long extractOutput(BitOr.Accum accum) {
if (accum.isEmpty) {
return null;
}
return accum.bitOr;
}
}
/**
* Bitwise AND function implementation.
*
* <p>Note: null values are ignored when mixed with non-null values.
* (https://issues.apache.org/jira/browse/BEAM-10379)
*/
static class BitAnd<T extends Number> extends CombineFn<T, BitAnd.Accum, Long> {
static class Accum implements Serializable {
/** True if no inputs have been seen yet. */
boolean isEmpty = true;
/** The bitwise-and of the inputs seen so far. */
long bitAnd = -1L;
}
@Override
public BitAnd.Accum createAccumulator() {
return new BitAnd.Accum();
}
@Override
public BitAnd.Accum addInput(BitAnd.Accum accum, T input) {
accum.isEmpty = false;
accum.bitAnd &= input.longValue();
return accum;
}
@Override
public BitAnd.Accum mergeAccumulators(Iterable<BitAnd.Accum> accums) {
BitAnd.Accum merged = createAccumulator();
for (BitAnd.Accum accum : accums) {
if (accum.isEmpty) {
continue;
}
merged.isEmpty = false;
merged.bitAnd &= accum.bitAnd;
}
return merged;
}
@Override
public Long extractOutput(BitAnd.Accum accum) {
if (accum.isEmpty) {
return null;
}
return accum.bitAnd;
}
}
public static class BitXOr<T extends Number> extends CombineFn<T, BitXOr.Accum, Long> {
static class Accum implements Serializable {
/** True if no inputs have been seen yet. */
boolean isEmpty = true;
/** The bitwise-and of the inputs seen so far. */
long bitXOr = 0L;
}
@Override
public BitXOr.Accum createAccumulator() {
return new BitXOr.Accum();
}
@Override
public BitXOr.Accum addInput(Accum mutableAccumulator, T input) {
if (input != null) {
mutableAccumulator.isEmpty = false;
mutableAccumulator.bitXOr ^= input.longValue();
}
return mutableAccumulator;
}
@Override
public BitXOr.Accum mergeAccumulators(Iterable<BitXOr.Accum> accumulators) {
BitXOr.Accum merged = createAccumulator();
for (BitXOr.Accum accum : accumulators) {
if (accum.isEmpty) {
continue;
}
merged.isEmpty = false;
merged.bitXOr ^= accum.bitXOr;
}
return merged;
}
@Override
public Long extractOutput(BitXOr.Accum accumulator) {
if (accumulator.isEmpty) {
return null;
}
return accumulator.bitXOr;
}
}
}