/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.drill.exec.expr.fn.registry;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.drill.exec.store.sys.store.DataChangeVersion;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.calcite.sql.SqlOperator;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.drill.common.scanner.persistence.AnnotatedClassDescriptor;
import org.apache.drill.common.scanner.persistence.ScanResult;
import org.apache.drill.exec.exception.FunctionValidationException;
import org.apache.drill.exec.exception.JarValidationException;
import org.apache.drill.exec.expr.annotations.FunctionTemplate;
import org.apache.drill.exec.expr.fn.DrillFuncHolder;
import org.apache.drill.exec.expr.fn.FunctionConverter;
import org.apache.drill.exec.planner.logical.DrillConstExecutor;
import org.apache.drill.exec.planner.sql.DrillOperatorTable;
import org.apache.drill.exec.planner.sql.DrillSqlAggOperator;
import org.apache.drill.exec.planner.sql.DrillSqlAggOperatorWithoutInference;
import org.apache.drill.exec.planner.sql.DrillSqlOperator;

import com.google.common.collect.ArrayListMultimap;
import org.apache.drill.exec.planner.sql.DrillSqlOperatorWithoutInference;

/**
 * Registry of Drill functions.
 */
public class LocalFunctionRegistry implements AutoCloseable {

  public static final String BUILT_IN = "built-in";

  private static final Logger logger = LoggerFactory.getLogger(LocalFunctionRegistry.class);
  private static final String functionSignaturePattern = "%s(%s)";

  private static final ImmutableMap<String, Pair<Integer, Integer>> registeredFuncNameToArgRange =
      ImmutableMap.<String, Pair<Integer, Integer>> builder()
      // CONCAT is allowed to take [1, infinity) number of arguments.
      // Currently, this flexibility is offered by DrillOptiq to rewrite it as
      // a nested structure
      .put("CONCAT", Pair.of(1, Integer.MAX_VALUE))

      // When LENGTH is given two arguments, this function relies on DrillOptiq to rewrite it as
      // another function based on the second argument (encodingType)
      .put("LENGTH", Pair.of(1, 2))

      // Dummy functions
      .put("CONVERT_TO", Pair.of(2, 2))
      .put("CONVERT_FROM", Pair.of(2, 2))
      .put("FLATTEN", Pair.of(1, 1)).build();

  private final FunctionRegistryHolder registryHolder;

  /**
   * Registers all functions present in Drill classpath on start-up.
   * All functions will be marked as built-in. Built-in functions are not allowed to be unregistered.
   * Since local function registry version is based on remote function registry version,
   * initially sync version will be set to {@link DataChangeVersion#UNDEFINED}
   * to ensure that upon first check both registries would be synchronized.
   */
  public LocalFunctionRegistry(ScanResult classpathScan) {
    registryHolder = new FunctionRegistryHolder();
    validate(BUILT_IN, classpathScan);
    register(Lists.newArrayList(new JarScan(BUILT_IN, classpathScan, this.getClass().getClassLoader())), DataChangeVersion.UNDEFINED);
    if (logger.isTraceEnabled()) {
      StringBuilder allFunctions = new StringBuilder();
      for (DrillFuncHolder method: registryHolder.getAllFunctionsWithHolders().values()) {
        allFunctions.append(method.toString()).append("\n");
      }
      logger.trace("Registered functions: [\n{}]", allFunctions);
    }
  }

  /**
   * @return remote function registry version number with which local function
   *         registry is synced
   */
  public int getVersion() {
    return registryHolder.getVersion();
  }

  /**
   * Validates all functions, present in jars.
   * Will throw {@link FunctionValidationException} if:
   * <ol>
   *  <li>Jar with the same name has been already registered.</li>
   *  <li>Conflicting function with the similar signature is found.</li>
   *  <li>Aggregating function is not deterministic.</li>
   *</ol>
   * @param jarName jar name to be validated
   * @param scanResult scan of all classes present in jar
   * @return list of validated function signatures
   */
  public List<String> validate(String jarName, ScanResult scanResult) {
    List<String> functions = Lists.newArrayList();
    FunctionConverter converter = new FunctionConverter();
    List<AnnotatedClassDescriptor> providerClasses = scanResult.getAnnotatedClasses(FunctionTemplate.class.getName());

    if (registryHolder.containsJar(jarName)) {
      throw new JarValidationException(String.format("Jar with %s name has been already registered", jarName));
    }

    final ListMultimap<String, String> allFuncWithSignatures = registryHolder.getAllFunctionsWithSignatures();

    for (AnnotatedClassDescriptor func : providerClasses) {
      DrillFuncHolder holder = converter.getHolder(func, ClassLoader.getSystemClassLoader());
      if (holder != null) {
        String functionInput = holder.getInputParameters();

        String[] names = holder.getRegisteredNames();
        for (String name : names) {
          String functionName = name.toLowerCase();
          String functionSignature = String.format(functionSignaturePattern, functionName, functionInput);

          if (allFuncWithSignatures.get(functionName).contains(functionSignature)) {
            throw new FunctionValidationException(String.format("Found duplicated function in %s: %s",
                registryHolder.getJarNameByFunctionSignature(functionName, functionSignature), functionSignature));
          } else if (holder.isAggregating() && !holder.isDeterministic()) {
            throw new FunctionValidationException(
                String.format("Aggregate functions must be deterministic: %s", func.getClassName()));
          } else {
            functions.add(functionSignature);
            allFuncWithSignatures.put(functionName, functionSignature);
          }
        }
      } else {
        logger.warn("Unable to initialize function for class {}", func.getClassName());
      }
    }
    return functions;
  }

