| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.flink.table.planner.codegen |
| |
| import org.apache.flink.api.common.functions.{FlatMapFunction, Function} |
| import org.apache.flink.api.dag.Transformation |
| import org.apache.flink.table.api.{TableConfig, TableException, ValidationException} |
| import org.apache.flink.table.data.{BoxedWrapperRowData, RowData} |
| import org.apache.flink.table.functions.FunctionKind |
| import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction |
| import org.apache.flink.table.runtime.generated.GeneratedFunction |
| import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory |
| import org.apache.flink.table.runtime.typeutils.InternalTypeInfo |
| import org.apache.flink.table.types.logical.RowType |
| |
| import org.apache.calcite.rex._ |
| |
| import scala.collection.JavaConversions._ |
| |
| object CalcCodeGenerator { |
| |
| private[flink] def generateCalcOperator( |
| ctx: CodeGeneratorContext, |
| inputTransform: Transformation[RowData], |
| outputType: RowType, |
| calcProgram: RexProgram, |
| condition: Option[RexNode], |
| retainHeader: Boolean = false, |
| opName: String): CodeGenOperatorFactory[RowData] = { |
| val inputType = inputTransform.getOutputType |
| .asInstanceOf[InternalTypeInfo[RowData]] |
| .toRowType |
| // filter out time attributes |
| val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM |
| val processCode = generateProcessCode( |
| ctx, |
| inputType, |
| outputType, |
| classOf[BoxedWrapperRowData], |
| calcProgram, |
| condition, |
| eagerInputUnboxingCode = true, |
| retainHeader = retainHeader, |
| allowSplit = true) |
| |
| val genOperator = |
| OperatorCodeGenerator.generateOneInputStreamOperator[RowData, RowData]( |
| ctx, |
| opName, |
| processCode, |
| inputType, |
| inputTerm = inputTerm, |
| lazyInputUnboxingCode = true) |
| |
| new CodeGenOperatorFactory(genOperator) |
| } |
| |
| private[flink] def generateFunction[T <: Function]( |
| inputType: RowType, |
| name: String, |
| returnType: RowType, |
| outRowClass: Class[_ <: RowData], |
| calcProjection: RexProgram, |
| calcCondition: Option[RexNode], |
| config: TableConfig): GeneratedFunction[FlatMapFunction[RowData, RowData]] = { |
| val ctx = CodeGeneratorContext(config) |
| val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM |
| val collectorTerm = CodeGenUtils.DEFAULT_COLLECTOR_TERM |
| val processCode = generateProcessCode( |
| ctx, |
| inputType, |
| returnType, |
| outRowClass, |
| calcProjection, |
| calcCondition, |
| collectorTerm = collectorTerm, |
| eagerInputUnboxingCode = false, |
| outputDirectly = true |
| ) |
| |
| FunctionCodeGenerator.generateFunction( |
| ctx, |
| name, |
| classOf[FlatMapFunction[RowData, RowData]], |
| processCode, |
| returnType, |
| inputType, |
| input1Term = inputTerm, |
| collectorTerm = collectorTerm) |
| } |
| |
| private[flink] def generateProcessCode( |
| ctx: CodeGeneratorContext, |
| inputType: RowType, |
| outRowType: RowType, |
| outRowClass: Class[_ <: RowData], |
| calcProgram: RexProgram, |
| condition: Option[RexNode], |
| inputTerm: String = CodeGenUtils.DEFAULT_INPUT1_TERM, |
| collectorTerm: String = CodeGenUtils.DEFAULT_OPERATOR_COLLECTOR_TERM, |
| eagerInputUnboxingCode: Boolean, |
| retainHeader: Boolean = false, |
| outputDirectly: Boolean = false, |
| allowSplit: Boolean = false): String = { |
| |
| val projection = calcProgram.getProjectList.map(calcProgram.expandLocalRef) |
| |
| // according to the SQL standard, every table function should also be a scalar function |
| // but we don't allow that for now |
| projection.foreach(_.accept(ScalarFunctionsValidator)) |
| condition.foreach(_.accept(ScalarFunctionsValidator)) |
| |
| val exprGenerator = new ExprCodeGenerator(ctx, false) |
| .bindInput(inputType, inputTerm = inputTerm) |
| |
| val onlyFilter = projection.lengthCompare(inputType.getFieldCount) == 0 && |
| projection.zipWithIndex.forall { case (rexNode, index) => |
| rexNode.isInstanceOf[RexInputRef] && rexNode.asInstanceOf[RexInputRef].getIndex == index |
| } |
| |
| def produceOutputCode(resultTerm: String) = if (outputDirectly) { |
| s"$collectorTerm.collect($resultTerm);" |
| } else { |
| s"${OperatorCodeGenerator.generateCollect(resultTerm)}" |
| } |
| |
| def produceProjectionCode = { |
| val projectionExprs = projection.map(exprGenerator.generateExpression) |
| val projectionExpression = exprGenerator.generateResultExpression( |
| projectionExprs, |
| outRowType, |
| outRowClass, |
| allowSplit = allowSplit) |
| |
| val projectionExpressionCode = projectionExpression.code |
| |
| val header = if (retainHeader) { |
| s"${projectionExpression.resultTerm}.setRowKind($inputTerm.getRowKind());" |
| } else { |
| "" |
| } |
| |
| s""" |
| |$header |
| |$projectionExpressionCode |
| |${produceOutputCode(projectionExpression.resultTerm)} |
| |""".stripMargin |
| } |
| |
| if (condition.isEmpty && onlyFilter) { |
| throw new TableException("This calc has no useful projection and no filter. " + |
| "It should be removed by CalcRemoveRule.") |
| } else if (condition.isEmpty) { // only projection |
| val projectionCode = produceProjectionCode |
| s""" |
| |${if (eagerInputUnboxingCode) ctx.reuseInputUnboxingCode() else ""} |
| |$projectionCode |
| |""".stripMargin |
| } else { |
| val filterCondition = exprGenerator.generateExpression(condition.get) |
| // only filter |
| if (onlyFilter) { |
| s""" |
| |${if (eagerInputUnboxingCode) ctx.reuseInputUnboxingCode() else ""} |
| |${filterCondition.code} |
| |if (${filterCondition.resultTerm}) { |
| | ${produceOutputCode(inputTerm)} |
| |} |
| |""".stripMargin |
| } else { // both filter and projection |
| val filterInputCode = ctx.reuseInputUnboxingCode() |
| val filterInputSet = Set(ctx.reusableInputUnboxingExprs.keySet.toSeq: _*) |
| |
| // if any filter conditions, projection code will enter an new scope |
| val projectionCode = produceProjectionCode |
| |
| val projectionInputCode = ctx.reusableInputUnboxingExprs |
| .filter(entry => !filterInputSet.contains(entry._1)) |
| .values.map(_.code).mkString("\n") |
| s""" |
| |${if (eagerInputUnboxingCode) filterInputCode else ""} |
| |${filterCondition.code} |
| |if (${filterCondition.resultTerm}) { |
| | ${if (eagerInputUnboxingCode) projectionInputCode else ""} |
| | $projectionCode |
| |} |
| |""".stripMargin |
| } |
| } |
| } |
| |
| private object ScalarFunctionsValidator extends RexVisitorImpl[Unit](true) { |
| override def visitCall(call: RexCall): Unit = { |
| super.visitCall(call) |
| call.getOperator match { |
| case bsf: BridgingSqlFunction if bsf.getDefinition.getKind != FunctionKind.SCALAR => |
| throw new ValidationException( |
| s"Invalid use of function '$bsf'. " + |
| s"Currently, only scalar functions can be used in a projection or filter operation.") |
| case _ => // ok |
| } |
| } |
| } |
| } |