blob: 41da51bdb2917c1bd2b54eaf89e1c22a46b1fff8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.nodes.common
import java.lang.reflect.{Method, Modifier}
import java.util
import java.util.Collections
import com.google.common.primitives.Primitives
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField}
import org.apache.calcite.rel.core.{JoinInfo, JoinRelType}
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlKind
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.sql.validate.SqlValidatorUtil
import org.apache.calcite.tools.RelBuilder
import org.apache.calcite.util.mapping.IntPair
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.{RowTypeInfo, TypeExtractor}
import org.apache.flink.streaming.api.datastream.AsyncDataStream.OutputMode
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
import org.apache.flink.streaming.api.functions.async.ResultFuture
import org.apache.flink.streaming.api.operators.ProcessOperator
import org.apache.flink.streaming.api.operators.async.AsyncWaitOperator
import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation}
import org.apache.flink.table.api.functions.{AsyncTableFunction, TableFunction, UserDefinedFunction}
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.types.{DataType, InternalType, TypeConverters}
import org.apache.flink.table.api.{TableConfig, TableException, ValidationException}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.TemporalJoinCodeGenerator._
import org.apache.flink.table.codegen.{CodeGeneratorContext, TemporalJoinCodeGenerator}
import org.apache.flink.table.dataformat.{BaseRow, BinaryRow, GenericRow, JoinedRow}
import org.apache.flink.table.functions.sql.ScalarSqlFunctions
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{checkAndExtractMethods, signatureToString, signaturesToString}
import org.apache.flink.table.plan.nodes.FlinkRelNode
import org.apache.flink.table.plan.schema.{BaseRowSchema, TimeIndicatorRelDataType}
import org.apache.flink.table.plan.util.TemporalJoinUtil._
import org.apache.flink.table.plan.util.{CalcUtil, RexLiteralUtil}
import org.apache.flink.table.runtime.join.{TemporalTableJoinAsyncRunner, TemporalTableJoinProcessRunner, TemporalTableJoinWithCalcAsyncRunner, TemporalTableJoinWithCalcProcessRunner}
import org.apache.flink.table.sources.{IndexKey, LookupConfig, LookupableTableSource, TableSource}
import org.apache.flink.table.typeutils.{BaseRowTypeInfo, TypeUtils}
import org.apache.flink.table.util.TableConnectorUtil
import org.apache.flink.types.Row
import scala.collection.JavaConversions._
import scala.collection.mutable
/**
* Common abstract RelNode for temporal table join which shares most methods.
* @param input input rel node
* @param tableSource the table source to be temporal joined
* @param tableRowType the row type of the table source
* @param tableCalcProgram the calc (projection&filter) after table scan before joining
* @param period the point in time to snapshot
*/
abstract class CommonTemporalTableJoin(
cluster: RelOptCluster,
traitSet: RelTraitSet,
input: RelNode,
val tableSource: TableSource,
tableRowType: RelDataType,
val tableCalcProgram: Option[RexProgram],
period: RexNode,
val joinInfo: JoinInfo,
val joinType: JoinRelType)
extends SingleRel(cluster, traitSet, input)
with FlinkRelNode {
val joinKeyPairs: util.List[IntPair] = getTemporalTableJoinKeyPairs(joinInfo, tableCalcProgram)
val indexKeys: util.List[IndexKey] = getTableIndexKeys(tableSource)
// constant keys which maybe empty if calc program is None
val constantLookupKeys: util.Map[Int, (InternalType, Object)] = analyzeConstantLookupKeys(
cluster,
tableCalcProgram,
indexKeys)
val joinedIndex: Option[IndexKey] = findMatchedIndex(indexKeys, joinKeyPairs, constantLookupKeys)
override def deriveRowType(): RelDataType = {
val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
val rightType = if (tableCalcProgram.isDefined) {
tableCalcProgram.get.getOutputRowType
} else {
tableRowType
}
SqlValidatorUtil.deriveJoinRowType(
input.getRowType,
rightType,
joinType,
flinkTypeFactory,
null,
Collections.emptyList[RelDataTypeField])
}
override def explainTerms(pw: RelWriter): RelWriter = {
val remaining = joinInfo.getRemaining(cluster.getRexBuilder)
val joinCondition = if (remaining.isAlwaysTrue) {
None
} else {
Some(remaining)
}
joinExplainTerms(
super.explainTerms(pw),
tableSource,
input.getRowType,
getRowType,
tableCalcProgram,
joinInfo.pairs(),
joinCondition,
joinType,
period,
getExpressionString)
}
// ----------------------------------------------------------------------------------------
// Physical Translation
// ----------------------------------------------------------------------------------------
def translateToPlanInternal(
inputTransformation: StreamTransformation[BaseRow],
env: StreamExecutionEnvironment,
config: TableConfig,
relBuilder: RelBuilder): StreamTransformation[BaseRow] = {
val inputSchema = new BaseRowSchema(input.getRowType)
val tableSchema = new BaseRowSchema(tableRowType)
val resultSchema = new BaseRowSchema(getRowType)
val inputBaseRowType = inputSchema.internalType()
val tableBaseRowType = tableSchema.internalType()
val resultBaseRowType = resultSchema.internalType()
val resultBaseRowTypeInfo = resultSchema.typeInfo()
val tableReturnTypeInfo =
TypeConverters.createExternalTypeInfoFromDataType(tableSource.getReturnType)
val tableReturnClass = CommonScan.extractTableSourceTypeClass(tableSource)
// validate whether the node is valid and supported.
validate(
tableSource,
period,
inputSchema,
tableSchema,
joinKeyPairs,
constantLookupKeys,
indexKeys,
joinedIndex,
joinType)
val checkedIndexInOrder = joinedIndex.get.getDefinedColumns.map(_.intValue()).toArray
val indexFieldTypes = checkedIndexInOrder.map(tableSchema.fieldTypes(_))
val remainingCondition = getRemainingJoinCondition(
cluster.getRexBuilder,
relBuilder,
input.getRowType,
tableRowType,
tableCalcProgram,
checkedIndexInOrder,
joinKeyPairs,
joinInfo,
constantLookupKeys)
val lookupKeysFromConstant: Map[Int, RexLiteral] = constantLookupKeys.toMap.map {
case (i, (_, o)) => (i, relBuilder.literal(o).asInstanceOf[RexLiteral])
}
val lookupKeyPairs = joinKeyPairs.filter(p => checkedIndexInOrder.contains(p.target))
// lookup key index -> input field index
val lookupKey2InputFieldIndex: Map[Int, Int] = lookupKeyPairs
.map { k => (k.target, k.source) }
.toMap
val lookupableTableSource = tableSource.asInstanceOf[LookupableTableSource[_]]
val lookupConfig = if (lookupableTableSource.getLookupConfig != null) {
lookupableTableSource.getLookupConfig
} else {
new LookupConfig
}
val leftOuterJoin = joinType == JoinRelType.LEFT
val operator = if (lookupConfig.isAsyncEnabled) {
val asyncBufferCapacity= lookupConfig.getAsyncBufferCapacity
val asyncTimeout = lookupConfig.getAsyncTimeoutMs
val asyncOutputMode = lookupConfig.getAsyncOutputMode
val asyncTableFunction = lookupableTableSource.getAsyncLookupFunction(checkedIndexInOrder)
val parameters = Array(classOf[ResultFuture[_]]) ++
indexFieldTypes.map(TypeUtils.getInternalClassForType(_))
val method = getSignatureMatchedEvalMethod(
asyncTableFunction,
parameters)
// eval method valid check
if (method.isEmpty) {
val msg = s"Given parameter types of the async lookup TableFunction of TableSource " +
s"'${tableSource.explainSource()}' do not match the expected signature.\n" +
s"Expected: eval${signatureToString(parameters)} \n" +
s"Actual: eval${signaturesToString(asyncTableFunction, "eval")}"
throw new TableException(msg)
}
// return type valid check
val udtfResultType = asyncTableFunction.getResultType(Array(), Array())
val extractedResultTypeInfo = TypeExtractor.createTypeInfo(
asyncTableFunction,
classOf[AsyncTableFunction[_]],
asyncTableFunction.getClass,
0)
checkUdtfReturnType(
tableSource.explainSource(),
tableReturnTypeInfo,
udtfResultType,
extractedResultTypeInfo)
val generatedFetcher = TemporalJoinCodeGenerator.generateAsyncLookupFunction(
config,
relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory],
inputBaseRowType,
resultBaseRowType,
tableReturnTypeInfo,
tableReturnClass,
checkedIndexInOrder,
lookupKey2InputFieldIndex,
lookupKeysFromConstant,
asyncTableFunction)
val asyncFunc = if (tableCalcProgram.isDefined) {
// a projection or filter after table source scan
val calcSchema = new BaseRowSchema(tableCalcProgram.get.getOutputRowType)
val rightTypeInfo = calcSchema.internalType
val collector = generateAsyncCollector(
config,
inputBaseRowType,
rightTypeInfo,
remainingCondition)
val calcMap = generateCalcMapFunction(config, tableCalcProgram, tableSchema)
new TemporalTableJoinWithCalcAsyncRunner(
generatedFetcher.name,
generatedFetcher.code,
calcMap.name,
calcMap.code,
collector.name,
collector.code,
asyncBufferCapacity,
leftOuterJoin,
inputSchema.fieldTypes.toArray,
resultBaseRowTypeInfo)
} else {
val collector = generateAsyncCollector(
config,
inputBaseRowType,
tableBaseRowType,
remainingCondition)
new TemporalTableJoinAsyncRunner(
generatedFetcher.name,
generatedFetcher.code,
collector.name,
collector.code,
asyncBufferCapacity,
leftOuterJoin,
inputSchema.fieldTypes.toArray,
resultBaseRowTypeInfo)
}
val mode = if (asyncOutputMode == LookupConfig.AsyncOutputMode.ORDERED) {
OutputMode.ORDERED
} else {
OutputMode.UNORDERED
}
new AsyncWaitOperator(asyncFunc, asyncTimeout, asyncBufferCapacity, mode)
} else {
// sync join
val lookupFunction = lookupableTableSource.getLookupFunction(checkedIndexInOrder)
val parameters = indexFieldTypes.map(TypeUtils.getInternalClassForType(_))
val method = getSignatureMatchedEvalMethod(
lookupFunction,
parameters)
// valid check
if (method.isEmpty) {
val msg = s"Given parameter types of the lookup TableFunction of TableSource " +
s"'${tableSource.explainSource()}' do not match the expected signature.\n" +
s"Expected: eval${signatureToString(parameters)} \n" +
s"Actual: eval${signaturesToString(lookupFunction, "eval")}"
throw new TableException(msg)
}
// return type valid check
val udtfResultType = lookupFunction.getResultType(Array(), Array())
val extractedResultTypeInfo = TypeExtractor.createTypeInfo(
lookupFunction,
classOf[TableFunction[_]],
lookupFunction.getClass,
0)
checkUdtfReturnType(
tableSource.explainSource(),
tableReturnTypeInfo,
udtfResultType,
extractedResultTypeInfo)
val generatedFetcher = TemporalJoinCodeGenerator.generateLookupFunction(
config,
relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory],
inputBaseRowType,
resultBaseRowType,
tableReturnTypeInfo,
tableReturnClass,
checkedIndexInOrder,
lookupKey2InputFieldIndex,
lookupKeysFromConstant,
lookupFunction,
env.getConfig.isObjectReuseEnabled)
val ctx = CodeGeneratorContext(config)
val processFunc = if (tableCalcProgram.isDefined) {
// a projection or filter after table source scan
val calcSchema = new BaseRowSchema(tableCalcProgram.get.getOutputRowType)
val rightTypeInfo = calcSchema.internalType()
val collector = generateCollector(
ctx,
config,
inputBaseRowType,
rightTypeInfo,
resultBaseRowType,
remainingCondition,
None)
val calcMap = generateCalcMapFunction(config, tableCalcProgram, tableSchema)
new TemporalTableJoinWithCalcProcessRunner(
generatedFetcher.name,
generatedFetcher.code,
calcMap.name,
calcMap.code,
collector.name,
collector.code,
leftOuterJoin,
inputSchema.fieldTypes.toArray,
resultBaseRowTypeInfo)
} else {
val collector = generateCollector(
ctx,
config,
inputBaseRowType,
tableBaseRowType,
resultBaseRowType,
remainingCondition,
None)
new TemporalTableJoinProcessRunner(
generatedFetcher.name,
generatedFetcher.code,
collector.name,
collector.code,
leftOuterJoin,
inputSchema.fieldTypes.toArray,
resultBaseRowTypeInfo)
}
new ProcessOperator(processFunc)
}
val operatorName = joinToString(
lookupableTableSource,
joinType,
resultSchema,
inputSchema,
tableSchema,
remainingCondition,
constantLookupKeys,
joinKeyPairs,
getExpressionString)
new OneInputTransformation(
inputTransformation,
operatorName,
operator,
TypeConverters.toBaseRowTypeInfo(resultBaseRowType),
inputTransformation.getParallelism)
}
private def rowTypeEquals(expected: TypeInformation[_], actual: TypeInformation[_]): Boolean = {
// check internal and external type, cause we will auto convert external class to internal
// class (eg: Row => BaseRow).
// check both type because GenericType<Row> and GenericType<BaseRow>.
TypeUtils.getExternalClassForType(expected) == TypeUtils.getExternalClassForType(actual) ||
TypeUtils.getInternalClassForType(expected) == TypeUtils.getInternalClassForType(actual)
}
private def getSignatureMatchedEvalMethod(
function: UserDefinedFunction,
methodSignature: Array[Class[_]]): Option[Method] = {
val methods = checkAndExtractMethods(function, "eval")
var applyCnt = 0
val filtered = methods
// go over all the methods and filter out matching methods
.filter {
case cur if !cur.isVarArgs =>
val signatures = cur.getParameterTypes
// match parameters of signature to actual parameters
methodSignature.length == signatures.length &&
signatures.zipWithIndex.forall { case (clazz, i) =>
if (methodSignature(i) == classOf[Object]) {
// The element of the method signature comes from the Table API's
// apply().
// We can not decide the type here. It is an Unresolved Expression.
// Actually, we do not have to decide the type here, any method of
// the overrides
// which matches the arguments count will do the job.
// So here we choose any method is correct.
applyCnt += 1
}
parameterTypeEquals(methodSignature(i), clazz)
}
case cur if cur.isVarArgs =>
val signatures = cur.getParameterTypes
methodSignature.zipWithIndex.forall {
// non-varargs
case (clazz, i) if i < signatures.length - 1 =>
parameterTypeEquals(clazz, signatures(i))
// varargs
case (clazz, i) if i >= signatures.length - 1 =>
parameterTypeEquals(clazz, signatures.last.getComponentType)
} || (methodSignature.isEmpty && signatures.length == 1) // empty varargs
}
// if there is a fixed method, compiler will call this method preferentially
val fixedMethodsCount = filtered.count(!_.isVarArgs)
val found = filtered.filter { cur =>
fixedMethodsCount > 0 && !cur.isVarArgs ||
fixedMethodsCount == 0 && cur.isVarArgs
}.filter { cur =>
// filter abstract methods
!Modifier.isVolatile(cur.getModifiers)
}
if (found.length > 1) {
if (applyCnt > 0) {
// As we can not decide type while apply() exists, so choose any one is correct
return found.headOption
}
throw new ValidationException(
s"Found multiple 'eval' methods which match the signature.")
}
found.headOption
}
private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean = {
candidate == null ||
candidate == expected ||
expected == classOf[Object] ||
candidate == classOf[Object] || // Special case when we don't know the type
expected.isPrimitive && Primitives.wrap(expected) == candidate ||
(candidate.isArray &&
expected.isArray &&
candidate.getComponentType.isInstanceOf[Object] &&
expected.getComponentType == classOf[Object])
}
private def getRemainingJoinCondition(
rexBuilder: RexBuilder,
relBuilder: RelBuilder,
leftRowType: RelDataType,
tableRowType: RelDataType,
tableCalcProgram: Option[RexProgram],
checkedIndexInOrder: Array[Int],
joinKeyPairs: util.List[IntPair],
joinInfo: JoinInfo,
constantLookupKeys: util.Map[Int, (InternalType, Object)]): Option[RexNode] = {
val remainingPairs = joinKeyPairs
.filter(p => !checkedIndexInOrder.contains(p.target))
// convert remaining pairs to RexInputRef tuple for building sqlStdOperatorTable.EQUALS calls
val remainingAnds = remainingPairs.map { p =>
val leftInputRef = new RexInputRef(p.source, leftRowType.getFieldList.get(p.source).getType)
val rightInputRef = tableCalcProgram match {
case Some(program) =>
val rightKeyIdx = program
.getOutputRowType.getFieldNames
.indexOf(program.getInputRowType.getFieldNames.get(p.target))
new RexInputRef(
leftRowType.getFieldCount + rightKeyIdx,
program.getOutputRowType.getFieldList.get(rightKeyIdx).getType)
case None =>
new RexInputRef(
leftRowType.getFieldCount + p.target,
tableRowType.getFieldList.get(p.target).getType)
}
(leftInputRef, rightInputRef)
}
val equiAnds = relBuilder.and(remainingAnds.map(p => relBuilder.equals(p._1, p._2)): _*)
val condition = relBuilder.and(equiAnds, joinInfo.getRemaining(rexBuilder))
if (condition.isAlwaysTrue) {
None
} else {
Some(condition)
}
}
/**
* Gets the join key pairs from left input key to temporal table key
* @param joinInfo the join information of temporal table join
* @param temporalTableCalcProgram the calc programs on temporal table
*/
private def getTemporalTableJoinKeyPairs(
joinInfo: JoinInfo,
temporalTableCalcProgram: Option[RexProgram]): util.List[IntPair] = {
temporalTableCalcProgram match {
case Some(program) =>
// the target key of joinInfo is the calc output fields, we have to remapping to table here
val keyPairs: util.List[IntPair] = new util.ArrayList[IntPair]()
joinInfo.pairs().map {
p =>
val calcSrcIdx = getIdenticalSourceField(program, p.target)
if (calcSrcIdx != -1) {
keyPairs.add(new IntPair(p.source, calcSrcIdx))
}
}
keyPairs
case None => joinInfo.pairs()
}
}
/**
* Analyze the constant lookup keys in the temporal table from the calc program on the temporal
* table.
*/
def analyzeConstantLookupKeys(
cluster: RelOptCluster,
temporalTableCalcProgram: Option[RexProgram],
indexKeys: util.List[IndexKey]): util.Map[Int, (InternalType, Object)] = {
val constantKeyMap: util.Map[Int, (InternalType, Object)] =
new util.HashMap[Int, (InternalType, Object)]
// all the columns in index keys
val allKeys = mutable.HashSet.empty[Int]
indexKeys.map(_.getDefinedColumns.map(allKeys += _))
if (temporalTableCalcProgram.isDefined && null != temporalTableCalcProgram.get.getCondition) {
val program = temporalTableCalcProgram.get
val condition = RexUtil.toCnf(
cluster.getRexBuilder,
program.expandLocalRef(program.getCondition))
// presume 'A = 1 AND A = 2' will be reduced to ALWAYS_FALSE
extractConstantKeysFromEquiCondition(condition, allKeys.toArray, constantKeyMap)
}
constantKeyMap
}
private def findMatchedIndex(
allIndexes: util.List[IndexKey],
joinKeyPairs: util.List[IntPair],
constantLookupKeys: util.Map[Int, (InternalType, Object)]): Option[IndexKey] = {
val lookupKeyCandidates = joinKeyPairs.map(_.target) ++ constantLookupKeys.keySet()
// do validation later due to unified ErrorCode
allIndexes.find(_.isIndex(lookupKeyCandidates.toArray))
}
// ----------------------------------------------------------------------------------------
// Physical Optimization Utilities
// ----------------------------------------------------------------------------------------
// this is highly inspired by Calcite's RexProgram#getSourceField(int)
private def getIdenticalSourceField(rexProgram: RexProgram, outputOrdinal: Int): Int = {
assert((outputOrdinal >= 0) && (outputOrdinal < rexProgram.getProjectList.size()))
val project = rexProgram.getProjectList.get(outputOrdinal)
var index = project.getIndex
while (true) {
var expr = rexProgram.getExprList.get(index)
expr match {
case call: RexCall if call.getOperator == SqlStdOperatorTable.IN_FENNEL =>
// drill through identity function
expr = call.getOperands.get(0)
case call: RexCall if call.getOperator == SqlStdOperatorTable.CAST =>
// drill through identity function
expr = call.getOperands.get(0)
case _ =>
}
expr match {
case ref: RexLocalRef => index = ref.getIndex
case ref: RexInputRef => return ref.getIndex
case _ => return -1
}
}
-1
}
private def extractConstantKeysFromEquiCondition(
condition: RexNode,
indexKeys: Array[Int],
constantKeyMap: util.Map[Int, (InternalType, Object)]): Unit = {
condition match {
case c: RexCall if c.getKind == SqlKind.AND =>
c.getOperands.foreach(r => extractConstantKeys(r, indexKeys, constantKeyMap))
case rex: RexNode => extractConstantKeys(rex, indexKeys, constantKeyMap)
case _ =>
}
}
private def extractConstantKeys(
pred: RexNode,
keyIndexes: Array[Int],
constantKeyMap: util.Map[Int, (InternalType, Object)])
: util.Map[Int, (InternalType, Object)] = {
pred match {
case c: RexCall if c.getKind == SqlKind.EQUALS =>
val leftTerm = c.getOperands.get(0)
val rightTerm = c.getOperands.get(1)
val t = FlinkTypeFactory.toInternalType(rightTerm.getType)
leftTerm match {
case rexLiteral: RexLiteral =>
rightTerm match {
case r: RexInputRef if keyIndexes.contains(r.getIndex) =>
constantKeyMap.put(
r.getIndex,
(t, RexLiteralUtil.literalValue(rexLiteral)))
case _ =>
}
case _ => rightTerm match {
case rexLiteral: RexLiteral =>
leftTerm match {
case r: RexInputRef if keyIndexes.contains(r.getIndex) =>
constantKeyMap.put(
r.getIndex,
(t, RexLiteralUtil.literalValue(rexLiteral)))
case _ =>
}
case _ =>
}
}
case _ =>
}
constantKeyMap
}
// ----------------------------------------------------------------------------------------
// Validation
// ----------------------------------------------------------------------------------------
def validate(
tableSource: TableSource,
period: RexNode,
inputSchema: BaseRowSchema,
tableSourceSchema: BaseRowSchema,
joinKeyPairs: util.List[IntPair],
constantLookupKeys: util.Map[Int, (InternalType, Object)],
allIndexKeys: util.List[IndexKey],
joinedIndex: Option[IndexKey],
joinType: JoinRelType): Unit = {
if (joinKeyPairs.isEmpty && constantLookupKeys.isEmpty) {
throw new TableException(
"Temporal table join requires an equality condition on ALL of " +
"temporal table's primary key(s) or unique key(s) or index field(s).")
}
// checked index never be null, so declared index also not null.
if (allIndexKeys.isEmpty) {
throw new TableException(
"Temporal table require to define an primary key or unique key or index.")
}
// check a matched index exist
if (joinedIndex.isEmpty) {
throw new TableException(
"Temporal table join requires an equality condition on ALL of " +
"temporal table's primary key(s) or unique key(s) or index field(s).")
}
if (!tableSource.isInstanceOf[LookupableTableSource[_]]) {
throw new TableException("TableSource must implement LookupableTableSource interface " +
"if it is used as a temporal table.")
}
val checkedLookupKeys = joinedIndex.get.getDefinedColumns
val lookupKeyPairs = joinKeyPairs.filter(p => checkedLookupKeys.contains(p.target))
val leftKeys = lookupKeyPairs.map(_.source).toArray
val rightKeys = lookupKeyPairs.map(_.target) ++ constantLookupKeys.keys
val leftKeyTypes = leftKeys.map(inputSchema.fieldTypeInfos(_))
// use original keyPair to validate key types (rigthKeys may include constant keys)
val rightKeyTypes = lookupKeyPairs.map(p => tableSourceSchema.fieldTypeInfos(p.target))
// check type
leftKeyTypes.zip(rightKeyTypes).foreach(f => {
if (f._1 != f._2) {
val leftNames = leftKeys.map(inputSchema.fieldNames(_))
val rightNames = rightKeys.map(tableSourceSchema.fieldNames(_))
val leftNameTypes = leftKeyTypes
.zip(leftNames)
.map(f => s"${f._2}[${f._1.toString}]")
val rightNameTypes = rightKeyTypes
.zip(rightNames)
.map(f => s"${f._2}[${f._1.toString}]")
val condition = leftNameTypes
.zip(rightNameTypes)
.map(f => s"${f._1}=${f._2}")
.mkString(", ")
throw new TableException("Join: Equality join predicate on incompatible types. " +
s"And the condition is $condition")
}
})
if (joinType != JoinRelType.LEFT && joinType != JoinRelType.INNER) {
throw new TableException(
"Temporal table join currently only support INNER JOIN and LEFT JOIN, " +
"but was " + joinType.toString + " JOIN")
}
val tableReturnType = TypeConverters.createExternalTypeInfoFromDataType(
tableSource.getReturnType)
if (!tableReturnType.isInstanceOf[BaseRowTypeInfo] &&
!tableReturnType.isInstanceOf[RowTypeInfo]) {
throw new TableException(
"Temporal table join only support Row or BaseRow type as return type of temporal table." +
" But was " + tableReturnType)
}
// period specification check
period.getType match {
case t: TimeIndicatorRelDataType if !t.isEventTime => // ok
case _ =>
throw new TableException(
"Currently only support join temporal table as of on left table's proctime field")
}
period match {
case r: RexFieldAccess if r.getReferenceExpr.isInstanceOf[RexCorrelVariable] =>
// it's left table's field, ok
case call: RexCall if call.getOperator == ScalarSqlFunctions.PROCTIME =>
// it is PROCTIME() call, ok
case _ =>
throw new TableException(
"Currently only support join temporal table as of on left table's proctime field.")
}
// success
}
def checkUdtfReturnType(
tableDesc: String,
tableReturnTypeInfo: TypeInformation[_],
udtfReturnType: DataType,
extractedUdtfReturnTypeInfo: TypeInformation[_]): Unit = {
if (udtfReturnType == null) {
if (!rowTypeEquals(tableReturnTypeInfo, extractedUdtfReturnTypeInfo)) {
throw new TableException(
s"The TableSource [$tableDesc] return type $tableReturnTypeInfo " +
s"do not match its lookup function extracted return type $extractedUdtfReturnTypeInfo")
}
if (extractedUdtfReturnTypeInfo.getTypeClass != classOf[BaseRow] &&
extractedUdtfReturnTypeInfo.getTypeClass != classOf[Row]) {
throw new TableException(
"Result type of the async lookup TableFunction of TableSource " +
s"'$tableDesc' is " +
s"$extractedUdtfReturnTypeInfo type, " +
s"currently only Row and BaseRow are supported.")
}
} else {
val udtfReturnTypeInfo = TypeConverters.createExternalTypeInfoFromDataType(udtfReturnType)
if (!rowTypeEquals(tableReturnTypeInfo, udtfReturnTypeInfo)) {
throw new TableException(
s"The TableSource [$tableDesc] return type $tableReturnTypeInfo " +
s"do not match its lookup function return type $udtfReturnTypeInfo")
}
if (!udtfReturnTypeInfo.isInstanceOf[BaseRowTypeInfo] &&
!udtfReturnTypeInfo.isInstanceOf[RowTypeInfo]) {
throw new TableException(
"Result type of the async lookup TableFunction of TableSource " +
s"'$tableDesc' is $udtfReturnTypeInfo type, " +
s"currently only Row and BaseRow are supported.")
}
}
}
// ----------------------------------------------------------------------------------------
// toString Utilities
// ----------------------------------------------------------------------------------------
private def joinSelectionToString(inputType: RelDataType): String = {
inputType.getFieldNames.toList.mkString(", ")
}
private def joinConditionToString(
inputType: RelDataType,
joinCondition: RexNode,
expression: (RexNode, List[String], Option[List[RexNode]]) => String): String = {
val inFields = inputType.getFieldNames.toList
if (joinCondition != null) {
expression(joinCondition, inFields, None)
} else {
null
}
}
private def joinTypeToString(joinType: JoinRelType): String = joinType match {
case JoinRelType.INNER => "InnerJoin"
case JoinRelType.LEFT => "LeftOuterJoin"
case JoinRelType.RIGHT => "RightOuterJoin"
case JoinRelType.FULL => "FullOuterJoin"
}
private def joinToString(
tableSource: LookupableTableSource[_],
joinType: JoinRelType,
joinResultSchema: BaseRowSchema,
inputSchema: BaseRowSchema,
tableSchema: BaseRowSchema,
joinCondition: Option[RexNode],
constantLookupKeys: util.Map[Int, (InternalType, Object)],
lookupKeyPairs: util.List[IntPair],
expression: (RexNode, List[String], Option[List[RexNode]]) => String): String = {
val isAsyncEnabled = if (tableSource.getLookupConfig != null) {
tableSource.getLookupConfig.isAsyncEnabled
} else {
(new LookupConfig).isAsyncEnabled
}
val prefix = if (isAsyncEnabled) {
"AsyncJoinTable"
} else {
"JoinTable"
}
var str = s"$prefix(table: (${tableSource.explainSource()})" +
s", joinType: ${joinTypeToString(joinType)}" +
s", join: (${joinSelectionToString(inputSchema.relDataType)}), "
val inputFieldNames = inputSchema.fieldNames
val tableFieldNames = tableSchema.fieldNames
val keyPairNames = lookupKeyPairs.map { p =>
s"${inputFieldNames(p.source)}=${
if (p.target > -1) tableFieldNames(p.target) else -1
}"
}
str += s" on: (${keyPairNames.mkString(", ")}"
str +=
s"${constantLookupKeys.map(k => tableFieldNames(k._1) + " = " + k._2)
.mkString(", ")})"
if (joinCondition.isDefined) {
val joinConditionString = joinConditionToString(
joinResultSchema.relDataType,
joinCondition.get,
expression)
str += s", where: ($joinConditionString)"
}
str += ")"
str
}
def joinExplainTerms(
pw: RelWriter,
tableSource: TableSource,
inputType: RelDataType,
joinResultType: RelDataType,
calcProgram: Option[RexProgram],
lookupKeyPairs: util.List[IntPair],
joinCondition: Option[RexNode],
joinType: JoinRelType,
period: RexNode,
expression: (RexNode, List[String], Option[List[RexNode]]) => String): RelWriter = {
val condition: String = if (calcProgram.isDefined) {
CalcUtil.conditionToString(calcProgram.get, expression)
} else {
""
}
var source = tableSource.explainSource()
if (source == null || source.isEmpty) {
source = TableConnectorUtil.generateRuntimeName(
tableSource.getClass, tableSource.getTableSchema.getColumnNames)
}
val inputFieldNames = inputType.getFieldNames
val tableFieldNames = tableSource.getTableSchema.getColumnNames
val keyPairNames = lookupKeyPairs.map { p =>
s"${inputFieldNames(p.source)}=${
if (p.target >= 0 && p.target < tableFieldNames.length) tableFieldNames(p.target) else -1
}"
}
pw.item("join", joinSelectionToString(joinResultType))
.item("source", source)
.item("on", keyPairNames.mkString(", "))
.item("joinType", joinTypeToString(joinType))
.itemIf("where", condition, !condition.isEmpty)
.itemIf("joinCondition",
joinConditionToString(joinResultType, joinCondition.orNull, expression),
joinCondition.isDefined)
.item("period", period)
}
}