| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.hops.codegen.cplan; |
| |
| import java.util.Arrays; |
| |
| import org.apache.commons.lang.StringUtils; |
| import org.apache.sysds.hops.codegen.template.TemplateUtils; |
| import org.apache.sysds.common.Types.DataType; |
| import org.apache.sysds.runtime.util.UtilFunctions; |
| import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI; |
| |
| public class CNodeBinary extends CNode { |
| |
| public enum BinType { |
| //matrix multiplication operations |
| DOT_PRODUCT, VECT_MATRIXMULT, VECT_OUTERMULT_ADD, |
| //vector-scalar-add operations |
| VECT_MULT_ADD, VECT_DIV_ADD, VECT_MINUS_ADD, VECT_PLUS_ADD, |
| VECT_POW_ADD, VECT_MIN_ADD, VECT_MAX_ADD, |
| VECT_EQUAL_ADD, VECT_NOTEQUAL_ADD, VECT_LESS_ADD, |
| VECT_LESSEQUAL_ADD, VECT_GREATER_ADD, VECT_GREATEREQUAL_ADD, |
| VECT_CBIND_ADD, VECT_XOR_ADD, |
| //vector-scalar operations |
| VECT_MULT_SCALAR, VECT_DIV_SCALAR, VECT_MINUS_SCALAR, VECT_PLUS_SCALAR, |
| VECT_POW_SCALAR, VECT_MIN_SCALAR, VECT_MAX_SCALAR, |
| VECT_EQUAL_SCALAR, VECT_NOTEQUAL_SCALAR, VECT_LESS_SCALAR, |
| VECT_LESSEQUAL_SCALAR, VECT_GREATER_SCALAR, VECT_GREATEREQUAL_SCALAR, |
| VECT_CBIND, |
| VECT_XOR_SCALAR, VECT_BITWAND_SCALAR, |
| //vector-vector operations |
| VECT_MULT, VECT_DIV, VECT_MINUS, VECT_PLUS, VECT_MIN, VECT_MAX, VECT_EQUAL, |
| VECT_NOTEQUAL, VECT_LESS, VECT_LESSEQUAL, VECT_GREATER, VECT_GREATEREQUAL, |
| VECT_XOR, VECT_BITWAND, |
| VECT_BIASADD, VECT_BIASMULT, |
| //scalar-scalar operations |
| MULT, DIV, PLUS, MINUS, MODULUS, INTDIV, |
| LESS, LESSEQUAL, GREATER, GREATEREQUAL, EQUAL,NOTEQUAL, |
| MIN, MAX, AND, OR, XOR, LOG, LOG_NZ, POW, |
| BITWAND, |
| SEQ_RIX, |
| MINUS1_MULT, MINUS_NZ; |
| |
| public static boolean contains(String value) { |
| return Arrays.stream(values()).anyMatch(bt -> bt.name().equals(value)); |
| } |
| |
| public boolean isCommutative() { |
| boolean ssComm = (this==EQUAL || this==NOTEQUAL |
| || this==PLUS || this==MULT || this==MIN || this==MAX |
| || this==OR || this==AND || this==XOR || this==BITWAND); |
| boolean vsComm = (this==VECT_EQUAL_SCALAR || this==VECT_NOTEQUAL_SCALAR |
| || this==VECT_PLUS_SCALAR || this==VECT_MULT_SCALAR |
| || this==VECT_MIN_SCALAR || this==VECT_MAX_SCALAR |
| || this==VECT_XOR_SCALAR || this==VECT_BITWAND_SCALAR ); |
| boolean vvComm = (this==VECT_EQUAL || this==VECT_NOTEQUAL |
| || this==VECT_PLUS || this==VECT_MULT || this==VECT_MIN || this==VECT_MAX |
| || this==VECT_XOR || this==BinType.VECT_BITWAND); |
| return ssComm || vsComm || vvComm; |
| } |
| |
| public boolean isVectorPrimitive() { |
| return isVectorScalarPrimitive() |
| || isVectorVectorPrimitive() |
| || isVectorMatrixPrimitive(); |
| } |
| public boolean isVectorScalarPrimitive() { |
| return this == VECT_DIV_SCALAR || this == VECT_MULT_SCALAR |
| || this == VECT_MINUS_SCALAR || this == VECT_PLUS_SCALAR |
| || this == VECT_POW_SCALAR |
| || this == VECT_MIN_SCALAR || this == VECT_MAX_SCALAR |
| || this == VECT_EQUAL_SCALAR || this == VECT_NOTEQUAL_SCALAR |
| || this == VECT_LESS_SCALAR || this == VECT_LESSEQUAL_SCALAR |
| || this == VECT_GREATER_SCALAR || this == VECT_GREATEREQUAL_SCALAR |
| || this == VECT_CBIND |
| || this == VECT_XOR_SCALAR || this == VECT_BITWAND_SCALAR; |
| } |
| public boolean isVectorVectorPrimitive() { |
| return this == VECT_DIV || this == VECT_MULT |
| || this == VECT_MINUS || this == VECT_PLUS |
| || this == VECT_MIN || this == VECT_MAX |
| || this == VECT_EQUAL || this == VECT_NOTEQUAL |
| || this == VECT_LESS || this == VECT_LESSEQUAL |
| || this == VECT_GREATER || this == VECT_GREATEREQUAL |
| || this == VECT_XOR || this == VECT_BITWAND |
| || this == VECT_BIASADD || this == VECT_BIASMULT; |
| } |
| public boolean isVectorMatrixPrimitive() { |
| return this == VECT_MATRIXMULT |
| || this == VECT_OUTERMULT_ADD; |
| } |
| public BinType getVectorAddPrimitive() { |
| return BinType.valueOf("VECT_"+getVectorPrimitiveName().toUpperCase()+"_ADD"); |
| } |
| public String getVectorPrimitiveName() { |
| String [] tmp = this.name().split("_"); |
| return StringUtils.capitalize(tmp[1].toLowerCase()); |
| } |
| } |
| |
| private final BinType _type; |
| |
| public CNodeBinary( CNode in1, CNode in2, BinType type ) { |
| //canonicalize commutative matrix-scalar operations |
| //to increase reuse potential |
| if( type.isCommutative() && in1 instanceof CNodeData |
| && in1.getDataType()==DataType.SCALAR ) { |
| CNode tmp = in1; |
| in1 = in2; |
| in2 = tmp; |
| } |
| |
| _inputs.add(in1); |
| _inputs.add(in2); |
| _type = type; |
| setOutputDims(); |
| } |
| |
| public BinType getType() { |
| return _type; |
| } |
| |
| @Override |
| public String codegen(boolean sparse, GeneratorAPI api) { |
| if( isGenerated() ) |
| return ""; |
| |
| StringBuilder sb = new StringBuilder(); |
| |
| //generate children |
| sb.append(_inputs.get(0).codegen(sparse, api)); |
| sb.append(_inputs.get(1).codegen(sparse, api)); |
| |
| //generate binary operation (use sparse template, if data input) |
| boolean lsparseLhs = sparse && _inputs.get(0) instanceof CNodeData |
| && _inputs.get(0).getVarname().startsWith("a"); |
| boolean lsparseRhs = sparse && _inputs.get(1) instanceof CNodeData |
| && _inputs.get(1).getVarname().startsWith("a"); |
| boolean scalarInput = _inputs.get(0).getDataType().isScalar(); |
| boolean scalarVector = (_inputs.get(0).getDataType().isScalar() |
| && _inputs.get(1).getDataType().isMatrix()); |
| String var = createVarname(); |
| // String tmp = _type.getTemplate(api, lang, lsparseLhs, lsparseRhs, scalarVector, scalarInput); |
| String tmp = getLanguageTemplateClass(this, api).getTemplate(_type, lsparseLhs, lsparseRhs, scalarVector, scalarInput); |
| |
| tmp = tmp.replace("%TMP%", var); |
| |
| //replace input references and start indexes |
| for( int j=0; j<2; j++ ) { |
| String varj = _inputs.get(j).getVarname(api); |
| |
| //replace sparse and dense inputs |
| tmp = tmp.replace("%IN"+(j+1)+"v%", varj+"vals"); |
| tmp = tmp.replace("%IN"+(j+1)+"i%", varj+"ix"); |
| tmp = tmp.replace("%IN"+(j+1)+"%", |
| varj.startsWith("b") ? varj + ".values(rix)" : varj ); |
| |
| //replace start position of main input |
| tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) instanceof CNodeData |
| && _inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" : |
| (TemplateUtils.isMatrix(_inputs.get(j)) && _type!=BinType.VECT_MATRIXMULT) ? |
| varj + ".pos(rix)" : "0" : "0"); |
| } |
| //replace length information (e.g., after matrix mult) |
| if( _type == BinType.VECT_OUTERMULT_ADD ) { |
| for( int j=0; j<2; j++ ) |
| tmp = tmp.replace("%LEN"+(j+1)+"%", _inputs.get(j).getVectorLength()); |
| } |
| else { //general case |
| CNode mInput = getIntermediateInputVector(); |
| if( mInput != null ) |
| tmp = tmp.replace("%LEN%", mInput.getVectorLength()); |
| } |
| |
| sb.append(tmp); |
| |
| //mark as generated |
| _generated = true; |
| |
| return sb.toString(); |
| } |
| |
| private CNode getIntermediateInputVector() { |
| for( int i=0; i<2; i++ ) |
| if( getInput().get(i).getDataType().isMatrix() ) |
| return getInput().get(i); |
| return null; |
| } |
| |
| @Override |
| public String toString() { |
| switch(_type) { |
| case DOT_PRODUCT: return "b(dot)"; |
| case VECT_MATRIXMULT: return "b(vmm)"; |
| case VECT_OUTERMULT_ADD: return "b(voma)"; |
| case VECT_MULT_ADD: return "b(vma)"; |
| case VECT_DIV_ADD: return "b(vda)"; |
| case VECT_MINUS_ADD: return "b(vmia)"; |
| case VECT_PLUS_ADD: return "b(vpa)"; |
| case VECT_POW_ADD: return "b(vpowa)"; |
| case VECT_MIN_ADD: return "b(vmina)"; |
| case VECT_MAX_ADD: return "b(vmaxa)"; |
| case VECT_EQUAL_ADD: return "b(veqa)"; |
| case VECT_NOTEQUAL_ADD: return "b(vneqa)"; |
| case VECT_LESS_ADD: return "b(vlta)"; |
| case VECT_LESSEQUAL_ADD: return "b(vltea)"; |
| case VECT_GREATEREQUAL_ADD: return "b(vgtea)"; |
| case VECT_GREATER_ADD: return "b(vgta)"; |
| case VECT_CBIND_ADD: return "b(vcbinda)"; |
| case VECT_MULT_SCALAR: return "b(vm)"; |
| case VECT_DIV_SCALAR: return "b(vd)"; |
| case VECT_MINUS_SCALAR: return "b(vmi)"; |
| case VECT_PLUS_SCALAR: return "b(vp)"; |
| case VECT_XOR_SCALAR: return "v(vxor)"; |
| case VECT_POW_SCALAR: return "b(vpow)"; |
| case VECT_MIN_SCALAR: return "b(vmin)"; |
| case VECT_MAX_SCALAR: return "b(vmax)"; |
| case VECT_EQUAL_SCALAR: return "b(veq)"; |
| case VECT_NOTEQUAL_SCALAR: return "b(vneq)"; |
| case VECT_LESS_SCALAR: return "b(vlt)"; |
| case VECT_LESSEQUAL_SCALAR: return "b(vlte)"; |
| case VECT_GREATEREQUAL_SCALAR: return "b(vgte)"; |
| case VECT_GREATER_SCALAR: return "b(vgt)"; |
| case VECT_MULT: return "b(v2m)"; |
| case VECT_DIV: return "b(v2d)"; |
| case VECT_MINUS: return "b(v2mi)"; |
| case VECT_PLUS: return "b(v2p)"; |
| case VECT_XOR: return "b(v2xor)"; |
| case VECT_MIN: return "b(v2min)"; |
| case VECT_MAX: return "b(v2max)"; |
| case VECT_EQUAL: return "b(v2eq)"; |
| case VECT_NOTEQUAL: return "b(v2neq)"; |
| case VECT_LESS: return "b(v2lt)"; |
| case VECT_LESSEQUAL: return "b(v2lte)"; |
| case VECT_GREATEREQUAL: return "b(v2gte)"; |
| case VECT_GREATER: return "b(v2gt)"; |
| case VECT_CBIND: return "b(cbind)"; |
| case VECT_BIASADD: return "b(vbias+)"; |
| case VECT_BIASMULT: return "b(vbias*)"; |
| case MULT: return "b(*)"; |
| case DIV: return "b(/)"; |
| case PLUS: return "b(+)"; |
| case MINUS: return "b(-)"; |
| case POW: return "b(^)"; |
| case MODULUS: return "b(%%)"; |
| case INTDIV: return "b(%/%)"; |
| case LESS: return "b(<)"; |
| case LESSEQUAL: return "b(<=)"; |
| case GREATER: return "b(>)"; |
| case GREATEREQUAL: return "b(>=)"; |
| case EQUAL: return "b(==)"; |
| case NOTEQUAL: return "b(!=)"; |
| case OR: return "b(|)"; |
| case AND: return "b(&)"; |
| case XOR: return "b(xor)"; |
| case BITWAND: return "b(bitwAnd)"; |
| case SEQ_RIX: return "b(seqr)"; |
| case MINUS1_MULT: return "b(1-*)"; |
| case MINUS_NZ: return "b(-nz)"; |
| default: return "b("+_type.name().toLowerCase()+")"; |
| } |
| } |
| |
| @Override |
| public void setOutputDims() |
| { |
| switch(_type) { |
| //VECT |
| case VECT_MULT_ADD: |
| case VECT_DIV_ADD: |
| case VECT_MINUS_ADD: |
| case VECT_PLUS_ADD: |
| case VECT_POW_ADD: |
| case VECT_MIN_ADD: |
| case VECT_MAX_ADD: |
| case VECT_EQUAL_ADD: |
| case VECT_NOTEQUAL_ADD: |
| case VECT_LESS_ADD: |
| case VECT_LESSEQUAL_ADD: |
| case VECT_GREATER_ADD: |
| case VECT_GREATEREQUAL_ADD: |
| case VECT_CBIND_ADD: |
| case VECT_XOR_ADD: |
| boolean vectorScalar = _inputs.get(1).getDataType()==DataType.SCALAR; |
| _rows = _inputs.get(vectorScalar ? 0 : 1)._rows; |
| _cols = _inputs.get(vectorScalar ? 0 : 1)._cols; |
| _dataType = DataType.MATRIX; |
| break; |
| |
| case VECT_CBIND: |
| _rows = _inputs.get(0)._rows; |
| _cols = _inputs.get(0)._cols+1; |
| _dataType = DataType.MATRIX; |
| break; |
| |
| case VECT_OUTERMULT_ADD: |
| _rows = _inputs.get(0)._cols; |
| _cols = _inputs.get(1)._cols; |
| _dataType = DataType.MATRIX; |
| break; |
| |
| case VECT_DIV_SCALAR: |
| case VECT_MULT_SCALAR: |
| case VECT_MINUS_SCALAR: |
| case VECT_PLUS_SCALAR: |
| case VECT_XOR_SCALAR: |
| case VECT_BITWAND_SCALAR: |
| case VECT_POW_SCALAR: |
| case VECT_MIN_SCALAR: |
| case VECT_MAX_SCALAR: |
| case VECT_EQUAL_SCALAR: |
| case VECT_NOTEQUAL_SCALAR: |
| case VECT_LESS_SCALAR: |
| case VECT_LESSEQUAL_SCALAR: |
| case VECT_GREATER_SCALAR: |
| case VECT_GREATEREQUAL_SCALAR: |
| |
| case VECT_DIV: |
| case VECT_MULT: |
| case VECT_MINUS: |
| case VECT_PLUS: |
| case VECT_XOR: |
| case VECT_BITWAND: |
| case VECT_MIN: |
| case VECT_MAX: |
| case VECT_EQUAL: |
| case VECT_NOTEQUAL: |
| case VECT_LESS: |
| case VECT_LESSEQUAL: |
| case VECT_GREATER: |
| case VECT_GREATEREQUAL: |
| case VECT_BIASADD: |
| case VECT_BIASMULT: |
| boolean scalarVector = (_inputs.get(0).getDataType()==DataType.SCALAR); |
| _rows = _inputs.get(scalarVector ? 1 : 0)._rows; |
| _cols = _inputs.get(scalarVector ? 1 : 0)._cols; |
| _dataType= DataType.MATRIX; |
| break; |
| |
| case VECT_MATRIXMULT: |
| _rows = _inputs.get(0)._rows; |
| _cols = _inputs.get(1)._cols; |
| _dataType = DataType.MATRIX; |
| break; |
| |
| case DOT_PRODUCT: |
| |
| //SCALAR Arithmetic |
| case MULT: |
| case DIV: |
| case PLUS: |
| case MINUS: |
| case MINUS1_MULT: |
| case MINUS_NZ: |
| case MODULUS: |
| case INTDIV: |
| //SCALAR Comparison |
| case LESS: |
| case LESSEQUAL: |
| case GREATER: |
| case GREATEREQUAL: |
| case EQUAL: |
| case NOTEQUAL: |
| //SCALAR LOGIC |
| case MIN: |
| case MAX: |
| case AND: |
| case OR: |
| case XOR: |
| case BITWAND: |
| case LOG: |
| case LOG_NZ: |
| case POW: |
| case SEQ_RIX: |
| _rows = 0; |
| _cols = 0; |
| _dataType= DataType.SCALAR; |
| break; |
| } |
| } |
| |
| @Override |
| public int hashCode() { |
| if( _hash == 0 ) { |
| _hash = UtilFunctions.intHashCode( |
| super.hashCode(), _type.hashCode()); |
| } |
| return _hash; |
| } |
| |
| @Override |
| public boolean equals(Object o) { |
| if( !(o instanceof CNodeBinary) ) |
| return false; |
| |
| CNodeBinary that = (CNodeBinary) o; |
| return super.equals(that) |
| && _type == that._type; |
| } |
| |
| @Override |
| public boolean isSupported(GeneratorAPI api) { |
| boolean is_supported = (api == GeneratorAPI.CUDA || api == GeneratorAPI.JAVA); |
| int i = 0; |
| while(is_supported && i < _inputs.size()) { |
| CNode in = _inputs.get(i++); |
| is_supported = in.isSupported(api); |
| } |
| return is_supported; |
| } |
| } |