blob: c9e45da3021ac68b3cf55e652d6420e306f11992 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.functions.utils
import java.sql.Timestamp
import java.util
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory, RelDataTypeFactoryImpl}
import org.apache.calcite.schema.impl.ReflectiveFunctionBase
import org.apache.calcite.sql._
import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency
import org.apache.calcite.sql.`type`._
import org.apache.calcite.sql.parser.SqlParserPos
import org.apache.calcite.sql.validate.{SqlUserDefinedTableFunction, SqlUserDefinedTableMacro}
import org.apache.calcite.util.{NlsString, TimestampString}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.dataformat.BinaryString
import org.apache.flink.table.plan.schema.FlinkTableFunction
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.api.functions.TableFunction
import org.apache.flink.table.api.types.DataType
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.functions.utils.TableSqlFunction._
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
/**
* Calcite wrapper for user-defined table functions.
*
* @param name function name (used by SQL parser)
* @param udtf user-defined table function to be called
* @param implicitResultType Implicit result type information
* @param typeFactory type factory for converting Flink's between Calcite's types
* @param functionImpl Calcite table function schema
* @return [[TableSqlFunction]]
*/
class TableSqlFunction(
name: String,
displayName: String,
udtf: TableFunction[_],
implicitResultType: DataType,
typeFactory: FlinkTypeFactory,
functionImpl: FlinkTableFunction)
extends SqlUserDefinedTableFunction(
new SqlIdentifier(name, SqlParserPos.ZERO),
ReturnTypes.CURSOR,
createOperandTypeInference(name, udtf, typeFactory),
createOperandTypeChecker(name, udtf),
null,
functionImpl) {
/**
* Get the user-defined table function.
*/
def getTableFunction = udtf
/**
* Get the type information of the table returned by the table function.
*/
def getImplicitResultType: DataType = implicitResultType
override def isDeterministic: Boolean = udtf.isDeterministic
override def toString: String = displayName
override def getRowType(
typeFactory: RelDataTypeFactory,
operandList: util.List[SqlNode],
relTypes: util.List[RelDataType]): RelDataType = {
val types = relTypes.asScala.map(FlinkTypeFactory.toInternalType).toArray
val parameterClasses = getEvalMethodSignature(udtf, types)
val parameters = {
val builder = ReflectiveFunctionBase.builder
parameterClasses.foreach((cls) => builder.add(cls, null))
builder.build()
}
val arguments = parameters.zip(operandList.asScala).map {
case (parameter, relType) =>
try {
val o = SqlUserDefinedTableMacro.getValue(relType)
// TODO Type should convert to internal
val t = parameter.getType(typeFactory)
if (o != null && o.isInstanceOf[NlsString] &&
t.isInstanceOf[RelDataTypeFactoryImpl#JavaType] &&
t.asInstanceOf[RelDataTypeFactoryImpl#JavaType].getJavaClass
== classOf[BinaryString]) {
o.asInstanceOf[NlsString].getValue
} else if (o != null && o.isInstanceOf[TimestampString] &&
t.isInstanceOf[RelDataTypeFactoryImpl#JavaType] &&
t.asInstanceOf[RelDataTypeFactoryImpl#JavaType].getJavaClass == classOf[Timestamp]) {
new Timestamp(o.asInstanceOf[TimestampString].getMillisSinceEpoch)
} else {
SqlUserDefinedTableMacro.coerce(o, t)
}
} catch {
case e: SqlUserDefinedTableMacro.NonLiteralException =>
null
}
}.toArray
functionImpl.getRowType(typeFactory, arguments, parameterClasses)
}
}
object TableSqlFunction {
private[flink] def createOperandTypeInference(
name: String,
udtf: TableFunction[_],
typeFactory: FlinkTypeFactory)
: SqlOperandTypeInference = {
/**
* Operand type inference based on [[TableFunction]] given information.
*/
new SqlOperandTypeInference {
override def inferOperandTypes(
callBinding: SqlCallBinding,
returnType: RelDataType,
operandTypes: Array[RelDataType]): Unit = {
ScalarSqlFunction.inferOperandTypes(
name, udtf, typeFactory, callBinding, returnType, operandTypes)
}
}
}
private[flink] def createOperandTypeChecker(
name: String,
udtf: TableFunction[_])
: SqlOperandTypeChecker = {
val methods = checkAndExtractMethods(udtf, "eval")
/**
* Operand type checker based on [[TableFunction]] given information.
*/
new SqlOperandTypeChecker {
override def getAllowedSignatures(op: SqlOperator, opName: String): String = {
s"$opName[${signaturesToString(udtf, "eval")}]"
}
override def getOperandCountRange: SqlOperandCountRange = {
var min = 254
var max = -1
var isVarargs = false
methods.foreach( m => {
var len = m.getParameterTypes.length
if (len > 0 && m.isVarArgs && m.getParameterTypes()(len - 1).isArray) {
isVarargs = true
len = len - 1
}
max = Math.max(len, max)
min = Math.min(len, min)
})
if (isVarargs) {
// if eval method is varargs, set max to -1 to skip length check in Calcite
max = -1
}
SqlOperandCountRanges.between(min, max)
}
override def checkOperandTypes(
callBinding: SqlCallBinding,
throwOnFailure: Boolean)
: Boolean = {
val operandTypes = getOperandType(callBinding)
if (getEvalUserDefinedMethod(udtf, operandTypes).isEmpty) {
if (throwOnFailure) {
throw new ValidationException(
s"Given parameters of function '$name' do not match any signature. \n" +
s"Actual: ${signatureToString(operandTypes)} \n" +
s"Expected: ${signaturesToString(udtf, "eval")}")
} else {
false
}
} else {
true
}
}
override def isOptional(i: Int): Boolean = false
override def getConsistency: Consistency = Consistency.NONE
}
}
}