blob: 2fa275db7f6b7622bf9bc7ffdd397c748ad4f534 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.codegen
import org.apache.flink.api.common.functions.{FlatMapFunction, Function}
import org.apache.flink.api.dag.Transformation
import org.apache.flink.table.api.{TableConfig, TableException, ValidationException}
import org.apache.flink.table.data.{BoxedWrapperRowData, RowData}
import org.apache.flink.table.functions.FunctionKind
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction
import org.apache.flink.table.runtime.generated.GeneratedFunction
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
import org.apache.flink.table.types.logical.RowType
import org.apache.calcite.rex._
import scala.collection.JavaConversions._
object CalcCodeGenerator {
private[flink] def generateCalcOperator(
ctx: CodeGeneratorContext,
inputTransform: Transformation[RowData],
outputType: RowType,
calcProgram: RexProgram,
condition: Option[RexNode],
retainHeader: Boolean = false,
opName: String): CodeGenOperatorFactory[RowData] = {
val inputType = inputTransform.getOutputType
.asInstanceOf[InternalTypeInfo[RowData]]
.toRowType
// filter out time attributes
val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
val processCode = generateProcessCode(
ctx,
inputType,
outputType,
classOf[BoxedWrapperRowData],
calcProgram,
condition,
eagerInputUnboxingCode = true,
retainHeader = retainHeader,
allowSplit = true)
val genOperator =
OperatorCodeGenerator.generateOneInputStreamOperator[RowData, RowData](
ctx,
opName,
processCode,
inputType,
inputTerm = inputTerm,
lazyInputUnboxingCode = true)
new CodeGenOperatorFactory(genOperator)
}
private[flink] def generateFunction[T <: Function](
inputType: RowType,
name: String,
returnType: RowType,
outRowClass: Class[_ <: RowData],
calcProjection: RexProgram,
calcCondition: Option[RexNode],
config: TableConfig): GeneratedFunction[FlatMapFunction[RowData, RowData]] = {
val ctx = CodeGeneratorContext(config)
val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
val collectorTerm = CodeGenUtils.DEFAULT_COLLECTOR_TERM
val processCode = generateProcessCode(
ctx,
inputType,
returnType,
outRowClass,
calcProjection,
calcCondition,
collectorTerm = collectorTerm,
eagerInputUnboxingCode = false,
outputDirectly = true
)
FunctionCodeGenerator.generateFunction(
ctx,
name,
classOf[FlatMapFunction[RowData, RowData]],
processCode,
returnType,
inputType,
input1Term = inputTerm,
collectorTerm = collectorTerm)
}
private[flink] def generateProcessCode(
ctx: CodeGeneratorContext,
inputType: RowType,
outRowType: RowType,
outRowClass: Class[_ <: RowData],
calcProgram: RexProgram,
condition: Option[RexNode],
inputTerm: String = CodeGenUtils.DEFAULT_INPUT1_TERM,
collectorTerm: String = CodeGenUtils.DEFAULT_OPERATOR_COLLECTOR_TERM,
eagerInputUnboxingCode: Boolean,
retainHeader: Boolean = false,
outputDirectly: Boolean = false,
allowSplit: Boolean = false): String = {
val projection = calcProgram.getProjectList.map(calcProgram.expandLocalRef)
// according to the SQL standard, every table function should also be a scalar function
// but we don't allow that for now
projection.foreach(_.accept(ScalarFunctionsValidator))
condition.foreach(_.accept(ScalarFunctionsValidator))
val exprGenerator = new ExprCodeGenerator(ctx, false)
.bindInput(inputType, inputTerm = inputTerm)
val onlyFilter = projection.lengthCompare(inputType.getFieldCount) == 0 &&
projection.zipWithIndex.forall { case (rexNode, index) =>
rexNode.isInstanceOf[RexInputRef] && rexNode.asInstanceOf[RexInputRef].getIndex == index
}
def produceOutputCode(resultTerm: String) = if (outputDirectly) {
s"$collectorTerm.collect($resultTerm);"
} else {
s"${OperatorCodeGenerator.generateCollect(resultTerm)}"
}
def produceProjectionCode = {
val projectionExprs = projection.map(exprGenerator.generateExpression)
val projectionExpression = exprGenerator.generateResultExpression(
projectionExprs,
outRowType,
outRowClass,
allowSplit = allowSplit)
val projectionExpressionCode = projectionExpression.code
val header = if (retainHeader) {
s"${projectionExpression.resultTerm}.setRowKind($inputTerm.getRowKind());"
} else {
""
}
s"""
|$header
|$projectionExpressionCode
|${produceOutputCode(projectionExpression.resultTerm)}
|""".stripMargin
}
if (condition.isEmpty && onlyFilter) {
throw new TableException("This calc has no useful projection and no filter. " +
"It should be removed by CalcRemoveRule.")
} else if (condition.isEmpty) { // only projection
val projectionCode = produceProjectionCode
s"""
|${if (eagerInputUnboxingCode) ctx.reuseInputUnboxingCode() else ""}
|$projectionCode
|""".stripMargin
} else {
val filterCondition = exprGenerator.generateExpression(condition.get)
// only filter
if (onlyFilter) {
s"""
|${if (eagerInputUnboxingCode) ctx.reuseInputUnboxingCode() else ""}
|${filterCondition.code}
|if (${filterCondition.resultTerm}) {
| ${produceOutputCode(inputTerm)}
|}
|""".stripMargin
} else { // both filter and projection
val filterInputCode = ctx.reuseInputUnboxingCode()
val filterInputSet = Set(ctx.reusableInputUnboxingExprs.keySet.toSeq: _*)
// if any filter conditions, projection code will enter an new scope
val projectionCode = produceProjectionCode
val projectionInputCode = ctx.reusableInputUnboxingExprs
.filter(entry => !filterInputSet.contains(entry._1))
.values.map(_.code).mkString("\n")
s"""
|${if (eagerInputUnboxingCode) filterInputCode else ""}
|${filterCondition.code}
|if (${filterCondition.resultTerm}) {
| ${if (eagerInputUnboxingCode) projectionInputCode else ""}
| $projectionCode
|}
|""".stripMargin
}
}
}
private object ScalarFunctionsValidator extends RexVisitorImpl[Unit](true) {
override def visitCall(call: RexCall): Unit = {
super.visitCall(call)
call.getOperator match {
case bsf: BridgingSqlFunction if bsf.getDefinition.getKind != FunctionKind.SCALAR =>
throw new ValidationException(
s"Invalid use of function '$bsf'. " +
s"Currently, only scalar functions can be used in a projection or filter operation.")
case _ => // ok
}
}
}
}