blob: 7f32e6d50aa379646a6bbd2eec161e109eef296b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.nodes.dataset
import java.lang.Iterable
import java.lang.{Boolean => JBool}
import org.apache.calcite.plan._
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{JoinInfo, JoinRelType}
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{BiRel, RelNode, RelWriter}
import org.apache.calcite.rex.RexNode
import org.apache.calcite.util.mapping.IntPair
import org.apache.flink.api.common.functions.{FilterFunction, FlatJoinFunction, GroupReduceFunction, JoinFunction}
import org.apache.flink.api.common.operators.Order
import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.table.api.{BatchQueryConfig, BatchTableEnvironment, TableConfig, TableException, Types}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.{FunctionCodeGenerator, GeneratedFunction}
import org.apache.flink.table.plan.nodes.CommonJoin
import org.apache.flink.table.runtime._
import org.apache.flink.types.Row
import org.apache.flink.util.Collector
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
/**
* Flink RelNode which matches along with JoinOperator and its related operations.
*/
class DataSetJoin(
cluster: RelOptCluster,
traitSet: RelTraitSet,
leftNode: RelNode,
rightNode: RelNode,
rowRelDataType: RelDataType,
joinCondition: RexNode,
joinRowType: RelDataType,
joinInfo: JoinInfo,
keyPairs: List[IntPair],
joinType: JoinRelType,
joinHint: JoinHint,
ruleDescription: String)
extends BiRel(cluster, traitSet, leftNode, rightNode)
with CommonJoin
with DataSetRel {
override def deriveRowType(): RelDataType = rowRelDataType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataSetJoin(
cluster,
traitSet,
inputs.get(0),
inputs.get(1),
getRowType,
joinCondition,
joinRowType,
joinInfo,
keyPairs,
joinType,
joinHint,
ruleDescription)
}
override def toString: String = {
joinToString(
joinRowType,
joinCondition,
joinType,
getExpressionString)
}
override def explainTerms(pw: RelWriter): RelWriter = {
joinExplainTerms(
super.explainTerms(pw),
joinRowType,
joinCondition,
joinType,
getExpressionString)
}
override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
val leftRowCnt = metadata.getRowCount(getLeft)
val leftRowSize = estimateRowSize(getLeft.getRowType)
val rightRowCnt = metadata.getRowCount(getRight)
val rightRowSize = estimateRowSize(getRight.getRowType)
val ioCost = (leftRowCnt * leftRowSize) + (rightRowCnt * rightRowSize)
val cpuCost = leftRowCnt + rightRowCnt
val rowCnt = leftRowCnt + rightRowCnt
planner.getCostFactory.makeCost(rowCnt, cpuCost, ioCost)
}
override def translateToPlan(
tableEnv: BatchTableEnvironment,
queryConfig: BatchQueryConfig): DataSet[Row] = {
val config = tableEnv.getConfig
val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType)
// get the equality keys
val leftKeys = ArrayBuffer.empty[Int]
val rightKeys = ArrayBuffer.empty[Int]
if (keyPairs.isEmpty) {
// if no equality keys => not supported
throw new TableException(
"Joins should have at least one equality condition.\n" +
s"\tLeft: ${left.toString},\n" +
s"\tRight: ${right.toString},\n" +
s"\tCondition: (${joinConditionToString(joinRowType,
joinCondition, getExpressionString)})"
)
}
else {
// at least one equality expression
val leftFields = left.getRowType.getFieldList
val rightFields = right.getRowType.getFieldList
keyPairs.foreach(pair => {
val leftKeyType = leftFields.get(pair.source).getType.getSqlTypeName
val rightKeyType = rightFields.get(pair.target).getType.getSqlTypeName
// check if keys are compatible
if (leftKeyType == rightKeyType) {
// add key pair
leftKeys.add(pair.source)
rightKeys.add(pair.target)
} else {
throw new TableException(
"Equality join predicate on incompatible types.\n" +
s"\tLeft: ${left.toString},\n" +
s"\tRight: ${right.toString},\n" +
s"\tCondition: (${joinConditionToString(joinRowType,
joinCondition, getExpressionString)})"
)
}
})
}
val leftDataSet = left.asInstanceOf[DataSetRel].translateToPlan(tableEnv, queryConfig)
val rightDataSet = right.asInstanceOf[DataSetRel].translateToPlan(tableEnv, queryConfig)
joinType match {
case JoinRelType.INNER =>
addInnerJoin(
leftDataSet,
rightDataSet,
leftKeys.toArray,
rightKeys.toArray,
returnType,
config)
case JoinRelType.LEFT =>
addLeftOuterJoin(
leftDataSet,
rightDataSet,
leftKeys.toArray,
rightKeys.toArray,
returnType,
config)
case JoinRelType.RIGHT =>
addRightOuterJoin(
leftDataSet,
rightDataSet,
leftKeys.toArray,
rightKeys.toArray,
returnType,
config)
case JoinRelType.FULL =>
addFullOuterJoin(
leftDataSet,
rightDataSet,
leftKeys.toArray,
rightKeys.toArray,
returnType,
config)
}
}
private def addInnerJoin(
left: DataSet[Row],
right: DataSet[Row],
leftKeys: Array[Int],
rightKeys: Array[Int],
resultType: TypeInformation[Row],
config: TableConfig): DataSet[Row] = {
val generator = new FunctionCodeGenerator(
config,
false,
left.getType,
Some(right.getType))
val conversion = generator.generateConverterResultExpression(
resultType,
joinRowType.getFieldNames)
val condition = generator.generateExpression(joinCondition)
val body =
s"""
|${condition.code}
|if (${condition.resultTerm}) {
| ${conversion.code}
| ${generator.collectorTerm}.collect(${conversion.resultTerm});
|}
|""".stripMargin
val genFunction = generator.generateFunction(
ruleDescription,
classOf[FlatJoinFunction[Row, Row, Row]],
body,
resultType)
val joinFun = new FlatJoinRunner[Row, Row, Row](
genFunction.name,
genFunction.code,
genFunction.returnType)
left.join(right)
.where(leftKeys: _*)
.equalTo(rightKeys: _*)
.`with`(joinFun)
.name(getJoinOpName)
}
private def addLeftOuterJoin(
left: DataSet[Row],
right: DataSet[Row],
leftKeys: Array[Int],
rightKeys: Array[Int],
resultType: TypeInformation[Row],
config: TableConfig): DataSet[Row] = {
if (!config.getNullCheck) {
throw new TableException("Null check in TableConfig must be enabled for outer joins.")
}
val joinOpName = getJoinOpName
// replace field names by indexed names for easier key handling
val leftType = new RowTypeInfo(left.getType.asInstanceOf[RowTypeInfo].getFieldTypes: _*)
val rightType = right.getType.asInstanceOf[RowTypeInfo]
// partition and sort left input
// this step ensures we can reuse the sorting for all following operations
// (groupBy->join->groupBy)
val partitionedSortedLeft: DataSet[Row] = partitionAndSort(left, leftKeys)
// fold identical rows of the left input
val foldedRowsLeft: DataSet[Row] = foldIdenticalRows(partitionedSortedLeft, leftType)
// create JoinFunction to evaluate join predicate
val predFun = generatePredicateFunction(leftType, rightType, config)
val joinOutType = new RowTypeInfo(leftType, rightType, Types.INT)
val joinFun = new LeftOuterJoinRunner(predFun.name, predFun.code, joinOutType)
// join left and right inputs, evaluate join predicate, and emit join pairs
val nestedLeftKeys = leftKeys.map(i => s"f0.f$i")
val joinPairs = foldedRowsLeft.leftOuterJoin(right, JoinHint.REPARTITION_SORT_MERGE)
.where(nestedLeftKeys: _*)
.equalTo(rightKeys: _*)
.`with`(joinFun)
.withForwardedFieldsFirst("f0->f0")
.name(joinOpName)
// create GroupReduceFunction to generate the join result
val convFun = generateConversionFunction(leftType, rightType, resultType, config)
val reduceFun = new LeftOuterJoinGroupReduceRunner(
convFun.name,
convFun.code,
convFun.returnType)
// convert join pairs to result.
// This step ensures we preserve the rows of the left input.
joinPairs
.groupBy("f0")
.reduceGroup(reduceFun)
.name(joinOpName)
.returns(resultType)
}
private def addRightOuterJoin(
left: DataSet[Row],
right: DataSet[Row],
leftKeys: Array[Int],
rightKeys: Array[Int],
resultType: TypeInformation[Row],
config: TableConfig): DataSet[Row] = {
if (!config.getNullCheck) {
throw new TableException("Null check in TableConfig must be enabled for outer joins.")
}
val joinOpName = getJoinOpName
// replace field names by indexed names for easier key handling
val leftType = left.getType.asInstanceOf[RowTypeInfo]
val rightType = new RowTypeInfo(right.getType.asInstanceOf[RowTypeInfo].getFieldTypes: _*)
// partition and sort right input
// this step ensures we can reuse the sorting for all following operations
// (groupBy->join->groupBy)
val partitionedSortedRight: DataSet[Row] = partitionAndSort(right, rightKeys)
// fold identical rows of the right input
val foldedRowsRight: DataSet[Row] = foldIdenticalRows(partitionedSortedRight, rightType)
// create JoinFunction to evaluate join predicate
val predFun = generatePredicateFunction(leftType, rightType, config)
val joinOutType = new RowTypeInfo(leftType, rightType, Types.INT)
val joinFun = new RightOuterJoinRunner(predFun.name, predFun.code, joinOutType)
// join left and right inputs, evaluate join predicate, and emit join pairs
val nestedRightKeys = rightKeys.map(i => s"f0.f$i")
val joinPairs = left.rightOuterJoin(foldedRowsRight, JoinHint.REPARTITION_SORT_MERGE)
.where(leftKeys: _*)
.equalTo(nestedRightKeys: _*)
.`with`(joinFun)
.withForwardedFieldsSecond("f0->f1")
.name(joinOpName)
// create GroupReduceFunction to generate the join result
val convFun = generateConversionFunction(leftType, rightType, resultType, config)
val reduceFun = new RightOuterJoinGroupReduceRunner(
convFun.name,
convFun.code,
convFun.returnType)
// convert join pairs to result
// This step ensures we preserve the rows of the right input.
joinPairs
.groupBy("f1")
.reduceGroup(reduceFun)
.name(joinOpName)
.returns(resultType)
}
private def addFullOuterJoin(
left: DataSet[Row],
right: DataSet[Row],
leftKeys: Array[Int],
rightKeys: Array[Int],
resultType: TypeInformation[Row],
config: TableConfig): DataSet[Row] = {
if (!config.getNullCheck) {
throw new TableException("Null check in TableConfig must be enabled for outer joins.")
}
val joinOpName = getJoinOpName
// replace field names by indexed names for easier key handling
val leftType = new RowTypeInfo(left.getType.asInstanceOf[RowTypeInfo].getFieldTypes: _*)
val rightType = new RowTypeInfo(right.getType.asInstanceOf[RowTypeInfo].getFieldTypes: _*)
// partition and sort left and right input
// this step ensures we can reuse the sorting for all following operations
// (groupBy->join->groupBy), except the second grouping to preserve right rows.
val partitionedSortedLeft: DataSet[Row] = partitionAndSort(left, leftKeys)
val partitionedSortedRight: DataSet[Row] = partitionAndSort(right, rightKeys)
// fold identical rows of the left and right input
val foldedRowsLeft: DataSet[Row] = foldIdenticalRows(partitionedSortedLeft, leftType)
val foldedRowsRight: DataSet[Row] = foldIdenticalRows(partitionedSortedRight, rightType)
// create JoinFunction to evaluate join predicate
val predFun = generatePredicateFunction(leftType, rightType, config)
val joinOutType = new RowTypeInfo(leftType, rightType, Types.INT, Types.INT)
val joinFun = new FullOuterJoinRunner(predFun.name, predFun.code, joinOutType)
// join left and right inputs, evaluate join predicate, and emit join pairs
val nestedLeftKeys = leftKeys.map(i => s"f0.f$i")
val nestedRightKeys = rightKeys.map(i => s"f0.f$i")
val joinPairs = foldedRowsLeft
.fullOuterJoin(foldedRowsRight, JoinHint.REPARTITION_SORT_MERGE)
.where(nestedLeftKeys: _*)
.equalTo(nestedRightKeys: _*)
.`with`(joinFun)
.withForwardedFieldsFirst("f0->f0")
.withForwardedFieldsSecond("f0->f1")
.name(joinOpName)
// create GroupReduceFunctions to generate the join result
val convFun = generateConversionFunction(leftType, rightType, resultType, config)
val leftReduceFun = new LeftFullOuterJoinGroupReduceRunner(
convFun.name,
convFun.code,
convFun.returnType)
val rightReduceFun = new RightFullOuterJoinGroupReduceRunner(
convFun.name,
convFun.code,
convFun.returnType)
// compute joined (left + right) and left preserved (left + null)
val joinedAndLeftPreserved = joinPairs
// filter for pairs with left row
.filter(new FilterFunction[Row](){
override def filter(row: Row): Boolean = row.getField(0) != null})
.groupBy("f0")
.reduceGroup(leftReduceFun)
.name(joinOpName)
.returns(resultType)
// compute right preserved (null + right)
val rightPreserved = joinPairs
// filter for pairs with right row
.filter(new FilterFunction[Row](){
override def filter(row: Row): Boolean = row.getField(1) != null})
.groupBy("f1")
.reduceGroup(rightReduceFun)
.name(joinOpName)
.returns(resultType)
// union joined (left + right), left preserved (left + null), and right preserved (null + right)
joinedAndLeftPreserved.union(rightPreserved)
}
private def getJoinOpName: String = {
s"where: (${joinConditionToString(joinRowType, joinCondition, getExpressionString)}), " +
s"join: (${joinSelectionToString(joinRowType)})"
}
/** Returns an array of indices with some indices being a prefix. */
private def getFullIndiciesWithPrefix(keys: Array[Int], numFields: Int): Array[Int] = {
// get indices of all fields which are not keys
val nonKeys = (0 until numFields).filter(!keys.contains(_))
// return all field indices prefixed by keys
keys ++ nonKeys
}
/**
* Partitions the data set on the join keys and sort it on all field with the join keys being a
* prefix.
*/
private def partitionAndSort(
dataSet: DataSet[Row],
partitionKeys: Array[Int]): DataSet[Row] = {
// construct full sort keys with partitionKeys being a prefix
val sortKeys = getFullIndiciesWithPrefix(partitionKeys, dataSet.getType.getArity)
// partition
val partitioned: DataSet[Row] = dataSet.partitionByHash(partitionKeys: _*)
// sort on all fields
sortKeys.foldLeft(partitioned: DataSet[Row]) { (d, i) =>
d.sortPartition(i, Order.ASCENDING).asInstanceOf[DataSet[Row]]
}
}
/**
* Folds identical rows of a data set into a single row with a duplicate count.
*/
private def foldIdenticalRows(
dataSet: DataSet[Row],
dataSetType: TypeInformation[Row]): DataSet[Row] = {
val resultType = new RowTypeInfo(dataSetType, Types.INT)
val groupKeys = 0 until dataSetType.getArity
dataSet
// group on all fields of the input row
.groupBy(groupKeys: _*)
// fold identical rows
.reduceGroup(new GroupReduceFunction[Row, Row] {
val outTuple = new Row(2)
override def reduce(values: Iterable[Row], out: Collector[Row]): Unit = {
// count number of duplicates
var cnt = 0
val it = values.iterator()
while (it.hasNext) {
// set output row
outTuple.setField(0, it.next())
cnt += 1
}
// set count
outTuple.setField(1, cnt)
// emit folded row with count
out.collect(outTuple)
}
})
.returns(resultType)
.withForwardedFields("*->f0")
.name("fold identical rows")
}
/**
* Generates a [[GeneratedFunction]] of a [[JoinFunction]] to evaluate the join predicate.
* The function returns the result of the predicate as [[JBool]].
*/
private def generatePredicateFunction(
leftType: TypeInformation[Row],
rightType: TypeInformation[Row],
config: TableConfig): GeneratedFunction[JoinFunction[Row, Row, JBool], JBool] = {
val predGenerator = new FunctionCodeGenerator(config, false, leftType, Some(rightType))
val condition = predGenerator.generateExpression(joinCondition)
val predCode =
s"""
|${condition.code}
|return (${condition.resultTerm});
|""".stripMargin
predGenerator.generateFunction(
"OuterJoinPredicate",
classOf[JoinFunction[Row, Row, JBool]],
predCode,
Types.BOOLEAN)
}
/**
* Generates a [[GeneratedFunction]] of a [[JoinFunction]] to produce the join result.
*/
private def generateConversionFunction(
leftType: TypeInformation[Row],
rightType: TypeInformation[Row],
resultType: TypeInformation[Row],
config: TableConfig): GeneratedFunction[JoinFunction[Row, Row, Row], Row] = {
val conversionGenerator = new FunctionCodeGenerator(config, true, leftType, Some(rightType))
val conversion = conversionGenerator.generateConverterResultExpression(
resultType,
joinRowType.getFieldNames)
val convCode =
s"""
|${conversion.code}
|return ${conversion.resultTerm};
|""".stripMargin
conversionGenerator.generateFunction(
"OuterJoinConverter",
classOf[JoinFunction[Row, Row, Row]],
convCode,
resultType)
}
}