| package edu.uci.ics.hivesterix.logical.expression; |
| |
| import java.util.ArrayList; |
| import java.util.List; |
| |
| import org.apache.commons.lang3.mutable.Mutable; |
| import org.apache.commons.lang3.mutable.MutableObject; |
| import org.apache.hadoop.hive.ql.plan.AggregationDesc; |
| import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; |
| import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; |
| import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; |
| import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; |
| |
| import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException; |
| import edu.uci.ics.hyracks.algebricks.core.algebra.base.ILogicalExpression; |
| import edu.uci.ics.hyracks.algebricks.core.algebra.base.IOptimizationContext; |
| import edu.uci.ics.hyracks.algebricks.core.algebra.base.LogicalVariable; |
| import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.AggregateFunctionCallExpression; |
| import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.IMergeAggregationExpressionFactory; |
| import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.VariableReferenceExpression; |
| import edu.uci.ics.hyracks.algebricks.core.algebra.functions.FunctionIdentifier; |
| |
| /** |
| * generate merge aggregation expression from an aggregation expression |
| * |
| * @author yingyib |
| */ |
| public class HiveMergeAggregationExpressionFactory implements IMergeAggregationExpressionFactory { |
| |
| public static IMergeAggregationExpressionFactory INSTANCE = new HiveMergeAggregationExpressionFactory(); |
| |
| @Override |
| public ILogicalExpression createMergeAggregation(ILogicalExpression expr, IOptimizationContext context) |
| throws AlgebricksException { |
| /** |
| * type inference for scalar function |
| */ |
| if (expr instanceof AggregateFunctionCallExpression) { |
| AggregateFunctionCallExpression funcExpr = (AggregateFunctionCallExpression) expr; |
| /** |
| * hive aggregation info |
| */ |
| AggregationDesc aggregator = (AggregationDesc) ((HiveFunctionInfo) funcExpr.getFunctionInfo()).getInfo(); |
| LogicalVariable inputVar = context.newVar(); |
| ExprNodeDesc col = new ExprNodeColumnDesc(TypeInfoFactory.voidTypeInfo, inputVar.toString(), null, false); |
| ArrayList<ExprNodeDesc> parameters = new ArrayList<ExprNodeDesc>(); |
| parameters.add(col); |
| |
| GenericUDAFEvaluator.Mode mergeMode; |
| if (aggregator.getMode() == GenericUDAFEvaluator.Mode.PARTIAL1) |
| mergeMode = GenericUDAFEvaluator.Mode.PARTIAL2; |
| else if (aggregator.getMode() == GenericUDAFEvaluator.Mode.COMPLETE) |
| mergeMode = GenericUDAFEvaluator.Mode.FINAL; |
| else |
| mergeMode = aggregator.getMode(); |
| AggregationDesc mergeDesc = new AggregationDesc(aggregator.getGenericUDAFName(), |
| aggregator.getGenericUDAFEvaluator(), parameters, aggregator.getDistinct(), mergeMode); |
| |
| String UDAFName = mergeDesc.getGenericUDAFName(); |
| List<Mutable<ILogicalExpression>> arguments = new ArrayList<Mutable<ILogicalExpression>>(); |
| arguments.add(new MutableObject<ILogicalExpression>(new VariableReferenceExpression(inputVar))); |
| |
| FunctionIdentifier funcId = new FunctionIdentifier(ExpressionConstant.NAMESPACE, UDAFName + "(" |
| + mergeDesc.getMode() + ")"); |
| HiveFunctionInfo funcInfo = new HiveFunctionInfo(funcId, mergeDesc); |
| AggregateFunctionCallExpression aggregationExpression = new AggregateFunctionCallExpression(funcInfo, |
| false, arguments); |
| return aggregationExpression; |
| } else { |
| throw new IllegalStateException("illegal expressions " + expr.getClass().getName()); |
| } |
| } |
| |
| } |