blob: 3594bc5170024fc2687cddde7e5d73916540dc9e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.codegen
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.table.api.TableException
import org.apache.flink.table.data.RowData
import org.apache.flink.table.data.binary.BinaryRowData
import org.apache.flink.table.data.util.DataFormatConverters.{DataFormatConverter, getConverterForDataType}
import org.apache.flink.table.functions.BuiltInFunctionDefinitions
import org.apache.flink.table.planner.calcite.{FlinkRexBuilder, FlinkTypeFactory, RexDistinctKeyVariable, RexFieldVariable}
import org.apache.flink.table.planner.codegen.CodeGenUtils.{requireTemporal, requireTimeInterval, _}
import org.apache.flink.table.planner.codegen.GenerateUtils._
import org.apache.flink.table.planner.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens._
import org.apache.flink.table.planner.codegen.calls._
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction
import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable._
import org.apache.flink.table.planner.functions.sql.SqlThrowExceptionFunction
import org.apache.flink.table.planner.functions.utils.{ScalarSqlFunction, TableSqlFunction}
import org.apache.flink.table.planner.plan.utils.FlinkRexUtil
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLogicalTypeToDataType
import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable
import org.apache.flink.table.runtime.typeutils.TypeCheckUtils
import org.apache.flink.table.runtime.typeutils.TypeCheckUtils.{isNumeric, isTemporal, isTimeInterval}
import org.apache.flink.table.types.logical._
import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{getFieldCount, isCompositeType}
import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo
import org.apache.calcite.rex._
import org.apache.calcite.sql.{SqlKind, SqlOperator}
import org.apache.calcite.sql.`type`.{ReturnTypes, SqlTypeName}
import org.apache.calcite.util.{Sarg, TimestampString}
import org.apache.flink.table.functions.{BuiltInFunctionDefinitions, FunctionDefinition}
import scala.collection.JavaConversions._
/**
* This code generator is mainly responsible for generating codes for a given calcite [[RexNode]].
* It can also generate type conversion codes for the result converter.
*/
class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
extends RexVisitor[GeneratedExpression] {
// check if nullCheck is enabled when inputs can be null
if (nullableInput && !ctx.nullCheck) {
throw new CodeGenException("Null check must be enabled if entire rows can be null.")
}
/**
* term of the [[ProcessFunction]]'s context, can be changed when needed
*/
var contextTerm = "ctx"
/**
* information of the first input
*/
var input1Type: LogicalType = _
var input1Term: String = _
var input1FieldMapping: Option[Array[Int]] = None
/**
* information of the optional second input
*/
var input2Type: Option[LogicalType] = None
var input2Term: Option[String] = None
var input2FieldMapping: Option[Array[Int]] = None
/**
* information of the user-defined constructor
* */
var functionContextTerm: Option[String] = None
/**
* Bind the input information, should be called before generating expression.
*/
def bindInput(
inputType: LogicalType,
inputTerm: String = DEFAULT_INPUT1_TERM,
inputFieldMapping: Option[Array[Int]] = None): ExprCodeGenerator = {
input1Type = inputType
input1Term = inputTerm
input1FieldMapping = inputFieldMapping
this
}
/**
* In some cases, the expression will have two inputs (e.g. join condition and udtf). We should
* bind second input information before use.
*/
def bindSecondInput(
inputType: LogicalType,
inputTerm: String = DEFAULT_INPUT2_TERM,
inputFieldMapping: Option[Array[Int]] = None): ExprCodeGenerator = {
input2Type = Some(inputType)
input2Term = Some(inputTerm)
input2FieldMapping = inputFieldMapping
this
}
/**
* In some cases, we should use user-defined input for constructor. For example,
* ScalaFunctionCodeGen allows to use user-defined context term rather than get
* from invoking getRuntimeContext() method.
* */
def bindConstructorTerm(
inputFunctionContextTerm: String): ExprCodeGenerator = {
functionContextTerm = Some(inputFunctionContextTerm)
this
}
private lazy val input1Mapping: Array[Int] = input1FieldMapping match {
case Some(mapping) => mapping
case _ => fieldIndices(input1Type)
}
private lazy val input2Mapping: Array[Int] = input2FieldMapping match {
case Some(mapping) => mapping
case _ => input2Type match {
case Some(input) => fieldIndices(input)
case _ => Array[Int]()
}
}
private def fieldIndices(t: LogicalType): Array[Int] = {
if (isCompositeType(t)) {
(0 until getFieldCount(t)).toArray
} else {
Array(0)
}
}
/**
* Generates an expression from a RexNode. If objects or variables can be reused, they will be
* added to reusable code sections internally.
*
* @param rex Calcite row expression
* @return instance of GeneratedExpression
*/
def generateExpression(rex: RexNode): GeneratedExpression = {
rex.accept(this)
}
/**
* Generates an expression that converts the first input (and second input) into the given type.
* If two inputs are converted, the second input is appended. If objects or variables can
* be reused, they will be added to reusable code sections internally. The evaluation result
* will be stored in the variable outRecordTerm.
*
* @param returnType conversion target type. Inputs and output must have the same arity.
* @param outRecordTerm the result term
* @param outRecordWriterTerm the result writer term
* @param reusedOutRow If objects or variables can be reused, they will be added to reusable
* code sections internally.
* @return instance of GeneratedExpression
*/
def generateConverterResultExpression(
returnType: RowType,
returnTypeClazz: Class[_ <: RowData],
outRecordTerm: String = DEFAULT_OUT_RECORD_TERM,
outRecordWriterTerm: String = DEFAULT_OUT_RECORD_WRITER_TERM,
reusedOutRow: Boolean = true,
fieldCopy: Boolean = false,
rowtimeExpression: Option[RexNode] = None)
: GeneratedExpression = {
val input1AccessExprs = input1Mapping.map {
case TimeIndicatorTypeInfo.ROWTIME_STREAM_MARKER |
TimeIndicatorTypeInfo.ROWTIME_BATCH_MARKER if rowtimeExpression.isDefined =>
// generate rowtime attribute from expression
generateExpression(rowtimeExpression.get)
case TimeIndicatorTypeInfo.ROWTIME_STREAM_MARKER |
TimeIndicatorTypeInfo.ROWTIME_BATCH_MARKER =>
throw new TableException("Rowtime extraction expression missing. Please report a bug.")
case TimeIndicatorTypeInfo.PROCTIME_STREAM_MARKER =>
// attribute is proctime indicator.
// we use a null literal and generate a timestamp when we need it.
generateNullLiteral(
new LocalZonedTimestampType(true, TimestampKind.PROCTIME, 3),
ctx.nullCheck)
case TimeIndicatorTypeInfo.PROCTIME_BATCH_MARKER =>
// attribute is proctime field in a batch query.
// it is initialized with the current time.
generateCurrentTimestamp(ctx)
case idx =>
// get type of result field
generateInputAccess(
ctx,
input1Type,
input1Term,
idx,
nullableInput,
fieldCopy)
}
val input2AccessExprs = input2Type match {
case Some(ti) =>
input2Mapping.map(idx => generateInputAccess(
ctx,
ti,
input2Term.get,
idx,
nullableInput,
ctx.nullCheck)
).toSeq
case None => Seq() // add nothing
}
generateResultExpression(
input1AccessExprs ++ input2AccessExprs,
returnType,
returnTypeClazz,
outRow = outRecordTerm,
outRowWriter = Some(outRecordWriterTerm),
reusedOutRow = reusedOutRow)
}
/**
* Generates an expression from a sequence of other expressions. The evaluation result
* may be stored in the variable outRecordTerm.
*
* @param fieldExprs field expressions to be converted
* @param returnType conversion target type. Type must have the same arity than fieldExprs.
* @param outRow the result term
* @param outRowWriter the result writer term for BinaryRowData.
* @param reusedOutRow If objects or variables can be reused, they will be added to reusable
* code sections internally.
* @param outRowAlreadyExists Don't need addReusableRecord if out row already exists.
* @return instance of GeneratedExpression
*/
def generateResultExpression(
fieldExprs: Seq[GeneratedExpression],
returnType: RowType,
returnTypeClazz: Class[_ <: RowData],
outRow: String = DEFAULT_OUT_RECORD_TERM,
outRowWriter: Option[String] = Some(DEFAULT_OUT_RECORD_WRITER_TERM),
reusedOutRow: Boolean = true,
outRowAlreadyExists: Boolean = false): GeneratedExpression = {
val fieldExprIdxToOutputRowPosMap = fieldExprs.indices.map(i => i -> i).toMap
generateResultExpression(fieldExprs, fieldExprIdxToOutputRowPosMap, returnType,
returnTypeClazz, outRow, outRowWriter, reusedOutRow, outRowAlreadyExists)
}
/**
* Generates an expression from a sequence of other expressions. The evaluation result
* may be stored in the variable outRecordTerm.
*
* @param fieldExprs field expressions to be converted
* @param fieldExprIdxToOutputRowPosMap Mapping index of fieldExpr in `fieldExprs`
* to position of output row.
* @param returnType conversion target type. Type must have the same arity than fieldExprs.
* @param outRow the result term
* @param outRowWriter the result writer term for BinaryRowData.
* @param reusedOutRow If objects or variables can be reused, they will be added to reusable
* code sections internally.
* @param outRowAlreadyExists Don't need addReusableRecord if out row already exists.
* @return instance of GeneratedExpression
*/
def generateResultExpression(
fieldExprs: Seq[GeneratedExpression],
fieldExprIdxToOutputRowPosMap: Map[Int, Int],
returnType: RowType,
returnTypeClazz: Class[_ <: RowData],
outRow: String,
outRowWriter: Option[String],
reusedOutRow: Boolean,
outRowAlreadyExists: Boolean)
: GeneratedExpression = {
// initial type check
if (returnType.getFieldCount != fieldExprs.length) {
throw new CodeGenException(
s"Arity [${returnType.getFieldCount}] of result type [$returnType] does not match " +
s"number [${fieldExprs.length}] of expressions [$fieldExprs].")
}
if (fieldExprIdxToOutputRowPosMap.size != fieldExprs.length) {
throw new CodeGenException(
s"Size [${returnType.getFieldCount}] of fieldExprIdxToOutputRowPosMap does not match " +
s"number [${fieldExprs.length}] of expressions [$fieldExprs].")
}
// type check
fieldExprs.zipWithIndex foreach {
// timestamp type(Include TimeIndicator) and generic type can compatible with each other.
case (fieldExpr, i)
if fieldExpr.resultType.isInstanceOf[TypeInformationRawType[_]] ||
fieldExpr.resultType.isInstanceOf[TimestampType] =>
if (returnType.getTypeAt(i).getClass != fieldExpr.resultType.getClass
&& !returnType.getTypeAt(i).isInstanceOf[TypeInformationRawType[_]]) {
throw new CodeGenException(
s"Incompatible types of expression and result type, Expression[$fieldExpr] type is " +
s"[${fieldExpr.resultType}], result type is [${returnType.getTypeAt(i)}]")
}
case (fieldExpr, i) if !isInteroperable(fieldExpr.resultType, returnType.getTypeAt(i)) =>
throw new CodeGenException(
s"Incompatible types of expression and result type. Expression[$fieldExpr] type is " +
s"[${fieldExpr.resultType}], result type is [${returnType.getTypeAt(i)}]")
case _ => // ok
}
val setFieldsCode = fieldExprs.zipWithIndex.map { case (fieldExpr, index) =>
val pos = fieldExprIdxToOutputRowPosMap.getOrElse(index,
throw new CodeGenException(s"Illegal field expr index: $index"))
rowSetField(ctx, returnTypeClazz, outRow, pos.toString, fieldExpr, outRowWriter)
}.mkString("\n")
val outRowInitCode = if (!outRowAlreadyExists) {
val initCode = generateRecordStatement(returnType, returnTypeClazz, outRow, outRowWriter, ctx)
if (reusedOutRow) {
NO_CODE
} else {
initCode
}
} else {
NO_CODE
}
val code = if (returnTypeClazz == classOf[BinaryRowData] && outRowWriter.isDefined) {
val writer = outRowWriter.get
val resetWriter = if (ctx.nullCheck) s"$writer.reset();" else s"$writer.resetCursor();"
val completeWriter: String = s"$writer.complete();"
s"""
|$outRowInitCode
|$resetWriter
|$setFieldsCode
|$completeWriter
""".stripMargin
} else {
s"""
|$outRowInitCode
|$setFieldsCode
""".stripMargin
}
GeneratedExpression(outRow, NEVER_NULL, code, returnType)
}
override def visitInputRef(inputRef: RexInputRef): GeneratedExpression = {
val input1Arity = input1Type match {
case r: RowType => r.getFieldCount
case _ => 1
}
// if inputRef index is within size of input1 we work with input1, input2 otherwise
val input = if (inputRef.getIndex < input1Arity) {
(input1Type, input1Term)
} else {
(input2Type.getOrElse(throw new CodeGenException("Invalid input access.")),
input2Term.getOrElse(throw new CodeGenException("Invalid input access.")))
}
val index = if (input._2 == input1Term) {
inputRef.getIndex
} else {
inputRef.getIndex - input1Arity
}
generateInputAccess(ctx, input._1, input._2, index, nullableInput, ctx.nullCheck)
}
override def visitTableInputRef(rexTableInputRef: RexTableInputRef): GeneratedExpression =
visitInputRef(rexTableInputRef)
override def visitFieldAccess(rexFieldAccess: RexFieldAccess): GeneratedExpression = {
val refExpr = rexFieldAccess.getReferenceExpr.accept(this)
val index = rexFieldAccess.getField.getIndex
val fieldAccessExpr = generateFieldAccess(
ctx,
refExpr.resultType,
refExpr.resultTerm,
index)
val resultTypeTerm = primitiveTypeTermForType(fieldAccessExpr.resultType)
val defaultValue = primitiveDefaultValue(fieldAccessExpr.resultType)
val Seq(resultTerm, nullTerm) = ctx.addReusableLocalVariables(
(resultTypeTerm, "result"),
("boolean", "isNull"))
val resultCode = if (ctx.nullCheck) {
s"""
|${refExpr.code}
|if (${refExpr.nullTerm}) {
| $resultTerm = $defaultValue;
| $nullTerm = true;
|}
|else {
| ${fieldAccessExpr.code}
| $resultTerm = ${fieldAccessExpr.resultTerm};
| $nullTerm = ${fieldAccessExpr.nullTerm};
|}
|""".stripMargin
} else {
s"""
|${refExpr.code}
|${fieldAccessExpr.code}
|$resultTerm = ${fieldAccessExpr.resultTerm};
|""".stripMargin
}
GeneratedExpression(resultTerm, nullTerm, resultCode, fieldAccessExpr.resultType)
}
override def visitLiteral(literal: RexLiteral): GeneratedExpression = {
val resultType = FlinkTypeFactory.toLogicalType(literal.getType)
val value = resultType.getTypeRoot match {
case LogicalTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE |
LogicalTypeRoot.TIMESTAMP_WITH_LOCAL_TIME_ZONE =>
literal.getValueAs(classOf[TimestampString])
case _ =>
literal.getValue3
}
generateLiteral(ctx, resultType, value)
}
override def visitCorrelVariable(correlVariable: RexCorrelVariable): GeneratedExpression = {
GeneratedExpression(input1Term, NEVER_NULL, NO_CODE, input1Type)
}
override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression =
throw new CodeGenException("RexLocalRef are not supported yet.")
def visitRexFieldVariable(variable: RexFieldVariable): GeneratedExpression = {
val internalType = FlinkTypeFactory.toLogicalType(variable.dataType)
val nullTerm = variable.fieldTerm + "IsNull" // not use newName, keep isNull unique.
ctx.addReusableMember(s"${primitiveTypeTermForType(internalType)} ${variable.fieldTerm};")
ctx.addReusableMember(s"boolean $nullTerm;")
GeneratedExpression(variable.fieldTerm, nullTerm, NO_CODE, internalType)
}
def visitDistinctKeyVariable(value: RexDistinctKeyVariable): GeneratedExpression = {
val inputExpr = ctx.getReusableInputUnboxingExprs(input1Term, 0) match {
case Some(expr) => expr
case None =>
val pType = primitiveTypeTermForType(value.internalType)
val defaultValue = primitiveDefaultValue(value.internalType)
val resultTerm = newName("field")
val nullTerm = newName("isNull")
val code =
s"""
|$pType $resultTerm = $defaultValue;
|boolean $nullTerm = true;
|if ($input1Term != null) {
| $nullTerm = false;
| $resultTerm = ($pType) $input1Term;
|}
""".stripMargin
val expr = GeneratedExpression(resultTerm, nullTerm, code, value.internalType)
ctx.addReusableInputUnboxingExprs(input1Term, 0, expr)
expr
}
// hide the generated code as it will be executed only once
GeneratedExpression(inputExpr.resultTerm, inputExpr.nullTerm, NO_CODE, inputExpr.resultType)
}
override def visitRangeRef(rangeRef: RexRangeRef): GeneratedExpression =
throw new CodeGenException("Range references are not supported yet.")
override def visitDynamicParam(dynamicParam: RexDynamicParam): GeneratedExpression =
throw new CodeGenException("Dynamic parameter references are not supported yet.")
override def visitCall(call: RexCall): GeneratedExpression = {
val resultType = FlinkTypeFactory.toLogicalType(call.getType)
if (call.getKind == SqlKind.SEARCH) {
val sarg = call.getOperands.get(1).asInstanceOf[RexLiteral]
.getValueAs(classOf[Sarg[_]])
val rexBuilder = new FlinkRexBuilder(FlinkTypeFactory.INSTANCE)
if (sarg.isPoints) {
val operands = FlinkRexUtil.expandSearchOperands(rexBuilder, call)
.map(operand => operand.accept(this))
return generateCallExpression(ctx, call, operands, resultType)
} else {
return RexUtil.expandSearch(
rexBuilder,
null,
call).accept(this)
}
}
// convert operands and help giving untyped NULL literals a type
val operands = call.getOperands.zipWithIndex.map {
// this helps e.g. for AS(null)
// we might need to extend this logic in case some rules do not create typed NULLs
case (operandLiteral: RexLiteral, 0) if
operandLiteral.getType.getSqlTypeName == SqlTypeName.NULL &&
call.getOperator.getReturnTypeInference == ReturnTypes.ARG0 =>
generateNullLiteral(resultType, ctx.nullCheck)
case (o@_, _) => o.accept(this)
}
generateCallExpression(ctx, call, operands, resultType)
}
override def visitOver(over: RexOver): GeneratedExpression =
throw new CodeGenException("Aggregate functions over windows are not supported yet.")
override def visitSubQuery(subQuery: RexSubQuery): GeneratedExpression =
throw new CodeGenException("Subqueries are not supported yet.")
override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression =
throw new CodeGenException("Pattern field references are not supported yet.")
// ----------------------------------------------------------------------------------------
private def generateCallExpression(
ctx: CodeGeneratorContext,
call: RexCall,
operands: Seq[GeneratedExpression],
resultType: LogicalType): GeneratedExpression = {
call.getOperator match {
// arithmetic
case PLUS if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left)
requireNumeric(right)
generateBinaryArithmeticOperator(ctx, "+", resultType, left, right)
case PLUS | DATETIME_PLUS if isTemporal(resultType) =>
val left = operands.head
val right = operands(1)
requireTemporal(left)
requireTemporal(right)
generateTemporalPlusMinus(ctx, plus = true, resultType, left, right)
case MINUS if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left)
requireNumeric(right)
generateBinaryArithmeticOperator(ctx, "-", resultType, left, right)
case MINUS | MINUS_DATE if isTemporal(resultType) =>
val left = operands.head
val right = operands(1)
requireTemporal(left)
requireTemporal(right)
generateTemporalPlusMinus(ctx, plus = false, resultType, left, right)
case MULTIPLY if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left)
requireNumeric(right)
generateBinaryArithmeticOperator(ctx, "*", resultType, left, right)
case MULTIPLY if isTimeInterval(resultType) =>
val left = operands.head
val right = operands(1)
requireTimeInterval(left)
requireNumeric(right)
generateBinaryArithmeticOperator(ctx, "*", resultType, left, right)
case DIVIDE | DIVIDE_INTEGER if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left)
requireNumeric(right)
generateBinaryArithmeticOperator(ctx, "/", resultType, left, right)
case MOD if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left)
requireNumeric(right)
generateBinaryArithmeticOperator(ctx, "%", resultType, left, right)
case UNARY_MINUS if isNumeric(resultType) =>
val operand = operands.head
requireNumeric(operand)
generateUnaryArithmeticOperator(ctx, "-", resultType, operand)
case UNARY_MINUS if isTimeInterval(resultType) =>
val operand = operands.head
requireTimeInterval(operand)
generateUnaryIntervalPlusMinus(ctx, plus = false, operand)
case UNARY_PLUS if isNumeric(resultType) =>
val operand = operands.head
requireNumeric(operand)
generateUnaryArithmeticOperator(ctx, "+", resultType, operand)
case UNARY_PLUS if isTimeInterval(resultType) =>
val operand = operands.head
requireTimeInterval(operand)
generateUnaryIntervalPlusMinus(ctx, plus = true, operand)
// comparison
case EQUALS =>
val left = operands.head
val right = operands(1)
generateEquals(ctx, left, right)
case IS_NOT_DISTINCT_FROM =>
val left = operands.head
val right = operands(1)
generateIsNotDistinctFrom(ctx, left, right)
case NOT_EQUALS =>
val left = operands.head
val right = operands(1)
generateNotEquals(ctx, left, right)
case GREATER_THAN =>
val left = operands.head
val right = operands(1)
requireComparable(left)
requireComparable(right)
generateComparison(ctx, ">", left, right)
case GREATER_THAN_OR_EQUAL =>
val left = operands.head
val right = operands(1)
requireComparable(left)
requireComparable(right)
generateComparison(ctx, ">=", left, right)
case LESS_THAN =>
val left = operands.head
val right = operands(1)
requireComparable(left)
requireComparable(right)
generateComparison(ctx, "<", left, right)
case LESS_THAN_OR_EQUAL =>
val left = operands.head
val right = operands(1)
requireComparable(left)
requireComparable(right)
generateComparison(ctx, "<=", left, right)
case IS_NULL =>
val operand = operands.head
generateIsNull(ctx, operand)
case IS_NOT_NULL =>
val operand = operands.head
generateIsNotNull(ctx, operand)
// logic
case AND =>
operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) =>
requireBoolean(left)
requireBoolean(right)
generateAnd(ctx, left, right)
}
case OR =>
operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) =>
requireBoolean(left)
requireBoolean(right)
generateOr(ctx, left, right)
}
case NOT =>
val operand = operands.head
requireBoolean(operand)
generateNot(ctx, operand)
case CASE =>
generateIfElse(ctx, operands, resultType)
case IS_TRUE =>
val operand = operands.head
requireBoolean(operand)
generateIsTrue(operand)
case IS_NOT_TRUE =>
val operand = operands.head
requireBoolean(operand)
generateIsNotTrue(operand)
case IS_FALSE =>
val operand = operands.head
requireBoolean(operand)
generateIsFalse(operand)
case IS_NOT_FALSE =>
val operand = operands.head
requireBoolean(operand)
generateIsNotFalse(operand)
case SEARCH | IN =>
val left = operands.head
val right = operands.tail
generateIn(ctx, left, right)
case NOT_IN =>
val left = operands.head
val right = operands.tail
generateNot(ctx, generateIn(ctx, left, right))
// casting
case CAST =>
val operand = operands.head
generateCast(ctx, operand, resultType)
// Reinterpret
case REINTERPRET =>
val operand = operands.head
generateReinterpret(ctx, operand, resultType)
// as / renaming
case AS =>
operands.head
// rows
case ROW =>
generateRow(ctx, resultType, operands)
// arrays
case ARRAY_VALUE_CONSTRUCTOR =>
generateArray(ctx, resultType, operands)
// maps
case MAP_VALUE_CONSTRUCTOR =>
generateMap(ctx, resultType, operands)
case ITEM =>
operands.head.resultType.getTypeRoot match {
case LogicalTypeRoot.ARRAY =>
val array = operands.head
val index = operands(1)
requireInteger(index)
generateArrayElementAt(ctx, array, index)
case LogicalTypeRoot.MAP =>
val key = operands(1)
generateMapGet(ctx, operands.head, key)
case LogicalTypeRoot.ROW | LogicalTypeRoot.STRUCTURED_TYPE =>
generateDot(ctx, operands)
case _ => throw new CodeGenException("Expect an array, a map or a row.")
}
case CARDINALITY =>
operands.head.resultType match {
case t: LogicalType if TypeCheckUtils.isArray(t) =>
val array = operands.head
generateArrayCardinality(ctx, array)
case t: LogicalType if TypeCheckUtils.isMap(t) =>
val map = operands.head
generateMapCardinality(ctx, map)
case _ => throw new CodeGenException("Expect an array or a map.")
}
case ELEMENT =>
val array = operands.head
requireArray(array)
generateArrayElement(ctx, array)
case DOT =>
generateDot(ctx, operands)
case PROCTIME =>
// attribute is proctime indicator.
// We use a null literal and generate a timestamp when we need it.
generateNullLiteral(
new LocalZonedTimestampType(true, TimestampKind.PROCTIME, 3),
ctx.nullCheck)
case PROCTIME_MATERIALIZE =>
generateProctimeTimestamp(ctx, contextTerm)
case STREAMRECORD_TIMESTAMP =>
generateRowtimeAccess(ctx, contextTerm, false)
case JSON_VALUE => new JsonValueCallGen().generate(ctx, operands, resultType)
case JSON_OBJECT => new JsonObjectCallGen(call).generate(ctx, operands, resultType)
case _: SqlThrowExceptionFunction =>
val nullValue = generateNullLiteral(resultType, nullCheck = true)
val code =
s"""
|${operands.map(_.code).mkString("\n")}
|${nullValue.code}
|org.apache.flink.util.ExceptionUtils.rethrow(
| new RuntimeException(${operands.head.resultTerm}.toString()));
|""".stripMargin
GeneratedExpression(nullValue.resultTerm, nullValue.nullTerm, code, resultType)
case ssf: ScalarSqlFunction =>
new ScalarFunctionCallGen(
ssf.makeFunction(getOperandLiterals(operands), operands.map(_.resultType).toArray))
.generate(ctx, operands, resultType)
case tsf: TableSqlFunction =>
new TableFunctionCallGen(
call,
tsf.makeFunction(getOperandLiterals(operands), operands.map(_.resultType).toArray))
.generate(ctx, operands, resultType)
case bsf: BridgingSqlFunction =>
bsf.getDefinition match {
case functionDefinition : FunctionDefinition
if functionDefinition eq BuiltInFunctionDefinitions.CURRENT_WATERMARK =>
generateWatermark(ctx, contextTerm, resultType)
case functionDefinition : FunctionDefinition
if functionDefinition eq BuiltInFunctionDefinitions.GREATEST =>
operands.foreach { operand =>
requireComparable(operand)
}
generateGreatestLeast(resultType, operands)
case functionDefinition : FunctionDefinition
if functionDefinition eq BuiltInFunctionDefinitions.LEAST =>
operands.foreach { operand =>
requireComparable(operand)
}
generateGreatestLeast(resultType, operands, false)
case _ =>
new BridgingSqlFunctionCallGen(call).generate(ctx, operands, resultType)
}
// advanced scalar functions
case sqlOperator: SqlOperator =>
StringCallGen.generateCallExpression(ctx, call.getOperator, operands, resultType)
.getOrElse {
FunctionGenerator
.getInstance(ctx.tableConfig)
.getCallGenerator(
sqlOperator,
operands.map(expr => expr.resultType),
resultType)
.getOrElse(
throw new CodeGenException(s"Unsupported call: " +
s"$sqlOperator(${operands.map(_.resultType).mkString(", ")}) \n" +
s"If you think this function should be supported, " +
s"you can create an issue and start a discussion for it."))
.generate(ctx, operands, resultType)
}
// unknown or invalid
case call@_ =>
val explainCall = s"$call(${operands.map(_.resultType).mkString(", ")})"
throw new CodeGenException(s"Unsupported call: $explainCall")
}
}
def getOperandLiterals(operands: Seq[GeneratedExpression]): Array[AnyRef] = {
operands.map { expr =>
expr.literalValue match {
case None => null
case Some(literal) =>
getConverterForDataType(fromLogicalTypeToDataType(expr.resultType))
.asInstanceOf[DataFormatConverter[AnyRef, AnyRef]
].toExternal(literal.asInstanceOf[AnyRef])
}
}.toArray
}
}