blob: cd7f0cc8644f7c0e6a47ac0a1df630e372007f70 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.controlprogram;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.FunctionBlock;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.hops.recompile.Recompiler.ResetType;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.utils.Statistics;
public class FunctionProgramBlock extends ProgramBlock implements FunctionBlock
{
public String _functionName;
public String _namespace;
protected ArrayList<ProgramBlock> _childBlocks;
protected ArrayList<DataIdentifier> _inputParams;
protected ArrayList<DataIdentifier> _outputParams;
private boolean _recompileOnce = false;
private boolean _nondeterministic = false;
public FunctionProgramBlock( Program prog, List<DataIdentifier> inputParams, List<DataIdentifier> outputParams) {
super(prog);
_childBlocks = new ArrayList<>();
_inputParams = new ArrayList<>();
for (DataIdentifier id : inputParams)
_inputParams.add(new DataIdentifier(id));
_outputParams = new ArrayList<>();
for (DataIdentifier id : outputParams)
_outputParams.add(new DataIdentifier(id));
}
public DataIdentifier getInputParam(String name) {
return _inputParams.stream()
.filter(d -> d.getName().equals(name))
.findFirst().orElse(null);
}
public List<String> getInputParamNames() {
return _inputParams.stream().map(d -> d.getName()).collect(Collectors.toList());
}
public ArrayList<DataIdentifier> getInputParams(){
return _inputParams;
}
public ArrayList<DataIdentifier> getOutputParams(){
return _outputParams;
}
public void addProgramBlock(ProgramBlock childBlock) {
_childBlocks.add(childBlock);
}
public void setChildBlocks(ArrayList<ProgramBlock> pbs) {
_childBlocks = pbs;
}
@Override
public ArrayList<ProgramBlock> getChildBlocks() {
return _childBlocks;
}
@Override
public boolean isNested() {
return true;
}
@Override
public void execute(ExecutionContext ec)
{
//dynamically recompile entire function body (according to function inputs)
try {
if( ConfigurationManager.isDynamicRecompilation()
&& isRecompileOnce()
&& ParForProgramBlock.RESET_RECOMPILATION_FLAGs )
{
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
//note: it is important to reset the recompilation flags here
// (1) it is safe to reset recompilation flags because a 'recompile_once'
// function will be recompiled for every execution.
// (2) without reset, there would be no benefit in recompiling the entire function
LocalVariableMap tmp = (LocalVariableMap) ec.getVariables().clone();
boolean codegen = ConfigurationManager.isCodegenEnabled();
boolean singlenode = DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE;
ResetType reset = (codegen || singlenode) ? ResetType.RESET_KNOWN_DIMS : ResetType.RESET;
Recompiler.recompileProgramBlockHierarchy(_childBlocks, tmp, _tid, reset);
if( DMLScript.STATISTICS ){
long t1 = System.nanoTime();
Statistics.incrementFunRecompileTime(t1-t0);
Statistics.incrementFunRecompiles();
}
}
}
catch(Exception ex) {
throw new DMLRuntimeException("Error recompiling function body.", ex);
}
// for each program block
try {
for (int i=0 ; i < this._childBlocks.size() ; i++) {
_childBlocks.get(i).execute(ec);
}
}
catch (DMLScriptException e) {
throw e;
}
catch (Exception e){
throw new DMLRuntimeException(this.printBlockErrorLocation() + "Error evaluating function program block", e);
}
// check return values
checkOutputParameters(ec.getVariables());
}
protected void checkOutputParameters( LocalVariableMap vars )
{
for( DataIdentifier diOut : _outputParams ) {
String varName = diOut.getName();
Data dat = vars.get( varName );
if( dat == null )
LOG.error("Function output "+ varName +" is missing.");
else if( dat.getDataType() != diOut.getDataType() )
LOG.warn("Function output "+ varName +" has wrong data type: "+dat.getDataType()+".");
else if( dat.getValueType() != diOut.getValueType() )
LOG.warn("Function output "+ varName +" has wrong value type: "+dat.getValueType()+".");
}
}
public void setRecompileOnce( boolean flag ) {
_recompileOnce = flag;
}
public boolean isRecompileOnce() {
return _recompileOnce;
}
public void setNondeterministic(boolean flag) {
_nondeterministic = flag;
}
public boolean isNondeterministic() {
return _nondeterministic;
}
@Override
public FunctionBlock cloneFunctionBlock() {
return ProgramConverter
.createDeepCopyFunctionProgramBlock(this, new HashSet<>(), new HashSet<>());
}
@Override
public String printBlockErrorLocation(){
return "ERROR: Runtime error in function program block generated from function statement block between lines " + _beginLine + " and " + _endLine + " -- ";
}
}