package edu.uci.ics.hivesterix.logical.plan.visitor;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

import org.apache.commons.lang3.mutable.Mutable;
import org.apache.commons.lang3.mutable.MutableObject;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.plan.JoinCondDesc;
import org.apache.hadoop.hive.ql.plan.JoinDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;

import edu.uci.ics.hivesterix.logical.plan.visitor.base.DefaultVisitor;
import edu.uci.ics.hivesterix.logical.plan.visitor.base.Translator;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.ILogicalExpression;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.ILogicalOperator;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.LogicalVariable;
import edu.uci.ics.hyracks.algebricks.core.algebra.operators.logical.InnerJoinOperator;
import edu.uci.ics.hyracks.algebricks.core.algebra.operators.logical.LeftOuterJoinOperator;
import edu.uci.ics.hyracks.algebricks.core.algebra.operators.logical.ProjectOperator;

@SuppressWarnings("rawtypes")
public class JoinVisitor extends DefaultVisitor {

    /**
     * reduce sink operator to variables
     */
    private HashMap<Operator, List<LogicalVariable>> reduceSinkToKeyVariables = new HashMap<Operator, List<LogicalVariable>>();

    /**
     * reduce sink operator to variables
     */
    private HashMap<Operator, List<String>> reduceSinkToFieldNames = new HashMap<Operator, List<String>>();

    /**
     * reduce sink operator to variables
     */
    private HashMap<Operator, List<TypeInfo>> reduceSinkToTypes = new HashMap<Operator, List<TypeInfo>>();

    /**
     * map a join operator (in hive) to its parent operators (in hive)
     */
    private HashMap<Operator, List<Operator>> operatorToHiveParents = new HashMap<Operator, List<Operator>>();

    /**
     * map a join operator (in hive) to its parent operators (in asterix)
     */
    private HashMap<Operator, List<ILogicalOperator>> operatorToAsterixParents = new HashMap<Operator, List<ILogicalOperator>>();

    /**
     * the latest traversed reduce sink operator
     */
    private Operator latestReduceSink = null;

    /**
     * the latest generated parent for join
     */
    private ILogicalOperator latestAlgebricksOperator = null;

    /**
     * process a join operator
     */
    @Override
    public Mutable<ILogicalOperator> visit(JoinOperator operator, Mutable<ILogicalOperator> AlgebricksParentOperator,
            Translator t) {
        latestAlgebricksOperator = AlgebricksParentOperator.getValue();
        translateJoinOperatorPreprocess(operator, t);
        List<Operator> parents = operatorToHiveParents.get(operator);
        if (parents.size() < operator.getParentOperators().size()) {
            return null;
        } else {
            ILogicalOperator joinOp = translateJoinOperator(operator, AlgebricksParentOperator, t);
            // clearStatus();
            return new MutableObject<ILogicalOperator>(joinOp);
        }
    }

    private void reorder(Byte[] order, List<ILogicalOperator> parents, List<Operator> hiveParents) {
        ILogicalOperator[] lops = new ILogicalOperator[parents.size()];
        Operator[] ops = new Operator[hiveParents.size()];

        for (Operator op : hiveParents) {
            ReduceSinkOperator rop = (ReduceSinkOperator) op;
            ReduceSinkDesc rdesc = rop.getConf();
            int tag = rdesc.getTag();

            int index = -1;
            for (int i = 0; i < order.length; i++)
                if (order[i] == tag) {
                    index = i;
                    break;
                }
            lops[index] = parents.get(hiveParents.indexOf(op));
            ops[index] = op;
        }

        parents.clear();
        hiveParents.clear();

        for (int i = 0; i < lops.length; i++) {
            parents.add(lops[i]);
            hiveParents.add(ops[i]);
        }
    }

