| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.hops; |
| |
| import org.apache.sysds.api.DMLScript; |
| import org.apache.sysds.common.Types.AggOp; |
| import org.apache.sysds.common.Types.DataType; |
| import org.apache.sysds.common.Types.Direction; |
| import org.apache.sysds.common.Types.OpOp1; |
| import org.apache.sysds.common.Types.OpOp2; |
| import org.apache.sysds.common.Types.ValueType; |
| import org.apache.sysds.hops.AggBinaryOp.SparkAggType; |
| import org.apache.sysds.hops.rewrite.HopRewriteUtils; |
| import org.apache.sysds.lops.Lop; |
| import org.apache.sysds.lops.LopProperties.ExecType; |
| import org.apache.sysds.lops.PartialAggregate; |
| import org.apache.sysds.lops.TernaryAggregate; |
| import org.apache.sysds.lops.UAggOuterChain; |
| import org.apache.sysds.lops.UnaryCP; |
| import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; |
| import org.apache.sysds.runtime.meta.DataCharacteristics; |
| import org.apache.sysds.runtime.meta.MatrixCharacteristics; |
| |
| |
| // Aggregate unary (cell) operation: Sum (aij), col_sum, row_sum |
| |
| public class AggUnaryOp extends MultiThreadedHop |
| { |
| private static final boolean ALLOW_UNARYAGG_WO_FINAL_AGG = true; |
| |
| private AggOp _op; |
| private Direction _direction; |
| |
| private AggUnaryOp() { |
| //default constructor for clone |
| } |
| |
| public AggUnaryOp(String l, DataType dt, ValueType vt, AggOp o, Direction idx, Hop inp) |
| { |
| super(l, dt, vt); |
| _op = o; |
| _direction = idx; |
| getInput().add(0, inp); |
| inp.getParent().add(this); |
| } |
| |
| @Override |
| public void checkArity() { |
| HopsException.check(_input.size() == 1, this, "should have arity 1 but has arity %d", _input.size()); |
| } |
| |
| public AggOp getOp() |
| { |
| return _op; |
| } |
| |
| public void setOp(AggOp op) |
| { |
| _op = op; |
| } |
| |
| public Direction getDirection() |
| { |
| return _direction; |
| } |
| |
| public void setDirection(Direction direction) |
| { |
| _direction = direction; |
| } |
| |
| |
| @Override |
| public boolean isGPUEnabled() { |
| if(!DMLScript.USE_ACCELERATOR) |
| return false; |
| |
| try { |
| if( isTernaryAggregateRewriteApplicable() || isUnaryAggregateOuterCPRewriteApplicable() ) { |
| return false; |
| } |
| else if ((_op == AggOp.SUM && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) |
| || (_op == AggOp.SUM_SQ && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) |
| || (_op == AggOp.MAX && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) |
| || (_op == AggOp.MIN && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) |
| || (_op == AggOp.MEAN && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) |
| || (_op == AggOp.VAR && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) |
| || (_op == AggOp.PROD && (_direction == Direction.RowCol))){ |
| return true; |
| } |
| } catch (HopsException e) { |
| throw new RuntimeException(e); |
| } |
| return false; |
| } |
| |
| @Override |
| public Lop constructLops() |
| { |
| //return already created lops |
| if( getLops() != null ) |
| return getLops(); |
| |
| try |
| { |
| ExecType et = optFindExecType(); |
| Hop input = getInput().get(0); |
| |
| if ( et == ExecType.CP || et == ExecType.GPU ) |
| { |
| Lop agg1 = null; |
| if( isTernaryAggregateRewriteApplicable() ) { |
| agg1 = constructLopsTernaryAggregateRewrite(et); |
| } |
| else if( isUnaryAggregateOuterCPRewriteApplicable() ) |
| { |
| BinaryOp binput = (BinaryOp)getInput().get(0); |
| agg1 = new UAggOuterChain( binput.getInput().get(0).constructLops(), |
| binput.getInput().get(1).constructLops(), _op, _direction, |
| binput.getOp(), DataType.MATRIX, getValueType(), ExecType.CP); |
| PartialAggregate.setDimensionsBasedOnDirection(agg1, getDim1(), getDim2(), input.getBlocksize(), _direction); |
| |
| if (getDataType() == DataType.SCALAR) { |
| UnaryCP unary1 = new UnaryCP(agg1, OpOp1.CAST_AS_SCALAR, |
| getDataType(), getValueType()); |
| unary1.getOutputParameters().setDimensions(0, 0, 0, -1); |
| setLineNumbers(unary1); |
| agg1 = unary1; |
| } |
| |
| } |
| else { //general case |
| int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); |
| agg1 = new PartialAggregate(input.constructLops(), |
| _op, _direction, getDataType(),getValueType(), et, k); |
| } |
| |
| setOutputDimensions(agg1); |
| setLineNumbers(agg1); |
| setLops(agg1); |
| |
| if (getDataType() == DataType.SCALAR) { |
| agg1.getOutputParameters().setDimensions(1, 1, getBlocksize(), getNnz()); |
| } |
| } |
| else if( et == ExecType.SPARK ) |
| { |
| //unary aggregate |
| if( isTernaryAggregateRewriteApplicable() ) |
| { |
| Lop aggregate = constructLopsTernaryAggregateRewrite(et); |
| setOutputDimensions(aggregate); //0x0 (scalar) |
| setLineNumbers(aggregate); |
| setLops(aggregate); |
| } |
| else if( isUnaryAggregateOuterSPRewriteApplicable() ) |
| { |
| BinaryOp binput = (BinaryOp)getInput().get(0); |
| Lop transform1 = new UAggOuterChain( binput.getInput().get(0).constructLops(), |
| binput.getInput().get(1).constructLops(), _op, _direction, |
| binput.getOp(), DataType.MATRIX, getValueType(), ExecType.SPARK); |
| PartialAggregate.setDimensionsBasedOnDirection(transform1, getDim1(), getDim2(), input.getBlocksize(), _direction); |
| setLineNumbers(transform1); |
| setLops(transform1); |
| |
| if (getDataType() == DataType.SCALAR) { |
| UnaryCP unary1 = new UnaryCP(transform1, |
| OpOp1.CAST_AS_SCALAR, getDataType(), getValueType()); |
| unary1.getOutputParameters().setDimensions(0, 0, 0, -1); |
| setLineNumbers(unary1); |
| setLops(unary1); |
| } |
| |
| } |
| else //default |
| { |
| boolean needAgg = requiresAggregation(input, _direction); |
| SparkAggType aggtype = getSparkUnaryAggregationType(needAgg); |
| |
| PartialAggregate aggregate = new PartialAggregate(input.constructLops(), |
| _op, _direction, input._dataType, getValueType(), aggtype, et); |
| aggregate.setDimensionsBasedOnDirection(getDim1(), getDim2(), input.getBlocksize()); |
| setLineNumbers(aggregate); |
| setLops(aggregate); |
| |
| if (getDataType() == DataType.SCALAR) { |
| UnaryCP unary1 = new UnaryCP(aggregate, |
| OpOp1.CAST_AS_SCALAR, getDataType(), getValueType()); |
| unary1.getOutputParameters().setDimensions(0, 0, 0, -1); |
| setLineNumbers(unary1); |
| setLops(unary1); |
| } |
| } |
| } |
| } |
| catch (Exception e) { |
| throw new HopsException(this.printErrorLocation() + "In AggUnary Hop, error constructing Lops " , e); |
| } |
| |
| //add reblock/checkpoint lops if necessary |
| constructAndSetLopsDataFlowProperties(); |
| |
| //return created lops |
| return getLops(); |
| } |
| |
| |
| |
| @Override |
| public String getOpString() { |
| //ua - unary aggregate, for consistency with runtime |
| return"ua(" + _op.toString() + _direction.toString() + ")"; |
| } |
| |
| @Override |
| public boolean allowsAllExecTypes() |
| { |
| return true; |
| } |
| |
| @Override |
| protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) |
| { |
| double sparsity = -1; |
| if (isGPUEnabled()) { |
| // The GPU version (for the time being) only does dense outputs |
| sparsity = 1.0; |
| } else { |
| sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); |
| } |
| |
| return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); |
| } |
| |
| @Override |
| protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz ) |
| { |
| //default: no additional memory required |
| double val = 0; |
| |
| double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); |
| |
| switch( _op ) //see MatrixAggLib for runtime operations |
| { |
| case MAX: |
| case MIN: |
| //worst-case: column-wise, sparse (temp int count arrays) |
| if( _direction == Direction.Col ) |
| val = dim2 * OptimizerUtils.INT_SIZE; |
| break; |
| case SUM: |
| case SUM_SQ: |
| //worst-case correction LASTROW / LASTCOLUMN |
| if( _direction == Direction.Col ) //(potentially sparse) |
| val = OptimizerUtils.estimateSizeExactSparsity(2, dim2, sparsity); |
| else if( _direction == Direction.Row ) //(always dense) |
| val = OptimizerUtils.estimateSizeExactSparsity(dim1, 2, 1.0); |
| break; |
| case MEAN: |
| //worst-case correction LASTTWOROWS / LASTTWOCOLUMNS |
| if( _direction == Direction.Col ) //(potentially sparse) |
| val = OptimizerUtils.estimateSizeExactSparsity(3, dim2, sparsity); |
| else if( _direction == Direction.Row ) //(always dense) |
| val = OptimizerUtils.estimateSizeExactSparsity(dim1, 3, 1.0); |
| break; |
| case VAR: |
| //worst-case correction LASTFOURROWS / LASTFOURCOLUMNS |
| if (isGPUEnabled()) { |
| // The GPU implementation only operates on dense data |
| // It allocates 2 dense blocks to help with these ops: |
| // Assume Y = var(X) Or colVars(X), Or rowVars(X) |
| // 1. Y = mean/rowMeans/colMeans(X) <-- Y is a scalar or row-vector or col-vector |
| // 2. temp1 = X - Y <-- temp1 is a matrix of size(X) |
| // 3. temp2 = temp1 ^ 2 <-- temp2 is a matrix of size(X) |
| // 4. temp3 = sum/rowSums/colSums(temp2) <-- temp3 is a scalar or a row-vector or col-vector |
| // 5. Y = temp3 / (size(X) or nrow(X) or ncol(X)) <-- Y is a scalar or a row-vector or col-vector |
| |
| long in1dim1 = getInput().get(0).getDim1(); |
| long in1dim2 = getInput().get(0).getDim2(); |
| |
| val = 2 * OptimizerUtils.estimateSize(in1dim1, in1dim2); // For temp1 & temp2 |
| if (_direction == Direction.Col){ |
| val += OptimizerUtils.estimateSize(in1dim1, 1); // For temp3 |
| } else if (_direction == Direction.Row){ |
| val += OptimizerUtils.estimateSize(1, in1dim2); // For temp3 |
| } |
| |
| } else if( _direction == Direction.Col ) { //(potentially sparse) |
| val = OptimizerUtils.estimateSizeExactSparsity(5, dim2, sparsity); |
| } else if( _direction == Direction.Row ) { //(always dense) |
| val = OptimizerUtils.estimateSizeExactSparsity(dim1, 5, 1.0); |
| } |
| break; |
| case MAXINDEX: |
| case MININDEX: |
| Hop hop = getInput().get(0); |
| if(isUnaryAggregateOuterCPRewriteApplicable()) |
| val = 3 * OptimizerUtils.estimateSizeExactSparsity(1, hop.getDim2(), 1.0); |
| else |
| //worst-case correction LASTCOLUMN |
| val = OptimizerUtils.estimateSizeExactSparsity(dim1, 2, 1.0); |
| break; |
| default: |
| //no intermediate memory consumption |
| val = 0; |
| } |
| |
| return val; |
| } |
| |
| @Override |
| protected DataCharacteristics inferOutputCharacteristics( MemoTable memo ) { |
| DataCharacteristics ret = null; |
| Hop input = getInput().get(0); |
| DataCharacteristics dc = memo.getAllInputStats(input); |
| if( _direction == Direction.Col && dc.colsKnown() ) |
| ret = new MatrixCharacteristics(1, dc.getCols(), -1, -1); |
| else if( _direction == Direction.Row && dc.rowsKnown() ) |
| ret = new MatrixCharacteristics(dc.getRows(), 1, -1 -1); |
| return ret; |
| } |
| |
| |
| @Override |
| protected ExecType optFindExecType() { |
| |
| checkAndSetForcedPlatform(); |
| |
| ExecType REMOTE = ExecType.SPARK; |
| |
| //forced / memory-based / threshold-based decision |
| if( _etypeForced != null ) |
| { |
| _etype = _etypeForced; |
| } |
| else |
| { |
| if ( OptimizerUtils.isMemoryBasedOptLevel() ) |
| { |
| _etype = findExecTypeByMemEstimate(); |
| } |
| // Choose CP, if the input dimensions are below threshold or if the input is a vector |
| else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVector() ) |
| { |
| _etype = ExecType.CP; |
| } |
| else |
| { |
| _etype = REMOTE; |
| } |
| |
| //check for valid CP dimensions and matrix size |
| checkAndSetInvalidCPDimsAndSize(); |
| } |
| |
| //spark-specific decision refinement (execute unary aggregate w/ spark input and |
| //single parent also in spark because it's likely cheap and reduces data transfer) |
| if( _etype == ExecType.CP && _etypeForced != ExecType.CP |
| && !(getInput().get(0) instanceof DataOp) //input is not checkpoint |
| && (getInput().get(0).getParent().size()==1 //uagg is only parent, or |
| || !requiresAggregation(getInput().get(0), _direction)) //w/o agg |
| && getInput().get(0).optFindExecType() == ExecType.SPARK ) |
| { |
| //pull unary aggregate into spark |
| _etype = ExecType.SPARK; |
| } |
| |
| //mark for recompile (forever) |
| setRequiresRecompileIfNecessary(); |
| |
| return _etype; |
| } |
| |
| private static boolean requiresAggregation( Hop input, Direction dir ) |
| { |
| if( !ALLOW_UNARYAGG_WO_FINAL_AGG ) |
| return false; //customization not allowed |
| |
| boolean noAggRequired = |
| ( input.getDim1()>1 && input.getDim1()<=input.getBlocksize() && dir==Direction.Col ) //e.g., colSums(X) with nrow(X)<=1000 |
| ||( input.getDim2()>1 && input.getDim2()<=input.getBlocksize() && dir==Direction.Row ); //e.g., rowSums(X) with ncol(X)<=1000 |
| |
| return !noAggRequired; |
| } |
| |
| private SparkAggType getSparkUnaryAggregationType( boolean agg ) |
| { |
| if( !agg ) |
| return SparkAggType.NONE; |
| |
| if( getDataType()==DataType.SCALAR //in case of scalars the block dims are not set |
| || dimsKnown() && getDim1()<=getBlocksize() && getDim2()<=getBlocksize() ) |
| return SparkAggType.SINGLE_BLOCK; |
| else |
| return SparkAggType.MULTI_BLOCK; |
| } |
| |
| private boolean isTernaryAggregateRewriteApplicable() |
| { |
| boolean ret = false; |
| |
| // TODO: Disable ternary aggregate rewrite on GPU backend. |
| if(DMLScript.USE_ACCELERATOR) |
| return false; |
| |
| //currently we support only sum over binary multiply but potentially |
| //it can be generalized to any RC aggregate over two common binary operations |
| if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES && _op == AggOp.SUM && |
| (_direction == Direction.RowCol || _direction == Direction.Col) ) |
| { |
| Hop input1 = getInput().get(0); |
| if (input1.getParent().size() == 1 |
| && input1 instanceof BinaryOp) { //sum single consumer |
| BinaryOp binput1 = (BinaryOp)input1; |
| |
| if (binput1.getOp() == OpOp2.POW |
| && binput1.getInput().get(1) instanceof LiteralOp) { |
| LiteralOp lit = (LiteralOp)binput1.getInput().get(1); |
| ret = HopRewriteUtils.getIntValueSafe(lit) == 3; |
| } |
| else if (binput1.getOp() == OpOp2.MULT ) { |
| Hop input11 = input1.getInput().get(0); |
| Hop input12 = input1.getInput().get(1); |
| |
| if (input11 instanceof BinaryOp && ((BinaryOp) input11).getOp() == OpOp2.MULT) { |
| //ternary, arbitrary matrices but no mv/outer operations. |
| ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1) && HopRewriteUtils |
| .isEqualSize(input11.getInput().get(1), input1) && HopRewriteUtils |
| .isEqualSize(input12, input1); |
| } else if (input12 instanceof BinaryOp && ((BinaryOp) input12).getOp() == OpOp2.MULT) { |
| //ternary, arbitrary matrices but no mv/outer operations. |
| ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1) && HopRewriteUtils |
| .isEqualSize(input12.getInput().get(1), input1) && HopRewriteUtils |
| .isEqualSize(input11, input1); |
| } else { |
| //binary, arbitrary matrices but no mv/outer operations. |
| ret = HopRewriteUtils.isEqualSize(input11, input12); |
| } |
| } |
| } |
| } |
| return ret; |
| } |
| |
| private static boolean isCompareOperator(OpOp2 opOp2) |
| { |
| return (opOp2 == OpOp2.LESS || opOp2 == OpOp2.LESSEQUAL |
| || opOp2 == OpOp2.GREATER || opOp2 == OpOp2.GREATEREQUAL |
| || opOp2 == OpOp2.EQUAL || opOp2 == OpOp2.NOTEQUAL); |
| } |
| |
| @Override |
| public boolean isMultiThreadedOpType() { |
| return true; |
| } |
| |
| /** |
| * This will check if there is sufficient memory locally (twice the size of second matrix, for original and sort data), and remotely (size of second matrix (sorted data)). |
| * @return true if sufficient memory |
| */ |
| private boolean isUnaryAggregateOuterSPRewriteApplicable() |
| { |
| boolean ret = false; |
| Hop input = getInput().get(0); |
| |
| if( input instanceof BinaryOp && ((BinaryOp)input).isOuter() ) |
| { |
| //note: both cases (partitioned matrix, and sorted double array), require to |
| //fit the broadcast twice into the local memory budget. Also, the memory |
| //constraint only needs to take the rhs into account because the output is |
| //guaranteed to be an aggregate of <=16KB |
| |
| Hop right = input.getInput().get(1); |
| |
| double size = right.dimsKnown() ? |
| OptimizerUtils.estimateSize(right.getDim1(), right.getDim2()) : //dims known and estimate fits |
| right.getOutputMemEstimate(); //dims unknown but worst-case estimate fits |
| |
| if(_op == AggOp.MAXINDEX || _op == AggOp.MININDEX){ |
| double memBudgetExec = SparkExecutionContext.getBroadcastMemoryBudget(); |
| double memBudgetLocal = OptimizerUtils.getLocalMemBudget(); |
| |
| //basic requirement: the broadcast needs to to fit twice in the remote broadcast memory |
| //and local memory budget because we have to create a partitioned broadcast |
| //memory and hand it over to the spark context as in-memory object |
| ret = ( 2*size < memBudgetExec && 2*size < memBudgetLocal ); |
| |
| } else { |
| if( OptimizerUtils.checkSparkBroadcastMemoryBudget(size) ) { |
| ret = true; |
| } |
| } |
| |
| } |
| |
| return ret; |
| } |
| |
| |
| |
| /** |
| * This will check if this is one of the operator from supported LibMatrixOuterAgg library. |
| * It needs to be Outer, aggregator type SUM, RowIndexMin, RowIndexMax and 6 operators <, <=, >, >=, == and != |
| * |
| * |
| * @return true if unary aggregate outer |
| */ |
| private boolean isUnaryAggregateOuterCPRewriteApplicable() { |
| boolean ret = false; |
| Hop input = getInput().get(0); |
| if(( input instanceof BinaryOp && ((BinaryOp)input).isOuter() ) |
| && (_op == AggOp.MAXINDEX || _op == AggOp.MININDEX || _op == AggOp.SUM) |
| && (isCompareOperator(((BinaryOp)input).getOp()))) |
| ret = true; |
| return ret; |
| } |
| |
| private Lop constructLopsTernaryAggregateRewrite(ExecType et) |
| { |
| BinaryOp input1 = (BinaryOp)getInput().get(0); |
| Hop input11 = input1.getInput().get(0); |
| Hop input12 = input1.getInput().get(1); |
| |
| Lop in1 = null, in2 = null, in3 = null; |
| boolean handled = false; |
| |
| if (input1.getOp() == OpOp2.POW) { |
| assert(HopRewriteUtils.isLiteralOfValue(input12, 3)) : "this case can only occur with a power of 3"; |
| in1 = input11.constructLops(); |
| in2 = in1; |
| in3 = in1; |
| handled = true; |
| } else if (input11 instanceof BinaryOp ) { |
| BinaryOp b11 = (BinaryOp)input11; |
| switch( b11.getOp() ) { |
| case MULT: // A*B*C case |
| in1 = input11.getInput().get(0).constructLops(); |
| in2 = input11.getInput().get(1).constructLops(); |
| in3 = input12.constructLops(); |
| handled = true; |
| break; |
| case POW: // A*A*B case |
| Hop b112 = b11.getInput().get(1); |
| if ( !(input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT) |
| && HopRewriteUtils.isLiteralOfValue(b112, 2) ) { |
| in1 = b11.getInput().get(0).constructLops(); |
| in2 = in1; |
| in3 = input12.constructLops(); |
| handled = true; |
| } |
| break; |
| default: break; |
| } |
| } else if( input12 instanceof BinaryOp ) { |
| BinaryOp b12 = (BinaryOp)input12; |
| switch (b12.getOp()) { |
| case MULT: // A*B*C case |
| in1 = input11.constructLops(); |
| in2 = input12.getInput().get(0).constructLops(); |
| in3 = input12.getInput().get(1).constructLops(); |
| handled = true; |
| break; |
| case POW: // A*B*B case |
| Hop b112 = b12.getInput().get(1); |
| if ( HopRewriteUtils.isLiteralOfValue(b112, 2) ) { |
| in1 = b12.getInput().get(0).constructLops(); |
| in2 = in1; |
| in3 = input11.constructLops(); |
| handled = true; |
| } |
| break; |
| default: break; |
| } |
| } |
| |
| if (!handled) { |
| in1 = input11.constructLops(); |
| in2 = input12.constructLops(); |
| in3 = new LiteralOp(1).constructLops(); |
| } |
| |
| //create new ternary aggregate operator |
| int k = OptimizerUtils.getConstrainedNumThreads( _maxNumThreads ); |
| // The execution type of a unary aggregate instruction should depend on the execution type of inputs to avoid OOM |
| // Since we only support matrix-vector and not vector-matrix, checking the execution type of input1 should suffice. |
| ExecType et_input = input1.optFindExecType(); |
| // Because ternary aggregate are not supported on GPU |
| et_input = et_input == ExecType.GPU ? ExecType.CP : et_input; |
| |
| return new TernaryAggregate(in1, in2, in3, AggOp.SUM, |
| OpOp2.MULT, _direction, getDataType(), ValueType.FP64, et_input, k); |
| } |
| |
| @Override |
| public void refreshSizeInformation() |
| { |
| if (getDataType() != DataType.SCALAR) |
| { |
| Hop input = getInput().get(0); |
| if ( _direction == Direction.Col ) //colwise computations |
| { |
| setDim1(1); |
| setDim2(input.getDim2()); |
| } |
| else if ( _direction == Direction.Row ) |
| { |
| setDim1(input.getDim1()); |
| setDim2(1); |
| } |
| } |
| } |
| |
| @Override |
| public boolean isTransposeSafe() |
| { |
| boolean ret = (_direction == Direction.RowCol) && //full aggregate |
| (_op == AggOp.SUM || _op == AggOp.SUM_SQ || //valid aggregration functions |
| _op == AggOp.MIN || _op == AggOp.MAX || |
| _op == AggOp.PROD || _op == AggOp.MEAN || |
| _op == AggOp.VAR); |
| //note: trace and maxindex are not transpose-safe. |
| |
| return ret; |
| } |
| |
| @Override |
| public Object clone() throws CloneNotSupportedException |
| { |
| AggUnaryOp ret = new AggUnaryOp(); |
| |
| //copy generic attributes |
| ret.clone(this, false); |
| |
| //copy specific attributes |
| ret._op = _op; |
| ret._direction = _direction; |
| ret._maxNumThreads = _maxNumThreads; |
| |
| return ret; |
| } |
| |
| @Override |
| public boolean compare( Hop that ) |
| { |
| if( !(that instanceof AggUnaryOp) ) |
| return false; |
| |
| AggUnaryOp that2 = (AggUnaryOp)that; |
| return ( _op == that2._op |
| && _direction == that2._direction |
| && _maxNumThreads == that2._maxNumThreads |
| && getInput().get(0) == that2.getInput().get(0)); |
| } |
| } |