blob: 54a115526bc1e7ccc4a7089f549c772a0e8ae983 [file] [log] [blame]
package edu.uci.ics.hivesterix.runtime.factory.evaluator;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluatorFactory;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.plan.AggregationDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.SerDe;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import edu.uci.ics.hivesterix.logical.expression.ExpressionTranslator;
import edu.uci.ics.hivesterix.runtime.evaluator.AggregatuibFunctionSerializableEvaluator;
import edu.uci.ics.hivesterix.runtime.jobgen.Schema;
import edu.uci.ics.hivesterix.serde.lazy.LazyFactory;
import edu.uci.ics.hivesterix.serde.lazy.LazyObject;
import edu.uci.ics.hivesterix.serde.lazy.LazySerDe;
import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.AggregateFunctionCallExpression;
import edu.uci.ics.hyracks.algebricks.core.algebra.expressions.IVariableTypeEnvironment;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopySerializableAggregateFunction;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopySerializableAggregateFunctionFactory;
public class AggregationFunctionSerializableFactory implements ICopySerializableAggregateFunctionFactory {
private static final long serialVersionUID = 1L;
/**
* list of parameters' serialization
*/
private List<String> parametersSerialization = new ArrayList<String>();
/**
* the name of the udf
*/
private String genericUDAFName;
/**
* aggregation mode
*/
private GenericUDAFEvaluator.Mode mode;
/**
* list of type info
*/
private List<TypeInfo> types = new ArrayList<TypeInfo>();
/**
* distinct or not
*/
private boolean distinct;
/**
* the schema of incoming rows
*/
private Schema rowSchema;
/**
* list of parameters
*/
private transient List<ExprNodeDesc> parametersOrigin;
/**
* row inspector
*/
private transient ObjectInspector rowInspector = null;
/**
* output object inspector
*/
private transient ObjectInspector outputInspector = null;
/**
* output object inspector
*/
private transient ObjectInspector outputInspectorPartial = null;
/**
* parameter inspectors
*/
private transient ObjectInspector[] parameterInspectors = null;
/**
* expression desc
*/
private transient HashMap<Long, List<ExprNodeDesc>> parameterExprs = new HashMap<Long, List<ExprNodeDesc>>();
/**
* evaluators
*/
private transient HashMap<Long, ExprNodeEvaluator[]> evaluators = new HashMap<Long, ExprNodeEvaluator[]>();
/**
* cached parameter objects
*/
private transient HashMap<Long, Object[]> cachedParameters = new HashMap<Long, Object[]>();
/**
* cached row object: one per thread
*/
private transient HashMap<Long, LazyObject<? extends ObjectInspector>> cachedRowObjects = new HashMap<Long, LazyObject<? extends ObjectInspector>>();
/**
* we only use lazy serde to do serialization
*/
private transient HashMap<Long, SerDe> serDe = new HashMap<Long, SerDe>();
/**
* udaf evaluators
*/
private transient HashMap<Long, GenericUDAFEvaluator> udafsPartial = new HashMap<Long, GenericUDAFEvaluator>();
/**
* udaf evaluators
*/
private transient HashMap<Long, GenericUDAFEvaluator> udafsComplete = new HashMap<Long, GenericUDAFEvaluator>();
/**
* aggregation function desc
*/
private transient AggregationDesc aggregator;
/**
* @param aggregator
* Algebricks function call expression
* @param oi
* schema
*/
public AggregationFunctionSerializableFactory(AggregateFunctionCallExpression expression, Schema oi,
IVariableTypeEnvironment env) throws AlgebricksException {
try {
aggregator = (AggregationDesc) ExpressionTranslator.getHiveExpression(expression, env);
} catch (Exception e) {
e.printStackTrace();
throw new AlgebricksException(e.getMessage());
}
init(aggregator.getParameters(), aggregator.getGenericUDAFName(), aggregator.getMode(),
aggregator.getDistinct(), oi);
}
/**
* constructor of aggregation function factory
*
* @param inputs
* @param name
* @param udafMode
* @param distinct
* @param oi
*/
private void init(List<ExprNodeDesc> inputs, String name, GenericUDAFEvaluator.Mode udafMode, boolean distinct,
Schema oi) {
parametersOrigin = inputs;
genericUDAFName = name;
mode = udafMode;
this.distinct = distinct;
rowSchema = oi;
for (ExprNodeDesc input : inputs) {
TypeInfo type = input.getTypeInfo();
if (type instanceof StructTypeInfo) {
types.add(TypeInfoFactory.doubleTypeInfo);
} else
types.add(type);
String s = Utilities.serializeExpression(input);
parametersSerialization.add(s);
}
}
@Override
public synchronized ICopySerializableAggregateFunction createAggregateFunction() throws AlgebricksException {
if (parametersOrigin == null) {
Configuration config = new Configuration();
config.setClassLoader(this.getClass().getClassLoader());
/**
* in case of class.forname(...) call in hive code
*/
Thread.currentThread().setContextClassLoader(this.getClass().getClassLoader());
parametersOrigin = new ArrayList<ExprNodeDesc>();
for (String serialization : parametersSerialization) {
parametersOrigin.add(Utilities.deserializeExpression(serialization, config));
}
}
/**
* exprs
*/
if (parameterExprs == null)
parameterExprs = new HashMap<Long, List<ExprNodeDesc>>();
/**
* evaluators
*/
if (evaluators == null)
evaluators = new HashMap<Long, ExprNodeEvaluator[]>();
/**
* cached parameter objects
*/
if (cachedParameters == null)
cachedParameters = new HashMap<Long, Object[]>();
/**
* cached row object: one per thread
*/
if (cachedRowObjects == null)
cachedRowObjects = new HashMap<Long, LazyObject<? extends ObjectInspector>>();
/**
* we only use lazy serde to do serialization
*/
if (serDe == null)
serDe = new HashMap<Long, SerDe>();
/**
* UDAF functions
*/
if (udafsComplete == null)
udafsComplete = new HashMap<Long, GenericUDAFEvaluator>();
/**
* UDAF functions
*/
if (udafsPartial == null)
udafsPartial = new HashMap<Long, GenericUDAFEvaluator>();
if (parameterInspectors == null)
parameterInspectors = new ObjectInspector[parametersOrigin.size()];
if (rowInspector == null)
rowInspector = rowSchema.toObjectInspector();
// get current thread id
long threadId = Thread.currentThread().getId();
/**
* expressions, expressions are thread local
*/
List<ExprNodeDesc> parameters = parameterExprs.get(threadId);
if (parameters == null) {
parameters = new ArrayList<ExprNodeDesc>();
for (ExprNodeDesc parameter : parametersOrigin)
parameters.add(parameter.clone());
parameterExprs.put(threadId, parameters);
}
/**
* cached parameter objects
*/
Object[] cachedParas = cachedParameters.get(threadId);
if (cachedParas == null) {
cachedParas = new Object[parameters.size()];
cachedParameters.put(threadId, cachedParas);
}
/**
* cached row object: one per thread
*/
LazyObject<? extends ObjectInspector> cachedRowObject = cachedRowObjects.get(threadId);
if (cachedRowObject == null) {
cachedRowObject = LazyFactory.createLazyObject(rowInspector);
cachedRowObjects.put(threadId, cachedRowObject);
}
/**
* we only use lazy serde to do serialization
*/
SerDe lazySer = serDe.get(threadId);
if (lazySer == null) {
lazySer = new LazySerDe();
serDe.put(threadId, lazySer);
}
/**
* evaluators
*/
ExprNodeEvaluator[] evals = evaluators.get(threadId);
if (evals == null) {
evals = new ExprNodeEvaluator[parameters.size()];
evaluators.put(threadId, evals);
}
GenericUDAFEvaluator udafPartial;
GenericUDAFEvaluator udafComplete;
// initialize object inspectors
try {
/**
* evaluators, udf, object inpsectors are shared in one thread
*/
for (int i = 0; i < evals.length; i++) {
if (evals[i] == null) {
evals[i] = ExprNodeEvaluatorFactory.get(parameters.get(i));
if (parameterInspectors[i] == null) {
parameterInspectors[i] = evals[i].initialize(rowInspector);
} else {
evals[i].initialize(rowInspector);
}
}
}
udafComplete = udafsComplete.get(threadId);
if (udafComplete == null) {
try {
udafComplete = FunctionRegistry.getGenericUDAFEvaluator(genericUDAFName, types, distinct, false);
} catch (HiveException e) {
throw new AlgebricksException(e);
}
udafsComplete.put(threadId, udafComplete);
udafComplete.init(mode, parameterInspectors);
}
// multiple stage group by, determined by the mode parameter
if (outputInspector == null)
outputInspector = udafComplete.init(mode, parameterInspectors);
// initial partial gby udaf
GenericUDAFEvaluator.Mode partialMode;
// adjust mode for external groupby
if (mode == GenericUDAFEvaluator.Mode.COMPLETE)
partialMode = GenericUDAFEvaluator.Mode.PARTIAL1;
else if (mode == GenericUDAFEvaluator.Mode.FINAL)
partialMode = GenericUDAFEvaluator.Mode.PARTIAL2;
else
partialMode = mode;
udafPartial = udafsPartial.get(threadId);
if (udafPartial == null) {
try {
udafPartial = FunctionRegistry.getGenericUDAFEvaluator(genericUDAFName, types, distinct, false);
} catch (HiveException e) {
throw new AlgebricksException(e);
}
udafPartial.init(partialMode, parameterInspectors);
udafsPartial.put(threadId, udafPartial);
}
// multiple stage group by, determined by the mode parameter
if (outputInspectorPartial == null)
outputInspectorPartial = udafPartial.init(partialMode, parameterInspectors);
} catch (Exception e) {
e.printStackTrace();
throw new AlgebricksException(e);
}
return new AggregatuibFunctionSerializableEvaluator(parameters, types, genericUDAFName, mode, distinct,
rowInspector, evals, parameterInspectors, cachedParas, lazySer, cachedRowObject, udafPartial,
udafComplete, outputInspector, outputInspectorPartial);
}
public String toString() {
return "aggregation function expression evaluator factory: " + this.genericUDAFName;
}
}