| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.hops.codegen.cplan; |
| |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| |
| import org.apache.sysds.hops.Hop; |
| import org.apache.sysds.common.Types.AggOp; |
| import org.apache.sysds.hops.codegen.SpoofFusedOp.SpoofOutputDimsType; |
| import org.apache.sysds.runtime.util.CollectionUtils; |
| import org.apache.sysds.runtime.util.UtilFunctions; |
| |
| public class CNodeMultiAgg extends CNodeTpl |
| { |
| private static final String TEMPLATE = |
| "package codegen;\n" |
| + "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n" |
| + "import org.apache.sysds.runtime.codegen.SpoofCellwise;\n" |
| + "import org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;\n" |
| + "import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;\n" |
| + "import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n" |
| + "import org.apache.commons.math3.util.FastMath;\n" |
| + "\n" |
| + "public final class %TMP% extends SpoofMultiAggregate { \n" |
| + " public %TMP%() {\n" |
| + " super(%SPARSE_SAFE%, %AGG_OP%);\n" |
| + " }\n" |
| + " protected void genexec(double a, SideInput[] b, double[] scalars, double[] c, " |
| + "int m, int n, long grix, int rix, int cix) { \n" |
| + "%BODY_dense%" |
| + " }\n" |
| + "}\n"; |
| private static final String TEMPLATE_OUT_SUM = " c[%IX%] += %IN%;\n"; |
| private static final String TEMPLATE_OUT_SUMSQ = " c[%IX%] += %IN% * %IN%;\n"; |
| private static final String TEMPLATE_OUT_MIN = " c[%IX%] = Math.min(c[%IX%], %IN%);\n"; |
| private static final String TEMPLATE_OUT_MAX = " c[%IX%] = Math.max(c[%IX%], %IN%);\n"; |
| |
| private ArrayList<CNode> _outputs = null; |
| private ArrayList<AggOp> _aggOps = null; |
| private ArrayList<Hop> _roots = null; |
| private boolean _sparseSafe = false; |
| |
| public CNodeMultiAgg(ArrayList<CNode> inputs, ArrayList<CNode> outputs) { |
| super(inputs, null); |
| _outputs = outputs; |
| } |
| |
| public ArrayList<CNode> getOutputs() { |
| return _outputs; |
| } |
| |
| @Override |
| public void resetVisitStatusOutputs() { |
| for( CNode output : _outputs ) |
| output.resetVisitStatus(); |
| } |
| |
| public void setAggOps(ArrayList<AggOp> aggOps) { |
| _aggOps = aggOps; |
| _hash = 0; |
| } |
| |
| public ArrayList<AggOp> getAggOps() { |
| return _aggOps; |
| } |
| |
| public void setRootNodes(ArrayList<Hop> roots) { |
| _roots = roots; |
| } |
| |
| public ArrayList<Hop> getRootNodes() { |
| return _roots; |
| } |
| |
| public void setSparseSafe(boolean flag) { |
| _sparseSafe = flag; |
| } |
| |
| public boolean isSparseSafe() { |
| return _sparseSafe; |
| } |
| |
| @Override |
| public void renameInputs() { |
| rRenameDataNode(_outputs, _inputs.get(0), "a"); // input matrix |
| renameInputs(_outputs, _inputs, 1); |
| } |
| |
| @Override |
| public String codegen(boolean sparse) { |
| // note: ignore sparse flag, generate both |
| String tmp = TEMPLATE; |
| |
| //generate dense/sparse bodies |
| StringBuilder sb = new StringBuilder(); |
| for( CNode out : _outputs ) |
| sb.append(out.codegen(false)); |
| for( CNode out : _outputs ) |
| out.resetGenerated(); |
| |
| //append output assignments |
| for( int i=0; i<_outputs.size(); i++ ) { |
| CNode out = _outputs.get(i); |
| String tmpOut = getAggTemplate(i); |
| //get variable name (w/ handling of direct consumption of inputs) |
| String varName = (out instanceof CNodeData && ((CNodeData)out).getHopID()== |
| ((CNodeData)_inputs.get(0)).getHopID()) ? "a" : out.getVarname(); |
| tmpOut = tmpOut.replace("%IN%", varName); |
| tmpOut = tmpOut.replace("%IX%", String.valueOf(i)); |
| sb.append(tmpOut); |
| } |
| |
| //replace class name and body |
| tmp = tmp.replace("%TMP%", createVarname()); |
| tmp = tmp.replace("%BODY_dense%", sb.toString()); |
| |
| //replace meta data information |
| String aggList = ""; |
| for( AggOp aggOp : _aggOps ) { |
| aggList += !aggList.isEmpty() ? "," : ""; |
| aggList += "AggOp."+aggOp.name(); |
| } |
| tmp = tmp.replace("%AGG_OP%", aggList); |
| tmp = tmp.replace("%SPARSE_SAFE%", |
| String.valueOf(isSparseSafe())); |
| |
| return tmp; |
| } |
| |
| @Override |
| public void setOutputDims() { |
| |
| } |
| |
| @Override |
| public SpoofOutputDimsType getOutputDimType() { |
| return SpoofOutputDimsType.MULTI_SCALAR; |
| } |
| |
| @Override |
| public CNodeTpl clone() { |
| CNodeMultiAgg ret = new CNodeMultiAgg(_inputs, _outputs); |
| ret.setAggOps(getAggOps()); |
| return ret; |
| } |
| |
| @Override |
| public int hashCode() { |
| if( _hash == 0 ) { |
| int h = super.hashCode(); |
| for( int i=0; i<_outputs.size(); i++ ) { |
| h = UtilFunctions.intHashCode(h, UtilFunctions.intHashCode( |
| _outputs.get(i).hashCode(), _aggOps.get(i).hashCode())); |
| } |
| _hash = h; |
| } |
| return _hash; |
| } |
| |
| @Override |
| public boolean equals(Object o) { |
| if(!(o instanceof CNodeMultiAgg)) |
| return false; |
| CNodeMultiAgg that = (CNodeMultiAgg)o; |
| return super.equals(o) |
| && CollectionUtils.equals(_aggOps, that._aggOps) |
| && equalInputReferences( |
| _outputs, that._outputs, _inputs, that._inputs); |
| } |
| |
| @Override |
| public String getTemplateInfo() { |
| StringBuilder sb = new StringBuilder(); |
| sb.append("SPOOF MULTIAGG [aggOps="); |
| sb.append(Arrays.toString(_aggOps.toArray(new AggOp[0]))); |
| sb.append("]"); |
| return sb.toString(); |
| } |
| |
| private String getAggTemplate(int pos) { |
| switch( _aggOps.get(pos) ) { |
| case SUM: return TEMPLATE_OUT_SUM; |
| case SUM_SQ: return TEMPLATE_OUT_SUMSQ; |
| case MIN: return TEMPLATE_OUT_MIN; |
| case MAX: return TEMPLATE_OUT_MAX; |
| default: |
| return null; |
| } |
| } |
| } |