blob: de7ef2a29f697ba29959e13b5a59538cf700451a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.impl.rel;
import java.io.Serializable;
import java.math.BigDecimal;
import java.util.Collections;
import java.util.List;
import java.util.NavigableMap;
import java.util.TreeMap;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
import org.apache.beam.sdk.extensions.sql.impl.transform.BeamBuiltinAnalyticFunctions;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.AggregationCombineFnAdapter;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelOptCluster;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelOptPlanner;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelTraitSet;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelFieldCollation;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Window;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.type.RelDataType;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexInputRef;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexLiteral;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
/**
* {@code BeamRelNode} to replace a {@code Window} node.
*
* <p>The following types of Analytic Functions are supported:
*
* <pre>{@code
* SELECT agg(c) over () FROM t
* SELECT agg(c1) over (PARTITION BY c2 ORDER BY c3 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM t
* SELECT agg(c1) over (PARTITION BY c2 ORDER BY c3 RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM t
* }</pre>
*
* <h3>Constraints</h3>
*
* <ul>
* <li>Only Aggregate Analytic Functions are available.
* </ul>
*/
@SuppressWarnings({
"rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
public class BeamWindowRel extends Window implements BeamRelNode {
public BeamWindowRel(
RelOptCluster cluster,
RelTraitSet traitSet,
RelNode input,
List<RexLiteral> constants,
RelDataType rowType,
List<Group> groups) {
super(cluster, traitSet, input, constants, rowType, groups);
}
@Override
public PTransform<PCollectionList<Row>, PCollection<Row>> buildPTransform() {
Schema outputSchema = CalciteUtils.toSchema(getRowType());
final List<FieldAggregation> analyticFields = Lists.newArrayList();
this.groups.stream()
.forEach(
anAnalyticGroup -> {
List<Integer> partitionKeysDef = anAnalyticGroup.keys.toList();
List<Integer> orderByKeys = Lists.newArrayList();
List<Boolean> orderByDirections = Lists.newArrayList();
List<Boolean> orderByNullDirections = Lists.newArrayList();
anAnalyticGroup.orderKeys.getFieldCollations().stream()
.forEach(
fc -> {
orderByKeys.add(fc.getFieldIndex());
orderByDirections.add(
fc.direction == RelFieldCollation.Direction.ASCENDING);
orderByNullDirections.add(
fc.nullDirection == RelFieldCollation.NullDirection.FIRST);
});
BigDecimal lowerB = null; // Unbounded by default
BigDecimal upperB = null; // Unbounded by default
if (anAnalyticGroup.lowerBound.isCurrentRow()) {
lowerB = BigDecimal.ZERO;
} else if (anAnalyticGroup.lowerBound.isPreceding()) {
if (!anAnalyticGroup.lowerBound.isUnbounded()) {
lowerB = getLiteralValueConstants(anAnalyticGroup.lowerBound.getOffset());
}
} else if (anAnalyticGroup.lowerBound.isFollowing()) {
if (!anAnalyticGroup.lowerBound.isUnbounded()) {
lowerB =
getLiteralValueConstants(anAnalyticGroup.lowerBound.getOffset()).negate();
}
}
if (anAnalyticGroup.upperBound.isCurrentRow()) {
upperB = BigDecimal.ZERO;
} else if (anAnalyticGroup.upperBound.isPreceding()) {
if (!anAnalyticGroup.upperBound.isUnbounded()) {
upperB =
getLiteralValueConstants(anAnalyticGroup.upperBound.getOffset()).negate();
}
} else if (anAnalyticGroup.upperBound.isFollowing()) {
if (!anAnalyticGroup.upperBound.isUnbounded()) {
upperB = getLiteralValueConstants(anAnalyticGroup.upperBound.getOffset());
}
}
final BigDecimal lowerBFinal = lowerB;
final BigDecimal upperBFinal = upperB;
List<AggregateCall> aggregateCalls = anAnalyticGroup.getAggregateCalls(this);
aggregateCalls.stream()
.forEach(
anAggCall -> {
List<Integer> argList = anAggCall.getArgList();
Schema.Field field =
CalciteUtils.toField(anAggCall.getName(), anAggCall.getType());
Combine.CombineFn combineFn =
AggregationCombineFnAdapter.createCombineFnAnalyticsFunctions(
anAggCall, field, anAggCall.getAggregation().getName());
FieldAggregation fieldAggregation =
new FieldAggregation(
partitionKeysDef,
orderByKeys,
orderByDirections,
orderByNullDirections,
lowerBFinal,
upperBFinal,
anAnalyticGroup.isRows,
argList,
combineFn,
field);
analyticFields.add(fieldAggregation);
});
});
return new Transform(outputSchema, analyticFields);
}
private BigDecimal getLiteralValueConstants(RexNode n) {
int idx = ((RexInputRef) n).getIndex() - input.getRowType().getFieldCount();
return (BigDecimal) this.constants.get(idx).getValue();
}
private static class FieldAggregation implements Serializable {
private List<Integer> partitionKeys;
private List<Integer> orderKeys;
private List<Boolean> orderOrientations;
private List<Boolean> orderNulls;
private BigDecimal lowerLimit = null;
private BigDecimal upperLimit = null;
private boolean rows = true;
private List<Integer> inputFields;
private Combine.CombineFn combineFn;
private Schema.Field outputField;
public FieldAggregation(
List<Integer> partitionKeys,
List<Integer> orderKeys,
List<Boolean> orderOrientations,
List<Boolean> orderNulls,
BigDecimal lowerLimit,
BigDecimal upperLimit,
boolean rows,
List<Integer> inputFields,
Combine.CombineFn combineFn,
Schema.Field outputField) {
this.partitionKeys = partitionKeys;
this.orderKeys = orderKeys;
this.orderOrientations = orderOrientations;
this.orderNulls = orderNulls;
this.lowerLimit = lowerLimit;
this.upperLimit = upperLimit;
this.rows = rows;
this.inputFields = inputFields;
this.combineFn = combineFn;
this.outputField = outputField;
}
}
@Override
public NodeStats estimateNodeStats(BeamRelMetadataQuery mq) {
NodeStats inputStat = BeamSqlRelUtils.getNodeStats(this.input, mq);
return inputStat;
}
/**
* A dummy cost computation based on a fixed multiplier.
*
* <p>Since, there are not additional LogicalWindow to BeamWindowRel strategies, the result of
* this cost computation has no impact in the query execution plan.
*/
@Override
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, BeamRelMetadataQuery mq) {
NodeStats inputStat = BeamSqlRelUtils.getNodeStats(this.input, mq);
float multiplier = 1f + 0.125f;
return BeamCostModel.FACTORY.makeCost(
inputStat.getRowCount() * multiplier, inputStat.getRate() * multiplier);
}
private static class Transform extends PTransform<PCollectionList<Row>, PCollection<Row>> {
private Schema outputSchema;
private List<FieldAggregation> aggFields;
public Transform(Schema schema, List<FieldAggregation> fieldAgg) {
this.outputSchema = schema;
this.aggFields = fieldAgg;
}
@Override
public PCollection<Row> expand(PCollectionList<Row> input) {
PCollection<Row> inputData = input.get(0);
Schema inputSchema = inputData.getSchema();
int ids = 0;
for (FieldAggregation af : aggFields) {
ids++;
String prefix = "transform_" + ids;
Coder<Row> rowCoder = inputData.getCoder();
PCollection<Iterable<Row>> partitioned = null;
if (af.partitionKeys.isEmpty()) {
partitioned =
inputData.apply(
prefix + "globalPartition",
org.apache.beam.sdk.schemas.transforms.Group.globally());
} else {
org.apache.beam.sdk.schemas.transforms.Group.ByFields<Row> myg =
org.apache.beam.sdk.schemas.transforms.Group.byFieldIds(af.partitionKeys);
PCollection<KV<Row, Iterable<Row>>> partitionBy =
inputData
.apply(prefix + "partitionByKV", myg.getToKV())
.apply(prefix + "partitionByGK", GroupByKey.create());
partitioned =
partitionBy
.apply(prefix + "selectOnlyValues", ParDo.of(new SelectOnlyValues()))
.setCoder(IterableCoder.of(rowCoder));
}
// Migrate to a SortedValues transform.
PCollection<List<Row>> sortedPartition =
partitioned
.apply(prefix + "orderBy", ParDo.of(sortPartition(af)))
.setCoder(ListCoder.of(rowCoder));
inputSchema =
Schema.builder().addFields(inputSchema.getFields()).addFields(af.outputField).build();
inputData =
sortedPartition
.apply(prefix + "aggCall", ParDo.of(aggField(inputSchema, af)))
.setRowSchema(inputSchema);
}
return inputData.setRowSchema(this.outputSchema);
}
}
private static DoFn<List<Row>, Row> aggField(
final Schema expectedSchema, final FieldAggregation fieldAgg) {
return new DoFn<List<Row>, Row>() {
@ProcessElement
public void processElement(
@Element List<Row> inputPartition, OutputReceiver<Row> out, ProcessContext c) {
List<Row> sortedRowsAsList = inputPartition;
NavigableMap<BigDecimal, List<Row>> indexRange = null;
if (!fieldAgg.rows) {
indexRange = indexRows(sortedRowsAsList);
}
for (int idx = 0; idx < sortedRowsAsList.size(); idx++) {
List<Row> aggRange = null;
if (fieldAgg.rows) {
aggRange = getRows(sortedRowsAsList, idx);
} else {
aggRange = getRange(indexRange, sortedRowsAsList.get(idx));
}
Object accumulator = fieldAgg.combineFn.createAccumulator();
// if not inputs are needed, put a mock Field index
final int aggFieldIndex =
fieldAgg.inputFields.isEmpty() ? -1 : fieldAgg.inputFields.get(0);
long count = 0;
for (Row aggRow : aggRange) {
if (fieldAgg.combineFn instanceof BeamBuiltinAnalyticFunctions.PositionAwareCombineFn) {
BeamBuiltinAnalyticFunctions.PositionAwareCombineFn fn =
(BeamBuiltinAnalyticFunctions.PositionAwareCombineFn) fieldAgg.combineFn;
accumulator =
fn.addInput(
accumulator,
getOrderByValue(aggRow),
count,
(long) idx,
(long) sortedRowsAsList.size());
} else {
accumulator =
fieldAgg.combineFn.addInput(accumulator, aggRow.getBaseValue(aggFieldIndex));
}
count++;
}
Object result = fieldAgg.combineFn.extractOutput(accumulator);
Row processingRow = sortedRowsAsList.get(idx);
List<Object> fieldValues = Lists.newArrayListWithCapacity(processingRow.getFieldCount());
fieldValues.addAll(processingRow.getValues());
fieldValues.add(result);
Row build = Row.withSchema(expectedSchema).addValues(fieldValues).build();
out.output(build);
}
}
private NavigableMap<BigDecimal, List<Row>> indexRows(List<Row> input) {
NavigableMap<BigDecimal, List<Row>> map = new TreeMap<BigDecimal, List<Row>>();
for (Row r : input) {
BigDecimal orderByValue = getOrderByValue(r);
if (orderByValue == null) {
/** Special case agg(X) OVER () set dummy value. */
orderByValue = BigDecimal.ZERO;
}
if (!map.containsKey(orderByValue)) {
map.put(orderByValue, Lists.newArrayList());
}
map.get(orderByValue).add(r);
}
return map;
}
private List<Row> getRange(NavigableMap<BigDecimal, List<Row>> indexRanges, Row aRow) {
NavigableMap<BigDecimal, List<Row>> subMap;
BigDecimal currentRowValue = getOrderByValue(aRow);
if (currentRowValue != null && fieldAgg.lowerLimit != null && fieldAgg.upperLimit != null) {
BigDecimal ll = currentRowValue.subtract(fieldAgg.lowerLimit);
BigDecimal ul = currentRowValue.add(fieldAgg.upperLimit);
subMap = indexRanges.subMap(ll, true, ul, true);
} else if (currentRowValue != null
&& fieldAgg.lowerLimit != null
&& fieldAgg.upperLimit == null) {
BigDecimal ll = currentRowValue.subtract(fieldAgg.lowerLimit);
subMap = indexRanges.tailMap(ll, true);
} else if (currentRowValue != null
&& fieldAgg.lowerLimit == null
&& fieldAgg.upperLimit != null) {
BigDecimal ul = currentRowValue.add(fieldAgg.upperLimit);
subMap = indexRanges.headMap(ul, true);
} else {
subMap = indexRanges;
}
List<Row> result = Lists.newArrayList();
for (List<Row> partialList : subMap.values()) {
result.addAll(partialList);
}
return result;
}
private BigDecimal getOrderByValue(Row r) {
/**
* Special Case: This query is transformed by calcite as follows: agg(X) over () -> agg(X)
* over (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) No orderKeys, so return
* null.
*/
if (fieldAgg.orderKeys.size() == 0) {
return null;
} else {
return new BigDecimal(((Number) r.getBaseValue(fieldAgg.orderKeys.get(0))).toString());
}
}
private List<Row> getRows(List<Row> input, int index) {
Integer ll =
fieldAgg.lowerLimit != null ? fieldAgg.lowerLimit.intValue() : Integer.MAX_VALUE;
Integer ul =
fieldAgg.upperLimit != null ? fieldAgg.upperLimit.intValue() : Integer.MAX_VALUE;
int lowerIndex = ll == Integer.MAX_VALUE ? Integer.MIN_VALUE : index - ll;
int upperIndex = ul == Integer.MAX_VALUE ? Integer.MAX_VALUE : index + ul + 1;
lowerIndex = lowerIndex < 0 ? 0 : lowerIndex;
upperIndex = upperIndex > input.size() ? input.size() : upperIndex;
List<Row> out = input.subList(lowerIndex, upperIndex);
return out;
}
};
}
static class SelectOnlyValues extends DoFn<KV<Row, Iterable<Row>>, Iterable<Row>> {
@ProcessElement
public void processElement(
@Element KV<Row, Iterable<Row>> inputPartition,
OutputReceiver<Iterable<Row>> out,
ProcessContext c) {
out.output(inputPartition.getValue());
}
}
private static DoFn<Iterable<Row>, List<Row>> sortPartition(final FieldAggregation fieldAgg) {
return new DoFn<Iterable<Row>, List<Row>>() {
@ProcessElement
public void processElement(
@Element Iterable<Row> inputPartition, OutputReceiver<List<Row>> out, ProcessContext c) {
List<Row> partitionRows = Lists.newArrayList(inputPartition);
BeamSortRel.BeamSqlRowComparator beamSqlRowComparator =
new BeamSortRel.BeamSqlRowComparator(
fieldAgg.orderKeys, fieldAgg.orderOrientations, fieldAgg.orderNulls);
Collections.sort(partitionRows, beamSqlRowComparator);
out.output(partitionRows);
}
};
}
@Override
public RelNode copy(RelTraitSet traitSet, List<RelNode> inputs) {
return this.copy(traitSet, sole(inputs), this.constants, this.rowType, this.groups);
}
public BeamWindowRel copy(
RelTraitSet traitSet,
RelNode input,
List<RexLiteral> constants,
RelDataType rowType,
List<Group> groups) {
return new BeamWindowRel(getCluster(), traitSet, input, constants, rowType, groups);
}
}