| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.sdk.extensions.sql.zetasql.translation; |
| |
| import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_CAST; |
| import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_COLUMN_REF; |
| import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_GET_STRUCT_FIELD; |
| import static org.apache.beam.sdk.extensions.sql.zetasql.TypeUtils.toSimpleRelDataType; |
| |
| import com.google.zetasql.FunctionSignature; |
| import com.google.zetasql.ZetaSQLType.TypeKind; |
| import com.google.zetasql.resolvedast.ResolvedNode; |
| import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateFunctionCall; |
| import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateScan; |
| import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn; |
| import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedExpr; |
| import java.util.ArrayList; |
| import java.util.Collections; |
| import java.util.List; |
| import java.util.stream.Collectors; |
| import java.util.stream.IntStream; |
| import org.apache.beam.sdk.extensions.sql.zetasql.SqlStdOperatorMappingTable; |
| import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelNode; |
| import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.AggregateCall; |
| import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.logical.LogicalAggregate; |
| import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.logical.LogicalProject; |
| import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataType; |
| import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexNode; |
| import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlAggFunction; |
| import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.fun.SqlStdOperatorTable; |
| import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.util.ImmutableBitSet; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; |
| |
| /** Converts aggregate calls. */ |
| class AggregateScanConverter extends RelConverter<ResolvedAggregateScan> { |
| private static final String AVG_ILLEGAL_LONG_INPUT_TYPE = |
| "AVG(LONG) is not supported. You might want to use AVG(CAST(expression AS DOUBLE)."; |
| |
| AggregateScanConverter(ConversionContext context) { |
| super(context); |
| } |
| |
| @Override |
| public List<ResolvedNode> getInputs(ResolvedAggregateScan zetaNode) { |
| return Collections.singletonList(zetaNode.getInputScan()); |
| } |
| |
| @Override |
| public RelNode convert(ResolvedAggregateScan zetaNode, List<RelNode> inputs) { |
| RelNode input = convertAggregateScanInputScanToLogicalProject(zetaNode, inputs.get(0)); |
| |
| // Calcite LogicalAggregate's GroupSet is indexes of group fields starting from 0. |
| int groupFieldsListSize = zetaNode.getGroupByList().size(); |
| ImmutableBitSet groupSet; |
| if (groupFieldsListSize != 0) { |
| groupSet = |
| ImmutableBitSet.of( |
| IntStream.rangeClosed(0, groupFieldsListSize - 1) |
| .boxed() |
| .collect(Collectors.toList())); |
| } else { |
| groupSet = ImmutableBitSet.of(); |
| } |
| |
| // TODO: add support for indicator |
| |
| List<AggregateCall> aggregateCalls; |
| if (zetaNode.getAggregateList().isEmpty()) { |
| aggregateCalls = ImmutableList.of(); |
| } else { |
| aggregateCalls = new ArrayList<>(); |
| // For aggregate calls, their input ref follow after GROUP BY input ref. |
| int columnRefoff = groupFieldsListSize; |
| for (ResolvedComputedColumn computedColumn : zetaNode.getAggregateList()) { |
| aggregateCalls.add(convertAggCall(computedColumn, columnRefoff)); |
| columnRefoff++; |
| } |
| } |
| |
| LogicalAggregate logicalAggregate = |
| new LogicalAggregate( |
| getCluster(), |
| input.getTraitSet(), |
| input, |
| false, |
| groupSet, |
| ImmutableList.of(groupSet), |
| aggregateCalls); |
| |
| return logicalAggregate; |
| } |
| |
| private RelNode convertAggregateScanInputScanToLogicalProject( |
| ResolvedAggregateScan node, RelNode input) { |
| // AggregateScan's input is the source of data (e.g. TableScan), which is different from the |
| // design of CalciteSQL, in which the LogicalAggregate's input is a LogicalProject, whose input |
| // is a LogicalTableScan. When AggregateScan's input is WithRefScan, the WithRefScan is |
| // ebullient to a LogicalTableScan. So it's still required to build another LogicalProject as |
| // the input of LogicalAggregate. |
| List<RexNode> projects = new ArrayList<>(); |
| List<String> fieldNames = new ArrayList<>(); |
| |
| // LogicalProject has a list of expr, which including UDF in GROUP BY clause for |
| // LogicalAggregate. |
| for (ResolvedComputedColumn computedColumn : node.getGroupByList()) { |
| projects.add( |
| getExpressionConverter() |
| .convertRexNodeFromResolvedExpr( |
| computedColumn.getExpr(), |
| node.getInputScan().getColumnList(), |
| input.getRowType().getFieldList())); |
| fieldNames.add(getTrait().resolveAlias(computedColumn.getColumn())); |
| } |
| |
| // LogicalProject should also include columns used by aggregate functions. These columns should |
| // follow after GROUP BY columns. |
| // TODO: remove duplicate columns in projects. |
| for (ResolvedComputedColumn resolvedComputedColumn : node.getAggregateList()) { |
| // Should create Calcite's RexInputRef from ResolvedColumn from ResolvedComputedColumn. |
| // TODO: handle aggregate function with more than one argument and handle OVER |
| // TODO: is there is general way for column reference tracking and deduplication for |
| // aggregation? |
| ResolvedAggregateFunctionCall aggregateFunctionCall = |
| ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr()); |
| if (aggregateFunctionCall.getArgumentList() != null |
| && aggregateFunctionCall.getArgumentList().size() == 1) { |
| ResolvedExpr resolvedExpr = aggregateFunctionCall.getArgumentList().get(0); |
| |
| // TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef). |
| // TODO: user might use multiple CAST so we need to handle this rare case. |
| projects.add( |
| getExpressionConverter() |
| .convertRexNodeFromResolvedExpr( |
| resolvedExpr, |
| node.getInputScan().getColumnList(), |
| input.getRowType().getFieldList())); |
| fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn())); |
| } else if (aggregateFunctionCall.getArgumentList() != null |
| && aggregateFunctionCall.getArgumentList().size() > 1) { |
| throw new RuntimeException( |
| aggregateFunctionCall.getFunction().getName() + " has more than one argument."); |
| } |
| } |
| |
| return LogicalProject.create(input, projects, fieldNames); |
| } |
| |
| private AggregateCall convertAggCall(ResolvedComputedColumn computedColumn, int columnRefOff) { |
| ResolvedAggregateFunctionCall aggregateFunctionCall = |
| (ResolvedAggregateFunctionCall) computedColumn.getExpr(); |
| |
| // Reject AVG(INT64) |
| if (aggregateFunctionCall.getFunction().getName().equals("avg")) { |
| FunctionSignature signature = aggregateFunctionCall.getSignature(); |
| if (signature |
| .getFunctionArgumentList() |
| .get(0) |
| .getType() |
| .getKind() |
| .equals(TypeKind.TYPE_INT64)) { |
| throw new RuntimeException(AVG_ILLEGAL_LONG_INPUT_TYPE); |
| } |
| } |
| |
| // Reject aggregation DISTINCT |
| if (aggregateFunctionCall.getDistinct()) { |
| throw new RuntimeException( |
| "Does not support " |
| + aggregateFunctionCall.getFunction().getSqlName() |
| + " DISTINCT. 'SELECT DISTINCT' syntax could be used to deduplicate before" |
| + " aggregation."); |
| } |
| |
| SqlAggFunction sqlAggFunction = |
| (SqlAggFunction) |
| SqlStdOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get( |
| aggregateFunctionCall.getFunction().getName()); |
| if (sqlAggFunction == null) { |
| throw new RuntimeException( |
| "Does not support ZetaSQL aggregate function: " |
| + aggregateFunctionCall.getFunction().getName()); |
| } |
| |
| List<Integer> argList = new ArrayList<>(); |
| for (ResolvedExpr expr : |
| ((ResolvedAggregateFunctionCall) computedColumn.getExpr()).getArgumentList()) { |
| // Throw an error if aggregate function's input isn't either a ColumnRef or a cast(ColumnRef). |
| // TODO: is there a general way to handle aggregation calls conversion? |
| if (expr.nodeKind() == RESOLVED_CAST |
| || expr.nodeKind() == RESOLVED_COLUMN_REF |
| || expr.nodeKind() == RESOLVED_GET_STRUCT_FIELD) { |
| argList.add(columnRefOff); |
| } else { |
| throw new RuntimeException( |
| "Aggregate function only accepts Column Reference or CAST(Column Reference) as its" |
| + " input."); |
| } |
| } |
| |
| // TODO: there should be a general way to decide if a return type of a aggcall is nullable. |
| RelDataType returnType; |
| if (sqlAggFunction.equals(SqlStdOperatorTable.ANY_VALUE)) { |
| returnType = |
| toSimpleRelDataType( |
| computedColumn.getColumn().getType().getKind(), getCluster().getRexBuilder(), true); |
| } else { |
| returnType = |
| toSimpleRelDataType( |
| computedColumn.getColumn().getType().getKind(), getCluster().getRexBuilder(), false); |
| } |
| |
| String aggName = getTrait().resolveAlias(computedColumn.getColumn()); |
| return AggregateCall.create(sqlAggFunction, false, false, argList, -1, returnType, aggName); |
| } |
| } |