package edu.uci.ics.hivesterix.runtime.factory.evaluator;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluatorFactory;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.plan.AggregationDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.SerDe;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;

import edu.uci.ics.hivesterix.logical.expression.ExpressionTranslator;
import edu.uci.ics.hivesterix.runtime.evaluator.AggregatuibFunctionSerializableEvaluator;
import edu.uci.ics.hivesterix.runtime.jobgen.Schema;
import edu.uci.ics.hivesterix.serde.lazy.LazyFactory;
import edu.uci.ics.hivesterix.serde.lazy.LazyObject;
import edu.uci.ics.hivesterix.serde.lazy.LazySerDe;
import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.AggregateFunctionCallExpression;
import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.IVariableTypeEnvironment;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopySerializableAggregateFunction;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopySerializableAggregateFunctionFactory;

public class AggregationFunctionSerializableFactory implements ICopySerializableAggregateFunctionFactory {

    private static final long serialVersionUID = 1L;

    /**
     * list of parameters' serialization
     */
    private List<String> parametersSerialization = new ArrayList<String>();

    /**
     * the name of the udf
     */
    private String genericUDAFName;

    /**
     * aggregation mode
     */
    private GenericUDAFEvaluator.Mode mode;

    /**
     * list of type info
     */
    private List<TypeInfo> types = new ArrayList<TypeInfo>();

    /**
     * distinct or not
     */
    private boolean distinct;

    /**
     * the schema of incoming rows
     */
    private Schema rowSchema;

    /**
     * list of parameters
     */
    private transient List<ExprNodeDesc> parametersOrigin;

    /**
     * row inspector
     */
    private transient ObjectInspector rowInspector = null;

    /**
     * output object inspector
     */
    private transient ObjectInspector outputInspector = null;

    /**
     * output object inspector
     */
    private transient ObjectInspector outputInspectorPartial = null;

    /**
     * parameter inspectors
     */
    private transient ObjectInspector[] parameterInspectors = null;

    /**
     * expression desc
     */
    private transient HashMap<Long, List<ExprNodeDesc>> parameterExprs = new HashMap<Long, List<ExprNodeDesc>>();

    /**
     * evaluators
     */
    private transient HashMap<Long, ExprNodeEvaluator[]> evaluators = new HashMap<Long, ExprNodeEvaluator[]>();

    /**
     * cached parameter objects
     */
    private transient HashMap<Long, Object[]> cachedParameters = new HashMap<Long, Object[]>();

    /**
     * cached row object: one per thread
     */
    private transient HashMap<Long, LazyObject<? extends ObjectInspector>> cachedRowObjects = new HashMap<Long, LazyObject<? extends ObjectInspector>>();

    /**
     * we only use lazy serde to do serialization
     */
    private transient HashMap<Long, SerDe> serDe = new HashMap<Long, SerDe>();

    /**
     * udaf evaluators
     */
    private transient HashMap<Long, GenericUDAFEvaluator> udafsPartial = new HashMap<Long, GenericUDAFEvaluator>();

    /**
     * udaf evaluators
     */
    private transient HashMap<Long, GenericUDAFEvaluator> udafsComplete = new HashMap<Long, GenericUDAFEvaluator>();

    /**
     * aggregation function desc
     */
    private transient AggregationDesc aggregator;

    /**
     * @param aggregator
     *            Algebricks function call expression
     * @param oi
     *            schema
     */
    public AggregationFunctionSerializableFactory(AggregateFunctionCallExpression expression, Schema oi,
            IVariableTypeEnvironment env) throws AlgebricksException {

        try {
            aggregator = (AggregationDesc) ExpressionTranslator.getHiveExpression(expression, env);
        } catch (Exception e) {
            e.printStackTrace();
            throw new AlgebricksException(e.getMessage());
        }
        init(aggregator.getParameters(), aggregator.getGenericUDAFName(), aggregator.getMode(),
                aggregator.getDistinct(), oi);
    }

    /**
     * constructor of aggregation function factory
     * 
     * @param inputs
     * @param name
     * @param udafMode
     * @param distinct
     * @param oi
     */
    private void init(List<ExprNodeDesc> inputs, String name, GenericUDAFEvaluator.Mode udafMode, boolean distinct,
            Schema oi) {
        parametersOrigin = inputs;
        genericUDAFName = name;
        mode = udafMode;
        this.distinct = distinct;
        rowSchema = oi;

        for (ExprNodeDesc input : inputs) {
            TypeInfo type = input.getTypeInfo();
            if (type instanceof StructTypeInfo) {
                types.add(TypeInfoFactory.doubleTypeInfo);
            } else
                types.add(type);

            String s = Utilities.serializeExpression(input);
            parametersSerialization.add(s);
        }
    }