    /**
     * translate a hive join operator to asterix join operator->assign
     * operator->project operator
     * 
     * @param parentOperator
     * @param operator
     * @return
     */
    private ILogicalOperator translateJoinOperator(Operator operator, Mutable<ILogicalOperator> parentOperator,
            Translator t) {

        JoinDesc joinDesc = (JoinDesc) operator.getConf();

        // get the projection expression (already re-written) from each source
        // table
        Map<Byte, List<ExprNodeDesc>> exprMap = joinDesc.getExprs();
        reorder(joinDesc.getTagOrder(), operatorToAsterixParents.get(operator), operatorToHiveParents.get(operator));

        // make an reduce join operator
        ILogicalOperator currentOperator = generateJoinTree(joinDesc.getCondsList(),
                operatorToAsterixParents.get(operator), operatorToHiveParents.get(operator), 0, t);
        parentOperator = new MutableObject<ILogicalOperator>(currentOperator);

        // add assign and project operator on top of a join
        // output variables
        ArrayList<LogicalVariable> variables = new ArrayList<LogicalVariable>();
        Set<Entry<Byte, List<ExprNodeDesc>>> entries = exprMap.entrySet();
        Iterator<Entry<Byte, List<ExprNodeDesc>>> iterator = entries.iterator();
        while (iterator.hasNext()) {
            List<ExprNodeDesc> outputExprs = iterator.next().getValue();
            ILogicalOperator assignOperator = t.getAssignOperator(parentOperator, outputExprs, variables);

            if (assignOperator != null) {
                currentOperator = assignOperator;
                parentOperator = new MutableObject<ILogicalOperator>(currentOperator);
            }
        }

        ILogicalOperator po = new ProjectOperator(variables);
        po.getInputs().add(parentOperator);
        t.rewriteOperatorOutputSchema(variables, operator);
        return po;
    }

    /**
     * deal with reduce sink operator for the case of join
     */
    @Override
    public Mutable<ILogicalOperator> visit(ReduceSinkOperator operator, Mutable<ILogicalOperator> parentOperator,
            Translator t) {

        Operator downStream = (Operator) operator.getChildOperators().get(0);
        if (!(downStream instanceof JoinOperator))
            return null;

        ReduceSinkDesc desc = (ReduceSinkDesc) operator.getConf();
        List<ExprNodeDesc> keys = desc.getKeyCols();
        List<ExprNodeDesc> values = desc.getValueCols();
        List<ExprNodeDesc> partitionCols = desc.getPartitionCols();

        /**
         * rewrite key, value, paritioncol expressions
         */
        for (ExprNodeDesc key : keys)
            t.rewriteExpression(key);
        for (ExprNodeDesc value : values)
            t.rewriteExpression(value);
        for (ExprNodeDesc col : partitionCols)
            t.rewriteExpression(col);

        ILogicalOperator currentOperator = null;

        // add assign operator for keys if necessary
        ArrayList<LogicalVariable> keyVariables = new ArrayList<LogicalVariable>();
        ILogicalOperator assignOperator = t.getAssignOperator(parentOperator, keys, keyVariables);
        if (assignOperator != null) {
            currentOperator = assignOperator;
            parentOperator = new MutableObject<ILogicalOperator>(currentOperator);
        }

        // add assign operator for values if necessary
        ArrayList<LogicalVariable> variables = new ArrayList<LogicalVariable>();
        assignOperator = t.getAssignOperator(parentOperator, values, variables);
        if (assignOperator != null) {
            currentOperator = assignOperator;
            parentOperator = new MutableObject<ILogicalOperator>(currentOperator);
        }

        // unified schema: key, value
        ArrayList<LogicalVariable> unifiedKeyValues = new ArrayList<LogicalVariable>();
        unifiedKeyValues.addAll(keyVariables);
        for (LogicalVariable value : variables)
            if (keyVariables.indexOf(value) < 0)
                unifiedKeyValues.add(value);

        // insert projection operator, it is a *must*,
        // in hive, reduce sink sometimes also do the projection operator's
        // task
        currentOperator = new ProjectOperator(unifiedKeyValues);
        currentOperator.getInputs().add(parentOperator);
        parentOperator = new MutableObject<ILogicalOperator>(currentOperator);

        reduceSinkToKeyVariables.put(operator, keyVariables);
        List<String> fieldNames = new ArrayList<String>();
        List<TypeInfo> types = new ArrayList<TypeInfo>();
        for (LogicalVariable var : unifiedKeyValues) {
            fieldNames.add(var.toString());
            types.add(t.getType(var));
        }
        reduceSinkToFieldNames.put(operator, fieldNames);
        reduceSinkToTypes.put(operator, types);
        t.rewriteOperatorOutputSchema(variables, operator);

        latestAlgebricksOperator = currentOperator;
        latestReduceSink = operator;
        return new MutableObject<ILogicalOperator>(currentOperator);
    }

