blob: 5500ddbfe19b76be1dff9d614efa75c042fedc60 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.codegen.cplan;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.runtime.util.DnnUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI;
public class CNodeNary extends CNode
{
public enum NaryType {
VECT_CBIND,
VECT_MAX_POOL,
VECT_AVG_POOL,
VECT_IM2COL,
VECT_CONV2DMM;
public static boolean contains(String value) {
for( NaryType bt : values() )
if( bt.name().equals(value) )
return true;
return false;
}
public String getTemplate(boolean sparseGen, long len, ArrayList<CNode> inputs, GeneratorAPI api) {
switch (this) {
case VECT_CBIND:
StringBuilder sb = new StringBuilder();
sb.append(" double[] %TMP% = LibSpoofPrimitives.allocVector("+len+", true); //nary cbind\n");
for( int i=0, off=0; i<inputs.size(); i++ ) {
CNode input = inputs.get(i);
boolean sparseInput = sparseGen && input instanceof CNodeData
&& input.getVarname().startsWith("a");
String varj = input.getVarname();
if( input.getDataType()==DataType.MATRIX ) {
String pos = (input instanceof CNodeData) ?
!varj.startsWith("b") ? varj+"i" : varj + ".pos(rix)" : "0";
sb.append( sparseInput ?
" LibSpoofPrimitives.vectWrite("+varj+"vals, %TMP%, "
+varj+"ix, "+pos+", "+off+", "+input._cols+");\n" :
" LibSpoofPrimitives.vectWrite("+(varj.startsWith("b")?varj+".values(rix)":varj)
+", %TMP%, "+pos+", "+off+", "+input._cols+");\n");
off += input._cols;
}
else { //e.g., col vectors -> scalars
sb.append(" %TMP%["+off+"] = "+varj+";\n");
off ++;
}
}
return sb.toString();
case VECT_MAX_POOL:
case VECT_AVG_POOL: {
String vectName = (this==VECT_MAX_POOL) ? "Maxpool" : "Avgpool";
String paramStr = getDnnParameterString(inputs, true);
return sparseGen ?
" double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len, "+paramStr+");\n" :
" double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%, "+paramStr+");\n";
}
case VECT_IM2COL: {
String paramStr = getDnnParameterString(inputs, true);
return sparseGen ?
" double[] %TMP% = LibSpoofPrimitives.vectIm2colWrite(%IN1v%, %IN1i%, %POS1%, alen, len, "+paramStr+");\n" :
" double[] %TMP% = LibSpoofPrimitives.vectIm2colWrite(%IN1%, %POS1%, %LEN%, "+paramStr+");\n";
}
case VECT_CONV2DMM: {
return " double[] %TMP% = LibSpoofPrimitives.vectConv2dmmWrite(%IN2%, %IN1%, %POS2%, %POS1%, %LEN%, "
+ getDnnParameterString(inputs, false) +");\n";
}
default:
throw new RuntimeException("Invalid nary type: "+this.toString());
}
}
public boolean isVectorPrimitive() {
return this == VECT_CBIND || this == VECT_MAX_POOL || this == VECT_AVG_POOL
|| this == VECT_IM2COL || this == NaryType.VECT_CONV2DMM;
}
}
private final NaryType _type;
public CNodeNary( CNode[] inputs, NaryType type ) {
for( CNode in : inputs )
_inputs.add(in);
_type = type;
setOutputDims();
}
public NaryType getType() {
return _type;
}
@Override
public String codegen(boolean sparse, GeneratorAPI api) {
if( isGenerated() )
return "";
StringBuilder sb = new StringBuilder();
//generate children
for(CNode in : _inputs)
sb.append(in.codegen(sparse, api));
//generate nary operation (use sparse template, if data input)
boolean lsparse = sparse && (_inputs.get(0) instanceof CNodeData
&& _inputs.get(0).getVarname().startsWith("a")
&& !_inputs.get(0).isLiteral());
String var = createVarname();
String tmp = _type.getTemplate(lsparse, _cols, _inputs, api);
tmp = tmp.replace("%TMP%", var);
//replace sparse and dense inputs
String varj1 = _inputs.get(0).getVarname();
String varj2 = _inputs.get(1).getVarname();
tmp = (_type == NaryType.VECT_CONV2DMM) ?
replaceBinaryPlaceholders(tmp, new String[]{varj1,varj2}, false) :
replaceUnaryPlaceholders(tmp, varj1, false);
sb.append(tmp);
//mark as generated
_generated = true;
return sb.toString();
}
@Override
public String toString() {
switch(_type) {
case VECT_CBIND: return "n(cbind)";
case VECT_MAX_POOL: return "n(maxpool)";
case VECT_AVG_POOL: return "n(avgpool)";
case VECT_IM2COL: return "n(im2col)";
case VECT_CONV2DMM: return "n(conv2dmm)";
default:
return "m("+_type.name().toLowerCase()+")";
}
}
@Override
public void setOutputDims() {
switch(_type) {
case VECT_CBIND:
_rows = _inputs.get(0)._rows;
_cols = 0;
for(CNode in : _inputs)
_cols += in._cols;
_dataType = DataType.MATRIX;
break;
case VECT_MAX_POOL:
case VECT_AVG_POOL: { //only stride 1, pad 0
int C = Integer.parseInt(_inputs.get(6).getVarname());
int H = Integer.parseInt(_inputs.get(7).getVarname());
int W = Integer.parseInt(_inputs.get(8).getVarname());
int R = Integer.parseInt(_inputs.get(11).getVarname());
int S = Integer.parseInt(_inputs.get(12).getVarname());
long P = DnnUtils.getP(H, R, 1, 0);
long Q = DnnUtils.getQ(W, S, 1, 0);
_rows = _inputs.get(0)._rows; //N
_cols = C * P * Q;
_dataType = DataType.MATRIX;
break;
}
case VECT_IM2COL:
_rows = 1;
_cols = -1;
_dataType = DataType.MATRIX;
break;
case VECT_CONV2DMM: {
int H = Integer.parseInt(_inputs.get(8).getVarname());
int W = Integer.parseInt(_inputs.get(9).getVarname());
int K = Integer.parseInt(_inputs.get(10).getVarname());
int R = Integer.parseInt(_inputs.get(12).getVarname());
int S = Integer.parseInt(_inputs.get(13).getVarname());
long P = DnnUtils.getP(H, R, 1, 0);
long Q = DnnUtils.getQ(W, S, 1, 0);
_rows = _inputs.get(0)._rows; //N
_cols = K * P * Q;
_dataType = DataType.MATRIX;
}
}
}
@Override
public int hashCode() {
if( _hash == 0 ) {
_hash = UtilFunctions.intHashCode(
super.hashCode(), _type.hashCode());
}
return _hash;
}
@Override
public boolean equals(Object o) {
if( !(o instanceof CNodeNary) )
return false;
CNodeNary that = (CNodeNary) o;
return super.equals(that)
&& _type == that._type;
}
@Override
public boolean isSupported(GeneratorAPI api) {
boolean is_supported = (api == GeneratorAPI.JAVA);
int i = 0;
while(is_supported && i < _inputs.size()) {
CNode in = _inputs.get(i++);
is_supported = in.isSupported(api);
}
return is_supported;
}
private static String getDnnParameterString(List<CNode> inputs, boolean unary) {
int off = unary ? 0 : 1;
//extract and derive individual parameters
int C = Integer.parseInt(inputs.get(off+6).getVarname());
int H = Integer.parseInt(inputs.get(off+7).getVarname());
int W = Integer.parseInt(inputs.get(off+8).getVarname());
int K = Integer.parseInt(inputs.get(off+9).getVarname());
int R = Integer.parseInt(inputs.get(off+11).getVarname());
int S = Integer.parseInt(inputs.get(off+12).getVarname());
int P = (int) DnnUtils.getP(H, R, 1, 0);
int Q = (int) DnnUtils.getQ(W, S, 1, 0);
//construct parameter string
return "rix, " + StringUtils.join(
new int[]{C, P, Q, K, R, S, H, W}, ',');
}
private String replaceBinaryPlaceholders(String tmp, String[] vars, boolean vectIn) {
//replace sparse and dense inputs
for( int j=0; j<2; j++ ) {
String varj = vars[j];
//replace sparse and dense inputs
tmp = tmp.replace("%IN"+(j+1)+"v%", varj+"vals");
tmp = tmp.replace("%IN"+(j+1)+"i%", varj+"ix");
tmp = tmp.replace("%IN"+(j+1)+"%",
varj.startsWith("b") ? varj + ".values(rix)" : varj );
//replace start position of main input
tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) instanceof CNodeData
&& _inputs.get(j).getDataType().isMatrix()) ? !varj.startsWith("b") ? varj+"i" :
(TemplateUtils.isMatrix(_inputs.get(j)) && _type!=NaryType.VECT_CONV2DMM) ? varj + ".pos(rix)" : "0" : "0");
}
//replace length
if( _inputs.get(0).getDataType().isMatrix() )
tmp = tmp.replace("%LEN%", _inputs.get(0).getVectorLength());
return tmp;
}
}