[SYSTEMDS-3575] Column Group get Compression Info
This commit allows one to extract the compression info,
from a column group. The implementation is basic and does not
consider if the user want to get information about different potential
compression plans for the individual column groups.
A follow up task is to extract more information to enable
morphing [SYSTEMDS-3578]
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index 3c4751e..c31fd69 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -227,6 +227,23 @@
}
/**
+ * Get the column group allocated and associated with a specific column Id;
+ *
+ * There is some search involved in this since we do not know where to look for the column and which Column group
+ * contains the value.
+ *
+ * @param id The column id or number we try to find
+ * @return The column group for that column
+ */
+ public AColGroup getColGroupForColumn(int id) {
+ for(AColGroup g : _colGroups) {
+ if(g.getColIndices().contains(id))
+ return g;
+ }
+ return null;
+ }
+
+ /**
* Decompress block into a MatrixBlock
*
* @param k degree of parallelism
@@ -409,7 +426,7 @@
AColGroup cg = ColGroupUncompressed.create(uncompressed);
allocateColGroup(cg);
// update non zeros, if not fully correct in compressed block
- nonZeros = cg.getNumberNonZeros(rlen);
+ nonZeros = cg.getNumberNonZeros(rlen);
// Clear the soft reference to the decompressed version,
// since the one column group is perfectly,
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
index 1a7241a..8a6478f 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
@@ -457,6 +457,19 @@
private Pair<MatrixBlock, CompressionStatistics> recompress(CompressedMatrixBlock cmb) {
LOG.debug("Recompressing an already compressed MatrixBlock");
LOG.warn("Not Implemented Recompress yet");
+
+ classifyPhase();
+ // informationExtractor = ComEstFactory.createEstimator(mb, compSettings, k);
+
+ // compressionGroups = informationExtractor.computeCompressedSizeInfos(k);
+
+ // _stats.estimatedSizeCols = compressionGroups.memoryEstimate();
+ // _stats.estimatedCostCols = costEstimator.getCost(compressionGroups);
+
+ // logPhase();
+
+
+
return new ImmutablePair<>(cmb, null);
// _stats.originalSize = cmb.getInMemorySize();
// CompressedMatrixBlock combined = CLALibCombineGroups.combine(cmb, k);
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java
index 24e187f..964dc08 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java
@@ -22,6 +22,8 @@
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
+import org.apache.sysds.runtime.compress.estim.EstimationFactors;
/**
* Column group that sparsely encodes the dictionary values. The idea is that all values is encoded with indexes except
@@ -52,4 +54,12 @@
public AOffset getOffsets() {
return _indexes;
}
+
+ public abstract int getNumberOffsets();
+
+ @Override
+ public final CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
+ EstimationFactors ef = new EstimationFactors(getNumValues(), _numRows, getNumberOffsets(), _dict.getSparsity());
+ return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType());
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java
index 6e63e3b..4ce3946 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java
@@ -24,6 +24,8 @@
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
+import org.apache.sysds.runtime.compress.estim.EstimationFactors;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -216,8 +218,16 @@
return _indexes;
}
+ public abstract int getNumberOffsets();
+
@Override
public double[] getDefaultTuple() {
return new double[_colIndexes.size()];
}
+
+ @Override
+ public final CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
+ EstimationFactors ef = new EstimationFactors(getNumValues(), _numRows, getNumberOffsets(), _dict.getSparsity());
+ return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType());
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
index 544e61b..f885790 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
@@ -38,6 +38,7 @@
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
+import org.apache.sysds.runtime.compress.estim.EstimationFactors;
import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.data.DenseBlock;
@@ -554,7 +555,8 @@
@Override
public CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
- throw new NotImplementedException();
+ EstimationFactors ef = new EstimationFactors(getNumValues(), _data.size(), _data.size(), _dict.getSparsity());
+ return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType());
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java
index bb0abe8..39399b5 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java
@@ -36,6 +36,7 @@
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
+import org.apache.sysds.runtime.compress.estim.EstimationFactors;
import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.functionobjects.Builtin;
@@ -468,7 +469,8 @@
@Override
public CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
- throw new NotImplementedException();
+ EstimationFactors ef = new EstimationFactors(getNumValues(), _data.size(), _data.size(), _dict.getSparsity());
+ return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType());
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index 25e7d5b..37817d4 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -24,7 +24,6 @@
import java.io.IOException;
import java.util.Arrays;
-import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
@@ -41,7 +40,6 @@
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
-import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.functionobjects.Builtin;
@@ -636,11 +634,6 @@
}
@Override
- public CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
- throw new NotImplementedException();
- }
-
- @Override
public IEncode getEncoding() {
return EncodingFactory.create(_data, _indexes, _numRows);
}
@@ -666,6 +659,11 @@
}
@Override
+ public int getNumberOffsets() {
+ return _data.size();
+ }
+
+ @Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(super.toString());
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
index edca22a..d05930d 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java
@@ -39,7 +39,6 @@
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
-import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.functionobjects.Builtin;
@@ -498,11 +497,6 @@
}
@Override
- public CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
- throw new NotImplementedException();
- }
-
- @Override
public IEncode getEncoding() {
return EncodingFactory.create(_data, _indexes, _numRows);
}
@@ -523,6 +517,11 @@
}
@Override
+ public int getNumberOffsets() {
+ return _data.size();
+ }
+
+ @Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(super.toString());
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
index d05adbc..f4c11f2 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
@@ -39,7 +39,6 @@
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
-import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.functionobjects.Builtin;
@@ -599,13 +598,13 @@
}
@Override
- public CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
- throw new NotImplementedException();
+ public IEncode getEncoding() {
+ return EncodingFactory.create(new MapToZero(getCounts()[0]), _indexes, _numRows);
}
@Override
- public IEncode getEncoding() {
- return EncodingFactory.create(new MapToZero(getCounts()[0]), _indexes, _numRows);
+ public int getNumberOffsets() {
+ return getCounts()[0];
}
@Override
@@ -625,7 +624,7 @@
@Override
protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) {
return ColGroupSDCSingle.create(newColIndex, getNumRows(), _dict.reorder(reordering),
- ColGroupUtils.reorderDefault(_defaultTuple, reordering), _indexes, getCachedCounts());
+ ColGroupUtils.reorderDefault(_defaultTuple, reordering), _indexes, getCachedCounts());
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java
index d882fba..8a90b2a 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java
@@ -40,7 +40,6 @@
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
-import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.data.DenseBlock;
@@ -874,13 +873,13 @@
}
@Override
- public CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
- throw new NotImplementedException();
+ public IEncode getEncoding() {
+ return EncodingFactory.create(new MapToZero(getCounts()[0]), _indexes, _numRows);
}
@Override
- public IEncode getEncoding() {
- return EncodingFactory.create(new MapToZero(getCounts()[0]), _indexes, _numRows);
+ public int getNumberOffsets() {
+ return getCounts()[0];
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
index 457926a..68eb214 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
@@ -24,7 +24,6 @@
import java.io.IOException;
import java.util.Arrays;
-import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
@@ -41,7 +40,6 @@
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
-import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.data.DenseBlock;
@@ -806,8 +804,8 @@
}
@Override
- public CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
- throw new NotImplementedException();
+ public int getNumberOffsets() {
+ return _data.size();
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
index bb2633a..ffa9656 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
@@ -29,6 +29,8 @@
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
+import org.apache.sysds.runtime.compress.CompressionSettings;
+import org.apache.sysds.runtime.compress.CompressionSettingsBuilder;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictLibMatrixMult;
@@ -38,6 +40,9 @@
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
+import org.apache.sysds.runtime.compress.estim.EstimationFactors;
+import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
+import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.data.DenseBlock;
@@ -828,7 +833,12 @@
@Override
public CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
- throw new NotImplementedException();
+ final IEncode map = EncodingFactory.createFromMatrixBlock(_data, false,
+ ColIndexFactory.create(_data.getNumColumns()));
+ final int _numRows = _data.getNumRows();
+ final CompressionSettings _cs = new CompressionSettingsBuilder().create();// default settings
+ final EstimationFactors em = map.extractFacts(_numRows, _data.getSparsity(), _data.getSparsity(), _cs);
+ return new CompressedSizeInfoColGroup(_colIndexes, em, _cs.validCompressions, map);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ArrayIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ArrayIndex.java
index c2acadc..711236c 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ArrayIndex.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ArrayIndex.java
@@ -191,7 +191,6 @@
for(int i = 1; i < cols.length; i++)
if(cols[i - 1] > cols[i])
return false;
-
return true;
}
@@ -203,6 +202,12 @@
return ColIndexFactory.create(ret);
}
+ @Override
+ public boolean contains(int i) {
+ int id = Arrays.binarySearch(cols, 0, cols.length, i);
+ return id >= 0;
+ }
+
protected class ArrayIterator implements IIterate {
int id = 0;
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java
index ff0acab..8b73abf 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java
@@ -170,6 +170,14 @@
*/
public IColIndex sort();
+ /**
+ * Analyze if this column group contain the given column id
+ *
+ * @param i id to search for
+ * @return if it is contained
+ */
+ public boolean contains(int i);
+
/** A Class for slice results containing indexes for the slicing of dictionaries, and the resulting column index */
public static class SliceResult {
/** Start index to slice inside the dictionary */
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java
index 717c7b8..7705c58 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java
@@ -26,18 +26,40 @@
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.utils.IntArrayList;
+/**
+ * A Range index that contain a lower and upper bound of the indexes that is symbolize.
+ *
+ * The upper bound is not inclusive
+ */
public class RangeIndex extends AColIndex {
+ /** Lower bound inclusive */
private final int l;
- private final int u; // not inclusive
+ /** Upper bound not inclusive */
+ private final int u;
+ /**
+ * Construct an range index from 0 until the given nCol, not inclusive
+ *
+ * @param nCol The upper index not included
+ */
public RangeIndex(int nCol) {
- l = 0;
- u = nCol;
+ this(0, nCol);
}
+ /** Construct an range index */
+
+ /**
+ * Construct an range index with lower and upper values given.
+ *
+ * @param l lower index
+ * @param u Upper index
+ */
public RangeIndex(int l, int u) {
this.l = l;
this.u = u;
+
+ if(l >= u)
+ throw new DMLCompressionException("Invalid construction of Range Index with l: " + l + " u: " + u);
}
@Override
@@ -213,6 +235,11 @@
throw new DMLCompressionException("range is always sorted");
}
+ @Override
+ public boolean contains(int i) {
+ return l <= i && i < u;
+ }
+
protected class RangeIterator implements IIterate {
int cl = l;
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/SingleIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/SingleIndex.java
index b9bef4f..5d10b2a 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/SingleIndex.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/SingleIndex.java
@@ -126,6 +126,11 @@
}
@Override
+ public boolean contains(int i) {
+ return i == idx;
+ }
+
+ @Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(this.getClass().getSimpleName());
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoIndex.java
index fae8dc7..9e3e848 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoIndex.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoIndex.java
@@ -155,6 +155,11 @@
}
@Override
+ public boolean contains(int i) {
+ return i == id1 || i == id2;
+ }
+
+ @Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(this.getClass().getSimpleName());
diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java b/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java
index 153786f..145682f 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java
@@ -222,7 +222,7 @@
protected abstract CompressedSizeInfoColGroup combine(IColIndex combinedColumns, CompressedSizeInfoColGroup g1,
CompressedSizeInfoColGroup g2, int maxDistinct);
- private List<CompressedSizeInfoColGroup> CompressedSizeInfoColGroup(int clen, int k) {
+ protected List<CompressedSizeInfoColGroup> CompressedSizeInfoColGroup(int clen, int k) {
if(k <= 1)
return CompressedSizeInfoColGroupSingleThread(clen);
else
diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java
new file mode 100644
index 0000000..7f8ccfe
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.estim;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.CompressionSettings;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
+
+public class ComEstCompressed extends AComEst {
+
+ final CompressedMatrixBlock cData;
+
+ protected ComEstCompressed(CompressedMatrixBlock data, CompressionSettings compSettings) {
+ super(data, compSettings);
+ cData = data;
+ }
+
+ @Override
+ protected List<CompressedSizeInfoColGroup> CompressedSizeInfoColGroup(int clen, int k) {
+ List<CompressedSizeInfoColGroup> ret = new ArrayList<CompressedSizeInfoColGroup>();
+ final int nRow = cData.getNumRows();
+ for(AColGroup g : cData.getColGroups()) {
+ ret.add(g.getCompressionInfo(nRow));
+ }
+ return ret;
+ }
+
+ @Override
+ public CompressedSizeInfoColGroup getColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) {
+
+ // final IEncode map =
+ throw new UnsupportedOperationException("Unimplemented method 'getColGroupInfo'");
+ }
+
+ @Override
+ public CompressedSizeInfoColGroup getDeltaColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) {
+ throw new UnsupportedOperationException("Unimplemented method 'getDeltaColGroupInfo'");
+ }
+
+ @Override
+ protected int worstCaseUpperBound(IColIndex columns) {
+ if(columns.size() == 1) {
+ int id = columns.get(0);
+ AColGroup g = cData.getColGroupForColumn(id);
+ return g.getNumValues();
+ }
+ else
+ throw new UnsupportedOperationException("Unimplemented method 'worstCaseUpperBound'");
+ }
+
+ @Override
+ protected CompressedSizeInfoColGroup combine(IColIndex combinedColumns, CompressedSizeInfoColGroup g1,
+ CompressedSizeInfoColGroup g2, int maxDistinct) {
+ throw new UnsupportedOperationException("Unimplemented method 'combine'");
+ }
+
+}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java
index a0b0e6f..070788b 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java
@@ -21,6 +21,7 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -36,6 +37,9 @@
* @return A new CompressionSizeEstimator used to extract information of column groups
*/
public static AComEst createEstimator(MatrixBlock data, CompressionSettings cs, int k) {
+ if(data instanceof CompressedMatrixBlock)
+ return createCompressedEstimator((CompressedMatrixBlock) data, cs);
+
final int nRows = cs.transposed ? data.getNumColumns() : data.getNumRows();
final int nCols = cs.transposed ? data.getNumRows() : data.getNumColumns();
final double sparsity = data.getSparsity();
@@ -54,14 +58,12 @@
* @param k The parallelization degree
* @return A new CompressionSizeEstimator used to extract information of column groups
*/
- public static AComEst createEstimator(MatrixBlock data, CompressionSettings cs, int sampleSize,
- int k) {
+ public static AComEst createEstimator(MatrixBlock data, CompressionSettings cs, int sampleSize, int k) {
final int nRows = cs.transposed ? data.getNumColumns() : data.getNumRows();
return createEstimator(data, cs, sampleSize, k, nRows);
}
- private static AComEst createEstimator(MatrixBlock data, CompressionSettings cs, int sampleSize,
- int k, int nRows) {
+ private static AComEst createEstimator(MatrixBlock data, CompressionSettings cs, int sampleSize, int k, int nRows) {
if(sampleSize >= (double) nRows * 0.8) // if sample size is larger than 80% use entire input as sample.
return createExactEstimator(data, cs);
else
@@ -73,8 +75,12 @@
return new ComEstExact(data, cs);
}
- private static ComEstSample createSampleEstimator(MatrixBlock data, CompressionSettings cs,
- int sampleSize, int k) {
+ private static ComEstCompressed createCompressedEstimator(CompressedMatrixBlock data, CompressionSettings cs) {
+ LOG.debug("Using Compressed Estimator");
+ return new ComEstCompressed(data, cs);
+ }
+
+ private static ComEstSample createSampleEstimator(MatrixBlock data, CompressionSettings cs, int sampleSize, int k) {
LOG.debug("Using sample size: " + sampleSize);
return new ComEstSample(data, cs, sampleSize, k);
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
index 57c9f8b..8de7360 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
@@ -79,7 +79,7 @@
_sizes.put(bestCompressionType, _minSize);
}
- protected CompressedSizeInfoColGroup(IColIndex columns, EstimationFactors facts,
+ public CompressedSizeInfoColGroup(IColIndex columns, EstimationFactors facts,
Set<CompressionType> validCompressionTypes, IEncode map) {
_cols = columns;
_facts = facts;