blob: 9d8c7687314bda656536553f5cfdf04852f7350b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.functions.utils
import java.lang.reflect.{Method, Modifier}
import java.lang.{Integer => JInt, Long => JLong}
import java.sql.{Date, Time, Timestamp}
import com.google.common.primitives.Primitives
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory}
import org.apache.calcite.rex.{RexLiteral, RexNode}
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.{SqlFunction, SqlOperatorBinding}
import org.apache.commons.codec.binary.Base64
import org.apache.flink.api.common.functions.InvalidTypesException
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils._
import org.apache.flink.table.api.functions._
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.expressions._
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.plan.schema.DeferredTypeFlinkTableFunction
import org.apache.flink.table.dataformat.{BaseRow, BinaryString, Decimal}
import org.apache.flink.table.errorcode.TableErrors
import org.apache.flink.table.hive.functions._
import org.apache.flink.table.api.types._
import org.apache.flink.table.typeutils.TypeUtils
import org.apache.flink.types.Row
import org.apache.flink.util.InstantiationUtil
import org.apache.hadoop.hive.ql.exec.{UDAF, UDF}
import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFResolver2, GenericUDF, GenericUDTF}
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.postfixOps
object UserDefinedFunctionUtils {
/**
* Checks if a user-defined function can be easily instantiated.
*/
def checkForInstantiation(clazz: Class[_]): Unit = {
if (!InstantiationUtil.isPublic(clazz)) {
throw new ValidationException(s"Function class ${clazz.getCanonicalName} is not public.")
}
else if (!InstantiationUtil.isProperClass(clazz)) {
throw new ValidationException(
s"Function class ${clazz.getCanonicalName} is no proper class," +
" it is either abstract, an interface, or a primitive type.")
}
else if (InstantiationUtil.isNonStaticInnerClass(clazz)) {
throw new ValidationException(
s"The class ${clazz.getCanonicalName} is an inner class, but" +
" not statically accessible.")
}
}
/**
* Check whether this is a Scala object. It is forbidden to use [[TableFunction]] implemented
* by a Scala object, since concurrent risks.
*/
def checkNotSingleton(clazz: Class[_]): Unit = {
// TODO it is not a good way to check singleton. Maybe improve it further.
if (clazz.getFields.map(_.getName) contains "MODULE$") {
throw new ValidationException(
s"TableFunction implemented by class ${clazz.getCanonicalName} " +
s"is a Scala object, it is forbidden since concurrent risks.")
}
}
// ----------------------------------------------------------------------------------------------
// Utilities for user-defined methods
// ----------------------------------------------------------------------------------------------
def throwValidationException(
name: String,
func: UserDefinedFunction,
parameters: Array[InternalType]): Method = {
throw new ValidationException(
s"Given parameters of function '$name' do not match any signature. \n" +
s"Actual: ${signatureToString(parameters)} \n" +
s"Expected: ${signaturesToString(func, "eval")}")
}
private def getParamClassesConsiderVarArgs(
isVarArgs: Boolean,
matchingSignature: Array[Class[_]],
expectedLength: Int): Array[Class[_]] = {
var paramClasses = new mutable.ArrayBuffer[Class[_]]
for (i <- 0 until expectedLength) {
if (i < matchingSignature.length - 1) {
paramClasses += matchingSignature(i)
} else if (isVarArgs) {
paramClasses += matchingSignature.last.getComponentType
} else {
// last argument is not an array type
paramClasses += matchingSignature.last
}
}
paramClasses.toArray
}
def getEvalMethodSignature(
func: CustomTypeDefinedFunction,
expectedTypes: Array[InternalType]): Array[Class[_]] = {
val method = getEvalUserDefinedMethod(func, expectedTypes).getOrElse(
throwValidationException(func.getClass.getCanonicalName, func, expectedTypes)
)
getParamClassesConsiderVarArgs(method.isVarArgs, method.getParameterTypes, expectedTypes.length)
}
def getAggUserDefinedInputTypes(
func: AggregateFunction[_, _],
externalAccType: DataType,
expectedTypes: Array[InternalType]): Array[DataType] = {
val accMethod = getAggFunctionUDIMethod(
func, "accumulate", externalAccType, expectedTypes).getOrElse(
throwValidationException(func.getClass.getCanonicalName, func, expectedTypes)
)
func.getUserDefinedInputTypes(
getParamClassesConsiderVarArgs(
accMethod.isVarArgs,
// drop first, first must be Acc.
accMethod.getParameterTypes.drop(1),
expectedTypes.length)).zipWithIndex.map {
case (t, i) =>
// we don't trust GenericType.
if (TypeUtils.isGeneric(t)) expectedTypes(i) else t
}
}
/**
* Returns signatures of accumulate methods matching the given signature of [[InternalType]].
* Elements of the signature can be null (act as a wildcard).
*/
def getAccumulateMethodSignature(
function: AggregateFunction[_, _],
expectedTypes: Seq[InternalType])
: Option[Array[Class[_]]] = {
getAggFunctionUDIMethod(
function,
"accumulate",
getAccumulatorTypeOfAggregateFunction(function),
expectedTypes
).map(_.getParameterTypes)
}
def getParameterTypes(
function: UserDefinedFunction,
signature: Array[Class[_]]): Array[InternalType] = {
signature.map { c =>
try {
TypeExtractor.getForClass(c).toInternalType
} catch {
case ite: InvalidTypesException =>
throw new ValidationException(
s"Parameter types of function '${function.getClass.getCanonicalName}' cannot be " +
s"automatically determined. Please provide type information manually.")
}
}
}
def getEvalUserDefinedMethod(
function: CustomTypeDefinedFunction,
expectedTypes: Seq[InternalType])
: Option[Method] = {
getUserDefinedMethod(
function,
"eval",
typesToClasses(expectedTypes),
expectedTypes.toArray,
(paraClasses) => function.getParameterTypes(paraClasses))
}
def getAggFunctionUDIMethod(
function: AggregateFunction[_, _],
methodName: String,
accType: DataType,
expectedTypes: Seq[InternalType])
: Option[Method] = {
val input = (Array(accType) ++ expectedTypes).toSeq
getUserDefinedMethod(
function,
methodName,
typesToClasses(input),
input.map(_.toInternalType).toArray,
(cls) => Array(accType) ++
function.getUserDefinedInputTypes(cls.drop(1)))
}
/**
* Get method without match DateType.
*/
def getUserDefinedMethod(
function: UserDefinedFunction,
methodName: String,
signature: Seq[DataType])
: Option[Method] = {
getUserDefinedMethod(
function,
methodName,
typesToClasses(signature),
signature.map(_.toInternalType).toArray,
(cls) => cls.indices.map((_) => null).toArray)
}
/**
* Returns user defined method matching the given name and signature.
*
* @param function function instance
* @param methodName method name
* @param methodSignature an array of raw Java classes. We compare the raw Java classes not the
* DateType. DateType does not matter during runtime (e.g.
* within a MapFunction)
*/
def getUserDefinedMethod(
function: UserDefinedFunction,
methodName: String,
methodSignature: Array[Class[_]],
internalTypes: Array[InternalType],
parameterTypes: (Array[Class[_]] => Array[DataType]))
: Option[Method] = {
val methods = checkAndExtractMethods(function, methodName)
var applyCnt = 0
val filtered = methods
// go over all the methods and filter out matching methods
.filter {
case cur if !cur.isVarArgs =>
val signatures = cur.getParameterTypes
val dataTypes = parameterTypes(signatures)
// match parameters of signature to actual parameters
methodSignature.length == signatures.length &&
signatures.zipWithIndex.forall { case (clazz, i) => {
if (methodSignature(i) == classOf[Object]) {
// The element of the method signature comes from the Table API's apply().
// We can not decide the type here. It is an Unresolved Expression.
// Actually, we do not have to decide the type here, any method of the overrides
// which matches the arguments count will do the job.
// So here we choose any method is correct.
applyCnt += 1
}
parameterTypeEquals(methodSignature(i), clazz) ||
parameterDataTypeEquals(internalTypes(i), dataTypes(i))
}
}
case cur if cur.isVarArgs =>
val signatures = cur.getParameterTypes
val dataTypes = parameterTypes(signatures)
methodSignature.zipWithIndex.forall {
// non-varargs
case (clazz, i) if i < signatures.length - 1 =>
parameterTypeEquals(clazz, signatures(i)) ||
parameterDataTypeEquals(internalTypes(i), dataTypes(i))
// varargs
case (clazz, i) if i >= signatures.length - 1 =>
parameterTypeEquals(clazz, signatures.last.getComponentType) ||
parameterDataTypeEquals(internalTypes(i), dataTypes(i))
} || (methodSignature.isEmpty && signatures.length == 1) // empty varargs
}
// if there is a fixed method, compiler will call this method preferentially
val fixedMethodsCount = filtered.count(!_.isVarArgs)
val found = filtered.filter { cur =>
fixedMethodsCount > 0 && !cur.isVarArgs ||
fixedMethodsCount == 0 && cur.isVarArgs
}.filter { cur =>
// filter abstract methods
!Modifier.isVolatile(cur.getModifiers)
}
// check if there is a Scala varargs annotation
if (found.isEmpty &&
methods.exists { method =>
val signatures = method.getParameterTypes
val dataTypes = parameterTypes(signatures)
if (!method.isVarArgs && signatures.length != methodSignature.length) {
false
} else if (method.isVarArgs && signatures.length > methodSignature.length + 1) {
false
} else {
signatures.zipWithIndex.forall {
case (clazz, i) if i < signatures.length - 1 =>
parameterTypeEquals(methodSignature(i), clazz) ||
parameterDataTypeEquals(internalTypes(i), dataTypes(i))
case (clazz, i) if i == signatures.length - 1 =>
clazz.getName.equals("scala.collection.Seq")
}
}
}) {
throw new ValidationException(
s"Scala-style variable arguments in '$methodName' methods are not supported. Please " +
s"add a @scala.annotation.varargs annotation.")
} else if (found.length > 1) {
if (applyCnt > 0) {
// As we can not decide type while apply() exists, so choose any one is correct
return found.headOption
}
throw new ValidationException(
s"Found multiple '$methodName' methods which match the signature.")
}
found.headOption
}
/**
* Check if a given method exists in the given function
*/
def ifMethodExistInFunction(method: String, function: UserDefinedFunction): Boolean = {
val methods = function
.getClass
.getMethods
.filter {
m => m.getName == method
}
!methods.isEmpty
}
/**
* Extracts methods and throws a [[ValidationException]] if no implementation
* can be found, or implementation does not match the requirements.
*/
def checkAndExtractMethods(
function: UserDefinedFunction,
methodName: String): Array[Method] = {
val methods = function
.getClass
.getMethods
.filter { m =>
val modifiers = m.getModifiers
m.getName == methodName &&
Modifier.isPublic(modifiers) &&
!Modifier.isAbstract(modifiers) &&
!(function.isInstanceOf[TableFunction[_]] && Modifier.isStatic(modifiers))
}
if (methods.isEmpty) {
throw new ValidationException(
s"Function class '${function.getClass.getCanonicalName}' does not implement at least " +
s"one method named '$methodName' which is public, not abstract and " +
s"(in case of table functions) not static.")
}
methods
}
def getMethodSignatures(
function: UserDefinedFunction,
methodName: String): Array[Array[Class[_]]] = {
checkAndExtractMethods(function, methodName).map(_.getParameterTypes)
}
// ----------------------------------------------------------------------------------------------
// Utilities for SQL functions
// ----------------------------------------------------------------------------------------------
/**
* Create [[SqlFunction]] for a [[ScalarFunction]]
*
* @param name function name
* @param function scalar function
* @param typeFactory type factory
* @return the ScalarSqlFunction
*/
def createScalarSqlFunction(
name: String,
displayName: String,
function: ScalarFunction,
typeFactory: FlinkTypeFactory)
: SqlFunction = {
new ScalarSqlFunction(name, displayName, function, typeFactory)
}
/**
* Create [[SqlFunction]] for a [[TableFunction]]
*
* @param name function name
* @param tableFunction table function
* @param implicitResultType the implicit type information of returned table
* @param typeFactory type factory
* @return the TableSqlFunction
*/
def createTableSqlFunction(
name: String,
displayName: String,
tableFunction: TableFunction[_],
implicitResultType: DataType,
typeFactory: FlinkTypeFactory)
: TableSqlFunction = {
// we don't know the exact result type yet.
val function = new DeferredTypeFlinkTableFunction(tableFunction, implicitResultType)
new TableSqlFunction(name, displayName, tableFunction, implicitResultType,
typeFactory, function)
}
/**
* Create [[SqlFunction]] for an [[AggregateFunction]]
*
* @param name function name
* @param aggFunction aggregate function
* @param typeFactory type factory
* @return the TableSqlFunction
*/
def createAggregateSqlFunction(
name: String,
displayName: String,
aggFunction: AggregateFunction[_, _],
externalResultType: DataType,
externalAccType: DataType,
typeFactory: FlinkTypeFactory)
: SqlFunction = {
//check if a qualified accumulate method exists before create Sql function
checkAndExtractMethods(aggFunction, "accumulate")
AggSqlFunction(
name,
displayName,
aggFunction,
externalResultType,
externalAccType,
typeFactory,
aggFunction.requiresOver)
}
// ----------------------------------------------------------------------------------------------
// Utilities for user-defined functions
// ----------------------------------------------------------------------------------------------
/**
* Tries to infer the DataType of an AggregateFunction's return type.
*
* @param aggregateFunction The AggregateFunction for which the return type is inferred.
* @param extractedType The implicitly inferred type of the result type.
*
* @return The inferred result type of the AggregateFunction.
*/
def getResultTypeOfAggregateFunction(
aggregateFunction: AggregateFunction[_, _],
extractedType: DataType = null)
: DataType = {
val resultType = aggregateFunction.getResultType
if (resultType != null) {
resultType
} else if (extractedType != null) {
extractedType
} else {
try {
extractTypeFromAggregateFunction(aggregateFunction, 0)
} catch {
case ite: InvalidTypesException =>
throw new TableException(
"Cannot infer generic type of ${aggregateFunction.getClass}. " +
"You can override AggregateFunction.getResultType() to specify the type.",
ite
)
}
}
}
/**
* Tries to infer the InternalType of an AggregateFunction's accumulator type.
*
* @param aggregateFunction The AggregateFunction for which the accumulator type is inferred.
* @param extractedType The implicitly inferred type of the accumulator type.
*
* @return The inferred accumulator type of the AggregateFunction.
*/
def getAccumulatorTypeOfAggregateFunction(
aggregateFunction: AggregateFunction[_, _],
extractedType: DataType = null): DataType = {
val accType = aggregateFunction.getAccumulatorType
if (accType != null) {
accType
} else if (extractedType != null) {
extractedType
} else {
try {
extractTypeFromAggregateFunction(aggregateFunction, 1)
} catch {
case ite: InvalidTypesException =>
throw new TableException(
"Cannot infer generic type of ${aggregateFunction.getClass}. " +
"You can override AggregateFunction.getAccumulatorType() to specify the type.",
ite
)
}
}
}
/**
* Internal method to extract a type from an AggregateFunction's type parameters.
*
* @param aggregateFunction The AggregateFunction for which the type is extracted.
* @param parameterTypePos The position of the type parameter for which the type is extracted.
*
* @return The extracted type.
*/
@throws(classOf[InvalidTypesException])
private def extractTypeFromAggregateFunction(
aggregateFunction: AggregateFunction[_, _],
parameterTypePos: Int): DataType = {
new TypeInfoWrappedDataType(TypeExtractor.createTypeInfo(
aggregateFunction,
classOf[AggregateFunction[_, _]],
aggregateFunction.getClass,
parameterTypePos))
}
def getResultTypeOfScalarFunction(
function: ScalarFunction,
arguments: Array[AnyRef],
argTypes: Array[InternalType]): DataType = {
val userDefinedTypeInfo = function.getResultType(
arguments, getEvalMethodSignature(function, argTypes))
if (userDefinedTypeInfo != null) {
userDefinedTypeInfo
} else {
extractTypeFromScalarFunc(function, argTypes)
}
}
private[flink] def extractTypeFromScalarFunc(
function: ScalarFunction,
argTypes: Array[InternalType]): DataType = {
try {TypeExtractor.getForClass(
getResultTypeClassOfScalarFunction(function, argTypes))
} catch {
case _: InvalidTypesException =>
throw new ValidationException(
s"Return type of scalar function '${function.getClass.getCanonicalName}' cannot be " +
s"automatically determined. Please provide type information manually.")
}
}
/**
* Returns the return type of the evaluation method matching the given signature.
*/
def getResultTypeClassOfScalarFunction(
function: ScalarFunction,
argTypes: Array[InternalType]): Class[_] = {
// find method for signature
getEvalUserDefinedMethod(function, argTypes).getOrElse(
throw new IllegalArgumentException("Given signature is invalid.")).getReturnType
}
/**
* Returns the return type of the evaluation method matching the given signature.
*/
def getResultTypeClassOfPythonScalarFunction(returnType: InternalType): Class[_] = {
returnType match {
case _: StringType => classOf[org.apache.flink.table.dataformat.BinaryString]
case _: BooleanType => classOf[java.lang.Boolean]
case _: ByteType => classOf[java.lang.Byte]
case _: ShortType => classOf[java.lang.Short]
case _: IntType => classOf[java.lang.Integer]
case _: LongType => classOf[java.lang.Long]
case _: FloatType => classOf[java.lang.Float]
case _: DoubleType => classOf[java.lang.Double]
case _: DecimalType => classOf[java.math.BigDecimal]
case _: DateType => classOf[java.lang.Integer]
case _: TimeType => classOf[java.lang.Integer]
case _: TimestampType => classOf[java.lang.Long]
case _: ByteArrayType => classOf[Array[Byte]]
}
}
// ----------------------------------------------------------------------------------------------
// Miscellaneous
// ----------------------------------------------------------------------------------------------
/**
* Returns field names and field positions for a given [[DataType]].
*
* Field names are automatically extracted for [[RowType]].
*
* @param inputType The DataType to extract the field names and positions from.
* @return A tuple of two arrays holding the field names and corresponding field positions.
*/
def getFieldInfo(inputType: DataType)
: (Array[String], Array[Int], Array[InternalType]) = {
(
TableEnvironment.getFieldNames(inputType),
TableEnvironment.getFieldIndices(inputType),
TableEnvironment.getFieldTypes(inputType))
}
/**
* Prints one signature consisting of classes.
*/
def signatureToString(signature: Array[Class[_]]): String =
signature.map { clazz =>
if (clazz == null) {
"null"
} else {
clazz.getCanonicalName
}
}.mkString("(", ", ", ")")
/**
* Prints one signature consisting of DataType.
*/
def signatureToString(signature: Seq[DataType]): String = {
signatureToString(typesToClasses(signature))
}
/**
* Prints all signatures of methods with given name in a class.
*/
def signaturesToString(function: UserDefinedFunction, name: String): String = {
getMethodSignatures(function, name).map(signatureToString).mkString(", ")
}
/**
* Extracts type classes of [[DataType]] in a null-aware way.
*/
def typesToClasses(types: Seq[DataType]): Array[Class[_]] =
types.map { t =>
if (t == null) {
null
} else {
TypeUtils.getExternalClassForType(t)
}
}.toArray
/**
* Compares parameter candidate classes with expected classes. If true, the parameters match.
* Candidate can be null (acts as a wildcard).
*/
private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean =
candidate == null ||
candidate == expected ||
expected == classOf[Object] ||
candidate == classOf[Object] || // Special case when we don't know the type
expected.isPrimitive && Primitives.wrap(expected) == candidate ||
candidate == classOf[Date] && (expected == classOf[Int] || expected == classOf[JInt]) ||
candidate == classOf[Time] && (expected == classOf[Int] || expected == classOf[JInt]) ||
candidate == classOf[Timestamp] && (expected == classOf[Long] || expected == classOf[JLong]) ||
candidate == classOf[BinaryString] && expected == classOf[String] ||
candidate == classOf[String] && expected == classOf[BinaryString] ||
classOf[BaseRow].isAssignableFrom(candidate) && expected == classOf[Row] ||
candidate == classOf[Row] && classOf[BaseRow].isAssignableFrom(expected) ||
candidate == classOf[Decimal] && expected == classOf[BigDecimal] ||
candidate == classOf[BigDecimal] && expected == classOf[Decimal] ||
(candidate.isArray &&
expected.isArray &&
candidate.getComponentType.isInstanceOf[Object] &&
expected.getComponentType == classOf[Object])
private def parameterDataTypeEquals(internal: InternalType, parameterType: DataType): Boolean =
// There is a special equal to GenericType. We need rewrite type extract to BaseRow etc...
parameterType.toInternalType == internal || internal.getTypeClass == parameterType.getTypeClass
@throws[Exception]
def serialize(function: UserDefinedFunction): String = {
val byteArray = InstantiationUtil.serializeObject(function)
Base64.encodeBase64URLSafeString(byteArray)
}
@throws[Exception]
def deserialize(data: String): UserDefinedFunction = {
val byteData = Base64.decodeBase64(data)
InstantiationUtil
.deserializeObject[UserDefinedFunction](byteData, Thread.currentThread.getContextClassLoader)
}
/**
* Creates a [[LogicalTableFunctionCall]] by parsing a String expression.
*
* @param tableEnv The table environmenent to lookup the function.
* @param udtf a String expression of a TableFunctionCall, such as "split(c)"
* @return A LogicalTableFunctionCall.
*/
def createLogicalFunctionCall(
tableEnv: TableEnvironment,
udtf: String): LogicalTableFunctionCall = {
var alias: Option[Seq[String]] = None
// unwrap an Expression until we get a TableFunctionCall
def unwrap(expr: Expression): TableFunctionCall = expr match {
case Alias(child, name, extraNames) =>
alias = Some(Seq(name) ++ extraNames)
unwrap(child)
case Call(name, args) =>
val function = tableEnv.chainedFunctionCatalog.lookupFunction(name, args)
unwrap(function)
case c: TableFunctionCall => c
case _ =>
throw new TableException(
"Table(TableEnv, String) constructor only accept String that " +
"define table function followed by some Alias.")
}
val functionCall: LogicalTableFunctionCall = unwrap(ExpressionParser.parseExpression(udtf))
.as(alias).toLogicalTableFunctionCall(child = null)
functionCall
}
def getOperandType(callBinding: SqlOperatorBinding): Seq[InternalType] = {
val operandTypes = for (i <- 0 until callBinding.getOperandCount)
yield callBinding.getOperandType(i)
operandTypes.map { operandType =>
if (operandType.getSqlTypeName == SqlTypeName.NULL) {
null
} else {
FlinkTypeFactory.toInternalType(operandType)
}
}
}
private[table] def getResultTypeOfCTDFunction(
func: CustomTypeDefinedFunction,
params: Array[Expression],
getImplicitResultType: () => DataType): DataType = {
val arguments = params.map {
case exp: Literal =>
exp.value.asInstanceOf[AnyRef]
case _ =>
null
}
val signature = params.map { param =>
if (param.valid) param.resultType
else new GenericType(new Object().getClass) // apply() can not decide type
}
val argTypes = getEvalMethodSignature(func, signature)
val udt = func.getResultType(arguments, argTypes)
if (udt != null) udt else getImplicitResultType()
}
/**
* Transform the rex nodes to Objects
* Only literal rex nodes will be transformed, non-literal rex nodes will be
* translated to nulls.
*
* @param rexNodes actual parameters of the function
* @return A Array of the Objects
*/
private[table] def transformRexNodes(
rexNodes: java.util.List[RexNode]): Array[AnyRef] = {
rexNodes.map {
case rexNode: RexLiteral =>
val value = rexNode.getValue2
rexNode.getType.getSqlTypeName match {
case SqlTypeName.INTEGER =>
value.asInstanceOf[Long].toInt.asInstanceOf[AnyRef]
case SqlTypeName.SMALLINT =>
value.asInstanceOf[Long].toShort.asInstanceOf[AnyRef]
case SqlTypeName.TINYINT =>
value.asInstanceOf[Long].toByte.asInstanceOf[AnyRef]
case SqlTypeName.FLOAT =>
value.asInstanceOf[Double].toFloat.asInstanceOf[AnyRef]
case SqlTypeName.REAL =>
value.asInstanceOf[Double].toFloat.asInstanceOf[AnyRef]
case _ =>
value.asInstanceOf[AnyRef]
}
case _ =>
null
}.toArray
}
private[table] def buildRelDataType(
typeFactory: RelDataTypeFactory,
resultType: InternalType,
fieldNames: Array[String],
fieldIndexes: Array[Int]): RelDataType = {
if (fieldIndexes.length != fieldNames.length) {
throw new TableException(
"Number of field indexes and field names must be equal.")
}
// check uniqueness of field names
if (fieldNames.length != fieldNames.toSet.size) {
throw new TableException(
"Table field names must be unique.")
}
val fieldTypes: Array[InternalType] =
resultType match {
case bt: RowType =>
if (fieldNames.length != bt.getArity) {
throw new TableException(
s"Arity of type (" + bt.getFieldNames.deep + ") " +
"not equal to number of field names " + fieldNames.deep + ".")
}
fieldIndexes.map(i => bt.getInternalTypeAt(i).toInternalType)
case _ =>
if (fieldIndexes.length != 1 || fieldIndexes(0) != 0) {
throw new TableException(
"Non-composite input type may have only a single field and its index must be 0.")
}
Array(resultType)
}
val flinkTypeFactory = typeFactory.asInstanceOf[FlinkTypeFactory]
val builder = flinkTypeFactory.builder
fieldNames
.zip(fieldTypes)
.foreach { f =>
builder.add(f._1, flinkTypeFactory.createTypeFromInternalType(f._2, isNullable = true))
}
builder.build
}
def createUserDefinedFunction(
classLoader: ClassLoader,
funcName: String,
functionDef: String): Any = {
var func: Any = Nil
val javaClass = functionDef.contains(".")
if (javaClass) {
try {
func = classLoader.loadClass(functionDef).newInstance()
} catch {
case e: Exception =>
throw new RuntimeException(
TableErrors.INST.sqlCreateUserDefinedFuncError(
funcName,
functionDef,
e.getClass.getCanonicalName + " : " + e.getMessage),
e)
}
} else {
try {
// try deserialize first
func = InstantiationUtil
.deserializeObject[Object](Base64.decodeBase64(hexString2String(functionDef)),
classLoader)
} catch {
case e: Exception =>
try {
// It might be a java class without package name
func = classLoader.loadClass(functionDef).newInstance()
} catch {
case e: Exception =>
throw new RuntimeException(
TableErrors.INST.sqlCreateUserDefinedFuncError(
funcName,
functionDef,
e.getClass.getCanonicalName + " : " + e.getMessage),
e)
}
}
}
// Check hive and covert to flink udf
func match {
case _: UDF =>
func = new HiveSimpleUDF(new HiveFunctionWrapper[UDF](functionDef))
case _: GenericUDF =>
func = new HiveGenericUDF(new HiveFunctionWrapper[GenericUDF](functionDef))
case _: UDAF =>
func = new HiveUDAFFunction(new HiveFunctionWrapper[UDAF](functionDef))
case _: GenericUDAFResolver2 =>
func = new HiveUDAFFunction(new HiveFunctionWrapper[GenericUDAFResolver2](functionDef))
case _: GenericUDTF =>
func = new HiveGenericUDTF(new HiveFunctionWrapper[GenericUDTF](functionDef))
case _ =>
}
func
}
def getImplicitResultType[T](tf: TableFunction[T]) = {
val implicitResultType = try {
TypeExtractor.createTypeInfo(tf, classOf[TableFunction[_]], tf.getClass, 0)
} catch {
case e: InvalidTypesException =>
// may be we should get type from getResultType
new GenericType(classOf[AnyRef])
}
implicitResultType
}
/**
* Get implicit type from table function through reflection, We will consider getResultType first
* then this, see [[getResultTypeIgnoreException]] for details.
* @return Inferred implicit [[TypeInformation]], if [[InvalidTypesException]] throws, return
* GenericTypeInfo(classOf[AnyRef]) as fallback
*/
def getImplicitResultTypeInfo[T](tf: TableFunction[T]): TypeInformation[T] = {
val implicitResultType = try {
TypeExtractor.createTypeInfo(tf, classOf[TableFunction[_]], tf.getClass, 0)
} catch {
case _: InvalidTypesException =>
new GenericTypeInfo(classOf[AnyRef])
}
implicitResultType.asInstanceOf[TypeInformation[T]]
}
/**
* Get result type from [[TableFunction]] while sans any exception.
* @return null if any exception throws
*/
def getResultTypeIgnoreException[T](tf: TableFunction[T],
arguments: Array[AnyRef] = null,
argTypes: Array[Class[_]] = null): DataType = {
try {
tf.getResultType(arguments, argTypes)
} catch {
case _: Exception =>
null
}
}
private def hexString2String(str: String): String = {
str.sliding(2, 2).map(s =>
Integer.parseInt(s, 16).asInstanceOf[Char].toString).mkString("")
}
}