blob: d94fcce418b21e0c9d557b1ea3b6d189d7aca759 [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
package org.apache.sysds.runtime.instructions.spark.utils;
import org.apache.sysds.common.Types.CorrectionLocationType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction.RDDUAggFunction2;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
* Collection of utility methods for aggregating binary block rdds. As a general
* policy always call stable algorithms which maintain corrections over blocks
* per key. The performance overhead over a simple reducebykey is roughly 7-10%
* and with that acceptable.
public class RDDAggregateUtils
//internal configuration to use tree aggregation (treeReduce w/ depth=2),
//this is currently disabled because it was 2x slower than a simple
//single-block reduce due to additional overhead for shuffling
private static final boolean TREE_AGGREGATION = false;
public static MatrixBlock sumStable( JavaPairRDD<MatrixIndexes, MatrixBlock> in ) {
return sumStable( in.values() );
public static MatrixBlock sumStable( JavaRDD<MatrixBlock> in )
//stable sum of all blocks with correction block per function instance
return in.treeReduce(
new SumSingleBlockFunction(true) );
else { //DEFAULT
//reduce-all aggregate via fold instead of reduce to allow
//for update in-place w/o deep copy of left-hand-side blocks
return in.fold(
new MatrixBlock(),
new SumSingleBlockFunction(false));
public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in) {
return sumByKeyStable(in, in.getNumPartitions(), true);
public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in,
boolean deepCopyCombiner) {
return sumByKeyStable(in, in.getNumPartitions(), deepCopyCombiner);
public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in,
int numPartitions, boolean deepCopyCombiner)
//stable sum of blocks per key, by passing correction blocks along with aggregates
JavaPairRDD<MatrixIndexes, CorrMatrixBlock> tmp =
in.combineByKey( new CreateCorrBlockCombinerFunction(deepCopyCombiner),
new MergeSumBlockValueFunction(deepCopyCombiner),
new MergeSumBlockCombinerFunction(deepCopyCombiner), numPartitions );
//strip-off correction blocks from
JavaPairRDD<MatrixIndexes, MatrixBlock> out =
tmp.mapValues( new ExtractMatrixBlock() );
//return the aggregate rdd
return out;
public static JavaPairRDD<MatrixIndexes, Double> sumCellsByKeyStable( JavaPairRDD<MatrixIndexes, Double> in ) {
return sumCellsByKeyStable(in, in.getNumPartitions());
public static JavaPairRDD<MatrixIndexes, Double> sumCellsByKeyStable( JavaPairRDD<MatrixIndexes, Double> in, int numParts )
//stable sum of blocks per key, by passing correction blocks along with aggregates
JavaPairRDD<MatrixIndexes, KahanObject> tmp =
in.combineByKey( new CreateCellCombinerFunction(),
new MergeSumCellValueFunction(),
new MergeSumCellCombinerFunction(), numParts);
//strip-off correction blocks from
JavaPairRDD<MatrixIndexes, Double> out =
tmp.mapValues( new ExtractDoubleCell() );
//return the aggregate rdd
return out;
* Single block aggregation over pair rdds with corrections for numerical stability.
* @param in matrix as {@code JavaPairRDD<MatrixIndexes, MatrixBlock>}
* @param aop aggregate operator
* @return matrix block
public static MatrixBlock aggStable( JavaPairRDD<MatrixIndexes, MatrixBlock> in, AggregateOperator aop ) {
return aggStable( in.values(), aop );
* Single block aggregation over rdds with corrections for numerical stability.
* @param in matrix as {@code JavaRDD<MatrixBlock>}
* @param aop aggregate operator
* @return matrix block
public static MatrixBlock aggStable( JavaRDD<MatrixBlock> in, AggregateOperator aop )
//stable aggregate of all blocks with correction block per function instance
//reduce-all aggregate via fold instead of reduce to allow
//for update in-place w/o deep copy of left-hand-side blocks
return in.fold(
new MatrixBlock(),
new AggregateSingleBlockFunction(aop) );
* Single block aggregation over pair rdds with corrections for numerical stability.
* @param in tensor as {@code JavaPairRDD<TensorIndexes, TensorBlock>}
* @param aop aggregate operator
* @return tensor block
public static TensorBlock aggStableTensor(JavaPairRDD<TensorIndexes, TensorBlock> in, AggregateOperator aop) {
return aggStableTensor(in.values(), aop);
* Single block aggregation over rdds with corrections for numerical stability.
* @param in tensor as {@code JavaRDD<TensorBlock>}
* @param aop aggregate operator
* @return tensor block
public static TensorBlock aggStableTensor(JavaRDD<TensorBlock> in, AggregateOperator aop )
//stable aggregate of all blocks with correction block per function instance
//reduce-all aggregate via fold instead of reduce to allow
//for update in-place w/o deep copy of left-hand-side blocks
return in.fold(
new TensorBlock(),
new AggregateSingleTensorBlockFunction(aop) );
public static JavaPairRDD<MatrixIndexes, MatrixBlock> aggByKeyStable( JavaPairRDD<MatrixIndexes, MatrixBlock> in,
AggregateOperator aop) {
return aggByKeyStable(in, aop, in.getNumPartitions(), true);
public static JavaPairRDD<MatrixIndexes, MatrixBlock> aggByKeyStable( JavaPairRDD<MatrixIndexes, MatrixBlock> in,
AggregateOperator aop, boolean deepCopyCombiner ) {
return aggByKeyStable(in, aop, in.getNumPartitions(), deepCopyCombiner);
public static JavaPairRDD<MatrixIndexes, MatrixBlock> aggByKeyStable( JavaPairRDD<MatrixIndexes, MatrixBlock> in,
AggregateOperator aop, int numPartitions, boolean deepCopyCombiner )
//stable sum of blocks per key, by passing correction blocks along with aggregates
JavaPairRDD<MatrixIndexes, CorrMatrixBlock> tmp =
in.combineByKey( new CreateCorrBlockCombinerFunction(deepCopyCombiner),
new MergeAggBlockValueFunction(aop),
new MergeAggBlockCombinerFunction(aop), numPartitions );
//strip-off correction blocks from
JavaPairRDD<MatrixIndexes, MatrixBlock> out =
tmp.mapValues( new ExtractMatrixBlock() );
//return the aggregate rdd
return out;
public static double max(JavaPairRDD<MatrixIndexes, MatrixBlock> in) {
AggregateUnaryOperator auop = InstructionUtils.parseBasicAggregateUnaryOperator("uamax");
MatrixBlock tmp = aggStable( RDDUAggFunction2(auop, -1)), auop.aggOp);
return tmp.quickGetValue(0, 0);
* Merges disjoint data of all blocks per key.
* Note: The behavior of this method is undefined for both sparse and dense data if the
* assumption of disjoint data is violated.
* @param in matrix as {@code JavaPairRDD<MatrixIndexes, MatrixBlock>}
* @return matrix as {@code JavaPairRDD<MatrixIndexes, MatrixBlock>}
public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeByKey( JavaPairRDD<MatrixIndexes, MatrixBlock> in ) {
return mergeByKey(in, in.getNumPartitions(), true);
* Merges disjoint data of all blocks per key.
* Note: The behavior of this method is undefined for both sparse and dense data if the
* assumption of disjoint data is violated.
* @param in matrix as {@code JavaPairRDD<MatrixIndexes, MatrixBlock>}
* @param deepCopyCombiner indicator if the createCombiner functions needs to deep copy the input block
* @return matrix as {@code JavaPairRDD<MatrixIndexes, MatrixBlock>}
public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeByKey( JavaPairRDD<MatrixIndexes, MatrixBlock> in,
boolean deepCopyCombiner ) {
return mergeByKey(in, in.getNumPartitions(), deepCopyCombiner);
* Merges disjoint data of all blocks per key.
* Note: The behavior of this method is undefined for both sparse and dense data if the
* assumption of disjoint data is violated.
* @param in matrix as {@code JavaPairRDD<MatrixIndexes, MatrixBlock>}
* @param numPartitions number of output partitions
* @param deepCopyCombiner indicator if the createCombiner functions needs to deep copy the input block
* @return matrix as {@code JavaPairRDD<MatrixIndexes, MatrixBlock>}
public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeByKey( JavaPairRDD<MatrixIndexes, MatrixBlock> in,
int numPartitions, boolean deepCopyCombiner )
//use combine by key to avoid unnecessary deep block copies, i.e.
//create combiner block once and merge remaining blocks in-place.
return in.combineByKey(
new CreateBlockCombinerFunction(deepCopyCombiner),
new MergeBlocksFunction(false),
new MergeBlocksFunction(false), numPartitions );
* Merges disjoint data of all blocks per key.
* Note: The behavior of this method is undefined for both sparse and dense data if the
* assumption of disjoint data is violated.
* @param in matrix as {@code JavaPairRDD<MatrixIndexes, RowMatrixBlock>}
* @return matrix as {@code JavaPairRDD<MatrixIndexes, MatrixBlock>}
public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeRowsByKey( JavaPairRDD<MatrixIndexes, RowMatrixBlock> in )
return in.combineByKey( new CreateRowBlockCombinerFunction(),
new MergeRowBlockValueFunction(),
new MergeBlocksFunction(false) );
private static class CreateCorrBlockCombinerFunction implements Function<MatrixBlock, CorrMatrixBlock>
private static final long serialVersionUID = -3666451526776017343L;
private final boolean _deep;
public CreateCorrBlockCombinerFunction(boolean deep) {
_deep = deep;
public CorrMatrixBlock call(MatrixBlock arg0)
throws Exception
//deep copy to allow update in-place
return new CorrMatrixBlock(
_deep ? new MatrixBlock(arg0) : arg0);
private static class MergeSumBlockValueFunction implements Function2<CorrMatrixBlock, MatrixBlock, CorrMatrixBlock>
private static final long serialVersionUID = 3703543699467085539L;
private AggregateOperator _op = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.NONE);
private final boolean _deep;
public MergeSumBlockValueFunction(boolean deep) {
_deep = deep;
public CorrMatrixBlock call(CorrMatrixBlock arg0, MatrixBlock arg1)
throws Exception
if( arg1.isEmptyBlock(false) )
return arg0;
//get current block and correction
MatrixBlock value = arg0.getValue();
MatrixBlock corr = arg0.getCorrection();
//correction block allocation on demand
if( corr == null && !arg1.isEmptyBlock(false) )
corr = new MatrixBlock(value.getNumRows(), value.getNumColumns(), false);
//aggregate other input and maintain corrections
//(existing value and corr are used in place)
OperationsOnMatrixValues.incrementalAggregation(value, corr, arg1, _op, false, _deep);
return arg0.set(value, corr);
private static class MergeSumBlockCombinerFunction implements Function2<CorrMatrixBlock, CorrMatrixBlock, CorrMatrixBlock>
private static final long serialVersionUID = 7664941774566119853L;
private AggregateOperator _op = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.NONE);
private final boolean _deep;
public MergeSumBlockCombinerFunction(boolean deep) {
_deep = deep;
public CorrMatrixBlock call(CorrMatrixBlock arg0, CorrMatrixBlock arg1)
throws Exception
//get current block and correction
MatrixBlock value1 = arg0.getValue();
MatrixBlock value2 = arg1.getValue();
MatrixBlock corr = arg0.getCorrection();
//correction block allocation on demand (but use second if exists)
if( corr == null ) {
corr = (arg1.getCorrection()!=null) ? arg1.getCorrection() :
value2.isEmptyBlock(false) || (!_deep && value1.isEmptyBlock(false)) ? null :
new MatrixBlock(value1.getNumRows(), value1.getNumColumns(), false);
//aggregate other input and maintain corrections
//(existing value and corr are used in place)
OperationsOnMatrixValues.incrementalAggregation(value1, corr, value2, _op, false, _deep);
return arg0.set(value1, corr);
private static class CreateBlockCombinerFunction implements Function<MatrixBlock, MatrixBlock>
private static final long serialVersionUID = 1987501624176848292L;
private final boolean _deep;
public CreateBlockCombinerFunction(boolean deep) {
_deep = deep;
public MatrixBlock call(MatrixBlock arg0)
throws Exception
//create deep copy of given block
return _deep ? new MatrixBlock(arg0) : arg0;
private static class CreateRowBlockCombinerFunction implements Function<RowMatrixBlock, MatrixBlock>
private static final long serialVersionUID = 2866598914232118425L;
public MatrixBlock call(RowMatrixBlock arg0)
throws Exception
//create new target block and copy row into it
MatrixBlock row = arg0.getValue();
MatrixBlock out = new MatrixBlock(arg0.getLen(), row.getNumColumns(), true);
out.copy(arg0.getRow(), arg0.getRow(), 0, row.getNumColumns()-1, row, false);
return out;
private static class MergeRowBlockValueFunction implements Function2<MatrixBlock, RowMatrixBlock, MatrixBlock>
private static final long serialVersionUID = -803689998683298516L;
public MatrixBlock call(MatrixBlock arg0, RowMatrixBlock arg1)
throws Exception
//copy row into existing target block
MatrixBlock row = arg1.getValue();
MatrixBlock out = arg0; //in-place update
out.copy(arg1.getRow(), arg1.getRow(), 0, row.getNumColumns()-1, row, true);
return out;
private static class CreateCellCombinerFunction implements Function<Double, KahanObject>
private static final long serialVersionUID = 3697505233057172994L;
public KahanObject call(Double arg0)
throws Exception
return new KahanObject(arg0, 0.0);
private static class MergeSumCellValueFunction implements Function2<KahanObject, Double, KahanObject>
private static final long serialVersionUID = 468335171573184825L;
public KahanObject call(KahanObject arg0, Double arg1)
throws Exception
//get reused kahan plus object
KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
//compute kahan plus and keep correction
kplus.execute2(arg0, arg1);
return arg0;
private static class MergeSumCellCombinerFunction implements Function2<KahanObject, KahanObject, KahanObject>
private static final long serialVersionUID = 8726716909849119657L;
public KahanObject call(KahanObject arg0, KahanObject arg1)
throws Exception
//get reused kahan plus object
KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
//compute kahan plus and keep correction
kplus.execute2(arg0, arg1._sum);
return arg0;
private static class MergeAggBlockValueFunction implements Function2<CorrMatrixBlock, MatrixBlock, CorrMatrixBlock>
private static final long serialVersionUID = 389422125491172011L;
private AggregateOperator _op = null;
public MergeAggBlockValueFunction(AggregateOperator aop)
_op = aop;
public CorrMatrixBlock call(CorrMatrixBlock arg0, MatrixBlock arg1)
throws Exception
//get current block and correction
MatrixBlock value = arg0.getValue();
MatrixBlock corr = arg0.getCorrection();
//correction block allocation on demand
if( corr == null && _op.existsCorrection() ){
corr = new MatrixBlock(value.getNumRows(), value.getNumColumns(), false);
//aggregate other input and maintain corrections
//(existing value and corr are used in place)
OperationsOnMatrixValues.incrementalAggregation(value, corr, arg1, _op, true);
OperationsOnMatrixValues.incrementalAggregation(value, null, arg1, _op, true);
return new CorrMatrixBlock(value, corr);
private static class MergeAggBlockCombinerFunction implements Function2<CorrMatrixBlock, CorrMatrixBlock, CorrMatrixBlock>
private static final long serialVersionUID = 4803711632648880797L;
private AggregateOperator _op = null;
public MergeAggBlockCombinerFunction(AggregateOperator aop)
_op = aop;
public CorrMatrixBlock call(CorrMatrixBlock arg0, CorrMatrixBlock arg1)
throws Exception
//get current block and correction
MatrixBlock value1 = arg0.getValue();
MatrixBlock value2 = arg1.getValue();
MatrixBlock corr = arg0.getCorrection();
//correction block allocation on demand (but use second if exists)
if( corr == null && _op.existsCorrection()) {
corr = (arg1.getCorrection()!=null)?arg1.getCorrection():
new MatrixBlock(value1.getNumRows(), value1.getNumColumns(), false);
//aggregate other input and maintain corrections
//(existing value and corr are used in place)
OperationsOnMatrixValues.incrementalAggregation(value1, corr, value2, _op, true);
OperationsOnMatrixValues.incrementalAggregation(value1, null, value2, _op, true);
return new CorrMatrixBlock(value1, corr);
private static class ExtractMatrixBlock implements Function<CorrMatrixBlock, MatrixBlock> {
private static final long serialVersionUID = 5242158678070843495L;
public MatrixBlock call(CorrMatrixBlock arg0) throws Exception {
return arg0.getValue();
private static class ExtractDoubleCell implements Function<KahanObject, Double> {
private static final long serialVersionUID = -2873241816558275742L;
public Double call(KahanObject arg0) throws Exception {
return arg0._sum;
* This aggregate function uses kahan+ with corrections to aggregate input blocks; it is meant for
* reduce all operations where we can reuse the same correction block independent of the input
* block indexes. Note that this aggregation function does not apply to embedded corrections.
private static class SumSingleBlockFunction implements Function2<MatrixBlock, MatrixBlock, MatrixBlock>
private static final long serialVersionUID = 1737038715965862222L;
private AggregateOperator _op = null;
private MatrixBlock _corr = null;
private boolean _deep = false;
public SumSingleBlockFunction(boolean deep) {
_op = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.NONE);
_deep = deep;
public MatrixBlock call(MatrixBlock arg0, MatrixBlock arg1)
throws Exception
//prepare combiner block
if( arg0.getNumRows() <= 0 || arg0.getNumColumns() <= 0 ) {
return arg0;
else if( arg1.getNumRows() <= 0 || arg1.getNumColumns() <= 0 ) {
return arg0;
//create correction block (on demand)
if( _corr == null ) {
_corr = new MatrixBlock(arg0.getNumRows(), arg0.getNumColumns(), false);
//aggregate other input (in-place if possible)
MatrixBlock out = _deep ? new MatrixBlock(arg0) : arg0;
out, _corr, arg1, _op, false);
return out;
* Note: currently we always include the correction and use a subsequent maptopair to
* drop them at the end because during aggregation we dont know if we produce an
* intermediate or the final aggregate.
private static class AggregateSingleBlockFunction implements Function2<MatrixBlock, MatrixBlock, MatrixBlock>
private static final long serialVersionUID = -3672377410407066396L;
private AggregateOperator _op = null;
private MatrixBlock _corr = null;
public AggregateSingleBlockFunction( AggregateOperator op ) {
_op = op;
public MatrixBlock call(MatrixBlock arg0, MatrixBlock arg1)
throws Exception
//prepare combiner block
if( arg0.getNumRows() == 0 && arg0.getNumColumns() == 0) {
return arg0;
else if( arg1.getNumRows() == 0 && arg1.getNumColumns() == 0 ) {
return arg0;
//create correction block (on demand)
if( _op.existsCorrection() && _corr == null ) {
_corr = new MatrixBlock(arg0.getNumRows(), arg0.getNumColumns(), false);
//aggregate second input (in-place)
arg0, _op.existsCorrection() ? _corr : null, arg1, _op, true);
return arg0;
* Note: currently we always include the correction and use a subsequent maptopair to
* drop them at the end because during aggregation we dont know if we produce an
* intermediate or the final aggregate.
private static class AggregateSingleTensorBlockFunction implements Function2<TensorBlock, TensorBlock, TensorBlock>
private static final long serialVersionUID = 5665180309149919945L;
private AggregateOperator _op = null;
public AggregateSingleTensorBlockFunction( AggregateOperator op ) {
_op = op;
public TensorBlock call(TensorBlock arg0, TensorBlock arg1)
throws Exception
//prepare combiner block
if( arg0.isEmpty()) {
return arg1;
else if( arg1.isEmpty() ) {
return arg0;
// TODO remove once KahanPlus is completely replaced by plus
if (_op.increOp.fn instanceof KahanPlus) {
_op = new AggregateOperator(0, Plus.getPlusFnObject());
//aggregate second input (in-place)
// TODO support DataTensor
arg0.getBasicTensor().incrementalAggregate(_op, arg1.getBasicTensor());
return arg0;
private static class MergeBlocksFunction implements Function2<MatrixBlock, MatrixBlock, MatrixBlock>
private static final long serialVersionUID = -8881019027250258850L;
private boolean _deep = false;
public MergeBlocksFunction() {
//by default deep copy first argument
public MergeBlocksFunction(boolean deep) {
_deep = deep;
public MatrixBlock call(MatrixBlock b1, MatrixBlock b2)
throws Exception
long b1nnz = b1.getNonZeros();
long b2nnz = b2.getNonZeros();
// sanity check input dimensions
if (b1.getNumRows() != b2.getNumRows() || b1.getNumColumns() != b2.getNumColumns()) {
throw new DMLRuntimeException("Mismatched block sizes for: "
+ b1.getNumRows() + " " + b1.getNumColumns() + " "
+ b2.getNumRows() + " " + b2.getNumColumns());
// execute merge (never pass by reference)
MatrixBlock ret = _deep ? new MatrixBlock(b1) : b1;
ret.merge(b2, false, false, _deep);
// sanity check output number of non-zeros
if (ret.getNonZeros() != b1nnz + b2nnz) {
throw new DMLRuntimeException("Number of non-zeros does not match: "
+ ret.getNonZeros() + " != " + b1nnz + " + " + b2nnz);
return ret;