blob: 419cef17881eeea00fb6063dd4239f8071f9703b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.impl.rel;
import static java.util.stream.Collectors.toList;
import static org.apache.beam.sdk.values.PCollection.IsBounded.BOUNDED;
import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;
import java.io.Serializable;
import java.util.List;
import javax.annotation.Nullable;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.AggregationCombineFnAdapter;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.WithTimestamps;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.transforms.windowing.Sessions;
import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Lists;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.util.ImmutableBitSet;
import org.joda.time.Duration;
/** {@link BeamRelNode} to replace a {@link Aggregate} node. */
public class BeamAggregationRel extends Aggregate implements BeamRelNode {
private @Nullable WindowFn<Row, IntervalWindow> windowFn;
private final int windowFieldIndex;
public BeamAggregationRel(
RelOptCluster cluster,
RelTraitSet traits,
RelNode child,
boolean indicator,
ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets,
List<AggregateCall> aggCalls,
@Nullable WindowFn<Row, IntervalWindow> windowFn,
int windowFieldIndex) {
super(cluster, traits, child, indicator, groupSet, groupSets, aggCalls);
this.windowFn = windowFn;
this.windowFieldIndex = windowFieldIndex;
}
@Override
public RelWriter explainTerms(RelWriter pw) {
super.explainTerms(pw);
if (this.windowFn != null) {
WindowFn windowFn = this.windowFn;
String window = windowFn.getClass().getSimpleName() + "($" + String.valueOf(windowFieldIndex);
if (windowFn instanceof FixedWindows) {
FixedWindows fn = (FixedWindows) windowFn;
window = window + ", " + fn.getSize().toString() + ", " + fn.getOffset().toString();
} else if (windowFn instanceof SlidingWindows) {
SlidingWindows fn = (SlidingWindows) windowFn;
window =
window
+ ", "
+ fn.getPeriod().toString()
+ ", "
+ fn.getSize().toString()
+ ", "
+ fn.getOffset().toString();
} else if (windowFn instanceof Sessions) {
Sessions fn = (Sessions) windowFn;
window = window + ", " + fn.getGapDuration().toString();
} else {
throw new RuntimeException(
"Unknown window function " + windowFn.getClass().getSimpleName());
}
window = window + ")";
pw.item("window", window);
}
return pw;
}
@Override
public PTransform<PCollectionList<Row>, PCollection<Row>> buildPTransform() {
Schema outputSchema = CalciteUtils.toSchema(getRowType());
List<FieldAggregation> aggregationAdapters =
getNamedAggCalls().stream()
.map(aggCall -> new FieldAggregation(aggCall.getKey(), aggCall.getValue()))
.collect(toList());
return new Transform(
windowFn, windowFieldIndex, getGroupSet(), aggregationAdapters, outputSchema);
}
private static class FieldAggregation implements Serializable {
final List<Integer> inputs;
final CombineFn combineFn;
final Field outputField;
FieldAggregation(AggregateCall call, String alias) {
inputs = call.getArgList();
outputField = CalciteUtils.toField(alias, call.getType());
combineFn =
AggregationCombineFnAdapter.createCombineFn(
call, outputField, call.getAggregation().getName());
}
}
private static class Transform extends PTransform<PCollectionList<Row>, PCollection<Row>> {
private final List<Integer> keyFieldsIds;
private Schema outputSchema;
private WindowFn<Row, IntervalWindow> windowFn;
private int windowFieldIndex;
private List<FieldAggregation> fieldAggregations;
private Transform(
WindowFn<Row, IntervalWindow> windowFn,
int windowFieldIndex,
ImmutableBitSet groupSet,
List<FieldAggregation> fieldAggregations,
Schema outputSchema) {
this.windowFn = windowFn;
this.windowFieldIndex = windowFieldIndex;
this.fieldAggregations = fieldAggregations;
this.outputSchema = outputSchema;
this.keyFieldsIds =
groupSet.asList().stream().filter(i -> i != windowFieldIndex).collect(toList());
}
@Override
public PCollection<Row> expand(PCollectionList<Row> pinput) {
checkArgument(
pinput.size() == 1,
"Wrong number of inputs for %s: %s",
BeamAggregationRel.class.getSimpleName(),
pinput);
PCollection<Row> upstream = pinput.get(0);
PCollection<Row> windowedStream = upstream;
if (windowFn != null) {
windowedStream = assignTimestampsAndWindow(upstream);
}
validateWindowIsSupported(windowedStream);
org.apache.beam.sdk.schemas.transforms.Group.ByFields<Row> byFields =
org.apache.beam.sdk.schemas.transforms.Group.byFieldIds(keyFieldsIds);
org.apache.beam.sdk.schemas.transforms.Group.CombineFieldsByFields<Row> combined = null;
for (FieldAggregation fieldAggregation : fieldAggregations) {
List<Integer> inputs = fieldAggregation.inputs;
CombineFn combineFn = fieldAggregation.combineFn;
if (inputs.size() > 1 || inputs.isEmpty()) {
// In this path we extract a Row (an empty row if inputs.isEmpty).
combined =
(combined == null)
? byFields.aggregateFieldsById(inputs, combineFn, fieldAggregation.outputField)
: combined.aggregateFieldsById(inputs, combineFn, fieldAggregation.outputField);
} else {
// Combining over a single field, so extract just that field.
combined =
(combined == null)
? byFields.aggregateField(inputs.get(0), combineFn, fieldAggregation.outputField)
: combined.aggregateField(inputs.get(0), combineFn, fieldAggregation.outputField);
}
}
PTransform<PCollection<Row>, PCollection<KV<Row, Row>>> combiner = combined;
if (combiner == null) {
// If no field aggregations were specified, we run a constant combiner that always returns
// a single empty row for each key. This is used by the SELECT DISTINCT query plan - in this
// case a group by is generated to determine unique keys, and a constant null combiner is
// used.
combiner = byFields.aggregate(AggregationCombineFnAdapter.createConstantCombineFn());
}
return windowedStream
.apply(combiner)
.apply("mergeRecord", ParDo.of(mergeRecord(outputSchema, windowFieldIndex)))
.setRowSchema(outputSchema);
}
/** Extract timestamps from the windowFieldIndex, then window into windowFns. */
private PCollection<Row> assignTimestampsAndWindow(PCollection<Row> upstream) {
PCollection<Row> windowedStream;
windowedStream =
upstream
.apply(
"assignEventTimestamp",
WithTimestamps.<Row>of(row -> row.getDateTime(windowFieldIndex).toInstant())
.withAllowedTimestampSkew(new Duration(Long.MAX_VALUE)))
.setCoder(upstream.getCoder())
.apply(Window.into(windowFn));
return windowedStream;
}
/**
* Performs the same check as {@link GroupByKey}, provides more context in exception.
*
* <p>Verifies that the input PCollection is bounded, or that there is windowing/triggering
* being used. Without this, the watermark (at end of global window) will never be reached.
*
* <p>Throws {@link UnsupportedOperationException} if validation fails.
*/
private void validateWindowIsSupported(PCollection<Row> upstream) {
WindowingStrategy<?, ?> windowingStrategy = upstream.getWindowingStrategy();
if (windowingStrategy.getWindowFn() instanceof GlobalWindows
&& windowingStrategy.getTrigger() instanceof DefaultTrigger
&& upstream.isBounded() != BOUNDED) {
throw new UnsupportedOperationException(
"Please explicitly specify windowing in SQL query using HOP/TUMBLE/SESSION functions "
+ "(default trigger will be used in this case). "
+ "Unbounded input with global windowing and default trigger is not supported "
+ "in Beam SQL aggregations. "
+ "See GroupByKey section in Beam Programming Guide");
}
}
static DoFn<KV<Row, Row>, Row> mergeRecord(Schema outputSchema, int windowStartFieldIndex) {
return new DoFn<KV<Row, Row>, Row>() {
@ProcessElement
public void processElement(
@Element KV<Row, Row> kvRow, BoundedWindow window, OutputReceiver<Row> o) {
List<Object> fieldValues =
Lists.newArrayListWithCapacity(
kvRow.getKey().getValues().size() + kvRow.getValue().getValues().size());
fieldValues.addAll(kvRow.getKey().getValues());
fieldValues.addAll(kvRow.getValue().getValues());
if (windowStartFieldIndex != -1) {
fieldValues.add(windowStartFieldIndex, ((IntervalWindow) window).start());
}
o.output(Row.withSchema(outputSchema).addValues(fieldValues).build());
}
};
}
}
@Override
public Aggregate copy(
RelTraitSet traitSet,
RelNode input,
boolean indicator,
ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets,
List<AggregateCall> aggCalls) {
return new BeamAggregationRel(
getCluster(),
traitSet,
input,
indicator,
groupSet,
groupSets,
aggCalls,
windowFn,
windowFieldIndex);
}
}