blob: 413093d416d83f0b9042925746357c0102486bf6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.impl.transform;
import java.math.BigDecimal;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
/** Built-in Analytic Functions for the aggregation analytics functionality. */
@SuppressWarnings({
"rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
public class BeamBuiltinAnalyticFunctions {
public static final Map<String, Function<Schema.FieldType, Combine.CombineFn<?, ?, ?>>>
BUILTIN_ANALYTIC_FACTORIES =
ImmutableMap.<String, Function<Schema.FieldType, Combine.CombineFn<?, ?, ?>>>builder()
// Aggregate Analytic Functions
.putAll(BeamBuiltinAggregations.BUILTIN_AGGREGATOR_FACTORIES)
// Navigation Functions
.put("FIRST_VALUE", typeName -> navigationFirstValue())
.put("LAST_VALUE", typeName -> navigationLastValue())
// Numbering Functions
.put("ROW_NUMBER", typeName -> numberingRowNumber())
.put("DENSE_RANK", typeName -> numberingDenseRank())
.put("RANK", typeName -> numberingRank())
.put("PERCENT_RANK", typeName -> numberingPercentRank())
.build();
public static Combine.CombineFn<?, ?, ?> create(String functionName, Schema.FieldType fieldType) {
Function<Schema.FieldType, Combine.CombineFn<?, ?, ?>> aggregatorFactory =
BUILTIN_ANALYTIC_FACTORIES.get(functionName);
if (aggregatorFactory != null) {
return aggregatorFactory.apply(fieldType);
}
throw new UnsupportedOperationException(
String.format("Analytics Function [%s] is not supported", functionName));
}
// Navigation functions
public static <T> Combine.CombineFn<T, ?, T> navigationFirstValue() {
return new FirstValueCombineFn();
}
public static <T> Combine.CombineFn<T, ?, T> navigationLastValue() {
return new LastValueCombineFn();
}
private static class FirstValueCombineFn<T> extends Combine.CombineFn<T, Optional<T>, T> {
private FirstValueCombineFn() {}
@Override
public Optional<T> createAccumulator() {
return Optional.empty();
}
@Override
public Optional<T> addInput(Optional<T> accumulator, T input) {
Optional<T> r = accumulator;
if (!accumulator.isPresent()) {
r = Optional.of(input);
}
return r;
}
@Override
public Optional<T> mergeAccumulators(Iterable<Optional<T>> accumulators) {
throw new UnsupportedOperationException();
}
@Override
public T extractOutput(Optional<T> accumulator) {
return accumulator.isPresent() ? accumulator.get() : null;
}
}
private static class LastValueCombineFn<T> extends Combine.CombineFn<T, Optional<T>, T> {
private LastValueCombineFn() {}
@Override
public Optional<T> createAccumulator() {
return Optional.empty();
}
@Override
public Optional<T> addInput(Optional<T> accumulator, T input) {
Optional<T> r = Optional.of(input);
return r;
}
@Override
public Optional<T> mergeAccumulators(Iterable<Optional<T>> accumulators) {
throw new UnsupportedOperationException();
}
@Override
public T extractOutput(Optional<T> accumulator) {
return accumulator.isPresent() ? accumulator.get() : null;
}
}
// Numbering functions
public static <T> Combine.CombineFn<T, ?, T> numberingRowNumber() {
return new RowNumberCombineFn();
}
public static <T> Combine.CombineFn<T, ?, T> numberingDenseRank() {
return new DenseRankCombineFn();
}
public static <T> Combine.CombineFn<T, ?, T> numberingRank() {
return new RankCombineFn();
}
public static <T> Combine.CombineFn<T, ?, T> numberingPercentRank() {
return new PercentRankCombineFn();
}
public abstract static class PositionAwareCombineFn<InputT, AccumT, OutputT>
extends Combine.CombineFn<InputT, AccumT, OutputT> {
public abstract AccumT addInput(
AccumT accumulator,
InputT input,
Long cursorWindow,
Long cursorPartition,
Long countPartition);
@Override
public AccumT addInput(AccumT mutableAccumulator, InputT input) {
throw new UnsupportedOperationException();
}
@Override
public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
throw new UnsupportedOperationException();
}
}
private static class RowNumberCombineFn<T>
extends PositionAwareCombineFn<BigDecimal, Optional<Long>, Long> {
@Override
public Optional<Long> addInput(
Optional<Long> accumulator,
BigDecimal input,
Long cursorPosition,
Long cursorPartition,
Long countPartition) {
return Optional.of(cursorPartition);
}
@Override
public Optional<Long> createAccumulator() {
return Optional.empty();
}
@Override
public Long extractOutput(Optional<Long> accumulator) {
// 1-based result
return accumulator.isPresent() ? accumulator.get() + 1 : null;
}
}
private static class DenseRankCombineFn<T>
extends PositionAwareCombineFn<BigDecimal, KV<BigDecimal, Long>, Long> {
@Override
public KV<BigDecimal, Long> addInput(
KV<BigDecimal, Long> accumulator,
BigDecimal input,
Long cursorPosition,
Long cursorPartition,
Long countPartition) {
KV<BigDecimal, Long> r = null;
if (accumulator == null) {
r = KV.of(input, 0L);
} else {
if (accumulator.getKey().compareTo(input) == 0) {
r = KV.of(input, accumulator.getValue());
} else {
r = KV.of(input, accumulator.getValue() + 1);
}
}
return r;
}
@Override
public KV<BigDecimal, Long> createAccumulator() {
return null;
}
@Override
public Long extractOutput(KV<BigDecimal, Long> accumulator) {
// 1-based result
return accumulator != null ? accumulator.getValue() + 1 : null;
}
}
private static class RankCombineFn<T>
extends PositionAwareCombineFn<BigDecimal, KV<BigDecimal, Long>, Long> {
@Override
public KV<BigDecimal, Long> addInput(
KV<BigDecimal, Long> accumulator,
BigDecimal input,
Long cursorPosition,
Long cursorPartition,
Long countPartition) {
KV<BigDecimal, Long> r = null;
if (accumulator == null) {
r = KV.of(input, 0L);
} else {
if (accumulator.getKey().compareTo(input) == 0) {
r = KV.of(input, accumulator.getValue());
} else {
r = KV.of(input, cursorPosition);
}
}
return r;
}
@Override
public KV<BigDecimal, Long> createAccumulator() {
return null;
}
@Override
public Long extractOutput(KV<BigDecimal, Long> accumulator) {
// 1-based result
return accumulator != null ? accumulator.getValue() + 1 : null;
}
}
private static class PercentRankCombineFn<T>
extends PositionAwareCombineFn<BigDecimal, KV<Optional<Long>, KV<BigDecimal, Long>>, Double> {
RankCombineFn internalRank;
PercentRankCombineFn() {
internalRank = new RankCombineFn();
}
@Override
public KV<Optional<Long>, KV<BigDecimal, Long>> addInput(
KV<Optional<Long>, KV<BigDecimal, Long>> accumulator,
BigDecimal input,
Long cursorPosition,
Long cursorPartition,
Long countPartition) {
KV<BigDecimal, Long> ac1 =
internalRank.addInput(
accumulator.getValue(), input, cursorPosition, cursorPartition, countPartition);
Optional<Long> ac2 = Optional.of(countPartition);
return KV.of(ac2, ac1);
}
@Override
public KV<Optional<Long>, KV<BigDecimal, Long>> createAccumulator() {
return KV.of(Optional.empty(), internalRank.createAccumulator());
}
@Override
public Double extractOutput(KV<Optional<Long>, KV<BigDecimal, Long>> accumulator) {
Long nr = accumulator.getKey().orElse(null);
Long rk = internalRank.extractOutput(accumulator.getValue());
Double r = 0.0;
if (nr != null && rk != null && nr > 1L) {
r = (rk.doubleValue() - 1) / (nr.doubleValue() - 1);
}
return r;
}
}
}