blob: 652aabd75d20cbe820c26d747dbce426d4f617ab [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.zetasql.translation;
import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_FUNCTION_CALL;
import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_BOOL;
import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_BYTES;
import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_DOUBLE;
import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_INT64;
import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_STRING;
import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_TIMESTAMP;
import static org.apache.beam.sdk.extensions.sql.zetasql.DateTimeUtils.convertDateValueToDateString;
import static org.apache.beam.sdk.extensions.sql.zetasql.DateTimeUtils.convertTimeValueToTimeString;
import static org.apache.beam.sdk.extensions.sql.zetasql.DateTimeUtils.safeMicrosToMillis;
import static org.apache.beam.sdk.extensions.sql.zetasql.SqlStdOperatorMappingTable.FUNCTION_FAMILY_DATE_ADD;
import static org.apache.beam.sdk.extensions.sql.zetasql.ZetaSQLCastFunctionImpl.ZETASQL_CAST_OP;
import com.google.common.base.Ascii;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.zetasql.ArrayType;
import com.google.zetasql.Type;
import com.google.zetasql.Value;
import com.google.zetasql.ZetaSQLType.TypeKind;
import com.google.zetasql.functions.ZetaSQLDateTime.DateTimestampPart;
import com.google.zetasql.resolvedast.ResolvedColumn;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateScan;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedCast;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedColumnRef;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedExpr;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedFunctionCall;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedGetStructField;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedLiteral;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedOrderByScan;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedParameter;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedProjectScan;
import io.grpc.Status;
import java.math.BigDecimal;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.extensions.sql.zetasql.SqlOperatorRewriter;
import org.apache.beam.sdk.extensions.sql.zetasql.SqlOperators;
import org.apache.beam.sdk.extensions.sql.zetasql.SqlStdOperatorMappingTable;
import org.apache.beam.sdk.extensions.sql.zetasql.TypeUtils;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.avatica.util.ByteString;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.avatica.util.TimeUnit;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.avatica.util.TimeUnitRange;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptCluster;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataType;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexBuilder;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexLiteral;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlOperator;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.type.SqlTypeName;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.util.TimestampString;
/**
* Extracts expressions (function calls, field accesses) from the resolve query nodes, converts them
* to RexNodes.
*/
@Internal
public class ExpressionConverter {
private static final String PRE_DEFINED_WINDOW_FUNCTIONS = "pre_defined_window_functions";
// Constants of pre-defined functions.
private static final String WINDOW_START = "_START";
private static final String WINDOW_END = "_END";
private static final String FIXED_WINDOW = "TUMBLE";
private static final String FIXED_WINDOW_START = FIXED_WINDOW + WINDOW_START;
private static final String FIXED_WINDOW_END = FIXED_WINDOW + WINDOW_END;
private static final String SLIDING_WINDOW = "HOP";
private static final String SLIDING_WINDOW_START = SLIDING_WINDOW + WINDOW_START;
private static final String SLIDING_WINDOW_END = SLIDING_WINDOW + WINDOW_END;
private static final String SESSION_WINDOW = "SESSION";
private static final String SESSION_WINDOW_START = SESSION_WINDOW + WINDOW_START;
private static final String SESSION_WINDOW_END = SESSION_WINDOW + WINDOW_END;
private static final ImmutableMap<String, String> WINDOW_START_END_TO_WINDOW_MAP =
ImmutableMap.<String, String>builder()
.put(FIXED_WINDOW_START, FIXED_WINDOW)
.put(FIXED_WINDOW_END, FIXED_WINDOW)
.put(SLIDING_WINDOW_START, SLIDING_WINDOW)
.put(SLIDING_WINDOW_END, SLIDING_WINDOW)
.put(SESSION_WINDOW_START, SESSION_WINDOW)
.put(SESSION_WINDOW_END, SESSION_WINDOW)
.build();
private static final ImmutableSet<String> WINDOW_START_END_FUNCTION_SET =
ImmutableSet.of(
FIXED_WINDOW_START,
FIXED_WINDOW_END,
SLIDING_WINDOW_START,
SLIDING_WINDOW_END,
SESSION_WINDOW_START,
SESSION_WINDOW_END);
private static final ImmutableMap<TypeKind, ImmutableSet<TypeKind>> UNSUPPORTED_CASTING =
ImmutableMap.<TypeKind, ImmutableSet<TypeKind>>builder()
.put(TYPE_INT64, ImmutableSet.of(TYPE_DOUBLE))
.put(TYPE_BOOL, ImmutableSet.of(TYPE_STRING))
.put(TYPE_STRING, ImmutableSet.of(TYPE_BOOL, TYPE_DOUBLE))
.build();
private static final ImmutableMap<Integer, TimeUnit> TIME_UNIT_CASTING_MAP =
ImmutableMap.<Integer, TimeUnit>builder()
.put(DateTimestampPart.YEAR.getNumber(), TimeUnit.YEAR)
.put(DateTimestampPart.MONTH.getNumber(), TimeUnit.MONTH)
.put(DateTimestampPart.DAY.getNumber(), TimeUnit.DAY)
.put(DateTimestampPart.DAYOFWEEK.getNumber(), TimeUnit.DOW)
.put(DateTimestampPart.DAYOFYEAR.getNumber(), TimeUnit.DOY)
.put(DateTimestampPart.QUARTER.getNumber(), TimeUnit.QUARTER)
.put(DateTimestampPart.HOUR.getNumber(), TimeUnit.HOUR)
.put(DateTimestampPart.MINUTE.getNumber(), TimeUnit.MINUTE)
.put(DateTimestampPart.SECOND.getNumber(), TimeUnit.SECOND)
.put(DateTimestampPart.MILLISECOND.getNumber(), TimeUnit.MILLISECOND)
.put(DateTimestampPart.MICROSECOND.getNumber(), TimeUnit.MICROSECOND)
.put(DateTimestampPart.NANOSECOND.getNumber(), TimeUnit.NANOSECOND)
.put(DateTimestampPart.ISOYEAR.getNumber(), TimeUnit.ISOYEAR)
.put(DateTimestampPart.ISOWEEK.getNumber(), TimeUnit.WEEK)
.build();
private static final ImmutableSet<String> DATE_PART_UNITS_TO_MILLIS =
ImmutableSet.of("DAY", "HOUR", "MINUTE", "SECOND");
private static final ImmutableSet<String> DATE_PART_UNITS_TO_MONTHS = ImmutableSet.of("YEAR");
private static final long ONE_SECOND_IN_MILLIS = 1000L;
private static final long ONE_MINUTE_IN_MILLIS = 60L * ONE_SECOND_IN_MILLIS;
private static final long ONE_HOUR_IN_MILLIS = 60L * ONE_MINUTE_IN_MILLIS;
private static final long ONE_DAY_IN_MILLIS = 24L * ONE_HOUR_IN_MILLIS;
@SuppressWarnings("unused")
private static final long ONE_MONTH_IN_MILLIS = 30L * ONE_DAY_IN_MILLIS;
@SuppressWarnings("unused")
private static final long ONE_YEAR_IN_MILLIS = 365L * ONE_DAY_IN_MILLIS;
// Constants of error messages.
private static final String INTERVAL_DATE_PART_MSG =
"YEAR, QUARTER, MONTH, WEEK, DAY, HOUR, MINUTE, SECOND, MILLISECOND";
private static final String INTERVAL_FORMAT_MSG =
"INTERVAL should be set as a STRING in the specific format: \"INTERVAL int64 date_part\"."
+ " The date_part includes: "
+ INTERVAL_DATE_PART_MSG;
private final RelOptCluster cluster;
private final Map<String, Value> queryParams;
public ExpressionConverter(RelOptCluster cluster, Map<String, Value> params) {
this.cluster = cluster;
this.queryParams = params;
}
/** Extract expressions from a project scan node. */
public List<RexNode> retrieveRexNode(ResolvedProjectScan node, List<RelDataTypeField> fieldList) {
List<RexNode> ret = new ArrayList<>();
for (ResolvedColumn column : node.getColumnList()) {
int index = -1;
if ((index = indexOfResolvedColumnInExprList(node.getExprList(), column)) != -1) {
ResolvedComputedColumn computedColumn = node.getExprList().get(index);
int windowFieldIndex = -1;
if (computedColumn.getExpr().nodeKind() == RESOLVED_FUNCTION_CALL) {
String functionName =
((ResolvedFunctionCall) computedColumn.getExpr()).getFunction().getName();
if (WINDOW_START_END_FUNCTION_SET.contains(functionName)) {
ResolvedAggregateScan resolvedAggregateScan =
(ResolvedAggregateScan) node.getInputScan();
windowFieldIndex =
indexOfWindowField(
resolvedAggregateScan.getGroupByList(),
resolvedAggregateScan.getColumnList(),
WINDOW_START_END_TO_WINDOW_MAP.get(functionName));
}
}
ret.add(
convertRexNodeFromComputedColumnWithFieldList(
computedColumn, node.getInputScan().getColumnList(), fieldList, windowFieldIndex));
} else {
// ResolvedColumn is not a expression, which means it has to be an input column reference.
index = indexOfProjectionColumnRef(column.getId(), node.getInputScan().getColumnList());
if (index < 0 || index >= node.getInputScan().getColumnList().size()) {
throw new RuntimeException(
String.format("Cannot find %s in fieldList %s", column, fieldList));
}
ret.add(rexBuilder().makeInputRef(fieldList.get(index).getType(), index));
}
}
return ret;
}
/** Extract expressions from order by scan node. */
public List<RexNode> retrieveRexNodeFromOrderByScan(
RelOptCluster cluster, ResolvedOrderByScan node, List<RelDataTypeField> fieldList) {
final RexBuilder rexBuilder = cluster.getRexBuilder();
List<RexNode> ret = new ArrayList<>();
for (ResolvedColumn column : node.getColumnList()) {
int index = indexOfProjectionColumnRef(column.getId(), node.getInputScan().getColumnList());
ret.add(rexBuilder.makeInputRef(fieldList.get(index).getType(), index));
}
return ret;
}
private static int indexOfResolvedColumnInExprList(
ImmutableList<ResolvedComputedColumn> exprList, ResolvedColumn column) {
if (exprList == null || exprList.isEmpty()) {
return -1;
}
for (int i = 0; i < exprList.size(); i++) {
ResolvedComputedColumn computedColumn = exprList.get(i);
if (computedColumn.getColumn().equals(column)) {
return i;
}
}
return -1;
}
private static int indexOfWindowField(
List<ResolvedComputedColumn> groupByList, List<ResolvedColumn> columnList, String windowFn) {
for (ResolvedComputedColumn groupByComputedColumn : groupByList) {
if (groupByComputedColumn.getExpr().nodeKind() == RESOLVED_FUNCTION_CALL) {
ResolvedFunctionCall functionCall = (ResolvedFunctionCall) groupByComputedColumn.getExpr();
if (functionCall.getFunction().getName().equals(windowFn)) {
int ret =
indexOfResolvedColumnInColumnList(columnList, groupByComputedColumn.getColumn());
if (ret == -1) {
throw new RuntimeException("Cannot find " + windowFn + " in " + groupByList);
} else {
return ret;
}
}
}
}
throw new RuntimeException("Cannot find " + windowFn + " in " + groupByList);
}
private static int indexOfResolvedColumnInColumnList(
List<ResolvedColumn> columnList, ResolvedColumn column) {
if (columnList == null || columnList.isEmpty()) {
return -1;
}
for (int i = 0; i < columnList.size(); i++) {
if (columnList.get(i).equals(column)) {
return i;
}
}
return -1;
}
/** Create a RexNode for a corresponding resolved expression node. */
public RexNode convertRexNodeFromResolvedExpr(
ResolvedExpr expr, List<ResolvedColumn> columnList, List<RelDataTypeField> fieldList) {
if (columnList == null || fieldList == null) {
return convertRexNodeFromResolvedExpr(expr);
}
RexNode ret;
switch (expr.nodeKind()) {
case RESOLVED_LITERAL:
ret = convertResolvedLiteral((ResolvedLiteral) expr);
break;
case RESOLVED_COLUMN_REF:
ret = convertResolvedColumnRef((ResolvedColumnRef) expr, columnList, fieldList);
break;
case RESOLVED_FUNCTION_CALL:
ret = convertResolvedFunctionCall((ResolvedFunctionCall) expr, columnList, fieldList);
break;
case RESOLVED_CAST:
ret = convertResolvedCast((ResolvedCast) expr, columnList, fieldList);
break;
case RESOLVED_PARAMETER:
ret = convertResolvedParameter((ResolvedParameter) expr);
break;
case RESOLVED_GET_STRUCT_FIELD:
ret =
convertResolvedStructFieldAccess((ResolvedGetStructField) expr, columnList, fieldList);
break;
default:
ret = convertRexNodeFromResolvedExpr(expr);
}
return ret;
}
/** Create a RexNode for a corresponding resolved expression. */
public RexNode convertRexNodeFromResolvedExpr(ResolvedExpr expr) {
RexNode ret;
switch (expr.nodeKind()) {
case RESOLVED_LITERAL:
ret = convertResolvedLiteral((ResolvedLiteral) expr);
break;
case RESOLVED_COLUMN_REF:
ret = convertResolvedColumnRef((ResolvedColumnRef) expr);
break;
case RESOLVED_FUNCTION_CALL:
// TODO: is there a better way to shared code for different cases of
// convertResolvedFunctionCall than passing into two nulls?
ret = convertResolvedFunctionCall((ResolvedFunctionCall) expr, null, null);
break;
case RESOLVED_CAST:
ret = convertResolvedCast((ResolvedCast) expr, null, null);
break;
case RESOLVED_PARAMETER:
ret = convertResolvedParameter((ResolvedParameter) expr);
break;
case RESOLVED_GET_STRUCT_FIELD:
ret = convertResolvedStructFieldAccess((ResolvedGetStructField) expr);
break;
case RESOLVED_SUBQUERY_EXPR:
throw new IllegalArgumentException("Does not support sub-queries");
default:
throw new RuntimeException("Does not support expr node kind " + expr.nodeKind());
}
return ret;
}
/** Extract the RexNode from expression with ref scan. */
public RexNode convertRexNodeFromResolvedExprWithRefScan(
ResolvedExpr expr,
List<ResolvedColumn> refScanLeftColumnList,
List<RelDataTypeField> leftFieldList,
List<ResolvedColumn> originalLeftColumnList,
List<ResolvedColumn> refScanRightColumnList,
List<RelDataTypeField> rightFieldList,
List<ResolvedColumn> originalRightColumnList) {
RexNode ret;
switch (expr.nodeKind()) {
case RESOLVED_LITERAL:
ret = convertResolvedLiteral((ResolvedLiteral) expr);
break;
case RESOLVED_COLUMN_REF:
ResolvedColumnRef columnRef = (ResolvedColumnRef) expr;
// first look for column ref on the left side
ret =
convertRexNodeFromResolvedColumnRefWithRefScan(
columnRef, refScanLeftColumnList, originalLeftColumnList, leftFieldList);
// if not found there look on the right
if (ret == null) {
ret =
convertRexNodeFromResolvedColumnRefWithRefScan(
columnRef, refScanRightColumnList, originalRightColumnList, rightFieldList);
}
break;
case RESOLVED_FUNCTION_CALL:
// JOIN only support equal join.
ResolvedFunctionCall resolvedFunctionCall = (ResolvedFunctionCall) expr;
List<RexNode> operands = new ArrayList<>();
for (ResolvedExpr resolvedExpr : resolvedFunctionCall.getArgumentList()) {
operands.add(
convertRexNodeFromResolvedExprWithRefScan(
resolvedExpr,
refScanLeftColumnList,
leftFieldList,
originalLeftColumnList,
refScanRightColumnList,
rightFieldList,
originalRightColumnList));
}
SqlOperator op =
SqlStdOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get(
resolvedFunctionCall.getFunction().getName());
ret = rexBuilder().makeCall(op, operands);
break;
case RESOLVED_CAST:
ResolvedCast resolvedCast = (ResolvedCast) expr;
RexNode operand =
convertRexNodeFromResolvedExprWithRefScan(
resolvedCast.getExpr(),
refScanLeftColumnList,
leftFieldList,
originalLeftColumnList,
refScanRightColumnList,
rightFieldList,
originalRightColumnList);
TypeKind fromType = resolvedCast.getExpr().getType().getKind();
TypeKind toType = resolvedCast.getType().getKind();
isCastingSupported(fromType, toType);
RelDataType outputType =
TypeUtils.toSimpleRelDataType(toType, rexBuilder(), operand.getType().isNullable());
if (isZetaSQLCast(fromType, toType)) {
ret = rexBuilder().makeCall(outputType, ZETASQL_CAST_OP, ImmutableList.of(operand));
} else {
ret = rexBuilder().makeCast(outputType, operand);
}
break;
default:
throw new RuntimeException("Does not support expr node kind " + expr.nodeKind());
}
return ret;
}
private RexNode convertRexNodeFromComputedColumnWithFieldList(
ResolvedComputedColumn column,
List<ResolvedColumn> columnList,
List<RelDataTypeField> fieldList,
int windowFieldIndex) {
if (column.getExpr().nodeKind() != RESOLVED_FUNCTION_CALL) {
return convertRexNodeFromResolvedExpr(column.getExpr(), columnList, fieldList);
}
ResolvedFunctionCall functionCall = (ResolvedFunctionCall) column.getExpr();
// TODO: is there any other illegal case?
if (functionCall.getFunction().getName().equals(FIXED_WINDOW)
|| functionCall.getFunction().getName().equals(SLIDING_WINDOW)
|| functionCall.getFunction().getName().equals(SESSION_WINDOW)) {
throw new RuntimeException(
functionCall.getFunction().getName() + " shouldn't appear in SELECT exprlist.");
}
if (!functionCall.getFunction().getGroup().equals(PRE_DEFINED_WINDOW_FUNCTIONS)) {
// non-window function should still go through normal FunctionCall conversion process.
return convertRexNodeFromResolvedExpr(column.getExpr(), columnList, fieldList);
}
// ONLY window_start and window_end should arrive here.
// TODO: Have extra verification here to make sure window start/end functions have the same
// parameter with window function.
List<RexNode> operands = new ArrayList<>();
switch (functionCall.getFunction().getName()) {
case FIXED_WINDOW_START:
case SLIDING_WINDOW_START:
case SESSION_WINDOW_START:
// TODO: in Calcite implementation, session window's start is equal to end. Need to fix it
// in Calcite.
case SESSION_WINDOW_END:
return rexBuilder()
.makeInputRef(fieldList.get(windowFieldIndex).getType(), windowFieldIndex);
case FIXED_WINDOW_END:
// WINDOW END is a function call
operands.add(
rexBuilder().makeInputRef(fieldList.get(windowFieldIndex).getType(), windowFieldIndex));
// TODO: check window_end 's duration is the same as it's aggregate window.
operands.add(
convertIntervalToRexIntervalLiteral(
(ResolvedLiteral) functionCall.getArgumentList().get(0)));
return rexBuilder().makeCall(SqlStdOperatorTable.PLUS, operands);
case SLIDING_WINDOW_END:
operands.add(
rexBuilder().makeInputRef(fieldList.get(windowFieldIndex).getType(), windowFieldIndex));
operands.add(
convertIntervalToRexIntervalLiteral(
(ResolvedLiteral) functionCall.getArgumentList().get(1)));
return rexBuilder().makeCall(SqlStdOperatorTable.PLUS, operands);
default:
throw new RuntimeException(
"Does not support window start/end: " + functionCall.getFunction().getName());
}
}
/** Convert a resolved literal to a RexNode. */
public RexNode convertResolvedLiteral(ResolvedLiteral resolvedLiteral) {
TypeKind kind = resolvedLiteral.getType().getKind();
RexNode ret;
switch (kind) {
case TYPE_BOOL:
case TYPE_INT32:
case TYPE_INT64:
case TYPE_FLOAT:
case TYPE_DOUBLE:
case TYPE_STRING:
case TYPE_TIMESTAMP:
case TYPE_DATE:
case TYPE_TIME:
// case TYPE_DATETIME:
case TYPE_BYTES:
case TYPE_ARRAY:
case TYPE_STRUCT:
case TYPE_ENUM:
ret = convertValueToRexNode(resolvedLiteral.getType(), resolvedLiteral.getValue());
break;
default:
throw new RuntimeException(
MessageFormat.format(
"Unsupported ResolvedLiteral type: {0}, kind: {1}, value: {2}, class: {3}",
resolvedLiteral.getType().typeName(),
kind,
resolvedLiteral.getValue(),
resolvedLiteral.getClass()));
}
return ret;
}
private RexNode convertValueToRexNode(Type type, Value value) {
RexNode ret;
switch (type.getKind()) {
case TYPE_BOOL:
case TYPE_INT32:
case TYPE_INT64:
case TYPE_FLOAT:
case TYPE_DOUBLE:
case TYPE_STRING:
case TYPE_TIMESTAMP:
case TYPE_DATE:
case TYPE_TIME:
// case TYPE_DATETIME:
case TYPE_BYTES:
ret = convertSimpleValueToRexNode(type.getKind(), value);
break;
case TYPE_ARRAY:
ret = convertArrayValueToRexNode(type.asArray(), value);
break;
case TYPE_ENUM:
ret = convertEnumToRexNode(type, value);
break;
default:
// TODO: convert struct literal.
throw new RuntimeException(
"Unsupported ResolvedLiteral kind: " + type.getKind() + " type: " + type.typeName());
}
return ret;
}
private RexNode convertSimpleValueToRexNode(TypeKind kind, Value value) {
if (value.isNull()) {
return rexBuilder().makeNullLiteral(TypeUtils.toSimpleRelDataType(kind, rexBuilder()));
}
RexNode ret;
switch (kind) {
case TYPE_BOOL:
ret = rexBuilder().makeLiteral(value.getBoolValue());
break;
case TYPE_INT32:
ret =
rexBuilder()
.makeExactLiteral(
new BigDecimal(value.getInt32Value()),
TypeUtils.toSimpleRelDataType(kind, rexBuilder()));
break;
case TYPE_INT64:
ret =
rexBuilder()
.makeExactLiteral(
new BigDecimal(value.getInt64Value()),
TypeUtils.toSimpleRelDataType(kind, rexBuilder()));
break;
case TYPE_FLOAT:
ret =
rexBuilder()
.makeApproxLiteral(
new BigDecimal(value.getFloatValue()),
TypeUtils.toSimpleRelDataType(kind, rexBuilder()));
break;
case TYPE_DOUBLE:
ret =
rexBuilder()
.makeApproxLiteral(
new BigDecimal(value.getDoubleValue()),
TypeUtils.toSimpleRelDataType(kind, rexBuilder()));
break;
case TYPE_STRING:
// has to allow CAST because Calcite create CHAR type first and does a CAST to VARCHAR.
// If not allow cast, rexBuilder() will only build a literal with CHAR type.
ret =
rexBuilder()
.makeLiteral(
value.getStringValue(), typeFactory().createSqlType(SqlTypeName.VARCHAR), true);
break;
case TYPE_TIMESTAMP:
ret =
rexBuilder()
.makeTimestampLiteral(
TimestampString.fromMillisSinceEpoch(
safeMicrosToMillis(value.getTimestampUnixMicros())),
typeFactory().getTypeSystem().getMaxPrecision(SqlTypeName.TIMESTAMP));
break;
case TYPE_DATE:
ret = rexBuilder().makeDateLiteral(convertDateValueToDateString(value));
break;
case TYPE_TIME:
RelDataType timeType =
typeFactory()
.createSqlType(
SqlTypeName.TIME,
typeFactory().getTypeSystem().getMaxPrecision(SqlTypeName.TIME));
// TODO: Doing micro to mills truncation, need to throw exception.
ret = rexBuilder().makeLiteral(convertTimeValueToTimeString(value), timeType, false);
break;
case TYPE_BYTES:
ret = rexBuilder().makeBinaryLiteral(new ByteString(value.getBytesValue().toByteArray()));
break;
default:
throw new RuntimeException("Unsupported column type: " + kind);
}
return ret;
}
private RexNode convertArrayValueToRexNode(ArrayType arrayType, Value value) {
if (value.isNull()) {
// TODO: should the nullable be false for a array?
return rexBuilder()
.makeNullLiteral(TypeUtils.toArrayRelDataType(rexBuilder(), arrayType, false));
}
List<RexNode> operands = new ArrayList<>();
for (Value v : value.getElementList()) {
operands.add(convertValueToRexNode(arrayType.getElementType(), v));
}
return rexBuilder().makeCall(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, operands);
}
private RexNode convertEnumToRexNode(Type type, Value value) {
if (type.typeName().equals("`zetasql.functions.DateTimestampPart`")) {
return convertTimeUnitRangeEnumToRexNode(type, value);
} else {
throw new RuntimeException(
MessageFormat.format(
"Unsupported enum. Kind: {0} Type: {1}", type.getKind(), type.typeName()));
}
}
private RexNode convertTimeUnitRangeEnumToRexNode(Type type, Value value) {
TimeUnit mappedUnit = TIME_UNIT_CASTING_MAP.get(value.getEnumValue());
if (mappedUnit == null) {
throw new RuntimeException(
MessageFormat.format(
"Unsupported enum value. Kind: {0} Type: {1} Value: {2} EnumName: {3}",
type.getKind(), type.typeName(), value.getEnumName(), value.getEnumValue()));
}
TimeUnitRange mappedRange = TimeUnitRange.of(mappedUnit, null);
return rexBuilder().makeFlag(mappedRange);
}
private RexNode convertResolvedColumnRef(
ResolvedColumnRef columnRef,
List<ResolvedColumn> columnList,
List<RelDataTypeField> fieldList) {
int index = indexOfProjectionColumnRef(columnRef.getColumn().getId(), columnList);
if (index < 0 || index >= columnList.size()) {
throw new RuntimeException(
String.format("Cannot find %s in fieldList %s", columnRef.getColumn(), fieldList));
}
return rexBuilder().makeInputRef(fieldList.get(index).getType(), index);
}
private RexNode convertResolvedColumnRef(ResolvedColumnRef columnRef) {
// TODO: id - 1 might be only correct if the columns read from TableScan.
// What if the columns come from other scans (which means their id are not indexed from 0),
// and what if there are some mis-order?
// TODO: can join key be NULL?
return rexBuilder()
.makeInputRef(
TypeUtils.toRelDataType(rexBuilder(), columnRef.getType(), false),
(int) columnRef.getColumn().getId() - 1);
}
/** Return an index of the projection column reference. */
public int indexOfProjectionColumnRef(long colId, List<ResolvedColumn> columnList) {
int ret = -1;
for (int i = 0; i < columnList.size(); i++) {
if (columnList.get(i).getId() == colId) {
ret = i;
break;
}
}
return ret;
}
private RexNode convertResolvedFunctionCall(
ResolvedFunctionCall functionCall,
List<ResolvedColumn> columnList,
List<RelDataTypeField> fieldList) {
RexNode ret;
SqlOperator op;
List<RexNode> operands = new ArrayList<>();
if (functionCall.getFunction().getGroup().equals(PRE_DEFINED_WINDOW_FUNCTIONS)) {
switch (functionCall.getFunction().getName()) {
case FIXED_WINDOW:
case SESSION_WINDOW:
op =
SqlStdOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get(
functionCall.getFunction().getName());
// TODO: check size and type of window function argument list.
// Add ts column reference to operands.
operands.add(
convertRexNodeFromResolvedExpr(
functionCall.getArgumentList().get(0), columnList, fieldList));
// Add fixed window size or session window gap to operands.
operands.add(
convertIntervalToRexIntervalLiteral(
(ResolvedLiteral) functionCall.getArgumentList().get(1)));
break;
case SLIDING_WINDOW:
op =
SqlStdOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get(
SLIDING_WINDOW);
// Add ts column reference to operands.
operands.add(
convertRexNodeFromResolvedExpr(
functionCall.getArgumentList().get(0), columnList, fieldList));
// add sliding window emit frequency to operands.
operands.add(
convertIntervalToRexIntervalLiteral(
(ResolvedLiteral) functionCall.getArgumentList().get(1)));
// add sliding window size to operands.
operands.add(
convertIntervalToRexIntervalLiteral(
(ResolvedLiteral) functionCall.getArgumentList().get(2)));
break;
default:
throw new RuntimeException("Only support TUMBLE, HOP AND SESSION functions right now.");
}
} else if (functionCall.getFunction().getGroup().equals("ZetaSQL")) {
op =
SqlStdOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get(
functionCall.getFunction().getName());
if (op == null) {
throw new RuntimeException(
"Does not support ZetaSQL function: " + functionCall.getFunction().getName());
}
// There are different processes to handle argument conversion because INTERVAL is not a
// type in ZetaSQL.
if (FUNCTION_FAMILY_DATE_ADD.contains(functionCall.getFunction().getName())) {
return convertTimestampAddFunction(functionCall, columnList, fieldList);
} else {
for (ResolvedExpr expr : functionCall.getArgumentList()) {
operands.add(convertRexNodeFromResolvedExpr(expr, columnList, fieldList));
}
}
} else {
throw new RuntimeException(
"Does not support function group: " + functionCall.getFunction().getGroup());
}
SqlOperatorRewriter rewriter =
SqlStdOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR_REWRITER.get(
functionCall.getFunction().getName());
if (rewriter != null) {
ret = rewriter.apply(rexBuilder(), operands);
} else {
ret = rexBuilder().makeCall(op, operands);
}
return ret;
}
private RexNode convertTimestampAddFunction(
ResolvedFunctionCall functionCall,
List<ResolvedColumn> columnList,
List<RelDataTypeField> fieldList) {
TimeUnit unit =
TIME_UNIT_CASTING_MAP.get(
((ResolvedLiteral) functionCall.getArgumentList().get(2)).getValue().getEnumValue());
if ((unit == TimeUnit.MICROSECOND) || (unit == TimeUnit.NANOSECOND)) {
throw Status.UNIMPLEMENTED
.withDescription("Micro and Nanoseconds are not supported by Beam ZetaSQL")
.asRuntimeException();
}
SqlIntervalQualifier qualifier = new SqlIntervalQualifier(unit, null, SqlParserPos.ZERO);
RexNode intervalArgumentNode =
convertRexNodeFromResolvedExpr(
functionCall.getArgumentList().get(1), columnList, fieldList);
RexNode validatedIntervalArgument =
rexBuilder()
.makeCall(
SqlOperators.VALIDATE_TIME_INTERVAL,
intervalArgumentNode,
rexBuilder().makeFlag(unit));
RexNode intervalNode =
rexBuilder()
.makeCall(
SqlStdOperatorTable.MULTIPLY,
rexBuilder().makeIntervalLiteral(unit.multiplier, qualifier),
validatedIntervalArgument);
RexNode timestampNode =
convertRexNodeFromResolvedExpr(
functionCall.getArgumentList().get(0), columnList, fieldList);
RexNode dateTimePlusResult =
rexBuilder().makeCall(SqlStdOperatorTable.DATETIME_PLUS, timestampNode, intervalNode);
RexNode validatedTimestampResult =
rexBuilder().makeCall(SqlOperators.VALIDATE_TIMESTAMP, dateTimePlusResult);
return validatedTimestampResult;
}
private RexNode convertIntervalToRexIntervalLiteral(ResolvedLiteral resolvedLiteral) {
if (resolvedLiteral.getType().getKind() != TYPE_STRING) {
throw new IllegalArgumentException(INTERVAL_FORMAT_MSG);
}
String valStr = resolvedLiteral.getValue().getStringValue();
List<String> stringList =
Arrays.stream(valStr.split(" ")).filter(s -> !s.isEmpty()).collect(Collectors.toList());
if (stringList.size() != 3) {
throw new IllegalArgumentException(INTERVAL_FORMAT_MSG);
}
if (!Ascii.toUpperCase(stringList.get(0)).equals("INTERVAL")) {
throw new IllegalArgumentException(INTERVAL_FORMAT_MSG);
}
long intervalValue;
try {
intervalValue = Long.parseLong(stringList.get(1));
} catch (NumberFormatException e) {
throw new IllegalArgumentException(INTERVAL_FORMAT_MSG, e);
}
String intervalDatepart = Ascii.toUpperCase(stringList.get(2));
return createCalciteIntervalRexLiteral(intervalValue, intervalDatepart);
}
private RexLiteral createCalciteIntervalRexLiteral(long intervalValue, String intervalTimeUnit) {
SqlIntervalQualifier sqlIntervalQualifier =
convertIntervalDatepartToSqlIntervalQualifier(intervalTimeUnit);
BigDecimal decimalValue;
if (DATE_PART_UNITS_TO_MILLIS.contains(intervalTimeUnit)) {
decimalValue = convertIntervalValueToMillis(sqlIntervalQualifier, intervalValue);
} else if (DATE_PART_UNITS_TO_MONTHS.contains(intervalTimeUnit)) {
decimalValue = new BigDecimal(intervalValue * 12);
} else {
decimalValue = new BigDecimal(intervalValue);
}
return rexBuilder().makeIntervalLiteral(decimalValue, sqlIntervalQualifier);
}
private static BigDecimal convertIntervalValueToMillis(
SqlIntervalQualifier qualifier, long value) {
switch (qualifier.typeName()) {
case INTERVAL_DAY:
return new BigDecimal(value * ONE_DAY_IN_MILLIS);
case INTERVAL_HOUR:
return new BigDecimal(value * ONE_HOUR_IN_MILLIS);
case INTERVAL_MINUTE:
return new BigDecimal(value * ONE_MINUTE_IN_MILLIS);
case INTERVAL_SECOND:
return new BigDecimal(value * ONE_SECOND_IN_MILLIS);
default:
throw new IllegalArgumentException(qualifier.typeName().toString());
}
}
private static SqlIntervalQualifier convertIntervalDatepartToSqlIntervalQualifier(
String datePart) {
switch (datePart) {
case "YEAR":
return new SqlIntervalQualifier(TimeUnit.YEAR, null, SqlParserPos.ZERO);
case "MONTH":
return new SqlIntervalQualifier(TimeUnit.MONTH, null, SqlParserPos.ZERO);
case "DAY":
return new SqlIntervalQualifier(TimeUnit.DAY, null, SqlParserPos.ZERO);
case "HOUR":
return new SqlIntervalQualifier(TimeUnit.HOUR, null, SqlParserPos.ZERO);
case "MINUTE":
return new SqlIntervalQualifier(TimeUnit.MINUTE, null, SqlParserPos.ZERO);
case "SECOND":
return new SqlIntervalQualifier(TimeUnit.SECOND, null, SqlParserPos.ZERO);
case "WEEK":
return new SqlIntervalQualifier(TimeUnit.WEEK, null, SqlParserPos.ZERO);
case "QUARTER":
return new SqlIntervalQualifier(TimeUnit.QUARTER, null, SqlParserPos.ZERO);
case "MILLISECOND":
return new SqlIntervalQualifier(TimeUnit.MILLISECOND, null, SqlParserPos.ZERO);
default:
throw new RuntimeException(
String.format(
"Received an undefined INTERVAL unit: %s. Please specify unit from the following"
+ " list: %s.",
datePart, INTERVAL_DATE_PART_MSG));
}
}
private RexNode convertResolvedCast(
ResolvedCast resolvedCast,
List<ResolvedColumn> columnList,
List<RelDataTypeField> fieldList) {
TypeKind fromType = resolvedCast.getExpr().getType().getKind();
TypeKind toType = resolvedCast.getType().getKind();
isCastingSupported(fromType, toType);
RexNode inputNode =
convertRexNodeFromResolvedExpr(resolvedCast.getExpr(), columnList, fieldList);
// nullability of the output type should match that of the input node's type
RelDataType outputType =
TypeUtils.toSimpleRelDataType(
resolvedCast.getType().getKind(), rexBuilder(), inputNode.getType().isNullable());
if (isZetaSQLCast(fromType, toType)) {
return rexBuilder().makeCall(outputType, ZETASQL_CAST_OP, ImmutableList.of(inputNode));
} else {
return rexBuilder().makeCast(outputType, inputNode);
}
}
private static void isCastingSupported(TypeKind fromType, TypeKind toType) {
if (UNSUPPORTED_CASTING.containsKey(toType)
&& UNSUPPORTED_CASTING.get(toType).contains(fromType)) {
throw new IllegalArgumentException(
"Does not support CAST(" + fromType + " AS " + toType + ")");
}
}
private static boolean isZetaSQLCast(TypeKind fromType, TypeKind toType) {
// TODO: Structure ZETASQL_CAST_OP so that we don't have to repeat the supported types
// here
return (fromType.equals(TYPE_BYTES) && toType.equals(TYPE_STRING))
|| (fromType.equals(TYPE_INT64) && toType.equals(TYPE_BOOL))
|| (fromType.equals(TYPE_BOOL) && toType.equals(TYPE_INT64))
|| (fromType.equals(TYPE_TIMESTAMP) && toType.equals(TYPE_STRING));
}
private RexNode convertRexNodeFromResolvedColumnRefWithRefScan(
ResolvedColumnRef columnRef,
List<ResolvedColumn> refScanColumnList,
List<ResolvedColumn> originalColumnList,
List<RelDataTypeField> fieldList) {
for (int i = 0; i < refScanColumnList.size(); i++) {
if (refScanColumnList.get(i).getId() == columnRef.getColumn().getId()) {
boolean nullable = fieldList.get(i).getType().isNullable();
int off = (int) originalColumnList.get(i).getId() - 1;
return rexBuilder()
.makeInputRef(
TypeUtils.toSimpleRelDataType(
columnRef.getType().getKind(), rexBuilder(), nullable),
off);
}
}
return null;
}
private RexNode convertResolvedParameter(ResolvedParameter parameter) {
assert parameter.getType().equals(queryParams.get(parameter.getName()).getType());
return convertValueToRexNode(
queryParams.get(parameter.getName()).getType(), queryParams.get(parameter.getName()));
}
private RexNode convertResolvedStructFieldAccess(ResolvedGetStructField resolvedGetStructField) {
return rexBuilder()
.makeFieldAccess(
convertRexNodeFromResolvedExpr(resolvedGetStructField.getExpr()),
(int) resolvedGetStructField.getFieldIdx());
}
private RexNode convertResolvedStructFieldAccess(
ResolvedGetStructField resolvedGetStructField,
List<ResolvedColumn> columnList,
List<RelDataTypeField> fieldList) {
return rexBuilder()
.makeFieldAccess(
convertRexNodeFromResolvedExpr(resolvedGetStructField.getExpr(), columnList, fieldList),
(int) resolvedGetStructField.getFieldIdx());
}
private RexBuilder rexBuilder() {
return cluster.getRexBuilder();
}
private RelDataTypeFactory typeFactory() {
return cluster.getTypeFactory();
}
}