blob: b77fe49ca7af33ffcb3901e2f7c8c6fc12f1d409 [file] [log] [blame]
package edu.uci.ics.hivesterix.logical.expression;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.mutable.Mutable;
import org.apache.commons.lang3.mutable.MutableObject;
import org.apache.hadoop.hive.ql.plan.AggregationDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.ILogicalExpression;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.IOptimizationContext;
import edu.uci.ics.hyracks.algebricks.core.algebra.base.LogicalVariable;
import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.AggregateFunctionCallExpression;
import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.IMergeAggregationExpressionFactory;
import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.VariableReferenceExpression;
import edu.uci.ics.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
/**
* generate merge aggregation expression from an aggregation expression
*
* @author yingyib
*/
public class HiveMergeAggregationExpressionFactory implements IMergeAggregationExpressionFactory {
public static IMergeAggregationExpressionFactory INSTANCE = new HiveMergeAggregationExpressionFactory();
@Override
public ILogicalExpression createMergeAggregation(ILogicalExpression expr, IOptimizationContext context)
throws AlgebricksException {
/**
* type inference for scalar function
*/
if (expr instanceof AggregateFunctionCallExpression) {
AggregateFunctionCallExpression funcExpr = (AggregateFunctionCallExpression) expr;
/**
* hive aggregation info
*/
AggregationDesc aggregator = (AggregationDesc) ((HiveFunctionInfo) funcExpr.getFunctionInfo()).getInfo();
LogicalVariable inputVar = context.newVar();
ExprNodeDesc col = new ExprNodeColumnDesc(TypeInfoFactory.voidTypeInfo, inputVar.toString(), null, false);
ArrayList<ExprNodeDesc> parameters = new ArrayList<ExprNodeDesc>();
parameters.add(col);
GenericUDAFEvaluator.Mode mergeMode;
if (aggregator.getMode() == GenericUDAFEvaluator.Mode.PARTIAL1)
mergeMode = GenericUDAFEvaluator.Mode.PARTIAL2;
else if (aggregator.getMode() == GenericUDAFEvaluator.Mode.COMPLETE)
mergeMode = GenericUDAFEvaluator.Mode.FINAL;
else
mergeMode = aggregator.getMode();
AggregationDesc mergeDesc = new AggregationDesc(aggregator.getGenericUDAFName(),
aggregator.getGenericUDAFEvaluator(), parameters, aggregator.getDistinct(), mergeMode);
String UDAFName = mergeDesc.getGenericUDAFName();
List<Mutable<ILogicalExpression>> arguments = new ArrayList<Mutable<ILogicalExpression>>();
arguments.add(new MutableObject<ILogicalExpression>(new VariableReferenceExpression(inputVar)));
FunctionIdentifier funcId = new FunctionIdentifier(ExpressionConstant.NAMESPACE, UDAFName + "("
+ mergeDesc.getMode() + ")");
HiveFunctionInfo funcInfo = new HiveFunctionInfo(funcId, mergeDesc);
AggregateFunctionCallExpression aggregationExpression = new AggregateFunctionCallExpression(funcInfo,
false, arguments);
return aggregationExpression;
} else {
throw new IllegalStateException("illegal expressions " + expr.getClass().getName());
}
}
}