[SYSTEMDS-3490] Compressed Transform Encode
Transform encode fused with compression. Making a compressed output from
the frame input depending on the transformations applied. Initial results
are very promising transforming single threaded at the same speed as our
tuned multithreaded version.
This commit contains the bare minimum for the transform encode, and
following commits will add more transformation pipelines.
Currently supported is Recode to dummy, recode, and pass through in
very naive implementations.
Also contained is an IdentityDictionary implementation that allows
one to specify that the compressed dictionary simply is the identity
matrix. This allocation is very small of a object and a integer specifying
the number of rows and columns contained in the Identity matrix.
To make the implementation efficient initially a soft reference to a
materialized MatrixBlock dictionary is materialized in all not supported
cases of operations the IdentityDictionary.
Closes #1772
diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index 0573da2..ab81ff4 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -75,12 +75,12 @@
* Value types (int, double, string, boolean, unknown).
*/
public enum ValueType {
- UINT8, // Used for parsing in UINT values from numpy.
+ UINT4, UINT8, // Used for parsing in UINT values from numpy.
FP32, FP64, INT32, INT64, BOOLEAN, STRING, UNKNOWN,
CHARACTER;
public boolean isNumeric() {
- return this == UINT8 || this == INT32 || this == INT64 || this == FP32 || this == FP64;
+ return this == UINT8 || this == INT32 || this == INT64 || this == FP32 || this == FP64 || this== UINT4;
}
public boolean isUnknown() {
return this == UNKNOWN;
@@ -92,6 +92,7 @@
switch(this) {
case FP32:
case FP64: return "DOUBLE";
+ case UINT4:
case UINT8:
case INT32:
case INT64: return "INT";
@@ -107,6 +108,7 @@
case "FP32": return FP32;
case "FP64":
case "DOUBLE": return FP64;
+ case "UINT4": return UINT4;
case "UINT8": return UINT8;
case "INT32": return INT32;
case "INT64":
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index dad670e..46580c2 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -86,6 +86,7 @@
public static final String COMPRESSED_COCODE = "sysds.compressed.cocode";
public static final String COMPRESSED_COST_MODEL= "sysds.compressed.costmodel";
public static final String COMPRESSED_TRANSPOSE = "sysds.compressed.transpose";
+ public static final String COMPRESSED_TRANSFORMENCODE = "sysds.compressed.transformencode";
public static final String NATIVE_BLAS = "sysds.native.blas";
public static final String NATIVE_BLAS_DIR = "sysds.native.blas.directory";
public static final String DAG_LINEARIZATION = "sysds.compile.linearization";
@@ -167,6 +168,7 @@
_defaultVals.put(COMPRESSED_COCODE, "AUTO");
_defaultVals.put(COMPRESSED_COST_MODEL, "AUTO");
_defaultVals.put(COMPRESSED_TRANSPOSE, "auto");
+ _defaultVals.put(COMPRESSED_TRANSFORMENCODE, "false");
_defaultVals.put(DAG_LINEARIZATION, DagLinearization.DEPTH_FIRST.name());
_defaultVals.put(CODEGEN, "false" );
_defaultVals.put(CODEGEN_API, GeneratorAPI.JAVA.name() );
@@ -450,7 +452,7 @@
CP_PARALLEL_OPS, CP_PARALLEL_IO, PARALLEL_ENCODE, NATIVE_BLAS, NATIVE_BLAS_DIR,
COMPRESSED_LINALG, COMPRESSED_LOSSY, COMPRESSED_VALID_COMPRESSIONS, COMPRESSED_OVERLAPPING,
COMPRESSED_SAMPLING_RATIO, COMPRESSED_SOFT_REFERENCE_COUNT,
- COMPRESSED_COCODE, COMPRESSED_TRANSPOSE, DAG_LINEARIZATION,
+ COMPRESSED_COCODE, COMPRESSED_TRANSPOSE, COMPRESSED_TRANSFORMENCODE, DAG_LINEARIZATION,
CODEGEN, CODEGEN_API, CODEGEN_COMPILER, CODEGEN_OPTIMIZER, CODEGEN_PLANCACHE, CODEGEN_LITERALS,
STATS_MAX_WRAP_LEN, LINEAGECACHESPILL, COMPILERASSISTED_RW, BUFFERPOOL_LIMIT, MEMORY_MANAGER,
PRINT_GPU_MEMORY_INFO, AVAILABLE_GPUS, SYNCHRONIZE_GPU, EAGER_CUDA_FREE, FLOATING_POINT_PRECISION,
diff --git a/src/main/java/org/apache/sysds/hops/LiteralOp.java b/src/main/java/org/apache/sysds/hops/LiteralOp.java
index 75bc73d..5d3f06b 100644
--- a/src/main/java/org/apache/sysds/hops/LiteralOp.java
+++ b/src/main/java/org/apache/sysds/hops/LiteralOp.java
@@ -246,6 +246,7 @@
switch( getValueType() ) {
case BOOLEAN:
return String.valueOf(value_boolean);
+ case UINT4:
case UINT8:
case INT32:
case INT64:
diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index 09896c1..80e9f75 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -615,6 +615,7 @@
case INT64:
case INT32:
case UINT8:
+ case UINT4:
case BOOLEAN:
output.setValueType(ValueType.INT64);
break;
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index 0a0c4b8..7f47f6d 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -399,10 +399,11 @@
@Override
public void write(DataOutput out) throws IOException {
+ // LOG.error(this);
if(nonZeros > 0 && getExactSizeOnDisk() > MatrixBlock.estimateSizeOnDisk(rlen, clen, nonZeros)) {
// If the size of this matrixBlock is smaller in uncompressed format, then
// decompress and save inside an uncompressed column group.
- MatrixBlock uncompressed = getUncompressed("for smaller serialization");
+ MatrixBlock uncompressed = getUncompressed("smaller serialization size");
ColGroupUncompressed cg = (ColGroupUncompressed) ColGroupUncompressed.create(uncompressed);
allocateColGroup(cg);
nonZeros = cg.getNumberNonZeros(rlen);
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
index 9eee973..9154f1a 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
@@ -646,7 +646,10 @@
// count distinct items frequencies
for(int j = apos; j < alen; j++)
- map.increment(vals[j]);
+ if(!Double.isNaN(vals[j]))
+ map.increment(vals[j]);
+ else
+ map.increment(0);
DCounts[] entries = map.extractValues();
Arrays.sort(entries, Comparator.comparing(x -> -x.count));
@@ -668,7 +671,10 @@
else {
final AMapToData mapToData = MapToFactory.create((alen - apos), entries.length);
for(int j = apos; j < alen; j++)
+ if(!Double.isNaN(vals[j]))
mapToData.set(j - apos, map.get(vals[j]));
+ else
+ mapToData.set(j - apos, map.get(0.0));
return ColGroupSDCZeros.create(cols, nRow, Dictionary.create(dict), offsets, mapToData, counts);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
index dd9557d..0651e78 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
@@ -44,7 +44,7 @@
protected static final Log LOG = LogFactory.getLog(ADictionary.class.getName());
public static enum DictType {
- Delta, Dict, MatrixBlock, UInt8;
+ Delta, Dict, MatrixBlock, UInt8, Identity;
}
/**
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
index a777201..7437df1 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java
@@ -41,7 +41,7 @@
static final Log LOG = LogFactory.getLog(DictionaryFactory.class.getName());
public enum Type {
- FP64_DICT, MATRIX_BLOCK_DICT, INT8_DICT
+ FP64_DICT, MATRIX_BLOCK_DICT, INT8_DICT, IDENTITY
}
public static ADictionary read(DataInput in) throws IOException {
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
new file mode 100644
index 0000000..23d5c36
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
@@ -0,0 +1,584 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.colgroup.dictionary;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.lang.ref.SoftReference;
+import java.util.Arrays;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
+import org.apache.sysds.runtime.functionobjects.ValueFunction;
+import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
+import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
+
+public class IdentityDictionary extends ADictionary {
+
+ private static final long serialVersionUID = 2535887782150955098L;
+
+ // final private MatrixBlock _data;
+ private final int nRowCol;
+
+ private SoftReference<MatrixBlockDictionary> cache = null;
+
+ /**
+ * Create a Identity matrix dictionary. It behaves as if allocated a Sparse Matrix block but exploits that the
+ * structure is known to have certain properties.
+ *
+ * @param nRowCol the number of rows and columns in this identity matrix.
+ */
+ public IdentityDictionary(int nRowCol) {
+ if(nRowCol <= 0)
+ throw new DMLCompressionException("Invalid Identity Dictionary");
+ this.nRowCol = nRowCol;
+ }
+
+ @Override
+ public double[] getValues() {
+ LOG.warn("Should not call getValues on Identity Dictionary");
+
+ double[] ret = new double[nRowCol * nRowCol];
+ for(int i = 0; i < nRowCol; i++) {
+ ret[(i * nRowCol) + i] = 1;
+ }
+ return ret;
+ }
+
+ @Override
+ public double getValue(int i) {
+ final int nCol = nRowCol;
+ final int row = i / nCol;
+ if(row > nRowCol)
+ return 0;
+ final int col = i % nCol;
+ return row == col ? 1 : 0;
+ }
+
+ @Override
+ public final double getValue(int r, int c, int nCol) {
+ return r == c ? 1 : 0;
+ }
+
+ @Override
+ public long getInMemorySize() {
+ return 4 + 4 + 8; // int + padding + softReference
+ }
+
+ public static long getInMemorySize(int numberColumns) {
+ return 4 + 4 + 8;
+ }
+
+ @Override
+ public double aggregate(double init, Builtin fn) {
+ if(fn.getBuiltinCode() == BuiltinCode.MAX)
+ return fn.execute(init, 1);
+ else if(fn.getBuiltinCode() == BuiltinCode.MIN)
+ return fn.execute(init, 0);
+ else
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public double aggregateWithReference(double init, Builtin fn, double[] reference, boolean def) {
+ return getMBDict().aggregateWithReference(init, fn, reference, def);
+ }
+
+ @Override
+ public double[] aggregateRows(Builtin fn, int nCol) {
+ double[] ret = new double[nRowCol];
+ Arrays.fill(ret, fn.execute(1, 0));
+ return ret;
+ }
+
+ @Override
+ public double[] aggregateRowsWithDefault(Builtin fn, double[] defaultTuple) {
+ return getMBDict().aggregateRowsWithDefault(fn, defaultTuple);
+ }
+
+ @Override
+ public double[] aggregateRowsWithReference(Builtin fn, double[] reference) {
+ return getMBDict().aggregateRowsWithReference(fn, reference);
+ }
+
+ @Override
+ public void aggregateCols(double[] c, Builtin fn, int[] colIndexes) {
+ for(int i = 0; i < nRowCol; i++) {
+ final int idx = colIndexes[i];
+ c[idx] = fn.execute(c[idx], 0);
+ c[idx] = fn.execute(c[idx], 1);
+ }
+ }
+
+ @Override
+ public void aggregateColsWithReference(double[] c, Builtin fn, int[] colIndexes, double[] reference, boolean def) {
+ getMBDict().aggregateColsWithReference(c, fn, colIndexes, reference, def);
+ }
+
+ @Override
+ public ADictionary applyScalarOp(ScalarOperator op) {
+ return getMBDict().applyScalarOp(op);
+ }
+
+ @Override
+ public ADictionary applyScalarOpAndAppend(ScalarOperator op, double v0, int nCol) {
+
+ return getMBDict().applyScalarOpAndAppend(op, v0, nCol);
+ }
+
+ @Override
+ public ADictionary applyUnaryOp(UnaryOperator op) {
+ return getMBDict().applyUnaryOp(op);
+ }
+
+ @Override
+ public ADictionary applyUnaryOpAndAppend(UnaryOperator op, double v0, int nCol) {
+ return getMBDict().applyUnaryOpAndAppend(op, v0, nCol);
+ }
+
+ @Override
+ public ADictionary applyScalarOpWithReference(ScalarOperator op, double[] reference, double[] newReference) {
+ return getMBDict().applyScalarOpWithReference(op, reference, newReference);
+ }
+
+ @Override
+ public ADictionary applyUnaryOpWithReference(UnaryOperator op, double[] reference, double[] newReference) {
+ return getMBDict().applyUnaryOpWithReference(op, reference, newReference);
+ }
+
+ @Override
+ public ADictionary binOpLeft(BinaryOperator op, double[] v, int[] colIndexes) {
+ return getMBDict().binOpLeft(op, v, colIndexes);
+ }
+
+ @Override
+ public ADictionary binOpLeftAndAppend(BinaryOperator op, double[] v, int[] colIndexes) {
+ return getMBDict().binOpLeftAndAppend(op, v, colIndexes);
+ }
+
+ @Override
+ public ADictionary binOpLeftWithReference(BinaryOperator op, double[] v, int[] colIndexes, double[] reference,
+ double[] newReference) {
+ return getMBDict().binOpLeftWithReference(op, v, colIndexes, reference, newReference);
+
+ }
+
+ @Override
+ public ADictionary binOpRight(BinaryOperator op, double[] v, int[] colIndexes) {
+ return getMBDict().binOpRight(op, v, colIndexes);
+ }
+
+ @Override
+ public ADictionary binOpRightAndAppend(BinaryOperator op, double[] v, int[] colIndexes) {
+ return getMBDict().binOpRightAndAppend(op, v, colIndexes);
+ }
+
+ @Override
+ public ADictionary binOpRight(BinaryOperator op, double[] v) {
+ return getMBDict().binOpRight(op, v);
+ }
+
+ @Override
+ public ADictionary binOpRightWithReference(BinaryOperator op, double[] v, int[] colIndexes, double[] reference,
+ double[] newReference) {
+ return getMBDict().binOpRightWithReference(op, v, colIndexes, reference, newReference);
+ }
+
+ @Override
+ public ADictionary clone() {
+ return new IdentityDictionary(nRowCol);
+ }
+
+ @Override
+ public DictType getDictType() {
+ return DictType.Identity;
+ }
+
+ @Override
+ public int getNumberOfValues(int ncol) {
+ return nRowCol;
+ }
+
+ @Override
+ public double[] sumAllRowsToDouble(int nrColumns) {
+ double[] ret = new double[nRowCol];
+ Arrays.fill(ret, 1);
+ return ret;
+ }
+
+ @Override
+ public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) {
+ double[] ret = new double[nRowCol];
+ Arrays.fill(ret, 1);
+ for(int i = 0; i < defaultTuple.length; i++)
+ ret[i] += defaultTuple[i];
+ return ret;
+ }
+
+ @Override
+ public double[] sumAllRowsToDoubleWithReference(double[] reference) {
+ double[] ret = new double[nRowCol];
+ Arrays.fill(ret, 1);
+ for(int i = 0; i < reference.length; i++)
+ ret[i] += reference[i] * nRowCol;
+ return ret;
+ }
+
+ @Override
+ public double[] sumAllRowsToDoubleSq(int nrColumns) {
+ double[] ret = new double[nRowCol];
+ Arrays.fill(ret, 1);
+ return ret;
+ }
+
+ @Override
+ public double[] sumAllRowsToDoubleSqWithDefault(double[] defaultTuple) {
+ return getMBDict().sumAllRowsToDoubleSqWithDefault(defaultTuple);
+ }
+
+ @Override
+ public double[] sumAllRowsToDoubleSqWithReference(double[] reference) {
+ return getMBDict().sumAllRowsToDoubleSqWithReference(reference);
+ }
+
+ @Override
+ public double[] productAllRowsToDouble(int nCol) {
+ return new double[nRowCol];
+ }
+
+ @Override
+ public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) {
+ return new double[nRowCol];
+ }
+
+ @Override
+ public double[] productAllRowsToDoubleWithReference(double[] reference) {
+ return getMBDict().productAllRowsToDoubleWithReference(reference);
+ }
+
+ @Override
+ public void colSum(double[] c, int[] counts, int[] colIndexes) {
+ for(int i = 0; i < colIndexes.length; i++) {
+ // very nice...
+ final int idx = colIndexes[i];
+ c[idx] = counts[i];
+ }
+ }
+
+ @Override
+ public void colSumSq(double[] c, int[] counts, int[] colIndexes) {
+ colSum(c, counts, colIndexes);
+ }
+
+ @Override
+ public void colProduct(double[] res, int[] counts, int[] colIndexes) {
+ for(int i = 0; i < colIndexes.length; i++) {
+ res[colIndexes[i]] = 0;
+ }
+ }
+
+ @Override
+ public void colProductWithReference(double[] res, int[] counts, int[] colIndexes, double[] reference) {
+ getMBDict().colProductWithReference(res, counts, colIndexes, reference);
+
+ }
+
+ @Override
+ public void colSumSqWithReference(double[] c, int[] counts, int[] colIndexes, double[] reference) {
+ getMBDict().colSumSqWithReference(c, counts, colIndexes, reference);
+ }
+
+ @Override
+ public double sum(int[] counts, int ncol) {
+ double s = 0.0;
+ for(int v : counts)
+ s += v;
+ return s;
+ }
+
+ @Override
+ public double sumSq(int[] counts, int ncol) {
+ return sum(counts, ncol);
+ }
+
+ @Override
+ public double sumSqWithReference(int[] counts, double[] reference) {
+ return getMBDict().sumSqWithReference(counts, reference);
+ }
+
+ @Override
+ public ADictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns) {
+ return getMBDict().sliceOutColumnRange(idxStart, idxEnd, previousNumberOfColumns);
+ }
+
+ @Override
+ public boolean containsValue(double pattern) {
+ return pattern == 0.0 || pattern == 1.0;
+ }
+
+ @Override
+ public boolean containsValueWithReference(double pattern, double[] reference) {
+ return getMBDict().containsValueWithReference(pattern, reference);
+ }
+
+ @Override
+ public long getNumberNonZeros(int[] counts, int nCol) {
+ return (long) sum(counts, nCol);
+ }
+
+ @Override
+ public long getNumberNonZerosWithReference(int[] counts, double[] reference, int nRows) {
+ return getMBDict().getNumberNonZerosWithReference(counts, reference, nRows);
+ }
+
+ @Override
+ public void addToEntry(final double[] v, final int fr, final int to, final int nCol) {
+ getMBDict().addToEntry(v, fr, to, nCol);
+ }
+
+ @Override
+ public void addToEntry(final double[] v, final int fr, final int to, final int nCol, int rep) {
+ getMBDict().addToEntry(v, fr, to, nCol, rep);
+ }
+
+ @Override
+ public void addToEntryVectorized(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, int t1,
+ int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol) {
+ getMBDict().addToEntryVectorized(v, f1, f2, f3, f4, f5, f6, f7, f8, t1, t2, t3, t4, t5, t6, t7, t8, nCol);
+ }
+
+ @Override
+ public ADictionary subtractTuple(double[] tuple) {
+ return getMBDict().subtractTuple(tuple);
+ }
+
+ public MatrixBlockDictionary getMBDict() {
+ return getMBDict(nRowCol);
+ }
+
+ @Override
+ public MatrixBlockDictionary getMBDict(int nCol) {
+ if(cache != null) {
+ MatrixBlockDictionary r = cache.get();
+ if(r != null)
+ return r;
+ }
+ MatrixBlockDictionary ret = createMBDict();
+ cache = new SoftReference<>(ret);
+ return ret;
+ }
+
+ private MatrixBlockDictionary createMBDict() {
+ MatrixBlock identity = new MatrixBlock(nRowCol, nRowCol, true);
+ for(int i = 0; i < nRowCol; i++)
+ identity.quickSetValue(i, i, 1.0);
+
+ return new MatrixBlockDictionary(identity);
+ }
+
+ @Override
+ public String getString(int colIndexes) {
+ return "IdentityMatrix of size: " + nRowCol;
+ }
+
+ @Override
+ public String toString() {
+ return "IdentityMatrix of size: " + nRowCol;
+ }
+
+ @Override
+ public ADictionary scaleTuples(int[] scaling, int nCol) {
+ return getMBDict().scaleTuples(scaling, nCol);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeByte(DictionaryFactory.Type.IDENTITY.ordinal());
+ out.writeInt(nRowCol);
+ }
+
+ public static IdentityDictionary read(DataInput in) throws IOException {
+ return new IdentityDictionary(in.readInt());
+ }
+
+ @Override
+ public long getExactSizeOnDisk() {
+ return 1 + 4;
+ }
+
+ @Override
+ public ADictionary preaggValuesFromDense(final int numVals, final int[] colIndexes, final int[] aggregateColumns,
+ final double[] b, final int cut) {
+ return getMBDict().preaggValuesFromDense(numVals, colIndexes, aggregateColumns, b, cut);
+ }
+
+ @Override
+ public ADictionary replace(double pattern, double replace, int nCol) {
+ if(containsValue(pattern))
+ return getMBDict().replace(pattern, replace, nCol);
+ else
+ return this;
+ }
+
+ @Override
+ public ADictionary replaceWithReference(double pattern, double replace, double[] reference) {
+ if(containsValueWithReference(pattern, reference))
+ return getMBDict().replaceWithReference(pattern, replace, reference);
+ else
+ return this;
+ }
+
+ @Override
+ public void product(double[] ret, int[] counts, int nCol) {
+ getMBDict().product(ret, counts, nCol);
+ }
+
+ @Override
+ public void productWithDefault(double[] ret, int[] counts, double[] def, int defCount) {
+ getMBDict().productWithDefault(ret, counts, def, defCount);
+ }
+
+ @Override
+ public void productWithReference(double[] ret, int[] counts, double[] reference, int refCount) {
+ getMBDict().productWithReference(ret, counts, reference, refCount);
+ }
+
+ @Override
+ public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] counts, int nRows) {
+ return getMBDict().centralMoment(ret, fn, counts, nRows);
+ }
+
+ @Override
+ public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret, ValueFunction fn, int[] counts, double def,
+ int nRows) {
+ return getMBDict().centralMomentWithDefault(ret, fn, counts, def, nRows);
+ }
+
+ @Override
+ public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction fn, int[] counts, double reference,
+ int nRows) {
+ return getMBDict().centralMomentWithReference(ret, fn, counts, reference, nRows);
+ }
+
+ @Override
+ public ADictionary rexpandCols(int max, boolean ignore, boolean cast, int nCol) {
+ return getMBDict().rexpandCols(max, ignore, cast, nCol);
+ }
+
+ @Override
+ public ADictionary rexpandColsWithReference(int max, boolean ignore, boolean cast, int reference) {
+ return getMBDict().rexpandColsWithReference(max, ignore, cast, reference);
+ }
+
+ @Override
+ public double getSparsity() {
+ // non-zeros / n cells
+ // nRowCol / (nRowCol * nRowCol)
+ // simplifies to
+ return 1.0d / (double) nRowCol;
+ }
+
+ @Override
+ public void multiplyScalar(double v, double[] ret, int off, int dictIdx, int[] cols) {
+ getMBDict().multiplyScalar(v, ret, off, dictIdx, cols);
+ }
+
+ @Override
+ protected void TSMMWithScaling(int[] counts, int[] rows, int[] cols, MatrixBlock ret) {
+ getMBDict().TSMMWithScaling(counts, rows, cols, ret);
+ }
+
+ @Override
+ protected void MMDict(ADictionary right, int[] rowsLeft, int[] colsRight, MatrixBlock result) {
+ getMBDict().MMDict(right, rowsLeft, colsRight, result);
+ // should replace with add to right to output cells.
+ }
+
+ @Override
+ protected void MMDictDense(double[] left, int[] rowsLeft, int[] colsRight, MatrixBlock result) {
+ getMBDict().MMDictDense(left, rowsLeft, colsRight, result);
+ // should replace with add to right to output cells.
+ }
+
+ @Override
+ protected void MMDictSparse(SparseBlock left, int[] rowsLeft, int[] colsRight, MatrixBlock result) {
+ getMBDict().MMDictSparse(left, rowsLeft, colsRight, result);
+ }
+
+ @Override
+ protected void TSMMToUpperTriangle(ADictionary right, int[] rowsLeft, int[] colsRight, MatrixBlock result) {
+ getMBDict().TSMMToUpperTriangle(right, rowsLeft, colsRight, result);
+ }
+
+ @Override
+ protected void TSMMToUpperTriangleDense(double[] left, int[] rowsLeft, int[] colsRight, MatrixBlock result) {
+ getMBDict().TSMMToUpperTriangleDense(left, rowsLeft, colsRight, result);
+ }
+
+ @Override
+ protected void TSMMToUpperTriangleSparse(SparseBlock left, int[] rowsLeft, int[] colsRight, MatrixBlock result) {
+ getMBDict().TSMMToUpperTriangleSparse(left, rowsLeft, colsRight, result);
+ }
+
+ @Override
+ protected void TSMMToUpperTriangleScaling(ADictionary right, int[] rowsLeft, int[] colsRight, int[] scale,
+ MatrixBlock result) {
+ getMBDict().TSMMToUpperTriangleScaling(right, rowsLeft, colsRight, scale, result);
+ }
+
+ @Override
+ protected void TSMMToUpperTriangleDenseScaling(double[] left, int[] rowsLeft, int[] colsRight, int[] scale,
+ MatrixBlock result) {
+ getMBDict().TSMMToUpperTriangleDenseScaling(left, rowsLeft, colsRight, scale, result);
+ }
+
+ @Override
+ protected void TSMMToUpperTriangleSparseScaling(SparseBlock left, int[] rowsLeft, int[] colsRight, int[] scale,
+ MatrixBlock result) {
+
+ getMBDict().TSMMToUpperTriangleSparseScaling(left, rowsLeft, colsRight, scale, result);
+ }
+
+ @Override
+ public boolean equals(ADictionary o) {
+ if(o instanceof IdentityDictionary)
+ return ((IdentityDictionary) o).nRowCol == nRowCol;
+
+ MatrixBlock mb = getMBDict().getMatrixBlock();
+ if(o instanceof MatrixBlockDictionary)
+ return mb.equals(((MatrixBlockDictionary) o).getMatrixBlock());
+ else if(o instanceof Dictionary) {
+ if(mb.isInSparseFormat())
+ return mb.getSparseBlock().equals(((Dictionary) o)._values, nRowCol);
+ final double[] dv = mb.getDenseBlockValues();
+ return Arrays.equals(dv, ((Dictionary) o)._values);
+ }
+
+ return false;
+ }
+
+}
diff --git a/src/main/java/org/apache/sysds/runtime/data/LibTensorAgg.java b/src/main/java/org/apache/sysds/runtime/data/LibTensorAgg.java
index 7136bdb..6ea603e 100644
--- a/src/main/java/org/apache/sysds/runtime/data/LibTensorAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/data/LibTensorAgg.java
@@ -246,6 +246,7 @@
}
case INT64:
case INT32:
+ case UINT4:
case UINT8: {
DenseBlock a = in.getDenseBlock();
long sum = 0;
diff --git a/src/main/java/org/apache/sysds/runtime/data/TensorBlock.java b/src/main/java/org/apache/sysds/runtime/data/TensorBlock.java
index c9f8b6c..5047ee2 100644
--- a/src/main/java/org/apache/sysds/runtime/data/TensorBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/data/TensorBlock.java
@@ -648,6 +648,8 @@
long size = 8 + 1;
if (!bt.isSparse()) {
switch (bt._vt) {
+ case UINT4:
+ size += getLength() / 2 + getLength() % 2;
case UINT8:
size += 1 * getLength(); break;
case INT32:
diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
index 3e04d54..69c3a99 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
@@ -71,7 +71,6 @@
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
-import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.DMVUtils;
import org.apache.sysds.runtime.util.EMAUtils;
@@ -88,9 +87,6 @@
/** Buffer size variable: 1M elements, size of default matrix block */
public static final int BUFFER_SIZE = 1 * 1000 * 1000;
- /** internal configuration */
- private static final boolean REUSE_RECODE_MAPS = true;
-
/** The schema of the data frame as an ordered list of value types */
private ValueType[] _schema = null;
@@ -169,18 +165,17 @@
}
/**
- * allocate a FrameBlock with the given data arrays.
+ * allocate a FrameBlock with the given data arrays.
*
- * The data is in row major, making the first dimension number of rows.
- * second number of columns.
+ * The data is in row major, making the first dimension number of rows. second number of columns.
*
* @param schema the schema to allocate
- * @param names The names of the column
- * @param data The data.
+ * @param names The names of the column
+ * @param data The data.
*/
public FrameBlock(ValueType[] schema, String[] names, String[][] data) {
_schema = schema;
- if(names != null){
+ if(names != null) {
_colnames = names;
if(schema.length != names.length)
throw new DMLRuntimeException("Invalid FrameBlock construction, invalid schema and names combination");
@@ -821,9 +816,11 @@
.map(x -> x.getInMemorySize()).reduce(0L, Long::sum);
}).get();
pool.shutdown();
+
}
catch(InterruptedException | ExecutionException e) {
pool.shutdown();
+ LOG.error(e);
for(Array<?> aa : _coldata)
size += aa.getInMemorySize();
}
@@ -831,6 +828,7 @@
else {
for(Array<?> aa : _coldata)
size += aa.getInMemorySize();
+
}
}
return size;
@@ -1187,34 +1185,8 @@
* @param col is the column # from frame data which contains Recode map generated earlier.
* @return map of token and code for every element in the input column of a frame containing Recode map
*/
- public HashMap<String, Long> getRecodeMap(int col) {
- // probe cache for existing map
- if(REUSE_RECODE_MAPS) {
- SoftReference<HashMap<String, Long>> tmp = _coldata[col].getCache();
- HashMap<String, Long> map = (tmp != null) ? tmp.get() : null;
- if(map != null)
- return map;
- }
-
- // construct recode map
- HashMap<String, Long> map = new HashMap<>();
- Array<?> ldata = _coldata[col];
- int nRow = _coldata[0].size();
- if(nRow != _nRow)
- throw new DMLRuntimeException("Invalid intermediate size:" + nRow + " " + _nRow);
- for(int i = 0; i < getNumRows(); i++) {
- Object val = ldata.get(i);
- if(val != null) {
- String[] tmp = ColumnEncoderRecode.splitRecodeMapEntry(val.toString());
- map.put(tmp[0], Long.parseLong(tmp[1]));
- }
- }
-
- // put created map into cache
- if(REUSE_RECODE_MAPS)
- _coldata[col].setCache(new SoftReference<>(map));
-
- return map;
+ public HashMap<Object, Long> getRecodeMap(int col) {
+ return _coldata[col].getRecodeMap();
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
index 2be9e10..e706672 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java
@@ -21,6 +21,7 @@
import java.lang.ref.SoftReference;
import java.util.HashMap;
+import java.util.Iterator;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
@@ -37,9 +38,11 @@
*/
public abstract class Array<T> implements Writable {
protected static final Log LOG = LogFactory.getLog(Array.class.getName());
+ /** internal configuration */
+ private static final boolean REUSE_RECODE_MAPS = true;
/** A soft reference to a memorization of this arrays mapping, used in transformEncode */
- protected SoftReference<HashMap<String, Long>> _rcdMapCache = null;
+ protected SoftReference<HashMap<T, Long>> _rcdMapCache = null;
/** The current allocated number of elements in this Array */
protected int _size;
@@ -59,7 +62,7 @@
*
* @return The cached object
*/
- public final SoftReference<HashMap<String, Long>> getCache() {
+ public final SoftReference<HashMap<T, Long>> getCache() {
return _rcdMapCache;
}
@@ -68,10 +71,43 @@
*
* @param m The element to cache.
*/
- public final void setCache(SoftReference<HashMap<String, Long>> m) {
+ public final void setCache(SoftReference<HashMap<T, Long>> m) {
_rcdMapCache = m;
}
+ public HashMap<T, Long> getRecodeMap() {
+ // probe cache for existing map
+ if(REUSE_RECODE_MAPS) {
+ SoftReference<HashMap<T, Long>> tmp = getCache();
+ HashMap<T, Long> map = (tmp != null) ? tmp.get() : null;
+ if(map != null)
+ return map;
+ }
+
+ // construct recode map
+ HashMap<T, Long> map = createRecodeMap();
+
+ // put created map into cache
+ if(REUSE_RECODE_MAPS)
+ setCache(new SoftReference<>(map));
+
+ return map;
+ }
+
+
+ protected HashMap<T, Long> createRecodeMap(){
+ HashMap<T, Long> map = new HashMap<>();
+ long id = 0;
+ for(int i = 0; i < size(); i++) {
+ T val = get(i);
+ if(val != null && !map.containsKey(val))
+ map.put(val, id++);
+ }
+ return map;
+ }
+
+
+
/**
* Get the number of elements in the array, this does not necessarily reflect the current allocated size.
*
@@ -306,6 +342,15 @@
return null;
}
+ /**
+ * analyze if the array contains null values.
+ *
+ * @return If the array contains null.
+ */
+ public boolean containsNull(){
+ return false;
+ }
+
public Array<?> changeTypeWithNulls(ValueType t) {
final ABooleanArray nulls = getNulls();
if(nulls == null)
@@ -321,6 +366,7 @@
return new OptionalArray<Float>(changeTypeFloat(), nulls);
case FP64:
return new OptionalArray<Double>(changeTypeDouble(), nulls);
+ case UINT4:
case UINT8:
throw new NotImplementedException();
case INT32:
@@ -354,6 +400,7 @@
return changeTypeFloat();
case FP64:
return changeTypeDouble();
+ case UINT4:
case UINT8:
throw new NotImplementedException();
case INT32:
@@ -520,4 +567,26 @@
return this.getClass().getSimpleName();
}
+
+ public ArrayIterator getIterator(){
+ return new ArrayIterator();
+ }
+
+ public class ArrayIterator implements Iterator<T> {
+ int index = -1;
+
+ public int getIndex(){
+ return index;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return index < size()-1;
+ }
+
+ @Override
+ public T next() {
+ return get(++index);
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java
index 88c6ff2..8af5623 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java
@@ -82,6 +82,7 @@
return Array.baseMemoryCost() + (long) MemoryEstimates.longArrayCost(_numRows);
case FP64:
return Array.baseMemoryCost() + (long) MemoryEstimates.doubleArrayCost(_numRows);
+ case UINT4:
case UINT8:
case INT32:
return Array.baseMemoryCost() + (long) MemoryEstimates.intArrayCost(_numRows);
@@ -111,6 +112,7 @@
return new OptionalArray<>(new BitSetArray(nRow), true);
else
return new OptionalArray<>(new BooleanArray(new boolean[nRow]), true);
+ case UINT4:
case UINT8:
case INT32:
return new OptionalArray<>(new IntegerArray(new int[nRow]), true);
@@ -140,6 +142,7 @@
switch(v) {
case BOOLEAN:
return allocateBoolean(nRow);
+ case UINT4:
case UINT8:
case INT32:
return new IntegerArray(new int[nRow]);
@@ -261,8 +264,9 @@
return FloatArray.parseFloat(s);
case FP64:
return DoubleArray.parseDouble(s);
- case INT32:
+ case UINT4:
case UINT8:
+ case INT32:
return IntegerArray.parseInt(s);
case INT64:
return LongArray.parseLong(s);
diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java
index 32a9c86..fe85c25 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java
@@ -426,6 +426,11 @@
}
@Override
+ public boolean containsNull(){
+ return !_n.isAllTrue();
+ }
+
+ @Override
public String toString() {
StringBuilder sb = new StringBuilder(_size + 2);
sb.append(super.toString() + "<" + _a.getClass().getSimpleName() + ">:[");
diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
index 250b876..862014b 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java
@@ -24,6 +24,7 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.BitSet;
+import java.util.HashMap;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.common.Types.ValueType;
@@ -32,6 +33,7 @@
import org.apache.sysds.runtime.frame.data.lib.FrameUtil;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.matrix.data.Pair;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
import org.apache.sysds.utils.MemoryEstimates;
public class StringArray extends Array<String> {
@@ -580,6 +582,14 @@
return false;
return true;
}
+
+ @Override
+ public boolean containsNull(){
+ for(int i = 0; i < _data.length; i++)
+ if(_data[i] == null)
+ return true;
+ return false;
+ }
@Override
public Array<String> select(int[] indices) {
@@ -605,6 +615,28 @@
}
@Override
+ protected HashMap<String, Long> createRecodeMap(){
+ try{
+
+ HashMap<String, Long> map = new HashMap<>();
+ for(int i = 0; i < size(); i++) {
+ Object val = get(i);
+ if(val != null) {
+ String[] tmp = ColumnEncoderRecode.splitRecodeMapEntry(val.toString());
+ map.put(tmp[0], Long.parseLong(tmp[1]));
+ }
+ else // once we hit null return.
+ break;
+ }
+ return map;
+ }
+ catch(Exception e){
+ return super.createRecodeMap();
+ }
+ }
+
+
+ @Override
public String toString() {
StringBuilder sb = new StringBuilder(_size * 5 + 2);
sb.append(super.toString() + ":[");
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
index 6ec120c..4a10ae0 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
@@ -152,7 +152,7 @@
MultiColumnEncoder encoder = EncoderFactory
.createEncoder(spec, colnames, fo.getSchema(), (int) fo.getNumColumns(), meta);
mcOut.setDimension(mcIn.getRows() - ((omap != null) ? omap.getNumRmRows() : 0),
- (int) fo.getNumColumns() + encoder.getNumExtraCols());
+ (int) encoder.getNumOutCols());
Broadcast<MultiColumnEncoder> bmeta = sec.getSparkContext().broadcast(encoder);
Broadcast<TfOffsetMap> bomap = (omap != null) ? sec.getSparkContext().broadcast(omap) : null;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index f4c29ea..e5b8fea 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -502,7 +502,7 @@
MultiColumnEncoder encoder = EncoderFactory
.createEncoder(params.get("spec"), colnames, fo.getSchema(), (int) fo.getNumColumns(), meta);
mcOut.setDimension(mcIn.getRows() - ((omap != null) ? omap.getNumRmRows() : 0),
- (int) fo.getNumColumns() + encoder.getNumExtraCols());
+ (int)encoder.getNumOutCols());
Broadcast<MultiColumnEncoder> bmeta = sec.getSparkContext().broadcast(encoder);
Broadcast<TfOffsetMap> bomap = (omap != null) ? sec.getSparkContext().broadcast(omap) : null;
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index 5ffd101..610e0cc 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -206,6 +206,11 @@
// do nothing
}
+ public int getDomainSize(){
+ return 1;
+ }
+
+
/**
* Partial build of internal data structures (e.g., in distributed spark operations).
*
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
index 3809af8..b532dc0 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
@@ -387,6 +387,15 @@
}
}
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName());
+ sb.append(": ");
+ sb.append(_colID);
+ return sb.toString();
+ }
+
public enum BinMethod {
INVALID, EQUI_WIDTH, EQUI_HEIGHT
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index 1060aa2..a033bfa 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -107,12 +107,16 @@
@Override
public void build(CacheBlock<?> in, Map<Integer, double[]> equiHeightMaxs) {
- for(ColumnEncoder columnEncoder : _columnEncoders)
- if(columnEncoder instanceof ColumnEncoderBin && ((ColumnEncoderBin) columnEncoder).getBinMethod() == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) {
- columnEncoder.build(in, equiHeightMaxs.get(columnEncoder.getColID()));
- } else {
- columnEncoder.build(in);
- }
+ if(equiHeightMaxs == null)
+ build(in);
+ else{
+ for(ColumnEncoder columnEncoder : _columnEncoders)
+ if(columnEncoder instanceof ColumnEncoderBin && ((ColumnEncoderBin) columnEncoder).getBinMethod() == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) {
+ columnEncoder.build(in, equiHeightMaxs.get(columnEncoder.getColID()));
+ } else {
+ columnEncoder.build(in);
+ }
+ }
}
@Override
@@ -321,9 +325,7 @@
sb.append("CompositeEncoder(").append(_columnEncoders.size()).append("):\n");
for(ColumnEncoder columnEncoder : _columnEncoders) {
sb.append("-- ");
- sb.append(columnEncoder.getClass().getSimpleName());
- sb.append(": ");
- sb.append(columnEncoder._colID);
+ sb.append(columnEncoder);
sb.append("\n");
}
return sb.toString();
@@ -410,6 +412,28 @@
}).collect(Collectors.toSet());
}
+ @Override
+ public int getDomainSize() {
+ return _columnEncoders.stream()//
+ .map(ColumnEncoder::getDomainSize).reduce(Integer::max).get();
+ }
+
+
+ public boolean isRecodeToDummy(){
+ return _columnEncoders.size() == 2 //
+ && _columnEncoders.get(0) instanceof ColumnEncoderRecode //
+ && _columnEncoders.get(1) instanceof ColumnEncoderDummycode;
+ }
+
+ public boolean isRecode(){
+ return _columnEncoders.size() == 1 //
+ && _columnEncoders.get(0) instanceof ColumnEncoderRecode;
+ }
+
+ public boolean isPassThrough(){
+ return _columnEncoders.size() == 1 //
+ && _columnEncoders.get(0) instanceof ColumnEncoderPassThrough;
+ }
private static class ColumnCompositeUpdateDCTask implements Callable<Object> {
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
index 970df3a..f30743f 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
@@ -266,10 +266,22 @@
return result;
}
+ @Override
public int getDomainSize() {
return _domainSize;
}
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName());
+ sb.append(": ");
+ sb.append(_colID);
+ sb.append(" --- DomainSize : ");
+ sb.append(_domainSize);
+ return sb.toString();
+ }
+
private static class DummycodeSparseApplyTask extends ColumnApplyTask<ColumnEncoderDummycode> {
protected DummycodeSparseApplyTask(ColumnEncoderDummycode encoder, MatrixBlock input,
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
index cfa69d1..12e3f80 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
@@ -155,6 +155,15 @@
_K = in.readLong();
}
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName());
+ sb.append(": ");
+ sb.append(_colID);
+ return sb.toString();
+ }
+
public static class FeatureHashSparseApplyTask extends ColumnApplyTask<ColumnEncoderFeatureHash>{
public FeatureHashSparseApplyTask(ColumnEncoderFeatureHash encoder, CacheBlock<?> input,
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
index 63c2746..9d775a7 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
@@ -145,6 +145,15 @@
// do nothing
}
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName());
+ sb.append(": ");
+ sb.append(_colID);
+ return sb.toString();
+ }
+
public static class PassThroughSparseApplyTask extends ColumnApplyTask<ColumnEncoderPassThrough>{
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
index 799ed37..eb7e706 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
@@ -47,18 +47,19 @@
public static boolean SORT_RECODE_MAP = false;
// recode maps and custom map for partial recode maps
- private HashMap<String, Long> _rcdMap = new HashMap<>();
+ private HashMap<Object, Long> _rcdMap;
private HashSet<Object> _rcdMapPart = null;
public ColumnEncoderRecode(int colID) {
super(colID);
+ _rcdMap = new HashMap<>();
}
public ColumnEncoderRecode() {
this(-1);
}
- private ColumnEncoderRecode(int colID, HashMap<String, Long> rcdMap) {
+ protected ColumnEncoderRecode(int colID, HashMap<Object, Long> rcdMap) {
super(colID);
_rcdMap = rcdMap;
}
@@ -75,7 +76,7 @@
return constructRecodeMapEntry(token, code, sb);
}
- private static String constructRecodeMapEntry(String token, Long code, StringBuilder sb) {
+ private static String constructRecodeMapEntry(Object token, Long code, StringBuilder sb) {
sb.setLength(0); // reset reused string builder
return sb.append(token).append(Lop.DATATYPE_PREFIX).append(code.longValue()).toString();
}
@@ -93,7 +94,7 @@
return new String[] {value.substring(0, pos), value.substring(pos + 1)};
}
- public HashMap<String, Long> getCPRecodeMaps() {
+ public HashMap<Object, Long> getCPRecodeMaps() {
return _rcdMap;
}
@@ -105,15 +106,15 @@
sortCPRecodeMaps(_rcdMap);
}
- private static void sortCPRecodeMaps(HashMap<String, Long> map) {
- String[] keys = map.keySet().toArray(new String[0]);
+ private static void sortCPRecodeMaps(HashMap<Object, Long> map) {
+ Object[] keys = map.keySet().toArray(new Object[0]);
Arrays.sort(keys);
map.clear();
- for(String key : keys)
+ for(Object key : keys)
putCode(map, key);
}
- private static void makeRcdMap(CacheBlock<?> in, HashMap<String, Long> map, int colID, int startRow, int blk) {
+ private static void makeRcdMap(CacheBlock<?> in, HashMap<Object, Long> map, int colID, int startRow, int blk) {
for(int row = startRow; row < getEndIndex(in.getNumRows(), startRow, blk); row++){
String key = in.getString(row, colID - 1);
if(key != null && !key.isEmpty() && !map.containsKey(key))
@@ -124,9 +125,8 @@
}
}
- private long lookupRCDMap(String key) {
- Long tmp = _rcdMap.get(key);
- return (tmp != null) ? tmp : -1;
+ private long lookupRCDMap(Object key) {
+ return _rcdMap.getOrDefault(key, -1L);
}
public void computeRCDMapSizeEstimate(CacheBlock<?> in, int[] sampleIndices) {
@@ -202,7 +202,7 @@
* @param map column map
* @param key key for the new entry
*/
- protected static void putCode(HashMap<String, Long> map, String key) {
+ protected static void putCode(HashMap<Object, Long> map, Object key) {
map.put(key, (long) (map.size() + 1));
}
@@ -270,10 +270,10 @@
assert other._colID == _colID;
// merge together overlapping columns
ColumnEncoderRecode otherRec = (ColumnEncoderRecode) other;
- HashMap<String, Long> otherMap = otherRec._rcdMap;
+ HashMap<Object, Long> otherMap = otherRec._rcdMap;
if(otherMap != null) {
// for each column, add all non present recode values
- for(Map.Entry<String, Long> entry : otherMap.entrySet()) {
+ for(Map.Entry<Object, Long> entry : otherMap.entrySet()) {
if(lookupRCDMap(entry.getKey()) == -1) {
// key does not yet exist
putCode(_rcdMap, entry.getKey());
@@ -305,7 +305,7 @@
// create compact meta data representation
StringBuilder sb = new StringBuilder(); // for reuse
int rowID = 0;
- for(Entry<String, Long> e : _rcdMap.entrySet()) {
+ for(Entry<Object, Long> e : _rcdMap.entrySet()) {
meta.set(rowID++, _colID - 1, // 1-based
constructRecodeMapEntry(e.getKey(), e.getValue(), sb));
}
@@ -330,8 +330,9 @@
public void writeExternal(ObjectOutput out) throws IOException {
super.writeExternal(out);
out.writeInt(_rcdMap.size());
- for(Entry<String, Long> e : _rcdMap.entrySet()) {
- out.writeUTF(e.getKey());
+
+ for(Entry<Object, Long> e : _rcdMap.entrySet()) {
+ out.writeUTF(e.getKey().toString());
out.writeLong(e.getValue());
}
}
@@ -362,10 +363,21 @@
return Objects.hash(_rcdMap);
}
- public HashMap<String, Long> getRcdMap() {
+ public HashMap<Object, Long> getRcdMap() {
return _rcdMap;
}
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getSimpleName());
+ sb.append(": ");
+ sb.append(_colID);
+ sb.append(" --- map: ");
+ sb.append(_rcdMap);
+ return sb.toString();
+ }
+
private static class RecodeSparseApplyTask extends ColumnApplyTask<ColumnEncoderRecode>{
public RecodeSparseApplyTask(ColumnEncoderRecode encoder, CacheBlock<?> input, MatrixBlock out, int outputCol) {
@@ -416,9 +428,9 @@
}
@Override
- public HashMap<String, Long> call() throws Exception {
+ public Object call() throws Exception {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
- HashMap<String, Long> partialMap = new HashMap<>();
+ HashMap<Object, Long> partialMap = new HashMap<>();
makeRcdMap(_input, partialMap, _colID, _startRow, _blockSize);
synchronized(_partialMaps) {
_partialMaps.put(_startRow, partialMap);
@@ -448,11 +460,11 @@
@Override
public Object call() throws Exception {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
- HashMap<String, Long> rcdMap = _encoder.getRcdMap();
+ HashMap<Object, Long> rcdMap = _encoder.getRcdMap();
_partialMaps.forEach((start_row, map) -> {
((HashMap<?, ?>) map).forEach((k, v) -> {
- if(!rcdMap.containsKey((String) k))
- putCode(rcdMap, (String) k);
+ if(!rcdMap.containsKey(k))
+ putCode(rcdMap, k);
});
});
_encoder._rcdMap = rcdMap;
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
new file mode 100644
index 0000000..b4bcd3c
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.transform.encode;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
+import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
+import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
+import org.apache.sysds.runtime.compress.utils.Util;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.frame.data.columns.Array;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+public class CompressedEncode {
+ protected static final Log LOG = LogFactory.getLog(CompressedEncode.class.getName());
+
+ private final MultiColumnEncoder enc;
+ private final FrameBlock in;
+
+ private CompressedEncode(MultiColumnEncoder enc, FrameBlock in) {
+ this.enc = enc;
+ this.in = in;
+ }
+
+ public static MatrixBlock encode(MultiColumnEncoder enc, FrameBlock in) {
+ return new CompressedEncode(enc, in).apply();
+ }
+
+ private MatrixBlock apply() {
+ List<ColumnEncoderComposite> encoders = enc.getColumnEncoders();
+
+ List<AColGroup> groups = new ArrayList<>(encoders.size());
+
+ for(ColumnEncoderComposite c : encoders)
+ groups.add(encode(c));
+
+ int cols = shiftGroups(groups);
+
+ MatrixBlock mb = new CompressedMatrixBlock(in.getNumRows(), cols, -1, false, groups);
+ mb.recomputeNonZeros();
+ logging(mb);
+ return mb;
+ }
+
+ /**
+ * Shift the column groups to the correct column numbers.
+ *
+ * @param groups the groups to shift
+ * @return The total number of columns contained.
+ */
+ private int shiftGroups(List<AColGroup> groups) {
+ int cols = groups.get(0).getColIndices().length;
+ for(int i = 1; i < groups.size(); i++) {
+ groups.set(i, groups.get(i).shiftColIndices(cols));
+ cols += groups.get(i).getColIndices().length;
+ }
+ return cols;
+ }
+
+ private AColGroup encode(ColumnEncoderComposite c) {
+ if(c.isRecodeToDummy())
+ return recodeToDummy(c);
+ else if(c.isRecode())
+ return recode(c);
+ else if(c.isPassThrough())
+ return passThrough(c);
+ else
+ throw new NotImplementedException("Not supporting : " + c);
+ }
+
+ @SuppressWarnings("unchecked")
+ private AColGroup recodeToDummy(ColumnEncoderComposite c) {
+ int colId = c._colID;
+ Array<?> a = in.getColumn(colId - 1);
+ HashMap<?, Long> map = a.getRecodeMap();
+ int domain = map.size();
+
+ // int domain = c.getDomainSize();
+ int[] colIndexes = Util.genColsIndices(0, domain);
+
+ ADictionary d = new IdentityDictionary(colIndexes.length);
+
+ AMapToData m = createMappingAMapToData(a, map);
+
+ List<ColumnEncoder> r = c.getEncoders();
+ r.set(0, new ColumnEncoderRecode(colId, (HashMap<Object, Long>) map));
+
+ return ColGroupDDC.create(colIndexes, d, m, null);
+
+ }
+
+ @SuppressWarnings("unchecked")
+ private AColGroup recode(ColumnEncoderComposite c) {
+ int colId = c._colID;
+ Array<?> a = in.getColumn(colId - 1);
+ HashMap<?, Long> map = a.getRecodeMap();
+ int domain = map.size();
+
+ // int domain = c.getDomainSize();
+ int[] colIndexes = new int[1];
+ MatrixBlock incrementing = new MatrixBlock(domain, 1, false);
+ for(int i = 0; i < domain; i++)
+ incrementing.quickSetValue(i, 0, i + 1);
+
+ ADictionary d = MatrixBlockDictionary.create(incrementing);
+
+ AMapToData m = createMappingAMapToData(a, map);
+
+ List<ColumnEncoder> r = c.getEncoders();
+ r.set(0, new ColumnEncoderRecode(colId, (HashMap<Object, Long>) map));
+
+ return ColGroupDDC.create(colIndexes, d, m, null);
+
+ }
+
+ @SuppressWarnings("unchecked")
+ private AColGroup passThrough(ColumnEncoderComposite c) {
+ int[] colIndexes = new int[1];
+ int colId = c._colID;
+ Array<?> a = in.getColumn(colId - 1);
+ HashMap<Object, Long> map = (HashMap<Object, Long>) a.getRecodeMap();
+
+ double[] vals = new double[map.size() + (a.containsNull() ? 1 : 0)];
+ for(int i = 0; i < a.size(); i++) {
+ Object v = a.get(i);
+ if(map.containsKey(v)) {
+ vals[map.get(v).intValue()] = a.getAsDouble(i);
+ }
+ else {
+ map.put(null, (long) map.size());
+ vals[map.get(v).intValue()] = a.getAsDouble(i);
+ }
+ }
+ ADictionary d = Dictionary.create(vals);
+ AMapToData m = createMappingAMapToData(a, map);
+ return ColGroupDDC.create(colIndexes, d, m, null);
+ }
+
+ private AMapToData createMappingAMapToData(Array<?> a, HashMap<?, Long> map) {
+ AMapToData m = MapToFactory.create(in.getNumRows(), map.size());
+ Array<?>.ArrayIterator it = a.getIterator();
+ while(it.hasNext()) {
+ Object v = it.next();
+ if(v != null) {
+ m.set(it.getIndex(), map.get(v).intValue());
+ }
+ }
+ return m;
+ }
+
+ private void logging(MatrixBlock mb) {
+ if(LOG.isDebugEnabled()) {
+ LOG.debug(String.format("Uncompressed transform encode Dense size: %16d", mb.estimateSizeDenseInMemory()));
+ LOG.debug(String.format("Uncompressed transform encode Sparse size: %16d", mb.estimateSizeSparseInMemory()));
+ LOG.debug(String.format("Compressed transform encode size: %16d", mb.estimateSizeInMemory()));
+
+ double ratio = Math.min(mb.estimateSizeDenseInMemory(), mb.estimateSizeSparseInMemory()) /
+ mb.estimateSizeInMemory();
+ double denseRatio = mb.estimateSizeDenseInMemory() / mb.estimateSizeInMemory();
+ LOG.debug(String.format("Compression ratio: %10.3f", ratio));
+ LOG.debug(String.format("Dense ratio: %10.3f", denseRatio));
+ }
+
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 22190e5..59c1a3d 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -42,6 +42,7 @@
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.estim.ComEstSample;
@@ -90,11 +91,21 @@
}
public MatrixBlock encode(CacheBlock<?> in, int k) {
- MatrixBlock out;
+ return encode(in, k, false);
+ }
+
+ public MatrixBlock encode(CacheBlock<?> in, boolean compressedOut) {
+ return encode(in, 1, compressedOut);
+ }
+
+ public MatrixBlock encode(CacheBlock<?> in, int k, boolean compressedOut){
+
deriveNumRowPartitions(in, k);
try {
- if(k > 1 && !MULTI_THREADED_STAGES && !hasLegacyEncoder()) {
- out = new MatrixBlock();
+ if(isCompressedTransformEncode(in, compressedOut))
+ return CompressedEncode.encode(this, (FrameBlock ) in);
+ else if(k > 1 && !MULTI_THREADED_STAGES && !hasLegacyEncoder()) {
+ MatrixBlock out = new MatrixBlock();
DependencyThreadPool pool = new DependencyThreadPool(k);
LOG.debug("Encoding with full DAG on " + k + " Threads");
try {
@@ -106,6 +117,7 @@
}
pool.shutdown();
outputMatrixPostProcessing(out);
+ return out;
}
else {
LOG.debug("Encoding with staged approach on: " + k + " Threads");
@@ -123,16 +135,20 @@
}
// apply meta data
t0 = System.nanoTime();
- out = apply(in, k);
+ MatrixBlock out = apply(in, k);
t1 = System.nanoTime();
LOG.debug("Elapsed time for apply phase: "+ ((double) t1 - t0) / 1000000 + " ms");
+ return out;
}
}
catch(Exception ex) {
LOG.error("Failed transform-encode frame with \n" + this);
throw ex;
}
- return out;
+ }
+
+ protected List<ColumnEncoderComposite> getEncoders() {
+ return _columnEncoders;
}
/* TASK DETAILS:
@@ -245,21 +261,7 @@
}
public void build(CacheBlock<?> in, int k) {
- if(hasLegacyEncoder() && !(in instanceof FrameBlock))
- throw new DMLRuntimeException("LegacyEncoders do not support non FrameBlock Inputs");
- if(!_partitionDone) //happens if this method is directly called
- deriveNumRowPartitions(in, k);
- if(k > 1) {
- buildMT(in, k);
- }
- else {
- for(ColumnEncoderComposite columnEncoder : _columnEncoders) {
- columnEncoder.build(in);
- columnEncoder.updateAllDCEncoders();
- }
- }
- if(hasLegacyEncoder())
- legacyBuild((FrameBlock) in);
+ build(in, k, null);
}
public void build(CacheBlock<?> in, int k, Map<Integer, double[]> equiHeightBinMaxs) {
@@ -317,7 +319,7 @@
boolean hasUDF = _columnEncoders.stream().anyMatch(e -> e.hasEncoder(ColumnEncoderUDF.class));
for(ColumnEncoderComposite columnEncoder : _columnEncoders)
columnEncoder.updateAllDCEncoders();
- int numCols = in.getNumColumns() + getNumExtraCols();
+ int numCols = getNumOutCols();
long estNNz = (long) in.getNumRows() * (hasUDF ? numCols : (long) in.getNumColumns());
boolean sparse = MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) && !hasUDF;
MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols, sparse, estNNz);
@@ -654,6 +656,8 @@
long t0 = System.nanoTime();
if(_meta != null)
return _meta;
+ if(meta == null)
+ meta = new FrameBlock(_columnEncoders.size(), ValueType.STRING);
this.allocateMetaData(meta);
if (k > 1) {
try {
@@ -854,24 +858,11 @@
return getEncoderTypes(-1);
}
- public int getNumExtraCols() {
- List<ColumnEncoderDummycode> dc = getColumnEncoders(ColumnEncoderDummycode.class);
- if(dc.isEmpty()) {
- return 0;
- }
- if(dc.stream().anyMatch(e -> e.getDomainSize() < 0)) {
- throw new DMLRuntimeException("Trying to get extra columns when DC encoders are not ready");
- }
- return dc.stream().map(ColumnEncoderDummycode::getDomainSize).mapToInt(i -> i).sum() - dc.size();
- }
-
- public int getNumExtraCols(IndexRange ixRange) {
- List<ColumnEncoderDummycode> dc = getColumnEncoders(ColumnEncoderDummycode.class).stream()
- .filter(dce -> ixRange.inColRange(dce._colID)).collect(Collectors.toList());
- if(dc.isEmpty()) {
- return 0;
- }
- return dc.stream().map(ColumnEncoderDummycode::getDomainSize).mapToInt(i -> i).sum() - dc.size();
+ public int getNumOutCols() {
+ int sum = 0;
+ for(int i = 0; i < _columnEncoders.size(); i++)
+ sum += _columnEncoders.get(i).getDomainSize();
+ return sum;
}
public <T extends ColumnEncoder> boolean containsEncoderForID(int colID, Class<T> type) {
@@ -998,6 +989,11 @@
return hasLegacyEncoder(EncoderMVImpute.class) || hasLegacyEncoder(EncoderOmit.class);
}
+ public boolean isCompressedTransformEncode(CacheBlock<?> in, boolean enabled){
+ return (enabled || ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.COMPRESSED_TRANSFORMENCODE)) &&
+ in instanceof FrameBlock && _colOffset == 0;
+ }
+
public <T extends LegacyEncoder> boolean hasLegacyEncoder(Class<T> type) {
if(type.equals(EncoderMVImpute.class))
return _legacyMVImpute != null;
@@ -1027,6 +1023,22 @@
_legacyMVImpute.shiftCols(_colOffset);
}
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(this.getClass().getSimpleName());
+ sb.append("\nIs Legacy: ");
+ sb.append(_legacyMVImpute);
+ sb.append("\nEncoders:\n");
+
+ for(int i = 0; i < _columnEncoders.size(); i++) {
+ sb.append(_columnEncoders.get(i));
+ sb.append("\n");
+ }
+
+ return sb.toString();
+ }
+
/*
* Currently, not in use will be integrated in the future
*/
@@ -1081,7 +1093,7 @@
@Override
public Object call() throws Exception {
boolean hasUDF = _encoder.getColumnEncoders().stream().anyMatch(e -> e.hasEncoder(ColumnEncoderUDF.class));
- int numCols = _input.getNumColumns() + _encoder.getNumExtraCols();
+ int numCols = _encoder.getNumOutCols();
boolean hasDC = _encoder.getColumnEncoders(ColumnEncoderDummycode.class).size() > 0;
long estNNz = (long) _input.getNumRows() * (hasUDF ? numCols : (long) _input.getNumColumns());
boolean sparse = MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) && !hasUDF;
diff --git a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java
index b060541..987f85a 100644
--- a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java
@@ -1119,6 +1119,7 @@
sb.append(dfFormat(df, value));
break;
case UINT8:
+ case UINT4:
case INT32:
case INT64:
sb.append(tb.get(ix));
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index ea77173..5568603 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -485,6 +485,7 @@
switch( vt ) {
case STRING: return in;
case BOOLEAN: return Boolean.parseBoolean(in);
+ case UINT4:
case UINT8:
case INT32: return Integer.parseInt(in);
case INT64: return Long.parseLong(in);
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java
index ece9c77..bade9dd 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -842,6 +842,10 @@
}
public static void compareFrames(FrameBlock expected, FrameBlock actual, boolean checkMeta) {
+ if(expected == null && actual == null)
+ return;
+ assertTrue("Expected frame was null pointer", expected != null);
+ assertTrue("Actual frame was null pointer", actual != null);
assertEquals("Number of columns and rows are not equivalent", expected.getNumRows(), actual.getNumRows());
assertEquals("Number of columns and rows are not equivalent", expected.getNumColumns(), actual.getNumColumns());
@@ -2417,6 +2421,7 @@
*/
public static Object generateRandomValueFromValueType(ValueType valueType, Random random){
switch (valueType){
+ case UINT4: return random.nextInt(16);
case UINT8: return random.nextInt(256);
case FP32: return random.nextFloat();
case FP64: return random.nextDouble();
diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java
index c68332e..cc6c351 100644
--- a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java
+++ b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java
@@ -1150,7 +1150,7 @@
@Test
public void mappingCache() {
- Array<?> a = new StringArray(new String[] {"1", null});
+ Array<String> a = new StringArray(new String[] {"1", null});
assertEquals(null, a.getCache());
a.setCache(new SoftReference<HashMap<String, Long>>(null));
assertTrue(null != a.getCache());
diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/transformCompressed.java b/src/test/java/org/apache/sysds/test/component/frame/transform/transformCompressed.java
new file mode 100644
index 0000000..343aaf0
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/frame/transform/transformCompressed.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.frame.transform;
+
+import static org.junit.Assert.fail;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.transform.encode.EncoderFactory;
+import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class transformCompressed {
+ protected static final Log LOG = LogFactory.getLog(transformCompressed.class.getName());
+
+ private final FrameBlock data;
+
+ public transformCompressed() {
+ try {
+
+ data = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231);
+ data.setSchema(new ValueType[] {ValueType.INT32});
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ throw e;
+ }
+ }
+
+ @Test
+ public void testRecode() {
+ test("{recode:[C1]}");
+ }
+
+ @Test
+ public void testDummyCode() {
+ test("{dummycode:[C1]}");
+ }
+
+ // @Test
+ // public void testBin() {
+ // test("{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}");
+ // }
+
+ // @Test
+ // public void testBin2() {
+ // test("{ids:true, bin:[{id:1, method:equi-width, numbins:100}]}");
+ // }
+
+ // @Test
+ // public void testBin3() {
+ // test("{ids:true, bin:[{id:1, method:equi-width, numbins:2}]}");
+ // }
+
+ // @Test
+ // public void testBin4() {
+ // test("{ids:true, bin:[{id:1, method:equi-height, numbins:2}]}");
+ // }
+
+ // @Test
+ // public void testBin5() {
+ // test("{ids:true, bin:[{id:1, method:equi-height, numbins:10}]}");
+ // }
+
+ public void test(String spec) {
+ try {
+
+ FrameBlock meta = null;
+ MultiColumnEncoder encoderCompressed = EncoderFactory.createEncoder(spec, data.getColumnNames(),
+ data.getNumColumns(), meta);
+ MatrixBlock outCompressed = encoderCompressed.encode(data, true);
+ FrameBlock outCompressedMD = encoderCompressed.getMetaData(null);
+ MultiColumnEncoder encoderNormal = EncoderFactory.createEncoder(spec, data.getColumnNames(),
+ data.getNumColumns(), meta);
+ MatrixBlock outNormal = encoderNormal.encode(data);
+ FrameBlock outNormalMD = encoderNormal.getMetaData(null);
+
+
+ LOG.error(outNormal);
+ LOG.error(outCompressed);
+ LOG.error(outCompressedMD);
+ LOG.error(outNormalMD);
+
+ TestUtils.compareMatrices(outNormal, outCompressed, 0, "Not Equal after apply");
+ TestUtils.compareFrames(outNormalMD, outCompressedMD, true);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+}