[MINOR] Performance binary cell operations sparse matrix-row vector
This patch improves the performance of element-wise binary operations
for sparse matrix, dense row vector, and dense outputs by adding a
dedicated kernel.
For a 1M x 1K sparse matrix with sp=0.1 and X - colMeans(X), this patch
improved performance from 11.3s to 5.9s (single-threaded). This is now
in the ballpark of dense operations of similar shape.
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
index 34464a6..f44f179 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
@@ -214,6 +214,9 @@
//note: m2 vector and hence always dense
if( !m1.sparse && !m2.sparse && !ret.sparse ) //DENSE all
safeBinaryMVDense(m1, m2, ret, op);
+ else if( m1.sparse && !m2.sparse && !ret.sparse
+ && atype == BinaryAccessType.MATRIX_ROW_VECTOR)
+ safeBinaryMVSparseDenseRow(m1, m2, ret, op);
else if( m1.sparse ) //SPARSE m1
safeBinaryMVSparse(m1, m2, ret, op);
else if( !m1.sparse && !m2.sparse && ret.sparse && op.fn instanceof Multiply
@@ -340,7 +343,7 @@
int len = dc.blockSize(bi);
for( int i=0, ix=0; i<len; i++, ix+=clen )
for( int j=0; j<clen; j++ ) {
- c[ix+j] = op.fn.execute( a[ix+j], ((b!=null) ? b[j] : 0) );
+ c[ix+j] = op.fn.execute( a[ix+j], ((b!=null) ? b[j] : 0) );
nnz += (c[ix+j] != 0) ? 1 : 0;
}
}
@@ -350,6 +353,52 @@
ret.nonZeros = nnz;
}
+ private static void safeBinaryMVSparseDenseRow(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) {
+ boolean isMultiply = (op.fn instanceof Multiply);
+ boolean skipEmpty = (isMultiply);
+ int rlen = m1.rlen;
+ int clen = m1.clen;
+ SparseBlock a = m1.sparseBlock;
+ double[] b = m2.getDenseBlockValues();
+ DenseBlock c = ret.allocateDenseBlock().getDenseBlock();
+
+ //early abort on skip and empty
+ if( skipEmpty && (m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) )
+ return; // skip entire empty block
+
+ //prepare op(0, m2) vector once for all rows
+ double[] tmp = new double[clen];
+ if( !skipEmpty ) {
+ for( int i=0; i<clen; i++ )
+ tmp[i] = op.fn.execute(0, b[i]);
+ }
+
+ long nnz = 0;
+ for( int i=0; i<rlen; i++ ) {
+ if( skipEmpty && (a==null || a.isEmpty(i)) )
+ continue; //skip empty rows
+
+ //set prepared empty row vector into output
+ double[] cvals = c.values(i);
+ int cpos = c.pos(i);
+ System.arraycopy(tmp, 0, cvals, cpos, clen);
+
+ //overwrite row cells with existing sparse lhs values
+ if( a!=null && !a.isEmpty(i) ) {
+ int apos = a.pos(i);
+ int alen = a.size(i);
+ int[] aix = a.indexes(i);
+ double[] avals = a.values(i);
+ for( int j=apos; j<apos+alen; j++ )
+ cvals[cpos+aix[j]] = op.fn.execute(avals[j], b[aix[j]]);
+ }
+
+ //compute row nnz with temporal locality
+ nnz += UtilFunctions.computeNnz(cvals, cpos, clen);
+ }
+ ret.nonZeros = nnz;
+ }
+
private static void safeBinaryMVSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) {
boolean isMultiply = (op.fn instanceof Multiply);
boolean skipEmpty = (isMultiply);
@@ -369,8 +418,7 @@
if( atype == BinaryAccessType.MATRIX_COL_VECTOR )
{
- for( int i=0; i<rlen; i++ )
- {
+ for( int i=0; i<rlen; i++ ) {
double v2 = m2.quickGetValue(i, 0);
if( (skipEmpty && (a==null || a.isEmpty(i) || v2 == 0 ))
@@ -379,62 +427,46 @@
continue; //skip empty rows
}
- if( isMultiply && v2==1 ) //ROW COPY
- {
+ if( isMultiply && v2==1 ) { //ROW COPY
if( a != null && !a.isEmpty(i) )
ret.appendRow(i, a.get(i));
}
- else //GENERAL CASE
- {
+ else { //GENERAL CASE
int lastIx = -1;
- if( a != null && !a.isEmpty(i) )
- {
+ if( a != null && !a.isEmpty(i) ) {
int apos = a.pos(i);
int alen = a.size(i);
int[] aix = a.indexes(i);
double[] avals = a.values(i);
for( int j=apos; j<apos+alen; j++ ) {
//empty left
- for( int k = lastIx+1; !skipEmpty&&k<aix[j]; k++ ){
- double v = op.fn.execute( 0, v2 );
- ret.appendValue(i, k, v);
- }
+ fillZeroValues(op, v2, ret, skipEmpty, i, lastIx+1, aix[j]);
//actual value
double v = op.fn.execute( avals[j], v2 );
ret.appendValue(i, aix[j], v);
lastIx = aix[j];
}
}
-
//empty left
- for( int k = lastIx+1; !skipEmpty&&k<clen; k++ ){
- double v = op.fn.execute( 0, v2 );
- ret.appendValue(i, k, v);
- }
+ fillZeroValues(op, v2, ret, skipEmpty, i, lastIx+1, clen);
}
}
}
else if( atype == BinaryAccessType.MATRIX_ROW_VECTOR )
{
- for( int i=0; i<rlen; i++ )
- {
+ for( int i=0; i<rlen; i++ ) {
if( skipEmpty && (a==null || a.isEmpty(i)) )
continue; //skip empty rows
int lastIx = -1;
- if( a!=null && !a.isEmpty(i) )
- {
+ if( a!=null && !a.isEmpty(i) ) {
int apos = a.pos(i);
int alen = a.size(i);
int[] aix = a.indexes(i);
double[] avals = a.values(i);
for( int j=apos; j<apos+alen; j++ ) {
//empty left
- for( int k=lastIx+1; !skipEmpty&&k<aix[j]; k++ ){
- double v2 = m2.quickGetValue(0, k);
- double v = op.fn.execute( 0, v2 );
- ret.appendValue(i, k, v);
- }
+ fillZeroValues(op, m2, ret, skipEmpty, i, lastIx+1, aix[j]);
//actual value
double v2 = m2.quickGetValue(0, aix[j]);
double v = op.fn.execute( avals[j], v2 );
@@ -442,18 +474,32 @@
lastIx = aix[j];
}
}
-
//empty left
- for( int k=lastIx+1; !skipEmpty&&k<clen; k++ ){
- double v2 = m2.quickGetValue(0, k);
- double v = op.fn.execute( 0, v2 );
- ret.appendValue(i, k, v);
- }
+ fillZeroValues(op, m2, ret, skipEmpty, i, lastIx+1, clen);
}
}
//no need to recomputeNonZeros since maintained in append value
}
+
+ private static void fillZeroValues(BinaryOperator op, double v2, MatrixBlock ret, boolean skipEmpty, int rpos, int cpos, int len) {
+ if(skipEmpty)
+ return;
+ for( int k=cpos; k<len; k++ ){
+ double v = op.fn.execute(0, v2);
+ ret.appendValue(rpos, k, v);
+ }
+ }
+
+ private static void fillZeroValues(BinaryOperator op, MatrixBlock m2, MatrixBlock ret, boolean skipEmpty, int rpos, int cpos, int len) {
+ if(skipEmpty)
+ return;
+ for( int k=cpos; k<len; k++ ){
+ double v2 = m2.quickGetValue(0, k);
+ double v = op.fn.execute(0, v2);
+ ret.appendValue(rpos, k, v);
+ }
+ }
private static void safeBinaryMVDenseSparseMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) {
if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) )