blob: b82ecaf4be9e1fceeb7d613d2eb11c083d2b2731 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.schemas.transforms;
import com.google.auto.value.AutoValue;
import java.io.Serializable;
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.schemas.FieldTypeDescriptors;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.schemas.utils.SelectHelpers;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.CombineFns;
import org.apache.beam.sdk.transforms.CombineFns.CoCombineResult;
import org.apache.beam.sdk.transforms.CombineFns.ComposedCombineFn;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
/** This is the builder used by {@link Group} to build up a composed {@link CombineFn}. */
@Experimental(Kind.SCHEMAS)
class SchemaAggregateFn {
static <T> Inner<T> create() {
return new AutoValue_SchemaAggregateFn_Inner.Builder<T>()
.setFieldAggregations(Lists.newArrayList())
.build();
}
/** Implementation of {@link #create}. */
@AutoValue
abstract static class Inner<T> extends CombineFn<T, Object[], Row> {
// Represents an aggregation of one or more fields.
static class FieldAggregation<FieldT, AccumT, OutputT> implements Serializable {
FieldAccessDescriptor fieldsToAggregate;
// The specification of the output field.
private final Field outputField;
// The combine function.
private final CombineFn<FieldT, AccumT, OutputT> fn;
// The TupleTag identifying this aggregation element in the composed combine fn.
private final TupleTag<Object> combineTag;
// The schema corresponding to the the subset of input fields being aggregated.
@Nullable private final Schema inputSubSchema;
// The flattened version of inputSubSchema.
@Nullable private final Schema unnestedInputSubSchema;
// The output schema resulting from the aggregation.
private final Schema aggregationSchema;
private final boolean needsUnnesting;
FieldAggregation(
FieldAccessDescriptor fieldsToAggregate,
Field outputField,
CombineFn<FieldT, AccumT, OutputT> fn,
TupleTag<Object> combineTag) {
this(
fieldsToAggregate,
outputField,
fn,
combineTag,
Schema.builder().addField(outputField).build(),
null);
}
FieldAggregation(
FieldAccessDescriptor fieldsToAggregate,
Field outputField,
CombineFn<FieldT, AccumT, OutputT> fn,
TupleTag<Object> combineTag,
Schema aggregationSchema,
@Nullable Schema inputSchema) {
if (inputSchema != null) {
this.fieldsToAggregate = fieldsToAggregate.resolve(inputSchema);
this.inputSubSchema = SelectHelpers.getOutputSchema(inputSchema, this.fieldsToAggregate);
this.unnestedInputSubSchema = Unnest.getUnnestedSchema(inputSubSchema);
this.needsUnnesting = !inputSchema.equals(unnestedInputSubSchema);
} else {
this.fieldsToAggregate = fieldsToAggregate;
this.inputSubSchema = null;
this.unnestedInputSubSchema = null;
this.needsUnnesting = false;
}
this.outputField = outputField;
this.fn = fn;
this.combineTag = combineTag;
this.aggregationSchema = aggregationSchema;
}
// The Schema is not necessarily known when the SchemaAggregateFn is created. Once the schema
// is known, resolve will be called with the proper schema.
FieldAggregation<FieldT, AccumT, OutputT> resolve(Schema schema) {
return new FieldAggregation<>(
fieldsToAggregate, outputField, fn, combineTag, aggregationSchema, schema);
}
}
abstract Builder<T> toBuilder();
@AutoValue.Builder
abstract static class Builder<T> {
abstract Builder<T> setInputSchema(@Nullable Schema inputSchema);
abstract Builder<T> setOutputSchema(@Nullable Schema outputSchema);
abstract Builder<T> setComposedCombineFn(@Nullable ComposedCombineFn<T> composedCombineFn);
abstract Builder<T> setFieldAggregations(List<FieldAggregation> fieldAggregations);
abstract Inner<T> build();
}
abstract @Nullable Schema getInputSchema();
abstract @Nullable Schema getOutputSchema();
abstract @Nullable ComposedCombineFn<T> getComposedCombineFn();
abstract List<FieldAggregation> getFieldAggregations();
/** Once the schema is known, this function is called by the {@link Group} transform. */
Inner<T> withSchema(Schema inputSchema, SerializableFunction<T, Row> toRowFunction) {
List<FieldAggregation> fieldAggregations =
getFieldAggregations().stream()
.map(f -> f.resolve(inputSchema))
.collect(Collectors.toList());
ComposedCombineFn<T> composedCombineFn = null;
for (int i = 0; i < fieldAggregations.size(); ++i) {
FieldAggregation fieldAggregation = fieldAggregations.get(i);
SimpleFunction<T, ?> extractFunction;
Coder extractOutputCoder;
if (fieldAggregation.unnestedInputSubSchema.getFieldCount() == 1) {
extractFunction = new ExtractSingleFieldFunction<>(fieldAggregation, toRowFunction);
extractOutputCoder =
SchemaCoder.coderForFieldType(
fieldAggregation.unnestedInputSubSchema.getField(0).getType());
} else {
extractFunction = new ExtractFieldsFunction<>(fieldAggregation, toRowFunction);
extractOutputCoder = SchemaCoder.of(fieldAggregation.inputSubSchema);
}
if (i == 0) {
composedCombineFn =
CombineFns.compose()
.with(
extractFunction,
extractOutputCoder,
fieldAggregation.fn,
fieldAggregation.combineTag);
} else {
composedCombineFn =
composedCombineFn.with(
extractFunction,
extractOutputCoder,
fieldAggregation.fn,
fieldAggregation.combineTag);
}
}
return toBuilder()
.setInputSchema(inputSchema)
.setComposedCombineFn(composedCombineFn)
.setFieldAggregations(fieldAggregations)
.build();
}
/** Aggregate all values of a set of fields into an output field. */
<CombineInputT, AccumT, CombineOutputT> Inner<T> aggregateFields(
FieldAccessDescriptor fieldsToAggregate,
CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
String outputFieldName) {
return aggregateFields(
fieldsToAggregate,
fn,
Field.of(outputFieldName, FieldTypeDescriptors.fieldTypeForJavaType(fn.getOutputType())));
}
/** Aggregate all values of a set of fields into an output field. */
<CombineInputT, AccumT, CombineOutputT> Inner<T> aggregateFields(
FieldAccessDescriptor fieldsToAggregate,
CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
Field outputField) {
List<FieldAggregation> fieldAggregations = getFieldAggregations();
TupleTag<Object> combineTag = new TupleTag<>(Integer.toString(fieldAggregations.size()));
FieldAggregation fieldAggregation =
new FieldAggregation<>(fieldsToAggregate, outputField, fn, combineTag);
fieldAggregations.add(fieldAggregation);
return toBuilder()
.setOutputSchema(getOutputSchema(fieldAggregations))
.setFieldAggregations(fieldAggregations)
.build();
}
private Schema getOutputSchema(List<FieldAggregation> fieldAggregations) {
Schema.Builder outputSchema = Schema.builder();
for (FieldAggregation aggregation : fieldAggregations) {
outputSchema.addField(aggregation.outputField);
}
return outputSchema.build();
}
/** Extract a single field from an input {@link Row}. */
private static class ExtractSingleFieldFunction<InputT, OutputT>
extends SimpleFunction<InputT, OutputT> {
private final FieldAggregation fieldAggregation;
private final SerializableFunction<InputT, Row> toRowFunction;
private ExtractSingleFieldFunction(
FieldAggregation fieldAggregation, SerializableFunction<InputT, Row> toRowFunction) {
this.fieldAggregation = fieldAggregation;
this.toRowFunction = toRowFunction;
}
@Override
public OutputT apply(InputT input) {
Row row = toRowFunction.apply(input);
Row selected =
SelectHelpers.selectRow(
row,
fieldAggregation.fieldsToAggregate,
row.getSchema(),
fieldAggregation.inputSubSchema);
if (fieldAggregation.needsUnnesting) {
selected = Unnest.unnestRow(selected, fieldAggregation.unnestedInputSubSchema);
}
return selected.getValue(0);
}
}
/** Extract multiple fields from an input {@link Row}. */
private static class ExtractFieldsFunction<T> extends SimpleFunction<T, Row> {
private FieldAggregation fieldAggregation;
private SerializableFunction<T, Row> toRowFunction;
private ExtractFieldsFunction(
FieldAggregation fieldAggregation, SerializableFunction<T, Row> toRowFunction) {
this.fieldAggregation = fieldAggregation;
this.toRowFunction = toRowFunction;
}
@Override
public Row apply(T input) {
Row row = toRowFunction.apply(input);
return SelectHelpers.selectRow(
row,
fieldAggregation.fieldsToAggregate,
row.getSchema(),
fieldAggregation.inputSubSchema);
}
}
@Override
public Object[] createAccumulator() {
return getComposedCombineFn().createAccumulator();
}
@Override
public Object[] addInput(Object[] accumulator, T input) {
return getComposedCombineFn().addInput(accumulator, input);
}
@Override
public Object[] mergeAccumulators(Iterable<Object[]> accumulator) {
return getComposedCombineFn().mergeAccumulators(accumulator);
}
@Override
public Coder<Object[]> getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder)
throws CannotProvideCoderException {
return getComposedCombineFn().getAccumulatorCoder(registry, inputCoder);
}
@Override
public Coder<Row> getDefaultOutputCoder(CoderRegistry registry, Coder<T> inputCoder) {
return SchemaCoder.of(
getOutputSchema(), SerializableFunctions.identity(), SerializableFunctions.identity());
}
@Override
public Row extractOutput(Object[] accumulator) {
// Build a row containing a field for every aggregate that was registered.
CoCombineResult coCombineResult = getComposedCombineFn().extractOutput(accumulator);
Row.Builder output = Row.withSchema(getOutputSchema());
for (FieldAggregation fieldAggregation : getFieldAggregations()) {
Object aggregate = coCombineResult.get(fieldAggregation.combineTag);
output.addValue(aggregate);
}
return output.build();
}
}
}