blob: 0e3dba71f769ae49bb4650992e91a2889bc1ce4a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.parser;
import java.util.HashMap;
public class BinaryExpression extends Expression
{
private Expression _left;
private Expression _right;
private BinaryOp _opcode;
public Expression rewriteExpression(String prefix) throws LanguageException{
BinaryExpression newExpr = new BinaryExpression(this._opcode,
this.getFilename(), this.getBeginLine(), this.getBeginColumn(), this.getEndLine(), this.getEndColumn());
newExpr.setLeft(_left.rewriteExpression(prefix));
newExpr.setRight(_right.rewriteExpression(prefix));
return newExpr;
}
public BinaryExpression(BinaryOp bop) {
_kind = Kind.BinaryOp;
_opcode = bop;
setFilename("MAIN SCRIPT");
setBeginLine(0);
setBeginColumn(0);
setEndLine(0);
setEndColumn(0);
}
public BinaryExpression(BinaryOp bop, String filename, int beginLine, int beginColumn, int endLine, int endColumn) {
_kind = Kind.BinaryOp;
_opcode = bop;
setFilename(filename);
setBeginLine(beginLine);
setBeginColumn(beginColumn);
setEndLine(endLine);
setEndColumn(endColumn);
}
public BinaryOp getOpCode() {
return _opcode;
}
public void setLeft(Expression l) {
_left = l;
// update script location information --> left expression is BEFORE in script
if (_left != null){
setFilename(_left.getFilename());
setBeginLine(_left.getBeginLine());
setBeginColumn(_left.getBeginColumn());
}
}
public void setRight(Expression r) {
_right = r;
// update script location information --> right expression is AFTER in script
if (_right != null){
setFilename(_right.getFilename());
setBeginLine(_right.getEndLine());
setBeginColumn(_right.getEndColumn());
}
}
public Expression getLeft() {
return _left;
}
public Expression getRight() {
return _right;
}
/**
* Validate parse tree : Process Binary Expression in an assignment
* statement
*
* @throws LanguageException
*/
@Override
public void validateExpression(HashMap<String, DataIdentifier> ids, HashMap<String, ConstIdentifier> constVars, boolean conditional)
throws LanguageException
{
//recursive validate
if (_left instanceof FunctionCallIdentifier || _right instanceof FunctionCallIdentifier){
raiseValidateError("user-defined function calls not supported in binary expressions",
false, LanguageException.LanguageErrorCodes.UNSUPPORTED_EXPRESSION);
}
_left.validateExpression(ids, constVars, conditional);
_right.validateExpression(ids, constVars, conditional);
//constant propagation (precondition for more complex constant folding rewrite)
if( _left instanceof DataIdentifier && constVars.containsKey(((DataIdentifier) _left).getName()) )
_left = constVars.get(((DataIdentifier) _left).getName());
if( _right instanceof DataIdentifier && constVars.containsKey(((DataIdentifier) _right).getName()) )
_right = constVars.get(((DataIdentifier) _right).getName());
String outputName = getTempName();
DataIdentifier output = new DataIdentifier(outputName);
output.setAllPositions(this.getFilename(), this.getBeginLine(), this.getBeginColumn(), this.getEndLine(), this.getEndColumn());
output.setDataType(computeDataType(this.getLeft(), this.getRight(), true));
ValueType resultVT = computeValueType(this.getLeft(), this.getRight(), true);
// Override the computed value type, if needed
if (this.getOpCode() == Expression.BinaryOp.POW
|| this.getOpCode() == Expression.BinaryOp.DIV) {
resultVT = ValueType.DOUBLE;
}
output.setValueType(resultVT);
checkAndSetDimensions(output, conditional);
if (this.getOpCode() == Expression.BinaryOp.MATMULT) {
if ((this.getLeft().getOutput().getDataType() != DataType.MATRIX) || (this.getRight().getOutput().getDataType() != DataType.MATRIX)) {
// remove exception for now
// throw new LanguageException(
// "Matrix multiplication not supported for scalars",
// LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
}
if (this.getLeft().getOutput().getDim2() != -1
&& this.getRight().getOutput().getDim1() != -1
&& this.getLeft().getOutput().getDim2() != this.getRight()
.getOutput().getDim1())
{
raiseValidateError("invalid dimensions for matrix multiplication (k1="+this.getLeft().getOutput().getDim2()+", k2="+this.getRight().getOutput().getDim1()+")",
conditional, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
}
output.setDimensions(this.getLeft().getOutput().getDim1(), this
.getRight().getOutput().getDim2());
}
this.setOutput(output);
}
private void checkAndSetDimensions(DataIdentifier output, boolean conditional)
throws LanguageException {
Identifier left = this.getLeft().getOutput();
Identifier right = this.getRight().getOutput();
Identifier pivot = null;
Identifier aux = null;
if (left.getDataType() == DataType.MATRIX) {
pivot = left;
if (right.getDataType() == DataType.MATRIX) {
aux = right;
}
} else if (right.getDataType() == DataType.MATRIX) {
pivot = right;
}
if ((pivot != null) && (aux != null))
{
//check dimensions binary operations (if dims known)
if (isSameDimensionBinaryOp(this.getOpCode()) && pivot.dimsKnown() && aux.dimsKnown() )
{
if( (pivot.getDim1() != aux.getDim1() && aux.getDim1()>1) //number of rows must always be equivalent if not row vector
|| (pivot.getDim2() != aux.getDim2() && aux.getDim2()>1)) //number of cols must be equivalent if not col vector
{
raiseValidateError("Mismatch in dimensions for operation "+ this.toString(), conditional, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
}
}
}
//set dimension information
if (pivot != null) {
output.setDimensions(pivot.getDim1(), pivot.getDim2());
}
}
public String toString() {
return "(" + _left.toString() + " " + _opcode.toString() + " "
+ _right.toString() + ")";
}
@Override
public VariableSet variablesRead() {
VariableSet result = new VariableSet();
result.addVariables(_left.variablesRead());
result.addVariables(_right.variablesRead());
return result;
}
@Override
public VariableSet variablesUpdated() {
VariableSet result = new VariableSet();
result.addVariables(_left.variablesUpdated());
result.addVariables(_right.variablesUpdated());
return result;
}
public static boolean isSameDimensionBinaryOp(BinaryOp op) {
return (op == BinaryOp.PLUS) || (op == BinaryOp.MINUS)
|| (op == BinaryOp.MULT) || (op == BinaryOp.DIV)
|| (op == BinaryOp.MODULUS) || (op == BinaryOp.INTDIV)
|| (op == BinaryOp.POW);
}
}