blob: 64df97e6293ff6285af99e037a143df8b60d1719 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp4;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.lops.LopsException;
import org.apache.sysds.lops.WeightedCrossEntropy;
import org.apache.sysds.lops.WeightedCrossEntropy.WCeMMType;
import org.apache.sysds.lops.WeightedCrossEntropyR;
import org.apache.sysds.lops.WeightedDivMM;
import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
import org.apache.sysds.lops.WeightedDivMMR;
import org.apache.sysds.lops.WeightedSigmoid;
import org.apache.sysds.lops.WeightedSigmoid.WSigmoidType;
import org.apache.sysds.lops.WeightedSigmoidR;
import org.apache.sysds.lops.WeightedSquaredLoss;
import org.apache.sysds.lops.WeightedSquaredLoss.WeightsType;
import org.apache.sysds.lops.WeightedSquaredLossR;
import org.apache.sysds.lops.WeightedUnaryMM;
import org.apache.sysds.lops.WeightedUnaryMM.WUMMType;
import org.apache.sysds.lops.WeightedUnaryMMR;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
/**
* Note: this hop should be called AggQuaternaryOp in consistency with AggUnaryOp and AggBinaryOp;
* however, since there does not exist a real QuaternaryOp yet - we can leave it as is for now.
*/
public class QuaternaryOp extends MultiThreadedHop
{
//config influencing mr operator selection (for testing purposes only)
public static boolean FORCE_REPLICATION = false;
private OpOp4 _op = null;
//wsloss-specific attributes
private boolean _postWeights = false;
//wsigmoid-specific attributes
private boolean _logout = false;
private boolean _minusin = false;
//wdivmm-specific attributes
private int _baseType = -1;
private boolean _mult = false;
private boolean _minus = false;
//wumm-specific attributes
private boolean _umult = false;
private OpOp1 _uop = null;
private OpOp2 _sop = null;
private QuaternaryOp() {
//default constructor for clone
}
/**
* Constructor for wsloss.
*
* @param l ?
* @param dt data type
* @param vt value type
* @param o the Hop.OpOp4
* @param inX high-level operator X
* @param inU high-level operator U
* @param inV high-level operator V
* @param inW high-level operator W
* @param post post weights
*/
public QuaternaryOp(String l, DataType dt, ValueType vt, OpOp4 o,
Hop inX, Hop inU, Hop inV, Hop inW, boolean post)
{
this(l, dt, vt, o, inX, inU, inV);
getInput().add(3, inW);
inW.getParent().add(this);
_postWeights = post;
}
/**
* Constructor for wsigmoid.
*
* @param l ?
* @param dt data type
* @param vt value type
* @param o the Hop.OpOp4
* @param inX high-level operator X
* @param inU high-level operator U
* @param inV high-level operator V
* @param flag1 logout
* @param flag2 minusin
*/
public QuaternaryOp(String l, DataType dt, ValueType vt, OpOp4 o,
Hop inX, Hop inU, Hop inV, boolean flag1, boolean flag2)
{
this(l, dt, vt, o, inX, inU, inV);
_logout = flag1;
_minusin = flag2;
}
public QuaternaryOp(String l, DataType dt, ValueType vt, OpOp4 o,
Hop inX, Hop inU, Hop inV, Hop inW, int baseType, boolean flag1, boolean flag2)
{
this(l, dt, vt, o, inX, inU, inV);
if( inW != null ) { //four inputs
getInput().add(3, inW);
inW.getParent().add(this);
}
_baseType = baseType;
_mult = flag1;
_minus = flag2;
}
public QuaternaryOp(String l, DataType dt, ValueType vt, OpOp4 o,
Hop inW, Hop inU, Hop inV, boolean umult, OpOp1 uop, OpOp2 sop)
{
this(l, dt, vt, o, inW, inU, inV);
_umult = umult;
_uop = uop;
_sop = sop;
}
public QuaternaryOp(String l, DataType dt, ValueType vt, OpOp4 o, Hop inX, Hop inU, Hop inV) {
super(l, dt, vt);
_op = o;
getInput().add(0, inX);
getInput().add(1, inU);
getInput().add(2, inV);
inX.getParent().add(this);
inU.getParent().add(this);
inV.getParent().add(this);
}
@Override
public void checkArity() {
HopsException.check(_input.size() == 3 || _input.size() == 4, this,
"should have arity 3 or 4 but has arity %d", _input.size());
}
public OpOp4 getOp(){
return _op;
}
@Override
public boolean isGPUEnabled() {
return false;
}
@Override
public boolean isMultiThreadedOpType() {
return true;
}
@Override
public Lop constructLops()
{
//return already created lops
if( getLops() != null )
return getLops();
try
{
ExecType et = optFindExecType();
switch( _op ) {
case WSLOSS: {
WeightsType wtype = checkWeightsType();
if( et == ExecType.CP )
constructCPLopsWeightedSquaredLoss(wtype);
else if( et == ExecType.SPARK )
constructSparkLopsWeightedSquaredLoss(wtype);
else
throw new HopsException("Unsupported quaternaryop-wsloss exec type: "+et);
break;
}
case WSIGMOID:{
WSigmoidType wtype = checkWSigmoidType();
if( et == ExecType.CP )
constructCPLopsWeightedSigmoid(wtype);
else if( et == ExecType.SPARK )
constructSparkLopsWeightedSigmoid(wtype);
else
throw new HopsException("Unsupported quaternaryop-wsigmoid exec type: "+et);
break;
}
case WDIVMM:{
WDivMMType wtype = checkWDivMMType();
if( et == ExecType.CP )
constructCPLopsWeightedDivMM(wtype);
else if( et == ExecType.SPARK )
constructSparkLopsWeightedDivMM(wtype);
else
throw new HopsException("Unsupported quaternaryop-wdivmm exec type: "+et);
break;
}
case WCEMM:{
WCeMMType wtype = checkWCeMMType();
if( et == ExecType.CP )
constructCPLopsWeightedCeMM(wtype);
else if( et == ExecType.SPARK )
constructSparkLopsWeightedCeMM(wtype);
else
throw new HopsException("Unsupported quaternaryop-wcemm exec type: "+et);
break;
}
case WUMM:{
WUMMType wtype = _umult ? WUMMType.MULT : WUMMType.DIV;
if( et == ExecType.CP )
constructCPLopsWeightedUMM(wtype);
else if( et == ExecType.SPARK )
constructSparkLopsWeightedUMM(wtype);
else
throw new HopsException("Unsupported quaternaryop-wumm exec type: "+et);
break;
}
default:
throw new HopsException(this.printErrorLocation() + "Unknown QuaternaryOp (" + _op + ") while constructing Lops");
}
}
catch(LopsException e) {
throw new HopsException(this.printErrorLocation() + "error constructing lops for QuaternaryOp." , e);
}
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
return getLops();
}
@Override
public String getOpString() {
return "q(" + _op.toString() + ")";
}
@Override
public boolean allowsAllExecTypes() {
return true;
}
private void constructCPLopsWeightedSquaredLoss(WeightsType wtype)
{
WeightedSquaredLoss wsloss = new WeightedSquaredLoss(
getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
getInput().get(2).constructLops(),
getInput().get(3).constructLops(),
getDataType(), getValueType(), wtype, ExecType.CP);
//set degree of parallelism
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
wsloss.setNumThreads(k);
setOutputDimensions( wsloss );
setLineNumbers( wsloss );
setLops( wsloss );
}
private void constructSparkLopsWeightedSquaredLoss(WeightsType wtype)
{
//NOTE: the common case for wsloss are factors U/V with a rank of 10s to 100s; the current runtime only
//supports single block outer products (U/V rank <= blocksize, i.e., 1000 by default); we enforce this
//by applying the hop rewrite for Weighted Squared Loss only if this constraint holds.
//Notes: Any broadcast needs to fit twice in local memory because we partition the input in cp,
//and needs to fit once in executor broadcast memory. The 2GB broadcast constraint is no longer
//required because the max_int byte buffer constraint has been fixed in Spark 1.4
double memBudgetExec = SparkExecutionContext.getBroadcastMemoryBudget();
double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
Hop X = getInput().get(0);
Hop U = getInput().get(1);
Hop V = getInput().get(2);
Hop W = getInput().get(3);
//MR operator selection, part1
double m1Size = OptimizerUtils.estimateSize(U.getDim1(), U.getDim2()); //size U
double m2Size = OptimizerUtils.estimateSize(V.getDim1(), V.getDim2()); //size V
boolean isMapWsloss = (!wtype.hasFourInputs() && m1Size+m2Size < memBudgetExec
&& 2*m1Size < memBudgetLocal && 2*m2Size < memBudgetLocal);
if( !FORCE_REPLICATION && isMapWsloss ) //broadcast
{
//map-side wsloss always with broadcast
Lop wsloss = new WeightedSquaredLoss( X.constructLops(), U.constructLops(), V.constructLops(),
W.constructLops(), DataType.SCALAR, ValueType.FP64, wtype, ExecType.SPARK);
setOutputDimensions(wsloss);
setLineNumbers(wsloss);
setLops(wsloss);
}
else //general case
{
//MR operator selection part 2
boolean cacheU = !FORCE_REPLICATION && (m1Size < memBudgetExec && 2*m1Size < memBudgetLocal);
boolean cacheV = !FORCE_REPLICATION && ((!cacheU && m2Size < memBudgetExec )
|| (cacheU && m1Size+m2Size < memBudgetExec)) && 2*m2Size < memBudgetLocal;
//reduce-side wsloss w/ or without broadcast
Lop wsloss = new WeightedSquaredLossR(
X.constructLops(), U.constructLops(), V.constructLops(), W.constructLops(),
DataType.SCALAR, ValueType.FP64, wtype, cacheU, cacheV, ExecType.SPARK);
setOutputDimensions(wsloss);
setLineNumbers(wsloss);
setLops(wsloss);
}
}
private void constructCPLopsWeightedSigmoid(WSigmoidType wtype) {
WeightedSigmoid wsig = new WeightedSigmoid(
getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
getInput().get(2).constructLops(),
getDataType(), getValueType(), wtype, ExecType.CP);
//set degree of parallelism
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
wsig.setNumThreads(k);
setOutputDimensions( wsig );
setLineNumbers( wsig );
setLops( wsig );
}
private void constructSparkLopsWeightedSigmoid( WSigmoidType wtype )
{
//NOTE: the common case for wsigmoid are factors U/V with a rank of 10s to 100s; the current runtime only
//supports single block outer products (U/V rank <= blocksize, i.e., 1000 by default); we enforce this
//by applying the hop rewrite for Weighted Sigmoid only if this constraint holds.
//Notes: Any broadcast needs to fit twice in local memory because we partition the input in cp,
//and needs to fit once in executor broadcast memory. The 2GB broadcast constraint is no longer
//required because the max_int byte buffer constraint has been fixed in Spark 1.4
double memBudgetExec = SparkExecutionContext.getBroadcastMemoryBudget();
double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
Hop X = getInput().get(0);
Hop U = getInput().get(1);
Hop V = getInput().get(2);
//MR operator selection, part1
double m1Size = OptimizerUtils.estimateSize(U.getDim1(), U.getDim2()); //size U
double m2Size = OptimizerUtils.estimateSize(V.getDim1(), V.getDim2()); //size V
boolean isMapWsig = (m1Size+m2Size < memBudgetExec
&& 2*m1Size<memBudgetLocal && 2*m2Size<memBudgetLocal);
if( !FORCE_REPLICATION && isMapWsig ) //broadcast
{
//map-side wsig always with broadcast
Lop wsigmoid = new WeightedSigmoid( X.constructLops(), U.constructLops(), V.constructLops(),
DataType.MATRIX, ValueType.FP64, wtype, ExecType.SPARK);
setOutputDimensions(wsigmoid);
setLineNumbers(wsigmoid);
setLops( wsigmoid );
}
else //general case
{
//MR operator selection part 2
boolean cacheU = !FORCE_REPLICATION && (m1Size < memBudgetExec && 2*m1Size < memBudgetLocal);
boolean cacheV = !FORCE_REPLICATION && ((!cacheU && m2Size < memBudgetExec )
|| (cacheU && m1Size+m2Size < memBudgetExec)) && 2*m2Size < memBudgetLocal;
//reduce-side wsig w/ or without broadcast
Lop wsigmoid = new WeightedSigmoidR(
X.constructLops(), U.constructLops(), V.constructLops(),
DataType.MATRIX, ValueType.FP64, wtype, cacheU, cacheV, ExecType.SPARK);
setOutputDimensions(wsigmoid);
setLineNumbers(wsigmoid);
setLops(wsigmoid);
}
}
private void constructCPLopsWeightedDivMM(WDivMMType wtype)
{
WeightedDivMM wdiv = new WeightedDivMM(
getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
getInput().get(2).constructLops(),
getInput().get(3).constructLops(),
getDataType(), getValueType(), wtype, ExecType.CP);
//set degree of parallelism
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
wdiv.setNumThreads(k);
setOutputDimensions( wdiv );
setLineNumbers( wdiv );
setLops( wdiv );
}
private void constructSparkLopsWeightedDivMM( WDivMMType wtype )
{
//NOTE: the common case for wdivmm are factors U/V with a rank of 10s to 100s; the current runtime only
//supports single block outer products (U/V rank <= blocksize, i.e., 1000 by default); we enforce this
//by applying the hop rewrite for Weighted DivMM only if this constraint holds.
//Notes: Any broadcast needs to fit twice in local memory because we partition the input in cp,
//and needs to fit once in executor broadcast memory. The 2GB broadcast constraint is no longer
//required because the max_int byte buffer constraint has been fixed in Spark 1.4
double memBudgetExec = SparkExecutionContext.getBroadcastMemoryBudget();
double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
Hop W = getInput().get(0);
Hop U = getInput().get(1);
Hop V = getInput().get(2);
Hop X = getInput().get(3);
//MR operator selection, part1
double m1Size = OptimizerUtils.estimateSize(U.getDim1(), U.getDim2()); //size U
double m2Size = OptimizerUtils.estimateSize(V.getDim1(), V.getDim2()); //size V
boolean isMapWdivmm = ((!wtype.hasFourInputs() || wtype.hasScalar()) && m1Size+m2Size < memBudgetExec
&& 2*m1Size<memBudgetLocal && 2*m2Size<memBudgetLocal);
if( !FORCE_REPLICATION && isMapWdivmm ) //broadcast
{
//map-side wdivmm always with broadcast
Lop wdivmm = new WeightedDivMM( W.constructLops(), U.constructLops(), V.constructLops(),
X.constructLops(), DataType.MATRIX, ValueType.FP64, wtype, ExecType.SPARK);
setOutputDimensions(wdivmm);
setLineNumbers(wdivmm);
setLops( wdivmm );
}
else //general case
{
//MR operator selection part 2
boolean cacheU = !FORCE_REPLICATION && (m1Size < memBudgetExec && 2*m1Size < memBudgetLocal);
boolean cacheV = !FORCE_REPLICATION && ((!cacheU && m2Size < memBudgetExec )
|| (cacheU && m1Size+m2Size < memBudgetExec)) && 2*m2Size < memBudgetLocal;
//reduce-side wdivmm w/ or without broadcast
Lop wdivmm = new WeightedDivMMR(
W.constructLops(), U.constructLops(), V.constructLops(), X.constructLops(),
DataType.MATRIX, ValueType.FP64, wtype, cacheU, cacheV, ExecType.SPARK);
setOutputDimensions(wdivmm);
setLineNumbers(wdivmm);
setLops(wdivmm);
}
}
private void constructCPLopsWeightedCeMM(WCeMMType wtype)
{
WeightedCrossEntropy wcemm = new WeightedCrossEntropy(
getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
getInput().get(2).constructLops(),
getInput().get(3).constructLops(),
getDataType(), getValueType(), wtype, ExecType.CP);
//set degree of parallelism
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
wcemm.setNumThreads(k);
setOutputDimensions( wcemm );
setLineNumbers( wcemm );
setLops( wcemm );
}
private void constructSparkLopsWeightedCeMM(WCeMMType wtype)
{
//NOTE: the common case for wcemm are factors U/V with a rank of 10s to 100s; the current runtime only
//supports single block outer products (U/V rank <= blocksize, i.e., 1000 by default); we enforce this
//by applying the hop rewrite for Weighted Cross Entropy only if this constraint holds.
//Notes: Any broadcast needs to fit twice in local memory because we partition the input in cp,
//and needs to fit once in executor broadcast memory. The 2GB broadcast constraint is no longer
//required because the max_int byte buffer constraint has been fixed in Spark 1.4
double memBudgetExec = SparkExecutionContext.getBroadcastMemoryBudget();
double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
Hop X = getInput().get(0);
Hop U = getInput().get(1);
Hop V = getInput().get(2);
Hop eps = getInput().get(3);
//MR operator selection, part1
double m1Size = OptimizerUtils.estimateSize(U.getDim1(), U.getDim2()); //size U
double m2Size = OptimizerUtils.estimateSize(V.getDim1(), V.getDim2()); //size V
boolean isMapWcemm = (m1Size+m2Size < memBudgetExec
&& 2*m1Size < memBudgetLocal && 2*m2Size < memBudgetLocal);
if( !FORCE_REPLICATION && isMapWcemm ) //broadcast
{
//map-side wcemm always with broadcast
Lop wcemm = new WeightedCrossEntropy( X.constructLops(), U.constructLops(), V.constructLops(),
eps.constructLops(), DataType.SCALAR, ValueType.FP64, wtype, ExecType.SPARK);
setOutputDimensions(wcemm);
setLineNumbers(wcemm);
setLops(wcemm);
}
else //general case
{
//MR operator selection part 2
boolean cacheU = !FORCE_REPLICATION && (m1Size < memBudgetExec && 2*m1Size < memBudgetLocal);
boolean cacheV = !FORCE_REPLICATION && ((!cacheU && m2Size < memBudgetExec )
|| (cacheU && m1Size+m2Size < memBudgetExec)) && 2*m2Size < memBudgetLocal;
//reduce-side wcemm w/ or without broadcast
Lop wcemm = new WeightedCrossEntropyR(
X.constructLops(), U.constructLops(), V.constructLops(), eps.constructLops(),
DataType.SCALAR, ValueType.FP64, wtype, cacheU, cacheV, ExecType.SPARK);
setOutputDimensions(wcemm);
setLineNumbers(wcemm);
setLops(wcemm);
}
}
private void constructCPLopsWeightedUMM(WUMMType wtype) {
OpOp1 uop = _uop!=null ? _uop : _sop==OpOp2.POW ?
OpOp1.POW2 : OpOp1.MULT2;
WeightedUnaryMM wumm = new WeightedUnaryMM(
getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
getInput().get(2).constructLops(),
getDataType(), getValueType(), wtype, uop, ExecType.CP);
//set degree of parallelism
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
wumm.setNumThreads(k);
setOutputDimensions( wumm );
setLineNumbers( wumm );
setLops( wumm );
}
private void constructSparkLopsWeightedUMM( WUMMType wtype )
{
//NOTE: the common case for wumm are factors U/V with a rank of 10s to 100s; the current runtime only
//supports single block outer products (U/V rank <= blocksize, i.e., 1000 by default); we enforce this
//by applying the hop rewrite for Weighted UnaryMM only if this constraint holds.
OpOp1 uop = _uop!=null ? _uop : _sop==OpOp2.POW ?
OpOp1.POW2 : OpOp1.MULT2;
//Notes: Any broadcast needs to fit twice in local memory because we partition the input in cp,
//and needs to fit once in executor broadcast memory. The 2GB broadcast constraint is no longer
//required because the max_int byte buffer constraint has been fixed in Spark 1.4
double memBudgetExec = SparkExecutionContext.getBroadcastMemoryBudget();
double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
Hop X = getInput().get(0);
Hop U = getInput().get(1);
Hop V = getInput().get(2);
//MR operator selection, part1
double m1Size = OptimizerUtils.estimateSize(U.getDim1(), U.getDim2()); //size U
double m2Size = OptimizerUtils.estimateSize(V.getDim1(), V.getDim2()); //size V
boolean isMapWsloss = (m1Size+m2Size < memBudgetExec
&& 2*m1Size<memBudgetLocal && 2*m2Size<memBudgetLocal);
if( !FORCE_REPLICATION && isMapWsloss ) //broadcast
{
//map-side wumm always with broadcast
Lop wumm = new WeightedUnaryMM( X.constructLops(), U.constructLops(),
V.constructLops(), DataType.MATRIX, ValueType.FP64, wtype, uop, ExecType.SPARK);
setOutputDimensions(wumm);
setLineNumbers(wumm);
setLops( wumm );
}
else //general case
{
//MR operator selection part 2
boolean cacheU = !FORCE_REPLICATION && (m1Size < memBudgetExec && 2*m1Size < memBudgetLocal);
boolean cacheV = !FORCE_REPLICATION && ((!cacheU && m2Size < memBudgetExec )
|| (cacheU && m1Size+m2Size < memBudgetExec)) && 2*m2Size < memBudgetLocal;
//reduce-side wumm w/ or without broadcast
Lop wumm = new WeightedUnaryMMR(
X.constructLops(), U.constructLops(), V.constructLops(),
DataType.MATRIX, ValueType.FP64, wtype, uop, cacheU, cacheV, ExecType.SPARK);
setOutputDimensions(wumm);
setLineNumbers(wumm);
setLops(wumm);
}
}
private WeightsType checkWeightsType()
{
WeightsType ret = WeightsType.NONE;
if( !(getInput().get(3) instanceof LiteralOp) ){
if( _postWeights )
ret = WeightsType.POST;
else
ret = WeightsType.PRE;
}
else if( _postWeights ){
ret = WeightsType.POST_NZ;
}
return ret;
}
private WSigmoidType checkWSigmoidType() {
if( _logout && _minusin )
return WSigmoidType.LOG_MINUS;
else if( _logout )
return WSigmoidType.LOG;
else if( _minusin )
return WSigmoidType.MINUS;
else
return WSigmoidType.BASIC;
}
private WDivMMType checkWDivMMType()
{
switch( _baseType )
{
case 0: //BASIC
return WDivMMType.MULT_BASIC;
case 1: //LEFT
if( getInput().get(3).getDataType()==DataType.MATRIX )
return WDivMMType.MULT_MINUS_4_LEFT;
else if( _minus )
return WDivMMType.MULT_MINUS_LEFT;
else
return _mult ? WDivMMType.MULT_LEFT : WDivMMType.DIV_LEFT;
case 2: //RIGHT
if( getInput().get(3).getDataType()==DataType.MATRIX )
return WDivMMType.MULT_MINUS_4_RIGHT;
else if( _minus )
return WDivMMType.MULT_MINUS_RIGHT;
else
return _mult ? WDivMMType.MULT_RIGHT : WDivMMType.DIV_RIGHT;
case 3: //LEFT w/EPS
return WDivMMType.DIV_LEFT_EPS;
case 4: //RIGHT w/EPS
return WDivMMType.DIV_RIGHT_EPS;
}
return null;
}
private WCeMMType checkWCeMMType() {
return _baseType == 1 ? WCeMMType.BASIC_EPS : WCeMMType.BASIC;
}
@Override
protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
{
switch( _op ) {
case WSLOSS: //always scalar output
case WCEMM:
return OptimizerUtils.DOUBLE_SIZE;
case WSIGMOID:
case WDIVMM:
case WUMM:
double sp = OptimizerUtils.getSparsity(dim1, dim2, nnz);
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sp);
default:
return 0;
}
}
@Override
protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz ) {
//no intermediates
return 0;
}
@Override
protected DataCharacteristics inferOutputCharacteristics( MemoTable memo )
{
DataCharacteristics ret = null;
switch( _op ) {
case WSLOSS: //always scalar output
break; //null
case WSIGMOID:
case WUMM: {
DataCharacteristics mcW = memo.getAllInputStats(getInput().get(0));
ret = new MatrixCharacteristics(mcW.getRows(), mcW.getCols(), -1, mcW.getNonZeros());
break;
}
case WDIVMM: {
if( _baseType == 0 ){ //basic
ret = memo.getAllInputStats(getInput().get(0));
}
else if( _baseType == 1 || _baseType == 3 ) { //left (w/ transpose or w/ epsilon)
DataCharacteristics mcV = memo.getAllInputStats(getInput().get(2));
ret = mcV.setNonZeros(-1);
}
else { //right
DataCharacteristics mcU = memo.getAllInputStats(getInput().get(1));
ret = mcU.setNonZeros(-1);
}
break;
}
default:
throw new RuntimeException("Memory for operation (" + _op + ") can not be estimated.");
}
return ret;
}
@Override
protected ExecType optFindExecType()
{
checkAndSetForcedPlatform();
if( _etypeForced != null )
{
_etype = _etypeForced;
}
else
{
if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
_etype = findExecTypeByMemEstimate();
}
else if ( (getInput().get(0).areDimsBelowThreshold()
&& getInput().get(1).areDimsBelowThreshold()
&& getInput().get(2).areDimsBelowThreshold()
&& getInput().get(3).areDimsBelowThreshold()) )
_etype = ExecType.CP;
else
_etype = ExecType.SPARK;
//check for valid CP dimensions and matrix size
checkAndSetInvalidCPDimsAndSize();
}
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
return _etype;
}
@Override
public void refreshSizeInformation()
{
switch( _op ) {
case WSLOSS:
//do nothing: always scalar
break;
case WSIGMOID:
case WUMM: {
Hop inW = getInput().get(0);
setDim1( inW.getDim1() );
setDim2( inW.getDim2() );
setNnz( inW.getNnz() );
break;
}
case WDIVMM: {
if( _baseType == 0 ) { //basic
Hop inW = getInput().get(0);
setDim1( inW.getDim1() );
setDim2( inW.getDim2() );
setNnz( inW.getNnz() );
}
else if( _baseType == 1 || _baseType == 3 ){ //left (w/ transpose or w/ epsilon)
Hop inV = getInput().get(2);
setDim1( inV.getDim1() );
setDim2( inV.getDim2() );
setNnz( -1 ); //reset
}
else { //right
Hop inU = getInput().get(1);
setDim1( inU.getDim1() );
setDim2( inU.getDim2() );
setNnz( -1 ); //reset
}
break;
}
default:
break;
}
}
@Override
public Object clone() throws CloneNotSupportedException
{
QuaternaryOp ret = new QuaternaryOp();
//copy generic attributes
ret.clone(this, false);
//copy specific attributes
ret._op = _op;
ret._postWeights = _postWeights;
ret._logout = _logout;
ret._minusin = _minusin;
ret._baseType = _baseType;
ret._mult = _mult;
ret._minus = _minus;
ret._umult = _umult;
ret._uop = _uop;
ret._sop = _sop;
ret._maxNumThreads = _maxNumThreads;
return ret;
}
@Override
public boolean compare( Hop that )
{
if( !(that instanceof QuaternaryOp) )
return false;
QuaternaryOp that2 = (QuaternaryOp)that;
//compare basic inputs and weights (always existing)
boolean ret = (_op == that2._op
&& getInput().size() == that2.getInput().size()
&& getInput().get(0) == that2.getInput().get(0)
&& getInput().get(1) == that2.getInput().get(1)
&& getInput().get(2) == that2.getInput().get(2) );
//check for 4th argument if same size (see above)
if( ret && getInput().size()==4 )
ret &= (getInput().get(3) == that2.getInput().get(3));
//compare specific parameters
ret &= _postWeights == that2._postWeights;
ret &= _logout == that2._logout;
ret &= _minusin == that2._minusin;
ret &= _baseType == that2._baseType;
ret &= _mult == that2._mult;
ret &= _minus == that2._minus;
ret &= _umult == that2._umult;
ret &= _uop == that2._uop;
ret &= _sop == that2._sop;
ret &= _maxNumThreads == that2._maxNumThreads;
return ret;
}
}