package edu.uci.ics.hivesterix.logical.plan.visitor;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

import org.apache.commons.lang3.mutable.Mutable;
import org.apache.commons.lang3.mutable.MutableObject;
import org.apache.hadoop.hive.ql.exec.ColumnInfo;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;

import edu.uci.ics.hivesterix.logical.plan.visitor.base.DefaultVisitor;
import edu.uci.ics.hivesterix.logical.plan.visitor.base.Translator;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.ILogicalExpression;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.ILogicalOperator;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.LogicalVariable;
import edu.uci.ics.hyracks.algebricks.core.algebra.operators.logical.InnerJoinOperator;
import edu.uci.ics.hyracks.algebricks.core.algebra.operators.logical.ProjectOperator;

@SuppressWarnings("rawtypes")
public class MapJoinVisitor extends DefaultVisitor {

    /**
     * map a join operator (in hive) to its parent operators (in asterix)
     */
    private HashMap<Operator, List<Mutable<ILogicalOperator>>> opMap = new HashMap<Operator, List<Mutable<ILogicalOperator>>>();

    @Override
    public Mutable<ILogicalOperator> visit(MapJoinOperator operator,
            Mutable<ILogicalOperator> AlgebricksParentOperatorRef, Translator t) {
        List<Operator<? extends Serializable>> joinSrc = operator.getParentOperators();
        List<Mutable<ILogicalOperator>> parents = opMap.get(operator);
        if (parents == null) {
            parents = new ArrayList<Mutable<ILogicalOperator>>();
            opMap.put(operator, parents);
        }
        parents.add(AlgebricksParentOperatorRef);
        if (joinSrc.size() != parents.size())
            return null;

        ILogicalOperator currentOperator;
        // make an map join operator
        // TODO: will have trouble for n-way joins
        MapJoinDesc joinDesc = (MapJoinDesc) operator.getConf();

        Map<Byte, List<ExprNodeDesc>> keyMap = joinDesc.getKeys();
        // get the projection expression (already re-written) from each source
        // table
        Map<Byte, List<ExprNodeDesc>> exprMap = joinDesc.getExprs();

        int inputSize = operator.getParentOperators().size();
        // get a list of reduce sink descs (input descs)

        // get the parent operator
        List<Mutable<ILogicalOperator>> parentOps = parents;

        List<String> fieldNames = new ArrayList<String>();
        List<TypeInfo> types = new ArrayList<TypeInfo>();
        for (Operator ts : joinSrc) {
            List<ColumnInfo> columns = ts.getSchema().getSignature();
            for (ColumnInfo col : columns) {
                fieldNames.add(col.getInternalName());
                types.add(col.getType());
            }
        }

        // get number of equality conjunctions in the final join condition
        Set<Entry<Byte, List<ExprNodeDesc>>> keyEntries = keyMap.entrySet();
        Iterator<Entry<Byte, List<ExprNodeDesc>>> entry = keyEntries.iterator();

        int size = 0;
        if (entry.hasNext())
            size = entry.next().getValue().size();

        // make up the join conditon expression
        List<ExprNodeDesc> joinConditionChildren = new ArrayList<ExprNodeDesc>();
        for (int i = 0; i < size; i++) {
            // create a join key pair
            List<ExprNodeDesc> keyPair = new ArrayList<ExprNodeDesc>();
            for (int j = 0; j < inputSize; j++) {
                keyPair.add(keyMap.get(Byte.valueOf((byte) j)).get(i));
            }
            // create a hive equal condition
            ExprNodeDesc equality = new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo,
                    new GenericUDFOPEqual(), keyPair);
            // add the equal condition to the conjunction list
            joinConditionChildren.add(equality);
        }
        // get final conjunction expression
        ExprNodeDesc conjunct = null;

        if (joinConditionChildren.size() > 1)
            conjunct = new ExprNodeGenericFuncDesc(TypeInfoFactory.booleanTypeInfo, new GenericUDFOPAnd(),
                    joinConditionChildren);
        else if (joinConditionChildren.size() == 1)
            conjunct = joinConditionChildren.get(0);
        else {
            // there is no join equality condition, full outer join
            conjunct = new ExprNodeConstantDesc(TypeInfoFactory.booleanTypeInfo, new Boolean(true));
        }
        // get an ILogicalExpression from hive's expression
        Mutable<ILogicalExpression> expression = t.translateScalarFucntion(conjunct);

        ArrayList<LogicalVariable> left = new ArrayList<LogicalVariable>();
        ArrayList<LogicalVariable> right = new ArrayList<LogicalVariable>();

        Set<Entry<Byte, List<ExprNodeDesc>>> kentries = keyMap.entrySet();
        Iterator<Entry<Byte, List<ExprNodeDesc>>> kiterator = kentries.iterator();
        int iteration = 0;
        ILogicalOperator assignOperator = null;
        while (kiterator.hasNext()) {
            List<ExprNodeDesc> outputExprs = kiterator.next().getValue();

            if (iteration == 0)
                assignOperator = t.getAssignOperator(AlgebricksParentOperatorRef, outputExprs, left);
            else
                assignOperator = t.getAssignOperator(AlgebricksParentOperatorRef, outputExprs, right);

            if (assignOperator != null) {
                currentOperator = assignOperator;
                AlgebricksParentOperatorRef = new MutableObject<ILogicalOperator>(currentOperator);
            }
            iteration++;
        }

        List<Mutable<ILogicalOperator>> inputs = parentOps;

        // get the join operator
        currentOperator = new InnerJoinOperator(expression);

        // set the inputs from asterix join operator
        for (Mutable<ILogicalOperator> input : inputs)
            currentOperator.getInputs().add(input);
        AlgebricksParentOperatorRef = new MutableObject<ILogicalOperator>(currentOperator);

        // add assign and project operator
        // output variables
        ArrayList<LogicalVariable> variables = new ArrayList<LogicalVariable>();
        Set<Entry<Byte, List<ExprNodeDesc>>> entries = exprMap.entrySet();
        Iterator<Entry<Byte, List<ExprNodeDesc>>> iterator = entries.iterator();
        while (iterator.hasNext()) {
            List<ExprNodeDesc> outputExprs = iterator.next().getValue();
            assignOperator = t.getAssignOperator(AlgebricksParentOperatorRef, outputExprs, variables);

            if (assignOperator != null) {
                currentOperator = assignOperator;
                AlgebricksParentOperatorRef = new MutableObject<ILogicalOperator>(currentOperator);
            }
        }

        currentOperator = new ProjectOperator(variables);
        currentOperator.getInputs().add(AlgebricksParentOperatorRef);
        t.rewriteOperatorOutputSchema(variables, operator);
        // opMap.clear();
        return new MutableObject<ILogicalOperator>(currentOperator);
    }
}
