blob: d254499ed28148f0064409c5db2663fe17fc56a2 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tajo.plan.function;
import org.apache.hadoop.io.*;
import org.apache.tajo.catalog.FunctionDesc;
import org.apache.tajo.common.TajoDataTypes;
import org.apache.tajo.datum.Datum;
import org.apache.tajo.exception.*;
import org.apache.tajo.function.UDFInvocationDesc;
import org.apache.tajo.storage.Tuple;
import org.apache.tajo.plan.util.WritableTypeConverter;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.Arrays;
public class HiveFunctionInvoke extends FunctionInvoke implements Cloneable {
private Object instance = null;
private Method evalMethod = null;
private Writable [] params;
public HiveFunctionInvoke(FunctionDesc desc) {
super(desc);
params = new Writable[desc.getParamTypes().length];
}
@Override
public void init(FunctionInvokeContext context) throws IOException {
UDFInvocationDesc udfDesc = functionDesc.getInvocation().getUDF();
URL [] urls = new URL [] { new URL(udfDesc.getPath()) };
URLClassLoader loader = new URLClassLoader(urls);
try {
Class<?> udfclass = loader.loadClass(udfDesc.getName());
evalMethod = getEvaluateMethod(functionDesc.getParamTypes(), udfclass);
} catch (ClassNotFoundException e) {
throw new TajoInternalError(e);
}
}
private Method getEvaluateMethod(TajoDataTypes.DataType [] tajoParamTypes, Class<?> clazz) {
Constructor constructor = clazz.getConstructors()[0];
try {
instance = constructor.newInstance();
} catch (InstantiationException|IllegalAccessException|InvocationTargetException e) {
throw new TajoInternalError(e);
}
for (Method m: clazz.getMethods()) {
if (m.getName().equals("evaluate")) {
Class [] methodParamTypes = m.getParameterTypes();
if (checkParamTypes(methodParamTypes, tajoParamTypes)) {
return m;
}
}
}
throw new TajoInternalError(new UndefinedFunctionException(String.format("Hive UDF (%s)", clazz.getSimpleName())));
}
private boolean checkParamTypes(Class [] writableParams, TajoDataTypes.DataType [] tajoParams) {
int i = 0;
if (writableParams.length != tajoParams.length) {
return false;
}
for (Class writable: writableParams) {
try {
if (!WritableTypeConverter.convertWritableToTajoType(writable).equals(tajoParams[i++])) {
return false;
}
} catch (UnsupportedDataTypeException e) {
throw new TajoRuntimeException(e);
}
}
return true;
}
@Override
public Datum eval(Tuple tuple) {
Datum resultDatum;
for (int i=0; i<tuple.size(); i++) {
params[i] = WritableTypeConverter.convertDatum2Writable(tuple.asDatum(i));
}
try {
Writable result = (Writable)evalMethod.invoke(instance, params);
resultDatum = WritableTypeConverter.convertWritable2Datum(result);
} catch (Exception e) {
throw new TajoInternalError(e);
}
return resultDatum;
}
}