blob: 8bccdc4324c1a397cc7fa1966559efa286a652c9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tajo.plan.function.python;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.tajo.TajoConstants;
import org.apache.tajo.catalog.*;
import org.apache.tajo.catalog.proto.CatalogProtos;
import org.apache.tajo.common.TajoDataTypes;
import org.apache.tajo.conf.TajoConf;
import org.apache.tajo.datum.Datum;
import org.apache.tajo.function.FunctionInvocation;
import org.apache.tajo.function.FunctionSignature;
import org.apache.tajo.function.FunctionSupplement;
import org.apache.tajo.function.PythonInvocationDesc;
import org.apache.tajo.plan.function.FunctionContext;
import org.apache.tajo.plan.function.PythonAggFunctionInvoke.PythonAggFunctionContext;
import org.apache.tajo.plan.function.stream.*;
import org.apache.tajo.storage.Tuple;
import org.apache.tajo.storage.VTuple;
import org.apache.tajo.unit.StorageUnit;
import org.apache.tajo.util.FileUtil;
import java.io.*;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;
/**
* {@link PythonScriptEngine} is responsible for registering python functions and maintaining the controller process.
* The controller is a python process that executes the python UDFs.
* (Please refer to 'tajo-core/src/main/resources/python/controller.py')
* Data are exchanged via standard I/O between PythonScriptEngine and the controller.
*/
public class PythonScriptEngine extends TajoScriptEngine {
private static final Log LOG = LogFactory.getLog(PythonScriptEngine.class);
public static final String FILE_EXTENSION = ".py";
/**
* Register functions defined in a python script
*
* @param path path to the python script file
* @param namespace namespace where the functions will be defined
* @return set of function descriptions
* @throws IOException
*/
public static Set<FunctionDesc> registerFunctions(URI path, String namespace) throws IOException {
// TODO: we should support the namespace for python functions.
Set<FunctionDesc> functionDescs = new HashSet<>();
List<FunctionInfo> functions = null;
try (InputStream in = getScriptAsStream(path)) {
functions = getFunctions(in);
}
for(FunctionInfo funcInfo : functions) {
try {
FunctionSignature signature;
FunctionInvocation invocation = new FunctionInvocation();
FunctionSupplement supplement = new FunctionSupplement();
if (funcInfo instanceof ScalarFuncInfo) {
ScalarFuncInfo scalarFuncInfo = (ScalarFuncInfo) funcInfo;
TajoDataTypes.DataType returnType = getReturnTypes(scalarFuncInfo)[0];
signature = new FunctionSignature(CatalogProtos.FunctionType.UDF, scalarFuncInfo.funcName,
returnType, createParamTypes(scalarFuncInfo.paramNum));
PythonInvocationDesc invocationDesc = new PythonInvocationDesc(scalarFuncInfo.funcName, path.getPath(), true);
invocation.setPython(invocationDesc);
functionDescs.add(new FunctionDesc(signature, invocation, supplement));
} else {
AggFuncInfo aggFuncInfo = (AggFuncInfo) funcInfo;
if (isValidUdaf(aggFuncInfo)) {
TajoDataTypes.DataType returnType = getReturnTypes(aggFuncInfo.getFinalResultInfo)[0];
signature = new FunctionSignature(CatalogProtos.FunctionType.UDA, aggFuncInfo.funcName,
returnType, createParamTypes(aggFuncInfo.evalInfo.paramNum));
PythonInvocationDesc invocationDesc = new PythonInvocationDesc(aggFuncInfo.className, path.getPath(), false);
invocation.setPythonAggregation(invocationDesc);
functionDescs.add(new FunctionDesc(signature, invocation, supplement));
}
}
} catch (Throwable t) {
// ignore invalid functions
LOG.warn(t);
}
}
return functionDescs;
}
private static boolean isValidUdaf(AggFuncInfo aggFuncInfo) {
if (aggFuncInfo.className != null && aggFuncInfo.evalInfo != null && aggFuncInfo.mergeInfo != null
&& aggFuncInfo.getPartialResultInfo != null && aggFuncInfo.getFinalResultInfo != null) {
return true;
}
return false;
}
private static TajoDataTypes.DataType[] createParamTypes(int paramNum) {
TajoDataTypes.DataType[] paramTypes = new TajoDataTypes.DataType[paramNum];
for (int i = 0; i < paramNum; i++) {
paramTypes[i] = TajoDataTypes.DataType.newBuilder().setType(TajoDataTypes.Type.ANY).build();
}
return paramTypes;
}
private static TajoDataTypes.DataType[] getReturnTypes(ScalarFuncInfo scalarFuncInfo) {
TajoDataTypes.DataType[] returnTypes = new TajoDataTypes.DataType[scalarFuncInfo.returnTypes.length];
for (int i = 0; i < scalarFuncInfo.returnTypes.length; i++) {
returnTypes[i] = CatalogUtil.newSimpleDataType(TajoDataTypes.Type.valueOf(scalarFuncInfo.returnTypes[i]));
}
return returnTypes;
}
private static final Pattern pSchema = Pattern.compile("^\\s*\\W+output_type.*");
private static final Pattern pDef = Pattern.compile("^\\s*def\\s+(\\w+)\\s*.+");
private static final Pattern pClass = Pattern.compile("class.*");
private interface FunctionInfo {
}
private static class AggFuncInfo implements FunctionInfo {
String className;
String funcName;
ScalarFuncInfo evalInfo;
ScalarFuncInfo mergeInfo;
ScalarFuncInfo getPartialResultInfo;
ScalarFuncInfo getFinalResultInfo;
}
private static class ScalarFuncInfo implements FunctionInfo {
String[] returnTypes;
String funcName;
int paramNum;
public ScalarFuncInfo(String[] quotedSchemaStrings, String funcName, int paramNum) {
this.returnTypes = new String[quotedSchemaStrings.length];
for (int i = 0; i < quotedSchemaStrings.length; i++) {
String quotedString = quotedSchemaStrings[i].trim();
String[] tokens = quotedString.substring(1, quotedString.length()-1).split(":");
this.returnTypes[i] = tokens.length == 1 ? tokens[0].toUpperCase() : tokens[1].toUpperCase();
}
this.funcName = funcName;
this.paramNum = paramNum;
}
}
// TODO: python parser must be improved.
private static List<FunctionInfo> getFunctions(InputStream is) throws IOException {
List<FunctionInfo> functions = new ArrayList<>();
InputStreamReader in = new InputStreamReader(is, Charset.defaultCharset());
BufferedReader br = new BufferedReader(in);
String line = br.readLine();
String[] quotedSchemaStrings = null;
AggFuncInfo aggFuncInfo = null;
while (line != null) {
try {
if (pSchema.matcher(line).matches()) {
int start = line.indexOf("(") + 1; //drop brackets
int end = line.lastIndexOf(")");
quotedSchemaStrings = line.substring(start, end).trim().split(",");
} else if (pDef.matcher(line).matches()) {
boolean isUdaf = aggFuncInfo != null && line.indexOf("def ") > 0;
int nameStart = line.indexOf("def ") + "def ".length();
int nameEnd = line.indexOf('(');
int signatureEnd = line.indexOf(')');
String[] params = line.substring(nameEnd + 1, signatureEnd).split(",");
int paramNum;
if (params.length == 1) {
paramNum = params[0].equals("") ? 0 : 1;
} else {
paramNum = params.length;
}
if (params[0].trim().equals("self")) {
paramNum--;
}
String functionName = line.substring(nameStart, nameEnd).trim();
quotedSchemaStrings = quotedSchemaStrings == null ? new String[]{"'blob'"} : quotedSchemaStrings;
if (isUdaf) {
if (functionName.equals("eval")) {
aggFuncInfo.evalInfo = new ScalarFuncInfo(quotedSchemaStrings, functionName, paramNum);
} else if (functionName.equals("merge")) {
aggFuncInfo.mergeInfo = new ScalarFuncInfo(quotedSchemaStrings, functionName, paramNum);
} else if (functionName.equals("get_partial_result")) {
aggFuncInfo.getPartialResultInfo = new ScalarFuncInfo(quotedSchemaStrings, functionName, paramNum);
} else if (functionName.equals("get_final_result")) {
aggFuncInfo.getFinalResultInfo = new ScalarFuncInfo(quotedSchemaStrings, functionName, paramNum);
}
} else {
aggFuncInfo = null;
functions.add(new ScalarFuncInfo(quotedSchemaStrings, functionName, paramNum));
}
quotedSchemaStrings = null;
} else if (pClass.matcher(line).matches()) {
// UDAF
if (aggFuncInfo != null) {
functions.add(aggFuncInfo);
}
aggFuncInfo = new AggFuncInfo();
int classNameStart = line.indexOf("class ") + "class ".length();
int classNameEnd = line.indexOf("(");
if (classNameEnd < 0) {
classNameEnd = line.indexOf(":");
}
aggFuncInfo.className = line.substring(classNameStart, classNameEnd).trim();
aggFuncInfo.funcName = aggFuncInfo.className.toLowerCase();
}
} catch (Throwable t) {
// ignore unexpected function and source lines
LOG.warn(t);
}
line = br.readLine();
}
if (aggFuncInfo != null) {
functions.add(aggFuncInfo);
}
br.close();
in.close();
return functions;
}
private static final String PYTHON_LANGUAGE = "python";
private static final String TAJO_UTIL_NAME = "tajo_util.py";
private static final String CONTROLLER_NAME = "controller.py";
private static final String BASE_DIR = FileUtils.getTempDirectoryPath() + File.separator + "tajo-" + System.getProperty("user.name") + File.separator + "python";
private static final String PYTHON_CONTROLLER_JAR_PATH = "/python/" + CONTROLLER_NAME; // Relative to root of tajo jar.
private static final String PYTHON_TAJO_UTIL_JAR_PATH = "/python/" + TAJO_UTIL_NAME; // Relative to root of tajo jar.
// Indexes for arguments being passed to external process
enum COMMAND_IDX {
UDF_LANGUAGE,
PATH_TO_CONTROLLER_FILE,
UDF_FILE_NAME,
UDF_FILE_PATH,
PATH_TO_FILE_CACHE,
OUT_SCHEMA,
FUNCTION_OR_CLASS_NAME,
FUNCTION_TYPE,
}
private Configuration systemConf;
private Process process; // Handle to the external execution of python functions
private InputHandler inputHandler;
private OutputHandler outputHandler;
private DataOutputStream stdin; // stdin of the process
private InputStream stdout; // stdout of the process
private InputStream stderr; // stderr of the process
private final FunctionSignature functionSignature;
private final PythonInvocationDesc invocationDesc;
private Schema inSchema;
private Schema outSchema;
private int[] projectionCols;
private final CSVLineSerDe lineSerDe = new CSVLineSerDe();
private final TableMeta pipeMeta = CatalogUtil.newTableMeta("TEXT");
private final Tuple EMPTY_INPUT = new VTuple(0);
private final Schema EMPTY_SCHEMA = SchemaFactory.newV1();
public PythonScriptEngine(FunctionDesc functionDesc) {
if (!functionDesc.getInvocation().hasPython() && !functionDesc.getInvocation().hasPythonAggregation()) {
throw new IllegalStateException("Function type must be 'python'");
}
functionSignature = functionDesc.getSignature();
invocationDesc = functionDesc.getInvocation().getPython();
setSchema();
}
public PythonScriptEngine(FunctionDesc functionDesc, boolean firstPhase, boolean lastPhase) {
if (!functionDesc.getInvocation().hasPython() && !functionDesc.getInvocation().hasPythonAggregation()) {
throw new IllegalStateException("Function type must be 'python'");
}
functionSignature = functionDesc.getSignature();
invocationDesc = functionDesc.getInvocation().getPython();
this.firstPhase = firstPhase;
this.lastPhase = lastPhase;
setSchema();
}
@Override
public void start(Configuration systemConf) throws IOException {
this.systemConf = systemConf;
startUdfController();
setStreams();
createInputHandlers();
if (LOG.isDebugEnabled()) {
LOG.debug("PythonScriptExecutor starts up");
}
}
@Override
public void shutdown() {
FileUtil.cleanup(LOG, stdin, stdout, stderr, inputHandler, outputHandler);
stdin = null;
stdout = stderr = null;
inputHandler = null;
outputHandler = null;
try {
int exitCode = process.waitFor();
if (systemConf.get(TajoConstants.TEST_KEY, Boolean.FALSE.toString()).equalsIgnoreCase(Boolean.TRUE.toString())) {
LOG.warn("Process exit code: " + exitCode);
} else {
if (LOG.isDebugEnabled()) {
LOG.debug("Process exit code: " + exitCode);
}
}
} catch (InterruptedException e) {
LOG.warn(e.getMessage(), e);
}
if (LOG.isDebugEnabled()) {
LOG.debug("PythonScriptExecutor shuts down");
}
}
private void startUdfController() throws IOException {
ProcessBuilder processBuilder = StreamingUtil.createProcess(buildCommand());
process = processBuilder.start();
}
/**
* Build a command to execute an external process.
* @return
* @throws IOException
*/
private String[] buildCommand() throws IOException {
String[] command = new String[8];
String funcName = invocationDesc.getName();
String filePath = invocationDesc.getPath();
command[COMMAND_IDX.UDF_LANGUAGE.ordinal()] = PYTHON_LANGUAGE;
command[COMMAND_IDX.PATH_TO_CONTROLLER_FILE.ordinal()] = getControllerPath();
int lastSeparator = filePath.lastIndexOf(File.separator) + 1;
String fileName = filePath.substring(lastSeparator);
fileName = fileName.endsWith(FILE_EXTENSION) ? fileName.substring(0, fileName.length()-3) : fileName;
command[COMMAND_IDX.UDF_FILE_NAME.ordinal()] = fileName;
command[COMMAND_IDX.UDF_FILE_PATH.ordinal()] = lastSeparator <= 0 ? "." : filePath.substring(0, lastSeparator - 1);
String fileCachePath = systemConf.get(TajoConf.ConfVars.PYTHON_CODE_DIR.keyname());
if (fileCachePath == null) {
throw new IOException(TajoConf.ConfVars.PYTHON_CODE_DIR.keyname() + " must be set.");
}
command[COMMAND_IDX.PATH_TO_FILE_CACHE.ordinal()] = "'" + fileCachePath + "'";
command[COMMAND_IDX.OUT_SCHEMA.ordinal()] = outSchema.getColumn(0).getDataType().getType().name().toLowerCase();
command[COMMAND_IDX.FUNCTION_OR_CLASS_NAME.ordinal()] = funcName;
command[COMMAND_IDX.FUNCTION_TYPE.ordinal()] = invocationDesc.isScalarFunction() ? "UDF" : "UDAF";
return command;
}
private void setSchema() {
if (invocationDesc.isScalarFunction()) {
TajoDataTypes.DataType[] paramTypes = functionSignature.getParamTypes();
inSchema = SchemaFactory.newV1();
for (int i = 0; i < paramTypes.length; i++) {
inSchema.addColumn(new Column("in_" + i, paramTypes[i]));
}
outSchema = SchemaFactory.newV1(new Column[]{new Column("out", functionSignature.getReturnType())});
} else {
// UDAF
if (firstPhase) {
// first phase
TajoDataTypes.DataType[] paramTypes = functionSignature.getParamTypes();
inSchema = SchemaFactory.newV1();
for (int i = 0; i < paramTypes.length; i++) {
inSchema.addColumn(new Column("in_" + i, paramTypes[i]));
}
outSchema = SchemaFactory.newV1(new Column[]{new Column("json", TajoDataTypes.Type.TEXT)});
} else if (lastPhase) {
inSchema = SchemaFactory.newV1(new Column[]{new Column("json", TajoDataTypes.Type.TEXT)});
outSchema = SchemaFactory.newV1(new Column[]{new Column("out", functionSignature.getReturnType())});
} else {
// intermediate phase
inSchema = outSchema = SchemaFactory.newV1(new Column[]{new Column("json", TajoDataTypes.Type.TEXT)});
}
}
projectionCols = new int[outSchema.size()];
for (int i = 0; i < outSchema.size(); i++) {
projectionCols[i] = i;
}
}
private void createInputHandlers() throws IOException {
setSchema();
TextLineSerializer serializer = lineSerDe.createSerializer(pipeMeta);
serializer.init();
this.inputHandler = new InputHandler(serializer);
inputHandler.bindTo(stdin);
TextLineDeserializer deserializer = lineSerDe.createDeserializer(outSchema, pipeMeta, projectionCols);
deserializer.init();
this.outputHandler = new OutputHandler(deserializer);
outputHandler.bindTo(stdout);
}
/**
* Get the standard input, output, and error streams of the external process
*
* @throws IOException
*/
private void setStreams() {
stdout = new DataInputStream(new BufferedInputStream(process.getInputStream()));
stdin = new DataOutputStream(new BufferedOutputStream(process.getOutputStream()));
stderr = new DataInputStream(new BufferedInputStream(process.getErrorStream()));
}
private static final File pythonScriptBaseDir = new File(PythonScriptEngine.getBaseDirPath());
private static final File pythonScriptControllerCopy = new File(PythonScriptEngine.getControllerPath());
private static final File pythonScriptUtilCopy = new File(PythonScriptEngine.getTajoUtilPath());
public static void initPythonScriptEngineFiles() throws IOException {
if (!pythonScriptBaseDir.exists()) {
pythonScriptBaseDir.mkdirs();
}
// Controller and util should be always overwritten.
PythonScriptEngine.loadController(pythonScriptControllerCopy);
PythonScriptEngine.loadTajoUtil(pythonScriptUtilCopy);
}
public static void loadController(File controllerCopy) throws IOException {
try (InputStream controllerInputStream = PythonScriptEngine.class.getResourceAsStream(PYTHON_CONTROLLER_JAR_PATH)) {
FileUtils.copyInputStreamToFile(controllerInputStream, controllerCopy);
}
}
public static void loadTajoUtil(File utilCopy) throws IOException {
try (InputStream utilInputStream = PythonScriptEngine.class.getResourceAsStream(PYTHON_TAJO_UTIL_JAR_PATH)) {
FileUtils.copyInputStreamToFile(utilInputStream, utilCopy);
}
}
public static String getBaseDirPath() {
LOG.info("Python base dir is " + BASE_DIR);
return BASE_DIR;
}
/**
* Find the path to the controller file for the streaming language.
*
* @return
* @throws IOException
*/
public static String getControllerPath() {
return BASE_DIR + File.separator + CONTROLLER_NAME;
}
public static String getTajoUtilPath() {
return BASE_DIR + File.separator + TAJO_UTIL_NAME;
}
/**
* Call Python scalar functions.
*
* @param input input tuple
* @return evaluated result datum
*/
@Override
public Datum callScalarFunc(Tuple input) {
try {
inputHandler.putNext(input, inSchema);
stdin.flush();
} catch (Throwable e) {
throwException(stderr, new RuntimeException("Failed adding input to inputQueue", e));
}
Datum result = null;
try {
Tuple next = outputHandler.getNext();
if (next != null) {
result = next.asDatum(0);
} else {
throw new RuntimeException("Cannot get output result from python controller");
}
} catch (Throwable e) {
throwException(stderr, new RuntimeException("Problem getting output: " + e.getMessage(), e));
}
return result;
}
/**
* Call Python aggregation functions.
*
* @param functionContext python function context
* @param input input tuple
*/
@Override
public void callAggFunc(FunctionContext functionContext, Tuple input) {
String methodName;
if (firstPhase) {
// eval
methodName = "eval";
} else {
// merge
methodName = "merge";
}
try {
inputHandler.putNext(methodName, input, inSchema);
stdin.flush();
} catch (Throwable e) {
throwException(stderr, new RuntimeException("Failed adding input to inputQueue while executing "
+ methodName + " with " + input, e));
}
try {
outputHandler.getNext();
} catch (Exception e) {
throwException(stderr, new RuntimeException("Problem getting output: " + e.getMessage(), e));
}
}
/**
* Get the standard error streams of the external process and throw the exception
*
* @throws RuntimeException
*/
private void throwException(InputStream stderr, RuntimeException e) throws RuntimeException {
try {
if (stderr.available() > 0) {
byte[] bytes = new byte[Math.min(stderr.available(), 100 * StorageUnit.KB)];
IOUtils.readFully(stderr, bytes);
String message = new String(bytes, Charset.defaultCharset());
throw new RuntimeException("Python exception caused by: " + message, e);
} else {
throw e;
}
} catch (IOException ioe) {
throw new RuntimeException(ioe.getMessage(), ioe);
}
}
/**
* Restore the intermediate result in Python UDAF with the snapshot stored in the function context.
*
* @param functionContext
*/
public void updatePythonSideContext(PythonAggFunctionContext functionContext) throws IOException {
try {
inputHandler.putNext("update_context", functionContext);
stdin.flush();
} catch (Throwable e) {
throwException(stderr, new RuntimeException("Failed adding input to inputQueue", e));
}
try {
outputHandler.getNext();
} catch (Throwable e) {
throwException(stderr, new RuntimeException("Problem getting output: " + e.getMessage(), e));
}
}
/**
* Get the snapshot of the intermediate result in the Python UDAF.
*
* @param functionContext
*/
public void updateJavaSideContext(PythonAggFunctionContext functionContext) throws IOException {
try {
inputHandler.putNext("get_context", EMPTY_INPUT, EMPTY_SCHEMA);
stdin.flush();
} catch (Throwable e) {
throwException(stderr, new RuntimeException("Failed adding input to inputQueue", e));
}
try {
outputHandler.getNext(functionContext);
} catch (Throwable e) {
throwException(stderr, new RuntimeException("Problem getting output: " + e.getMessage(), e));
}
}
/**
* Get intermediate result after the first stage.
*
* @param functionContext
* @return
*/
@Override
public String getPartialResult(FunctionContext functionContext) {
try {
inputHandler.putNext("get_partial_result", EMPTY_INPUT, EMPTY_SCHEMA);
stdin.flush();
} catch (Throwable e) {
throwException(stderr, new RuntimeException("Failed adding input to inputQueue", e));
}
String result = null;
try {
result = outputHandler.getPartialResultString();
} catch (Throwable e) {
throwException(stderr, new RuntimeException("Problem getting output: " + e.getMessage(), e));
}
return result;
}
/**
* Get final result after the last stage.
*
* @param functionContext
* @return
*/
@Override
public Datum getFinalResult(FunctionContext functionContext) {
try {
inputHandler.putNext("get_final_result", EMPTY_INPUT, EMPTY_SCHEMA);
stdin.flush();
} catch (Exception e) {
throwException(stderr, new RuntimeException("Failed adding input to inputQueue", e));
}
Datum result = null;
try {
Tuple next = outputHandler.getNext();
if (next != null) {
result = next.asDatum(0);
} else {
throw new RuntimeException("Cannot get output result from python controller");
}
} catch (Exception e) {
throwException(stderr, new RuntimeException("Problem getting output: " + e.getMessage(), e));
}
return result;
}
}