blob: 6a3a6471318e494e64f03d7c23122bbcdc53f624 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.codegen.cplan;
import java.util.ArrayList;
import org.apache.sysds.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.runtime.codegen.SpoofOuterProduct.OutProdType;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI;
public class CNodeOuterProduct extends CNodeTpl
{
private static final String TEMPLATE =
"package codegen;\n"
+ "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n"
+ "import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n"
+ "import org.apache.sysds.runtime.codegen.SpoofOuterProduct;\n"
+ "import org.apache.sysds.runtime.codegen.SpoofOuterProduct.OutProdType;\n"
+ "import org.apache.commons.math3.util.FastMath;\n"
+ "\n"
+ "public final class %TMP% extends SpoofOuterProduct { \n"
+ " public %TMP%() {\n"
+ " super(OutProdType.%TYPE%);\n"
+ " }\n"
+ " protected void genexecDense(double a, double[] a1, int a1i, double[] a2, int a2i, SideInput[] b, double[] scalars, double[] c, int ci, int m, int n, int len, int rix, int cix) { \n"
+ "%BODY_dense%"
+ " }\n"
+ " protected double genexecCellwise(double a, double[] a1, int a1i, double[] a2, int a2i, SideInput[] b, double[] scalars, int m, int n, int len, int rix, int cix) { \n"
+ "%BODY_cellwise%"
+ " return %OUT_cellwise%;\n"
+ " }\n"
+ "}\n";
private OutProdType _type = null;
public MMTSJ.MMTSJType getMMTSJtype() {
return _mmtsj;
}
MMTSJ.MMTSJType _mmtsj;
private boolean _transposeOutput = false;
public CNodeOuterProduct(ArrayList<CNode> inputs, CNode output, MMTSJ.MMTSJType mmtsj) {
super(inputs,output);
_mmtsj = mmtsj;
// In case of a self transpose we need to add a duplicate input
if(_mmtsj != MMTSJ.MMTSJType.NONE)
_inputs.add(1,inputs.get(1));
}
@Override
public void renameInputs() {
rRenameDataNode(_output, _inputs.get(0), "a");
rRenameDataNode(_output, _inputs.get(1), "a1"); // u
rRenameDataNode(_output, _inputs.get(2), "a2"); // v
renameInputs(_inputs, 3);
}
@Override
public String codegen(boolean sparse, GeneratorAPI api) {
// note: ignore sparse flag, generate both
String tmp = TEMPLATE;
//generate dense/sparse bodies
String tmpDense = _output.codegen(false, api);
_output.resetGenerated();
tmp = tmp.replace("%TMP%", createVarname());
if(_type == OutProdType.LEFT_OUTER_PRODUCT || _type == OutProdType.RIGHT_OUTER_PRODUCT) {
tmp = tmp.replace("%BODY_dense%", tmpDense);
tmp = tmp.replace("%OUT%", "c");
tmp = tmp.replace("%BODY_cellwise%", "");
tmp = tmp.replace("%OUT_cellwise%", "0");
}
else {
tmp = tmp.replace("%BODY_dense%", "");
tmp = tmp.replace("%BODY_cellwise%", tmpDense);
tmp = tmp.replace("%OUT_cellwise%", _output.getVarname());
}
//replace size information
tmp = tmp.replace("%LEN%", "len");
tmp = tmp.replace("%POSOUT%", "ci");
tmp = tmp.replace("%TYPE%", _type.toString());
return tmp;
}
public void setOutProdType(OutProdType type) {
_type = type;
_hash = 0;
}
public OutProdType getOutProdType() {
return _type;
}
@Override
public void setOutputDims() {
}
public void setTransposeOutput(boolean transposeOutput) {
_transposeOutput = transposeOutput;
_hash = 0;
}
public boolean isTransposeOutput() {
return _transposeOutput;
}
@Override
public SpoofOutputDimsType getOutputDimType() {
switch( _type ) {
case LEFT_OUTER_PRODUCT:
return SpoofOutputDimsType.COLUMN_RANK_DIMS;
case RIGHT_OUTER_PRODUCT:
return SpoofOutputDimsType.ROW_RANK_DIMS;
case CELLWISE_OUTER_PRODUCT:
return SpoofOutputDimsType.INPUT_DIMS;
case AGG_OUTER_PRODUCT:
return SpoofOutputDimsType.SCALAR;
default:
throw new RuntimeException("Unsupported outer product type: "+_type.toString());
}
}
@Override
public CNodeTpl clone() {
return new CNodeOuterProduct(_inputs, _output, _mmtsj);
}
@Override
public int hashCode() {
if( _hash == 0 ) {
int h = UtilFunctions.intHashCode(super.hashCode(), _type.hashCode());
h = UtilFunctions.intHashCode(h, Boolean.hashCode(_transposeOutput));
_hash = h;
}
return _hash;
}
@Override
public boolean equals(Object o) {
if(!(o instanceof CNodeOuterProduct))
return false;
CNodeOuterProduct that = (CNodeOuterProduct)o;
return super.equals(that)
&& _type == that._type
&& _transposeOutput == that._transposeOutput
&& equalInputReferences(
_output, that._output, _inputs, that._inputs);
}
@Override
public String getTemplateInfo() {
StringBuilder sb = new StringBuilder();
sb.append("SPOOF OUTERPRODUCT [type=");
sb.append(_type.name());
sb.append(", to="+_transposeOutput);
sb.append("]");
return sb.toString();
}
@Override
public boolean isSupported(GeneratorAPI api) {
boolean is_supported = (api == GeneratorAPI.JAVA);
int i = 0;
while(is_supported && i < _inputs.size()) {
CNode in = _inputs.get(i++);
is_supported = in.isSupported(api);
}
return is_supported;
}
}