    @Override
    public synchronized ICopySerializableAggregateFunction createAggregateFunction() throws AlgebricksException {
        if (parametersOrigin == null) {
            Configuration config = new Configuration();
            config.setClassLoader(this.getClass().getClassLoader());
            /**
             * in case of class.forname(...) call in hive code
             */
            Thread.currentThread().setContextClassLoader(this.getClass().getClassLoader());

            parametersOrigin = new ArrayList<ExprNodeDesc>();
            for (String serialization : parametersSerialization) {
                parametersOrigin.add(Utilities.deserializeExpression(serialization, config));
            }
        }

        /**
         * exprs
         */
        if (parameterExprs == null)
            parameterExprs = new HashMap<Long, List<ExprNodeDesc>>();

        /**
         * evaluators
         */
        if (evaluators == null)
            evaluators = new HashMap<Long, ExprNodeEvaluator[]>();

        /**
         * cached parameter objects
         */
        if (cachedParameters == null)
            cachedParameters = new HashMap<Long, Object[]>();

        /**
         * cached row object: one per thread
         */
        if (cachedRowObjects == null)
            cachedRowObjects = new HashMap<Long, LazyObject<? extends ObjectInspector>>();

        /**
         * we only use lazy serde to do serialization
         */
        if (serDe == null)
            serDe = new HashMap<Long, SerDe>();

        /**
         * UDAF functions
         */
        if (udafsComplete == null)
            udafsComplete = new HashMap<Long, GenericUDAFEvaluator>();

        /**
         * UDAF functions
         */
        if (udafsPartial == null)
            udafsPartial = new HashMap<Long, GenericUDAFEvaluator>();

        if (parameterInspectors == null)
            parameterInspectors = new ObjectInspector[parametersOrigin.size()];

        if (rowInspector == null)
            rowInspector = rowSchema.toObjectInspector();

        // get current thread id
        long threadId = Thread.currentThread().getId();

        /**
         * expressions, expressions are thread local
         */
        List<ExprNodeDesc> parameters = parameterExprs.get(threadId);
        if (parameters == null) {
            parameters = new ArrayList<ExprNodeDesc>();
            for (ExprNodeDesc parameter : parametersOrigin)
                parameters.add(parameter.clone());
            parameterExprs.put(threadId, parameters);
        }

        /**
         * cached parameter objects
         */
        Object[] cachedParas = cachedParameters.get(threadId);
        if (cachedParas == null) {
            cachedParas = new Object[parameters.size()];
            cachedParameters.put(threadId, cachedParas);
        }

        /**
         * cached row object: one per thread
         */
        LazyObject<? extends ObjectInspector> cachedRowObject = cachedRowObjects.get(threadId);
        if (cachedRowObject == null) {
            cachedRowObject = LazyFactory.createLazyObject(rowInspector);
            cachedRowObjects.put(threadId, cachedRowObject);
        }

        /**
         * we only use lazy serde to do serialization
         */
        SerDe lazySer = serDe.get(threadId);
        if (lazySer == null) {
            lazySer = new LazySerDe();
            serDe.put(threadId, lazySer);
        }

        /**
         * evaluators
         */
        ExprNodeEvaluator[] evals = evaluators.get(threadId);
        if (evals == null) {
            evals = new ExprNodeEvaluator[parameters.size()];
            evaluators.put(threadId, evals);
        }

        GenericUDAFEvaluator udafPartial;
        GenericUDAFEvaluator udafComplete;

        // initialize object inspectors
        try {
            /**
             * evaluators, udf, object inpsectors are shared in one thread
             */
            for (int i = 0; i < evals.length; i++) {
                if (evals[i] == null) {
                    evals[i] = ExprNodeEvaluatorFactory.get(parameters.get(i));
                    if (parameterInspectors[i] == null) {
                        parameterInspectors[i] = evals[i].initialize(rowInspector);
                    } else {
                        evals[i].initialize(rowInspector);
                    }
                }
            }

            udafComplete = udafsComplete.get(threadId);
            if (udafComplete == null) {
                try {
                    udafComplete = FunctionRegistry.getGenericUDAFEvaluator(genericUDAFName, types, distinct, false);
                } catch (HiveException e) {
                    throw new AlgebricksException(e);
                }
                udafsComplete.put(threadId, udafComplete);
                udafComplete.init(mode, parameterInspectors);
            }

            // multiple stage group by, determined by the mode parameter
            if (outputInspector == null)
                outputInspector = udafComplete.init(mode, parameterInspectors);

            // initial partial gby udaf
            GenericUDAFEvaluator.Mode partialMode;
            // adjust mode for external groupby
            if (mode == GenericUDAFEvaluator.Mode.COMPLETE)
                partialMode = GenericUDAFEvaluator.Mode.PARTIAL1;
            else if (mode == GenericUDAFEvaluator.Mode.FINAL)
                partialMode = GenericUDAFEvaluator.Mode.PARTIAL2;
            else
                partialMode = mode;
            udafPartial = udafsPartial.get(threadId);
            if (udafPartial == null) {
                try {
                    udafPartial = FunctionRegistry.getGenericUDAFEvaluator(genericUDAFName, types, distinct, false);
                } catch (HiveException e) {
                    throw new AlgebricksException(e);
                }
                udafPartial.init(partialMode, parameterInspectors);
                udafsPartial.put(threadId, udafPartial);
            }

            // multiple stage group by, determined by the mode parameter
            if (outputInspectorPartial == null)
                outputInspectorPartial = udafPartial.init(partialMode, parameterInspectors);
        } catch (Exception e) {
            e.printStackTrace();
            throw new AlgebricksException(e);
        }

        return new AggregatuibFunctionSerializableEvaluator(parameters, types, genericUDAFName, mode, distinct,
                rowInspector, evals, parameterInspectors, cachedParas, lazySer, cachedRowObject, udafPartial,
                udafComplete, outputInspector, outputInspectorPartial);
    }

    public String toString() {
        return "aggregation function expression evaluator factory: " + this.genericUDAFName;
    }

}
