blob: ad24ced2355f5b899eb07a3261e56c7a35f525e1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.instructions.cp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.lineage.Lineage;
import org.apache.sysds.runtime.lineage.LineageCache;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.utils.Statistics;
public class FunctionCallCPInstruction extends CPInstruction {
private static final Log LOG = LogFactory.getLog(FunctionCallCPInstruction.class.getName());
private final String _functionName;
private final String _namespace;
private final boolean _opt;
private final CPOperand[] _boundInputs;
private final List<String> _boundInputNames;
private final List<String> _funArgNames;
private final List<String> _boundOutputNames;
public FunctionCallCPInstruction(String namespace, String functName, boolean opt,
CPOperand[] boundInputs, List<String> funArgNames, List<String> boundOutputNames, String istr) {
super(CPType.FCall, null, functName, istr);
_functionName = functName;
_namespace = namespace;
_opt = opt;
_boundInputs = boundInputs;
_boundInputNames = Arrays.stream(boundInputs).map(i -> i.getName())
.collect(Collectors.toCollection(ArrayList::new));
_funArgNames = funArgNames;
_boundOutputNames = boundOutputNames;
}
public String getFunctionName() {
return _functionName;
}
public String getNamespace() {
return _namespace;
}
public static FunctionCallCPInstruction parseInstruction(String str) {
//schema: fcall, fnamespace, fname, opt, num inputs, num outputs, inputs (name-value pairs), outputs
String[] parts = InstructionUtils.getInstructionPartsWithValueType (str);
String namespace = parts[1];
String functionName = parts[2];
boolean opt = Boolean.parseBoolean(parts[3]);
int numInputs = Integer.valueOf(parts[4]);
int numOutputs = Integer.valueOf(parts[5]);
CPOperand[] boundInputs = new CPOperand[numInputs];
List<String> funArgNames = new ArrayList<>();
List<String> boundOutputNames = new ArrayList<>();
for (int i = 0; i < numInputs; i++) {
String[] nameValue = IOUtilFunctions.splitByFirst(parts[6 + i], "=");
boundInputs[i] = new CPOperand(nameValue[1]);
funArgNames.add(nameValue[0]);
}
for (int i = 0; i < numOutputs; i++)
boundOutputNames.add(parts[6 + numInputs + i]);
return new FunctionCallCPInstruction ( namespace, functionName,
opt, boundInputs, funArgNames, boundOutputNames, str );
}
@Override
public Instruction preprocessInstruction(ExecutionContext ec) {
//default pre-process behavior
return super.preprocessInstruction(ec);
}
@Override
public void processInstruction(ExecutionContext ec) {
if( LOG.isTraceEnabled() ){
LOG.trace("Executing instruction : " + toString());
}
// get the function program block (stored in the Program object)
FunctionProgramBlock fpb = ec.getProgram().getFunctionProgramBlock(_namespace, _functionName, _opt);
// sanity check number of function parameters
if( _boundInputs.length < fpb.getInputParams().size() ) {
throw new DMLRuntimeException("fcall "+_functionName+": "
+ "Number of bound input parameters does not match the function signature "
+ "("+_boundInputs.length+", but "+fpb.getInputParams().size()+" expected)");
}
// check if function outputs can be reused from cache
LineageItem[] liInputs = DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse() ?
LineageItemUtils.getLineage(ec, _boundInputs) : null;
if (!fpb.isNondeterministic() && reuseFunctionOutputs(liInputs, fpb, ec))
return; //only if all the outputs are found in cache
// create bindings to formal parameters for given function call
// These are the bindings passed to the FunctionProgramBlock for function execution
LocalVariableMap functionVariables = new LocalVariableMap();
Lineage lineage = DMLScript.LINEAGE ? new Lineage() : null;
for( int i=0; i<_boundInputs.length; i++) {
//error handling non-existing variables
CPOperand input = _boundInputs[i];
if( !input.isLiteral() && !ec.containsVariable(input.getName()) ) {
throw new DMLRuntimeException("Input variable '"+input.getName()+"' not existing on call of " +
DMLProgram.constructFunctionKey(_namespace, _functionName) + " (line "+getLineNum()+").");
}
//get input matrix/frame/scalar
String argName = _funArgNames.get(i);
DataIdentifier currFormalParam = fpb.getInputParam(argName);
if( currFormalParam == null ) {
throw new DMLRuntimeException("fcall "+_functionName+": Non-existing named "
+ "function argument: '"+argName+"' (line "+getLineNum()+").");
}
Data value = ec.getVariable(input);
//graceful value type conversion for scalar inputs with wrong type
if( value.getDataType() == DataType.SCALAR
&& value.getValueType() != currFormalParam.getValueType() )
{
value = ScalarObjectFactory.createScalarObject(
currFormalParam.getValueType(), (ScalarObject)value);
}
//set input parameter
functionVariables.put(currFormalParam.getName(), value);
//map lineage to function arguments
if( lineage != null ) {
LineageItem litem = ec.getLineageItem(input);
lineage.set(currFormalParam.getName(), (litem!=null) ?
litem : ec.getLineage().getOrCreate(input));
}
}
// Pin the input variables so that they do not get deleted
// from pb's symbol table at the end of execution of function
boolean[] pinStatus = ec.pinVariables(_boundInputNames);
// Create a symbol table under a new execution context for the function invocation,
// and copy the function arguments into the created table.
ExecutionContext fn_ec = ExecutionContextFactory.createContext(false, false, ec.getProgram());
if (DMLScript.USE_ACCELERATOR) {
fn_ec.setGPUContexts(ec.getGPUContexts());
fn_ec.getGPUContext(0).initializeThread();
}
fn_ec.setVariables(functionVariables);
fn_ec.setLineage(lineage);
// execute the function block
long t0 = !ReuseCacheType.isNone() ? System.nanoTime() : 0;
try {
fpb._functionName = this._functionName;
fpb._namespace = this._namespace;
fpb.execute(fn_ec);
}
catch (DMLScriptException e) {
throw e;
}
catch (Exception e){
String fname = DMLProgram.constructFunctionKey(_namespace, _functionName);
throw new DMLRuntimeException("error executing function " + fname, e);
}
long t1 = !ReuseCacheType.isNone() ? System.nanoTime() : 0;
// cleanup all returned variables w/o binding
HashSet<String> expectRetVars = new HashSet<>();
for(DataIdentifier di : fpb.getOutputParams())
expectRetVars.add(di.getName());
LocalVariableMap retVars = fn_ec.getVariables();
for( String varName : new ArrayList<>(retVars.keySet()) ) {
if( expectRetVars.contains(varName) )
continue;
//cleanup unexpected return values to avoid leaks
fn_ec.cleanupDataObject(fn_ec.removeVariable(varName));
}
// Unpin the pinned variables
ec.unpinVariables(_boundInputNames, pinStatus);
// add the updated binding for each return variable to the variables in original symbol table
// (with robustness for unbound outputs, i.e., function calls without assignment)
int numOutputs = Math.min(_boundOutputNames.size(), fpb.getOutputParams().size());
for (int i=0; i< numOutputs; i++) {
String boundVarName = _boundOutputNames.get(i);
String retVarName = fpb.getOutputParams().get(i).getName();
Data boundValue = retVars.get(retVarName);
if (boundValue == null)
throw new DMLRuntimeException("fcall "+_functionName+": "
+boundVarName + " was not assigned a return value");
//cleanup existing data bound to output variable name
Data exdata = ec.removeVariable(boundVarName);
if( exdata != boundValue )
ec.cleanupDataObject(exdata);
//add/replace data in symbol table
ec.setVariable(boundVarName, boundValue);
//map lineage of function returns back to calling site
if( lineage != null ) //unchanged ref
ec.getLineage().set(boundVarName, lineage.get(retVarName));
}
//update lineage cache with the functions outputs
if (DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse() && !fpb.isNondeterministic()) {
LineageCache.putValue(fpb.getOutputParams(), liInputs,
getCacheFunctionName(_functionName, fpb), fn_ec, t1-t0);
//FIXME: send _boundOutputNames instead of fpb.getOutputParams as
//those are already replaced by boundoutput names in the lineage map.
}
}
@Override
public void postprocessInstruction(ExecutionContext ec) {
//default post-process behavior
super.postprocessInstruction(ec);
}
@Override
public void printMe() {
LOG.debug("ExternalBuiltInFunction: " + this.toString());
}
public List<String> getBoundOutputParamNames() {
return _boundOutputNames;
}
public String updateInstStringFunctionName(String pattern, String replace)
{
//split current instruction
String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
if( parts[3].equals(pattern) )
parts[3] = replace;
//construct and set modified instruction
StringBuilder sb = new StringBuilder();
for( String part : parts ) {
sb.append(part);
sb.append(Lop.OPERAND_DELIMITOR);
}
return sb.substring( 0, sb.length()-Lop.OPERAND_DELIMITOR.length() );
}
public CPOperand[] getInputs(){
return _boundInputs;
}
private boolean reuseFunctionOutputs(LineageItem[] liInputs, FunctionProgramBlock fpb, ExecutionContext ec) {
//prepare lineage cache probing
String funcName = getCacheFunctionName(_functionName, fpb);
int numOutputs = Math.min(_boundOutputNames.size(), fpb.getOutputParams().size());
//reuse of function outputs
boolean reuse = LineageCache.reuse(
_boundOutputNames, fpb.getOutputParams(), numOutputs, liInputs, funcName, ec);
//statistics maintenance
if (reuse && DMLScript.STATISTICS) {
//decrement the call count for this function
Statistics.maintainCPFuncCallStats(getExtendedOpcode());
LineageCacheStatistics.incrementFuncHits();
}
return reuse;
}
private static String getCacheFunctionName(String fname, FunctionProgramBlock fpb) {
return !fpb.hasThreadID() ? fname :
fname.substring(0, fname.lastIndexOf(Lop.CP_CHILD_THREAD+fpb.getThreadID()));
}
}