blob: a4a93f7be0822ed2203ecbde95f9a1e313733e31 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.nodes.dataset
import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.calcite.rex.{RexCall, RexNode}
import org.apache.calcite.sql.SemiJoinType
import org.apache.flink.api.common.functions.FlatMapFunction
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.table.api.{BatchQueryConfig, BatchTableEnvironment}
import org.apache.flink.table.functions.utils.TableSqlFunction
import org.apache.flink.table.plan.nodes.CommonCorrelate
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableFunctionScan
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.runtime.CorrelateFlatMapRunner
import org.apache.flink.types.Row
/**
* Flink RelNode which matches along with join a user defined table function.
*/
class DataSetCorrelate(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputNode: RelNode,
scan: FlinkLogicalTableFunctionScan,
condition: Option[RexNode],
relRowType: RelDataType,
joinRowType: RelDataType,
joinType: SemiJoinType,
ruleDescription: String)
extends SingleRel(cluster, traitSet, inputNode)
with CommonCorrelate
with DataSetRel {
override def deriveRowType() = relRowType
override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
val rowCnt = metadata.getRowCount(getInput) * 1.5
planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * 0.5)
}
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataSetCorrelate(
cluster,
traitSet,
inputs.get(0),
scan,
condition,
relRowType,
joinRowType,
joinType,
ruleDescription)
}
override def toString: String = {
val rexCall = scan.getCall.asInstanceOf[RexCall]
val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
correlateToString(joinRowType, rexCall, sqlFunction, getExpressionString)
}
override def explainTerms(pw: RelWriter): RelWriter = {
val rexCall = scan.getCall.asInstanceOf[RexCall]
val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
super.explainTerms(pw)
.item("invocation", scan.getCall)
.item("correlate", correlateToString(
inputNode.getRowType,
rexCall, sqlFunction,
getExpressionString))
.item("select", selectToString(relRowType))
.item("rowType", relRowType)
.item("joinType", joinType)
.itemIf("condition", condition.orNull, condition.isDefined)
}
override def translateToPlan(
tableEnv: BatchTableEnvironment,
queryConfig: BatchQueryConfig): DataSet[Row] = {
val config = tableEnv.getConfig
// we do not need to specify input type
val inputDS = inputNode.asInstanceOf[DataSetRel].translateToPlan(tableEnv, queryConfig)
val funcRel = scan.asInstanceOf[FlinkLogicalTableFunctionScan]
val rexCall = funcRel.getCall.asInstanceOf[RexCall]
val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
val pojoFieldMapping = Some(sqlFunction.getPojoFieldMapping)
val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]]
val flatMap = generateFunction(
config,
new RowSchema(getInput.getRowType),
udtfTypeInfo,
new RowSchema(getRowType),
joinType,
rexCall,
pojoFieldMapping,
ruleDescription,
classOf[FlatMapFunction[Row, Row]])
val collector = generateCollector(
config,
new RowSchema(getInput.getRowType),
udtfTypeInfo,
new RowSchema(getRowType),
condition,
pojoFieldMapping)
val mapFunc = new CorrelateFlatMapRunner[Row, Row](
flatMap.name,
flatMap.code,
collector.name,
collector.code,
flatMap.returnType)
inputDS
.flatMap(mapFunc)
.name(correlateOpName(
inputNode.getRowType,
rexCall,
sqlFunction,
relRowType,
getExpressionString)
)
}
}