blob: d3969685091eeb78e5eb6f07e86643f62bf767de [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.expr.fn.registry;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.drill.exec.store.sys.store.DataChangeVersion;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.calcite.sql.SqlOperator;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.drill.common.scanner.persistence.AnnotatedClassDescriptor;
import org.apache.drill.common.scanner.persistence.ScanResult;
import org.apache.drill.exec.exception.FunctionValidationException;
import org.apache.drill.exec.exception.JarValidationException;
import org.apache.drill.exec.expr.annotations.FunctionTemplate;
import org.apache.drill.exec.expr.fn.DrillFuncHolder;
import org.apache.drill.exec.expr.fn.FunctionConverter;
import org.apache.drill.exec.planner.logical.DrillConstExecutor;
import org.apache.drill.exec.planner.sql.DrillOperatorTable;
import org.apache.drill.exec.planner.sql.DrillSqlAggOperator;
import org.apache.drill.exec.planner.sql.DrillSqlAggOperatorWithoutInference;
import org.apache.drill.exec.planner.sql.DrillSqlOperator;
import com.google.common.collect.ArrayListMultimap;
import org.apache.drill.exec.planner.sql.DrillSqlOperatorWithoutInference;
/**
* Registry of Drill functions.
*/
public class LocalFunctionRegistry implements AutoCloseable {
public static final String BUILT_IN = "built-in";
private static final Logger logger = LoggerFactory.getLogger(LocalFunctionRegistry.class);
private static final String functionSignaturePattern = "%s(%s)";
private static final ImmutableMap<String, Pair<Integer, Integer>> registeredFuncNameToArgRange =
ImmutableMap.<String, Pair<Integer, Integer>> builder()
// CONCAT is allowed to take [1, infinity) number of arguments.
// Currently, this flexibility is offered by DrillOptiq to rewrite it as
// a nested structure
.put("CONCAT", Pair.of(1, Integer.MAX_VALUE))
// When LENGTH is given two arguments, this function relies on DrillOptiq to rewrite it as
// another function based on the second argument (encodingType)
.put("LENGTH", Pair.of(1, 2))
// Dummy functions
.put("CONVERT_TO", Pair.of(2, 2))
.put("CONVERT_FROM", Pair.of(2, 2))
.put("FLATTEN", Pair.of(1, 1)).build();
private final FunctionRegistryHolder registryHolder;
/**
* Registers all functions present in Drill classpath on start-up.
* All functions will be marked as built-in. Built-in functions are not allowed to be unregistered.
* Since local function registry version is based on remote function registry version,
* initially sync version will be set to {@link DataChangeVersion#UNDEFINED}
* to ensure that upon first check both registries would be synchronized.
*/
public LocalFunctionRegistry(ScanResult classpathScan) {
registryHolder = new FunctionRegistryHolder();
validate(BUILT_IN, classpathScan);
register(Lists.newArrayList(new JarScan(BUILT_IN, classpathScan, this.getClass().getClassLoader())), DataChangeVersion.UNDEFINED);
if (logger.isTraceEnabled()) {
StringBuilder allFunctions = new StringBuilder();
for (DrillFuncHolder method: registryHolder.getAllFunctionsWithHolders().values()) {
allFunctions.append(method.toString()).append("\n");
}
logger.trace("Registered functions: [\n{}]", allFunctions);
}
}
/**
* @return remote function registry version number with which local function
* registry is synced
*/
public int getVersion() {
return registryHolder.getVersion();
}
/**
* Validates all functions, present in jars.
* Will throw {@link FunctionValidationException} if:
* <ol>
* <li>Jar with the same name has been already registered.</li>
* <li>Conflicting function with the similar signature is found.</li>
* <li>Aggregating function is not deterministic.</li>
*</ol>
* @param jarName jar name to be validated
* @param scanResult scan of all classes present in jar
* @return list of validated function signatures
*/
public List<String> validate(String jarName, ScanResult scanResult) {
List<String> functions = Lists.newArrayList();
FunctionConverter converter = new FunctionConverter();
List<AnnotatedClassDescriptor> providerClasses = scanResult.getAnnotatedClasses(FunctionTemplate.class.getName());
if (registryHolder.containsJar(jarName)) {
throw new JarValidationException(String.format("Jar with %s name has been already registered", jarName));
}
final ListMultimap<String, String> allFuncWithSignatures = registryHolder.getAllFunctionsWithSignatures();
for (AnnotatedClassDescriptor func : providerClasses) {
DrillFuncHolder holder = converter.getHolder(func, ClassLoader.getSystemClassLoader());
if (holder != null) {
String functionInput = holder.getInputParameters();
String[] names = holder.getRegisteredNames();
for (String name : names) {
String functionName = name.toLowerCase();
String functionSignature = String.format(functionSignaturePattern, functionName, functionInput);
if (allFuncWithSignatures.get(functionName).contains(functionSignature)) {
throw new FunctionValidationException(String.format("Found duplicated function in %s: %s",
registryHolder.getJarNameByFunctionSignature(functionName, functionSignature), functionSignature));
} else if (holder.isAggregating() && !holder.isDeterministic()) {
throw new FunctionValidationException(
String.format("Aggregate functions must be deterministic: %s", func.getClassName()));
} else {
functions.add(functionSignature);
allFuncWithSignatures.put(functionName, functionSignature);
}
}
} else {
logger.warn("Unable to initialize function for class {}", func.getClassName());
}
}
return functions;
}
/**
* Registers all functions present in jar and updates registry version.
* If jar name is already registered, all jar related functions will be overridden.
* To prevent classpath collisions during loading and unloading jars,
* each jar is shipped with its own class loader.
*
* @param jars list of jars to be registered
* @param version remote function registry version number with which local function registry is synced
*/
public void register(List<JarScan> jars, int version) {
Map<String, List<FunctionHolder>> newJars = new HashMap<>();
for (JarScan jarScan : jars) {
FunctionConverter converter = new FunctionConverter();
List<AnnotatedClassDescriptor> providerClasses = jarScan.getScanResult().getAnnotatedClasses(FunctionTemplate.class.getName());
List<FunctionHolder> functions = new ArrayList<>();
newJars.put(jarScan.getJarName(), functions);
for (AnnotatedClassDescriptor func : providerClasses) {
DrillFuncHolder holder = converter.getHolder(func, jarScan.getClassLoader());
if (holder != null) {
String functionInput = holder.getInputParameters();
String[] names = holder.getRegisteredNames();
for (String name : names) {
String functionName = name.toLowerCase();
String functionSignature = String.format(functionSignaturePattern, functionName, functionInput);
functions.add(new FunctionHolder(functionName, functionSignature, holder));
}
}
}
}
registryHolder.addJars(newJars, version);
}
/**
* Removes all function associated with the given jar name.
* Functions marked as built-in is not allowed to be unregistered.
* If user attempts to unregister built-in functions, logs warning and does nothing.
* Jar name is case-sensitive.
*
* @param jarName jar name to be unregistered
*/
public void unregister(String jarName) {
if (BUILT_IN.equals(jarName)) {
logger.warn("Functions marked as built-in are not allowed to be unregistered.");
return;
}
registryHolder.removeJar(jarName);
}
/**
* Returns list of jar names registered in function registry.
*
* @return list of jar names
*/
public List<String> getAllJarNames() {
return registryHolder.getAllJarNames();
}
/**
* @return quantity of all registered functions
*/
public int size(){
return registryHolder.functionsSize();
}
/**
* @param name function name
* @return all function holders associated with the function name. Function name is case insensitive.
*/
public List<DrillFuncHolder> getMethods(String name, AtomicInteger version) {
return registryHolder.getHoldersByFunctionName(name.toLowerCase(), version);
}
/**
* @param name function name
* @return all function holders associated with the function name. Function name is case insensitive.
*/
public List<DrillFuncHolder> getMethods(String name) {
return registryHolder.getHoldersByFunctionName(name.toLowerCase());
}
/**
* Returns a map of all function holders mapped by source jars
* @return all functions organized by source jars
*/
public Map<String, List<FunctionHolder>> getAllJarsWithFunctionsHolders() {
return registryHolder.getAllJarsWithFunctionHolders();
}
/**
* Registers all functions present in {@link DrillOperatorTable},
* also sets sync registry version used at the moment of function registration.
*
* @param operatorTable drill operator table
*/
public void register(DrillOperatorTable operatorTable) {
AtomicInteger versionHolder = new AtomicInteger();
final Map<String, Collection<DrillFuncHolder>> registeredFunctions =
registryHolder.getAllFunctionsWithHolders(versionHolder).asMap();
operatorTable.setFunctionRegistryVersion(versionHolder.get());
registerOperatorsWithInference(operatorTable, registeredFunctions);
registerOperatorsWithoutInference(operatorTable, registeredFunctions);
}
private void registerOperatorsWithInference(DrillOperatorTable operatorTable, Map<String,
Collection<DrillFuncHolder>> registeredFunctions) {
final Map<String, DrillSqlOperator.DrillSqlOperatorBuilder> map = new HashMap<>();
final Map<String, DrillSqlAggOperator.DrillSqlAggOperatorBuilder> mapAgg = new HashMap<>();
for (Entry<String, Collection<DrillFuncHolder>> function : registeredFunctions.entrySet()) {
final ArrayListMultimap<Pair<Integer, Integer>, DrillFuncHolder> functions = ArrayListMultimap.create();
final ArrayListMultimap<Integer, DrillFuncHolder> aggregateFunctions = ArrayListMultimap.create();
final String name = function.getKey().toUpperCase();
boolean isDeterministic = true;
boolean isNiladic = false;
boolean isVarArg = false;
for (DrillFuncHolder func : function.getValue()) {
final int paramCount = func.getParamCount();
if (func.isAggregating()) {
aggregateFunctions.put(paramCount, func);
} else {
final Pair<Integer, Integer> argNumberRange;
if(registeredFuncNameToArgRange.containsKey(name)) {
argNumberRange = registeredFuncNameToArgRange.get(name);
} else {
argNumberRange = Pair.of(func.getParamCount(), func.getParamCount());
}
functions.put(argNumberRange, func);
}
if (!func.isDeterministic() || func.isComplexWriterFuncHolder()) {
isDeterministic = false;
}
if (func.isNiladic()) {
isNiladic = true;
}
if (func.isVarArg()) {
isVarArg = true;
}
}
for (Entry<Pair<Integer, Integer>, Collection<DrillFuncHolder>> entry : functions.asMap().entrySet()) {
final Pair<Integer, Integer> range = entry.getKey();
final int max = range.getRight();
final int min = range.getLeft();
if(!map.containsKey(name)) {
map.put(name, new DrillSqlOperator.DrillSqlOperatorBuilder()
.setName(name));
}
final DrillSqlOperator.DrillSqlOperatorBuilder drillSqlOperatorBuilder = map.get(name);
drillSqlOperatorBuilder
.addFunctions(entry.getValue())
.setVarArg(isVarArg)
.setArgumentCount(min, max)
.setDeterministic(isDeterministic)
.setNiladic(isNiladic);
}
for (Entry<Integer, Collection<DrillFuncHolder>> entry : aggregateFunctions.asMap().entrySet()) {
if(!mapAgg.containsKey(name)) {
mapAgg.put(name, new DrillSqlAggOperator.DrillSqlAggOperatorBuilder().setName(name));
}
final DrillSqlAggOperator.DrillSqlAggOperatorBuilder drillSqlAggOperatorBuilder = mapAgg.get(name);
drillSqlAggOperatorBuilder
.addFunctions(entry.getValue())
.setArgumentCount(entry.getKey(), entry.getKey());
}
}
for (final Entry<String, DrillSqlOperator.DrillSqlOperatorBuilder> entry : map.entrySet()) {
operatorTable.addOperatorWithInference(
entry.getKey(),
entry.getValue().build());
}
for (final Entry<String, DrillSqlAggOperator.DrillSqlAggOperatorBuilder> entry : mapAgg.entrySet()) {
operatorTable.addOperatorWithInference(
entry.getKey(),
entry.getValue().build());
}
}
private void registerOperatorsWithoutInference(DrillOperatorTable operatorTable, Map<String, Collection<DrillFuncHolder>> registeredFunctions) {
SqlOperator op;
for (Entry<String, Collection<DrillFuncHolder>> function : registeredFunctions.entrySet()) {
Set<Integer> argCounts = new HashSet<>();
String name = function.getKey().toUpperCase();
for (DrillFuncHolder func : function.getValue()) {
if (argCounts.add(func.getParamCount())) {
if (func.isAggregating()) {
op = new DrillSqlAggOperatorWithoutInference(name, func.getParamCount(), func.isVarArg());
} else {
boolean isDeterministic;
// prevent Drill from folding constant functions with types that cannot be materialized
// into literals
if (DrillConstExecutor.NON_REDUCIBLE_TYPES.contains(func.getReturnType().getMinorType()) || func.isComplexWriterFuncHolder()) {
isDeterministic = false;
} else {
isDeterministic = func.isDeterministic();
}
op = new DrillSqlOperatorWithoutInference(name, func.getParamCount(),
func.getReturnType(), isDeterministic, func.isNiladic(), func.isVarArg());
}
operatorTable.addOperatorWithoutInference(function.getKey(), op);
}
}
}
}
@Override
public void close() {
registryHolder.close();
}
}