  /**
   * Registers all functions present in jar and updates registry version.
   * If jar name is already registered, all jar related functions will be overridden.
   * To prevent classpath collisions during loading and unloading jars,
   * each jar is shipped with its own class loader.
   *
   * @param jars list of jars to be registered
   * @param version remote function registry version number with which local function registry is synced
   */
  public void register(List<JarScan> jars, int version) {
    Map<String, List<FunctionHolder>> newJars = new HashMap<>();
    for (JarScan jarScan : jars) {
      FunctionConverter converter = new FunctionConverter();
      List<AnnotatedClassDescriptor> providerClasses = jarScan.getScanResult().getAnnotatedClasses(FunctionTemplate.class.getName());
      List<FunctionHolder> functions = new ArrayList<>();
      newJars.put(jarScan.getJarName(), functions);
      for (AnnotatedClassDescriptor func : providerClasses) {
        DrillFuncHolder holder = converter.getHolder(func, jarScan.getClassLoader());
        if (holder != null) {
          String functionInput = holder.getInputParameters();
          String[] names = holder.getRegisteredNames();
          for (String name : names) {
            String functionName = name.toLowerCase();
            String functionSignature = String.format(functionSignaturePattern, functionName, functionInput);
            functions.add(new FunctionHolder(functionName, functionSignature, holder));
          }
        }
      }
    }
    registryHolder.addJars(newJars, version);
  }

  /**
   * Removes all function associated with the given jar name.
   * Functions marked as built-in is not allowed to be unregistered.
   * If user attempts to unregister built-in functions, logs warning and does nothing.
   * Jar name is case-sensitive.
   *
   * @param jarName jar name to be unregistered
   */
  public void unregister(String jarName) {
    if (BUILT_IN.equals(jarName)) {
      logger.warn("Functions marked as built-in are not allowed to be unregistered.");
      return;
    }
    registryHolder.removeJar(jarName);
  }

  /**
   * Returns list of jar names registered in function registry.
   *
   * @return list of jar names
   */
  public List<String> getAllJarNames() {
    return registryHolder.getAllJarNames();
  }

  /**
   * @return quantity of all registered functions
   */
  public int size(){
    return registryHolder.functionsSize();
  }

  /**
   * @param name function name
   * @return all function holders associated with the function name. Function name is case insensitive.
   */
  public List<DrillFuncHolder> getMethods(String name, AtomicInteger version) {
    return registryHolder.getHoldersByFunctionName(name.toLowerCase(), version);
  }

  /**
   * @param name function name
   * @return all function holders associated with the function name. Function name is case insensitive.
   */
  public List<DrillFuncHolder> getMethods(String name) {
    return registryHolder.getHoldersByFunctionName(name.toLowerCase());
  }

  /**
   * Returns a map of all function holders mapped by source jars
   * @return all functions organized by source jars
   */
  public Map<String, List<FunctionHolder>> getAllJarsWithFunctionsHolders() {
    return registryHolder.getAllJarsWithFunctionHolders();
  }

  /**
   * Registers all functions present in {@link DrillOperatorTable},
   * also sets sync registry version used at the moment of function registration.
   *
   * @param operatorTable drill operator table
   */
  public void register(DrillOperatorTable operatorTable) {
    AtomicInteger versionHolder = new AtomicInteger();
    final Map<String, Collection<DrillFuncHolder>> registeredFunctions =
        registryHolder.getAllFunctionsWithHolders(versionHolder).asMap();
    operatorTable.setFunctionRegistryVersion(versionHolder.get());
    registerOperatorsWithInference(operatorTable, registeredFunctions);
    registerOperatorsWithoutInference(operatorTable, registeredFunctions);
  }

