blob: 03395921de276f3b5278918b48d11304044673f6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.zetasql.translation;
import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_CAST;
import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_COLUMN_REF;
import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_GET_STRUCT_FIELD;
import static org.apache.beam.sdk.extensions.sql.zetasql.TypeUtils.toSimpleRelDataType;
import com.google.zetasql.FunctionSignature;
import com.google.zetasql.ZetaSQLType.TypeKind;
import com.google.zetasql.resolvedast.ResolvedNode;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateFunctionCall;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateScan;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedExpr;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.beam.sdk.extensions.sql.zetasql.SqlStdOperatorMappingTable;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.AggregateCall;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.logical.LogicalProject;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataType;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlAggFunction;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.util.ImmutableBitSet;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
/** Converts aggregate calls. */
class AggregateScanConverter extends RelConverter<ResolvedAggregateScan> {
private static final String AVG_ILLEGAL_LONG_INPUT_TYPE =
"AVG(LONG) is not supported. You might want to use AVG(CAST(expression AS DOUBLE).";
AggregateScanConverter(ConversionContext context) {
super(context);
}
@Override
public List<ResolvedNode> getInputs(ResolvedAggregateScan zetaNode) {
return Collections.singletonList(zetaNode.getInputScan());
}
@Override
public RelNode convert(ResolvedAggregateScan zetaNode, List<RelNode> inputs) {
RelNode input = convertAggregateScanInputScanToLogicalProject(zetaNode, inputs.get(0));
// Calcite LogicalAggregate's GroupSet is indexes of group fields starting from 0.
int groupFieldsListSize = zetaNode.getGroupByList().size();
ImmutableBitSet groupSet;
if (groupFieldsListSize != 0) {
groupSet =
ImmutableBitSet.of(
IntStream.rangeClosed(0, groupFieldsListSize - 1)
.boxed()
.collect(Collectors.toList()));
} else {
groupSet = ImmutableBitSet.of();
}
// TODO: add support for indicator
List<AggregateCall> aggregateCalls;
if (zetaNode.getAggregateList().isEmpty()) {
aggregateCalls = ImmutableList.of();
} else {
aggregateCalls = new ArrayList<>();
// For aggregate calls, their input ref follow after GROUP BY input ref.
int columnRefoff = groupFieldsListSize;
for (ResolvedComputedColumn computedColumn : zetaNode.getAggregateList()) {
aggregateCalls.add(convertAggCall(computedColumn, columnRefoff));
columnRefoff++;
}
}
LogicalAggregate logicalAggregate =
new LogicalAggregate(
getCluster(),
input.getTraitSet(),
input,
false,
groupSet,
ImmutableList.of(groupSet),
aggregateCalls);
return logicalAggregate;
}
private RelNode convertAggregateScanInputScanToLogicalProject(
ResolvedAggregateScan node, RelNode input) {
// AggregateScan's input is the source of data (e.g. TableScan), which is different from the
// design of CalciteSQL, in which the LogicalAggregate's input is a LogicalProject, whose input
// is a LogicalTableScan. When AggregateScan's input is WithRefScan, the WithRefScan is
// ebullient to a LogicalTableScan. So it's still required to build another LogicalProject as
// the input of LogicalAggregate.
List<RexNode> projects = new ArrayList<>();
List<String> fieldNames = new ArrayList<>();
// LogicalProject has a list of expr, which including UDF in GROUP BY clause for
// LogicalAggregate.
for (ResolvedComputedColumn computedColumn : node.getGroupByList()) {
projects.add(
getExpressionConverter()
.convertRexNodeFromResolvedExpr(
computedColumn.getExpr(),
node.getInputScan().getColumnList(),
input.getRowType().getFieldList()));
fieldNames.add(getTrait().resolveAlias(computedColumn.getColumn()));
}
// LogicalProject should also include columns used by aggregate functions. These columns should
// follow after GROUP BY columns.
// TODO: remove duplicate columns in projects.
for (ResolvedComputedColumn resolvedComputedColumn : node.getAggregateList()) {
// Should create Calcite's RexInputRef from ResolvedColumn from ResolvedComputedColumn.
// TODO: handle aggregate function with more than one argument and handle OVER
// TODO: is there is general way for column reference tracking and deduplication for
// aggregation?
ResolvedAggregateFunctionCall aggregateFunctionCall =
((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr());
if (aggregateFunctionCall.getArgumentList() != null
&& aggregateFunctionCall.getArgumentList().size() == 1) {
ResolvedExpr resolvedExpr = aggregateFunctionCall.getArgumentList().get(0);
// TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef).
// TODO: user might use multiple CAST so we need to handle this rare case.
projects.add(
getExpressionConverter()
.convertRexNodeFromResolvedExpr(
resolvedExpr,
node.getInputScan().getColumnList(),
input.getRowType().getFieldList()));
fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn()));
} else if (aggregateFunctionCall.getArgumentList() != null
&& aggregateFunctionCall.getArgumentList().size() > 1) {
throw new RuntimeException(
aggregateFunctionCall.getFunction().getName() + " has more than one argument.");
}
}
return LogicalProject.create(input, projects, fieldNames);
}
private AggregateCall convertAggCall(ResolvedComputedColumn computedColumn, int columnRefOff) {
ResolvedAggregateFunctionCall aggregateFunctionCall =
(ResolvedAggregateFunctionCall) computedColumn.getExpr();
// Reject AVG(INT64)
if (aggregateFunctionCall.getFunction().getName().equals("avg")) {
FunctionSignature signature = aggregateFunctionCall.getSignature();
if (signature
.getFunctionArgumentList()
.get(0)
.getType()
.getKind()
.equals(TypeKind.TYPE_INT64)) {
throw new RuntimeException(AVG_ILLEGAL_LONG_INPUT_TYPE);
}
}
// Reject aggregation DISTINCT
if (aggregateFunctionCall.getDistinct()) {
throw new RuntimeException(
"Does not support "
+ aggregateFunctionCall.getFunction().getSqlName()
+ " DISTINCT. 'SELECT DISTINCT' syntax could be used to deduplicate before"
+ " aggregation.");
}
SqlAggFunction sqlAggFunction =
(SqlAggFunction)
SqlStdOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get(
aggregateFunctionCall.getFunction().getName());
if (sqlAggFunction == null) {
throw new RuntimeException(
"Does not support ZetaSQL aggregate function: "
+ aggregateFunctionCall.getFunction().getName());
}
List<Integer> argList = new ArrayList<>();
for (ResolvedExpr expr :
((ResolvedAggregateFunctionCall) computedColumn.getExpr()).getArgumentList()) {
// Throw an error if aggregate function's input isn't either a ColumnRef or a cast(ColumnRef).
// TODO: is there a general way to handle aggregation calls conversion?
if (expr.nodeKind() == RESOLVED_CAST
|| expr.nodeKind() == RESOLVED_COLUMN_REF
|| expr.nodeKind() == RESOLVED_GET_STRUCT_FIELD) {
argList.add(columnRefOff);
} else {
throw new RuntimeException(
"Aggregate function only accepts Column Reference or CAST(Column Reference) as its"
+ " input.");
}
}
// TODO: there should be a general way to decide if a return type of a aggcall is nullable.
RelDataType returnType;
if (sqlAggFunction.equals(SqlStdOperatorTable.ANY_VALUE)) {
returnType =
toSimpleRelDataType(
computedColumn.getColumn().getType().getKind(), getCluster().getRexBuilder(), true);
} else {
returnType =
toSimpleRelDataType(
computedColumn.getColumn().getType().getKind(), getCluster().getRexBuilder(), false);
}
String aggName = getTrait().resolveAlias(computedColumn.getColumn());
return AggregateCall.create(sqlAggFunction, false, false, argList, -1, returnType, aggName);
}
}