blob: 8150165cf01a88523259cb439c315824823d858f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.pig.builtin;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.pig.backend.executionengine.ExecException;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.Tuple;
import org.apache.pig.impl.PigContext;
import org.apache.pig.impl.logicalLayer.FrontendException;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
public class Invoker<T> {
private static final Log LOG = LogFactory.getLog(Invoker.class);
private static final Class<?> DOUBLE_ARRAY_CLASS = new double[0].getClass();
private static final Class<?> INT_ARRAY_CLASS = new int[0].getClass();
private static final Class<?> FLOAT_ARRAY_CLASS = new float[0].getClass();
private static final Class<?> STRING_ARRAY_CLASS = new String[0].getClass();
private static final Class<?> LONG_ARRAY_CLASS = new long[0].getClass();
@SuppressWarnings("unchecked")
private static final Set<Class<?>> ARRAY_CLASSES = Sets.newHashSet(
DOUBLE_ARRAY_CLASS, INT_ARRAY_CLASS, FLOAT_ARRAY_CLASS, STRING_ARRAY_CLASS,
LONG_ARRAY_CLASS);
private Method method_;
private Class<?>[] paramClasses_;
private boolean isStatic_;
private Class<?> selfClass_;
private Type returnType_;
public Invoker(String fullName, String paramSpecsStr)
throws ClassNotFoundException, FrontendException, SecurityException, NoSuchMethodException {
this(fullName, paramSpecsStr, "true");
}
public Invoker(String fullName, String paramSpecsStr, String isStatic)
throws ClassNotFoundException, FrontendException, SecurityException, NoSuchMethodException {
String className = fullName.substring(0, fullName.lastIndexOf('.'));
String methodName = fullName.substring(fullName.lastIndexOf('.')+1);
Class<?> klazz;
try {
klazz = PigContext.resolveClassName(className);
} catch (IOException e) {
// the amusing part is that PigContext throws this to wrap one of
// the exceptions we declare!
throw new FrontendException(e);
}
String[] paramSpecs = "".equals(paramSpecsStr) ? new String[0] : paramSpecsStr.split(" ");
isStatic_ = "static".equalsIgnoreCase(isStatic) || "true".equals(isStatic);
paramClasses_ = new Class<?>[paramSpecs.length];
for (int i = 0; i < paramSpecs.length; i++) {
paramClasses_[i] = stringToClass(paramSpecs[i]);
}
if (!isStatic_) {
selfClass_ = paramClasses_[0];
}
method_ = klazz.getMethod(methodName, (isStatic_ ? paramClasses_ : dropFirstClass(paramClasses_)));
returnType_ = method_.getGenericReturnType();
}
@SuppressWarnings("rawtypes")
public Type getReturnType() {
return unPrimitivize((Class) returnType_);
}
private static Class<?>[] dropFirstClass(Class<?>[] original) {
if (original.length < 2) {
return new Class[0];
} else {
return Arrays.copyOfRange(original, 1, original.length-1);
}
}
private static Object[] dropFirstObject(Object[] original) {
if (original.length < 2) {
return new Object[0];
} else {
return Arrays.copyOfRange(original, 1, original.length-1);
}
}
private static Class<?> stringToClass(String klass) throws FrontendException {
if ("string".equalsIgnoreCase(klass)) {
return String.class;
} else if ("int".equalsIgnoreCase(klass)) {
return Integer.TYPE;
} else if ("double".equalsIgnoreCase(klass)) {
return Double.TYPE;
} else if ("float".equalsIgnoreCase(klass)){
return Float.TYPE;
} else if ("long".equalsIgnoreCase(klass)) {
return Long.TYPE;
} else if ("double[]".equalsIgnoreCase(klass)) {
return DOUBLE_ARRAY_CLASS;
} else if ("int[]".equalsIgnoreCase(klass)) {
return INT_ARRAY_CLASS;
} else if ("long[]".equalsIgnoreCase(klass)) {
return LONG_ARRAY_CLASS;
} else if ("float[]".equalsIgnoreCase(klass)) {
return FLOAT_ARRAY_CLASS;
} else if ("string[]".equalsIgnoreCase(klass)) {
return STRING_ARRAY_CLASS;
} else {
throw new FrontendException("unable to find matching class for " + klass);
}
}
private static Class<?> unPrimitivize(Class<?> klass) {
if (klass.equals(Integer.TYPE)) {
return Integer.class;
} if (klass.equals(Long.TYPE)) {
return Long.class;
} else if (klass.equals(Float.TYPE)) {
return Float.class;
} else if (klass.equals(Double.TYPE)) {
return Double.class;
} else {
return klass;
}
}
private static <T> T convertToExpectedArg(Class<T> klass, Object obj) throws ExecException {
if (ARRAY_CLASSES.contains(klass)) {
DataBag dbag = (DataBag) obj;
if (STRING_ARRAY_CLASS.equals(klass)) {
List<String> dataList = Lists.newArrayList();
for (Tuple t : dbag) {
dataList.add( (String) t.get(0));
}
String[] dataArray = new String[dataList.size()];
for (int i = 0; i < dataList.size(); i++) {
dataArray[i] = dataList.get(i);
}
obj = dataArray;
} else {
List<Number> dataList = bagToNumberList(dbag);
if (DOUBLE_ARRAY_CLASS.equals(klass)) {
double[] dataArray = new double[dataList.size()];
for (int i = 0; i < dataList.size(); i++) {
dataArray[i] = dataList.get(i).doubleValue();
}
obj = dataArray;
} else if (INT_ARRAY_CLASS.equals(klass)) {
int[] dataArray = new int[dataList.size()];
for (int i = 0; i < dataList.size(); i++) {
dataArray[i] = dataList.get(i).intValue();
}
obj = dataArray;
} else if (FLOAT_ARRAY_CLASS.equals(klass)) {
float[] dataArray = new float[dataList.size()];
for (int i = 0; i < dataList.size(); i++) {
dataArray[i] = dataList.get(i).floatValue();
}
obj = dataArray;
} else if (LONG_ARRAY_CLASS.equals(klass)) {
long[] dataArray = new long[dataList.size()];
for (int i = 0; i < dataList.size(); i++) {
dataArray[i] = dataList.get(i).longValue();
}
obj = dataArray;
}
}
}
try {
return klass.cast(obj);
} catch (ClassCastException e) {
LOG.error("Error in dynamic argument processing. Casting to: "
+ klass + " from: " + obj.getClass(), e);
throw new ExecException(e);
}
}
private static List<Number> bagToNumberList(DataBag dbag) throws ExecException {
List<Number> dataList = Lists.newArrayList();
for (Tuple t : dbag) {
dataList.add( (Number) t.get(0));
}
return dataList;
}
private Object[] tupleToArgs(Tuple t) throws ExecException {
if ( (t == null && (paramClasses_ != null || paramClasses_.length != 0))
|| (t != null && t.size() < paramClasses_.length)) {
throw new ExecException("unable to match function arguments to declared signature.");
}
if (t == null) {
return null;
}
Object[] args = new Object[paramClasses_.length];
for (int i = 0; i < paramClasses_.length; i++) {
args[i] = convertToExpectedArg(unPrimitivize(paramClasses_[i]), t.get(i));
}
return args;
}
@SuppressWarnings("unchecked")
public T invoke(Tuple input) throws IOException {
Object[] args = tupleToArgs(input);
try {
if (!isStatic_) {
return (T) method_.invoke(selfClass_.cast(args[0]), dropFirstObject(args));
} else {
return (T) method_.invoke(null, args);
}
} catch (IllegalArgumentException e) {
throw new ExecException(e);
} catch (IllegalAccessException e) {
throw new ExecException(e);
} catch (InvocationTargetException e) {
throw new ExecException(e);
}
}
}