blob: 82d3e6a1d8b34a01ddd2a0dae96926506d2bd714 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.codegen
import java.lang.reflect.Method
import java.lang.{Boolean => JBoolean, Byte => JByte, Character => JChar, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong, Short => JShort}
import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Date, Time, Timestamp}
import java.util.concurrent.atomic.AtomicInteger
import org.apache.calcite.avatica.util.ByteString
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.`type`.SqlTypeName.{ROW => _, _}
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.calcite.util.BuiltInMethod
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.common.functions.{FlatJoinFunction, FlatMapFunction, MapFunction}
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo._
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.table.api.types._
import org.apache.flink.table.codegen.CodeGeneratorContext.{BASE_ROW_UTIL, BINARY_STRING}
import org.apache.flink.table.codegen.GeneratedExpression.NEVER_NULL
import org.apache.flink.table.codegen.calls.ScalarOperators._
import org.apache.flink.table.codegen.calls.{BinaryStringCallGen, BuiltInMethods, CurrentTimePointCallGen, FunctionGenerator}
import org.apache.flink.table.dataformat._
import org.apache.flink.table.errorcode.TableErrors
import org.apache.flink.table.functions.sql.ScalarSqlFunctions
import org.apache.flink.table.functions.sql.internal.{SqlRuntimeFilterBuilderFunction, SqlRuntimeFilterFunction, SqlThrowExceptionFunction}
import org.apache.flink.table.typeutils.TypeCheckUtils.{isNumeric, isTemporal, isTimeInterval}
import org.apache.flink.table.typeutils._
import org.apache.flink.table.util.Logging.CODE_LOG
import org.apache.flink.types.Row
import org.apache.flink.util.StringUtils
import org.codehaus.commons.compiler.{CompileException, ICookable}
import org.codehaus.janino.SimpleCompiler
import scala.collection.mutable.ListBuffer
object CodeGenUtils {
private val nameCounter = new AtomicInteger
def newName(name: String): String = {
s"$name$$${nameCounter.getAndIncrement}"
}
def newNames(names: Seq[String]): Seq[String] = {
require(names.toSet.size == names.length, "Duplicated names")
val newId = nameCounter.getAndIncrement
names.map(name => s"$name$$$newId")
}
/**
* Retrieve the canonical name of a class type.
*/
def className[T](implicit m: Manifest[T]): String = m.runtimeClass.getCanonicalName
def needCopyForType(t: InternalType): Boolean = t match {
case DataTypes.STRING => true
case _: ArrayType => true
case _: MapType => true
case _: RowType => true
case _: GenericType[_] => true
case _ => false
}
def needCloneRefForType(t: InternalType): Boolean = t match {
case DataTypes.STRING => true
case _ => false
}
def needCloneRefForDataType(t: DataType): Boolean =
TypeConverters.createExternalTypeInfoFromDataType(t) match {
case BinaryStringTypeInfo.INSTANCE => true
case _ => false
}
// when casting we first need to unbox Primitives, for example,
// float a = 1.0f;
// byte b = (byte) a;
// works, but for boxed types we need this:
// Float a = 1.0f;
// Byte b = (byte)(float) a;
def primitiveTypeTermForType(t: InternalType): String = t match {
case DataTypes.INT => "int"
case DataTypes.LONG => "long"
case DataTypes.SHORT => "short"
case DataTypes.BYTE => "byte"
case DataTypes.FLOAT => "float"
case DataTypes.DOUBLE => "double"
case DataTypes.BOOLEAN => "boolean"
case DataTypes.CHAR => "char"
case _: DateType => "int"
case DataTypes.TIME => "int"
case _: TimestampType => "long"
case DataTypes.INTERVAL_MONTHS => "int"
case DataTypes.INTERVAL_MILLIS => "long"
case _ => boxedTypeTermForType(t)
}
def isInternalPrimitive(tpe: InternalType): Boolean = {
// now, only temporal type use primitive for representation
isTemporal(tpe)
}
def externalBoxedTermForType(t: DataType): String = t match {
case DataTypes.STRING => classOf[String].getCanonicalName
case _: DecimalType => classOf[JBigDecimal].getCanonicalName
case at: ArrayType if at.isPrimitive =>
s"${primitiveTypeTermForType(at.getElementInternalType)}[]"
case at: ArrayType => s"${externalBoxedTermForType(at.getElementType)}[]"
case bt: RowType => classOf[Row].getCanonicalName
case _: MapType => classOf[java.util.Map[_, _]].getCanonicalName
case _: TimestampType if t != DataTypes.INTERVAL_MILLIS => classOf[Timestamp].getCanonicalName
case _: DateType if t != DataTypes.INTERVAL_MONTHS => classOf[Date].getCanonicalName
case DataTypes.TIME => classOf[Time].getCanonicalName
case it: InternalType => boxedTypeTermForType(it)
case wt: TypeInfoWrappedDataType => wt.getTypeInfo match {
// From PrimitiveArrayTypeInfo we would get class "int[]", scala reflections
// does not seem to like this, so we manually give the correct type here.
case INT_PRIMITIVE_ARRAY_TYPE_INFO => "int[]"
case LONG_PRIMITIVE_ARRAY_TYPE_INFO => "long[]"
case SHORT_PRIMITIVE_ARRAY_TYPE_INFO => "short[]"
case BYTE_PRIMITIVE_ARRAY_TYPE_INFO => "byte[]"
case FLOAT_PRIMITIVE_ARRAY_TYPE_INFO => "float[]"
case DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO => "double[]"
case BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO => "boolean[]"
case CHAR_PRIMITIVE_ARRAY_TYPE_INFO => "char[]"
case _ => wt.getTypeInfo.getTypeClass.getCanonicalName
}
}
def boxedTypeTermForType(t: InternalType): String = t match {
case DataTypes.INT => classOf[JInt].getCanonicalName
case DataTypes.LONG => classOf[JLong].getCanonicalName
case DataTypes.SHORT => classOf[JShort].getCanonicalName
case DataTypes.BYTE => classOf[JByte].getCanonicalName
case DataTypes.FLOAT => classOf[JFloat].getCanonicalName
case DataTypes.DOUBLE => classOf[JDouble].getCanonicalName
case DataTypes.BOOLEAN => classOf[JBoolean].getCanonicalName
case DataTypes.CHAR => classOf[JChar].getCanonicalName
case _: DateType => boxedTypeTermForType(DataTypes.INT)
case DataTypes.TIME => boxedTypeTermForType(DataTypes.INT)
case _: TimestampType => boxedTypeTermForType(DataTypes.LONG)
case DataTypes.STRING => BINARY_STRING
case DataTypes.BYTE_ARRAY => "byte[]"
case _: DecimalType => classOf[Decimal].getCanonicalName
case _: ArrayType => classOf[BaseArray].getCanonicalName
case _: MapType => classOf[BaseMap].getCanonicalName
case _: RowType => classOf[BaseRow].getCanonicalName
case gt: GenericType[_] => gt.getTypeInfo.getTypeClass.getCanonicalName
}
def primitiveDefaultValue(t: InternalType): String = t match {
case DataTypes.INT | DataTypes.BYTE | DataTypes.SHORT => "-1"
case DataTypes.LONG => "-1L"
case DataTypes.FLOAT => "-1.0f"
case DataTypes.DOUBLE => "-1.0d"
case DataTypes.BOOLEAN => "false"
case DataTypes.STRING => s"$BINARY_STRING.EMPTY_UTF8"
case DataTypes.CHAR => "'\\0'"
case _: DateType | DataTypes.TIME => "-1"
case _: TimestampType => "-1L"
case _ => "null"
}
/**
* If it's internally compatible, don't need to DataStructure converter.
* clazz != classOf[Row] => Row can only infer GenericType[Row].
*/
def isInternalClass(clazz: Class[_], t: DataType): Boolean =
clazz != classOf[Object] && clazz != classOf[Row] &&
(classOf[BaseRow].isAssignableFrom(clazz) ||
clazz == TypeConverters.createInternalTypeInfoFromDataType(t).getTypeClass)
def qualifyMethod(method: Method): String =
method.getDeclaringClass.getCanonicalName + "." + method.getName
def qualifyEnum(enum: Enum[_]): String =
enum.getClass.getCanonicalName + "." + enum.name()
def internalToStringCode(t: InternalType,
resultTerm: String,
zoneTerm: String): String =
t match {
case DataTypes.DATE =>
s"${qualifyMethod(BuiltInMethod.UNIX_DATE_TO_STRING.method)}($resultTerm)"
case DataTypes.TIME =>
s"${qualifyMethod(BuiltInMethods.UNIX_TIME_TO_STRING)}($resultTerm)"
case _: TimestampType =>
s"""${qualifyMethod(BuiltInMethods.TIMESTAMP_TO_STRING)}($resultTerm, 3, $zoneTerm)"""
}
def compareEnum(term: String, enum: Enum[_]): Boolean = term == qualifyEnum(enum)
def getEnum(genExpr: GeneratedExpression): Enum[_] = {
val split = genExpr.resultTerm.split('.')
val value = split.last
enumValueOf(genExpr.resultType.asInstanceOf[GenericType[_]].getTypeInfo.getTypeClass, value)
}
def enumValueOf[T <: Enum[T]](cls: Class[_], stringValue: String): Enum[_] =
Enum.valueOf(cls.asInstanceOf[Class[T]], stringValue).asInstanceOf[Enum[_]]
// ----------------------------------------------------------------------------------------------
def requireNumeric(genExpr: GeneratedExpression, operatorName: String): Unit =
if (!TypeCheckUtils.isNumeric(genExpr.resultType)) {
throw new CodeGenException(
TableErrors.INST.sqlCodeGenOperatorParamError(
"Numeric expression type expected, but was " + s"'${genExpr.resultType}'.",
operatorName))
}
def requireComparable(genExpr: GeneratedExpression, operatorName: String): Unit =
if (!TypeCheckUtils.isComparable(genExpr.resultType)) {
throw new CodeGenException(
TableErrors.INST.sqlCodeGenOperatorParamError(
s"Comparable type expected, but was '${genExpr.resultType}'.",
operatorName))
}
def requireString(genExpr: GeneratedExpression, operatorName: String): Unit =
if (!TypeCheckUtils.isString(genExpr.resultType)) {
throw new CodeGenException(
TableErrors.INST.sqlCodeGenOperatorParamError(
"String expression type expected.",
operatorName))
}
def requireBoolean(genExpr: GeneratedExpression, operatorName: String): Unit =
if (!TypeCheckUtils.isBoolean(genExpr.resultType)) {
throw new CodeGenException(
TableErrors.INST.sqlCodeGenOperatorParamError(
"Boolean expression type expected.",
operatorName))
}
def requireTemporal(genExpr: GeneratedExpression, operatorName: String): Unit =
if (!TypeCheckUtils.isTemporal(genExpr.resultType)) {
throw new CodeGenException(
TableErrors.INST.sqlCodeGenOperatorParamError(
"Temporal expression type expected.",
operatorName))
}
def requireTimeInterval(genExpr: GeneratedExpression, operatorName: String): Unit =
if (!TypeCheckUtils.isTimeInterval(genExpr.resultType)) {
throw new CodeGenException(
TableErrors.INST.sqlCodeGenOperatorParamError(
"Interval expression type expected.",
operatorName))
}
def requireArray(genExpr: GeneratedExpression, operatorName: String): Unit =
if (!TypeCheckUtils.isArray(genExpr.resultType)) {
throw new CodeGenException(
TableErrors.INST.sqlCodeGenOperatorParamError(
"Array expression type expected.",
operatorName))
}
def requireMap(genExpr: GeneratedExpression, operatorName: String): Unit =
if (!TypeCheckUtils.isMap(genExpr.resultType)) {
throw new CodeGenException(
TableErrors.INST.sqlCodeGenOperatorParamError(
"Array expression type expected.",
operatorName))
}
def requireInteger(genExpr: GeneratedExpression, operatorName: String): Unit =
if (!TypeCheckUtils.isInteger(genExpr.resultType)) {
throw new CodeGenException(
TableErrors.INST.sqlCodeGenOperatorParamError(
"Integer expression type expected.",
operatorName))
}
def requireList(genExpr: GeneratedExpression, operatorName: String): Unit =
if (!TypeCheckUtils.isList(genExpr.resultType)) {
throw new CodeGenException(
TableErrors.INST.sqlCodeGenOperatorParamError(
"List expression type expected.",
operatorName))
}
def generateNullLiteral(
resultType: InternalType,
nullCheck: Boolean): GeneratedExpression = {
val defaultValue = primitiveDefaultValue(resultType)
val resultTypeTerm = primitiveTypeTermForType(resultType)
if (nullCheck) {
GeneratedExpression(
s"(($resultTypeTerm)$defaultValue)",
"true",
"",
resultType,
literal = true)
} else {
throw new CodeGenException("Null literals are not allowed if nullCheck is disabled.")
}
}
def generateNonNullLiteral(
literalType: InternalType,
literalCode: String,
literalValue: Any,
nullCheck: Boolean): GeneratedExpression = {
val resultTypeTerm = primitiveTypeTermForType(literalType)
GeneratedExpression(
s"(($resultTypeTerm)$literalCode)",
"false",
"",
literalType,
literal = true,
literalValue = literalValue)
}
def generateLiteral(
ctx: CodeGeneratorContext,
literalRelDataType: RelDataType,
literalInternalType: InternalType,
literalValue: Any,
nullCheck: Boolean): GeneratedExpression = {
if (literalValue == null) {
return generateNullLiteral(literalInternalType, nullCheck)
}
// non-null values
literalRelDataType.getSqlTypeName match {
case BOOLEAN =>
generateNonNullLiteral(literalInternalType, literalValue.toString, literalValue, nullCheck)
case TINYINT =>
val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal])
generateNonNullLiteral(
literalInternalType,
decimal.byteValue().toString,
decimal.byteValue(), nullCheck)
case SMALLINT =>
val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal])
generateNonNullLiteral(
literalInternalType,
decimal.shortValue().toString,
decimal.shortValue(), nullCheck)
case INTEGER =>
val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal])
generateNonNullLiteral(
literalInternalType,
decimal.intValue().toString,
decimal.intValue(), nullCheck)
case BIGINT =>
val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal])
generateNonNullLiteral(
literalInternalType,
decimal.longValue().toString + "L",
decimal.longValue(), nullCheck)
case FLOAT =>
val floatValue = literalValue.asInstanceOf[JBigDecimal].floatValue()
floatValue match {
case Float.NaN => generateNonNullLiteral(
literalInternalType, "java.lang.Float.NaN", Float.NaN, nullCheck)
case Float.NegativeInfinity =>
generateNonNullLiteral(
literalInternalType,
"java.lang.Float.NEGATIVE_INFINITY",
Float.NegativeInfinity, nullCheck)
case Float.PositiveInfinity => generateNonNullLiteral(
literalInternalType,
"java.lang.Float.POSITIVE_INFINITY",
Float.PositiveInfinity, nullCheck)
case _ => generateNonNullLiteral(
literalInternalType,
floatValue.toString + "f",
floatValue,
nullCheck)
}
case DOUBLE =>
val doubleValue = literalValue.asInstanceOf[JBigDecimal].doubleValue()
doubleValue match {
case Double.NaN => generateNonNullLiteral(
literalInternalType, "java.lang.Double.NaN", Double.NaN, nullCheck)
case Double.NegativeInfinity =>
generateNonNullLiteral(
literalInternalType,
"java.lang.Double.NEGATIVE_INFINITY",
Double.NegativeInfinity, nullCheck)
case Double.PositiveInfinity =>
generateNonNullLiteral(
literalInternalType,
"java.lang.Double.POSITIVE_INFINITY",
Double.PositiveInfinity, nullCheck)
case _ => generateNonNullLiteral(
literalInternalType, doubleValue.toString + "d", doubleValue, nullCheck)
}
case DECIMAL =>
val precision = literalRelDataType.getPrecision
val scale = literalRelDataType.getScale
val fieldTerm = newName("decimal")
val fieldDecimal =
s"""
|${classOf[Decimal].getCanonicalName} $fieldTerm =
| ${Decimal.Ref.castFrom}("${literalValue.toString}", $precision, $scale);
|""".stripMargin
ctx.addReusableMember(fieldDecimal)
generateNonNullLiteral(
literalInternalType,
fieldTerm,
Decimal.fromBigDecimal(literalValue.asInstanceOf[JBigDecimal], precision, scale),
nullCheck)
case VARCHAR | CHAR =>
val escapedValue = StringEscapeUtils.ESCAPE_JAVA.translate(literalValue.toString)
val field = ctx.addReusableStringConstants(escapedValue)
generateNonNullLiteral(
literalInternalType,
field,
BinaryString.fromString(escapedValue),
nullCheck)
case VARBINARY | BINARY =>
val bytesVal = literalValue.asInstanceOf[ByteString].getBytes
val fieldTerm = ctx.addReusableObject(bytesVal, "binary",
bytesVal.getClass.getCanonicalName)
generateNonNullLiteral(
literalInternalType,
fieldTerm,
BinaryString.fromBytes(bytesVal),
nullCheck)
case SYMBOL =>
generateSymbol(literalValue.asInstanceOf[Enum[_]])
case DATE =>
generateNonNullLiteral(literalInternalType, literalValue.toString, literalValue, nullCheck)
case TIME =>
generateNonNullLiteral(literalInternalType, literalValue.toString, literalValue, nullCheck)
case TIMESTAMP =>
// Hack
// Currently, in RexLiteral/SqlLiteral(Calcite), TimestampString has no time zone.
// TimeString, DateString TimestampString are treated as UTC time/(unix time)
// when they are converted/formatted/validated
// Here, we adjust millis before Calcite solve TimeZone perfectly
val millis = literalValue.asInstanceOf[Long]
val adjustedValue = millis - ctx.getTableConfig.getTimeZone.getOffset(millis)
generateNonNullLiteral(
literalInternalType, adjustedValue.toString + "L", literalValue, nullCheck)
case typeName if YEAR_INTERVAL_TYPES.contains(typeName) =>
val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal])
if (decimal.isValidInt) {
generateNonNullLiteral(
literalInternalType,
decimal.intValue().toString,
decimal.intValue(), nullCheck)
} else {
throw new CodeGenException(
s"Decimal '$decimal' can not be converted to interval of months.")
}
case typeName if DAY_INTERVAL_TYPES.contains(typeName) =>
val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal])
if (decimal.isValidLong) {
generateNonNullLiteral(
literalInternalType,
decimal.longValue().toString + "L",
decimal.longValue(), nullCheck)
} else {
throw new CodeGenException(
s"Decimal '$decimal' can not be converted to interval of milliseconds.")
}
case t@_ =>
throw new CodeGenException(s"Type not supported: $t")
}
}
def generateNonNullField(
t: InternalType,
code: String,
nullCheck: Boolean): GeneratedExpression = {
GeneratedExpression(s"((${primitiveTypeTermForType(t)}) $code)", "false", "", t)
}
def generateSymbol(enum: Enum[_]): GeneratedExpression =
GeneratedExpression(qualifyEnum(enum), "false", "", new GenericType(enum.getDeclaringClass))
def generateProctimeTimestamp(
contextTerm: String,
ctx: CodeGeneratorContext): GeneratedExpression = {
val resultTerm = ctx.newReusableField("result", "long")
val resultCode =
s"""
|$resultTerm = $contextTerm.timerService().currentProcessingTime();
|""".stripMargin.trim
GeneratedExpression(resultTerm, NEVER_NULL, resultCode, DataTypes.TIMESTAMP)
}
def generateCurrentTimestamp(
ctx: CodeGeneratorContext): GeneratedExpression = {
new CurrentTimePointCallGen(false).generate(ctx, Seq(), DataTypes.TIMESTAMP, false)
}
def generateRowtimeAccess(
contextTerm: String,
ctx: CodeGeneratorContext): GeneratedExpression = {
val Seq(resultTerm, nullTerm) = ctx.newReusableFields(
Seq("result", "isNull"),
Seq("Long", "boolean"))
val accessCode =
s"""
|$resultTerm = $contextTerm.timestamp();
|if ($resultTerm == null) {
| throw new RuntimeException("Rowtime timestamp is null. Please make sure that a " +
| "proper TimestampAssigner is defined and the stream environment uses the EventTime " +
| "time characteristic.");
|}
|$nullTerm = false;
""".stripMargin.trim
GeneratedExpression(resultTerm, nullTerm, accessCode, DataTypes.ROWTIME_INDICATOR)
}
def generateInputAccess(
ctx: CodeGeneratorContext,
inputType: InternalType,
inputTerm: String,
index: Int,
nullableInput: Boolean,
nullCheck: Boolean,
fieldCopy: Boolean = false): GeneratedExpression = {
// if input has been used before, we can reuse the code that
// has already been generated
val inputExpr = ctx.getReusableInputUnboxingExprs(inputTerm, index) match {
// input access and unboxing has already been generated
case Some(expr) => expr
// generate input access and unboxing if necessary
case None =>
val expr = if (nullableInput) {
generateNullableInputFieldAccess(ctx, inputType, inputTerm, index, nullCheck, fieldCopy)
} else {
generateFieldAccess(ctx, inputType, inputTerm, index, nullCheck, fieldCopy)
}
ctx.addReusableInputUnboxingExprs(inputTerm, index, expr)
expr
}
// hide the generated code as it will be executed only once
GeneratedExpression(inputExpr.resultTerm, inputExpr.nullTerm, "", inputExpr.resultType)
}
/**
* Generates field access code expression. The different between this method and
* [[generateFieldAccess(ctx, inputType, inputTerm, index, nullCheck)]] is that this method
* accepts an additional `fieldCopy` parameter. When copyResult is set to true, the returned
* result will be copied.
*
* NOTE: Please set `fieldCopy` to true when the result will be buffered.
*/
def generateFieldAccess(
ctx: CodeGeneratorContext,
inputType: InternalType,
inputTerm: String,
index: Int,
nullCheck: Boolean,
fieldCopy: Boolean): GeneratedExpression = {
val expr = generateFieldAccess(ctx, inputType, inputTerm, index, nullCheck)
if (fieldCopy) {
expr.copyResultIfNeeded(ctx, fieldCopy)
} else {
expr
}
}
def generateFieldAccess(
ctx: CodeGeneratorContext,
inputType: InternalType,
inputTerm: String,
index: Int,
nullCheck: Boolean): GeneratedExpression =
inputType match {
case ct: RowType =>
val fieldType = ct.getFieldTypes()(index).toInternalType
val resultTypeTerm = primitiveTypeTermForType(fieldType)
val defaultValue = primitiveDefaultValue(fieldType)
val readCode = baseRowFieldReadAccess(ctx, index.toString, inputTerm, fieldType)
val Seq(fieldTerm, nullTerm) = ctx.newReusableFields(
Seq("field", "isNull"),
Seq(resultTypeTerm, "boolean"))
val inputCode = if (nullCheck) {
s"""
|$nullTerm = $inputTerm.isNullAt($index);
|$fieldTerm = $defaultValue;
|if (!$nullTerm) {
| $fieldTerm = $readCode;
|}
""".stripMargin.trim
} else {
s"""
|$nullTerm = false;
|$fieldTerm = $readCode;
""".stripMargin
}
GeneratedExpression(fieldTerm, nullTerm, inputCode, fieldType)
case _ =>
val fieldTypeTerm = boxedTypeTermForType(inputType)
val inputCode = s"($fieldTypeTerm) $inputTerm"
generateInputFieldUnboxing(inputType, inputCode, nullCheck, ctx)
}
def generateNullableInputFieldAccess(
ctx: CodeGeneratorContext,
inputType: InternalType,
inputTerm: String,
index: Int,
nullCheck: Boolean,
fieldCopy: Boolean = false): GeneratedExpression = {
val fieldType = inputType match {
case ct: RowType => ct.getFieldTypes()(index).toInternalType
case _ => inputType
}
val resultTypeTerm = primitiveTypeTermForType(fieldType)
val defaultValue = primitiveDefaultValue(fieldType)
val Seq(resultTerm, nullTerm) = ctx.newReusableFields(
Seq("result", "isNull"),
Seq(resultTypeTerm, "boolean"))
val fieldAccessExpr = generateFieldAccess(
ctx, inputType, inputTerm, index, nullCheck, fieldCopy)
val inputCheckCode =
s"""
|$resultTerm = $defaultValue;
|$nullTerm = true;
|if ($inputTerm != null) {
| ${fieldAccessExpr.code}
| $resultTerm = ${fieldAccessExpr.resultTerm};
| $nullTerm = ${fieldAccessExpr.nullTerm};
|}
|""".stripMargin.trim
GeneratedExpression(resultTerm, nullTerm, inputCheckCode, fieldType)
}
/**
* Converts the external boxed format to an internal mostly primitive field representation.
* Wrapper types can autoboxed to their corresponding primitive type (Integer -> int).
*
* @param fieldType type of field
* @param fieldTerm expression term of field to be unboxed
* @param nullCheck whether to check null
* @return internal unboxed field representation
*/
def generateInputFieldUnboxing(
fieldType: InternalType,
fieldTerm: String,
nullCheck: Boolean,
ctx: CodeGeneratorContext): GeneratedExpression = {
val resultTypeTerm = primitiveTypeTermForType(fieldType)
val defaultValue = primitiveDefaultValue(fieldType)
val Seq(resultTerm, nullTerm) = ctx.newReusableFields(
Seq("result", "isNull"),
Seq(resultTypeTerm, "boolean"))
val wrappedCode = if (nullCheck) {
s"""
|$nullTerm = $fieldTerm == null;
|$resultTerm = $defaultValue;
|if (!$nullTerm) {
| $resultTerm = $fieldTerm;
|}
|""".stripMargin.trim
} else {
s"""
|$resultTerm = $fieldTerm;
|""".stripMargin.trim
}
GeneratedExpression(resultTerm, nullTerm, wrappedCode, fieldType)
}
def generateCallExpression(
ctx: CodeGeneratorContext,
operator: SqlOperator,
operands: Seq[GeneratedExpression],
resultType: InternalType,
nullCheck: Boolean): GeneratedExpression = {
operator match {
// arithmetic
case PLUS if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left, operator.getName)
requireNumeric(right, operator.getName)
generateArithmeticOperator(ctx, "+", nullCheck, resultType, left, right)
case PLUS | DATETIME_PLUS if isTemporal(resultType) =>
val left = operands.head
val right = operands(1)
requireTemporal(left, operator.getName)
requireTemporal(right, operator.getName)
generateTemporalPlusMinus(ctx, plus = true, nullCheck, resultType, left, right)
case MINUS if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left, operator.getName)
requireNumeric(right, operator.getName)
generateArithmeticOperator(ctx, "-", nullCheck, resultType, left, right)
case MINUS | MINUS_DATE if isTemporal(resultType) =>
val left = operands.head
val right = operands(1)
requireTemporal(left, operator.getName)
requireTemporal(right, operator.getName)
generateTemporalPlusMinus(ctx, plus = false, nullCheck, resultType, left, right)
case MULTIPLY if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left, operator.getName)
requireNumeric(right, operator.getName)
generateArithmeticOperator(ctx, "*", nullCheck, resultType, left, right)
case MULTIPLY if isTimeInterval(resultType) =>
val left = operands.head
val right = operands(1)
requireTimeInterval(left, operator.getName)
requireNumeric(right, operator.getName)
generateArithmeticOperator(ctx, "*", nullCheck, resultType, left, right)
case ScalarSqlFunctions.DIVIDE | DIVIDE_INTEGER if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left, operator.getName)
requireNumeric(right, operator.getName)
generateArithmeticOperator(ctx, "/", nullCheck, resultType, left, right)
case MOD if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left, operator.getName)
requireNumeric(right, operator.getName)
generateArithmeticOperator(ctx, "%", nullCheck, resultType, left, right)
case UNARY_MINUS if isNumeric(resultType) =>
val operand = operands.head
requireNumeric(operand, operator.getName)
generateUnaryArithmeticOperator(ctx, "-", nullCheck, resultType, operand)
case UNARY_MINUS if isTimeInterval(resultType) =>
val operand = operands.head
requireTimeInterval(operand, operator.getName)
generateUnaryIntervalPlusMinus(ctx, plus = false, nullCheck, operand)
case UNARY_PLUS if isNumeric(resultType) =>
val operand = operands.head
requireNumeric(operand, operator.getName)
generateUnaryArithmeticOperator(ctx, "+", nullCheck, resultType, operand)
case UNARY_PLUS if isTimeInterval(resultType) =>
val operand = operands.head
requireTimeInterval(operand, operator.getName)
generateUnaryIntervalPlusMinus(ctx, plus = true, nullCheck, operand)
// comparison
case EQUALS =>
val left = operands.head
val right = operands(1)
generateEquals(ctx, nullCheck, left, right)
case NOT_EQUALS =>
val left = operands.head
val right = operands(1)
generateNotEquals(ctx, nullCheck, left, right)
case GREATER_THAN =>
val left = operands.head
val right = operands(1)
requireComparable(left, operator.getName)
requireComparable(right, operator.getName)
generateComparison(ctx, ">", nullCheck, left, right)
case GREATER_THAN_OR_EQUAL =>
val left = operands.head
val right = operands(1)
requireComparable(left, operator.getName)
requireComparable(right, operator.getName)
generateComparison(ctx, ">=", nullCheck, left, right)
case LESS_THAN =>
val left = operands.head
val right = operands(1)
requireComparable(left, operator.getName)
requireComparable(right, operator.getName)
generateComparison(ctx, "<", nullCheck, left, right)
case LESS_THAN_OR_EQUAL =>
val left = operands.head
val right = operands(1)
requireComparable(left, operator.getName)
requireComparable(right, operator.getName)
generateComparison(ctx, "<=", nullCheck, left, right)
case IS_NULL =>
val operand = operands.head
generateIsNull(nullCheck, operand)
case IS_NOT_NULL =>
val operand = operands.head
generateIsNotNull(nullCheck, operand)
// logic
case AND =>
operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) =>
requireBoolean(left, operator.getName)
requireBoolean(right, operator.getName)
generateAnd(nullCheck, left, right)
}
case OR =>
operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) =>
requireBoolean(left, operator.getName)
requireBoolean(right, operator.getName)
generateOr(nullCheck, left, right)
}
case NOT =>
val operand = operands.head
requireBoolean(operand, operator.getName)
generateNot(ctx, nullCheck, operand)
case CASE =>
generateIfElse(ctx, nullCheck, operands, resultType)
case IS_TRUE =>
val operand = operands.head
requireBoolean(operand, operator.getName)
generateIsTrue(operand)
case IS_NOT_TRUE =>
val operand = operands.head
requireBoolean(operand, operator.getName)
generateIsNotTrue(operand)
case IS_FALSE =>
val operand = operands.head
requireBoolean(operand, operator.getName)
generateIsFalse(operand)
case IS_NOT_FALSE =>
val operand = operands.head
requireBoolean(operand, operator.getName)
generateIsNotFalse(operand)
case IN =>
val left = operands.head
val right = operands.tail
generateIn(ctx, left, right, nullCheck)
case NOT_IN =>
val left = operands.head
val right = operands.tail
generateNot(ctx, nullCheck, generateIn(ctx, left, right, nullCheck))
// casting
case CAST =>
val operand = operands.head
generateCast(ctx, nullCheck, operand, resultType)
// Reinterpret
case REINTERPRET =>
val operand = operands.head
generateReinterpret(ctx, nullCheck, operand, resultType)
// as / renaming
case AS =>
operands.head
// rows
case ROW =>
generateRow(ctx, resultType, operands, nullCheck)
// arrays
case ARRAY_VALUE_CONSTRUCTOR =>
generateArray(ctx, resultType, operands, nullCheck)
// maps
case MAP_VALUE_CONSTRUCTOR =>
generateMap(ctx, resultType, operands, nullCheck)
case ITEM =>
operands.head.resultType match {
case t: InternalType if TypeCheckUtils.isArray(t) =>
val array = operands.head
val index = operands(1)
requireInteger(index, operator.getName)
generateArrayElementAt(ctx, array, index, nullCheck)
case t: InternalType if TypeCheckUtils.isMap(t) =>
val key = operands(1)
generateMapGet(ctx, operands.head, key, nullCheck)
case _ => throw new CodeGenException("Expect an array or a map.")
}
case CARDINALITY =>
operands.head.resultType match {
case t: InternalType if TypeCheckUtils.isArray(t) =>
val array = operands.head
generateArrayCardinality(ctx, nullCheck, array)
case t: InternalType if TypeCheckUtils.isMap(t) =>
val map = operands.head
generateMapCardinality(ctx, nullCheck, map)
case _ => throw new CodeGenException("Expect an array or a map.")
}
case ELEMENT =>
val array = operands.head
requireArray(array, operator.getName)
generateArrayElement(ctx, array, nullCheck)
case DOT =>
generateDOT(ctx, operands, nullCheck)
case func: SqlRuntimeFilterFunction => generateRuntimeFilter(ctx, operands, func)
case func: SqlRuntimeFilterBuilderFunction =>
generateRuntimeFilterBuilder(ctx, operands, func)
case _: SqlThrowExceptionFunction =>
val nullValue = generateNullLiteral(resultType, nullCheck)
val code =
s"""
|${nullValue.code}
|org.apache.flink.util.ExceptionUtils.rethrow(
| new RuntimeException(${operands.head.resultTerm}.toString()));
|""".stripMargin
GeneratedExpression(nullValue.resultTerm, nullValue.nullTerm, code, resultType)
case ScalarSqlFunctions.PROCTIME =>
// attribute is proctime indicator.
// We use a null literal and generate a timestamp when we need it.
generateNullLiteral(DataTypes.PROCTIME_INDICATOR, nullCheck)
// advanced scalar functions
case sqlOperator: SqlOperator =>
BinaryStringCallGen.generateCallExpression(ctx, operator, operands, resultType).getOrElse{
FunctionGenerator.getCallGenerator(
sqlOperator,
operands.map(expr => expr.resultType),
resultType).getOrElse(
throw new CodeGenException(TableErrors.INST.sqlCodeGenUnsupportedScalaFunc(
s"$sqlOperator(${operands.map(_.resultType).mkString(",")})")))
.generate(ctx, operands, resultType, nullCheck)
}
// unknown or invalid
case call@_ =>
throw new CodeGenException(
TableErrors.INST.sqlCodeGenUnsupportedCall(
s"$call${operands.map(_.resultType).mkString(",")}"))
}
}
// ----------------------------------------------------------------------------------------------
def isReference(genExpr: GeneratedExpression): Boolean = isReference(genExpr.resultType)
def isReference(t: InternalType): Boolean = t match {
case DataTypes.INT
| DataTypes.LONG
| DataTypes.SHORT
| DataTypes.BYTE
| DataTypes.FLOAT
| DataTypes.DOUBLE
| DataTypes.BOOLEAN
| DataTypes.CHAR => false
case _ => true
}
def baseRowFieldReadAccess(
ctx: CodeGeneratorContext,
pos: Int,
rowTerm: String,
fieldType: InternalType) : String =
baseRowFieldReadAccess(ctx, pos.toString, rowTerm, fieldType)
def baseRowFieldReadAccess(
ctx: CodeGeneratorContext,
pos: String,
rowTerm: String,
fieldType: InternalType) : String =
fieldType match {
case DataTypes.INT => s"$rowTerm.getInt($pos)"
case DataTypes.LONG => s"$rowTerm.getLong($pos)"
case DataTypes.SHORT => s"$rowTerm.getShort($pos)"
case DataTypes.BYTE => s"$rowTerm.getByte($pos)"
case DataTypes.FLOAT => s"$rowTerm.getFloat($pos)"
case DataTypes.DOUBLE => s"$rowTerm.getDouble($pos)"
case DataTypes.BOOLEAN => s"$rowTerm.getBoolean($pos)"
case DataTypes.STRING =>
val reuse = newName("reuseBString")
ctx.addReusableMember(s"$BINARY_STRING $reuse = new $BINARY_STRING();")
s"$rowTerm.getBinaryString($pos, $reuse)"
case dt: DecimalType => s"$rowTerm.getDecimal($pos, ${dt.precision()}, ${dt.scale()})"
case DataTypes.CHAR => s"$rowTerm.getChar($pos)"
case _: TimestampType => s"$rowTerm.getLong($pos)"
case _: DateType => s"$rowTerm.getInt($pos)"
case DataTypes.TIME => s"$rowTerm.getInt($pos)"
case DataTypes.BYTE_ARRAY => s"$rowTerm.getByteArray($pos)"
case _: ArrayType => s"$rowTerm.getBaseArray($pos)"
case _: MapType => s"$rowTerm.getBaseMap($pos)"
case rt: RowType =>
s"$rowTerm.getBaseRow($pos, ${rt.getArity})"
case gt: GenericType[_] =>
s"""
|(${gt.getTypeClass.getCanonicalName})
| $rowTerm.getGeneric($pos, ${ctx.addReusableTypeSerializer(fieldType)})
""".stripMargin.trim
}
def binaryWriterWriteNull(pos: Int, writerTerm: String, t: InternalType): String = t match {
case d: DecimalType if !Decimal.isCompact(d.precision()) =>
s"$writerTerm.writeDecimal($pos, null, ${d.precision()}, ${d.scale()})"
case _ => s"$writerTerm.setNullAt($pos)"
}
def binaryRowSetNull(pos: Int, rowTerm: String, t: InternalType): String = t match {
case d: DecimalType if !Decimal.isCompact(d.precision()) =>
s"$rowTerm.setDecimal($pos, null, ${d.precision()}, ${d.scale()})"
case _ => s"$rowTerm.setNullAt($pos)"
}
def binaryRowFieldSetAccess(
pos: Int,
binaryRowTerm: String,
fieldType: InternalType,
fieldValTerm: String)
: String =
fieldType match {
case DataTypes.INT => s"$binaryRowTerm.setInt($pos, $fieldValTerm)"
case DataTypes.LONG => s"$binaryRowTerm.setLong($pos, $fieldValTerm)"
case DataTypes.SHORT => s"$binaryRowTerm.setShort($pos, $fieldValTerm)"
case DataTypes.BYTE => s"$binaryRowTerm.setByte($pos, $fieldValTerm)"
case DataTypes.FLOAT => s"$binaryRowTerm.setFloat($pos, $fieldValTerm)"
case DataTypes.DOUBLE => s"$binaryRowTerm.setDouble($pos, $fieldValTerm)"
case DataTypes.BOOLEAN => s"$binaryRowTerm.setBoolean($pos, $fieldValTerm)"
case DataTypes.CHAR => s"$binaryRowTerm.setChar($pos, $fieldValTerm)"
case _: DateType => s"$binaryRowTerm.setInt($pos, $fieldValTerm)"
case DataTypes.TIME => s"$binaryRowTerm.setInt($pos, $fieldValTerm)"
case _: TimestampType => s"$binaryRowTerm.setLong($pos, $fieldValTerm)"
case d: DecimalType =>
s"$binaryRowTerm.setDecimal($pos, $fieldValTerm, ${d.precision()}, ${d.scale()})"
case _ =>
throw new CodeGenException("Fail to find binary row field setter method of InternalType "
+ fieldType + ".")
}
def binaryWriterWriteField(
ctx: CodeGeneratorContext,
pos: Int,
fieldValTerm: String,
writerTerm: String,
fieldType: InternalType): String =
fieldType match {
case DataTypes.INT => s"$writerTerm.writeInt($pos, $fieldValTerm)"
case DataTypes.LONG => s"$writerTerm.writeLong($pos, $fieldValTerm)"
case DataTypes.SHORT => s"$writerTerm.writeShort($pos, $fieldValTerm)"
case DataTypes.BYTE => s"$writerTerm.writeByte($pos, $fieldValTerm)"
case DataTypes.FLOAT => s"$writerTerm.writeFloat($pos, $fieldValTerm)"
case DataTypes.DOUBLE => s"$writerTerm.writeDouble($pos, $fieldValTerm)"
case DataTypes.BOOLEAN => s"$writerTerm.writeBoolean($pos, $fieldValTerm)"
case DataTypes.STRING => s"$writerTerm.writeBinaryString($pos, $fieldValTerm)"
case d: DecimalType =>
s"$writerTerm.writeDecimal($pos, $fieldValTerm, ${d.precision()}, ${d.scale()})"
case DataTypes.CHAR => s"$writerTerm.writeChar($pos, $fieldValTerm)"
case _: DateType => s"$writerTerm.writeInt($pos, $fieldValTerm)"
case DataTypes.TIME => s"$writerTerm.writeInt($pos, $fieldValTerm)"
case _: TimestampType => s"$writerTerm.writeLong($pos, $fieldValTerm)"
case DataTypes.BYTE_ARRAY => s"$writerTerm.writeByteArray($pos, $fieldValTerm)"
case _: ArrayType =>
s"$BASE_ROW_UTIL.writeBaseArray($writerTerm, $pos, $fieldValTerm, " +
s"(${classOf[BaseArraySerializer].getCanonicalName}) " +
s"${ctx.addReusableTypeSerializer(fieldType)})"
case _: MapType =>
s"$BASE_ROW_UTIL.writeBaseMap($writerTerm, $pos, $fieldValTerm, " +
s"(${classOf[BaseMapSerializer].getCanonicalName}) " +
s"${ctx.addReusableTypeSerializer(fieldType)})"
case _: RowType =>
s"$BASE_ROW_UTIL.writeBaseRow($writerTerm, $pos, $fieldValTerm, " +
s"(${classOf[BaseRowSerializer[_]].getCanonicalName}) " +
s"${ctx.addReusableTypeSerializer(fieldType)})"
case _: GenericType[_] => s"$writerTerm.writeGeneric($pos, $fieldValTerm, " +
s"${ctx.addReusableTypeSerializer(fieldType)})"
}
def baseArraySetNull(
pos: Int,
term: String,
t: InternalType): String = t match {
case DataTypes.BOOLEAN => s"$term.setNullBoolean($pos)"
case DataTypes.BYTE => s"$term.setNullByte($pos)"
case DataTypes.CHAR => s"$term.setNullChar($pos)"
case DataTypes.SHORT => s"$term.setNullShort($pos)"
case DataTypes.INT => s"$term.setNullInt($pos)"
case DataTypes.LONG => s"$term.setNullLong($pos)"
case DataTypes.FLOAT => s"$term.setNullFloat($pos)"
case DataTypes.DOUBLE => s"$term.setNullDouble($pos)"
case DataTypes.TIME => s"$term.setNullInt($pos)"
case _: DateType => s"$term.setNullInt($pos)"
case _: TimestampType => s"$term.setNullLong($pos)"
case _ => s"$term.setNullLong($pos)"
}
def boxedWrapperRowFieldUpdateAccess(
pos: Int,
fieldValTerm: String,
rowTerm: String,
fieldType: InternalType): String =
fieldType match {
case DataTypes.INT => s"$rowTerm.setInt($pos, $fieldValTerm)"
case DataTypes.LONG => s"$rowTerm.setLong($pos, $fieldValTerm)"
case DataTypes.SHORT => s"$rowTerm.setShort($pos, $fieldValTerm)"
case DataTypes.BYTE => s"$rowTerm.setByte($pos, $fieldValTerm)"
case DataTypes.FLOAT => s"$rowTerm.setFloat($pos, $fieldValTerm)"
case DataTypes.DOUBLE => s"$rowTerm.setDouble($pos, $fieldValTerm)"
case DataTypes.BOOLEAN => s"$rowTerm.setBoolean($pos, $fieldValTerm)"
case DataTypes.CHAR => s"$rowTerm.setChar($pos, $fieldValTerm)"
case _: DateType => s"$rowTerm.setInt($pos, $fieldValTerm)"
case DataTypes.TIME => s"$rowTerm.setInt($pos, $fieldValTerm)"
case _: TimestampType => s"$rowTerm.setLong($pos, $fieldValTerm)"
case _ => s"$rowTerm.setNonPrimitiveValue($pos, $fieldValTerm)"
}
// ----------------------------------------------------------------------------------------------
@throws(classOf[CompileException])
def compile[T](cl: ClassLoader, name: String, code: String): Class[T] = {
CODE_LOG.debug(s"Compiling: $name \n\n Code:\n$code")
require(cl != null, "Classloader must not be null.")
val compiler = new SimpleCompiler()
compiler.setParentClassLoader(cl)
try {
compiler.cook(code)
} catch {
case t: Throwable =>
println(CodeFormatter.format(code))
throw new InvalidProgramException("Table program cannot be compiled. " +
"This is a bug. Please file an issue.", t)
}
compiler.getClassLoader.loadClass(name).asInstanceOf[Class[T]]
}
/**
* enable code generate debug for janino
* like "gcc -g"
*/
def enableCodeGenerateDebug(): Unit = {
System.setProperty(ICookable.SYSTEM_PROPERTY_SOURCE_DEBUGGING_ENABLE, "true")
}
def disableCodeGenerateDebug(): Unit = {
System.setProperty(ICookable.SYSTEM_PROPERTY_SOURCE_DEBUGGING_ENABLE, "false")
}
def setCodeGenerateTmpDir(path: String): Unit = {
if (!StringUtils.isNullOrWhitespaceOnly(path)) {
System.setProperty(ICookable.SYSTEM_PROPERTY_SOURCE_DEBUGGING_DIR, path)
} else {
throw new RuntimeException("code generate tmp dir can't be empty")
}
}
// ----------------------------------------------------------------------------------------------
def genLogInfo(logTerm: String, format: String, argTerm: String): String =
s"""$logTerm.info("$format", $argTerm);"""
/**
*
* @param codeBuffer sequence of code need to be split
* @param limitLength maxLength of split code
* @param subFunctionName name of function which code will be split into
* @param subFunctionModifier modifier of function which code will be split into
* @param defineParams params for function definition
* @param callingParams params for function call
* @return
*/
def generateSplitFunctionCalls(
codeBuffer: Seq[String],
limitLength: Int,
subFunctionName: String,
subFunctionModifier: String,
fieldStatementLength: Int,
defineParams: String = "",
callingParams: String = ""): GeneratedSplittableExpression = {
val bodies = new ListBuffer[String]()
val rest = codeBuffer.foldLeft("")((acc, code) => {
if (acc.length + code.length <= limitLength) {
if (acc.length > 0) {
acc + "\n" + code
} else {
code
}
} else {
if (acc.length > 0) {
bodies += acc
}
code
}
})
bodies += rest
val defines = bodies.indices
.map(index => s"$subFunctionModifier ${subFunctionName}_$index($defineParams)")
val callings = bodies.indices
.map(index => s"${subFunctionName}_$index($callingParams);")
val isSplit = (defines.length > 1) ||
(defines.length == 1 && codeBuffer.map(_.length).sum + fieldStatementLength > limitLength)
GeneratedSplittableExpression(defines, bodies, callings, isSplit)
}
def getDefineParamsByFunctionClass(clazz: Class[_]): String = {
if (clazz == classOf[FlatMapFunction[_, _]]) {
s"Object _in1, " +
s"org.apache.flink.util.Collector ${CodeGeneratorContext.DEFAULT_COLLECTOR_TERM}"
} else if (clazz == classOf[MapFunction[_, _]]) {
s"Object _in1"
} else if (clazz == classOf[FlatJoinFunction[_, _, _]]) {
s"Object _in1, Object _in2, " +
s"org.apache.flink.util.Collector ${CodeGeneratorContext.DEFAULT_COLLECTOR_TERM}"
} else if (clazz == classOf[ProcessFunction[_, _]]) {
"Object _in1, org.apache.flink.streaming.api.functions.ProcessFunction.Context " +
s"${CodeGeneratorContext.DEFAULT_CONTEXT_TERM}, org.apache.flink.util.Collector " +
s"${CodeGeneratorContext.DEFAULT_COLLECTOR_TERM}"
} else {
""
}
}
def getCallingParamsByFunctionClass(clazz: Class[_]): String = {
if (clazz == classOf[FlatMapFunction[_, _]]) {
s"${CodeGeneratorContext.DEFAULT_INPUT1_TERM}, " +
s"${CodeGeneratorContext.DEFAULT_COLLECTOR_TERM}"
} else if (clazz == classOf[MapFunction[_, _]]) {
s"${CodeGeneratorContext.DEFAULT_INPUT1_TERM}"
} else if (clazz == classOf[FlatJoinFunction[_, _, _]]) {
s"${CodeGeneratorContext.DEFAULT_INPUT1_TERM}, " +
s"${CodeGeneratorContext.DEFAULT_INPUT2_TERM}, " +
s"${CodeGeneratorContext.DEFAULT_COLLECTOR_TERM}"
} else if (clazz == classOf[ProcessFunction[_, _]]) {
s"${CodeGeneratorContext.DEFAULT_INPUT1_TERM}, " +
s"${CodeGeneratorContext.DEFAULT_CONTEXT_TERM}, " +
s"${CodeGeneratorContext.DEFAULT_COLLECTOR_TERM}"
} else {
""
}
}
// ----------------------------------------------------------------------------------------------
// Cast numeric type to another numeric type with larger range.
// This function must be in sync with [[NumericOrDefaultReturnTypeInference]].
def getNumericCastedResultTerm(expr: GeneratedExpression, targetType: InternalType): String = {
(expr.resultType, targetType) match {
case _ if expr.resultType == targetType => expr.resultTerm
// byte -> other numeric types
case (_: ByteType, _: ShortType) => s"(short) ${expr.resultTerm}"
case (_: ByteType, _: IntType) => s"(int) ${expr.resultTerm}"
case (_: ByteType, _: LongType) => s"(long) ${expr.resultTerm}"
case (_: ByteType, dt: DecimalType) =>
s"${classOf[Decimal].getCanonicalName}.castFrom(" +
s"${expr.resultTerm}, ${dt.precision}, ${dt.scale})"
case (_: ByteType, _: FloatType) => s"(float) ${expr.resultTerm}"
case (_: ByteType, _: DoubleType) => s"(double) ${expr.resultTerm}"
// short -> other numeric types
case (_: ShortType, _: IntType) => s"(int) ${expr.resultTerm}"
case (_: ShortType, _: LongType) => s"(long) ${expr.resultTerm}"
case (_: ShortType, dt: DecimalType) =>
s"${classOf[Decimal].getCanonicalName}.castFrom(" +
s"${expr.resultTerm}, ${dt.precision}, ${dt.scale})"
case (_: ShortType, _: FloatType) => s"(float) ${expr.resultTerm}"
case (_: ShortType, _: DoubleType) => s"(double) ${expr.resultTerm}"
// int -> other numeric types
case (_: IntType, _: LongType) => s"(long) ${expr.resultTerm}"
case (_: IntType, dt: DecimalType) =>
s"${classOf[Decimal].getCanonicalName}.castFrom(" +
s"${expr.resultTerm}, ${dt.precision}, ${dt.scale})"
case (_: IntType, _: FloatType) => s"(float) ${expr.resultTerm}"
case (_: IntType, _: DoubleType) => s"(double) ${expr.resultTerm}"
// long -> other numeric types
case (_: LongType, dt: DecimalType) =>
s"${classOf[Decimal].getCanonicalName}.castFrom(" +
s"${expr.resultTerm}, ${dt.precision}, ${dt.scale})"
case (_: LongType, _: FloatType) => s"(float) ${expr.resultTerm}"
case (_: LongType, _: DoubleType) => s"(double) ${expr.resultTerm}"
// decimal -> other numeric types
case (_: DecimalType, dt: DecimalType) =>
s"${classOf[Decimal].getCanonicalName}.castToDecimal(" +
s"${expr.resultTerm}, ${dt.precision}, ${dt.scale})"
case (_: DecimalType, _: FloatType) =>
s"${classOf[Decimal].getCanonicalName}.castToFloat(${expr.resultTerm})"
case (_: DecimalType, _: DoubleType) =>
s"${classOf[Decimal].getCanonicalName}.castToDouble(${expr.resultTerm})"
// float -> other numeric types
case (_: FloatType, _: DoubleType) => s"(double) ${expr.resultTerm}"
case _ => null
}
}
}