blob: bfec9ffe3b44f828e5c5df2192792dbe5cdc1711 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.lops;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.CorrectionLocationType;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.lops.LopProperties.ExecType;
/**
* Lop to perform a partial aggregation. It was introduced to do some initial
* aggregation operations on blocks in the mapper/reducer.
*/
public class PartialAggregate extends Lop
{
private AggOp operation;
private Direction direction;
//optional attribute for CP num threads
private int _numThreads = -1;
//optional attribute for spark exec type
private SparkAggType _aggtype = SparkAggType.MULTI_BLOCK;
public PartialAggregate( Lop input, AggOp op, Direction direct, DataType dt, ValueType vt, ExecType et, int k)
{
super(Lop.Type.PartialAggregate, dt, vt);
init(input, op, direct, dt, vt, et);
_numThreads = k;
}
public PartialAggregate( Lop input, AggOp op, Direction direct, DataType dt, ValueType vt, SparkAggType aggtype, ExecType et)
{
super(Lop.Type.PartialAggregate, dt, vt);
init(input, op, direct, dt, vt, et);
_aggtype = aggtype;
}
/**
* Constructor to setup a partial aggregate operation.
*
* @param input low-level operator
* @param op aggregate operation type
* @param direct partial aggregate directon type
* @param dt data type
* @param vt value type
* @param et execution type
*/
private void init(Lop input, AggOp op, Direction direct, DataType dt, ValueType vt, ExecType et) {
operation = op;
direction = direct;
addInput(input);
input.addOutput(this);
lps.setProperties(inputs, et);
}
/**
* This method computes the location of "correction" terms in the output
* produced by PartialAgg instruction.
*
* When computing the stable sum, "correction" refers to the compensation as
* defined by the original Kahan algorithm. When computing the stable mean,
* "correction" refers to two extra values (the running mean, count)
* produced by each Mapper i.e., by each PartialAgg instruction.
*
* This method is invoked during hop-to-lop translation, while creating the
* corresponding Aggregate lop
*
* Computed information is encoded in the PartialAgg instruction so that the
* appropriate aggregate operator is used at runtime (see:
* dml.runtime.matrix.operator.AggregateOperator.java and dml.runtime.matrix)
*
* @return correct location
*/
public CorrectionLocationType getCorrectionLocation() {
return getCorrectionLocation(operation, direction);
}
public static CorrectionLocationType getCorrectionLocation(AggOp operation, Direction direction) {
CorrectionLocationType loc;
switch (operation) {
case SUM:
case SUM_SQ:
case TRACE:
switch (direction) {
case Col:
// colSums: corrections will be present as a last row in the
// result
loc = CorrectionLocationType.LASTROW;
break;
case Row:
case RowCol:
// rowSums, sum: corrections will be present as a last column in
// the result
loc = CorrectionLocationType.LASTCOLUMN;
break;
default:
throw new LopsException("PartialAggregate.getCorrectionLocation() - "
+ "Unknown aggregate direction: " + direction);
}
break;
case MEAN:
// Computation of stable mean requires each mapper to output both
// the running mean as well as the count
switch (direction) {
case Col:
// colMeans: last row is correction 2nd last is count
loc = CorrectionLocationType.LASTTWOROWS;
break;
case Row:
case RowCol:
// rowMeans, mean: last column is correction 2nd last is count
loc = CorrectionLocationType.LASTTWOCOLUMNS;
break;
default:
throw new LopsException("PartialAggregate.getCorrectionLocation() - "
+ "Unknown aggregate direction: " + direction);
}
break;
case VAR:
// Computation of stable variance requires each mapper to
// output the running variance, the running mean, the
// count, a correction term for the squared deviations
// from the sample mean (m2), and a correction term for
// the mean. These values collectively allow all other
// necessary intermediates to be reconstructed, and the
// variance will output by our unary aggregate framework.
// Thus, our outputs will be:
// { var | mean, count, m2 correction, mean correction }
switch (direction) {
case Col:
// colVars: { var | mean, count, m2 correction, mean correction },
// where each element is a column.
loc = CorrectionLocationType.LASTFOURROWS;
break;
case Row:
case RowCol:
// var, rowVars: { var | mean, count, m2 correction, mean correction },
// where each element is a row.
loc = CorrectionLocationType.LASTFOURCOLUMNS;
break;
default:
throw new LopsException("PartialAggregate.getCorrectionLocation() - "
+ "Unknown aggregate direction: " + direction);
}
break;
case MAXINDEX:
case MININDEX:
loc = CorrectionLocationType.LASTCOLUMN;
break;
default:
loc = CorrectionLocationType.NONE;
}
return loc;
}
public void setDimensionsBasedOnDirection(long dim1, long dim2, long blen) {
setDimensionsBasedOnDirection(this, dim1, dim2, blen, direction);
}
public static void setDimensionsBasedOnDirection(Lop lop, long dim1, long dim2, long blen, Direction dir)
{
try {
if (dir == Direction.Row)
lop.outParams.setDimensions(dim1, 1, blen, -1);
else if (dir == Direction.Col)
lop.outParams.setDimensions(1, dim2, blen, -1);
else if (dir == Direction.RowCol)
lop.outParams.setDimensions(1, 1, blen, -1);
else
throw new LopsException("In PartialAggregate Lop, Unknown aggregate direction " + dir);
} catch (HopsException e) {
throw new LopsException("In PartialAggregate Lop, error setting dimensions based on direction", e);
}
}
@Override
public String toString() {
return "Partial Aggregate " + operation;
}
private String getOpcode() {
return getOpcode(operation, direction);
}
/**
* Instruction generation for for CP and Spark
*/
@Override
public String getInstructions(String input1, String output)
{
StringBuilder sb = new StringBuilder();
sb.append( getExecType() );
sb.append( OPERAND_DELIMITOR );
sb.append( getOpcode() );
sb.append( OPERAND_DELIMITOR );
sb.append( getInputs().get(0).prepInputOperand(input1) );
sb.append( OPERAND_DELIMITOR );
sb.append( prepOutputOperand(output) );
//exec-type specific attributes
sb.append( OPERAND_DELIMITOR );
if( getExecType() == ExecType.SPARK )
sb.append( _aggtype );
else if( getExecType() == ExecType.CP )
sb.append( _numThreads );
return sb.toString();
}
public static String getOpcode(AggOp op, Direction dir)
{
switch( op )
{
case SUM: {
// instructions that use kahanSum are similar to ua+,uar+,uac+
// except that they also produce correction values along with partial
// sums.
if( dir == Direction.RowCol )
return "uak+";
else if( dir == Direction.Row )
return "uark+";
else if( dir == Direction.Col )
return "uack+";
break;
}
case SUM_SQ: {
if( dir == Direction.RowCol )
return "uasqk+";
else if( dir == Direction.Row )
return "uarsqk+";
else if( dir == Direction.Col )
return "uacsqk+";
break;
}
case MEAN: {
if( dir == Direction.RowCol )
return "uamean";
else if( dir == Direction.Row )
return "uarmean";
else if( dir == Direction.Col )
return "uacmean";
break;
}
case VAR: {
if( dir == Direction.RowCol )
return "uavar";
else if( dir == Direction.Row )
return "uarvar";
else if( dir == Direction.Col )
return "uacvar";
break;
}
case PROD: {
switch( dir ) {
case RowCol: return "ua*";
case Row: return "uar*";
case Col: return "uac*";
}
}
case SUM_PROD: {
switch( dir ) {
case RowCol: return "ua+*";
case Row: return "uar+*";
case Col: return "uac+*";
}
}
case MAX: {
if( dir == Direction.RowCol )
return "uamax";
else if( dir == Direction.Row )
return "uarmax";
else if( dir == Direction.Col )
return "uacmax";
break;
}
case MIN: {
if( dir == Direction.RowCol )
return "uamin";
else if( dir == Direction.Row )
return "uarmin";
else if( dir == Direction.Col )
return "uacmin";
break;
}
case MAXINDEX:{
if( dir == Direction.Row )
return "uarimax";
break;
}
case MININDEX: {
if( dir == Direction.Row )
return "uarimin";
break;
}
case TRACE: {
if( dir == Direction.RowCol )
return "uaktrace";
break;
}
case COUNT_DISTINCT: {
if(dir == Direction.RowCol )
return "uacd";
break;
}
case COUNT_DISTINCT_APPROX: {
if(dir == Direction.RowCol )
return "uacdap";
break;
}
}
//should never come here for normal compilation
throw new UnsupportedOperationException("Instruction is not defined for PartialAggregate operation " + op);
}
}