blob: 39b0f6f719c15c3f16d55bc9bad2a5903fad03aa [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.codegen.cplan.java;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary.BinType;
import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
public class Binary implements CodeTemplate {
@Override
public String getTemplate(BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector,
boolean scalarInput) {
switch (type) {
case DOT_PRODUCT:
return sparseLhs ? " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
" double %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
case VECT_MATRIXMULT:
return sparseLhs ? " double[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" :
" double[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
case VECT_OUTERMULT_ADD:
return sparseLhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" :
sparseRhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" :
" LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";
//vector-scalar-add operations
case VECT_MULT_ADD:
case VECT_DIV_ADD:
case VECT_MINUS_ADD:
case VECT_PLUS_ADD:
case VECT_POW_ADD:
case VECT_XOR_ADD:
case VECT_MIN_ADD:
case VECT_MAX_ADD:
case VECT_EQUAL_ADD:
case VECT_NOTEQUAL_ADD:
case VECT_LESS_ADD:
case VECT_LESSEQUAL_ADD:
case VECT_GREATER_ADD:
case VECT_GREATEREQUAL_ADD:
case VECT_CBIND_ADD: {
String vectName = type.getVectorPrimitiveName();
if( scalarVector )
return sparseLhs ? " LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" :
" LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
else
return sparseLhs ? " LibSpoofPrimitives.vect"+vectName+"Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" :
" LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n";
}
//vector-scalar operations
case VECT_MULT_SCALAR:
case VECT_DIV_SCALAR:
case VECT_MINUS_SCALAR:
case VECT_PLUS_SCALAR:
case VECT_POW_SCALAR:
case VECT_XOR_SCALAR:
case VECT_BITWAND_SCALAR:
case VECT_MIN_SCALAR:
case VECT_MAX_SCALAR:
case VECT_EQUAL_SCALAR:
case VECT_NOTEQUAL_SCALAR:
case VECT_LESS_SCALAR:
case VECT_LESSEQUAL_SCALAR:
case VECT_GREATER_SCALAR:
case VECT_GREATEREQUAL_SCALAR: {
String vectName = type.getVectorPrimitiveName();
if( scalarVector )
return sparseRhs ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" :
" double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS2%, %LEN%);\n";
else
return sparseLhs ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" :
" double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
}
case VECT_CBIND:
if( scalarInput )
return " double[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%);\n";
else
return sparseLhs ?
" double[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" :
" double[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%, %POS1%, %LEN%);\n";
//vector-vector operations
case VECT_MULT:
case VECT_DIV:
case VECT_MINUS:
case VECT_PLUS:
case VECT_XOR:
case VECT_BITWAND:
case VECT_BIASADD:
case VECT_BIASMULT:
case VECT_MIN:
case VECT_MAX:
case VECT_EQUAL:
case VECT_NOTEQUAL:
case VECT_LESS:
case VECT_LESSEQUAL:
case VECT_GREATER:
case VECT_GREATEREQUAL: {
String vectName = type.getVectorPrimitiveName();
return sparseLhs ?
" double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN%);\n" :
sparseRhs ?
" double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" :
" double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
}
//scalar-scalar operations
case MULT:
return " double %TMP% = %IN1% * %IN2%;\n";
case DIV:
return " double %TMP% = %IN1% / %IN2%;\n";
case PLUS:
return " double %TMP% = %IN1% + %IN2%;\n";
case MINUS:
return " double %TMP% = %IN1% - %IN2%;\n";
case MODULUS:
return " double %TMP% = LibSpoofPrimitives.mod(%IN1%, %IN2%);\n";
case INTDIV:
return " double %TMP% = LibSpoofPrimitives.intDiv(%IN1%, %IN2%);\n";
case LESS:
return " double %TMP% = (%IN1% < %IN2%) ? 1 : 0;\n";
case LESSEQUAL:
return " double %TMP% = (%IN1% <= %IN2%) ? 1 : 0;\n";
case GREATER:
return " double %TMP% = (%IN1% > %IN2%) ? 1 : 0;\n";
case GREATEREQUAL:
return " double %TMP% = (%IN1% >= %IN2%) ? 1 : 0;\n";
case EQUAL:
return " double %TMP% = (%IN1% == %IN2%) ? 1 : 0;\n";
case NOTEQUAL:
return " double %TMP% = (%IN1% != %IN2%) ? 1 : 0;\n";
case MIN:
return " double %TMP% = Math.min(%IN1%, %IN2%);\n";
case MAX:
return " double %TMP% = Math.max(%IN1%, %IN2%);\n";
case LOG:
return " double %TMP% = Math.log(%IN1%)/Math.log(%IN2%);\n";
case LOG_NZ:
return " double %TMP% = (%IN1% == 0) ? 0 : Math.log(%IN1%)/Math.log(%IN2%);\n";
case POW:
return " double %TMP% = Math.pow(%IN1%, %IN2%);\n";
case MINUS1_MULT:
return " double %TMP% = 1 - %IN1% * %IN2%;\n";
case MINUS_NZ:
return " double %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
case XOR:
return " double %TMP% = ( (%IN1% != 0) != (%IN2% != 0) ) ? 1 : 0;\n";
case BITWAND:
return " double %TMP% = LibSpoofPrimitives.bwAnd(%IN1%, %IN2%);\n";
case SEQ_RIX:
return " double %TMP% = %IN1% + grix * %IN2%;\n"; //0-based global rix
default:
throw new RuntimeException("Invalid binary type: "+this.toString());
}
}
@Override
public String getTemplate() {
throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
}
@Override
public String getTemplate(SpoofCellwise.CellType ct) {
throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
}
@Override
public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
}
@Override
public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
}
}