blob: 13c8e10fca454526534ef8d961d8595babc19040 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.codegen.cplan;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI;
import org.apache.sysds.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary.BinType;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;
import org.apache.sysds.runtime.util.UtilFunctions;
import java.util.ArrayList;
public class CNodeRow extends CNodeTpl
{
protected static final String JAVA_TEMPLATE =
"package codegen;\n"
+ "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n"
+ "import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n"
+ "import org.apache.sysds.runtime.codegen.SpoofRowwise;\n"
+ "import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;\n"
+ "import org.apache.commons.math3.util.FastMath;\n"
+ "\n"
+ "public final class %TMP% extends SpoofRowwise { \n"
+ " public %TMP%() {\n"
+ " super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n"
+ " }\n"
+ " protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { \n"
+ "%BODY_dense%"
+ " }\n"
+ " protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int alen, int len, long grix, int rix) { \n"
+ "%BODY_sparse%"
+ " }\n"
+ "}\n";
private static final String TEMPLATE_ROWAGG_OUT = " c[rix] = %IN%;\n";
private static final String TEMPLATE_FULLAGG_OUT = " c[0] += %IN%;\n";
private static final String TEMPLATE_NOAGG_OUT = " LibSpoofPrimitives.vectWrite(%IN%, c, ci, %LEN%);\n";
private static final String TEMPLATE_NOAGG_CONST_OUT_CUDA = "\t\tvectWrite(%IN%, c.vals(0), 0, ci, %LEN%);\n";
private static final String TEMPLATE_NOAGG_OUT_CUDA = "\t\tvectWrite(%IN%, c.vals(0), 0, ci, %LEN%);\n";
// private static final String TEMPLATE_ROWAGG_OUT_CUDA = "\t\tif(threadIdx.x == 0){\n\t\t\t*(c.vals(rix)) = %IN%;\n//printf(\"rix=%d TMP7=%f TMP8=%f %IN%=%f\\n\",rix, TMP7, TMP8,%IN%);\n}\n";
private static final String TEMPLATE_ROWAGG_OUT_CUDA = "\t\tif(threadIdx.x == 0){\n\t\t\t*(c.vals(rix)) = %IN%;\n\t\t}\n";
// private static final String TEMPLATE_FULLAGG_OUT_CUDA =
// "\t\tif(threadIdx.x == 0) {\n\t\t\tT old = atomicAdd(c.vals(0), %IN%);\n//\t\t\tprintf(\"bid=%d full_agg add %f to %f\\n\",blockIdx.x, %IN%, old);\n\t\t}\n";
private static final String TEMPLATE_FULLAGG_OUT_CUDA =
"\t\tif(threadIdx.x == 0) {\n\t\tT old = atomicAdd(c.vals(0), %IN%);\n\t\t}\n";
public CNodeRow(ArrayList<CNode> inputs, CNode output ) {
super(inputs, output);
}
private RowType _type = null; //access pattern
private long _constDim2 = -1; //constant number of output columns
private int _numVectors = -1; //number of intermediate vectors
private boolean _tb1 = false;
public void setRowType(RowType type) {
_type = type;
_hash = 0;
}
public RowType getRowType() {
return _type;
}
public void setNumVectorIntermediates(int num) {
_numVectors = num;
_hash = 0;
}
public int getNumVectorIntermediates() {
return _numVectors;
}
public void setConstDim2(long dim2) {
_constDim2 = dim2;
_hash = 0;
}
public long getConstDim2() {
return _constDim2;
}
@Override
public void renameInputs() {
rRenameDataNode(_output, _inputs.get(0), "a"); // input matrix
renameInputs(_inputs, 1);
}
@Override
public String codegen(boolean sparse, GeneratorAPI _api) {
api = _api;
// note: ignore sparse flag, generate both
String tmp = getLanguageTemplate(this, api);
//generate dense/sparse bodies
String tmpDense = _output.codegen(false, api) + getOutputStatement(_output.getVarname());
_output.resetGenerated();
String tmpSparse = _output.codegen(true, api) + getOutputStatement(_output.getVarname());
_output.resetGenerated();
String varName = createVarname();
tmp = tmp.replace(api.isJava()?"%TMP%":"//%TMP%", varName);
if( !api.isJava() )
tmp = tmp.replace("/*%TMP%*/SPOOF_OP_NAME", varName);
String prefix = api.isJava()? "" : "//";
tmp = tmp.replace(prefix+"%BODY_dense%", tmpDense);
tmp = tmp.replace(prefix+"%BODY_sparse%", tmpSparse);
//replace outputs
tmp = api.isJava() ? tmp.replace("%OUT%", "c") :
tmp.replace("%OUT%", "c.vals(0)");
tmp = tmp.replace("%POSOUT%", "0");
//replace size information
tmp = tmp.replace("%LEN%", "a.cols()");
//replace colvector information and number of vector intermediates
tmp = tmp.replace("%TYPE%", _type.name());
tmp = tmp.replace("%CONST_DIM2%", String.valueOf(_constDim2));
_tb1 = TemplateUtils.containsBinary(_output, BinType.VECT_MATRIXMULT);
tmp = tmp.replace("%TB1%", String.valueOf(_tb1));
if(api == GeneratorAPI.CUDA && _numVectors > 0) {
tmp = tmp.replace("//%HAS_TEMP_VECT%", ": public TempStorageImpl<T, NUM_TMP_VECT, TMP_VECT_LEN>");
tmp = tmp.replace("/*%INIT_TEMP_VECT%*/", ", TempStorageImpl<T, NUM_TMP_VECT, TMP_VECT_LEN>(tmp_stor)");
}
else {
tmp = tmp.replace("//%HAS_TEMP_VECT%", "");
tmp = tmp.replace("/*%INIT_TEMP_VECT%*/", "");
}
tmp = tmp.replace("%VECT_MEM%", String.valueOf(_numVectors));
return tmp;
}
private String getOutputStatement(String varName) {
switch( _type ) {
case NO_AGG:
if(api == GeneratorAPI.CUDA)
return TEMPLATE_NOAGG_OUT_CUDA.replace("%IN%", varName + ".vals(0)").replaceAll("%LEN%", _output.getVarname()+".length");
case NO_AGG_B1:
case NO_AGG_CONST:
if(api == GeneratorAPI.JAVA)
return TEMPLATE_NOAGG_OUT.replace("%IN%", varName).replace("%LEN%", _output.getVarname()+".length");
else
return TEMPLATE_NOAGG_CONST_OUT_CUDA.replace("%IN%", varName + ".vals(0)").replaceAll("%LEN%", _output.getVarname()+".length");
case FULL_AGG:
if(api == GeneratorAPI.JAVA)
return TEMPLATE_FULLAGG_OUT.replace("%IN%", varName);
else
return TEMPLATE_FULLAGG_OUT_CUDA.replace("%IN%", varName);
case ROW_AGG:
if(api == GeneratorAPI.JAVA)
return TEMPLATE_ROWAGG_OUT.replace("%IN%", varName);
else
return TEMPLATE_ROWAGG_OUT_CUDA.replace("%IN%", varName);
default:
return ""; //_type.isColumnAgg()
}
}
@Override
public void setOutputDims() {
// TODO Auto-generated method stub
}
@Override
public SpoofOutputDimsType getOutputDimType() {
switch( _type ) {
case NO_AGG: return SpoofOutputDimsType.INPUT_DIMS;
case NO_AGG_B1: return SpoofOutputDimsType.ROW_RANK_DIMS;
case NO_AGG_CONST: return SpoofOutputDimsType.INPUT_DIMS_CONST2;
case FULL_AGG: return SpoofOutputDimsType.SCALAR;
case ROW_AGG: return SpoofOutputDimsType.ROW_DIMS;
case COL_AGG: return SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector
case COL_AGG_T: return SpoofOutputDimsType.COLUMN_DIMS_ROWS; //column vector
case COL_AGG_B1: return SpoofOutputDimsType.COLUMN_RANK_DIMS;
case COL_AGG_B1_T: return SpoofOutputDimsType.COLUMN_RANK_DIMS_T;
case COL_AGG_B1R: return SpoofOutputDimsType.RANK_DIMS_COLS;
case COL_AGG_CONST: return SpoofOutputDimsType.VECT_CONST2;
default:
throw new RuntimeException("Unsupported row type: "+_type.toString());
}
}
@Override
public CNodeTpl clone() {
CNodeRow tmp = new CNodeRow(_inputs, _output);
tmp.setRowType(_type);
tmp.setNumVectorIntermediates(_numVectors);
return tmp;
}
@Override
public int hashCode() {
if( _hash == 0 ) {
int h = UtilFunctions.intHashCode(super.hashCode(), _type.hashCode());
h = UtilFunctions.intHashCode(h, Long.hashCode(_constDim2));
_hash = UtilFunctions.intHashCode(h, Integer.hashCode(_numVectors));
}
return _hash;
}
@Override
public boolean equals(Object o) {
if(!(o instanceof CNodeRow))
return false;
CNodeRow that = (CNodeRow)o;
return super.equals(o)
&& _type == that._type
&& _numVectors == that._numVectors
&& _constDim2 == that._constDim2
&& equalInputReferences(
_output, that._output, _inputs, that._inputs);
}
@Override
public String getTemplateInfo() {
StringBuilder sb = new StringBuilder();
sb.append("SPOOF ROWAGGREGATE [type=");
sb.append(_type.name());
sb.append(", reqVectMem=");
sb.append(_numVectors);
sb.append("]");
return sb.toString();
}
@Override
public boolean isSupported(GeneratorAPI api) {
return (api == GeneratorAPI.CUDA || api == GeneratorAPI.JAVA) && _output.isSupported(api);
}
public int compile(GeneratorAPI api, String src) {
if(api == GeneratorAPI.CUDA)
return compile_nvrtc(SpoofCompiler.native_contexts.get(api), _genVar, src, _type.getValue(), _constDim2,
_numVectors, _tb1);
return -1;
}
private native int compile_nvrtc(long context, String name, String src, int type, long constDim2, int numVectors,
boolean TB1);
}