| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.hops.codegen.cplan; |
| |
| import java.util.ArrayList; |
| |
| import org.apache.sysds.hops.codegen.template.TemplateUtils; |
| import org.apache.sysds.common.Types.DataType; |
| import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; |
| import org.apache.sysds.runtime.util.UtilFunctions; |
| import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI; |
| |
| public abstract class CNode |
| { |
| private static final IDSequence _seqVar = new IDSequence(); |
| private static final IDSequence _seqID = new IDSequence(); |
| |
| protected final long _ID; |
| protected ArrayList<CNode> _inputs = null; |
| protected CNode _output = null; |
| protected boolean _visited = false; |
| protected boolean _generated = false; |
| protected String _genVar = null; |
| protected long _rows = -1; |
| protected long _cols = -1; |
| protected DataType _dataType; |
| protected boolean _literal = false; |
| |
| //cached hash to allow memoization in DAG structures and repeated |
| //recursive hash computation over all inputs (w/ reset on updates) |
| protected int _hash = 0; |
| |
| public CNode() { |
| _ID = _seqID.getNextID(); |
| _inputs = new ArrayList<>(); |
| _generated = false; |
| } |
| |
| public long getID() { |
| return _ID; |
| } |
| |
| public ArrayList<CNode> getInput() { |
| return _inputs; |
| } |
| |
| public boolean isGenerated() { |
| return _generated; |
| } |
| |
| public void resetGenerated() { |
| if( isGenerated() ) |
| for( CNode cn : _inputs ) |
| cn.resetGenerated(); |
| _generated = false; |
| } |
| |
| public String createVarname() { |
| _genVar = "TMP"+_seqVar.getNextID(); |
| return _genVar; |
| } |
| |
| public String getVarname() { |
| return _genVar; |
| } |
| |
| public String getVarname(GeneratorAPI api) { return getVarname(); } |
| |
| |
| public String getVectorLength() { |
| if( getVarname().startsWith("a") ) |
| return "len"; |
| else if( getVarname().startsWith("b") ) |
| return getVarname()+".clen"; |
| else if( _dataType==DataType.MATRIX ) |
| return getVarname()+".length"; |
| return ""; |
| } |
| |
| public String getClassname() { |
| return getVarname(); |
| } |
| |
| public void resetHash() { |
| _hash = 0; |
| } |
| |
| public void setNumRows(long rows) { |
| _rows = rows; |
| } |
| |
| public long getNumRows() { |
| return _rows; |
| } |
| |
| public void setNumCols(long cols) { |
| _cols = cols; |
| } |
| |
| public long getNumCols() { |
| return _cols; |
| } |
| |
| public DataType getDataType() { |
| return _dataType; |
| } |
| |
| public void setDataType(DataType dt) { |
| _dataType = dt; |
| _hash = 0; |
| } |
| |
| public boolean isLiteral() { |
| return _literal; |
| } |
| |
| public void setLiteral(boolean literal) { |
| _literal = literal; |
| _hash = 0; |
| } |
| |
| public CNode getOutput() { |
| return _output; |
| } |
| |
| public void setOutput(CNode output) { |
| _output = output; |
| _hash = 0; |
| } |
| |
| public boolean isVisited() { |
| return _visited; |
| } |
| |
| public void setVisited() { |
| setVisited(true); |
| } |
| |
| public void setVisited(boolean flag) { |
| _visited = flag; |
| } |
| |
| public void resetVisitStatus() { |
| if( !isVisited() ) |
| return; |
| for( CNode h : getInput() ) |
| h.resetVisitStatus(); |
| setVisited(false); |
| } |
| |
| public abstract String codegen(boolean sparse, GeneratorAPI api); |
| |
| public abstract void setOutputDims(); |
| |
| /////////////////////////////////////// |
| // Functionality for plan cache |
| |
| //note: genvar/generated changed on codegen and not considered, |
| //rows and cols also not include to increase reuse potential |
| |
| @Override |
| public int hashCode() { |
| if( _hash == 0 ) { |
| //include inputs, partitioned by matrices and scalars to increase |
| //reuse in case of interleaved inputs (see CNodeTpl.renameInputs) |
| int h = 1; |
| for( CNode c : _inputs ) |
| if( c.getDataType()==DataType.MATRIX ) |
| h = UtilFunctions.intHashCode(h, c.hashCode()); |
| for( CNode c : _inputs ) |
| if( c.getDataType()!=DataType.MATRIX ) |
| h = UtilFunctions.intHashCode(h, c.hashCode()); |
| h = UtilFunctions.intHashCode(h, (_output!=null)?_output.hashCode():0); |
| h = UtilFunctions.intHashCode(h, (_dataType!=null)?_dataType.hashCode():0); |
| h = UtilFunctions.intHashCode(h, Boolean.hashCode(_literal)); |
| _hash = h; |
| } |
| return _hash; |
| } |
| |
| @Override |
| public boolean equals(Object that) { |
| if( !(that instanceof CNode) ) |
| return false; |
| |
| CNode cthat = (CNode) that; |
| boolean ret = _inputs.size() == cthat._inputs.size(); |
| for( int i=0; i<_inputs.size() && ret; i++ ) |
| ret &= _inputs.get(i).equals(cthat._inputs.get(i)); |
| return ret |
| && (_output == cthat._output || _output.equals(cthat._output)) |
| && _dataType == cthat._dataType |
| && _literal == cthat._literal; |
| } |
| |
| protected String replaceUnaryPlaceholders(String tmp, String varj, boolean vectIn) { |
| //replace sparse and dense inputs |
| tmp = tmp.replace("%IN1v%", varj+"vals"); |
| tmp = tmp.replace("%IN1i%", varj+"ix"); |
| tmp = tmp.replace("%IN1%", |
| (vectIn && TemplateUtils.isMatrix(_inputs.get(0))) ? varj + ".values(rix)" : |
| (vectIn && TemplateUtils.isRowVector(_inputs.get(0)) ? varj + ".values(0)" : varj)); |
| |
| //replace start position of main input |
| String spos = (_inputs.get(0) instanceof CNodeData |
| && _inputs.get(0).getDataType().isMatrix()) ? !varj.startsWith("b") ? |
| varj+"i" : TemplateUtils.isMatrix(_inputs.get(0)) ? varj + ".pos(rix)" : "0" : "0"; |
| |
| tmp = tmp.replace("%POS1%", spos); |
| tmp = tmp.replace("%POS2%", spos); |
| |
| //replace length |
| if( _inputs.get(0).getDataType().isMatrix() ) |
| tmp = tmp.replace("%LEN%", _inputs.get(0).getVectorLength()); |
| |
| return tmp; |
| } |
| |
| protected CodeTemplate getLanguageTemplateClass(CNode caller, GeneratorAPI api) { |
| switch (api) { |
| case CUDA: |
| if(caller instanceof CNodeCell) |
| return new org.apache.sysds.hops.codegen.cplan.cuda.CellWise(); |
| else if (caller instanceof CNodeUnary) |
| return new org.apache.sysds.hops.codegen.cplan.cuda.Unary(); |
| else if (caller instanceof CNodeBinary) |
| return new org.apache.sysds.hops.codegen.cplan.cuda.Binary(); |
| else if (caller instanceof CNodeTernary) |
| return new org.apache.sysds.hops.codegen.cplan.cuda.Ternary(); |
| else |
| return null; |
| case JAVA: |
| if(caller instanceof CNodeCell) |
| return new org.apache.sysds.hops.codegen.cplan.java.CellWise(); |
| else if (caller instanceof CNodeUnary) |
| return new org.apache.sysds.hops.codegen.cplan.java.Unary(); |
| else if (caller instanceof CNodeBinary) |
| return new org.apache.sysds.hops.codegen.cplan.java.Binary(); |
| else if (caller instanceof CNodeTernary) |
| return new org.apache.sysds.hops.codegen.cplan.java.Ternary(); |
| |
| else |
| return null; |
| default: |
| throw new RuntimeException("API not supported by code generator: " + api.toString()); |
| } |
| } |
| |
| public abstract boolean isSupported(GeneratorAPI api); |
| } |