blob: d50e4b09a904a6309c5bf50728220fb25cc054d1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.codegen.cplan.cpp;
import org.apache.commons.lang.StringUtils;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import static org.apache.sysds.runtime.matrix.data.LibMatrixNative.isSinglePrecision;
public class Unary implements CodeTemplate {
@Override
public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
if(isSinglePrecision()) {
switch( type ) {
case ROW_SUMS:
case ROW_SUMSQS:
case ROW_MINS:
case ROW_MAXS:
case ROW_MEANS:
case ROW_COUNTNNZS: {
String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
return sparse ? " T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n":
" T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n";
}
case VECT_EXP:
case VECT_POW2:
case VECT_MULT2:
case VECT_SQRT:
case VECT_LOG:
case VECT_ABS:
case VECT_ROUND:
case VECT_CEIL:
case VECT_FLOOR:
case VECT_SIGN:
case VECT_SIN:
case VECT_COS:
case VECT_TAN:
case VECT_ASIN:
case VECT_ACOS:
case VECT_ATAN:
case VECT_SINH:
case VECT_COSH:
case VECT_TANH:
case VECT_CUMSUM:
case VECT_CUMMIN:
case VECT_CUMMAX:
case VECT_SPROP:
case VECT_SIGMOID: {
String vectName = type.getVectorPrimitiveName();
return sparse ? " T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len);\n" :
" T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%);\n";
}
case EXP:
return " T %TMP% = expf(%IN1%);\n";
case LOOKUP_R:
return sparse ?
" T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
" T %TMP% = getValue(%IN1%, rix);\n";
case LOOKUP_C:
return " T %TMP% = getValue(%IN1%, n, 0, cix);\n";
case LOOKUP_RC:
return " T %TMP% = getValue(%IN1%, n, rix, cix);\n";
case LOOKUP0:
return " T %TMP% = %IN1%[0];\n";
case POW2:
return " T %TMP% = %IN1% * %IN1%;\n";
case MULT2:
return " T %TMP% = %IN1% + %IN1%;\n";
case ABS:
return " T %TMP% = fabsf(%IN1%);\n";
case SIN:
return " T %TMP% = sinf(%IN1%);\n";
case COS:
return " T %TMP% = cosf(%IN1%);\n";
case TAN:
return " T %TMP% = tanf(%IN1%);\n";
case ASIN:
return " T %TMP% = asinf(%IN1%);\n";
case ACOS:
return " T %TMP% = acosf(%IN1%);\n";
case ATAN:
return " T %TMP% = atanf(%IN1%);\n";
case SINH:
return " T %TMP% = sinhf(%IN1%);\n";
case COSH:
return " T %TMP% = coshf(%IN1%);\n";
case TANH:
return " T %TMP% = tanhf(%IN1%);\n";
case SIGN:
return " T %TMP% = signbit(%IN1%) == 0 ? 1.0f : -1.0f;\n";
case SQRT:
return " T %TMP% = sqrtf(%IN1%);\n";
case LOG:
return " T %TMP% = logf(%IN1%);\n";
case ROUND:
return " T %TMP% = roundf(%IN1%);\n";
case CEIL:
return " T %TMP% = ceilf(%IN1%);\n";
case FLOOR:
return " T %TMP% = floorf(%IN1%);\n";
case SPROP:
return " T %TMP% = %IN1% * (1 - %IN1%);\n";
case SIGMOID:
return " T %TMP% = 1 / (1 + expf(-%IN1%));\n";
case LOG_NZ:
return " T %TMP% = (%IN1%==0) ? 0 : logf(%IN1%);\n";
default:
throw new RuntimeException("Invalid unary type: "+this.toString());
}
}
else { /* double precision */
switch( type ) {
case ROW_SUMS:
case ROW_SUMSQS:
case ROW_MINS:
case ROW_MAXS:
case ROW_MEANS:
case ROW_COUNTNNZS: {
String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
return sparse ? " T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n":
" T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n";
}
case VECT_EXP:
case VECT_POW2:
case VECT_MULT2:
case VECT_SQRT:
case VECT_LOG:
case VECT_ABS:
case VECT_ROUND:
case VECT_CEIL:
case VECT_FLOOR:
case VECT_SIGN:
case VECT_SIN:
case VECT_COS:
case VECT_TAN:
case VECT_ASIN:
case VECT_ACOS:
case VECT_ATAN:
case VECT_SINH:
case VECT_COSH:
case VECT_TANH:
case VECT_CUMSUM:
case VECT_CUMMIN:
case VECT_CUMMAX:
case VECT_SPROP:
case VECT_SIGMOID: {
String vectName = type.getVectorPrimitiveName();
return sparse ? " T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len);\n" :
" T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%);\n";
}
case EXP:
return " T %TMP% = exp(%IN1%);\n";
case LOOKUP_R:
return sparse ?
" T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
" T %TMP% = getValue(%IN1%, rix);\n";
case LOOKUP_C:
return " T %TMP% = getValue(%IN1%, n, 0, cix);\n";
case LOOKUP_RC:
return " T %TMP% = getValue(%IN1%, n, rix, cix);\n";
case LOOKUP0:
return " T %TMP% = %IN1%[0];\n";
case POW2:
return " T %TMP% = %IN1% * %IN1%;\n";
case MULT2:
return " T %TMP% = %IN1% + %IN1%;\n";
case ABS:
return " T %TMP% = fabs(%IN1%);\n";
case SIN:
return " T %TMP% = sin(%IN1%);\n";
case COS:
return " T %TMP% = cos(%IN1%);\n";
case TAN:
return " T %TMP% = tan(%IN1%);\n";
case ASIN:
return " T %TMP% = asin(%IN1%);\n";
case ACOS:
return " T %TMP% = acos(%IN1%);\n";
case ATAN:
return " T %TMP% = atan(%IN1%);\n";
case SINH:
return " T %TMP% = sinh(%IN1%);\n";
case COSH:
return " T %TMP% = cosh(%IN1%);\n";
case TANH:
return " T %TMP% = tanh(%IN1%);\n";
case SIGN:
return " T %TMP% = signbit(%IN1%) == 0 ? 1.0f : -1.0f;\n";
case SQRT:
return " T %TMP% = sqrt(%IN1%);\n";
case LOG:
return " T %TMP% = log(%IN1%);\n";
case ROUND:
return " T %TMP% = round(%IN1%);\n";
case CEIL:
return " T %TMP% = ceil(%IN1%);\n";
case FLOOR:
return " T %TMP% = floor(%IN1%);\n";
case SPROP:
return " T %TMP% = %IN1% * (1 - %IN1%);\n";
case SIGMOID:
return " T %TMP% = 1 / (1 + exp(-%IN1%));\n";
case LOG_NZ:
return " T %TMP% = (%IN1%==0) ? 0 : log(%IN1%);\n";
default:
throw new RuntimeException("Invalid unary type: "+this.toString());
}
}
}
@Override
public String getTemplate() {
throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
}
@Override
public String getTemplate(SpoofCellwise.CellType ct) {
throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
}
@Override
public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector, boolean scalarInput) {
throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
}
@Override
public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
}
}