blob: 5850d45877b7d684d25e3596bd8c9344007082af [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.parser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.FunctionOp.FunctionType;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
public class FunctionStatementBlock extends StatementBlock
{
private boolean _recompileOnce = false;
/**
* TODO: DRB: This needs to be changed to reflect:
*
* 1) Default values for variables -- need to add R styled check here to make sure that once vars with
* default values start, they keep going to the right
*
* 2) The other parameters for External Functions
* @throws IOException
*/
@Override
public VariableSet validate(DMLProgram dmlProg, VariableSet ids, HashMap<String,ConstIdentifier> constVars, boolean conditional)
throws LanguageException, ParseException, IOException
{
if (_statements.size() > 1){
LOG.error(this.printBlockErrorLocation() + "FunctionStatementBlock should have only 1 statement (FunctionStatement)");
throw new LanguageException(this.printBlockErrorLocation() + "FunctionStatementBlock should have only 1 statement (FunctionStatement)");
}
FunctionStatement fstmt = (FunctionStatement) _statements.get(0);
// validate all function input parameters
ArrayList<DataIdentifier> inputValues = fstmt.getInputParams();
for( DataIdentifier inputValue : inputValues ) {
//check all input matrices have value type double
if( inputValue.getDataType()==DataType.MATRIX && inputValue.getValueType()!=ValueType.DOUBLE ) {
raiseValidateError("for function " + fstmt.getName() + ", input variable " + inputValue.getName()
+ " has an unsupported value type of " + inputValue.getValueType() + ".", false);
}
}
// handle DML-bodied functions
if (!(fstmt instanceof ExternalFunctionStatement))
{
// perform validate for function body
this._dmlProg = dmlProg;
for(StatementBlock sb : fstmt.getBody())
{
ids = sb.validate(dmlProg, ids, constVars, conditional);
constVars = sb.getConstOut();
}
if (fstmt.getBody().size() > 0)
_constVarsIn.putAll(fstmt.getBody().get(0).getConstIn());
if (fstmt.getBody().size() > 1)
_constVarsOut.putAll(fstmt.getBody().get(fstmt.getBody().size()-1).getConstOut());
// for each return value, check variable is defined and validate the return type
// if returnValue type known incorrect, then throw exception
ArrayList<DataIdentifier> returnValues = fstmt.getOutputParams();
for (DataIdentifier returnValue : returnValues){
DataIdentifier curr = ids.getVariable(returnValue.getName());
if (curr == null){
raiseValidateError("for function " + fstmt.getName() + ", return variable " + returnValue.getName() + " must be defined in function ", conditional);
}
if (curr.getDataType() == DataType.UNKNOWN){
raiseValidateError("for function " + fstmt.getName() + ", return variable " + curr.getName() + " data type of " + curr.getDataType() + " may not match data type in function signature of " + returnValue.getDataType(), true);
}
if (curr.getValueType() == ValueType.UNKNOWN){
raiseValidateError("for function " + fstmt.getName() + ", return variable " + curr.getName() + " data type of " + curr.getValueType() + " may not match data type in function signature of " + returnValue.getValueType(), true);
}
if (curr.getDataType() != DataType.UNKNOWN && !curr.getDataType().equals(returnValue.getDataType()) ){
raiseValidateError("for function " + fstmt.getName() + ", return variable " + curr.getName() + " data type of " + curr.getDataType() + " does not match data type in function signature of " + returnValue.getDataType(), conditional);
}
if (curr.getValueType() != ValueType.UNKNOWN && !curr.getValueType().equals(returnValue.getValueType())){
// attempt to convert value type: handle conversion from scalar DOUBLE or INT
if (curr.getDataType() == DataType.SCALAR && returnValue.getDataType() == DataType.SCALAR){
if (returnValue.getValueType() == ValueType.DOUBLE){
if (curr.getValueType() == ValueType.INT){
IntIdentifier currIntValue = (IntIdentifier)constVars.get(curr.getName());
if (currIntValue != null){
DoubleIdentifier currDoubleValue = new DoubleIdentifier(currIntValue.getValue(),
curr.getFilename(), curr.getBeginLine(), curr.getBeginColumn(),
curr.getEndLine(), curr.getEndColumn());
constVars.put(curr.getName(), currDoubleValue);
}
LOG.warn(curr.printWarningLocation() + "for function " + fstmt.getName()
+ ", return variable " + curr.getName() + " value type of "
+ curr.getValueType() + " does not match value type in function signature of "
+ returnValue.getValueType() + " but was safely cast");
curr.setValueType(ValueType.DOUBLE);
ids.addVariable(curr.getName(), curr);
}
else {
// THROW EXCEPTION -- CANNOT CONVERT
LOG.error(curr.printErrorLocation() + "for function " + fstmt.getName()
+ ", return variable " + curr.getName() + " value type of "
+ curr.getValueType() + " does not match value type in function signature of "
+ returnValue.getValueType() + " and cannot safely cast value");
throw new LanguageException(curr.printErrorLocation() + "for function "
+ fstmt.getName() + ", return variable " + curr.getName()
+ " value type of " + curr.getValueType()
+ " does not match value type in function signature of "
+ returnValue.getValueType() + " and cannot safely cast value");
}
}
if (returnValue.getValueType() == ValueType.INT){
// THROW EXCEPTION -- CANNOT CONVERT
LOG.error(curr.printErrorLocation() + "for function " + fstmt.getName()
+ ", return variable " + curr.getName() + " value type of "
+ curr.getValueType() + " does not match value type in function signature of "
+ returnValue.getValueType() + " and cannot safely cast " + curr.getValueType()
+ " as " + returnValue.getValueType());
throw new LanguageException(curr.printErrorLocation() + "for function " + fstmt.getName()
+ ", return variable " + curr.getName() + " value type of " + curr.getValueType()
+ " does not match value type in function signature of "
+ returnValue.getValueType() + " and cannot safely cast " + curr.getValueType()
+ " as " + returnValue.getValueType());
}
}
else {
LOG.error(curr.printErrorLocation() + "for function " + fstmt.getName() + ", return variable " + curr.getName() + " value type of " + curr.getValueType() + " does not match value type in function signature of " + returnValue.getValueType() + " and cannot safely cast double as int");
throw new LanguageException(curr.printErrorLocation() + "for function " + fstmt.getName() + ", return variable " + curr.getName() + " value type of " + curr.getValueType() + " does not match value type in function signature of " + returnValue.getValueType() + " and cannot safely cast " + curr.getValueType() + " as " + returnValue.getValueType());
}
}
}
}
// handle external functions
else
{
//validate specified attributes and attribute values
ExternalFunctionStatement efstmt = (ExternalFunctionStatement) fstmt;
efstmt.validateParameters(this);
//validate child statements
this._dmlProg = dmlProg;
for(StatementBlock sb : efstmt.getBody())
{
ids = sb.validate(dmlProg, ids, constVars, conditional);
constVars = sb.getConstOut();
}
}
return ids;
}
public FunctionType getFunctionOpType()
{
FunctionType ret = FunctionType.UNKNOWN;
FunctionStatement fstmt = (FunctionStatement) _statements.get(0);
if (fstmt instanceof ExternalFunctionStatement)
{
ExternalFunctionStatement efstmt = (ExternalFunctionStatement) fstmt;
String execType = efstmt.getOtherParams().get(ExternalFunctionStatement.EXEC_TYPE);
if( execType!=null ){
if(execType.equals(ExternalFunctionStatement.IN_MEMORY))
ret = FunctionType.EXTERNAL_MEM;
else
ret = FunctionType.EXTERNAL_FILE;
}
}
else
{
ret = FunctionType.DML;
}
return ret;
}
public VariableSet initializeforwardLV(VariableSet activeInPassed) throws LanguageException {
FunctionStatement fstmt = (FunctionStatement)_statements.get(0);
if (_statements.size() > 1){
LOG.error(this.printBlockErrorLocation() + "FunctionStatementBlock should have only 1 statement (while statement)");
throw new LanguageException(this.printBlockErrorLocation() + "FunctionStatementBlock should have only 1 statement (while statement)");
}
_read = new VariableSet();
_gen = new VariableSet();
VariableSet current = new VariableSet();
current.addVariables(activeInPassed);
for( StatementBlock sb : fstmt.getBody() )
{
current = sb.initializeforwardLV(current);
// for each generated variable in this block, check variable not killed
// in prior statement block in while stmt blody
for (String varName : sb._gen.getVariableNames()){
// IF the variable is NOT set in the while loop PRIOR to this stmt block,
// THEN needs to be generated
if (!_kill.getVariableNames().contains(varName)){
_gen.addVariable(varName, sb._gen.getVariable(varName));
}
}
_read.addVariables(sb._read);
_updated.addVariables(sb._updated);
// only add kill variables for statement blocks guaranteed to execute
if (!(sb instanceof WhileStatementBlock) && !(sb instanceof ForStatementBlock) ){
_kill.addVariables(sb._kill);
}
}
// activeOut includes variables from passed live in and updated in the while body
_liveOut = new VariableSet();
_liveOut.addVariables(current);
_liveOut.addVariables(_updated);
return _liveOut;
}
public VariableSet initializebackwardLV(VariableSet loPassed) throws LanguageException{
FunctionStatement wstmt = (FunctionStatement)_statements.get(0);
VariableSet lo = new VariableSet();
lo.addVariables(loPassed);
// calls analyze for each statement block in while stmt body
int numBlocks = wstmt.getBody().size();
for (int i = numBlocks - 1; i >= 0; i--){
lo = wstmt.getBody().get(i).analyze(lo);
}
VariableSet loReturn = new VariableSet();
loReturn.addVariables(lo);
return loReturn;
}
public ArrayList<Hop> get_hops() throws HopsException {
if (_hops != null && _hops.size() > 0){
LOG.error(this.printBlockErrorLocation() + "there should be no HOPs associated with the FunctionStatementBlock");
throw new HopsException(this.printBlockErrorLocation() + "there should be no HOPs associated with the FunctionStatementBlock");
}
return _hops;
}
public VariableSet analyze(VariableSet loPassed) throws LanguageException{
LOG.error(this.printBlockErrorLocation() + "Both liveIn and liveOut variables need to be specified for liveness analysis for FunctionStatementBlock");
throw new LanguageException(this.printBlockErrorLocation() + "Both liveIn and liveOut variables need to be specified for liveness analysis for FunctionStatementBlock");
}
public VariableSet analyze(VariableSet liPassed, VariableSet loPassed) throws LanguageException{
VariableSet candidateLO = new VariableSet();
candidateLO.addVariables(loPassed);
candidateLO.addVariables(_gen);
VariableSet origLiveOut = new VariableSet();
origLiveOut.addVariables(_liveOut);
_liveOut = new VariableSet();
for (String name : candidateLO.getVariableNames()){
if (origLiveOut.containsVariable(name)){
_liveOut.addVariable(name, candidateLO.getVariable(name));
}
}
initializebackwardLV(_liveOut);
// Cannot remove kill variables
_liveIn = new VariableSet();
_liveIn.addVariables(liPassed);
VariableSet liveInReturn = new VariableSet();
liveInReturn.addVariables(_liveIn);
return liveInReturn;
}
public void setRecompileOnce( boolean flag ) {
_recompileOnce = flag;
}
public boolean isRecompileOnce() {
return _recompileOnce;
}
}