| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.lops; |
| |
| import org.apache.sysds.common.Types.AggOp; |
| import org.apache.sysds.common.Types.CorrectionLocationType; |
| import org.apache.sysds.common.Types.DataType; |
| import org.apache.sysds.common.Types.Direction; |
| import org.apache.sysds.common.Types.ValueType; |
| import org.apache.sysds.hops.AggBinaryOp.SparkAggType; |
| import org.apache.sysds.hops.HopsException; |
| import org.apache.sysds.lops.LopProperties.ExecType; |
| |
| |
| /** |
| * Lop to perform a partial aggregation. It was introduced to do some initial |
| * aggregation operations on blocks in the mapper/reducer. |
| */ |
| |
| public class PartialAggregate extends Lop |
| { |
| private AggOp operation; |
| private Direction direction; |
| |
| //optional attribute for CP num threads |
| private int _numThreads = -1; |
| |
| //optional attribute for spark exec type |
| private SparkAggType _aggtype = SparkAggType.MULTI_BLOCK; |
| |
| public PartialAggregate( Lop input, AggOp op, Direction direct, DataType dt, ValueType vt, ExecType et, int k) |
| { |
| super(Lop.Type.PartialAggregate, dt, vt); |
| init(input, op, direct, dt, vt, et); |
| _numThreads = k; |
| } |
| |
| public PartialAggregate( Lop input, AggOp op, Direction direct, DataType dt, ValueType vt, SparkAggType aggtype, ExecType et) |
| { |
| super(Lop.Type.PartialAggregate, dt, vt); |
| init(input, op, direct, dt, vt, et); |
| _aggtype = aggtype; |
| } |
| |
| /** |
| * Constructor to setup a partial aggregate operation. |
| * |
| * @param input low-level operator |
| * @param op aggregate operation type |
| * @param direct partial aggregate directon type |
| * @param dt data type |
| * @param vt value type |
| * @param et execution type |
| */ |
| private void init(Lop input, AggOp op, Direction direct, DataType dt, ValueType vt, ExecType et) { |
| operation = op; |
| direction = direct; |
| addInput(input); |
| input.addOutput(this); |
| lps.setProperties(inputs, et); |
| } |
| |
| /** |
| * This method computes the location of "correction" terms in the output |
| * produced by PartialAgg instruction. |
| * |
| * When computing the stable sum, "correction" refers to the compensation as |
| * defined by the original Kahan algorithm. When computing the stable mean, |
| * "correction" refers to two extra values (the running mean, count) |
| * produced by each Mapper i.e., by each PartialAgg instruction. |
| * |
| * This method is invoked during hop-to-lop translation, while creating the |
| * corresponding Aggregate lop |
| * |
| * Computed information is encoded in the PartialAgg instruction so that the |
| * appropriate aggregate operator is used at runtime (see: |
| * dml.runtime.matrix.operator.AggregateOperator.java and dml.runtime.matrix) |
| * |
| * @return correct location |
| */ |
| public CorrectionLocationType getCorrectionLocation() { |
| return getCorrectionLocation(operation, direction); |
| } |
| |
| public static CorrectionLocationType getCorrectionLocation(AggOp operation, Direction direction) { |
| CorrectionLocationType loc; |
| |
| switch (operation) { |
| case SUM: |
| case SUM_SQ: |
| case TRACE: |
| switch (direction) { |
| case Col: |
| // colSums: corrections will be present as a last row in the |
| // result |
| loc = CorrectionLocationType.LASTROW; |
| break; |
| case Row: |
| case RowCol: |
| // rowSums, sum: corrections will be present as a last column in |
| // the result |
| loc = CorrectionLocationType.LASTCOLUMN; |
| break; |
| default: |
| throw new LopsException("PartialAggregate.getCorrectionLocation() - " |
| + "Unknown aggregate direction: " + direction); |
| } |
| break; |
| |
| case MEAN: |
| // Computation of stable mean requires each mapper to output both |
| // the running mean as well as the count |
| switch (direction) { |
| case Col: |
| // colMeans: last row is correction 2nd last is count |
| loc = CorrectionLocationType.LASTTWOROWS; |
| break; |
| case Row: |
| case RowCol: |
| // rowMeans, mean: last column is correction 2nd last is count |
| loc = CorrectionLocationType.LASTTWOCOLUMNS; |
| break; |
| default: |
| throw new LopsException("PartialAggregate.getCorrectionLocation() - " |
| + "Unknown aggregate direction: " + direction); |
| } |
| break; |
| |
| case VAR: |
| // Computation of stable variance requires each mapper to |
| // output the running variance, the running mean, the |
| // count, a correction term for the squared deviations |
| // from the sample mean (m2), and a correction term for |
| // the mean. These values collectively allow all other |
| // necessary intermediates to be reconstructed, and the |
| // variance will output by our unary aggregate framework. |
| // Thus, our outputs will be: |
| // { var | mean, count, m2 correction, mean correction } |
| switch (direction) { |
| case Col: |
| // colVars: { var | mean, count, m2 correction, mean correction }, |
| // where each element is a column. |
| loc = CorrectionLocationType.LASTFOURROWS; |
| break; |
| case Row: |
| case RowCol: |
| // var, rowVars: { var | mean, count, m2 correction, mean correction }, |
| // where each element is a row. |
| loc = CorrectionLocationType.LASTFOURCOLUMNS; |
| break; |
| default: |
| throw new LopsException("PartialAggregate.getCorrectionLocation() - " |
| + "Unknown aggregate direction: " + direction); |
| } |
| break; |
| |
| case MAXINDEX: |
| case MININDEX: |
| loc = CorrectionLocationType.LASTCOLUMN; |
| break; |
| |
| default: |
| loc = CorrectionLocationType.NONE; |
| } |
| return loc; |
| } |
| |
| public void setDimensionsBasedOnDirection(long dim1, long dim2, long blen) { |
| setDimensionsBasedOnDirection(this, dim1, dim2, blen, direction); |
| } |
| |
| public static void setDimensionsBasedOnDirection(Lop lop, long dim1, long dim2, long blen, Direction dir) |
| { |
| try { |
| if (dir == Direction.Row) |
| lop.outParams.setDimensions(dim1, 1, blen, -1); |
| else if (dir == Direction.Col) |
| lop.outParams.setDimensions(1, dim2, blen, -1); |
| else if (dir == Direction.RowCol) |
| lop.outParams.setDimensions(1, 1, blen, -1); |
| else |
| throw new LopsException("In PartialAggregate Lop, Unknown aggregate direction " + dir); |
| } catch (HopsException e) { |
| throw new LopsException("In PartialAggregate Lop, error setting dimensions based on direction", e); |
| } |
| } |
| |
| @Override |
| public String toString() { |
| return "Partial Aggregate " + operation; |
| } |
| |
| private String getOpcode() { |
| return getOpcode(operation, direction); |
| } |
| |
| /** |
| * Instruction generation for for CP and Spark |
| */ |
| @Override |
| public String getInstructions(String input1, String output) |
| { |
| StringBuilder sb = new StringBuilder(); |
| sb.append( getExecType() ); |
| |
| sb.append( OPERAND_DELIMITOR ); |
| sb.append( getOpcode() ); |
| |
| sb.append( OPERAND_DELIMITOR ); |
| sb.append( getInputs().get(0).prepInputOperand(input1) ); |
| |
| sb.append( OPERAND_DELIMITOR ); |
| sb.append( prepOutputOperand(output) ); |
| |
| //exec-type specific attributes |
| sb.append( OPERAND_DELIMITOR ); |
| if( getExecType() == ExecType.SPARK ) |
| sb.append( _aggtype ); |
| else if( getExecType() == ExecType.CP ) |
| sb.append( _numThreads ); |
| |
| return sb.toString(); |
| } |
| |
| public static String getOpcode(AggOp op, Direction dir) |
| { |
| switch( op ) |
| { |
| case SUM: { |
| // instructions that use kahanSum are similar to ua+,uar+,uac+ |
| // except that they also produce correction values along with partial |
| // sums. |
| if( dir == Direction.RowCol ) |
| return "uak+"; |
| else if( dir == Direction.Row ) |
| return "uark+"; |
| else if( dir == Direction.Col ) |
| return "uack+"; |
| break; |
| } |
| |
| case SUM_SQ: { |
| if( dir == Direction.RowCol ) |
| return "uasqk+"; |
| else if( dir == Direction.Row ) |
| return "uarsqk+"; |
| else if( dir == Direction.Col ) |
| return "uacsqk+"; |
| break; |
| } |
| |
| case MEAN: { |
| if( dir == Direction.RowCol ) |
| return "uamean"; |
| else if( dir == Direction.Row ) |
| return "uarmean"; |
| else if( dir == Direction.Col ) |
| return "uacmean"; |
| break; |
| } |
| |
| case VAR: { |
| if( dir == Direction.RowCol ) |
| return "uavar"; |
| else if( dir == Direction.Row ) |
| return "uarvar"; |
| else if( dir == Direction.Col ) |
| return "uacvar"; |
| break; |
| } |
| |
| case PROD: { |
| switch( dir ) { |
| case RowCol: return "ua*"; |
| case Row: return "uar*"; |
| case Col: return "uac*"; |
| } |
| } |
| |
| case SUM_PROD: { |
| switch( dir ) { |
| case RowCol: return "ua+*"; |
| case Row: return "uar+*"; |
| case Col: return "uac+*"; |
| } |
| } |
| |
| case MAX: { |
| if( dir == Direction.RowCol ) |
| return "uamax"; |
| else if( dir == Direction.Row ) |
| return "uarmax"; |
| else if( dir == Direction.Col ) |
| return "uacmax"; |
| break; |
| } |
| |
| case MIN: { |
| if( dir == Direction.RowCol ) |
| return "uamin"; |
| else if( dir == Direction.Row ) |
| return "uarmin"; |
| else if( dir == Direction.Col ) |
| return "uacmin"; |
| break; |
| } |
| |
| case MAXINDEX:{ |
| if( dir == Direction.Row ) |
| return "uarimax"; |
| break; |
| } |
| |
| case MININDEX: { |
| if( dir == Direction.Row ) |
| return "uarimin"; |
| break; |
| } |
| |
| case TRACE: { |
| if( dir == Direction.RowCol ) |
| return "uaktrace"; |
| break; |
| } |
| |
| case COUNT_DISTINCT: { |
| if(dir == Direction.RowCol ) |
| return "uacd"; |
| break; |
| } |
| |
| case COUNT_DISTINCT_APPROX: { |
| if(dir == Direction.RowCol ) |
| return "uacdap"; |
| break; |
| } |
| } |
| |
| //should never come here for normal compilation |
| throw new UnsupportedOperationException("Instruction is not defined for PartialAggregate operation " + op); |
| } |
| |
| } |