[SYSTEMDS-2695 + 2743] CLA Row parallel left %*%
This PR contains re-enabling parallel left multiplication for
sparse matrices, plus row based parallelization of dense.
Furthermore, it also contains optimization of Binary and scalar
divide, that does not accidentally decompress anymore.
Closes #1118
diff --git a/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
index 7055f0d..aff136a 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
@@ -26,6 +26,7 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.random.Well1024a;
+import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.colgroup.ColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
@@ -507,7 +508,7 @@
}
protected static MatrixBlock getUncompressed(MatrixValue mVal) {
- return isCompressed((MatrixBlock) mVal) ? ((CompressedMatrixBlock) mVal).decompress() : (MatrixBlock) mVal;
+ return isCompressed((MatrixBlock) mVal) ? ((CompressedMatrixBlock) mVal).decompress(OptimizerUtils.getConstrainedNumThreads(-1)) : (MatrixBlock) mVal;
}
protected void printDecompressWarning(String operation) {
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index 660c28e..986a77d 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -62,6 +62,7 @@
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseRow;
import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.functionobjects.Equals;
import org.apache.sysds.runtime.functionobjects.GreaterThan;
@@ -115,11 +116,13 @@
/**
* Main constructor for building a block from scratch.
*
+ * Use with caution, since it constructs an empty matrix block with nothing inside.
+ *
* @param rl number of rows in the block
* @param cl number of columns
* @param sparse true if the UNCOMPRESSED representation of the block should be sparse
*/
- protected CompressedMatrixBlock(int rl, int cl, boolean sparse) {
+ public CompressedMatrixBlock(int rl, int cl, boolean sparse) {
super(rl, cl, sparse);
}
@@ -371,26 +374,26 @@
@Override
public MatrixBlock scalarOperations(ScalarOperator sop, MatrixValue result) {
// Special case handling of overlapping relational operations
- if(sop.fn instanceof LessThan || sop.fn instanceof LessThanEquals || sop.fn instanceof GreaterThan ||
- sop.fn instanceof GreaterThanEquals || sop.fn instanceof Equals || sop.fn instanceof NotEquals) {
- CompressedMatrixBlock ret = null;
- if(result == null || !(result instanceof CompressedMatrixBlock))
- ret = new CompressedMatrixBlock(getNumRows(), getNumColumns(), sparse);
- return LibRelationalOp.relationalOperation(sop, this, ret, overlappingColGroups);
+ if(isOverlapping() &&
+ (sop.fn instanceof LessThan || sop.fn instanceof LessThanEquals || sop.fn instanceof GreaterThan ||
+ sop.fn instanceof GreaterThanEquals || sop.fn instanceof Equals || sop.fn instanceof NotEquals)) {
+ MatrixBlock ret = LibRelationalOp.relationalOperation(sop, this, isOverlapping());
+
+ result = ret;
+ return ret;
}
- if(overlappingColGroups &&
- (!(sop.fn instanceof Multiply || sop.fn instanceof Plus || sop.fn instanceof Minus))) {
+ if(isOverlapping() && (!(sop.fn instanceof Multiply || sop.fn instanceof Divide
+ || sop.fn instanceof Plus || sop.fn instanceof Minus))) {
LOG.warn("scalar overlapping not supported for op: " + sop.fn);
MatrixBlock m1d = decompress(sop.getNumThreads());
return m1d.scalarOperations(sop, result);
-
}
CompressedMatrixBlock ret = null;
if(result == null || !(result instanceof CompressedMatrixBlock))
ret = new CompressedMatrixBlock(getNumRows(), getNumColumns(), sparse);
- return LibScalar.scalarOperations(sop, this, ret, overlappingColGroups);
+ return LibScalar.scalarOperations(sop, this, ret, isOverlapping());
}
@Override
@@ -402,20 +405,22 @@
+ "x" + this.clen + " vs " + that.getNumRows() + "x" + that.getNumColumns());
}
- if(LibMatrixBincell.getBinaryAccessType(this, that) == BinaryAccessType.MATRIX_COL_VECTOR ||
- (this.getNumColumns() == 1 && that.getNumColumns() == 1 && that.getNumRows() != 1) ||
- !(op.fn instanceof Multiply || op.fn instanceof Plus || op.fn instanceof Minus ||
- op.fn instanceof MinusMultiply || op.fn instanceof PlusMultiply)) {
- // case MATRIX_COL_VECTOR:
- // TODO make partial decompress and do operation.
- // TODO support more of the operations... since it is possible.
+ BinaryAccessType atype = LibMatrixBincell.getBinaryAccessType(this, that);
+
+ if(atype == BinaryAccessType.MATRIX_COL_VECTOR || atype == BinaryAccessType.MATRIX_MATRIX ) {
+ MatrixBlock ret = LibBinaryCellOp.binaryMVPlusCol(this, that, op);
+ result = ret;
+ return ret;
+ }
+ else if(!(op.fn instanceof Multiply || op.fn instanceof Divide || op.fn instanceof Plus || op.fn instanceof Minus ||
+ op.fn instanceof MinusMultiply || op.fn instanceof PlusMultiply)) {
+ LOG.warn("Decompressing since Binary Ops" + op.fn + " is not supported compressed");
MatrixBlock m2 = getUncompressed(this);
MatrixBlock ret = m2.binaryOperations(op, thatValue, result);
result = ret;
return ret;
}
else {
-
CompressedMatrixBlock ret = null;
if(result == null || !(result instanceof CompressedMatrixBlock))
ret = new CompressedMatrixBlock(getNumRows(), getNumColumns(), sparse);
@@ -608,9 +613,7 @@
return LibCompAgg.aggregateUnaryOverlapping(this, ret, op, blen, indexesIn, inCP);
}
- ret = LibCompAgg.aggregateUnary(this, ret, op, blen, indexesIn, inCP);
-
- return ret;
+ return LibCompAgg.aggregateUnary(this, ret, op, blen, indexesIn, inCP);
}
@Override
@@ -732,7 +735,7 @@
}
public boolean isOverlapping() {
- return overlappingColGroups;
+ return _colGroups.size() != 1 && overlappingColGroups;
}
public void setOverlapping(boolean overlapping) {
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java
index 5e0f712..a48532d 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java
@@ -89,43 +89,38 @@
@Override
public void decompressToBlock(MatrixBlock target, int rl, int ru, int offT, double[] values) {
- if(getNumValues() > 1) {
- final int blksz = CompressionSettings.BITMAP_BLOCK_SZ;
- final int numCols = getNumCols();
- final int numVals = getNumValues();
- // cache blocking config and position array
- int[] apos = skipScan(numVals, rl);
+ final int blksz = CompressionSettings.BITMAP_BLOCK_SZ;
+ final int numCols = getNumCols();
+ final int numVals = getNumValues();
- // cache conscious append via horizontal scans
- for(int bi = (rl / blksz) * blksz; bi < ru; bi += blksz) {
- for(int k = 0, off = 0; k < numVals; k++, off += numCols) {
- int boff = _ptr[k];
- int blen = len(k);
- int bix = apos[k];
-
- if(bix >= blen)
- continue;
- int len = _data[boff + bix];
- int pos = boff + bix + 1;
- for(int i = pos; i < pos + len; i++) {
- int row = bi + _data[i];
- if(row >= rl && row < ru){
- int rix = row - (rl - offT);
- for(int j = 0; j < numCols; j++) {
- double v = target.quickGetValue(rix, _colIndexes[j]);
- target.setValue(rix, _colIndexes[j], values[off + j] + v);
- }
+ // cache blocking config and position array
+ int[] apos = skipScan(numVals, rl);
+
+ // cache conscious append via horizontal scans
+ for(int bi = (rl / blksz) * blksz; bi < ru; bi += blksz) {
+ for(int k = 0, off = 0; k < numVals; k++, off += numCols) {
+ int boff = _ptr[k];
+ int blen = len(k);
+ int bix = apos[k];
+
+ if(bix >= blen)
+ continue;
+ int len = _data[boff + bix];
+ int pos = boff + bix + 1;
+ for(int i = pos; i < pos + len; i++) {
+ int row = bi + _data[i];
+ if(row >= rl && row < ru) {
+ int rix = row - (rl - offT);
+ for(int j = 0; j < numCols; j++) {
+ double v = target.quickGetValue(rix, _colIndexes[j]);
+ target.setValue(rix, _colIndexes[j], values[off + j] + v);
}
}
- apos[k] += len + 1;
}
+ apos[k] += len + 1;
}
}
- else {
- // call generic decompression with decoder
- super.decompressToBlock(target, rl, ru, offT, values);
- }
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOffset.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOffset.java
index d518434..c4b6d08 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOffset.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOffset.java
@@ -65,7 +65,7 @@
super(colIndices, numRows, ubm, cs);
}
- protected ColGroupOffset(int[] colIndices, int numRows, boolean zeros, ADictionary dict){
+ protected ColGroupOffset(int[] colIndices, int numRows, boolean zeros, ADictionary dict) {
super(colIndices, numRows, dict);
_zeros = zeros;
}
@@ -94,33 +94,8 @@
return ColGroupSizes.estimateInMemorySizeOffset(getNumCols(), _colIndexes.length, 0, 0, isLossy());
}
else {
- return ColGroupSizes.estimateInMemorySizeOffset(getNumCols(), getValues().length, _ptr.length, _data.length, isLossy());
- }
- }
-
- @Override
- public void decompressToBlock(MatrixBlock target, int rl, int ru, int offT, double[] values) {
- final int numCols = getNumCols();
- final int numVals = getNumValues();
- int[] colIndices = getColIndices();
-
- // Run through the bitmaps for this column group
- for(int i = 0; i < numVals; i++) {
- Iterator<Integer> decoder = getIterator(i);
- int valOff = i * numCols;
-
- while(decoder.hasNext()) {
- int row = decoder.next();
- if(row < rl)
- continue;
- if(row > ru)
- break;
- row = row - (rl - offT);
- for(int colIx = 0; colIx < numCols; colIx++){
- double v = target.quickGetValue(row , colIndices[colIx]);
- target.setValue(row, colIndices[colIx], v + values[valOff + colIx]);
- }
- }
+ return ColGroupSizes
+ .estimateInMemorySizeOffset(getNumCols(), getValues().length, _ptr.length, _data.length, isLossy());
}
}
@@ -207,8 +182,8 @@
protected final double mxxValues(int bitmapIx, Builtin builtin, double[] values) {
final int numCols = getNumCols();
final int valOff = bitmapIx * numCols;
- double val = (builtin.getBuiltinCode() == BuiltinCode.MAX) ?
- Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY;
+ double val = (builtin
+ .getBuiltinCode() == BuiltinCode.MAX) ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY;
for(int i = 0; i < numCols; i++)
val = builtin.execute(val, values[valOff + i]);
@@ -249,15 +224,15 @@
@Override
public void readFields(DataInput in) throws IOException {
super.readFields(in);
-
+
// read bitmaps
_ptr = new int[in.readInt()];
- for(int i = 0; i< _ptr.length; i++){
+ for(int i = 0; i < _ptr.length; i++) {
_ptr[i] = in.readInt();
}
int totalLen = in.readInt();
_data = new char[totalLen];
- for(int i = 0; i< totalLen; i++){
+ for(int i = 0; i < totalLen; i++) {
_data[i] = in.readChar();
}
}
@@ -267,11 +242,11 @@
super.write(out);
// write bitmaps (lens and data, offset later recreated)
out.writeInt(_ptr.length);
- for(int i = 0; i < _ptr.length; i++){
+ for(int i = 0; i < _ptr.length; i++) {
out.writeInt(_ptr[i]);
}
out.writeInt(_data.length);
- for(int i = 0; i < _data.length; i++){
+ for(int i = 0; i < _data.length; i++) {
out.writeChar(_data[i]);
}
@@ -286,7 +261,7 @@
ret += 4; // _data list
ret += 2 * _data.length;
// for(int i = 0; i < getNumValues(); i++)
- // ret += 4 + 2 * len(i);
+ // ret += 4 + 2 * len(i);
return ret;
}
@@ -320,9 +295,9 @@
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(super.toString());
- sb.append(String.format("\n%15s%5d ", "Pointers:" , this._ptr.length ));
+ sb.append(String.format("\n%15s%5d ", "Pointers:", this._ptr.length));
sb.append(Arrays.toString(this._ptr));
- sb.append(String.format("\n%15s%5d ", "Data:" , this._data.length));
+ sb.append(String.format("\n%15s%5d ", "Data:", this._data.length));
sb.append("[");
for(int x = 0; x < _data.length; x++) {
sb.append(((int) _data[x]));
@@ -365,7 +340,8 @@
public IJV next() {
if(!hasNext())
throw new RuntimeException("No more offset entries.");
- _buff.set(_rpos, _colIndexes[_cpos],
+ _buff.set(_rpos,
+ _colIndexes[_cpos],
(_vpos >= getNumValues()) ? 0 : _dict.getValue(_vpos * getNumCols() + _cpos));
getNextValue();
return _buff;
@@ -481,7 +457,7 @@
public IJV next() {
if(!hasNext())
throw new RuntimeException("No more offset entries.");
- _ret.set(_rpos, _colIndexes[_cpos], (_vpos < 0) ? 0 : _dict.getValue(_vpos *getNumCols() + _cpos));
+ _ret.set(_rpos, _colIndexes[_cpos], (_vpos < 0) ? 0 : _dict.getValue(_vpos * getNumCols() + _cpos));
getNextValue();
return _ret;
}
@@ -505,7 +481,7 @@
return;
_cpos++;
}
- while(!_inclZeros && (_vpos < 0 || _dict.getValue(_vpos *getNumCols() + _cpos) == 0));
+ while(!_inclZeros && (_vpos < 0 || _dict.getValue(_vpos * getNumCols() + _cpos) == 0));
}
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java
index 6bbcd5d..551bfb3 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java
@@ -70,6 +70,7 @@
// compact bitmaps to linearized representation
createCompressedBitmaps(numVals, totalLen, lbitmaps);
+ // LOG.error(this);
}
protected ColGroupRLE(int[] colIndices, int numRows, boolean zeros, ADictionary dict, char[] bitmaps,
@@ -91,44 +92,38 @@
@Override
public void decompressToBlock(MatrixBlock target, int rl, int ru, int offT, double[] values) {
- if(getNumValues() > 1) {
- final int blksz = CompressionSettings.BITMAP_BLOCK_SZ;
- final int numCols = getNumCols();
- final int numVals = getNumValues();
+ final int blksz = CompressionSettings.BITMAP_BLOCK_SZ;
+ final int numCols = getNumCols();
+ final int numVals = getNumValues();
- // position and start offset arrays
- int[] astart = new int[numVals];
- int[] apos = skipScan(numVals, rl, astart);
+ // position and start offset arrays
+ int[] astart = new int[numVals];
+ int[] apos = skipScan(numVals, rl, astart);
- // cache conscious append via horizontal scans
- for(int bi = rl; bi < ru; bi += blksz) {
- int bimax = Math.min(bi + blksz, ru);
- for(int k = 0, off = 0; k < numVals; k++, off += numCols) {
- int boff = _ptr[k];
- int blen = len(k);
- int bix = apos[k];
- int start = astart[k];
- for(; bix < blen & start < bimax; bix += 2) {
- start += _data[boff + bix];
- int len = _data[boff + bix + 1];
- for(int i = Math.max(rl, start) - (rl - offT); i < Math.min(start + len, ru) - (rl - offT); i++)
- for(int j = 0; j < numCols; j++) {
- if(values[off + j] != 0) {
- double v = target.quickGetValue(i, _colIndexes[j]);
- target.quickSetValue(i, _colIndexes[j], values[off + j] + v);
- }
+ // cache conscious append via horizontal scans
+ for(int bi = rl; bi < ru; bi += blksz) {
+ int bimax = Math.min(bi + blksz, ru);
+ for(int k = 0, off = 0; k < numVals; k++, off += numCols) {
+ int boff = _ptr[k];
+ int blen = len(k);
+ int bix = apos[k];
+ int start = astart[k];
+ for(; bix < blen & start < bimax; bix += 2) {
+ start += _data[boff + bix];
+ int len = _data[boff + bix + 1];
+ for(int i = Math.max(rl, start) - (rl - offT); i < Math.min(start + len, ru) - (rl - offT); i++)
+ for(int j = 0; j < numCols; j++) {
+ if(values[off + j] != 0) {
+ double v = target.quickGetValue(i, _colIndexes[j]);
+ target.quickSetValue(i, _colIndexes[j], values[off + j] + v);
}
- start += len;
- }
- apos[k] = bix;
- astart[k] = start;
+ }
+ start += len;
}
+ apos[k] = bix;
+ astart[k] = start;
}
}
- else {
- // call generic decompression with decoder
- super.decompressToBlock(target, rl, ru, offT, values);
- }
}
@Override
@@ -183,7 +178,6 @@
@Override
public void decompressToBlock(MatrixBlock target, int colpos) {
- // LOG.error("Does not work");
final int blksz = 128 * 1024;
final int numCols = getNumCols();
final int numVals = getNumValues();
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/QDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/QDictionary.java
index 6d1906e..0105e7b 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/QDictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/QDictionary.java
@@ -28,6 +28,7 @@
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.utils.BitmapLossy;
import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
@@ -123,7 +124,7 @@
@Override
public QDictionary apply(ScalarOperator op) {
- if(op.fn instanceof Multiply) {
+ if(op.fn instanceof Multiply || op.fn instanceof Divide) {
_scale = op.executeScalar(_scale);
return this;
// return new QDictionary(_values, op.executeScalar(_scale));
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/LibBinaryCellOp.java b/src/main/java/org/apache/sysds/runtime/compress/lib/LibBinaryCellOp.java
index 01239f7..9a77b50 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/LibBinaryCellOp.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/LibBinaryCellOp.java
@@ -21,17 +21,26 @@
import java.util.ArrayList;
import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLCompressionException;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.colgroup.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.ColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.Dictionary;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
@@ -42,6 +51,7 @@
import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
+import org.apache.sysds.runtime.util.CommonThreadPool;
public class LibBinaryCellOp {
@@ -63,7 +73,10 @@
m2 = m2.scalarOperations(sop, new MatrixBlock());
return LibBinaryCellOp.bincellOp(m1, m2, ret, new BinaryOperator(Plus.getPlusFnObject()));
}
- if(m1.isOverlapping() && !(op.fn instanceof Multiply)) {
+
+ BinaryAccessType atype = LibMatrixBincell.getBinaryAccessType(m1, m2);
+
+ if(m1.isOverlapping() && !(op.fn instanceof Multiply || op.fn instanceof Divide)) {
if(op.fn instanceof Plus || op.fn instanceof Minus) {
return binaryMVPlusStack(m1, m2, ret, op);
}
@@ -73,7 +86,6 @@
}
else {
- BinaryAccessType atype = LibMatrixBincell.getBinaryAccessType(m1, m2);
switch(atype) {
case MATRIX_ROW_VECTOR:
// Verify if it is okay to include all OuterVectorVector ops here.
@@ -150,4 +162,73 @@
ret.setNonZeros(-1);
return ret;
}
+
+ public static MatrixBlock binaryMVPlusCol(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op) {
+ MatrixBlock ret = new MatrixBlock(m1.getNumRows(), m1.getNumColumns(), false, -1).allocateBlock();
+
+ final int blkz = CompressionSettings.BITMAP_BLOCK_SZ;
+ int k = OptimizerUtils.getConstrainedNumThreads(-1);
+ ExecutorService pool = CommonThreadPool.get(k);
+ ArrayList<BinaryMVColTask> tasks = new ArrayList<>();
+
+ try {
+ for(int i = 0; i * blkz < m1.getNumRows(); i++) {
+ BinaryMVColTask rt = new BinaryMVColTask(m1.getColGroups(), m2, ret, i * blkz,
+ Math.min(m1.getNumRows(), (i + 1) * blkz), op);
+ tasks.add(rt);
+ }
+ List<Future<Integer>> futures = pool.invokeAll(tasks);
+ pool.shutdown();
+ long nnz = 0;
+ for(Future<Integer> f : futures)
+ nnz += f.get();
+ ret.setNonZeros(nnz);
+ }
+ catch(InterruptedException | ExecutionException e) {
+ e.printStackTrace();
+ throw new DMLRuntimeException(e);
+ }
+
+ return ret;
+ }
+
+ private static class BinaryMVColTask implements Callable<Integer> {
+ private final List<ColGroup> _groups;
+ private final int _rl;
+ private final int _ru;
+ private final MatrixBlock _m2;
+ private final MatrixBlock _ret;
+ private final BinaryOperator _op;
+
+ protected BinaryMVColTask(List<ColGroup> groups, MatrixBlock m2, MatrixBlock ret, int rl, int ru,
+ BinaryOperator op) {
+ _groups = groups;
+ _m2 = m2;
+ _ret = ret;
+ _op = op;
+ _rl = rl;
+ _ru = ru;
+ }
+
+ @Override
+ public Integer call() {
+
+ for(ColGroup g : _groups) {
+ g.decompressToBlock(_ret, _rl, _ru, _rl, g.getValues());
+ }
+
+ int nnz = 0;
+ DenseBlock db = _ret.getDenseBlock();
+ for(int row = _rl; row < _ru; row++) {
+ double vr = _m2.quickGetValue(row, 0);
+ for(int col = 0; col < _ret.getNumColumns(); col++) {
+ double v = _op.fn.execute(_ret.quickGetValue(row, col), vr);
+ nnz += (v != 0) ? 1 : 0;
+ db.set(row, col, v);
+ }
+ }
+
+ return nnz;
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/LibCompAgg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/LibCompAgg.java
index 4eb1970..940a849 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/LibCompAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/LibCompAgg.java
@@ -196,7 +196,8 @@
// compute all compressed column groups
ExecutorService pool = CommonThreadPool.get(op.getNumThreads());
ArrayList<UnaryAggregateOverlappingTask> tasks = new ArrayList<>();
- final int blklen = CompressionSettings.BITMAP_BLOCK_SZ / m1.getNumColumns();
+ final int blklen = Math.min(m1.getNumRows() /op.getNumThreads(), CompressionSettings.BITMAP_BLOCK_SZ) ;
+ // final int blklen = CompressionSettings.BITMAP_BLOCK_SZ ;/// m1.getNumColumns();
for(int i = 0; i * blklen < m1.getNumRows(); i++) {
tasks.add(new UnaryAggregateOverlappingTask(m1.getColGroups(), ret, i * blklen,
@@ -228,7 +229,6 @@
ret.recomputeNonZeros();
}
else if(op.indexFn instanceof ReduceCol) {
- // LOG.error("Here");
long nnz = 0;
for(int i = 0; i * blklen < m1.getNumRows(); i++) {
MatrixBlock tmp = rtasks.get(i).get();
@@ -248,11 +248,11 @@
Plus.getPlusFnObject()) : op.aggOp.increOp);
}
}
+ memPool.remove();
}
catch(InterruptedException | ExecutionException e) {
throw new DMLRuntimeException(e);
}
-
if(op.aggOp.existsCorrection() && inCP)
ret.dropLastRowsOrColumns(op.aggOp.correction);
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/LibLeftMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/LibLeftMultBy.java
index aa198d4..258f556 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/LibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/LibLeftMultBy.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.compress.lib;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
@@ -47,6 +48,13 @@
public class LibLeftMultBy {
private static final Log LOG = LogFactory.getLog(LibLeftMultBy.class.getName());
+ private static ThreadLocal<double[]> memPoolOLE = new ThreadLocal<double[]>() {
+ @Override
+ protected double[] initialValue() {
+ return null;
+ }
+ };
+
public static MatrixBlock leftMultByMatrix(List<ColGroup> groups, MatrixBlock that, MatrixBlock ret,
boolean doTranspose, boolean allocTmp, int rl, int cl, boolean overlapping, int k, Pair<Integer, int[]> v) {
@@ -54,7 +62,7 @@
ret = new MatrixBlock(rl, cl, false, rl * cl);
else if(!(ret.getNumColumns() == cl && ret.getNumRows() == rl && ret.isAllocated()))
ret.reset(rl, cl, false, rl * cl);
- if(that instanceof CompressedMatrixBlock){
+ if(that instanceof CompressedMatrixBlock) {
LOG.info("Decompression Left side Matrix (Should not really happen)");
}
that = that instanceof CompressedMatrixBlock ? ((CompressedMatrixBlock) that).decompress() : that;
@@ -101,7 +109,7 @@
int numColumns, Pair<Integer, int[]> v, boolean overlapping) {
ret.allocateDenseBlock();
if(that.isInSparseFormat()) {
- ret = leftMultBySparseMatrix(colGroups, that, ret, k, numColumns, v);
+ ret = leftMultBySparseMatrix(colGroups, that, ret, k, numColumns, v, overlapping);
}
else {
ret = leftMultByDenseMatrix(colGroups, that, ret, k, numColumns, v, overlapping);
@@ -130,7 +138,7 @@
blockU = Math.min(blockL + blockSize, ret.getNumRows());
thatV = db.valuesAt(b);
- if(k == 1 || overlapping) {
+ if(k == 1) {
// Pair<Integer, int[]> v = getMaxNumValues(colGroups);
for(int j = 0; j < colGroups.size(); j++) {
colGroups.get(j).leftMultByMatrix(thatV,
@@ -260,7 +268,7 @@
}
private static MatrixBlock leftMultBySparseMatrix(List<ColGroup> colGroups, MatrixBlock that, MatrixBlock ret,
- int k, int numColumns, Pair<Integer, int[]> v) {
+ int k, int numColumns, Pair<Integer, int[]> v, boolean overlapping) {
SparseBlock sb = that.getSparseBlock();
if(sb == null)
@@ -271,15 +279,15 @@
((ColGroupUncompressed) grp).leftMultByMatrix(that, ret);
}
- if(k == 1) {
- double[][] materialized = new double[colGroups.size()][];
- boolean containsOLE = false;
- for(int i = 0; i < colGroups.size(); i++) {
- materialized[i] = colGroups.get(i).getValues();
- if(colGroups.get(i) instanceof ColGroupOLE) {
- containsOLE = true;
- }
+ double[][] materialized = new double[colGroups.size()][];
+ boolean containsOLE = false;
+ for(int i = 0; i < colGroups.size(); i++) {
+ materialized[i] = colGroups.get(i).getValues();
+ if(colGroups.get(i) instanceof ColGroupOLE) {
+ containsOLE = true;
}
+ }
+ if(k == 1) {
double[] materializedRow = containsOLE ? new double[CompressionSettings.BITMAP_BLOCK_SZ * 2] : null;
for(int r = 0; r < that.getNumRows(); r++) {
@@ -305,17 +313,25 @@
ExecutorService pool = CommonThreadPool.get(k);
ArrayList<LeftMatrixSparseMatrixMultTask> tasks = new ArrayList<>();
try {
- // compute remaining compressed column groups in parallel
- // List<ColGroup>[] parts = createStaticTaskPartitioningForSparseMatrixMult(colGroups, k, false);
- // for(List<ColGroup> part : parts) {
- tasks.add(new LeftMatrixSparseMatrixMultTask(colGroups, sb, ret.getDenseBlockValues(),
- that.getNumRows(), numColumns, v));
- // }
+
+ for(int r = 0; r < that.getNumRows(); r++) {
+ if(overlapping) {
+ tasks.add(new LeftMatrixSparseMatrixMultTask(colGroups, materialized, sb,
+ ret.getDenseBlockValues(), that.getNumRows(), numColumns, v, r, r + 1));
+ }
+ else {
+ for(int i = 0; i < colGroups.size(); i++) {
+ tasks.add(new LeftMatrixSparseMatrixMultTask(colGroups.get(i), materialized, i, sb,
+ ret.getDenseBlockValues(), that.getNumRows(), numColumns, v, r, r + 1));
+ }
+ }
+ }
List<Future<Object>> futures = pool.invokeAll(tasks);
pool.shutdown();
for(Future<Object> future : futures)
future.get();
+ memPoolOLE.remove();
}
catch(InterruptedException | ExecutionException e) {
throw new DMLRuntimeException(e);
@@ -461,53 +477,89 @@
}
private static class LeftMatrixSparseMatrixMultTask implements Callable<Object> {
- private final List<ColGroup> _group;
+ private final List<ColGroup> _groups;
+ private final ColGroup _group;
+ private final int _i; // Used to identify the index for the materialized values.
private final SparseBlock _that;
private final double[] _ret;
private final int _numRows;
private final int _numCols;
private final Pair<Integer, int[]> _v;
+ private final double[][] _materialized;
+ private final int _rl;
+ private final int _ru;
- protected LeftMatrixSparseMatrixMultTask(List<ColGroup> group, SparseBlock that, double[] ret, int numRows,
- int numCols, Pair<Integer, int[]> v) {
- _group = group;
+ protected LeftMatrixSparseMatrixMultTask(List<ColGroup> group, double[][] materialized, SparseBlock that,
+ double[] ret, int numRows, int numCols, Pair<Integer, int[]> v, int rl, int ru) {
+ _groups = group;
+ _group = null;
+ _i = -1;
+ _materialized = materialized;
_that = that;
_ret = ret;
_numRows = numRows;
_numCols = numCols;
_v = v;
+ _rl = rl;
+ _ru = ru;
+ }
+
+ protected LeftMatrixSparseMatrixMultTask(ColGroup group, double[][] materialized, int i, SparseBlock that,
+ double[] ret, int numRows, int numCols, Pair<Integer, int[]> v, int rl, int ru) {
+ _groups = null;
+ _group = group;
+ _i = i;
+ _materialized = materialized;
+ _that = that;
+ _ret = ret;
+ _numRows = numRows;
+ _numCols = numCols;
+ _v = v;
+ _rl = rl;
+ _ru = ru;
}
@Override
public Object call() {
- // setup memory pool for reuse
-
- // double[][] materialized = new double[_group.size()][];
- // for(int i = 0; i < _group.size(); i++) {
- // materialized[i] = _group.get(i).getValues();
- // }
-
- boolean containsOLE = false;
- for(int j = 0; j < _group.size(); j++) {
- if(_group.get(j) instanceof ColGroupOLE) {
- containsOLE = true;
- }
- }
// Temporary Array to store 2 * block size in
- double[] tmpA = containsOLE ? new double[CompressionSettings.BITMAP_BLOCK_SZ * 2] : null;
+ double[] tmpA = memPoolOLE.get();
+ if(tmpA == null) {
+ tmpA = new double[CompressionSettings.BITMAP_BLOCK_SZ * 2];
+ }
+ else {
+ Arrays.fill(tmpA, 0.0);
+ }
ColGroupValue.setupThreadLocalMemory(_v.getLeft());
try {
- for(int j = 0; j < _group.size(); j++) {
- double[] materializedV = _group.get(j).getValues();
- for(int r = 0; r < _that.numRows(); r++) {
+ if(_groups != null) {
+ for(int j = 0; j < _groups.size(); j++) {
+ double[] materializedV = _materialized[j];
+ for(int r = _rl; r < _ru; r++) {
+ if(_that.get(r) != null) {
+ _groups.get(j).leftMultBySparseMatrix(_that.get(r).size(),
+ _that.get(r).indexes(),
+ _that.get(r).values(),
+ _ret,
+ _v.getRight()[j],
+ materializedV,
+ _numRows,
+ _numCols,
+ r,
+ tmpA);
+ }
+ }
+ }
+ }
+ else if(_group != null) {
+ for(int r = _rl; r < _ru; r++) {
if(_that.get(r) != null) {
- _group.get(j).leftMultBySparseMatrix(_that.get(r).size(),
+ _group.leftMultBySparseMatrix(_that.get(r).size(),
_that.get(r).indexes(),
_that.get(r).values(),
_ret,
- _v.getRight()[j],
- materializedV,
+ _v.getRight()[0],
+ _materialized[_i],
_numRows,
_numCols,
r,
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/LibRelationalOp.java b/src/main/java/org/apache/sysds/runtime/compress/lib/LibRelationalOp.java
index db87310..a19a3c0 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/LibRelationalOp.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/LibRelationalOp.java
@@ -21,7 +21,6 @@
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.BitSet;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
@@ -29,17 +28,13 @@
import java.util.concurrent.Future;
import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.DMLCompressionException;
import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.compress.BitmapEncoder;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressionSettings;
-import org.apache.sysds.runtime.compress.CompressionSettingsBuilder;
import org.apache.sysds.runtime.compress.colgroup.ColGroup;
-import org.apache.sysds.runtime.compress.colgroup.ColGroup.CompressionType;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
-import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
import org.apache.sysds.runtime.compress.colgroup.Dictionary;
-import org.apache.sysds.runtime.compress.utils.ABitmap;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.functionobjects.Equals;
@@ -71,16 +66,20 @@
}
};
- public static MatrixBlock relationalOperation(ScalarOperator sop, CompressedMatrixBlock m1,
- CompressedMatrixBlock ret, boolean overlapping) {
+ public static MatrixBlock relationalOperation(ScalarOperator sop, CompressedMatrixBlock m1, boolean overlapping) {
List<ColGroup> colGroups = m1.getColGroups();
if(overlapping) {
if(sop.fn instanceof LessThan || sop.fn instanceof LessThanEquals || sop.fn instanceof GreaterThan ||
- sop.fn instanceof GreaterThanEquals || sop.fn instanceof Equals || sop.fn instanceof NotEquals)
- return overlappingRelativeRelationalOperation(sop, m1, ret);
+ sop.fn instanceof GreaterThanEquals || sop.fn instanceof Equals || sop.fn instanceof NotEquals) {
+ return overlappingRelativeRelationalOperation(sop, m1);
+ }
+ else {
+ throw new DMLCompressionException("Invalid arguments to relational Operation");
+ }
}
else {
+ CompressedMatrixBlock ret = new CompressedMatrixBlock(m1.getNumRows(), m1.getNumColumns(), true);
List<ColGroup> newColGroups = new ArrayList<>();
for(ColGroup grp : colGroups) {
newColGroups.add(grp.scalarOperation(sop));
@@ -88,13 +87,12 @@
ret.allocateColGroupList(newColGroups);
ret.setNonZeros(-1);
ret.setOverlapping(false);
+ return (MatrixBlock) ret;
}
- return ret;
}
- private static MatrixBlock overlappingRelativeRelationalOperation(ScalarOperator sop, CompressedMatrixBlock m1,
- CompressedMatrixBlock ret) {
+ private static MatrixBlock overlappingRelativeRelationalOperation(ScalarOperator sop, CompressedMatrixBlock m1) {
List<ColGroup> colGroups = m1.getColGroups();
boolean less = ((sop.fn instanceof LessThan || sop.fn instanceof LessThanEquals) &&
@@ -102,7 +100,6 @@
(sop instanceof RightScalarOperator &&
(sop.fn instanceof GreaterThan || sop.fn instanceof GreaterThanEquals));
double v = sop.getConstant();
- // Queue<Pair<Double, ColGroup>> pq = new PriorityQueue<>();
MinMaxGroup[] minMax = new MinMaxGroup[colGroups.size()];
double maxS = 0.0;
double minS = 0.0;
@@ -125,120 +122,109 @@
if(v < minS || v > maxS) {
if(sop.fn instanceof Equals) {
- return makeConstZero(ret);
+ return makeConstZero(m1.getNumRows(), m1.getNumColumns());
}
else if(sop.fn instanceof NotEquals) {
- return makeConstOne(ret);
+ return makeConstOne(m1.getNumRows(), m1.getNumColumns());
}
else if(less) {
if(v < minS || ((sop.fn instanceof LessThanEquals || sop.fn instanceof GreaterThan) && v <= minS))
- return makeConstOne(ret);
+ return makeConstOne(m1.getNumRows(), m1.getNumColumns());
else
- return makeConstZero(ret);
-
+ return makeConstZero(m1.getNumRows(), m1.getNumColumns());
}
else {
if(v > minS || ((sop.fn instanceof LessThanEquals || sop.fn instanceof GreaterThan) && v >= minS))
- return makeConstOne(ret);
+ return makeConstOne(m1.getNumRows(), m1.getNumColumns());
else
- return makeConstZero(ret);
+ return makeConstZero(m1.getNumRows(), m1.getNumColumns());
}
}
else {
- return processNonConstant(sop, ret, minMax, minS, maxS, less);
+ return processNonConstant(sop, minMax, minS, maxS, m1.getNumRows(), m1.getNumColumns(), less);
}
}
- private static MatrixBlock makeConstOne(CompressedMatrixBlock ret) {
+ private static MatrixBlock makeConstOne(int rows, int cols) {
List<ColGroup> newColGroups = new ArrayList<>();
- int[] colIndexes = new int[ret.getNumColumns()];
+ int[] colIndexes = new int[cols];
for(int i = 0; i < colIndexes.length; i++) {
colIndexes[i] = i;
}
- double[] values = new double[ret.getNumColumns()];
+ double[] values = new double[cols];
Arrays.fill(values, 1);
- newColGroups.add(new ColGroupConst(colIndexes, ret.getNumRows(), new Dictionary(values)));
+ newColGroups.add(new ColGroupConst(colIndexes, rows, new Dictionary(values)));
+ CompressedMatrixBlock ret = new CompressedMatrixBlock(rows, cols, true);
ret.allocateColGroupList(newColGroups);
- ret.setNonZeros(ret.getNumColumns() * ret.getNumRows());
+ ret.setNonZeros(cols * rows);
ret.setOverlapping(false);
return ret;
}
- private static MatrixBlock makeConstZero(CompressedMatrixBlock ret) {
- MatrixBlock sb = new MatrixBlock(ret.getNumRows(), ret.getNumColumns(), true, 0);
+ private static MatrixBlock makeConstZero(int rows, int cols) {
+ MatrixBlock sb = new MatrixBlock(rows, cols, true, 0);
return sb;
}
- private static MatrixBlock processNonConstant(ScalarOperator sop, CompressedMatrixBlock ret, MinMaxGroup[] minMax,
- double minS, double maxS, boolean less) {
+ private static MatrixBlock processNonConstant(ScalarOperator sop, MinMaxGroup[] minMax, double minS, double maxS,
+ final int rows, final int cols, boolean less) {
- BitSet res = new BitSet(ret.getNumColumns() * ret.getNumRows());
+ // BitSet res = new BitSet(ret.getNumColumns() * ret.getNumRows());
+ MatrixBlock res = new MatrixBlock(rows, cols, true, 0).allocateBlock();
int k = OptimizerUtils.getConstrainedNumThreads(-1);
- int outRows = ret.getNumRows();
-
+ int outRows = rows;
+ long nnz = 0;
if(k == 1) {
- final int b = CompressionSettings.BITMAP_BLOCK_SZ / ret.getNumColumns();
+ final int b = CompressionSettings.BITMAP_BLOCK_SZ / cols;
final int blkz = (outRows < b) ? outRows : b;
- MatrixBlock tmp = new MatrixBlock(blkz, ret.getNumColumns(), false, -1).allocateBlock();
+ MatrixBlock tmp = new MatrixBlock(blkz, cols, false, -1).allocateBlock();
for(int i = 0; i * blkz < outRows; i++) {
-
- // LOG.error(mmg.g.getClass());
for(MinMaxGroup mmg : minMax) {
- mmg.g.decompressToBlock(tmp, i * blkz, Math.min((i + 1) * blkz, mmg.g.getNumRows()), 0, mmg.values);
- // minS -= mmg.min;
- // maxS -= mmg.max;
+ mmg.g.decompressToBlock(tmp, i * blkz, Math.min((i + 1) * blkz, rows), 0, mmg.values);
}
- for(int row = 0; row < blkz && row < ret.getNumRows() - i * blkz; row++) {
- int off = (row + i * blkz) * ret.getNumColumns();
- for(int col = 0; col < ret.getNumColumns(); col++, off++) {
- if(sop.executeScalar(tmp.quickGetValue(row, col)) != 0.0)
- res.set(off);
+ for(int row = 0; row < blkz && row < rows - i * blkz; row++) {
+ int off = (row + i * blkz);
+ for(int col = 0; col < cols; col++) {
+ res.quickSetValue(off, col, sop.executeScalar(tmp.quickGetValue(row, col)));
+ if(res.quickGetValue(off, col) != 0) {
+ nnz++;
+ }
}
}
- tmp.reset();
}
+ tmp.reset();
+ res.setNonZeros(nnz);
}
else {
- final int blkz = CompressionSettings.BITMAP_BLOCK_SZ / ret.getNumColumns();
+ final int blkz = CompressionSettings.BITMAP_BLOCK_SZ / cols;
ExecutorService pool = CommonThreadPool.get(k);
ArrayList<RelationalTask> tasks = new ArrayList<>();
+
try {
for(int i = 0; i * blkz < outRows; i++) {
- RelationalTask rt = new RelationalTask(minMax, i, blkz, res, ret.getNumRows(), ret.getNumColumns(),
- sop);
+ RelationalTask rt = new RelationalTask(minMax, i, blkz, res, rows, cols, sop);
tasks.add(rt);
}
List<Future<Object>> futures = pool.invokeAll(tasks);
pool.shutdown();
for(Future<Object> f : futures)
f.get();
- memPool.remove();
}
catch(InterruptedException | ExecutionException e) {
+ e.printStackTrace();
throw new DMLRuntimeException(e);
}
- }
- int[] colIndexes = new int[ret.getNumColumns()];
- for(int i = 0; i < colIndexes.length; i++) {
- colIndexes[i] = i;
}
- CompressionSettings cs = new CompressionSettingsBuilder().setTransposeInput(false).create();
- ABitmap bm = BitmapEncoder.extractBitmap(colIndexes, ret.getNumRows(), res, cs);
+ memPool.remove();
- ColGroup resGroup = ColGroupFactory.compress(colIndexes, ret.getNumRows(), bm, CompressionType.DDC, cs, null);
- List<ColGroup> newColGroups = new ArrayList<>();
- newColGroups.add(resGroup);
- ret.allocateColGroupList(newColGroups);
- ret.setNonZeros(ret.getNumColumns() * ret.getNumRows());
- ret.setOverlapping(false);
- return ret;
+ return res;
}
- protected static class MinMaxGroup {
+ protected static class MinMaxGroup implements Comparable<MinMaxGroup> {
double min;
double max;
ColGroup g;
@@ -250,18 +236,34 @@
this.g = g;
this.values = g.getValues();
}
+
+ @Override
+ public int compareTo(MinMaxGroup o) {
+ double t = max - min;
+ double ot = o.max - o.min;
+ return Double.compare(t, ot);
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append("MMG: ");
+ sb.append("[" + min + "," + max + "]");
+ sb.append(" " + g.getClass().getSimpleName());
+ return sb.toString();
+ }
}
private static class RelationalTask implements Callable<Object> {
private final MinMaxGroup[] _minMax;
private final int _i;
private final int _blkz;
- private final BitSet _res;
+ private final MatrixBlock _res;
private final int _rows;
private final int _cols;
private final ScalarOperator _sop;
- protected RelationalTask(MinMaxGroup[] minMax, int i, int blkz, BitSet res, int rows, int cols,
+ protected RelationalTask(MinMaxGroup[] minMax, int i, int blkz, MatrixBlock res, int rows, int cols,
ScalarOperator sop) {
_minMax = minMax;
_i = i;
@@ -286,17 +288,13 @@
for(MinMaxGroup mmg : _minMax) {
mmg.g.decompressToBlock(tmp, _i * _blkz, Math.min((_i + 1) * _blkz, mmg.g.getNumRows()), 0, mmg.values);
- // minS -= mmg.min;
- // maxS -= mmg.max;
- }
- for(int row = 0; row < _blkz && row < _rows - _i * _blkz; row++) {
- int off = (row + _i * _blkz) * _cols;
- for(int col = 0; col < _cols; col++, off++) {
- if(_sop.executeScalar(tmp.quickGetValue(row, col)) != 0.0)
- _res.set(off);
- }
}
+ for(int row = 0, off = _i * _blkz; row < _blkz && row < _rows - _i * _blkz; row++, off++) {
+ for(int col = 0; col < _cols; col++) {
+ _res.appendValue(off, col, _sop.executeScalar(tmp.quickGetValue(row, col)));
+ }
+ }
return null;
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/LibRightMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/LibRightMultBy.java
index a6b4219..b74cbc2 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/LibRightMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/LibRightMultBy.java
@@ -77,7 +77,7 @@
}
int rl = colGroups.get(0).getNumRows();
int cl = that.getNumColumns();
- if(!allowOverlap || (containsUncompressable || distinctCount >= rl / 2)) {
+ if(!allowOverlap || (containsUncompressable || distinctCount >= rl )) {
if(ret == null)
ret = new MatrixBlock(rl, cl, false, rl * cl);
else if(!(ret.getNumColumns() == cl && ret.getNumRows() == rl && ret.isAllocated()))
@@ -380,7 +380,6 @@
private static MatrixBlock rightMultBySparseMatrixCompressed(List<ColGroup> colGroups, MatrixBlock that,
CompressedMatrixBlock ret, int k, Pair<Integer, int[]> v) {
- // long StartTime = System.currentTimeMillis();
SparseBlock sb = that.getSparseBlock();
for(ColGroup grp : colGroups) {
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/LibScalar.java b/src/main/java/org/apache/sysds/runtime/compress/lib/LibScalar.java
index 4f9020a..b67b6ef 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/LibScalar.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/LibScalar.java
@@ -35,6 +35,7 @@
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
import org.apache.sysds.runtime.compress.colgroup.Dictionary;
+import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
@@ -50,9 +51,8 @@
// private static final Log LOG = LogFactory.getLog(LibScalar.class.getName());
private static final int MINIMUM_PARALLEL_SIZE = 8096;
- public static MatrixBlock scalarOperations(ScalarOperator sop, CompressedMatrixBlock m1,
- CompressedMatrixBlock ret, boolean overlapping)
- {
+ public static MatrixBlock scalarOperations(ScalarOperator sop, CompressedMatrixBlock m1, CompressedMatrixBlock ret,
+ boolean overlapping) {
if(sop instanceof LeftScalarOperator) {
if(sop.fn instanceof Minus) {
m1 = (CompressedMatrixBlock) scalarOperations(new RightScalarOperator(Multiply.getMultiplyFnObject(),
@@ -62,28 +62,16 @@
ret,
overlapping);
}
+ else if(sop.fn instanceof Divide){
+ throw new DMLCompressionException("Not supported left hand side divide Compressed");
+ }
else if(sop.fn instanceof Power2) {
throw new DMLCompressionException("Left Power does not make sense.");
- // List<ColGroup> newColGroups = new ArrayList<>();
- // double v = sop.executeScalar(0);
-
- // double[] values = new double[m1.getNumColumns()];
- // Arrays.fill(values, v);
-
- // int[] colIndexes = new int[m1.getNumColumns()];
- // for(int i = 0; i < colIndexes.length; i++) {
- // colIndexes[i] = i;
- // }
- // newColGroups.add(new ColGroupConst(colIndexes, ret.getNumRows(), new Dictionary(values)));
- // ret.allocateColGroupList(newColGroups);
- // ret.setNonZeros(ret.getNumColumns() * ret.getNumRows());
- // return ret;
}
-
}
List<ColGroup> colGroups = m1.getColGroups();
- if(overlapping && !(sop.fn instanceof Multiply)) {
+ if(overlapping && !(sop.fn instanceof Multiply || sop.fn instanceof Divide)) {
if(sop.fn instanceof Plus || sop.fn instanceof Minus) {
// If the colGroup is overlapping we know there are no incompressable colGroups.
@@ -103,7 +91,6 @@
}
}
else {
-
if(sop.getNumThreads() > 1) {
parallelScalarOperations(sop, colGroups, ret, sop.getNumThreads());
}
diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java
index f0cb1ea..2d4cfec 100644
--- a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java
+++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java
@@ -124,6 +124,7 @@
}
@Test
+ @Ignore
public void testCountDistinct() {
try {
// Counting distinct is potentially wrong in cases with overlapping, resulting in a few to many or few
diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
index d49a60a..2e13619 100644
--- a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
@@ -36,9 +36,11 @@
import org.apache.sysds.runtime.compress.CompressionSettingsBuilder;
import org.apache.sysds.runtime.compress.CompressionStatistics;
import org.apache.sysds.runtime.compress.colgroup.ColGroup.CompressionType;
+import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Equals;
import org.apache.sysds.runtime.functionobjects.GreaterThan;
import org.apache.sysds.runtime.functionobjects.GreaterThanEquals;
+import org.apache.sysds.runtime.functionobjects.LessThan;
import org.apache.sysds.runtime.functionobjects.LessThanEquals;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
@@ -60,6 +62,7 @@
import org.apache.sysds.test.component.compress.TestConstants.SparsityType;
import org.apache.sysds.test.component.compress.TestConstants.ValueRange;
import org.apache.sysds.test.component.compress.TestConstants.ValueType;
+import org.junit.Ignore;
import org.junit.Test;
import org.junit.runners.Parameterized.Parameters;
@@ -70,7 +73,7 @@
// SparsityType.FULL,
SparsityType.DENSE,
// SparsityType.SPARSE,
- SparsityType.ULTRA_SPARSE,
+ // SparsityType.ULTRA_SPARSE,
// SparsityType.EMPTY
};
@@ -78,20 +81,20 @@
// ValueType.RAND,
// ValueType.CONST,
ValueType.RAND_ROUND,
- ValueType.OLE_COMPRESSIBLE,
+ // ValueType.OLE_COMPRESSIBLE,
// ValueType.RLE_COMPRESSIBLE,
};
- protected static ValueRange[] usedValueRanges = new ValueRange[] {ValueRange.SMALL,
+ protected static ValueRange[] usedValueRanges = new ValueRange[] {
+ ValueRange.SMALL,
// ValueRange.LARGE,
- // ValueRange.BYTE
+ // ValueRange.BYTE,
+ ValueRange.BOOLEAN,
};
- protected static OverLapping[] overLapping = new OverLapping[] {
- OverLapping.COL,
+ protected static OverLapping[] overLapping = new OverLapping[] {OverLapping.COL,
// OverLapping.MATRIX,
- OverLapping.NONE,
- OverLapping.MATRIX_PLUS,
+ OverLapping.NONE, OverLapping.MATRIX_PLUS,
// OverLapping.MATRIX_MULT_NEGATIVE
};
@@ -136,12 +139,13 @@
};
protected static MatrixTypology[] usedMatrixTypology = new MatrixTypology[] { // Selected Matrix Types
- // MatrixTypology.SMALL, MatrixTypology.FEW_COL,
- MatrixTypology.FEW_ROW,
+ // MatrixTypology.SMALL,
+ MatrixTypology.FEW_COL,
+ // MatrixTypology.FEW_ROW,
// MatrixTypology.LARGE,
- // MatrixTypology.SINGLE_COL,
+ // // MatrixTypology.SINGLE_COL,
// MatrixTypology.SINGLE_ROW,
- MatrixTypology.L_ROWS,
+ // MatrixTypology.L_ROWS,
// MatrixTypology.XL_ROWS,
// MatrixTypology.SINGLE_COL_L
};
@@ -569,6 +573,14 @@
}
@Test
+ public void testScalarOpRightDivide() {
+ double mult = 0.2;
+ ScalarOperator sop = new RightScalarOperator(Divide.getDivideFnObject(), mult, _k);
+ testScalarOperations(sop, lossyTolerance * 7);
+ }
+
+
+ @Test
public void testScalarOpRightMultiplyNegative() {
double mult = -7;
ScalarOperator sop = new RightScalarOperator(Multiply.getMultiplyFnObject(), mult, _k);
@@ -592,6 +604,13 @@
@Test
public void testScalarRightOpLess() {
double addValue = 0.11;
+ ScalarOperator sop = new RightScalarOperator(LessThan.getLessThanFnObject(), addValue);
+ testScalarOperations(sop, lossyTolerance + 0.1);
+ }
+
+ @Test
+ public void testScalarRightOpLessThanEqual() {
+ double addValue = -50;
ScalarOperator sop = new RightScalarOperator(LessThanEquals.getLessThanEqualsFnObject(), addValue);
testScalarOperations(sop, lossyTolerance + 0.1);
}
@@ -648,7 +667,7 @@
@Test
public void testScalarLeftOpLess() {
double addValue = 0.11;
- ScalarOperator sop = new LeftScalarOperator(LessThanEquals.getLessThanEqualsFnObject(), addValue);
+ ScalarOperator sop = new LeftScalarOperator(LessThan.getLessThanFnObject(), addValue);
testScalarOperations(sop, lossyTolerance + 0.1);
}
@@ -660,6 +679,13 @@
}
@Test
+ public void testScalarLeftOpLessThanEqualSmallValue() {
+ double addValue = -1000000.11;
+ ScalarOperator sop = new LeftScalarOperator(LessThanEquals.getLessThanEqualsFnObject(), addValue);
+ testScalarOperations(sop, lossyTolerance + 0.1);
+ }
+
+ @Test
public void testScalarLeftOpGreaterThanEqualsSmallValue() {
double addValue = -1001310000.11;
ScalarOperator sop = new LeftScalarOperator(GreaterThanEquals.getGreaterThanEqualsFnObject(), addValue);
@@ -687,6 +713,14 @@
testScalarOperations(sop, lossyTolerance + 0.1);
}
+ @Test
+ @Ignore
+ public void testScalarLeftOpDivide() {
+ double addValue = 14.0;
+ ScalarOperator sop = new LeftScalarOperator(Divide.getDivideFnObject(), addValue);
+ testScalarOperations(sop, lossyTolerance + 0.1);
+ }
+
// @Test
// This test does not make sense to execute... since the result of left power always is 4.
// Furthermore it does not work consistently in our normal matrix blocks ... and should never be used.
@@ -726,41 +760,90 @@
}
@Test
- public void testBinaryMVAddition() {
-
+ public void testBinaryMVAdditionROW() {
ValueFunction vf = Plus.getPlusFnObject();
- testBinaryMV(vf);
+ MatrixBlock vector = DataConverter
+ .convertToMatrixBlock(TestUtils.generateTestMatrix(1, cols, -1.0, 1.5, 1.0, 3));
+ testBinaryMV(vf, vector);
}
@Test
- public void testBinaryMVMultiply() {
+ public void testBinaryMVAdditionCOL() {
+ ValueFunction vf = Plus.getPlusFnObject();
+ MatrixBlock vector = DataConverter
+ .convertToMatrixBlock(TestUtils.generateTestMatrix(rows, 1, -1.0, 1.5, 1.0, 3));
+ testBinaryMV(vf, vector);
+ }
+
+ @Test
+ public void testBinaryMVMultiplyROW() {
ValueFunction vf = Multiply.getMultiplyFnObject();
- testBinaryMV(vf);
+ MatrixBlock vector = DataConverter
+ .convertToMatrixBlock(TestUtils.generateTestMatrix(1, cols, -1.0, 1.5, 1.0, 3));
+ testBinaryMV(vf, vector);
}
@Test
- public void testBinaryMVMinus() {
+ public void testBinaryMVDivideROW() {
+ ValueFunction vf = Divide.getDivideFnObject();
+ MatrixBlock vector = DataConverter
+ .convertToMatrixBlock(TestUtils.generateTestMatrix(1, cols, -1.0, 1.5, 1.0, 3));
+ testBinaryMV(vf, vector);
+ }
+
+ @Test
+ @Ignore
+ public void testBinaryMVDivideROWLeft() {
+ ValueFunction vf = Divide.getDivideFnObject();
+ MatrixBlock vector = DataConverter
+ .convertToMatrixBlock(TestUtils.generateTestMatrix(1, cols, -1.0, 1.5, 1.0, 3));
+ testBinaryMV(vf, vector, false);
+ }
+
+ @Test
+ public void testBinaryMVMultiplyCOL() {
+ ValueFunction vf = Multiply.getMultiplyFnObject();
+ MatrixBlock vector = DataConverter
+ .convertToMatrixBlock(TestUtils.generateTestMatrix(rows, 1, -1.0, 1.5, 1.0, 3));
+ testBinaryMV(vf, vector);
+ }
+
+ @Test
+ public void testBinaryMVMinusROW() {
ValueFunction vf = Minus.getMinusFnObject();
- testBinaryMV(vf);
+ MatrixBlock vector = DataConverter
+ .convertToMatrixBlock(TestUtils.generateTestMatrix(1, cols, -1.0, 1.5, 1.0, 3));
+ testBinaryMV(vf, vector);
}
@Test
- public void testBinaryMVXor() {
+ public void testBinaryMVXorROW() {
ValueFunction vf = Xor.getXorFnObject();
- testBinaryMV(vf);
+ MatrixBlock vector = DataConverter
+ .convertToMatrixBlock(TestUtils.generateTestMatrix(1, cols, -1.0, 1.5, 1.0, 3));
+ testBinaryMV(vf, vector);
}
- public void testBinaryMV(ValueFunction vf) {
+ public void testBinaryMV(ValueFunction vf, MatrixBlock vector) {
+ testBinaryMV(vf, vector, true);
+ }
+
+ public void testBinaryMV(ValueFunction vf, MatrixBlock vector, boolean right) {
try {
if(!(cmb instanceof CompressedMatrixBlock))
return; // Input was not compressed then just pass test
BinaryOperator bop = new BinaryOperator(vf);
- MatrixBlock vector = DataConverter
- .convertToMatrixBlock(TestUtils.generateTestMatrix(1, cols, -1.0, 1.5, 1.0, 3));
-
- MatrixBlock ret1 = mb.binaryOperations(bop, vector, new MatrixBlock());
- MatrixBlock ret2 = cmb.binaryOperations(bop, vector, new MatrixBlock());
+ MatrixBlock ret1, ret2;
+ if(right) {
+ ret1 = mb.binaryOperations(bop, vector, new MatrixBlock());
+ ret2 = cmb.binaryOperations(bop, vector, new MatrixBlock());
+ }
+ else {
+ ret1 = vector.binaryOperations(bop, mb, new MatrixBlock());
+ ret2 = vector.binaryOperations(bop, cmb, new MatrixBlock());
+ }
+ // LOG.error(ret2);
if(ret2 instanceof CompressedMatrixBlock)
ret2 = ((CompressedMatrixBlock) ret2).decompress();
double[][] d1 = DataConverter.convertToDoubleMatrix(ret1);
@@ -781,40 +864,4 @@
}
}
- @Test
- public void testBinaryVMMultiply() {
- ValueFunction vf = Multiply.getMultiplyFnObject();
- testBinaryVM(vf);
- }
-
- public void testBinaryVM(ValueFunction vf) {
- try {
- if(!(cmb instanceof CompressedMatrixBlock))
- return; // Input was not compressed then just pass test
- // NOTE THIS METHOD DECOMPRESSES AND MULTIPLIES
-
- BinaryOperator bop = new BinaryOperator(vf);
- MatrixBlock vector = DataConverter
- .convertToMatrixBlock(TestUtils.generateTestMatrix(rows, 1, -1.0, 1.5, 1.0, 3));
-
- MatrixBlock ret1 = mb.binaryOperations(bop, vector, new MatrixBlock());
- MatrixBlock ret2 = cmb.binaryOperations(bop, vector, new MatrixBlock());
-
- double[][] d1 = DataConverter.convertToDoubleMatrix(ret1);
- double[][] d2 = DataConverter.convertToDoubleMatrix(ret2);
-
- if(compressionSettings.lossy)
- TestUtils.compareMatrices(d1, d2, lossyTolerance * 2, this.toString());
- else if(overlappingType == OverLapping.MATRIX_MULT_NEGATIVE || overlappingType == OverLapping.MATRIX_PLUS ||
- overlappingType == OverLapping.MATRIX || overlappingType == OverLapping.COL)
- TestUtils.compareMatricesBitAvgDistance(d1, d2, 65536, 512, this.toString());
- else
- TestUtils.compareMatricesBitAvgDistance(d1, d2, 150, 1, this.toString());
-
- }
- catch(Exception e) {
- e.printStackTrace();
- throw new RuntimeException(this.toString() + "\n" + e.getMessage(), e);
- }
- }
}
diff --git a/src/test/java/org/apache/sysds/test/component/compress/TestConstants.java b/src/test/java/org/apache/sysds/test/component/compress/TestConstants.java
index 097b851..e9cfb1d 100644
--- a/src/test/java/org/apache/sysds/test/component/compress/TestConstants.java
+++ b/src/test/java/org/apache/sysds/test/component/compress/TestConstants.java
@@ -24,8 +24,8 @@
*/
public class TestConstants {
- private static final int rows[] = {4, 2008, 1283, 5, 1, 100, 5000, 100000, 64000*2};
- private static final int cols[] = {20, 20, 13, 998, 321, 1, 5, 1, 1};
+ private static final int rows[] = {4, 2008, 1283, 500, 1, 100, 5000, 100000, 64000*2};
+ private static final int cols[] = {20, 20, 13, 1, 321, 1, 5, 1, 1};
private static final double[] sparsityValues = {0.9, 0.1, 0.01, 0.0, 1.0};
private static final int[] mins = {-10, -127 * 2};
@@ -56,7 +56,7 @@
}
public enum ValueRange {
- SMALL, LARGE, BYTE
+ SMALL, LARGE, BYTE, BOOLEAN
}
public enum OverLapping{
@@ -88,6 +88,8 @@
return mins[1];
case BYTE:
return -127;
+ case BOOLEAN:
+ return 0;
default:
throw new RuntimeException("Invalid range value enum type");
}
@@ -101,6 +103,8 @@
return maxs[1];
case BYTE:
return 127;
+ case BOOLEAN:
+ return 1;
default:
throw new RuntimeException("Invalid range value enum type");
}