  private void registerOperatorsWithInference(DrillOperatorTable operatorTable, Map<String,
      Collection<DrillFuncHolder>> registeredFunctions) {
    final Map<String, DrillSqlOperator.DrillSqlOperatorBuilder> map = new HashMap<>();
    final Map<String, DrillSqlAggOperator.DrillSqlAggOperatorBuilder> mapAgg = new HashMap<>();
    for (Entry<String, Collection<DrillFuncHolder>> function : registeredFunctions.entrySet()) {
      final ArrayListMultimap<Pair<Integer, Integer>, DrillFuncHolder> functions = ArrayListMultimap.create();
      final ArrayListMultimap<Integer, DrillFuncHolder> aggregateFunctions = ArrayListMultimap.create();
      final String name = function.getKey().toUpperCase();
      boolean isDeterministic = true;
      boolean isNiladic = false;
      boolean isVarArg = false;
      for (DrillFuncHolder func : function.getValue()) {
        final int paramCount = func.getParamCount();
        if (func.isAggregating()) {
          aggregateFunctions.put(paramCount, func);
        } else {
          final Pair<Integer, Integer> argNumberRange;
          if(registeredFuncNameToArgRange.containsKey(name)) {
            argNumberRange = registeredFuncNameToArgRange.get(name);
          } else {
            argNumberRange = Pair.of(func.getParamCount(), func.getParamCount());
          }
          functions.put(argNumberRange, func);
        }

        if (!func.isDeterministic() || func.isComplexWriterFuncHolder()) {
          isDeterministic = false;
        }

        if (func.isNiladic()) {
          isNiladic = true;
        }

        if (func.isVarArg()) {
          isVarArg = true;
        }
      }
      for (Entry<Pair<Integer, Integer>, Collection<DrillFuncHolder>> entry : functions.asMap().entrySet()) {
        final Pair<Integer, Integer> range = entry.getKey();
        final int max = range.getRight();
        final int min = range.getLeft();
        if(!map.containsKey(name)) {
          map.put(name, new DrillSqlOperator.DrillSqlOperatorBuilder()
              .setName(name));
        }

        final DrillSqlOperator.DrillSqlOperatorBuilder drillSqlOperatorBuilder = map.get(name);
        drillSqlOperatorBuilder
            .addFunctions(entry.getValue())
            .setVarArg(isVarArg)
            .setArgumentCount(min, max)
            .setDeterministic(isDeterministic)
            .setNiladic(isNiladic);
      }
      for (Entry<Integer, Collection<DrillFuncHolder>> entry : aggregateFunctions.asMap().entrySet()) {
        if(!mapAgg.containsKey(name)) {
          mapAgg.put(name, new DrillSqlAggOperator.DrillSqlAggOperatorBuilder().setName(name));
        }

        final DrillSqlAggOperator.DrillSqlAggOperatorBuilder drillSqlAggOperatorBuilder = mapAgg.get(name);
        drillSqlAggOperatorBuilder
            .addFunctions(entry.getValue())
            .setArgumentCount(entry.getKey(), entry.getKey());
      }
    }

    for (final Entry<String, DrillSqlOperator.DrillSqlOperatorBuilder> entry : map.entrySet()) {
      operatorTable.addOperatorWithInference(
          entry.getKey(),
          entry.getValue().build());
    }

    for (final Entry<String, DrillSqlAggOperator.DrillSqlAggOperatorBuilder> entry : mapAgg.entrySet()) {
      operatorTable.addOperatorWithInference(
          entry.getKey(),
          entry.getValue().build());
    }
  }

  private void registerOperatorsWithoutInference(DrillOperatorTable operatorTable, Map<String, Collection<DrillFuncHolder>> registeredFunctions) {
    SqlOperator op;
    for (Entry<String, Collection<DrillFuncHolder>> function : registeredFunctions.entrySet()) {
      Set<Integer> argCounts = new HashSet<>();
      String name = function.getKey().toUpperCase();
      for (DrillFuncHolder func : function.getValue()) {
        if (argCounts.add(func.getParamCount())) {
          if (func.isAggregating()) {
            op = new DrillSqlAggOperatorWithoutInference(name, func.getParamCount(), func.isVarArg());
          } else {
            boolean isDeterministic;
            // prevent Drill from folding constant functions with types that cannot be materialized
            // into literals
            if (DrillConstExecutor.NON_REDUCIBLE_TYPES.contains(func.getReturnType().getMinorType()) || func.isComplexWriterFuncHolder()) {
              isDeterministic = false;
            } else {
              isDeterministic = func.isDeterministic();
            }
            op = new DrillSqlOperatorWithoutInference(name, func.getParamCount(),
                func.getReturnType(), isDeterministic, func.isNiladic(), func.isVarArg());
          }
          operatorTable.addOperatorWithoutInference(function.getKey(), op);
        }
      }
    }
  }

  @Override
  public void close() {
    registryHolder.close();
  }
}
