blob: 779924527de720b52c7601f7c762ec35367d6247 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.impl.transform;
import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.Map;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.apache.beam.sdk.coders.BigDecimalCoder;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.CovarianceFn;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.VarianceFn;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Max;
import org.apache.beam.sdk.transforms.Min;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
/** Built-in aggregations functions for COUNT/MAX/MIN/SUM/AVG/VAR_POP/VAR_SAMP. */
public class BeamBuiltinAggregations {
public static final Map<String, Function<Schema.FieldType, CombineFn<?, ?, ?>>>
BUILTIN_AGGREGATOR_FACTORIES =
ImmutableMap.<String, Function<Schema.FieldType, CombineFn<?, ?, ?>>>builder()
.put("COUNT", typeName -> Count.combineFn())
.put("MAX", BeamBuiltinAggregations::createMax)
.put("MIN", BeamBuiltinAggregations::createMin)
.put("SUM", BeamBuiltinAggregations::createSum)
.put("$SUM0", BeamBuiltinAggregations::createSum)
.put("AVG", BeamBuiltinAggregations::createAvg)
.put("VAR_POP", t -> VarianceFn.newPopulation(t.getTypeName()))
.put("VAR_SAMP", t -> VarianceFn.newSample(t.getTypeName()))
.put("COVAR_POP", t -> CovarianceFn.newPopulation(t.getTypeName()))
.put("COVAR_SAMP", t -> CovarianceFn.newSample(t.getTypeName()))
.build();
private static MathContext mc = new MathContext(10, RoundingMode.HALF_UP);
public static CombineFn<?, ?, ?> create(String functionName, Schema.FieldType fieldType) {
Function<Schema.FieldType, CombineFn<?, ?, ?>> aggregatorFactory =
BUILTIN_AGGREGATOR_FACTORIES.get(functionName);
if (aggregatorFactory != null) {
return aggregatorFactory.apply(fieldType);
}
throw new UnsupportedOperationException(
String.format("Aggregator [%s] is not supported", functionName));
}
/** {@link CombineFn} for MAX based on {@link Max} and {@link Combine.BinaryCombineFn}. */
static CombineFn createMax(FieldType fieldType) {
if (CalciteUtils.isDateTimeType(fieldType)) {
return new CustMax<>();
}
switch (fieldType.getTypeName()) {
case BOOLEAN:
case INT16:
case BYTE:
case FLOAT:
case DATETIME:
case DECIMAL:
case STRING:
return new CustMax<>();
case INT32:
return Max.ofIntegers();
case INT64:
return Max.ofLongs();
case DOUBLE:
return Max.ofDoubles();
default:
throw new UnsupportedOperationException(
String.format("[%s] is not support in MAX", fieldType));
}
}
/** {@link CombineFn} for MIN based on {@link Min} and {@link Combine.BinaryCombineFn}. */
static CombineFn createMin(Schema.FieldType fieldType) {
if (CalciteUtils.isDateTimeType(fieldType)) {
return new CustMin();
}
switch (fieldType.getTypeName()) {
case BOOLEAN:
case BYTE:
case INT16:
case FLOAT:
case DATETIME:
case DECIMAL:
case STRING:
return new CustMin();
case INT32:
return Min.ofIntegers();
case INT64:
return Min.ofLongs();
case DOUBLE:
return Min.ofDoubles();
default:
throw new UnsupportedOperationException(
String.format("[%s] is not support in MIN", fieldType));
}
}
/** {@link CombineFn} for Sum based on {@link Sum} and {@link Combine.BinaryCombineFn}. */
static CombineFn createSum(Schema.FieldType fieldType) {
switch (fieldType.getTypeName()) {
case INT32:
return Sum.ofIntegers();
case INT16:
return new ShortSum();
case BYTE:
return new ByteSum();
case INT64:
return Sum.ofLongs();
case FLOAT:
return new FloatSum();
case DOUBLE:
return Sum.ofDoubles();
case DECIMAL:
return new BigDecimalSum();
default:
throw new UnsupportedOperationException(
String.format("[%s] is not support in SUM", fieldType));
}
}
/** {@link CombineFn} for AVG. */
static CombineFn createAvg(Schema.FieldType fieldType) {
switch (fieldType.getTypeName()) {
case INT32:
return new IntegerAvg();
case INT16:
return new ShortAvg();
case BYTE:
return new ByteAvg();
case INT64:
return new LongAvg();
case FLOAT:
return new FloatAvg();
case DOUBLE:
return new DoubleAvg();
case DECIMAL:
return new BigDecimalAvg();
default:
throw new UnsupportedOperationException(
String.format("[%s] is not support in AVG", fieldType));
}
}
static class CustMax<T extends Comparable<T>> extends Combine.BinaryCombineFn<T> {
@Override
public T apply(T left, T right) {
return (right == null || right.compareTo(left) < 0) ? left : right;
}
}
static class CustMin<T extends Comparable<T>> extends Combine.BinaryCombineFn<T> {
@Override
public T apply(T left, T right) {
return (left == null || left.compareTo(right) < 0) ? left : right;
}
}
static class ShortSum extends Combine.BinaryCombineFn<Short> {
@Override
public Short apply(Short left, Short right) {
return (short) (left + right);
}
}
static class ByteSum extends Combine.BinaryCombineFn<Byte> {
@Override
public Byte apply(Byte left, Byte right) {
return (byte) (left + right);
}
}
static class FloatSum extends Combine.BinaryCombineFn<Float> {
@Override
public Float apply(Float left, Float right) {
return left + right;
}
}
static class BigDecimalSum extends Combine.BinaryCombineFn<BigDecimal> {
@Override
public BigDecimal apply(BigDecimal left, BigDecimal right) {
return left.add(right);
}
}
/** {@link CombineFn} for <em>AVG</em> on {@link Number} types. */
abstract static class Avg<T extends Number> extends CombineFn<T, KV<Integer, BigDecimal>, T> {
@Override
public KV<Integer, BigDecimal> createAccumulator() {
return KV.of(0, BigDecimal.ZERO);
}
@Override
public KV<Integer, BigDecimal> addInput(KV<Integer, BigDecimal> accumulator, T input) {
return KV.of(accumulator.getKey() + 1, accumulator.getValue().add(toBigDecimal(input)));
}
@Override
public KV<Integer, BigDecimal> mergeAccumulators(
Iterable<KV<Integer, BigDecimal>> accumulators) {
int size = 0;
BigDecimal acc = BigDecimal.ZERO;
for (KV<Integer, BigDecimal> ele : accumulators) {
size += ele.getKey();
acc = acc.add(ele.getValue());
}
return KV.of(size, acc);
}
@Override
public Coder<KV<Integer, BigDecimal>> getAccumulatorCoder(
CoderRegistry registry, Coder<T> inputCoder) {
return KvCoder.of(BigEndianIntegerCoder.of(), BigDecimalCoder.of());
}
protected BigDecimal prepareOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getValue().divide(new BigDecimal(accumulator.getKey()), mc);
}
@Override
public abstract T extractOutput(KV<Integer, BigDecimal> accumulator);
public abstract BigDecimal toBigDecimal(T record);
}
static class IntegerAvg extends Avg<Integer> {
@Override
@Nullable
public Integer extractOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).intValue();
}
@Override
public BigDecimal toBigDecimal(Integer record) {
return new BigDecimal(record);
}
}
static class LongAvg extends Avg<Long> {
@Override
@Nullable
public Long extractOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).longValue();
}
@Override
public BigDecimal toBigDecimal(Long record) {
return new BigDecimal(record);
}
}
static class ShortAvg extends Avg<Short> {
@Override
@Nullable
public Short extractOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).shortValue();
}
@Override
public BigDecimal toBigDecimal(Short record) {
return new BigDecimal(record);
}
}
static class ByteAvg extends Avg<Byte> {
@Override
@Nullable
public Byte extractOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).byteValue();
}
@Override
public BigDecimal toBigDecimal(Byte record) {
return new BigDecimal(record);
}
}
static class FloatAvg extends Avg<Float> {
@Override
@Nullable
public Float extractOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).floatValue();
}
@Override
public BigDecimal toBigDecimal(Float record) {
return new BigDecimal(record);
}
}
static class DoubleAvg extends Avg<Double> {
@Override
@Nullable
public Double extractOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).doubleValue();
}
@Override
public BigDecimal toBigDecimal(Double record) {
return new BigDecimal(record);
}
}
static class BigDecimalAvg extends Avg<BigDecimal> {
@Override
@Nullable
public BigDecimal extractOutput(KV<Integer, BigDecimal> accumulator) {
return accumulator.getKey() == 0 ? null : prepareOutput(accumulator);
}
@Override
public BigDecimal toBigDecimal(BigDecimal record) {
return record;
}
}
}