blob: 1490229ff9788131f4ca37bf7445e0134ac343a7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.sql2rel;
import org.apache.calcite.avatica.util.DateTimeUtils;
import org.apache.calcite.avatica.util.TimeUnit;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeFamily;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.TimeFrame;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCallBinding;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexRangeRef;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.runtime.SqlFunctions;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlBinaryOperator;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlIntervalLiteral;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlJdbcFunctionCall;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlNumericLiteral;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.SqlTableFunction;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.SqlWindowTableFunction;
import org.apache.calcite.sql.fun.SqlArrayValueConstructor;
import org.apache.calcite.sql.fun.SqlBetweenOperator;
import org.apache.calcite.sql.fun.SqlCase;
import org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.calcite.sql.fun.SqlDatetimeSubtractionOperator;
import org.apache.calcite.sql.fun.SqlExtractFunction;
import org.apache.calcite.sql.fun.SqlInternalOperators;
import org.apache.calcite.sql.fun.SqlJsonValueFunction;
import org.apache.calcite.sql.fun.SqlLibrary;
import org.apache.calcite.sql.fun.SqlLibraryOperators;
import org.apache.calcite.sql.fun.SqlLiteralChainOperator;
import org.apache.calcite.sql.fun.SqlMapValueConstructor;
import org.apache.calcite.sql.fun.SqlMultisetQueryConstructor;
import org.apache.calcite.sql.fun.SqlMultisetValueConstructor;
import org.apache.calcite.sql.fun.SqlOverlapsOperator;
import org.apache.calcite.sql.fun.SqlRowOperator;
import org.apache.calcite.sql.fun.SqlSequenceValueOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSubstringFunction;
import org.apache.calcite.sql.fun.SqlTrimFunction;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import com.google.common.collect.ImmutableList;
import org.checkerframework.checker.initialization.qual.UnknownInitialization;
import org.checkerframework.checker.nullness.qual.Nullable;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
import static com.google.common.base.Preconditions.checkArgument;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.QUANTIFY_OPERATORS;
import static org.apache.calcite.sql.type.NonNullableAccessors.getComponentTypeOrThrow;
import static org.apache.calcite.util.Util.first;
import static java.util.Objects.requireNonNull;
/**
* Standard implementation of {@link SqlRexConvertletTable}.
*/
public class StandardConvertletTable extends ReflectiveConvertletTable {
/** Singleton instance. */
public static final StandardConvertletTable INSTANCE =
new StandardConvertletTable();
//~ Constructors -----------------------------------------------------------
private StandardConvertletTable() {
super();
// Register aliases (operators which have a different name but
// identical behavior to other operators).
addAlias(SqlLibraryOperators.LEN,
SqlStdOperatorTable.CHAR_LENGTH);
addAlias(SqlLibraryOperators.LENGTH,
SqlStdOperatorTable.CHAR_LENGTH);
addAlias(SqlStdOperatorTable.CHARACTER_LENGTH,
SqlStdOperatorTable.CHAR_LENGTH);
addAlias(SqlStdOperatorTable.IS_UNKNOWN,
SqlStdOperatorTable.IS_NULL);
addAlias(SqlStdOperatorTable.IS_NOT_UNKNOWN,
SqlStdOperatorTable.IS_NOT_NULL);
addAlias(SqlLibraryOperators.NULL_SAFE_EQUAL,
SqlStdOperatorTable.IS_NOT_DISTINCT_FROM);
addAlias(SqlStdOperatorTable.PERCENT_REMAINDER, SqlStdOperatorTable.MOD);
addAlias(SqlLibraryOperators.IFNULL, SqlLibraryOperators.NVL);
addAlias(SqlLibraryOperators.REGEXP_SUBSTR, SqlLibraryOperators.REGEXP_EXTRACT);
addAlias(SqlLibraryOperators.ENDSWITH, SqlLibraryOperators.ENDS_WITH);
addAlias(SqlLibraryOperators.STARTSWITH, SqlLibraryOperators.STARTS_WITH);
addAlias(SqlLibraryOperators.BITAND_AGG, SqlStdOperatorTable.BIT_AND);
addAlias(SqlLibraryOperators.BITOR_AGG, SqlStdOperatorTable.BIT_OR);
// Register convertlets for specific objects.
registerOp(SqlStdOperatorTable.CAST, this::convertCast);
registerOp(SqlLibraryOperators.SAFE_CAST, this::convertCast);
registerOp(SqlLibraryOperators.TRY_CAST, this::convertCast);
registerOp(SqlLibraryOperators.INFIX_CAST, this::convertCast);
registerOp(SqlStdOperatorTable.IS_DISTINCT_FROM,
(cx, call) -> convertIsDistinctFrom(cx, call, false));
registerOp(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM,
(cx, call) -> convertIsDistinctFrom(cx, call, true));
registerOp(SqlStdOperatorTable.PLUS, this::convertPlus);
registerOp(SqlStdOperatorTable.MINUS,
(cx, call) -> {
final RexCall e =
(RexCall) StandardConvertletTable.this.convertCall(cx, call);
switch (e.getOperands().get(0).getType().getSqlTypeName()) {
case DATE:
case TIME:
case TIMESTAMP:
return convertDatetimeMinus(cx, SqlStdOperatorTable.MINUS_DATE,
call);
default:
return e;
}
});
// DATE(string) is equivalent to CAST(string AS DATE),
// but other DATE variants are treated as regular functions.
registerOp(SqlLibraryOperators.DATE,
(cx, call) -> {
final RexCall e =
(RexCall) StandardConvertletTable.this.convertCall(cx, call);
if (e.getOperands().size() == 1
&& SqlTypeUtil.isString(e.getOperands().get(0).getType())) {
return cx.getRexBuilder().makeCast(e.type, e.getOperands().get(0));
}
return e;
});
registerOp(SqlLibraryOperators.DATETIME_TRUNC,
new TruncConvertlet());
registerOp(SqlLibraryOperators.TIMESTAMP_TRUNC,
new TruncConvertlet());
registerOp(SqlLibraryOperators.LTRIM,
new TrimConvertlet(SqlTrimFunction.Flag.LEADING));
registerOp(SqlLibraryOperators.RTRIM,
new TrimConvertlet(SqlTrimFunction.Flag.TRAILING));
registerOp(SqlLibraryOperators.GREATEST, new GreatestConvertlet());
registerOp(SqlLibraryOperators.LEAST, new GreatestConvertlet());
registerOp(SqlLibraryOperators.SUBSTR_BIG_QUERY,
new SubstrConvertlet(SqlLibrary.BIG_QUERY));
registerOp(SqlLibraryOperators.SUBSTR_MYSQL,
new SubstrConvertlet(SqlLibrary.MYSQL));
registerOp(SqlLibraryOperators.SUBSTR_ORACLE,
new SubstrConvertlet(SqlLibrary.ORACLE));
registerOp(SqlLibraryOperators.SUBSTR_POSTGRESQL,
new SubstrConvertlet(SqlLibrary.POSTGRESQL));
registerOp(SqlLibraryOperators.DATE_ADD,
new TimestampAddConvertlet());
registerOp(SqlLibraryOperators.DATE_DIFF,
new TimestampDiffConvertlet());
registerOp(SqlLibraryOperators.DATE_SUB,
new TimestampSubConvertlet());
registerOp(SqlLibraryOperators.DATETIME_ADD,
new TimestampAddConvertlet());
registerOp(SqlLibraryOperators.DATETIME_DIFF,
new TimestampDiffConvertlet());
registerOp(SqlLibraryOperators.DATETIME_SUB,
new TimestampSubConvertlet());
registerOp(SqlLibraryOperators.TIME_ADD,
new TimestampAddConvertlet());
registerOp(SqlLibraryOperators.TIME_DIFF,
new TimestampDiffConvertlet());
registerOp(SqlLibraryOperators.TIME_SUB,
new TimestampSubConvertlet());
registerOp(SqlLibraryOperators.TIMESTAMP_ADD2,
new TimestampAddConvertlet());
registerOp(SqlLibraryOperators.TIMESTAMP_DIFF3,
new TimestampDiffConvertlet());
registerOp(SqlLibraryOperators.TIMESTAMP_SUB,
new TimestampSubConvertlet());
QUANTIFY_OPERATORS.forEach(operator ->
registerOp(operator, StandardConvertletTable::convertQuantifyOperator));
registerOp(SqlLibraryOperators.NVL, StandardConvertletTable::convertNvl);
registerOp(SqlLibraryOperators.DECODE,
StandardConvertletTable::convertDecode);
registerOp(SqlLibraryOperators.IF, StandardConvertletTable::convertIf);
// Expand "x NOT LIKE y" into "NOT (x LIKE y)"
registerOp(SqlStdOperatorTable.NOT_LIKE,
(cx, call) -> cx.convertExpression(
SqlStdOperatorTable.NOT.createCall(SqlParserPos.ZERO,
SqlStdOperatorTable.LIKE.createCall(SqlParserPos.ZERO,
call.getOperandList()))));
// Expand "x NOT ILIKE y" into "NOT (x ILIKE y)"
registerOp(SqlLibraryOperators.NOT_ILIKE,
(cx, call) -> cx.convertExpression(
SqlStdOperatorTable.NOT.createCall(SqlParserPos.ZERO,
SqlLibraryOperators.ILIKE.createCall(SqlParserPos.ZERO,
call.getOperandList()))));
// Expand "x NOT RLIKE y" into "NOT (x RLIKE y)"
registerOp(SqlLibraryOperators.NOT_RLIKE,
(cx, call) -> cx.convertExpression(
SqlStdOperatorTable.NOT.createCall(SqlParserPos.ZERO,
SqlLibraryOperators.RLIKE.createCall(SqlParserPos.ZERO,
call.getOperandList()))));
// Expand "x NOT SIMILAR y" into "NOT (x SIMILAR y)"
registerOp(SqlStdOperatorTable.NOT_SIMILAR_TO,
(cx, call) -> cx.convertExpression(
SqlStdOperatorTable.NOT.createCall(SqlParserPos.ZERO,
SqlStdOperatorTable.SIMILAR_TO.createCall(SqlParserPos.ZERO,
call.getOperandList()))));
// Unary "+" has no effect, so expand "+ x" into "x".
registerOp(SqlStdOperatorTable.UNARY_PLUS,
(cx, call) -> cx.convertExpression(call.operand(0)));
// "DOT"
registerOp(SqlStdOperatorTable.DOT,
(cx, call) -> cx.getRexBuilder().makeFieldAccess(
cx.convertExpression(call.operand(0)),
call.operand(1).toString(), false));
// "ITEM"
registerOp(SqlStdOperatorTable.ITEM, this::convertItem);
// "AS" has no effect, so expand "x AS id" into "x".
registerOp(SqlStdOperatorTable.AS,
(cx, call) -> cx.convertExpression(call.operand(0)));
registerOp(SqlStdOperatorTable.CONVERT, this::convertCharset);
registerOp(SqlStdOperatorTable.TRANSLATE, this::translateCharset);
// "SQRT(x)" is equivalent to "POWER(x, .5)"
registerOp(SqlStdOperatorTable.SQRT,
(cx, call) -> cx.convertExpression(
SqlStdOperatorTable.POWER.createCall(SqlParserPos.ZERO,
call.operand(0),
SqlLiteral.createExactNumeric("0.5", SqlParserPos.ZERO))));
// "STRPOS(string, substring) is equivalent to
// "POSITION(substring IN string)"
registerOp(SqlLibraryOperators.STRPOS,
(cx, call) -> cx.convertExpression(
SqlStdOperatorTable.POSITION.createCall(SqlParserPos.ZERO,
call.operand(1), call.operand(0))));
// "INSTR(string, substring, position, occurrence) is equivalent to
// "POSITION(substring, string, position, occurrence)"
registerOp(SqlLibraryOperators.INSTR, StandardConvertletTable::convertInstr);
// REVIEW jvs 24-Apr-2006: This only seems to be working from within a
// windowed agg. I have added an optimizer rule
// org.apache.calcite.rel.rules.AggregateReduceFunctionsRule which handles
// other cases post-translation. The reason I did that was to defer the
// implementation decision; e.g. we may want to push it down to a foreign
// server directly rather than decomposed; decomposition is easier than
// recognition.
// Convert "avg(<expr>)" to "cast(sum(<expr>) / count(<expr>) as
// <type>)". We don't need to handle the empty set specially, because
// the SUM is already supposed to come out as NULL in cases where the
// COUNT is zero, so the null check should take place first and prevent
// division by zero. We need the cast because SUM and COUNT may use
// different types, say BIGINT.
//
// Similarly STDDEV_POP and STDDEV_SAMP, VAR_POP and VAR_SAMP.
registerOp(SqlStdOperatorTable.AVG,
new AvgVarianceConvertlet(SqlKind.AVG));
registerOp(SqlStdOperatorTable.STDDEV_POP,
new AvgVarianceConvertlet(SqlKind.STDDEV_POP));
registerOp(SqlStdOperatorTable.STDDEV_SAMP,
new AvgVarianceConvertlet(SqlKind.STDDEV_SAMP));
registerOp(SqlStdOperatorTable.STDDEV,
new AvgVarianceConvertlet(SqlKind.STDDEV_SAMP));
registerOp(SqlStdOperatorTable.VAR_POP,
new AvgVarianceConvertlet(SqlKind.VAR_POP));
registerOp(SqlStdOperatorTable.VAR_SAMP,
new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
registerOp(SqlStdOperatorTable.VARIANCE,
new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
registerOp(SqlStdOperatorTable.COVAR_POP,
new RegrCovarianceConvertlet(SqlKind.COVAR_POP));
registerOp(SqlStdOperatorTable.COVAR_SAMP,
new RegrCovarianceConvertlet(SqlKind.COVAR_SAMP));
registerOp(SqlStdOperatorTable.REGR_SXX,
new RegrCovarianceConvertlet(SqlKind.REGR_SXX));
registerOp(SqlStdOperatorTable.REGR_SYY,
new RegrCovarianceConvertlet(SqlKind.REGR_SYY));
final SqlRexConvertlet floorCeilConvertlet = new FloorCeilConvertlet();
registerOp(SqlStdOperatorTable.FLOOR, floorCeilConvertlet);
registerOp(SqlStdOperatorTable.CEIL, floorCeilConvertlet);
registerOp(SqlStdOperatorTable.TIMESTAMP_ADD, new TimestampAddConvertlet());
registerOp(SqlStdOperatorTable.TIMESTAMP_DIFF,
new TimestampDiffConvertlet());
registerOp(SqlStdOperatorTable.INTERVAL,
StandardConvertletTable::convertInterval);
// Convert "element(<expr>)" to "$element_slice(<expr>)", if the
// expression is a multiset of scalars.
if (false) {
registerOp(SqlStdOperatorTable.ELEMENT,
(cx, call) -> {
assert call.operandCount() == 1;
final SqlNode operand = call.operand(0);
final RelDataType type =
cx.getValidator().getValidatedNodeType(operand);
if (!getComponentTypeOrThrow(type).isStruct()) {
return cx.convertExpression(
SqlStdOperatorTable.ELEMENT_SLICE.createCall(
SqlParserPos.ZERO, operand));
}
// fallback on default behavior
return StandardConvertletTable.this.convertCall(cx, call);
});
}
// Convert "$element_slice(<expr>)" to "element(<expr>).field#0"
if (false) {
registerOp(SqlStdOperatorTable.ELEMENT_SLICE,
(cx, call) -> {
assert call.operandCount() == 1;
final SqlNode operand = call.operand(0);
final RexNode expr =
cx.convertExpression(
SqlStdOperatorTable.ELEMENT.createCall(SqlParserPos.ZERO,
operand));
return cx.getRexBuilder().makeFieldAccess(expr, 0);
});
}
}
/** Converts ALL or SOME operators. */
private static RexNode convertQuantifyOperator(SqlRexContext cx, SqlCall call) {
final RexBuilder rexBuilder = cx.getRexBuilder();
final RexNode left = cx.convertExpression(call.getOperandList().get(0));
assert call.getOperandList().get(1) instanceof SqlNodeList;
final RexNode right = cx.convertExpression(((SqlNodeList) call.getOperandList().get(1)).get(0));
final RelDataType rightComponentType = requireNonNull(right.getType().getComponentType());
final RelDataType returnType =
cx.getTypeFactory().createTypeWithNullability(
cx.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN), right.getType().isNullable()
|| left.getType().isNullable() || rightComponentType.isNullable());
return rexBuilder.makeCall(returnType, call.getOperator(), ImmutableList.of(left, right));
}
/** Converts a call to the {@code NVL} function (and also its synonym,
* {@code IFNULL}). */
private static RexNode convertNvl(SqlRexContext cx, SqlCall call) {
final RexBuilder rexBuilder = cx.getRexBuilder();
final RexNode operand0 =
cx.convertExpression(call.getOperandList().get(0));
final RexNode operand1 =
cx.convertExpression(call.getOperandList().get(1));
final RelDataType type =
cx.getValidator().getValidatedNodeType(call);
// Preserve Operand Nullability
return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE,
ImmutableList.of(
rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL,
operand0),
rexBuilder.makeCast(
cx.getTypeFactory()
.createTypeWithNullability(type, operand0.getType().isNullable()),
operand0),
rexBuilder.makeCast(
cx.getTypeFactory()
.createTypeWithNullability(type, operand1.getType().isNullable()),
operand1)));
}
/** Converts a call to the INSTR function.
* INSTR(string, substring, position, occurrence) is equivalent to
* POSITION(substring, string, position, occurrence) */
private static RexNode convertInstr(SqlRexContext cx, SqlCall call) {
final RexBuilder rexBuilder = cx.getRexBuilder();
final List<RexNode> operands =
convertOperands(cx, call, SqlOperandTypeChecker.Consistency.NONE);
final RelDataType type =
cx.getValidator().getValidatedNodeType(call);
final List<RexNode> exprs = new ArrayList<>();
switch (call.operandCount()) {
// Must reverse order of first 2 operands.
case 2:
exprs.add(operands.get(1)); // Substring
exprs.add(operands.get(0)); // String
break;
case 3:
exprs.add(operands.get(1)); // Substring
exprs.add(operands.get(0)); // String
exprs.add(operands.get(2)); // Position
break;
case 4:
exprs.add(operands.get(1)); // Substring
exprs.add(operands.get(0)); // String
exprs.add(operands.get(2)); // Position
exprs.add(operands.get(3)); // Occurrence
break;
default:
throw new UnsupportedOperationException("Position does not accept "
+ call.operandCount() + " operands");
}
return rexBuilder.makeCall(type, SqlStdOperatorTable.POSITION, exprs);
}
/** Converts a call to the DECODE function. */
private static RexNode convertDecode(SqlRexContext cx, SqlCall call) {
final RexBuilder rexBuilder = cx.getRexBuilder();
final List<RexNode> operands =
convertOperands(cx, call, SqlOperandTypeChecker.Consistency.NONE);
final RelDataType type =
cx.getValidator().getValidatedNodeType(call);
final List<RexNode> exprs = new ArrayList<>();
for (int i = 1; i < operands.size() - 1; i += 2) {
exprs.add(
RelOptUtil.isDistinctFrom(rexBuilder, operands.get(0),
operands.get(i), true));
exprs.add(operands.get(i + 1));
}
if (operands.size() % 2 == 0) {
exprs.add(Util.last(operands));
} else {
exprs.add(rexBuilder.makeNullLiteral(type));
}
return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE, exprs);
}
/** Converts a call to the IF function.
*
* <p>{@code IF(b, x, y)} &rarr; {@code CASE WHEN b THEN x ELSE y END}. */
private static RexNode convertIf(SqlRexContext cx, SqlCall call) {
final RexBuilder rexBuilder = cx.getRexBuilder();
final List<RexNode> operands =
convertOperands(cx, call, SqlOperandTypeChecker.Consistency.NONE);
final RelDataType type =
cx.getValidator().getValidatedNodeType(call);
return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE, operands);
}
/** Converts an interval expression to a numeric multiplied by an interval
* literal. */
private static RexNode convertInterval(SqlRexContext cx, SqlCall call) {
// "INTERVAL n HOUR" becomes "n * INTERVAL '1' HOUR"
final SqlNode n = call.operand(0);
final SqlIntervalQualifier intervalQualifier = call.operand(1);
final SqlIntervalLiteral literal =
SqlLiteral.createInterval(1, "1", intervalQualifier,
call.getParserPosition());
final SqlCall multiply =
SqlStdOperatorTable.MULTIPLY.createCall(call.getParserPosition(), n,
literal);
return cx.convertExpression(multiply);
}
//~ Methods ----------------------------------------------------------------
private static RexNode or(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.OR, a0, a1);
}
private static RexNode eq(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, a0, a1);
}
private static RexNode ge(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, a0,
a1);
}
private static RexNode le(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, a0, a1);
}
private static RexNode and(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.AND, a0, a1);
}
private static RexNode divideInt(RexBuilder rexBuilder, RexNode a0,
RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE_INTEGER, a0, a1);
}
private static RexNode plus(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.PLUS, a0, a1);
}
private static RexNode minus(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.MINUS, a0, a1);
}
private static RexNode multiply(RexBuilder rexBuilder, RexNode a0,
RexNode a1) {
return rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, a0, a1);
}
private static RexNode case_(RexBuilder rexBuilder, RexNode... args) {
return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args);
}
// SqlNode helpers
private static SqlCall plus(SqlParserPos pos, SqlNode a0, SqlNode a1) {
return SqlStdOperatorTable.PLUS.createCall(pos, a0, a1);
}
/**
* Converts a CASE expression.
*/
public RexNode convertCase(
SqlRexContext cx,
SqlCase call) {
SqlNodeList whenList = call.getWhenOperands();
SqlNodeList thenList = call.getThenOperands();
assert whenList.size() == thenList.size();
RexBuilder rexBuilder = cx.getRexBuilder();
final List<RexNode> exprList = new ArrayList<>();
final RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
final RexLiteral unknownLiteral =
rexBuilder.makeNullLiteral(typeFactory.createSqlType(SqlTypeName.BOOLEAN));
final RexLiteral nullLiteral =
rexBuilder.makeNullLiteral(typeFactory.createSqlType(SqlTypeName.NULL));
for (int i = 0; i < whenList.size(); i++) {
if (SqlUtil.isNullLiteral(whenList.get(i), false)) {
exprList.add(unknownLiteral);
} else {
exprList.add(cx.convertExpression(whenList.get(i)));
}
if (SqlUtil.isNullLiteral(thenList.get(i), false)) {
exprList.add(nullLiteral);
} else {
exprList.add(cx.convertExpression(thenList.get(i)));
}
}
SqlNode elseOperand = call.getElseOperand();
if (SqlUtil.isNullLiteral(elseOperand, false)) {
exprList.add(nullLiteral);
} else {
exprList.add(cx.convertExpression(requireNonNull(elseOperand, "elseOperand")));
}
RelDataType type =
rexBuilder.deriveReturnType(call.getOperator(), exprList);
for (int i : elseArgs(exprList.size())) {
exprList.set(i,
rexBuilder.ensureType(type, exprList.get(i), false));
}
return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE, exprList);
}
public RexNode convertMultiset(
SqlRexContext cx,
SqlMultisetValueConstructor op,
SqlCall call) {
final RelDataType originalType =
cx.getValidator().getValidatedNodeType(call);
RexRangeRef rr = cx.getSubQueryExpr(call);
assert rr != null;
RelDataType msType = rr.getType().getFieldList().get(0).getType();
RexNode expr =
cx.getRexBuilder().makeInputRef(
msType,
rr.getOffset());
assert msType.getComponentType() != null && msType.getComponentType().isStruct()
: "componentType of " + msType + " must be struct";
assert originalType.getComponentType() != null
: "componentType of " + originalType + " must be struct";
if (!originalType.getComponentType().isStruct()) {
// If the type is not a struct, the multiset operator will have
// wrapped the type as a record. Add a call to the $SLICE operator
// to compensate. For example,
// if '<ms>' has type 'RECORD (INTEGER x) MULTISET',
// then '$SLICE(<ms>) has type 'INTEGER MULTISET'.
// This will be removed as the expression is translated.
expr =
cx.getRexBuilder().makeCall(originalType, SqlStdOperatorTable.SLICE,
ImmutableList.of(expr));
}
return expr;
}
public RexNode convertArray(
SqlRexContext cx,
SqlArrayValueConstructor op,
SqlCall call) {
return convertCall(cx, call);
}
public RexNode convertMap(
SqlRexContext cx,
SqlMapValueConstructor op,
SqlCall call) {
return convertCall(cx, call);
}
public RexNode convertMultisetQuery(
SqlRexContext cx,
SqlMultisetQueryConstructor op,
SqlCall call) {
final RelDataType originalType =
cx.getValidator().getValidatedNodeType(call);
RexRangeRef rr = cx.getSubQueryExpr(call);
assert rr != null;
RelDataType msType = rr.getType().getFieldList().get(0).getType();
RexNode expr =
cx.getRexBuilder().makeInputRef(
msType,
rr.getOffset());
assert msType.getComponentType() != null
: "componentType of " + msType + " must not be null";
assert originalType.getComponentType() != null
: "componentType of " + originalType + " must not be null";
return expr;
}
public RexNode convertJdbc(
SqlRexContext cx,
SqlJdbcFunctionCall op,
SqlCall call) {
// Yuck!! The function definition contains arguments!
// TODO: adopt a more conventional definition/instance structure
final SqlCall convertedCall = op.getLookupCall();
return cx.convertExpression(convertedCall);
}
protected RexNode convertCast(
@UnknownInitialization StandardConvertletTable this,
SqlRexContext cx,
final SqlCall call) {
RelDataTypeFactory typeFactory = cx.getTypeFactory();
final SqlValidator validator = cx.getValidator();
final SqlKind kind = call.getKind();
checkArgument(kind == SqlKind.CAST || kind == SqlKind.SAFE_CAST, kind);
final boolean safe = kind == SqlKind.SAFE_CAST;
final SqlNode left = call.operand(0);
final SqlNode right = call.operand(1);
final SqlLiteral format = call.getOperandList().size() > 2
? call.operand(2) : SqlLiteral.createNull(SqlParserPos.ZERO);
final RexBuilder rexBuilder = cx.getRexBuilder();
final RexNode arg = cx.convertExpression(left);
final RexLiteral formatArg = (RexLiteral) cx.convertLiteral(format);
if (right instanceof SqlIntervalQualifier) {
final SqlIntervalQualifier intervalQualifier =
(SqlIntervalQualifier) right;
if (left instanceof SqlIntervalLiteral) {
RexLiteral sourceInterval =
(RexLiteral) cx.convertExpression(left);
BigDecimal sourceValue =
(BigDecimal) sourceInterval.getValue();
RexLiteral castedInterval =
rexBuilder.makeIntervalLiteral(sourceValue,
intervalQualifier);
return castToValidatedType(call, castedInterval, validator, rexBuilder,
safe);
} else if (left instanceof SqlNumericLiteral) {
RexLiteral sourceInterval =
(RexLiteral) cx.convertExpression(left);
BigDecimal sourceValue =
requireNonNull(sourceInterval.getValueAs(BigDecimal.class),
"sourceValue");
final BigDecimal multiplier = intervalQualifier.getUnit().multiplier;
RexLiteral castedInterval =
rexBuilder.makeIntervalLiteral(
SqlFunctions.multiply(sourceValue, multiplier),
intervalQualifier);
return castToValidatedType(call, castedInterval, validator, rexBuilder,
safe);
}
RexNode value = cx.convertExpression(left);
return castToValidatedType(call, value, validator, rexBuilder, safe);
}
final SqlDataTypeSpec dataType = (SqlDataTypeSpec) right;
RelDataType type =
SqlCastFunction.deriveType(cx.getTypeFactory(), arg.getType(),
dataType.deriveType(validator), safe);
if (SqlUtil.isNullLiteral(left, false)) {
validator.setValidatedNodeType(left, type);
return cx.convertExpression(left);
}
if (null != dataType.getCollectionsTypeName()) {
RelDataType argComponentType = arg.getType().getComponentType();
// arg.getType() may be ANY
if (argComponentType == null) {
argComponentType = dataType.getComponentTypeSpec().deriveType(validator);
}
requireNonNull(argComponentType, () -> "componentType of " + arg);
RelDataType typeFinal = type;
final RelDataType componentType =
requireNonNull(type.getComponentType(),
() -> "componentType of " + typeFinal);
if (argComponentType.isStruct()
&& !componentType.isStruct()) {
RelDataType tt =
typeFactory.builder()
.add(argComponentType.getFieldList().get(0).getName(),
componentType)
.build();
tt = typeFactory.createTypeWithNullability(tt, componentType.isNullable());
boolean isn = type.isNullable();
type = typeFactory.createMultisetType(tt, -1);
type = typeFactory.createTypeWithNullability(type, isn);
}
}
return rexBuilder.makeCast(type, arg, safe, safe, formatArg);
}
protected RexNode convertFloorCeil(SqlRexContext cx, SqlCall call) {
final boolean floor = call.getKind() == SqlKind.FLOOR;
// Rewrite floor, ceil of interval
if (call.operandCount() == 1
&& call.operand(0) instanceof SqlIntervalLiteral) {
final SqlIntervalLiteral literal = call.operand(0);
SqlIntervalLiteral.IntervalValue interval =
literal.getValueAs(SqlIntervalLiteral.IntervalValue.class);
BigDecimal val =
interval.getIntervalQualifier().getStartUnit().multiplier;
RexNode rexInterval = cx.convertExpression(literal);
final RexBuilder rexBuilder = cx.getRexBuilder();
RexNode zero = rexBuilder.makeExactLiteral(BigDecimal.valueOf(0));
RexNode cond = ge(rexBuilder, rexInterval, zero);
RexNode pad =
rexBuilder.makeExactLiteral(val.subtract(BigDecimal.ONE));
RexNode cast =
rexBuilder.makeReinterpretCast(rexInterval.getType(), pad,
rexBuilder.makeLiteral(false));
RexNode sum =
floor ? minus(rexBuilder, rexInterval, cast)
: plus(rexBuilder, rexInterval, cast);
RexNode kase = floor
? case_(rexBuilder, rexInterval, cond, sum)
: case_(rexBuilder, sum, cond, rexInterval);
RexNode factor = rexBuilder.makeExactLiteral(val);
RexNode div = divideInt(rexBuilder, kase, factor);
return multiply(rexBuilder, div, factor);
}
// normal floor, ceil function
return convertFunction(cx, (SqlFunction) call.getOperator(), call);
}
protected RexNode convertCharset(
@UnknownInitialization StandardConvertletTable this,
SqlRexContext cx, SqlCall call) {
final SqlNode expr = call.operand(0);
final String srcCharset = call.operand(1).toString();
final String destCharset = call.operand(2).toString();
final RexBuilder rexBuilder = cx.getRexBuilder();
return rexBuilder.makeCall(SqlStdOperatorTable.CONVERT,
cx.convertExpression(expr),
rexBuilder.makeLiteral(srcCharset),
rexBuilder.makeLiteral(destCharset));
}
protected RexNode translateCharset(
@UnknownInitialization StandardConvertletTable this,
SqlRexContext cx, SqlCall call) {
final SqlNode expr = call.operand(0);
final String transcodingName = call.operand(1).toString();
final RexBuilder rexBuilder = cx.getRexBuilder();
return rexBuilder.makeCall(SqlStdOperatorTable.TRANSLATE,
cx.convertExpression(expr),
rexBuilder.makeLiteral(transcodingName));
}
/**
* Converts a call to the {@code EXTRACT} function.
*
* <p>Called automatically via reflection.
*/
public RexNode convertExtract(
SqlRexContext cx,
SqlExtractFunction op,
SqlCall call) {
return convertFunction(cx, (SqlFunction) call.getOperator(), call);
}
@SuppressWarnings("unused")
private static RexNode mod(RexBuilder rexBuilder, RelDataType resType, RexNode res,
BigDecimal val) {
if (val.equals(BigDecimal.ONE)) {
return res;
}
return rexBuilder.makeCall(SqlStdOperatorTable.MOD, res,
rexBuilder.makeExactLiteral(val, resType));
}
private static RexNode divide(RexBuilder rexBuilder, RexNode res,
BigDecimal val) {
if (val.equals(BigDecimal.ONE)) {
return res;
}
// If val is between 0 and 1, rather than divide by val, multiply by its
// reciprocal. For example, rather than divide by 0.001 multiply by 1000.
if (val.compareTo(BigDecimal.ONE) < 0
&& val.signum() == 1) {
try {
final BigDecimal reciprocal =
BigDecimal.ONE.divide(val, RoundingMode.UNNECESSARY);
return multiply(rexBuilder, res,
rexBuilder.makeExactLiteral(reciprocal));
} catch (ArithmeticException e) {
// ignore - reciprocal is not an integer
}
}
return divideInt(rexBuilder, res, rexBuilder.makeExactLiteral(val));
}
public RexNode convertDatetimeMinus(
@UnknownInitialization StandardConvertletTable this,
SqlRexContext cx,
SqlDatetimeSubtractionOperator op,
SqlCall call) {
// Rewrite datetime minus
final RexBuilder rexBuilder = cx.getRexBuilder();
final List<RexNode> exprs =
convertOperands(cx, call, SqlOperandTypeChecker.Consistency.NONE);
final RelDataType resType =
cx.getValidator().getValidatedNodeType(call);
return rexBuilder.makeCall(resType, op, exprs.subList(0, 2));
}
public RexNode convertFunction(
SqlRexContext cx,
SqlFunction fun,
SqlCall call) {
final List<RexNode> exprs =
convertOperands(cx, call, SqlOperandTypeChecker.Consistency.NONE);
if (fun.getFunctionType() == SqlFunctionCategory.USER_DEFINED_CONSTRUCTOR) {
return makeConstructorCall(cx, fun, exprs);
}
RelDataType returnType =
cx.getValidator().getValidatedNodeTypeIfKnown(call);
if (returnType == null) {
returnType = cx.getRexBuilder().deriveReturnType(fun, exprs);
}
return cx.getRexBuilder().makeCall(returnType, fun, exprs);
}
public RexNode convertWindowFunction(
SqlRexContext cx,
SqlWindowTableFunction fun,
SqlCall call) {
// The first operand of window function is actually a query, skip that.
final List<SqlNode> operands = Util.skip(call.getOperandList());
final List<RexNode> exprs =
convertOperands(cx, call, operands,
SqlOperandTypeChecker.Consistency.NONE);
RelDataType returnType =
cx.getValidator().getValidatedNodeTypeIfKnown(call);
if (returnType == null) {
returnType = cx.getRexBuilder().deriveReturnType(fun, exprs);
}
return cx.getRexBuilder().makeCall(returnType, fun, exprs);
}
public RexNode convertJsonValueFunction(
SqlRexContext cx,
SqlJsonValueFunction fun,
SqlCall call) {
// For Expression with explicit return type:
// i.e. json_value('{"foo":"bar"}', 'lax $.foo', returning varchar(2000))
// use the specified type as the return type.
List<SqlNode> operands =
SqlJsonValueFunction.removeTypeSpecOperands(call);
final List<RexNode> exprs =
convertOperands(cx, call, operands,
SqlOperandTypeChecker.Consistency.NONE);
RelDataType returnType =
cx.getValidator().getValidatedNodeTypeIfKnown(call);
requireNonNull(returnType, () -> "Unable to get type of " + call);
return cx.getRexBuilder().makeCall(returnType, fun, exprs);
}
public RexNode convertSequenceValue(
SqlRexContext cx,
SqlSequenceValueOperator fun,
SqlCall call) {
final List<SqlNode> operands = call.getOperandList();
assert operands.size() == 1;
assert operands.get(0) instanceof SqlIdentifier;
final SqlIdentifier id = (SqlIdentifier) operands.get(0);
final String key = Util.listToString(id.names);
RelDataType returnType =
cx.getValidator().getValidatedNodeType(call);
return cx.getRexBuilder().makeCall(returnType, fun,
ImmutableList.of(cx.getRexBuilder().makeLiteral(key)));
}
public RexNode convertAggregateFunction(
SqlRexContext cx,
SqlAggFunction fun,
SqlCall call) {
final List<RexNode> exprs;
if (call.isCountStar()) {
exprs = ImmutableList.of();
} else {
exprs = convertOperands(cx, call, SqlOperandTypeChecker.Consistency.NONE);
}
RelDataType returnType =
cx.getValidator().getValidatedNodeTypeIfKnown(call);
final int groupCount = cx.getGroupCount();
if (returnType == null) {
RexCallBinding binding =
new RexCallBinding(cx.getTypeFactory(), fun, exprs,
ImmutableList.of()) {
@Override public int getGroupCount() {
return groupCount;
}
};
returnType = fun.inferReturnType(binding);
}
return cx.getRexBuilder().makeCall(returnType, fun, exprs);
}
private static RexNode makeConstructorCall(
SqlRexContext cx,
SqlFunction constructor,
List<RexNode> exprs) {
final RexBuilder rexBuilder = cx.getRexBuilder();
RelDataType type = rexBuilder.deriveReturnType(constructor, exprs);
int n = type.getFieldCount();
ImmutableList.Builder<RexNode> initializationExprs =
ImmutableList.builder();
final InitializerContext initializerContext = new InitializerContext() {
@Override public RexBuilder getRexBuilder() {
return rexBuilder;
}
@Override public SqlNode validateExpression(RelDataType rowType, SqlNode expr) {
throw new UnsupportedOperationException();
}
@Override public RexNode convertExpression(SqlNode e) {
throw new UnsupportedOperationException();
}
};
for (int i = 0; i < n; ++i) {
initializationExprs.add(
cx.getInitializerExpressionFactory().newAttributeInitializer(
type, constructor, i, exprs, initializerContext));
}
List<RexNode> defaultCasts =
RexUtil.generateCastExpressions(
rexBuilder,
type,
initializationExprs.build());
return rexBuilder.makeNewInvocation(type, defaultCasts);
}
private RexNode convertItem(
@UnknownInitialization StandardConvertletTable this,
SqlRexContext cx,
SqlCall call) {
final RexBuilder rexBuilder = cx.getRexBuilder();
final SqlOperator op = call.getOperator();
SqlOperandTypeChecker operandTypeChecker = op.getOperandTypeChecker();
final SqlOperandTypeChecker.Consistency consistency =
operandTypeChecker == null
? SqlOperandTypeChecker.Consistency.NONE
: operandTypeChecker.getConsistency();
final List<RexNode> exprs = convertOperands(cx, call, consistency);
final RelDataType collectionType = exprs.get(0).getType();
final boolean isRowTypeField = SqlTypeUtil.isRow(collectionType);
final boolean isNumericIndex = SqlTypeUtil.isIntType(exprs.get(1).getType());
if (isRowTypeField && isNumericIndex) {
final SqlOperatorBinding opBinding =
new RexCallBinding(cx.getTypeFactory(), op, exprs, ImmutableList.of());
final RelDataType operandType = opBinding.getOperandType(0);
final Integer index = opBinding.getOperandLiteralValue(1, Integer.class);
if (index == null || index < 1 || index > operandType.getFieldCount()) {
throw new AssertionError("Cannot access field at position "
+ index + " within ROW type: " + operandType);
} else {
RelDataTypeField relDataTypeField = collectionType.getFieldList().get(index - 1);
return rexBuilder.makeFieldAccess(
exprs.get(0), relDataTypeField.getName(), false);
}
}
RelDataType type = rexBuilder.deriveReturnType(op, exprs);
return rexBuilder.makeCall(type, op, RexUtil.flatten(exprs, op));
}
/**
* Converts a call to an operator into a {@link RexCall} to the same
* operator.
*
* <p>Called automatically via reflection.
*
* @param cx Context
* @param call Call
* @return Rex call
*/
public RexNode convertCall(
@UnknownInitialization StandardConvertletTable this,
SqlRexContext cx,
SqlCall call) {
final SqlOperator op = call.getOperator();
final RexBuilder rexBuilder = cx.getRexBuilder();
SqlOperandTypeChecker operandTypeChecker = op.getOperandTypeChecker();
final SqlOperandTypeChecker.Consistency consistency =
operandTypeChecker == null
? SqlOperandTypeChecker.Consistency.NONE
: operandTypeChecker.getConsistency();
final List<RexNode> exprs = convertOperands(cx, call, consistency);
RelDataType type = rexBuilder.deriveReturnType(op, exprs);
// Expand 'ROW (x0, x1, ...) = ROW (y0, y1, ...)'
// to 'x0 = y0 AND x1 = y1 AND ...'
if (op.kind == SqlKind.EQUALS) {
final RexNode expr0 = RexUtil.removeCast(exprs.get(0));
final RexNode expr1 = RexUtil.removeCast(exprs.get(1));
if (expr0.getKind() == SqlKind.ROW && expr1.getKind() == SqlKind.ROW) {
final RexCall call0 = (RexCall) expr0;
final RexCall call1 = (RexCall) expr1;
final List<RexNode> eqList = new ArrayList<>();
Pair.forEach(call0.getOperands(), call1.getOperands(), (x, y) ->
eqList.add(rexBuilder.makeCall(op, x, y)));
return RexUtil.composeConjunction(rexBuilder, eqList);
}
}
return rexBuilder.makeCall(type, op, RexUtil.flatten(exprs, op));
}
private static List<Integer> elseArgs(int count) {
// If list is odd, e.g. [0, 1, 2, 3, 4] we get [1, 3, 4]
// If list is even, e.g. [0, 1, 2, 3, 4, 5] we get [2, 4, 5]
final List<Integer> list = new ArrayList<>();
for (int i = count % 2;;) {
list.add(i);
i += 2;
if (i >= count) {
list.add(i - 1);
break;
}
}
return list;
}
private static List<RexNode> convertOperands(SqlRexContext cx,
SqlCall call, SqlOperandTypeChecker.Consistency consistency) {
List<SqlNode> operandList;
if (call.getOperator() instanceof SqlTableFunction) {
// skip set semantic table node of table function
operandList =
call.getOperandList().stream().filter(
operand -> operand.getKind() != SqlKind.SET_SEMANTICS_TABLE)
.collect(Collectors.toList());
} else {
operandList = call.getOperandList();
}
return convertOperands(cx, call, operandList, consistency);
}
private static List<RexNode> convertOperands(SqlRexContext cx,
SqlCall call, List<SqlNode> nodes,
SqlOperandTypeChecker.Consistency consistency) {
final List<RexNode> exprs = new ArrayList<>();
for (SqlNode node : nodes) {
exprs.add(cx.convertExpression(node));
}
final List<RelDataType> operandTypes =
cx.getValidator().getValidatedOperandTypes(call);
if (operandTypes != null) {
final List<RexNode> oldExprs = new ArrayList<>(exprs);
exprs.clear();
Pair.forEach(oldExprs, operandTypes, (expr, type) ->
exprs.add(cx.getRexBuilder().ensureType(type, expr, true)));
}
if (exprs.size() > 1) {
final RelDataType type =
consistentType(cx, consistency, RexUtil.types(exprs));
if (type != null) {
final List<RexNode> oldExprs = new ArrayList<>(exprs);
exprs.clear();
for (RexNode expr : oldExprs) {
exprs.add(cx.getRexBuilder().ensureType(type, expr, true));
}
}
}
return exprs;
}
private static @Nullable RelDataType consistentType(SqlRexContext cx,
SqlOperandTypeChecker.Consistency consistency, List<RelDataType> types) {
switch (consistency) {
case COMPARE:
if (SqlTypeUtil.areSameFamily(types)) {
// All arguments are of same family. No need for explicit casts.
return null;
}
final List<RelDataType> nonCharacterTypes = new ArrayList<>();
for (RelDataType type : types) {
if (type.getFamily() != SqlTypeFamily.CHARACTER) {
nonCharacterTypes.add(type);
}
}
if (!nonCharacterTypes.isEmpty()) {
final int typeCount = types.size();
types = nonCharacterTypes;
if (nonCharacterTypes.size() < typeCount) {
final RelDataTypeFamily family =
nonCharacterTypes.get(0).getFamily();
if (family instanceof SqlTypeFamily) {
// The character arguments might be larger than the numeric
// argument. Give ourselves some headroom.
switch ((SqlTypeFamily) family) {
case INTEGER:
case NUMERIC:
nonCharacterTypes.add(
cx.getTypeFactory().createSqlType(SqlTypeName.BIGINT));
break;
default:
break;
}
}
}
}
// fall through
case LEAST_RESTRICTIVE:
return cx.getTypeFactory().leastRestrictive(types);
default:
return null;
}
}
private RexNode convertPlus(
@UnknownInitialization StandardConvertletTable this,
SqlRexContext cx, SqlCall call) {
final RexNode rex = convertCall(cx, call);
switch (rex.getType().getSqlTypeName()) {
case DATE:
case TIME:
case TIMESTAMP:
// Use special "+" operator for datetime + interval.
// Re-order operands, if necessary, so that interval is second.
final RexBuilder rexBuilder = cx.getRexBuilder();
List<RexNode> operands = ((RexCall) rex).getOperands();
if (operands.size() == 2) {
final SqlTypeName sqlTypeName = operands.get(0).getType().getSqlTypeName();
switch (sqlTypeName) {
case INTERVAL_YEAR:
case INTERVAL_YEAR_MONTH:
case INTERVAL_MONTH:
case INTERVAL_DAY:
case INTERVAL_DAY_HOUR:
case INTERVAL_DAY_MINUTE:
case INTERVAL_DAY_SECOND:
case INTERVAL_HOUR:
case INTERVAL_HOUR_MINUTE:
case INTERVAL_HOUR_SECOND:
case INTERVAL_MINUTE:
case INTERVAL_MINUTE_SECOND:
case INTERVAL_SECOND:
operands = ImmutableList.of(operands.get(1), operands.get(0));
break;
default:
break;
}
}
return rexBuilder.makeCall(rex.getType(),
SqlStdOperatorTable.DATETIME_PLUS, operands);
default:
return rex;
}
}
private RexNode convertIsDistinctFrom(
@UnknownInitialization StandardConvertletTable this,
SqlRexContext cx,
SqlCall call,
boolean neg) {
RexNode op0 = cx.convertExpression(call.operand(0));
RexNode op1 = cx.convertExpression(call.operand(1));
return RelOptUtil.isDistinctFrom(
cx.getRexBuilder(), op0, op1, neg);
}
/**
* Converts a BETWEEN expression.
*
* <p>Called automatically via reflection.
*/
public RexNode convertBetween(
SqlRexContext cx,
SqlBetweenOperator op,
SqlCall call) {
SqlOperandTypeChecker operandTypeChecker = op.getOperandTypeChecker();
final SqlOperandTypeChecker.Consistency consistency =
operandTypeChecker == null
? SqlOperandTypeChecker.Consistency.NONE
: operandTypeChecker.getConsistency();
final List<RexNode> list =
convertOperands(cx, call,
consistency);
final RexNode x = list.get(SqlBetweenOperator.VALUE_OPERAND);
final RexNode y = list.get(SqlBetweenOperator.LOWER_OPERAND);
final RexNode z = list.get(SqlBetweenOperator.UPPER_OPERAND);
final RexBuilder rexBuilder = cx.getRexBuilder();
RexNode ge1 = ge(rexBuilder, x, y);
RexNode le1 = le(rexBuilder, x, z);
RexNode and1 = and(rexBuilder, ge1, le1);
RexNode res;
final SqlBetweenOperator.Flag symmetric = op.flag;
switch (symmetric) {
case ASYMMETRIC:
res = and1;
break;
case SYMMETRIC:
RexNode ge2 = ge(rexBuilder, x, z);
RexNode le2 = le(rexBuilder, x, y);
RexNode and2 = and(rexBuilder, ge2, le2);
res = or(rexBuilder, and1, and2);
break;
default:
throw Util.unexpected(symmetric);
}
final SqlBetweenOperator betweenOp =
(SqlBetweenOperator) call.getOperator();
if (betweenOp.isNegated()) {
res = rexBuilder.makeCall(SqlStdOperatorTable.NOT, res);
}
return res;
}
/**
* Converts a SUBSTRING expression.
*
* <p>Called automatically via reflection.
*/
public RexNode convertSubstring(
SqlRexContext cx,
SqlSubstringFunction op,
SqlCall call) {
final SqlLibrary library =
cx.getValidator().config().conformance().semantics();
final SqlBasicCall basicCall = (SqlBasicCall) call;
switch (library) {
case BIG_QUERY:
return toRex(cx, basicCall, SqlLibraryOperators.SUBSTR_BIG_QUERY);
case MYSQL:
return toRex(cx, basicCall, SqlLibraryOperators.SUBSTR_MYSQL);
case ORACLE:
return toRex(cx, basicCall, SqlLibraryOperators.SUBSTR_ORACLE);
case POSTGRESQL:
default:
return convertFunction(cx, op, call);
}
}
private RexNode toRex(SqlRexContext cx, SqlBasicCall call, SqlFunction f) {
final SqlCall call2 =
new SqlBasicCall(f, call.getOperandList(), call.getParserPosition());
final SqlRexConvertlet convertlet = requireNonNull(get(call2));
return convertlet.convertCall(cx, call2);
}
/**
* Converts a LiteralChain expression: that is, concatenates the operands
* immediately, to produce a single literal string.
*
* <p>Called automatically via reflection.
*/
public RexNode convertLiteralChain(
SqlRexContext cx,
SqlLiteralChainOperator op,
SqlCall call) {
Util.discard(cx);
SqlLiteral sum = SqlLiteralChainOperator.concatenateOperands(call);
return cx.convertLiteral(sum);
}
/**
* Converts a ROW.
*
* <p>Called automatically via reflection.
*/
public RexNode convertRow(
SqlRexContext cx,
SqlRowOperator op,
SqlCall call) {
if (cx.getValidator().getValidatedNodeType(call).getSqlTypeName()
!= SqlTypeName.COLUMN_LIST) {
return convertCall(cx, call);
}
final RexBuilder rexBuilder = cx.getRexBuilder();
final List<RexNode> columns = new ArrayList<>();
for (String operand : SqlIdentifier.simpleNames(call.getOperandList())) {
columns.add(rexBuilder.makeLiteral(operand));
}
final RelDataType type =
rexBuilder.deriveReturnType(SqlStdOperatorTable.COLUMN_LIST, columns);
return rexBuilder.makeCall(type, SqlStdOperatorTable.COLUMN_LIST, columns);
}
/**
* Converts a call to OVERLAPS.
*
* <p>Called automatically via reflection.
*/
public RexNode convertOverlaps(
SqlRexContext cx,
SqlOverlapsOperator op,
SqlCall call) {
// for intervals [t0, t1] overlaps [t2, t3], we can find if the
// intervals overlaps by: ~(t1 < t2 or t3 < t0)
assert call.getOperandList().size() == 2;
final Pair<RexNode, RexNode> left =
convertOverlapsOperand(cx, call.getParserPosition(), call.operand(0));
final RexNode r0 = left.left;
final RexNode r1 = left.right;
final Pair<RexNode, RexNode> right =
convertOverlapsOperand(cx, call.getParserPosition(), call.operand(1));
final RexNode r2 = right.left;
final RexNode r3 = right.right;
// Sort end points into start and end, such that (s0 <= e0) and (s1 <= e1).
final RexBuilder rexBuilder = cx.getRexBuilder();
RexNode leftSwap = le(rexBuilder, r0, r1);
final RexNode s0 = case_(rexBuilder, leftSwap, r0, r1);
final RexNode e0 = case_(rexBuilder, leftSwap, r1, r0);
RexNode rightSwap = le(rexBuilder, r2, r3);
final RexNode s1 = case_(rexBuilder, rightSwap, r2, r3);
final RexNode e1 = case_(rexBuilder, rightSwap, r3, r2);
// (e0 >= s1) AND (e1 >= s0)
switch (op.kind) {
case OVERLAPS:
return and(rexBuilder,
ge(rexBuilder, e0, s1),
ge(rexBuilder, e1, s0));
case CONTAINS:
return and(rexBuilder,
le(rexBuilder, s0, s1),
ge(rexBuilder, e0, e1));
case PERIOD_EQUALS:
return and(rexBuilder,
eq(rexBuilder, s0, s1),
eq(rexBuilder, e0, e1));
case PRECEDES:
return le(rexBuilder, e0, s1);
case IMMEDIATELY_PRECEDES:
return eq(rexBuilder, e0, s1);
case SUCCEEDS:
return ge(rexBuilder, s0, e1);
case IMMEDIATELY_SUCCEEDS:
return eq(rexBuilder, s0, e1);
default:
throw new AssertionError(op);
}
}
private static Pair<RexNode, RexNode> convertOverlapsOperand(SqlRexContext cx,
SqlParserPos pos, SqlNode operand) {
final SqlNode a0;
final SqlNode a1;
switch (operand.getKind()) {
case ROW:
a0 = ((SqlCall) operand).operand(0);
final SqlNode a10 = ((SqlCall) operand).operand(1);
final RelDataType t1 = cx.getValidator().getValidatedNodeType(a10);
if (SqlTypeUtil.isInterval(t1)) {
// make t1 = t0 + t1 when t1 is an interval.
a1 = plus(pos, a0, a10);
} else {
a1 = a10;
}
break;
default:
a0 = operand;
a1 = operand;
}
final RexNode r0 = cx.convertExpression(a0);
final RexNode r1 = cx.convertExpression(a1);
return Pair.of(r0, r1);
}
@Deprecated // to be removed before 2.0
public RexNode castToValidatedType(
@UnknownInitialization StandardConvertletTable this,
SqlRexContext cx,
SqlCall call,
RexNode value) {
return castToValidatedType(call, value, cx.getValidator(),
cx.getRexBuilder(), false);
}
@Deprecated // to be removed before 2.0
public static RexNode castToValidatedType(SqlNode node, RexNode e,
SqlValidator validator, RexBuilder rexBuilder) {
return castToValidatedType(node, e, validator, rexBuilder, false);
}
/**
* Casts a RexNode value to the validated type of a SqlCall. If the value
* was already of the validated type, then the value is returned without an
* additional cast.
*/
public static RexNode castToValidatedType(SqlNode node, RexNode e,
SqlValidator validator, RexBuilder rexBuilder, boolean safe) {
final RelDataType type = validator.getValidatedNodeType(node);
if (e.getType() == type) {
return e;
}
return rexBuilder.makeCast(type, e, safe, safe);
}
/** Convertlet that handles {@code COVAR_POP}, {@code COVAR_SAMP},
* {@code REGR_SXX}, {@code REGR_SYY} windowed aggregate functions.
*/
private static class RegrCovarianceConvertlet implements SqlRexConvertlet {
private final SqlKind kind;
RegrCovarianceConvertlet(SqlKind kind) {
this.kind = kind;
}
@Override public RexNode convertCall(SqlRexContext cx, SqlCall call) {
assert call.operandCount() == 2;
final SqlNode arg1 = call.operand(0);
final SqlNode arg2 = call.operand(1);
final SqlNode expr;
final RelDataType type =
cx.getValidator().getValidatedNodeType(call);
switch (kind) {
case COVAR_POP:
expr = expandCovariance(arg1, arg2, null, type, cx, true);
break;
case COVAR_SAMP:
expr = expandCovariance(arg1, arg2, null, type, cx, false);
break;
case REGR_SXX:
expr = expandRegrSzz(arg2, arg1, type, cx, true);
break;
case REGR_SYY:
expr = expandRegrSzz(arg1, arg2, type, cx, true);
break;
default:
throw Util.unexpected(kind);
}
RexNode rex = cx.convertExpression(expr);
return cx.getRexBuilder().ensureType(type, rex, true);
}
private static SqlNode expandRegrSzz(
final SqlNode arg1, final SqlNode arg2,
final RelDataType avgType, final SqlRexContext cx, boolean variance) {
final SqlParserPos pos = SqlParserPos.ZERO;
final SqlNode count =
SqlStdOperatorTable.REGR_COUNT.createCall(pos, arg1, arg2);
final SqlNode varPop =
expandCovariance(arg1, variance ? arg1 : arg2, arg2, avgType, cx, true);
final RexNode varPopRex = cx.convertExpression(varPop);
final SqlNode varPopCast;
varPopCast = getCastedSqlNode(varPop, avgType, pos, varPopRex);
return SqlStdOperatorTable.MULTIPLY.createCall(pos, varPopCast, count);
}
private static SqlNode expandCovariance(
final SqlNode arg0Input,
final SqlNode arg1Input,
final @Nullable SqlNode dependent,
final RelDataType varType,
final SqlRexContext cx,
boolean biased) {
// covar_pop(x1, x2) ==>
// (sum(x1 * x2) - sum(x2) * sum(x1) / count(x1, x2))
// / count(x1, x2)
//
// covar_samp(x1, x2) ==>
// (sum(x1 * x2) - sum(x1) * sum(x2) / count(x1, x2))
// / (count(x1, x2) - 1)
final SqlParserPos pos = SqlParserPos.ZERO;
final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO);
final RexNode arg0Rex = cx.convertExpression(arg0Input);
final RexNode arg1Rex = cx.convertExpression(arg1Input);
final SqlNode arg0 = getCastedSqlNode(arg0Input, varType, pos, arg0Rex);
final SqlNode arg1 = getCastedSqlNode(arg1Input, varType, pos, arg1Rex);
final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg0, arg1);
final SqlNode sumArgSquared;
final SqlNode sum0;
final SqlNode sum1;
final SqlNode count;
if (dependent == null) {
sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared);
sum0 = SqlStdOperatorTable.SUM.createCall(pos, arg0, arg1);
sum1 = SqlStdOperatorTable.SUM.createCall(pos, arg1, arg0);
count = SqlStdOperatorTable.REGR_COUNT.createCall(pos, arg0, arg1);
} else {
sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared, dependent);
sum0 =
SqlStdOperatorTable.SUM.createCall(pos, arg0,
Objects.equals(dependent, arg0Input) ? arg1 : dependent);
sum1 =
SqlStdOperatorTable.SUM.createCall(pos, arg1,
Objects.equals(dependent, arg1Input) ? arg0 : dependent);
count =
SqlStdOperatorTable.REGR_COUNT.createCall(pos, arg0,
Objects.equals(dependent, arg0Input) ? arg1 : dependent);
}
final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, sum0, sum1);
final SqlNode countCasted =
getCastedSqlNode(count, varType, pos, cx.convertExpression(count));
final SqlNode avgSumSquared =
SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquared, countCasted);
final SqlNode diff =
SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquared, avgSumSquared);
SqlNode denominator;
if (biased) {
denominator = countCasted;
} else {
final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
denominator =
new SqlCase(SqlParserPos.ZERO, countCasted,
SqlNodeList.of(
SqlStdOperatorTable.EQUALS.createCall(pos, countCasted, one)),
SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, null)),
SqlStdOperatorTable.MINUS.createCall(pos, countCasted, one));
}
return SqlStdOperatorTable.DIVIDE.createCall(pos, diff, denominator);
}
private static SqlNode getCastedSqlNode(SqlNode argInput,
RelDataType varType, SqlParserPos pos, @Nullable RexNode argRex) {
if (argRex == null || argRex.getType().equals(varType)) {
return argInput;
}
return SqlStdOperatorTable.CAST.createCall(pos, argInput,
SqlTypeUtil.convertTypeToSpec(varType));
}
}
/** Convertlet that handles {@code AVG} and {@code VARIANCE}
* windowed aggregate functions. */
private static class AvgVarianceConvertlet implements SqlRexConvertlet {
private final SqlKind kind;
AvgVarianceConvertlet(SqlKind kind) {
this.kind = kind;
}
@Override public RexNode convertCall(SqlRexContext cx, SqlCall call) {
assert call.operandCount() == 1;
final SqlNode arg = call.operand(0);
final SqlNode expr;
final RelDataType type =
cx.getValidator().getValidatedNodeType(call);
switch (kind) {
case AVG:
expr = expandAvg(arg, type, cx);
break;
case STDDEV_POP:
expr = expandVariance(arg, type, cx, true, true);
break;
case STDDEV_SAMP:
expr = expandVariance(arg, type, cx, false, true);
break;
case VAR_POP:
expr = expandVariance(arg, type, cx, true, false);
break;
case VAR_SAMP:
expr = expandVariance(arg, type, cx, false, false);
break;
default:
throw Util.unexpected(kind);
}
RexNode rex = cx.convertExpression(expr);
return cx.getRexBuilder().ensureType(type, rex, true);
}
private static SqlNode expandAvg(
final SqlNode arg, final RelDataType avgType, final SqlRexContext cx) {
final SqlParserPos pos = SqlParserPos.ZERO;
final SqlNode sum =
SqlStdOperatorTable.SUM.createCall(pos, arg);
final RexNode sumRex = cx.convertExpression(sum);
final SqlNode sumCast;
sumCast = getCastedSqlNode(sum, avgType, pos, sumRex);
final SqlNode count =
SqlStdOperatorTable.COUNT.createCall(pos, arg);
return SqlStdOperatorTable.DIVIDE.createCall(
pos, sumCast, count);
}
private static SqlNode expandVariance(
final SqlNode argInput,
final RelDataType varType,
final SqlRexContext cx,
boolean biased,
boolean sqrt) {
// stddev_pop(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / count(x),
// .5)
//
// stddev_samp(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / (count(x) - 1),
// .5)
//
// var_pop(x) ==>
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / count(x)
//
// var_samp(x) ==>
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / (count(x) - 1)
final SqlParserPos pos = SqlParserPos.ZERO;
final SqlNode arg =
getCastedSqlNode(argInput, varType, pos,
cx.convertExpression(argInput));
final SqlNode argSquared =
SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
final SqlNode argSquaredCasted =
getCastedSqlNode(argSquared, varType, pos,
cx.convertExpression(argSquared));
final SqlNode sumArgSquared =
SqlStdOperatorTable.SUM.createCall(pos, argSquaredCasted);
final SqlNode sumArgSquaredCasted =
getCastedSqlNode(sumArgSquared, varType, pos,
cx.convertExpression(sumArgSquared));
final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, arg);
final SqlNode sumCasted =
getCastedSqlNode(sum, varType, pos, cx.convertExpression(sum));
final SqlNode sumSquared =
SqlStdOperatorTable.MULTIPLY.createCall(pos, sumCasted, sumCasted);
final SqlNode sumSquaredCasted =
getCastedSqlNode(sumSquared, varType, pos,
cx.convertExpression(sumSquared));
final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg);
final SqlNode countCasted =
getCastedSqlNode(count, varType, pos, cx.convertExpression(count));
final SqlNode avgSumSquared =
SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquaredCasted,
countCasted);
final SqlNode avgSumSquaredCasted =
getCastedSqlNode(avgSumSquared, varType, pos,
cx.convertExpression(avgSumSquared));
final SqlNode diff =
SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquaredCasted,
avgSumSquaredCasted);
final SqlNode diffCasted =
getCastedSqlNode(diff, varType, pos, cx.convertExpression(diff));
final SqlNode denominator;
if (biased) {
denominator = countCasted;
} else {
final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO);
denominator =
new SqlCase(SqlParserPos.ZERO, count,
SqlNodeList.of(
SqlStdOperatorTable.EQUALS.createCall(pos, count, one)),
SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, null)),
SqlStdOperatorTable.MINUS.createCall(pos, count, one));
}
final SqlNode div =
SqlStdOperatorTable.DIVIDE.createCall(pos, diffCasted, denominator);
final SqlNode divCasted =
getCastedSqlNode(div, varType, pos, cx.convertExpression(div));
SqlNode result = div;
if (sqrt) {
final SqlNumericLiteral half = SqlLiteral.createExactNumeric("0.5", pos);
result = SqlStdOperatorTable.POWER.createCall(pos, divCasted, half);
}
return result;
}
private static SqlNode getCastedSqlNode(SqlNode argInput,
RelDataType varType, SqlParserPos pos, @Nullable RexNode argRex) {
if (argRex == null || argRex.getType().equals(varType)) {
return argInput;
}
return SqlStdOperatorTable.CAST.createCall(pos, argInput,
SqlTypeUtil.convertTypeToSpec(varType));
}
}
/** Convertlet that converts {@code LTRIM} and {@code RTRIM} to
* {@code TRIM}. */
private static class TrimConvertlet implements SqlRexConvertlet {
private final SqlTrimFunction.Flag flag;
TrimConvertlet(SqlTrimFunction.Flag flag) {
this.flag = flag;
}
@Override public RexNode convertCall(SqlRexContext cx, SqlCall call) {
final RexBuilder rexBuilder = cx.getRexBuilder();
final RexNode operand =
cx.convertExpression(call.getOperandList().get(0));
return rexBuilder.makeCall(SqlStdOperatorTable.TRIM,
rexBuilder.makeFlag(flag), rexBuilder.makeLiteral(" "), operand);
}
}
/** Convertlet that converts {@code GREATEST} and {@code LEAST}. */
private static class GreatestConvertlet implements SqlRexConvertlet {
@Override public RexNode convertCall(SqlRexContext cx, SqlCall call) {
// Translate
// GREATEST(a, b, c, d)
// to
// CASE
// WHEN a IS NULL OR b IS NULL OR c IS NULL OR d IS NULL
// THEN NULL
// WHEN a > b AND a > c AND a > d
// THEN a
// WHEN b > c AND b > d
// THEN b
// WHEN c > d
// THEN c
// ELSE d
// END
final RexBuilder rexBuilder = cx.getRexBuilder();
final RelDataType type =
cx.getValidator().getValidatedNodeType(call);
final SqlBinaryOperator op;
switch (call.getKind()) {
case GREATEST:
op = SqlStdOperatorTable.GREATER_THAN;
break;
case LEAST:
op = SqlStdOperatorTable.LESS_THAN;
break;
default:
throw new AssertionError();
}
final List<RexNode> exprs =
convertOperands(cx, call, SqlOperandTypeChecker.Consistency.NONE);
final List<RexNode> list = new ArrayList<>();
final List<RexNode> orList = new ArrayList<>();
for (RexNode expr : exprs) {
orList.add(rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, expr));
}
list.add(RexUtil.composeDisjunction(rexBuilder, orList));
list.add(rexBuilder.makeNullLiteral(type));
for (int i = 0; i < exprs.size() - 1; i++) {
RexNode expr = exprs.get(i);
final List<RexNode> andList = new ArrayList<>();
for (int j = i + 1; j < exprs.size(); j++) {
final RexNode expr2 = exprs.get(j);
andList.add(rexBuilder.makeCall(op, expr, expr2));
}
list.add(RexUtil.composeConjunction(rexBuilder, andList));
list.add(expr);
}
list.add(exprs.get(exprs.size() - 1));
return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE, list);
}
}
/** Convertlet that handles {@code FLOOR} and {@code CEIL} functions. */
private class FloorCeilConvertlet implements SqlRexConvertlet {
@Override public RexNode convertCall(SqlRexContext cx, SqlCall call) {
return convertFloorCeil(cx, call);
}
}
/** Convertlet that handles the {@code SUBSTR} function; various dialects
* have slightly different specifications. PostgreSQL seems to comply with
* the ISO standard for the {@code SUBSTRING} function, and therefore
* Calcite's default behavior matches PostgreSQL. */
private static class SubstrConvertlet implements SqlRexConvertlet {
private final SqlLibrary library;
SubstrConvertlet(SqlLibrary library) {
this.library = library;
checkArgument(library == SqlLibrary.ORACLE
|| library == SqlLibrary.MYSQL
|| library == SqlLibrary.BIG_QUERY
|| library == SqlLibrary.POSTGRESQL);
}
@Override public RexNode convertCall(SqlRexContext cx, SqlCall call) {
// Translate
// SUBSTR(value, start, length)
//
// to the following if we want PostgreSQL semantics:
// SUBSTRING(value, start, length)
//
// to the following if we want Oracle semantics:
// SUBSTRING(
// value
// FROM CASE
// WHEN start = 0
// THEN 1
// WHEN start + (length(value) + 1) < 1
// THEN length(value) + 1
// WHEN start < 0
// THEN start + (length(value) + 1)
// ELSE start)
// FOR CASE WHEN length < 0 THEN 0 ELSE length END)
//
// to the following in MySQL:
// SUBSTRING(
// value
// FROM CASE
// WHEN start = 0
// THEN length(value) + 1 -- different from Oracle
// WHEN start + (length(value) + 1) < 1
// THEN length(value) + 1
// WHEN start < 0
// THEN start + length(value) + 1
// ELSE start)
// FOR CASE WHEN length < 0 THEN 0 ELSE length END)
//
// to the following if we want BigQuery semantics:
// CASE
// WHEN start + (length(value) + 1) < 1
// THEN value
// ELSE SUBSTRING(
// value
// FROM CASE
// WHEN start = 0
// THEN 1
// WHEN start < 0
// THEN start + length(value) + 1
// ELSE start)
// FOR CASE WHEN length < 0 THEN 0 ELSE length END)
final RexBuilder rexBuilder = cx.getRexBuilder();
final List<RexNode> exprs =
convertOperands(cx, call, SqlOperandTypeChecker.Consistency.NONE);
final RexNode value = exprs.get(0);
final RexNode start = exprs.get(1);
final RelDataType startType = start.getType();
final RexLiteral zeroLiteral = rexBuilder.makeLiteral(0, startType);
final RexLiteral oneLiteral = rexBuilder.makeLiteral(1, startType);
final RexNode valueLength =
SqlTypeUtil.isBinary(value.getType())
? rexBuilder.makeCall(SqlStdOperatorTable.OCTET_LENGTH, value)
: rexBuilder.makeCall(SqlStdOperatorTable.CHAR_LENGTH, value);
final RexNode valueLengthPlusOne =
rexBuilder.makeCall(SqlStdOperatorTable.PLUS, valueLength,
oneLiteral);
final RexNode newStart;
switch (library) {
case POSTGRESQL:
if (call.operandCount() == 2) {
newStart =
rexBuilder.makeCall(SqlStdOperatorTable.CASE,
rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, start,
oneLiteral),
oneLiteral, start);
} else {
newStart = start;
}
break;
case BIG_QUERY:
newStart =
rexBuilder.makeCall(SqlStdOperatorTable.CASE,
rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, start,
zeroLiteral),
oneLiteral,
rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, start,
zeroLiteral),
rexBuilder.makeCall(SqlStdOperatorTable.PLUS, start,
valueLengthPlusOne),
start);
break;
default:
newStart =
rexBuilder.makeCall(SqlStdOperatorTable.CASE,
rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, start,
zeroLiteral),
library == SqlLibrary.MYSQL ? valueLengthPlusOne : oneLiteral,
rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN,
rexBuilder.makeCall(SqlStdOperatorTable.PLUS, start,
valueLengthPlusOne),
oneLiteral),
valueLengthPlusOne,
rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, start,
zeroLiteral),
rexBuilder.makeCall(SqlStdOperatorTable.PLUS, start,
valueLengthPlusOne),
start);
break;
}
if (call.operandCount() == 2) {
return rexBuilder.makeCall(SqlStdOperatorTable.SUBSTRING, value,
newStart);
}
assert call.operandCount() == 3;
final RexNode length = exprs.get(2);
final RexNode newLength;
switch (library) {
case POSTGRESQL:
newLength = length;
break;
default:
newLength =
rexBuilder.makeCall(SqlStdOperatorTable.CASE,
rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, length,
zeroLiteral),
zeroLiteral, length);
}
final RexNode substringCall =
rexBuilder.makeCall(SqlStdOperatorTable.SUBSTRING, value, newStart,
newLength);
switch (library) {
case BIG_QUERY:
return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN,
rexBuilder.makeCall(SqlStdOperatorTable.PLUS, start,
valueLengthPlusOne), oneLiteral),
value, substringCall);
default:
return substringCall;
}
}
}
/** Convertlet that handles the 3-argument {@code TIMESTAMPADD} function
* and the 2-argument BigQuery-style {@code TIMESTAMP_ADD} function. */
private static class TimestampAddConvertlet implements SqlRexConvertlet {
@Override public RexNode convertCall(SqlRexContext cx, SqlCall call) {
// TIMESTAMPADD(unit, count, timestamp)
// => timestamp + count * INTERVAL '1' UNIT
// TIMESTAMP_ADD(timestamp, interval)
// => timestamp + interval
// "timestamp" may be of type TIMESTAMP or TIMESTAMP WITH LOCAL TIME ZONE.
final RexBuilder rexBuilder = cx.getRexBuilder();
SqlIntervalQualifier qualifier;
final RexNode op1;
final RexNode op2;
switch (call.operandCount()) {
case 2:
// BigQuery-style 'TIMESTAMP_ADD(timestamp, interval)'
final SqlBasicCall operandCall = call.operand(1);
qualifier = operandCall.operand(1);
op1 = cx.convertExpression(operandCall.operand(0));
op2 = cx.convertExpression(call.operand(0));
break;
default:
// JDBC-style 'TIMESTAMPADD(unit, count, timestamp)'
qualifier = call.operand(0);
op1 = cx.convertExpression(call.operand(1));
op2 = cx.convertExpression(call.operand(2));
}
final TimeFrame timeFrame = cx.getValidator().validateTimeFrame(qualifier);
final TimeUnit unit = first(timeFrame.unit(), TimeUnit.EPOCH);
final RelDataType type = cx.getValidator().getValidatedNodeType(call);
if (unit == TimeUnit.EPOCH && qualifier.timeFrameName != null) {
// Custom time frames have a different path. They are kept as names,
// and then handled by Java functions such as
// SqlFunctions.customTimestampAdd.
final RexLiteral timeFrameName =
rexBuilder.makeLiteral(qualifier.timeFrameName);
// If the TIMESTAMPADD call has type TIMESTAMP and op2 has type DATE
// (which can happen for sub-day time frames such as HOUR), cast op2 to
// TIMESTAMP.
final RexNode op2b = rexBuilder.makeCast(type, op2);
return rexBuilder.makeCall(type, SqlStdOperatorTable.TIMESTAMP_ADD,
ImmutableList.of(timeFrameName, op1, op2b));
}
if (qualifier.getUnit() != unit) {
qualifier =
new SqlIntervalQualifier(unit, null,
qualifier.getParserPosition());
}
RexNode interval2Add;
switch (unit) {
case MICROSECOND:
case NANOSECOND:
interval2Add =
divide(rexBuilder,
multiply(rexBuilder,
rexBuilder.makeIntervalLiteral(BigDecimal.ONE, qualifier), op1),
BigDecimal.ONE.divide(unit.multiplier,
RoundingMode.UNNECESSARY));
break;
default:
interval2Add =
multiply(rexBuilder,
rexBuilder.makeIntervalLiteral(unit.multiplier, qualifier), op1);
}
return rexBuilder.makeCall(SqlStdOperatorTable.DATETIME_PLUS,
op2, interval2Add);
}
}
/** Convertlet that handles the BigQuery {@code DATETIME_TRUNC} and
* {@code TIMESTAMP_TRUNC} functions. Ensures that DATE operands are
* cast to TIMESTAMPs to match the expected return type for BigQuery. */
private static class TruncConvertlet implements SqlRexConvertlet {
@Override public RexNode convertCall(SqlRexContext cx, SqlCall call) {
final RexBuilder rexBuilder = cx.getRexBuilder();
RexNode op1 = cx.convertExpression(call.operand(0));
RexNode op2 = cx.convertExpression(call.operand(1));
if (op1.getType().getSqlTypeName() == SqlTypeName.DATE) {
RelDataType type = cx.getValidator().getValidatedNodeType(call);
op1 = cx.getRexBuilder().makeCast(type, op1);
}
return rexBuilder.makeCall(call.getOperator(), op1, op2);
}
}
/** Convertlet that handles the BigQuery {@code TIMESTAMP_SUB} function. */
private static class TimestampSubConvertlet implements SqlRexConvertlet {
@Override public RexNode convertCall(SqlRexContext cx, SqlCall call) {
// TIMESTAMP_SUB(timestamp, interval)
// => timestamp - count * INTERVAL '1' UNIT
final RexBuilder rexBuilder = cx.getRexBuilder();
final SqlBasicCall operandCall = call.operand(1);
SqlIntervalQualifier qualifier = operandCall.operand(1);
final RexNode op1 = cx.convertExpression(operandCall.operand(0));
final RexNode op2 = cx.convertExpression(call.operand(0));
final TimeFrame timeFrame = cx.getValidator().validateTimeFrame(qualifier);
final TimeUnit unit = first(timeFrame.unit(), TimeUnit.EPOCH);
final RexNode interval2Sub;
switch (unit) {
// Fractional second units are converted to seconds using their
// associated multiplier.
case MICROSECOND:
case NANOSECOND:
interval2Sub =
divide(rexBuilder,
multiply(rexBuilder,
rexBuilder.makeIntervalLiteral(BigDecimal.ONE, qualifier), op1),
BigDecimal.ONE.divide(unit.multiplier,
RoundingMode.UNNECESSARY));
break;
default:
interval2Sub =
multiply(rexBuilder,
rexBuilder.makeIntervalLiteral(unit.multiplier, qualifier), op1);
}
return rexBuilder.makeCall(SqlInternalOperators.MINUS_DATE2,
op2, interval2Sub);
}
}
/** Convertlet that handles the {@code TIMESTAMPDIFF} function. */
private static class TimestampDiffConvertlet implements SqlRexConvertlet {
@Override public RexNode convertCall(SqlRexContext cx, SqlCall call) {
// The standard TIMESTAMPDIFF and BigQuery's TIMESTAMP_DIFF have two key
// differences. The first being the order of the subtraction, outlined
// below. The second is that BigQuery truncates each timestamp to the
// specified time unit before the difference is computed.
//
// In fact, all BigQuery functions (TIMESTAMP_DIFF, DATETIME_DIFF,
// DATE_DIFF) truncate before subtracting when applied to date intervals
// (DAY, WEEK, ISOWEEK, MONTH, YEAR, etc.)
//
// For example, if computing the number of weeks between two timestamps,
// one occurring on a Saturday and the other occurring the next day on
// Sunday, their week difference is 1. This is because the first timestamp
// is truncated to the previous Sunday. This is done by making calls to
// TIMESTAMP_TRUNC and the difference is then computed using their
// results.
//
// TIMESTAMPDIFF(unit, t1, t2)
// => (t2 - t1) UNIT
// TIMESTAMP_DIFF(t1, t2, unit)
// => (t1 - t2) UNIT
SqlIntervalQualifier qualifier;
final boolean preTruncate;
final RexNode op1;
final RexNode op2;
if (call.operand(0).getKind() == SqlKind.INTERVAL_QUALIFIER) {
qualifier = call.operand(0);
preTruncate = false;
op1 = cx.convertExpression(call.operand(1));
op2 = cx.convertExpression(call.operand(2));
} else {
qualifier = call.operand(2);
preTruncate = qualifier.isDate();
op1 = cx.convertExpression(call.operand(1));
op2 = cx.convertExpression(call.operand(0));
}
final RexBuilder rexBuilder = cx.getRexBuilder();
final RelDataTypeFactory typeFactory = cx.getTypeFactory();
final TimeFrame timeFrame = cx.getValidator().validateTimeFrame(qualifier);
final TimeUnit unit = first(timeFrame.unit(), TimeUnit.EPOCH);
UnaryOperator<RexNode> truncateFn = UnaryOperator.identity();
if (unit == TimeUnit.EPOCH && qualifier.timeFrameName != null) {
// Custom time frames have a different path. They are kept as names, and
// then handled by Java functions.
final RexLiteral timeFrameName =
rexBuilder.makeLiteral(qualifier.timeFrameName);
// This additional logic accounts for BigQuery truncating prior to
// computing the difference.
if (preTruncate) {
truncateFn = e ->
rexBuilder.makeCall(e.getType(),
SqlLibraryOperators.TIMESTAMP_TRUNC,
ImmutableList.of(e, timeFrameName));
}
return rexBuilder.makeCall(cx.getValidator().getValidatedNodeType(call),
SqlStdOperatorTable.TIMESTAMP_DIFF,
ImmutableList.of(timeFrameName, truncateFn.apply(op1),
truncateFn.apply(op2)));
}
if (preTruncate) {
// The timestamps should be truncated unless the time unit is HOUR, in
// which case only the whole number of hours between the timestamps
// should be returned.
final RexNode timeUnit = cx.convertExpression(qualifier);
truncateFn = e ->
rexBuilder.makeCall(e.getType(),
SqlLibraryOperators.TIMESTAMP_TRUNC,
ImmutableList.of(e, timeUnit));
}
BigDecimal multiplier = BigDecimal.ONE;
BigDecimal divider = BigDecimal.ONE;
SqlTypeName sqlTypeName = unit == TimeUnit.NANOSECOND
? SqlTypeName.BIGINT
: SqlTypeName.INTEGER;
switch (unit) {
case MICROSECOND:
case MILLISECOND:
case NANOSECOND:
case WEEK:
multiplier = BigDecimal.valueOf(DateTimeUtils.MILLIS_PER_SECOND);
divider = unit.multiplier;
qualifier =
new SqlIntervalQualifier(TimeUnit.SECOND, null,
qualifier.getParserPosition());
break;
case QUARTER:
case CENTURY:
case MILLENNIUM:
divider = unit.multiplier;
qualifier =
new SqlIntervalQualifier(TimeUnit.MONTH, null,
qualifier.getParserPosition());
break;
default:
qualifier =
new SqlIntervalQualifier(unit, null,
qualifier.getParserPosition());
break;
}
final RelDataType intervalType =
typeFactory.createTypeWithNullability(
typeFactory.createSqlIntervalType(qualifier),
op1.getType().isNullable() || op2.getType().isNullable());
final RexNode call2 =
rexBuilder.makeCall(intervalType,
SqlStdOperatorTable.MINUS_DATE,
ImmutableList.of(truncateFn.apply(op2), truncateFn.apply(op1)));
final RelDataType intType =
typeFactory.createTypeWithNullability(
typeFactory.createSqlType(sqlTypeName),
SqlTypeUtil.containsNullable(call2.getType()));
RexNode e = rexBuilder.makeCast(intType, call2);
return rexBuilder.multiplyDivide(e, multiplier, divider);
}
}
}