blob: f0b4695439a302a574555ecbdaba828dae8ee261 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.parser;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.parser.LanguageException.LanguageErrorCodes;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import java.util.HashMap;
public class RelationalExpression extends Expression
{
private Expression _left;
private Expression _right;
private RelationalOp _opcode;
public RelationalExpression(RelationalOp bop) {
_opcode = bop;
setFilename("MAIN SCRIPT");
setBeginLine(0);
setBeginColumn(0);
setEndLine(0);
setEndColumn(0);
setText(null);
}
public RelationalExpression(RelationalOp bop, ParseInfo parseInfo) {
_opcode = bop;
setParseInfo(parseInfo);
}
@Override
public Expression rewriteExpression(String prefix) {
RelationalExpression newExpr = new RelationalExpression(this._opcode, this);
newExpr.setLeft(_left.rewriteExpression(prefix));
newExpr.setRight(_right.rewriteExpression(prefix));
return newExpr;
}
public RelationalOp getOpCode(){
return _opcode;
}
public void setLeft(Expression l){
_left = l;
// update script location information --> left expression is BEFORE in script
if (_left != null) {
setParseInfo(_left);
}
}
public void setRight(Expression r){
_right = r;
// update script location information --> right expression is AFTER in script
if (_right != null) {
setParseInfo(_right);
}
}
public Expression getLeft(){
return _left;
}
public Expression getRight(){
return _right;
}
/**
* Validate parse tree : Process Relational Expression
*/
@Override
public void validateExpression(HashMap<String,DataIdentifier> ids, HashMap<String, ConstIdentifier> constVars, boolean conditional)
{
//check for functions calls in expression
if (_left instanceof FunctionCallIdentifier){
raiseValidateError("user-defined function calls not supported in relational expressions",
false, LanguageException.LanguageErrorCodes.UNSUPPORTED_EXPRESSION);
}
if (_right instanceof FunctionCallIdentifier){
raiseValidateError("user-defined function calls not supported in relational expressions",
false, LanguageException.LanguageErrorCodes.UNSUPPORTED_EXPRESSION);
}
// handle <NUMERIC> == <BOOLEAN> --> convert <BOOLEAN> to numeric value
if ((_left != null && _left instanceof BooleanIdentifier)
|| (_right != null && _right instanceof BooleanIdentifier)) {
if ((_left instanceof IntIdentifier || _left instanceof DoubleIdentifier) || _right instanceof IntIdentifier
|| _right instanceof DoubleIdentifier) {
if (_left instanceof BooleanIdentifier) {
if (((BooleanIdentifier) _left).getValue())
this.setLeft(new IntIdentifier(1, _left));
else
this.setLeft(new IntIdentifier(0, _left));
} else if (_right instanceof BooleanIdentifier) {
if (((BooleanIdentifier) _right).getValue())
this.setRight(new IntIdentifier(1, _right));
else
this.setRight(new IntIdentifier(0, _right));
}
}
}
//recursive validate
_left.validateExpression(ids, constVars, conditional);
if( _right !=null )
_right.validateExpression(ids, constVars, conditional);
//constant propagation (precondition for more complex constant folding rewrite)
if( !conditional ) {
if( _left instanceof DataIdentifier && constVars.containsKey(((DataIdentifier) _left).getName()) )
_left = constVars.get(((DataIdentifier) _left).getName());
if( _right instanceof DataIdentifier && constVars.containsKey(((DataIdentifier) _right).getName()) )
_right = constVars.get(((DataIdentifier) _right).getName());
}
String outputName = getTempName();
DataIdentifier output = new DataIdentifier(outputName);
output.setParseInfo(this);
boolean isLeftMatrix = (_left.getOutput() != null && _left.getOutput().getDataType() == DataType.MATRIX);
boolean isRightMatrix = (_right.getOutput() != null && _right.getOutput().getDataType() == DataType.MATRIX);
boolean isLeftFrame = (_left.getOutput() != null && _left.getOutput().getDataType() == DataType.FRAME);
boolean isRightFrame = (_right.getOutput() != null && _right.getOutput().getDataType() == DataType.FRAME);
if(isLeftMatrix || isRightMatrix) {
// Added to support matrix relational comparison
if(isLeftMatrix && isRightMatrix) {
checkMatchingDimensions(_left, _right, true);
}
MatrixCharacteristics dims = getBinaryMatrixCharacteristics(_left, _right);
output.setDataType(DataType.MATRIX);
output.setDimensions(dims.getRows(), dims.getCols());
output.setBlocksize(dims.getBlocksize());
//since SystemDS only supports double matrices, the value type is forced to
//double; once we support boolean matrices this needs to change
output.setValueType(ValueType.FP64);
}
else if(isLeftFrame && isRightFrame) {
output.setDataType(DataType.FRAME);
output.setDimensions(_left.getOutput().getDim1(), _left.getOutput().getDim2());
output.setValueType(ValueType.BOOLEAN);
}
else if( isLeftFrame || isRightFrame ) {
raiseValidateError("Unsupported relational expression for mixed types "
+_left.getOutput().getDataType().name()+" "+_right.getOutput().getDataType().name());
}
else {
output.setBooleanProperties();
}
this.setOutput(output);
}
/**
* This is same as the function from BuiltinFunctionExpression which is called by ppred
*
* @param expr1 expression 1
* @param expr2 expression 2
* @param allowsMV ?
*/
private void checkMatchingDimensions(Expression expr1, Expression expr2, boolean allowsMV)
{
if (expr1 != null && expr2 != null) {
// if any matrix has unknown dimensions, simply return
if( expr1.getOutput().getDim1() == -1 || expr2.getOutput().getDim1() == -1
||expr1.getOutput().getDim2() == -1 || expr2.getOutput().getDim2() == -1 )
{
return;
}
else if( (!allowsMV && expr1.getOutput().getDim1() != expr2.getOutput().getDim1())
|| (allowsMV && expr1.getOutput().getDim1() != expr2.getOutput().getDim1() && expr2.getOutput().getDim1() != 1)
|| (!allowsMV && expr1.getOutput().getDim2() != expr2.getOutput().getDim2())
|| (allowsMV && expr1.getOutput().getDim2() != expr2.getOutput().getDim2() && expr2.getOutput().getDim2() != 1) )
{
raiseValidateError("Mismatch in matrix dimensions of parameters for function "
+ this.getOpCode(), false, LanguageErrorCodes.INVALID_PARAMETERS);
}
}
}
@Override
public String toString(){
String leftString;
String rightString;
if (_left instanceof StringIdentifier) {
leftString = "\"" + _left.toString() + "\"";
} else {
leftString = _left.toString();
}
if (_right instanceof StringIdentifier) {
rightString = "\"" + _right.toString() + "\"";
} else {
rightString = _right.toString();
}
return "(" + leftString + " " + _opcode.toString() + " "
+ rightString + ")";
}
@Override
public VariableSet variablesRead() {
VariableSet result = new VariableSet();
result.addVariables(_left.variablesRead());
result.addVariables(_right.variablesRead());
return result;
}
@Override
public VariableSet variablesUpdated() {
VariableSet result = new VariableSet();
result.addVariables(_left.variablesUpdated());
result.addVariables(_right.variablesUpdated());
return result;
}
}