    /**
     * partial rewrite a join operator
     * 
     * @param operator
     * @param t
     */
    private void translateJoinOperatorPreprocess(Operator operator, Translator t) {
        JoinDesc desc = (JoinDesc) operator.getConf();
        ReduceSinkDesc reduceSinkDesc = (ReduceSinkDesc) latestReduceSink.getConf();
        int tag = reduceSinkDesc.getTag();

        Map<Byte, List<ExprNodeDesc>> exprMap = desc.getExprs();
        List<ExprNodeDesc> exprs = exprMap.get(Byte.valueOf((byte) tag));

        for (ExprNodeDesc expr : exprs)
            t.rewriteExpression(expr);

        List<Operator> parents = operatorToHiveParents.get(operator);
        if (parents == null) {
            parents = new ArrayList<Operator>();
            operatorToHiveParents.put(operator, parents);
        }
        parents.add(latestReduceSink);

        List<ILogicalOperator> asterixParents = operatorToAsterixParents.get(operator);
        if (asterixParents == null) {
            asterixParents = new ArrayList<ILogicalOperator>();
            operatorToAsterixParents.put(operator, asterixParents);
        }
        asterixParents.add(latestAlgebricksOperator);
    }

    // generate a join tree from a list of exchange/reducesink operator
    // both exchanges and reduce sinks have the same order
    private ILogicalOperator generateJoinTree(List<JoinCondDesc> conds, List<ILogicalOperator> exchanges,
            List<Operator> reduceSinks, int offset, Translator t) {
        // get a list of reduce sink descs (input descs)
        int inputSize = reduceSinks.size() - offset;

        if (inputSize == 2) {
            ILogicalOperator currentRoot;

            List<ReduceSinkDesc> reduceSinkDescs = new ArrayList<ReduceSinkDesc>();
            for (int i = reduceSinks.size() - 1; i >= offset; i--)
                reduceSinkDescs.add((ReduceSinkDesc) reduceSinks.get(i).getConf());

            // get the object inspector for the join
            List<String> fieldNames = new ArrayList<String>();
            List<TypeInfo> types = new ArrayList<TypeInfo>();
            for (int i = reduceSinks.size() - 1; i >= offset; i--) {
                fieldNames.addAll(reduceSinkToFieldNames.get(reduceSinks.get(i)));
                types.addAll(reduceSinkToTypes.get(reduceSinks.get(i)));
            }

            // get number of equality conjunctions in the final join condition
            int size = reduceSinkDescs.get(0).getKeyCols().size();

            // make up the join conditon expression
            List<ExprNodeDesc> joinConditionChildren = new ArrayList<ExprNodeDesc>();
            for (int i = 0; i < size; i++) {
                // create a join key pair
                List<ExprNodeDesc> keyPair = new ArrayList<ExprNodeDesc>();
                for (ReduceSinkDesc sink : reduceSinkDescs) {
                    keyPair.add(sink.getKeyCols().get(i));
                }
                // create a hive equal condition
                ExprNodeDesc equality = new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo,
                        new GenericUDFOPEqual(), keyPair);
                // add the equal condition to the conjunction list
                joinConditionChildren.add(equality);
            }
            // get final conjunction expression
            ExprNodeDesc conjunct = null;

            if (joinConditionChildren.size() > 1)
                conjunct = new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo, new GenericUDFOPAnd(),
                        joinConditionChildren);
            else if (joinConditionChildren.size() == 1)
                conjunct = joinConditionChildren.get(0);
            else {
                // there is no join equality condition, equal-join
                conjunct = new ExprNodeConstantDesc(TypeInfoFactory.booleanTypeInfo, new Boolean(true));
            }
            // get an ILogicalExpression from hive's expression
            Mutable<ILogicalExpression> expression = t.translateScalarFucntion(conjunct);

