blob: 972438d10b522efd582c39c6e11356e1e23530c3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.impl.transform.agg;
import javax.annotation.Nullable;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.extensions.sql.impl.UdafImpl;
import org.apache.beam.sdk.extensions.sql.impl.transform.BeamBuiltinAggregations;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.values.Row;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
/** Wrapper {@link CombineFn}s for aggregation function calls. */
public class AggregationCombineFnAdapter<T> {
private abstract static class WrappedCombinerBase<T> extends CombineFn<T, Object, Object> {
CombineFn<T, Object, Object> combineFn;
WrappedCombinerBase(CombineFn<T, Object, Object> combineFn) {
this.combineFn = combineFn;
}
@Override
public Object createAccumulator() {
return combineFn.createAccumulator();
}
@Override
public Object addInput(Object accumulator, T input) {
T processedInput = getInput(input);
return (processedInput == null)
? accumulator
: combineFn.addInput(accumulator, getInput(input));
}
@Override
public Object mergeAccumulators(Iterable<Object> accumulators) {
return combineFn.mergeAccumulators(accumulators);
}
@Override
public Object extractOutput(Object accumulator) {
return combineFn.extractOutput(accumulator);
}
@Nullable
abstract T getInput(T input);
@Override
public Coder<Object> getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder)
throws CannotProvideCoderException {
return combineFn.getAccumulatorCoder(registry, inputCoder);
}
}
private static class MultiInputCombiner extends WrappedCombinerBase<Row> {
MultiInputCombiner(CombineFn<Row, Object, Object> combineFn) {
super(combineFn);
}
@Override
Row getInput(Row input) {
for (Object o : input.getValues()) {
if (o == null) {
return null;
}
}
return input;
}
}
private static class SingleInputCombiner extends WrappedCombinerBase<Object> {
SingleInputCombiner(CombineFn<Object, Object, Object> combineFn) {
super(combineFn);
}
@Override
Object getInput(Object input) {
return input;
}
}
private static class ConstantEmpty extends CombineFn<Row, Row, Row> {
private static final Schema EMPTY_SCHEMA = Schema.builder().build();
private static final Row EMPTY_ROW = Row.withSchema(EMPTY_SCHEMA).build();
public static final ConstantEmpty INSTANCE = new ConstantEmpty();
@Override
public Row createAccumulator() {
return EMPTY_ROW;
}
@Override
public Row addInput(Row accumulator, Row input) {
return EMPTY_ROW;
}
@Override
public Row mergeAccumulators(Iterable<Row> accumulators) {
return EMPTY_ROW;
}
@Override
public Row extractOutput(Row accumulator) {
return EMPTY_ROW;
}
@Override
public Coder<Row> getAccumulatorCoder(CoderRegistry registry, Coder<Row> inputCoder)
throws CannotProvideCoderException {
return SchemaCoder.of(EMPTY_SCHEMA);
}
@Override
public Coder<Row> getDefaultOutputCoder(CoderRegistry registry, Coder<Row> inputCoder) {
return SchemaCoder.of(EMPTY_SCHEMA);
}
}
/** Creates either a UDAF or a built-in {@link CombineFn}. */
public static CombineFn<?, ?, ?> createCombineFn(
AggregateCall call, Schema.Field field, String functionName) {
if (call.isDistinct()) {
throw new IllegalArgumentException(
"Does not support " + call.getAggregation().getName() + " DISTINCT");
}
CombineFn combineFn;
if (call.getAggregation() instanceof SqlUserDefinedAggFunction) {
combineFn = getUdafCombineFn(call);
} else {
combineFn = BeamBuiltinAggregations.create(functionName, field.getType());
}
if (call.getArgList().isEmpty()) {
return new SingleInputCombiner(combineFn);
} else if (call.getArgList().size() == 1) {
return new SingleInputCombiner(combineFn);
} else {
return new MultiInputCombiner(combineFn);
}
}
public static CombineFn<Row, ?, Row> createConstantCombineFn() {
return ConstantEmpty.INSTANCE;
}
private static CombineFn<?, ?, ?> getUdafCombineFn(AggregateCall call) {
try {
return ((UdafImpl) ((SqlUserDefinedAggFunction) call.getAggregation()).function)
.getCombineFn();
} catch (Exception e) {
throw new IllegalStateException(e);
}
}
}