blob: 93d188d973fbffc2d7905e83df99884dedd3cb88 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.codegen
import org.apache.calcite.rex._
import org.apache.calcite.sql.`type`.{ReturnTypes, SqlTypeName}
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.table.api.TableException
import org.apache.flink.table.api.types._
import org.apache.flink.table.calcite.{FlinkTypeFactory, RexAggBufferVariable, RexAggLocalVariable, RexDistinctKeyVariable}
import org.apache.flink.table.codegen.CodeGenUtils._
import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.table.dataformat._
import org.apache.flink.table.functions.sql.{ProctimeSqlFunction, StreamRecordTimestampSqlFunction}
import org.apache.flink.table.typeutils.TypeUtils
import org.apache.flink.util.Preconditions.checkArgument
import scala.collection.JavaConversions._
/**
* This code generator is mainly responsible for generating codes for a given calcite [[RexNode]].
* It can also generate type conversion codes for the result converter.
*/
class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean, nullCheck: Boolean)
extends RexVisitor[GeneratedExpression] {
// check if nullCheck is enabled when inputs can be null
if (nullableInput && !nullCheck) {
throw new CodeGenException("Null check must be enabled if entire rows can be null.")
}
/**
* term of the [[ProcessFunction]]'s context, can be changed when needed
*/
var contextTerm = "ctx"
/**
* information of the first input
*/
var input1Type: InternalType = _
var input1Term: String = _
var input1FieldMapping: Option[Array[Int]] = None
/**
* information of the optional second input
*/
var input2Type: Option[InternalType] = None
var input2Term: Option[String] = None
var input2FieldMapping: Option[Array[Int]] = None
/**
* Bind the input information, should be called before generating expression.
*/
def bindInput(
inputType: InternalType,
inputTerm: String = CodeGeneratorContext.DEFAULT_INPUT1_TERM,
inputFieldMapping: Option[Array[Int]] = None): ExprCodeGenerator = {
input1Type = inputType
input1Term = inputTerm
input1FieldMapping = inputFieldMapping
this
}
/**
* In some cases, the expression will have two inputs (e.g. join condition and udtf). We should
* bind second input information before use.
*/
def bindSecondInput(
inputType: InternalType,
inputTerm: String = CodeGeneratorContext.DEFAULT_INPUT2_TERM,
inputFieldMapping: Option[Array[Int]] = None): ExprCodeGenerator = {
input2Type = Some(inputType)
input2Term = Some(inputTerm)
input2FieldMapping = inputFieldMapping
this
}
protected lazy val input1Mapping: Array[Int] = input1FieldMapping match {
case Some(mapping) => mapping
case _ => (0 until TypeUtils.getArity(input1Type)).toArray
}
protected lazy val input2Mapping: Array[Int] = input2FieldMapping match {
case Some(mapping) => mapping
case _ => input2Type match {
case Some(input) => (0 until TypeUtils.getArity(input)).toArray
case _ => Array[Int]()
}
}
/**
* Generates an expression from a RexNode. If objects or variables can be reused, they will be
* added to reusable code sections internally.
*
* @param rex Calcite row expression
* @return instance of GeneratedExpression
*/
def generateExpression(rex: RexNode): GeneratedExpression = {
rex.accept(this)
}
/**
* Generates an expression that converts the first input (and second input) into the given type.
* If two inputs are converted, the second input is appended. If objects or variables can
* be reused, they will be added to reusable code sections internally. The evaluation result
* will be stored in the variable outRecordTerm.
*
* @param returnType conversion target type. Inputs and output must have the same arity.
* @param outRecordTerm the result term
* @param outRecordWriterTerm the result writer term
* @param reusedOutRow If objects or variables can be reused, they will be added to reusable
* code sections internally.
* @return instance of GeneratedExpression
*/
def generateConverterResultExpression(
returnType: RowType,
returnTypeClazz: Class[_ <: BaseRow],
outRecordTerm: String = CodeGeneratorContext.DEFAULT_OUT_RECORD_TERM,
outRecordWriterTerm: String = CodeGeneratorContext.DEFAULT_OUT_RECORD_WRITER_TERM,
reusedOutRow: Boolean = true,
fieldCopy: Boolean = false,
rowtimeExpression: Option[RexNode] = None)
: GeneratedExpression = {
val input1AccessExprs = input1Mapping.map {
case DataTypes.ROWTIME_STREAM_MARKER |
DataTypes.ROWTIME_BATCH_MARKER if rowtimeExpression.isDefined =>
// generate rowtime attribute from expression
generateExpression(rowtimeExpression.get)
case DataTypes.ROWTIME_STREAM_MARKER |
DataTypes.ROWTIME_BATCH_MARKER =>
throw new TableException("Rowtime extraction expression missing. Please report a bug.")
case DataTypes.PROCTIME_STREAM_MARKER =>
// attribute is proctime indicator.
// we use a null literal and generate a timestamp when we need it.
generateNullLiteral(DataTypes.PROCTIME_INDICATOR, nullCheck)
case DataTypes.PROCTIME_BATCH_MARKER =>
// attribute is proctime field in a batch query.
// it is initialized with the current time.
generateCurrentTimestamp(ctx)
case idx =>
// get type of result field
generateInputAccess(
ctx,
input1Type,
input1Term,
idx,
nullableInput,
nullCheck,
fieldCopy)
}
val input2AccessExprs = input2Type match {
case Some(ti) =>
input2Mapping.map(idx => generateInputAccess(
ctx,
ti,
input2Term.get,
idx,
nullableInput,
nullCheck)
).toSeq
case None => Seq() // add nothing
}
generateResultExpression(
input1AccessExprs ++ input2AccessExprs,
returnType,
returnTypeClazz,
outRow = outRecordTerm,
outRowWriter = Some(outRecordWriterTerm),
reusedOutRow = reusedOutRow)
}
/**
* Generates an expression from a sequence of other expressions. The evaluation result
* may be stored in the variable outRecordTerm.
*
* @param fieldExprs field expressions to be converted
* @param returnType conversion target type. Type must have the same arity than fieldExprs.
* @param outRow the result term
* @param outRowWriter the result writer term for BinaryRow.
* @param reusedOutRow If objects or variables can be reused, they will be added to reusable
* code sections internally.
* @param outRowAlreadyExists Don't need addReusableRecord if out row already exists.
* @return instance of GeneratedExpression
*/
def generateResultExpression(
fieldExprs: Seq[GeneratedExpression],
returnType: RowType,
returnTypeClazz: Class[_ <: BaseRow],
outRow: String = CodeGeneratorContext.DEFAULT_OUT_RECORD_TERM,
outRowWriter: Option[String] = Some(CodeGeneratorContext.DEFAULT_OUT_RECORD_WRITER_TERM),
reusedOutRow: Boolean = true,
outRowAlreadyExists: Boolean = false): GeneratedExpression = {
val fieldExprIdxToOutputRowPosMap = fieldExprs.indices.map(i => i -> i).toMap
generateResultExpression(fieldExprs, fieldExprIdxToOutputRowPosMap, returnType,
returnTypeClazz, outRow, outRowWriter, reusedOutRow, outRowAlreadyExists)
}
/**
* Generates an expression from a sequence of other expressions. The evaluation result
* may be stored in the variable outRecordTerm.
*
* @param fieldExprs field expressions to be converted
* @param fieldExprIdxToOutputRowPosMap Mapping index of fieldExpr in `fieldExprs`
* to position of output row.
* @param returnType conversion target type. Type must have the same arity than fieldExprs.
* @param outRow the result term
* @param outRowWriter the result writer term for BinaryRow.
* @param reusedOutRow If objects or variables can be reused, they will be added to reusable
* code sections internally.
* @param outRowAlreadyExists Don't need addReusableRecord if out row already exists.
* @return instance of GeneratedExpression
*/
def generateResultExpression(
fieldExprs: Seq[GeneratedExpression],
fieldExprIdxToOutputRowPosMap: Map[Int, Int],
returnType: RowType,
returnTypeClazz: Class[_ <: BaseRow],
outRow: String,
outRowWriter: Option[String],
reusedOutRow: Boolean,
outRowAlreadyExists: Boolean)
: GeneratedExpression = {
// initial type check
if (returnType.getArity != fieldExprs.length) {
throw new CodeGenException(
s"Arity [${returnType.getArity}] of result type [$returnType] does not match " +
s"number [${fieldExprs.length}] of expressions [$fieldExprs].")
}
if (fieldExprIdxToOutputRowPosMap.size != fieldExprs.length) {
throw new CodeGenException(
s"Size [${returnType.getArity}] of fieldExprIdxToOutputRowPosMap does not match " +
s"number [${fieldExprs.length}] of expressions [$fieldExprs].")
}
// type check
fieldExprs.zipWithIndex foreach {
// timestamp type(Include TimeIndicator) and generic type can compatible with each other.
case (fieldExpr, i)
if fieldExpr.resultType.isInstanceOf[GenericType[_]] ||
fieldExpr.resultType.isInstanceOf[TimestampType] =>
if (returnType.getInternalTypeAt(i).getClass != fieldExpr.resultType.getClass
&& !returnType.getInternalTypeAt(i).isInstanceOf[GenericType[_]]) {
throw new CodeGenException(
s"Incompatible types of expression and result type, Expression[$fieldExpr] type is " +
s"[${fieldExpr.resultType}], result type is [${returnType.getInternalTypeAt(i)}]")
}
case (fieldExpr, i) if fieldExpr.resultType != returnType.getInternalTypeAt(i) =>
throw new CodeGenException(
s"Incompatible types of expression and result type. Expression[$fieldExpr] type is " +
s"[${fieldExpr.resultType}], result type is [${returnType.getInternalTypeAt(i)}]")
case _ => // ok
}
def getOutputRowPos(fieldExprIdx: Int): Int =
fieldExprIdxToOutputRowPosMap.getOrElse(fieldExprIdx,
throw new CodeGenException(s"Illegal field expr index: $fieldExprIdx"))
def objectArrayRowWrite(
genUpdate: (Int, GeneratedExpression) => String): GeneratedExpression = {
val initReturnRecord = if (outRowAlreadyExists) {
""
} else {
ctx.addOutputRecord(returnType, returnTypeClazz, outRow, reused = reusedOutRow)
}
val resultBuffer: Seq[String] = fieldExprs.zipWithIndex map {
case (fieldExpr, i) =>
val idx = getOutputRowPos(i)
if (nullCheck) {
s"""
|${fieldExpr.code}
|if (${fieldExpr.nullTerm}) {
| $outRow.setNullAt($idx);
|} else {
| ${genUpdate(idx, fieldExpr)};
|}
|""".stripMargin
}
else {
s"""
|${fieldExpr.code}
|${genUpdate(idx, fieldExpr)};
|""".stripMargin
}
}
val statement =
s"""
|$initReturnRecord
|${resultBuffer.mkString("")}
""".stripMargin.trim
GeneratedExpression(outRow, "false", statement, returnType,
codeBuffer = resultBuffer, preceding = s"$initReturnRecord")
}
returnTypeClazz match {
case cls if cls == classOf[BinaryRow] =>
outRowWriter match {
case Some(writer) => // binary row writer.
val initReturnRecord = if (outRowAlreadyExists) {
""
} else {
ctx.addOutputRecord(returnType, returnTypeClazz, outRow, outRowWriter, reusedOutRow)
}
val resetWriter = if (nullCheck) s"$writer.reset();" else s"$writer.resetCursor();"
val completeWriter: String = s"$writer.complete();"
val resultBuffer = fieldExprs.zipWithIndex map {
case (fieldExpr, i) =>
val t = returnType.getInternalTypeAt(i)
val idx = getOutputRowPos(i)
val writeCode = binaryWriterWriteField(ctx, idx, fieldExpr.resultTerm, writer, t)
if (nullCheck) {
s"""
|${fieldExpr.code}
|if (${fieldExpr.nullTerm}) {
| ${binaryWriterWriteNull(idx, writer, t)};
|} else {
| $writeCode;
|}
|""".stripMargin.trim
} else {
s"""
|${fieldExpr.code}
|$writeCode;
|""".stripMargin.trim
}
}
val statement =
s"""
|$initReturnRecord
|$resetWriter
|${resultBuffer.mkString("\n")}
|$completeWriter
|""".stripMargin.trim
GeneratedExpression(outRow, "false", statement, returnType,
codeBuffer = resultBuffer,
preceding = s"$initReturnRecord\n$resetWriter",
flowing = s"$completeWriter")
case None => // update to binary row (setXXX).
checkArgument(outRowAlreadyExists)
val resultBuffer = fieldExprs.zipWithIndex map {
case (fieldExpr, i) =>
val t = returnType.getInternalTypeAt(i)
val idx = getOutputRowPos(i)
val writeCode = binaryRowFieldSetAccess(idx, outRow, t, fieldExpr.resultTerm)
if (nullCheck) {
s"""
|${fieldExpr.code}
|if (${fieldExpr.nullTerm}) {
| ${binaryRowSetNull(idx, outRow, t)};
|} else {
| $writeCode;
|}
|""".stripMargin.trim
} else {
s"""
|${fieldExpr.code}
|$writeCode;
|""".stripMargin.trim
}
}
GeneratedExpression(
outRow, "false", resultBuffer.mkString(""), returnType)
}
case cls if cls == classOf[GenericRow] =>
objectArrayRowWrite((i: Int, expr: GeneratedExpression) =>
s"$outRow.update($i, ${expr.resultTerm})")
case cls if cls == classOf[BoxedWrapperRow] =>
objectArrayRowWrite((i: Int, expr: GeneratedExpression) =>
boxedWrapperRowFieldUpdateAccess(i, expr.resultTerm, outRow, expr.resultType))
}
}
override def visitInputRef(inputRef: RexInputRef): GeneratedExpression = {
// if inputRef index is within size of input1 we work with input1, input2 otherwise
val input = if (inputRef.getIndex < TypeUtils.getArity(input1Type)) {
(input1Type, input1Term)
} else {
(input2Type.getOrElse(throw new CodeGenException("Invalid input access.")),
input2Term.getOrElse(throw new CodeGenException("Invalid input access.")))
}
val index = if (input._2 == input1Term) {
inputRef.getIndex
} else {
inputRef.getIndex - TypeUtils.getArity(input1Type)
}
generateInputAccess(ctx, input._1, input._2, index, nullableInput, nullCheck)
}
override def visitTableInputRef(rexTableInputRef: RexTableInputRef): GeneratedExpression =
visitInputRef(rexTableInputRef)
override def visitFieldAccess(rexFieldAccess: RexFieldAccess): GeneratedExpression = {
val refExpr = rexFieldAccess.getReferenceExpr.accept(this)
val index = rexFieldAccess.getField.getIndex
val fieldAccessExpr = generateFieldAccess(
ctx,
refExpr.resultType,
refExpr.resultTerm,
index,
nullCheck)
val resultTypeTerm = primitiveTypeTermForType(fieldAccessExpr.resultType)
val defaultValue = primitiveDefaultValue(fieldAccessExpr.resultType)
val Seq(resultTerm, nullTerm) = ctx.newReusableFields(
Seq("result", "isNull"),
Seq(resultTypeTerm, "boolean"))
val resultCode = if (nullCheck) {
s"""
|${refExpr.code}
|if (${refExpr.nullTerm}) {
| $resultTerm = $defaultValue;
| $nullTerm = true;
|}
|else {
| ${fieldAccessExpr.code}
| $resultTerm = ${fieldAccessExpr.resultTerm};
| $nullTerm = ${fieldAccessExpr.nullTerm};
|}
|""".stripMargin
} else {
s"""
|${refExpr.code}
|${fieldAccessExpr.code}
|$resultTerm = ${fieldAccessExpr.resultTerm};
|""".stripMargin
}
GeneratedExpression(resultTerm, nullTerm, resultCode, fieldAccessExpr.resultType)
}
override def visitLiteral(literal: RexLiteral): GeneratedExpression = {
val resultType = FlinkTypeFactory.toInternalType(literal.getType)
val value = literal.getValue3
generateLiteral(ctx, literal.getType, resultType, value, nullCheck)
}
override def visitCorrelVariable(correlVariable: RexCorrelVariable): GeneratedExpression = {
GeneratedExpression(input1Term, NEVER_NULL, NO_CODE, input1Type)
}
override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression = localRef match {
case localVar: RexAggBufferVariable =>
val resultTerm = localVar.getName
val nullTerm = resultTerm + "IsNull"
val pType = primitiveTypeTermForType(localVar.internalType)
ctx.addReusableMember(s"$pType $resultTerm;")
ctx.addReusableMember(s"boolean $nullTerm;")
GeneratedExpression(resultTerm, nullTerm, "", localVar.internalType)
case local: RexAggLocalVariable =>
GeneratedExpression(local.fieldTerm, local.nullTerm, NO_CODE, local.internalType)
case value: RexDistinctKeyVariable =>
val inputExpr = ctx.getReusableInputUnboxingExprs(input1Term, 0) match {
case Some(expr) => expr
case None =>
val pType = primitiveTypeTermForType(value.internalType)
val defaultValue = primitiveDefaultValue(value.internalType)
val resultTerm = newName("field")
val nullTerm = newName("isNull")
val code =
s"""
|$pType $resultTerm = $defaultValue;
|boolean $nullTerm = true;
|if ($input1Term != null) {
| $nullTerm = false;
| $resultTerm = ($pType) $input1Term;
|}
""".stripMargin
val expr = GeneratedExpression(resultTerm, nullTerm, code, value.internalType)
ctx.addReusableInputUnboxingExprs(input1Term, 0, expr)
expr
}
// hide the generated code as it will be executed only once
GeneratedExpression(inputExpr.resultTerm, inputExpr.nullTerm, NO_CODE, inputExpr.resultType)
case _ => throw new CodeGenException("Local variables are not supported yet.")
}
override def visitRangeRef(rangeRef: RexRangeRef): GeneratedExpression =
throw new CodeGenException("Range references are not supported yet.")
override def visitDynamicParam(dynamicParam: RexDynamicParam): GeneratedExpression =
throw new CodeGenException("Dynamic parameter references are not supported yet.")
override def visitCall(call: RexCall): GeneratedExpression = {
// special case: time materialization
if (call.getOperator == ProctimeSqlFunction) {
return generateProctimeTimestamp(contextTerm, ctx)
}
if (call.getOperator == StreamRecordTimestampSqlFunction) {
return generateRowtimeAccess(contextTerm, ctx)
}
val resultType = FlinkTypeFactory.toInternalType(call.getType)
// convert operands and help giving untyped NULL literals a type
val operands = call.getOperands.zipWithIndex.map {
// this helps e.g. for AS(null)
// we might need to extend this logic in case some rules do not create typed NULLs
case (operandLiteral: RexLiteral, 0) if
operandLiteral.getType.getSqlTypeName == SqlTypeName.NULL &&
call.getOperator.getReturnTypeInference == ReturnTypes.ARG0 =>
generateNullLiteral(resultType, nullCheck)
case (o@_, idx) => o.accept(this)
}
generateCallExpression(
ctx, call.getOperator, operands, resultType, nullCheck)
}
override def visitOver(over: RexOver): GeneratedExpression =
throw new CodeGenException("Aggregate functions over windows are not supported yet.")
override def visitSubQuery(subQuery: RexSubQuery): GeneratedExpression =
throw new CodeGenException("Subqueries are not supported yet.")
override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression =
throw new CodeGenException("Pattern field references are not supported yet.")
}