            Mutable<ILogicalOperator> leftBranch = new MutableObject<ILogicalOperator>(
                    exchanges.get(exchanges.size() - 1));
            Mutable<ILogicalOperator> rightBranch = new MutableObject<ILogicalOperator>(
                    exchanges.get(exchanges.size() - 2));
            // get the join operator
            if (conds.get(offset).getType() == JoinDesc.LEFT_OUTER_JOIN) {
                currentRoot = new LeftOuterJoinOperator(expression);
                Mutable<ILogicalOperator> temp = leftBranch;
                leftBranch = rightBranch;
                rightBranch = temp;
            } else if (conds.get(offset).getType() == JoinDesc.RIGHT_OUTER_JOIN) {
                currentRoot = new LeftOuterJoinOperator(expression);
            } else
                currentRoot = new InnerJoinOperator(expression);

            currentRoot.getInputs().add(leftBranch);
            currentRoot.getInputs().add(rightBranch);

            // rewriteOperatorOutputSchema(variables, operator);
            return currentRoot;
        } else {
            // get the child join operator and insert and one-to-one exchange
            ILogicalOperator joinSrcOne = generateJoinTree(conds, exchanges, reduceSinks, offset + 1, t);
            // joinSrcOne.addInput(childJoin);

            ILogicalOperator currentRoot;

            List<ReduceSinkDesc> reduceSinkDescs = new ArrayList<ReduceSinkDesc>();
            for (int i = offset; i < offset + 2; i++)
                reduceSinkDescs.add((ReduceSinkDesc) reduceSinks.get(i).getConf());

            // get the object inspector for the join
            List<String> fieldNames = new ArrayList<String>();
            List<TypeInfo> types = new ArrayList<TypeInfo>();
            for (int i = offset; i < reduceSinks.size(); i++) {
                fieldNames.addAll(reduceSinkToFieldNames.get(reduceSinks.get(i)));
                types.addAll(reduceSinkToTypes.get(reduceSinks.get(i)));
            }

            // get number of equality conjunctions in the final join condition
            int size = reduceSinkDescs.get(0).getKeyCols().size();

            // make up the join condition expression
            List<ExprNodeDesc> joinConditionChildren = new ArrayList<ExprNodeDesc>();
            for (int i = 0; i < size; i++) {
                // create a join key pair
                List<ExprNodeDesc> keyPair = new ArrayList<ExprNodeDesc>();
                for (ReduceSinkDesc sink : reduceSinkDescs) {
                    keyPair.add(sink.getKeyCols().get(i));
                }
                // create a hive equal condition
                ExprNodeDesc equality = new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo,
                        new GenericUDFOPEqual(), keyPair);
                // add the equal condition to the conjunction list
                joinConditionChildren.add(equality);
            }
            // get final conjunction expression
            ExprNodeDesc conjunct = null;

            if (joinConditionChildren.size() > 1)
                conjunct = new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo, new GenericUDFOPAnd(),
                        joinConditionChildren);
            else if (joinConditionChildren.size() == 1)
                conjunct = joinConditionChildren.get(0);
            else {
                // there is no join equality condition, full outer join
                conjunct = new ExprNodeConstantDesc(TypeInfoFactory.booleanTypeInfo, new Boolean(true));
            }
            // get an ILogicalExpression from hive's expression
            Mutable<ILogicalExpression> expression = t.translateScalarFucntion(conjunct);

            Mutable<ILogicalOperator> leftBranch = new MutableObject<ILogicalOperator>(joinSrcOne);
            Mutable<ILogicalOperator> rightBranch = new MutableObject<ILogicalOperator>(exchanges.get(offset));

            // get the join operator
            if (conds.get(offset).getType() == JoinDesc.LEFT_OUTER_JOIN) {
                currentRoot = new LeftOuterJoinOperator(expression);
                Mutable<ILogicalOperator> temp = leftBranch;
                leftBranch = rightBranch;
                rightBranch = temp;
            } else if (conds.get(offset).getType() == JoinDesc.RIGHT_OUTER_JOIN) {
                currentRoot = new LeftOuterJoinOperator(expression);
            } else
                currentRoot = new InnerJoinOperator(expression);

            // set the inputs from Algebricks join operator
            // add the current table
            currentRoot.getInputs().add(leftBranch);
            currentRoot.getInputs().add(rightBranch);

            return currentRoot;
        }
    }
}
