blob: 2b0eb1fd9faf51b764f7b5ca5c14d12ae209187a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.codegen;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;
public abstract class SpoofMultiAggregate extends SpoofOperator implements Serializable
{
private static final long serialVersionUID = -6164871955591089349L;
private final AggOp[] _aggOps;
private final boolean _sparseSafe;
public SpoofMultiAggregate(boolean sparseSafe, AggOp... aggOps) {
_sparseSafe = sparseSafe;
_aggOps = aggOps;
}
public AggOp[] getAggOps() {
return _aggOps;
}
public boolean isSparseSafe() {
return _sparseSafe;
}
@Override
public String getSpoofType() {
return "MA" + getClass().getName().split("\\.")[1];
}
@Override
public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out) {
return execute(inputs, scalarObjects, out, 1, 0);
}
@Override
public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out, int k) {
return execute(inputs, scalarObjects, out, k, 0);
}
public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out, int k, long rix) {
//sanity check
if( inputs==null || inputs.size() < 1 )
throw new RuntimeException("Invalid input arguments.");
long inputSize = isSparseSafe() ?
getTotalInputNnz(inputs) : getTotalInputSize(inputs);
if( inputSize < PAR_NUMCELL_THRESHOLD ) {
k = 1; //serial execution
}
//result allocation and preparations
out.reset(1, _aggOps.length, false);
out.allocateDenseBlock();
double[] c = out.getDenseBlockValues(); //1x<num_agg>
setInitialOutputValues(c);
//input preparation
SideInput[] b = prepInputMatrices(inputs);
double[] scalars = prepInputScalars(scalarObjects);
final int m = inputs.get(0).getNumRows();
final int n = inputs.get(0).getNumColumns();
boolean sparseSafe = isSparseSafe();
if( k <= 1 ) //SINGLE-THREADED
{
if( !inputs.get(0).isInSparseFormat() )
executeDense(inputs.get(0).getDenseBlock(), b, scalars, c, m, n, sparseSafe, 0, m, rix);
else
executeSparse(inputs.get(0).getSparseBlock(), b, scalars, c, m, n, sparseSafe, 0, m, rix);
}
else //MULTI-THREADED
{
try {
ExecutorService pool = CommonThreadPool.get(k);
ArrayList<ParAggTask> tasks = new ArrayList<>();
int nk = UtilFunctions.roundToNext(Math.min(8*k,m/32), k);
int blklen = (int)(Math.ceil((double)m/nk));
for( int i=0; i<nk & i*blklen<m; i++ )
tasks.add(new ParAggTask(inputs.get(0), b, scalars,
m, n, sparseSafe, i*blklen, Math.min((i+1)*blklen, m)));
//execute tasks
List<Future<double[]>> taskret = pool.invokeAll(tasks);
pool.shutdown();
//aggregate partial results
ArrayList<double[]> pret = new ArrayList<>();
for( Future<double[]> task : taskret )
pret.add(task.get());
aggregatePartialResults(c, pret);
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
}
//post-processing
out.recomputeNonZeros();
out.examSparsity();
return out;
}
private void executeDense(DenseBlock a, SideInput[] b, double[] scalars, double[] c, int m, int n, boolean sparseSafe, int rl, int ru, long rix)
{
SideInput[] lb = createSparseSideInputs(b);
//core dense aggregation operation
if( a == null && !sparseSafe ) {
for( int i=rl; i<ru; i++ )
for( int j=0; j<n; j++ )
genexec( 0, lb, scalars, c, m, n, rix+i, i, j );
}
else if( a != null ) {
for( int i=rl; i<ru; i++ ) {
double[] avals = a.values(i);
int aix = a.pos(i);
for( int j=0; j<n; j++ )
genexec( avals[aix+j], lb, scalars, c, m, n, rix+i, i, j );
}
}
}
private void executeSparse(SparseBlock sblock, SideInput[] b, double[] scalars,
double[] c, int m, int n, boolean sparseSafe, int rl, int ru, long rix)
{
if( sblock == null && sparseSafe )
return;
SideInput[] lb = createSparseSideInputs(b);
//note: sequential scan algorithm for both sparse-safe and -unsafe
//in order to avoid binary search for sparse-unsafe
for(int i=rl; i<ru; i++) {
int lastj = -1;
//handle non-empty rows
if( sblock != null && !sblock.isEmpty(i) ) {
int apos = sblock.pos(i);
int alen = sblock.size(i);
int[] aix = sblock.indexes(i);
double[] avals = sblock.values(i);
for(int k=apos; k<apos+alen; k++) {
//process zeros before current non-zero
if( !sparseSafe )
for(int j=lastj+1; j<aix[k]; j++)
genexec(0, lb, scalars, c, m, n, rix+i, i, j);
//process current non-zero
lastj = aix[k];
genexec(avals[k], lb, scalars, c, m, n, rix+i, i, lastj);
}
}
//process empty rows or remaining zeros
if( !sparseSafe )
for(int j=lastj+1; j<n; j++)
genexec(0, lb, scalars, c, m, n, rix+i, i, j);
}
}
//local execution where grix==rix
protected final void genexec( double a, SideInput[] b,
double[] scalars, double[] c, int m, int n, int rix, int cix) {
genexec(a, b, scalars, c, m, n, rix, rix, cix);
}
//distributed execution with additional global row index
protected abstract void genexec( double a, SideInput[] b,
double[] scalars, double[] c, int m, int n, long grix, int rix, int cix);
private void setInitialOutputValues(double[] c) {
for( int k=0; k<_aggOps.length; k++ )
c[k] = getInitialValue(_aggOps[k]);
}
public static double getInitialValue(AggOp aggop) {
switch( aggop ) {
case SUM:
case SUM_SQ: return 0;
case MIN: return Double.POSITIVE_INFINITY;
case MAX: return Double.NEGATIVE_INFINITY;
}
return 0;
}
private void aggregatePartialResults(double[] c, ArrayList<double[]> pret) {
ValueFunction[] vfun = getAggFunctions(_aggOps);
for( int k=0; k<_aggOps.length; k++ ) {
if( vfun[k] instanceof KahanFunction ) {
KahanObject kbuff = new KahanObject(0, 0);
KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
for(double[] tmp : pret)
kplus.execute2(kbuff, tmp[k]);
c[k] = kbuff._sum;
}
else {
for(double[] tmp : pret)
c[k] = vfun[k].execute(c[k], tmp[k]);
}
}
}
public static void aggregatePartialResults(AggOp[] aggOps, MatrixBlock c, MatrixBlock b) {
ValueFunction[] vfun = getAggFunctions(aggOps);
for( int k=0; k< aggOps.length; k++ ) {
if( vfun[k] instanceof KahanFunction ) {
KahanObject kbuff = new KahanObject(c.quickGetValue(0, k), 0);
KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
kplus.execute2(kbuff, b.quickGetValue(0, k));
c.quickSetValue(0, k, kbuff._sum);
}
else {
double cval = c.quickGetValue(0, k);
double bval = b.quickGetValue(0, k);
c.quickSetValue(0, k, vfun[k].execute(cval, bval));
}
}
}
public static ValueFunction[] getAggFunctions(AggOp[] aggOps) {
ValueFunction[] fun = new ValueFunction[aggOps.length];
for( int i=0; i<aggOps.length; i++ ) {
switch( aggOps[i] ) {
case SUM: fun[i] = KahanPlus.getKahanPlusFnObject(); break;
case SUM_SQ: fun[i] = KahanPlusSq.getKahanPlusSqFnObject(); break;
case MIN: fun[i] = Builtin.getBuiltinFnObject(BuiltinCode.MIN); break;
case MAX: fun[i] = Builtin.getBuiltinFnObject(BuiltinCode.MAX); break;
default:
throw new RuntimeException("Unsupported "
+ "aggregation type: "+aggOps[i].name());
}
}
return fun;
}
private class ParAggTask implements Callable<double[]>
{
private final MatrixBlock _a;
private final SideInput[] _b;
private final double[] _scalars;
private final int _rlen;
private final int _clen;
private final boolean _safe;
private final int _rl;
private final int _ru;
protected ParAggTask( MatrixBlock a, SideInput[] b, double[] scalars,
int rlen, int clen, boolean safe, int rl, int ru ) {
_a = a;
_b = b;
_scalars = scalars;
_rlen = rlen;
_clen = clen;
_safe = safe;
_rl = rl;
_ru = ru;
}
@Override
public double[] call() {
double[] c = new double[_aggOps.length];
setInitialOutputValues(c);
if( !_a.isInSparseFormat() )
executeDense(_a.getDenseBlock(), _b, _scalars, c, _rlen, _clen, _safe, _rl, _ru, 0);
else
executeSparse(_a.getSparseBlock(), _b, _scalars, c, _rlen, _clen, _safe, _rl, _ru, 0);
return c;
}
}
}