blob: 6419dbcedef01280552e87dfbbe9fa72c8b895f5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.lops.Nary;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
/**
* The NaryOp Hop allows for a variable number of operands. Functionality
* such as 'printf' (overloaded into the existing print function) is an example
* of an operation that potentially takes a variable number of operands.
*
*/
public class NaryOp extends Hop {
protected OpOpN _op = null;
protected NaryOp() {
}
/**
* NaryOp constructor.
*
* @param name
* the target name, typically set by the DMLTranslator when
* constructing Hops. (For example, 'parsertemp1'.)
* @param dataType
* the target data type (SCALAR for printf)
* @param valueType
* the target value type (STRING for printf)
* @param op
* the operation type (such as PRINTF)
* @param inputs
* a variable number of input Hops
*/
public NaryOp(String name, DataType dataType, ValueType valueType,
OpOpN op, Hop... inputs) {
super(name, dataType, valueType);
_op = op;
for (int i = 0; i < inputs.length; i++) {
getInput().add(i, inputs[i]);
inputs[i].getParent().add(this);
}
refreshSizeInformation();
}
/** MultipleOp may have any number of inputs. */
@Override
public void checkArity() {}
public OpOpN getOp() {
return _op;
}
@Override
public String getOpString() {
return "m(" + _op.name().toLowerCase() + ")";
}
@Override
public boolean isGPUEnabled() {
return false;
}
/**
* Construct the corresponding Lops for this Hop
*/
@Override
public Lop constructLops() {
// reuse existing lop
if (getLops() != null)
return getLops();
try {
Lop[] inLops = new Lop[getInput().size()];
for (int i = 0; i < getInput().size(); i++)
inLops[i] = getInput().get(i).constructLops();
ExecType et = optFindExecType();
Nary multipleCPLop = new Nary(_op, getDataType(), getValueType(), inLops, et);
setOutputDimensions(multipleCPLop);
setLineNumbers(multipleCPLop);
setLops(multipleCPLop);
}
catch (Exception e) {
throw new HopsException(this.printErrorLocation() + "error constructing Lops for NaryOp -- \n ", e);
}
// add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
return getLops();
}
@Override
public boolean allowsAllExecTypes() {
return false;
}
@Override
public void computeMemEstimate(MemoTable memo) {
//overwrites default hops behavior
super.computeMemEstimate(memo);
//specific case for function call
if( _op == OpOpN.EVAL ) {
_memEstimate = OptimizerUtils.INT_SIZE;
_outputMemEstimate = OptimizerUtils.INT_SIZE;
_processingMemEstimate = 0;
}
}
@Override
protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
}
@Override
protected ExecType optFindExecType() {
checkAndSetForcedPlatform();
ExecType REMOTE = ExecType.SPARK;
//forced / memory-based / threshold-based decision
if( _etypeForced != null ) {
_etype = _etypeForced;
}
else
{
if ( OptimizerUtils.isMemoryBasedOptLevel() )
_etype = findExecTypeByMemEstimate();
// Choose CP, if the input dimensions are below threshold or if the input is a vector
else if ( areDimsBelowThreshold() )
_etype = ExecType.CP;
else
_etype = REMOTE;
//check for valid CP dimensions and matrix size
checkAndSetInvalidCPDimsAndSize();
}
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
//ensure cp exec type for single-node operations
if ( _op == OpOpN.PRINTF || _op == OpOpN.EVAL || _op == OpOpN.LIST
//TODO: cbind/rbind of lists only support in CP right now
|| (_op == OpOpN.CBIND && getInput().get(0).getDataType().isList())
|| (_op == OpOpN.RBIND && getInput().get(0).getDataType().isList())
|| _op.isCellOp() && getInput().stream().allMatch(h -> h.getDataType().isScalar()))
_etype = ExecType.CP;
return _etype;
}
@Override
protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) {
return 0;
}
@Override
@SuppressWarnings("incomplete-switch")
protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) {
if( !getDataType().isScalar() ) {
DataCharacteristics[] dc = memo.getAllInputStats(getInput());
switch( _op ) {
case CBIND: return new MatrixCharacteristics(
HopRewriteUtils.getMaxInputDim(dc, true),
HopRewriteUtils.getSumValidInputDims(dc, false), -1,
HopRewriteUtils.getSumValidInputNnz(dc, true));
case RBIND: return new MatrixCharacteristics(
HopRewriteUtils.getSumValidInputDims(dc, true),
HopRewriteUtils.getMaxInputDim(dc, false), -1,
HopRewriteUtils.getSumValidInputNnz(dc, true));
case MIN:
case MAX:
case PLUS: return new MatrixCharacteristics(
HopRewriteUtils.getMaxInputDim(this, true),
HopRewriteUtils.getMaxInputDim(this, false), -1, -1);
case LIST:
return new MatrixCharacteristics(getInput().size(), 1, -1, -1);
}
}
return null; //do nothing
}
@Override
public void refreshSizeInformation() {
switch( _op ) {
case CBIND:
if( !getInput().get(0).getDataType().isList() ) {
setDim1(HopRewriteUtils.getMaxInputDim(this, true));
setDim2(HopRewriteUtils.getSumValidInputDims(this, false));
setNnz(HopRewriteUtils.getSumValidInputNnz(this));
}
break;
case RBIND:
if( !getInput().get(0).getDataType().isList() ) {
setDim1(HopRewriteUtils.getSumValidInputDims(this, true));
setDim2(HopRewriteUtils.getMaxInputDim(this, false));
setNnz(HopRewriteUtils.getSumValidInputNnz(this));
}
break;
case MIN:
case MAX:
case PLUS:
setDim1(getDataType().isScalar() ? 0 : HopRewriteUtils.getMaxInputDim(this, true));
setDim2(getDataType().isScalar() ? 0 : HopRewriteUtils.getMaxInputDim(this, false));
break;
case LIST:
setDim1(getInput().size());
setDim2(1);
case PRINTF:
case EVAL:
//do nothing:
}
}
@Override
public Object clone() throws CloneNotSupportedException {
NaryOp multipleOp = new NaryOp();
// copy generic attributes
multipleOp.clone(this, false);
// copy specific attributes
multipleOp._op = _op;
return multipleOp;
}
@Override
public boolean compare(Hop that) {
if (!(that instanceof NaryOp) || _op == OpOpN.PRINTF)
return false;
NaryOp that2 = (NaryOp) that;
boolean ret = (_op == that2._op
&& getInput().size() == that2.getInput().size());
for( int i=0; i<getInput().size() && ret; i++ )
ret &= (getInput().get(i) == that2.getInput().get(i));
return ret;
}
}